Skip to main content

ratel_ai_core/
skill_registry.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use crate::search::bm25_search;
5use crate::skill::Skill;
6use crate::skill_indexing::searchable_text;
7use crate::trace::{
8    ChurnKind, NoopSink, Origin, SearchStage, SkillHitTrace, TraceEvent, TraceSink,
9};
10
11pub struct SkillHit {
12    pub skill_id: String,
13    pub score: f32,
14}
15
16/// Retrieval index over [`Skill`]s — the on-demand analog of
17/// [`crate::ToolRegistry`]. Same BM25 engine and tuning; a parallel type keeps
18/// the tool path untouched and lets skill telemetry stand on its own.
19pub struct SkillRegistry {
20    skills: Vec<Skill>,
21    sink: Arc<dyn TraceSink>,
22}
23
24impl Default for SkillRegistry {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl SkillRegistry {
31    pub fn new() -> Self {
32        Self {
33            skills: Vec::new(),
34            sink: Arc::new(NoopSink),
35        }
36    }
37
38    pub fn with_trace_sink(sink: Arc<dyn TraceSink>) -> Self {
39        Self {
40            skills: Vec::new(),
41            sink,
42        }
43    }
44
45    pub fn set_trace_sink(&mut self, sink: Arc<dyn TraceSink>) {
46        self.sink = sink;
47    }
48
49    pub fn record_event(&self, event: TraceEvent) {
50        self.sink.record(event);
51    }
52
53    pub fn register(&mut self, skill: Skill) {
54        let skill_id = skill.id.clone();
55        self.skills.push(skill);
56        self.sink.record(TraceEvent::SkillChurn {
57            kind: ChurnKind::Add,
58            skill_id,
59        });
60    }
61
62    pub fn search(&self, query: &str, top_k: usize) -> Vec<SkillHit> {
63        self.search_with_origin(query, top_k, Origin::Direct)
64    }
65
66    pub fn search_with_origin(&self, query: &str, top_k: usize, origin: Origin) -> Vec<SkillHit> {
67        let started = Instant::now();
68        let hits: Vec<SkillHit> = bm25_search(
69            self.skills
70                .iter()
71                .map(|s| (s.id.clone(), searchable_text(s))),
72            query,
73            top_k,
74        )
75        .into_iter()
76        .map(|(skill_id, score)| SkillHit { skill_id, score })
77        .collect();
78        let took_ms = started.elapsed().as_millis() as u64;
79        let top_score = hits.first().map(|h| h.score as f64);
80        self.sink.record(TraceEvent::SkillSearch {
81            query: query.to_string(),
82            origin,
83            top_k: top_k as u32,
84            hits: hits
85                .iter()
86                .map(|h| SkillHitTrace {
87                    skill_id: h.skill_id.clone(),
88                    score: h.score as f64,
89                })
90                .collect(),
91            stages: vec![SearchStage {
92                name: "bm25".into(),
93                took_ms,
94                top_score,
95            }],
96            took_ms,
97        });
98        hits
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::trace::MemorySink;
106
107    fn skill(id: &str, name: &str, description: &str, tags: &[&str]) -> Skill {
108        Skill {
109            id: id.into(),
110            name: name.into(),
111            description: description.into(),
112            tags: tags.iter().map(|t| (*t).into()).collect(),
113            tools: vec![],
114            metadata: std::collections::HashMap::new(),
115            body: format!("# {name}\n\nbody"),
116        }
117    }
118
119    fn catalog() -> SkillRegistry {
120        let mut reg = SkillRegistry::new();
121        reg.register(skill(
122            "frontend-slides",
123            "frontend-slides",
124            "Build animation-rich HTML presentations from scratch",
125            &["frontend", "presentations"],
126        ));
127        reg.register(skill(
128            "api-design",
129            "api-design",
130            "REST API design patterns: resource naming, status codes, pagination",
131            &["backend", "api"],
132        ));
133        reg
134    }
135
136    #[test]
137    fn search_ranks_the_relevant_skill_first() {
138        let reg = catalog();
139        let hits = reg.search("design a REST endpoint with pagination", 5);
140        assert_eq!(
141            hits.first().map(|h| h.skill_id.as_str()),
142            Some("api-design")
143        );
144    }
145
146    #[test]
147    fn search_on_empty_registry_returns_no_hits() {
148        let reg = SkillRegistry::new();
149        assert!(reg.search("anything", 5).is_empty());
150    }
151
152    #[test]
153    fn register_and_search_emit_trace_events() {
154        let sink = Arc::new(MemorySink::new("test-session"));
155        let mut reg = SkillRegistry::with_trace_sink(sink.clone());
156        reg.register(skill(
157            "api-design",
158            "api-design",
159            "REST API design",
160            &["api"],
161        ));
162        reg.search_with_origin("api design", 5, Origin::Agent);
163
164        let events = sink.drain();
165        assert!(events.iter().any(|e| matches!(
166            e.event,
167            TraceEvent::SkillChurn {
168                kind: ChurnKind::Add,
169                ..
170            }
171        )));
172        assert!(events.iter().any(|e| matches!(
173            &e.event,
174            TraceEvent::SkillSearch { origin: Origin::Agent, hits, .. } if !hits.is_empty()
175        )));
176    }
177}