Skip to main content

ratel_ai_core/
registry.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use bm25::{Document, Language, SearchEngineBuilder};
5
6use crate::indexing::searchable_text;
7use crate::tool::Tool;
8use crate::trace::{
9    ChurnKind, NoopSink, Origin, SearchHitTrace, SearchStage, TraceEvent, TraceSink,
10};
11
12// Tuned for short tool descriptions; see ADR-0004.
13const BM25_K1: f32 = 0.9;
14const BM25_B: f32 = 0.4;
15
16pub struct SearchHit {
17    pub tool_id: String,
18    pub score: f32,
19}
20
21pub struct ToolRegistry {
22    tools: Vec<Tool>,
23    sink: Arc<dyn TraceSink>,
24}
25
26impl Default for ToolRegistry {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl ToolRegistry {
33    pub fn new() -> Self {
34        Self {
35            tools: Vec::new(),
36            sink: Arc::new(NoopSink),
37        }
38    }
39
40    pub fn with_trace_sink(sink: Arc<dyn TraceSink>) -> Self {
41        Self {
42            tools: Vec::new(),
43            sink,
44        }
45    }
46
47    pub fn set_trace_sink(&mut self, sink: Arc<dyn TraceSink>) {
48        self.sink = sink;
49    }
50
51    pub fn record_event(&self, event: TraceEvent) {
52        self.sink.record(event);
53    }
54
55    pub fn register(&mut self, tool: Tool) {
56        let tool_id = tool.id.clone();
57        self.tools.push(tool);
58        self.sink.record(TraceEvent::IndexChurn {
59            kind: ChurnKind::Add,
60            tool_id,
61        });
62    }
63
64    pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchHit> {
65        self.search_with_origin(query, top_k, Origin::Direct)
66    }
67
68    pub fn search_with_origin(&self, query: &str, top_k: usize, origin: Origin) -> Vec<SearchHit> {
69        let started = Instant::now();
70        let hits: Vec<SearchHit> = if self.tools.is_empty() {
71            Vec::new()
72        } else {
73            let docs = self.tools.iter().map(|t| Document {
74                id: t.id.clone(),
75                contents: searchable_text(t),
76            });
77            let engine = SearchEngineBuilder::<String>::with_documents(Language::English, docs)
78                .k1(BM25_K1)
79                .b(BM25_B)
80                .build();
81            engine
82                .search(query, top_k)
83                .into_iter()
84                .map(|r| SearchHit {
85                    tool_id: r.document.id,
86                    score: r.score,
87                })
88                .collect()
89        };
90        let took_ms = started.elapsed().as_millis() as u64;
91        let top_score = hits.first().map(|h| h.score as f64);
92        self.sink.record(TraceEvent::Search {
93            query: query.to_string(),
94            origin,
95            top_k: top_k as u32,
96            hits: hits
97                .iter()
98                .map(|h| SearchHitTrace {
99                    tool_id: h.tool_id.clone(),
100                    score: h.score as f64,
101                })
102                .collect(),
103            stages: vec![SearchStage {
104                name: "bm25".into(),
105                took_ms,
106                top_score,
107            }],
108            took_ms,
109        });
110        hits
111    }
112}