ratel_ai_core/
registry.rs1use std::sync::Arc;
2use std::time::Instant;
3
4use bm25::{Document, Language, SearchEngineBuilder};
5
6use crate::indexing::searchable_text;
7use crate::tool::Tool;
8use crate::trace::{
9 ChurnKind, NoopSink, Origin, SearchHitTrace, SearchStage, TraceEvent, TraceSink,
10};
11
12const BM25_K1: f32 = 0.9;
14const BM25_B: f32 = 0.4;
15
16pub struct SearchHit {
17 pub tool_id: String,
18 pub score: f32,
19}
20
21pub struct ToolRegistry {
22 tools: Vec<Tool>,
23 sink: Arc<dyn TraceSink>,
24}
25
26impl Default for ToolRegistry {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl ToolRegistry {
33 pub fn new() -> Self {
34 Self {
35 tools: Vec::new(),
36 sink: Arc::new(NoopSink),
37 }
38 }
39
40 pub fn with_trace_sink(sink: Arc<dyn TraceSink>) -> Self {
41 Self {
42 tools: Vec::new(),
43 sink,
44 }
45 }
46
47 pub fn set_trace_sink(&mut self, sink: Arc<dyn TraceSink>) {
48 self.sink = sink;
49 }
50
51 pub fn record_event(&self, event: TraceEvent) {
52 self.sink.record(event);
53 }
54
55 pub fn register(&mut self, tool: Tool) {
56 let tool_id = tool.id.clone();
57 self.tools.push(tool);
58 self.sink.record(TraceEvent::IndexChurn {
59 kind: ChurnKind::Add,
60 tool_id,
61 });
62 }
63
64 pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchHit> {
65 self.search_with_origin(query, top_k, Origin::Direct)
66 }
67
68 pub fn search_with_origin(&self, query: &str, top_k: usize, origin: Origin) -> Vec<SearchHit> {
69 let started = Instant::now();
70 let hits: Vec<SearchHit> = if self.tools.is_empty() {
71 Vec::new()
72 } else {
73 let docs = self.tools.iter().map(|t| Document {
74 id: t.id.clone(),
75 contents: searchable_text(t),
76 });
77 let engine = SearchEngineBuilder::<String>::with_documents(Language::English, docs)
78 .k1(BM25_K1)
79 .b(BM25_B)
80 .build();
81 engine
82 .search(query, top_k)
83 .into_iter()
84 .map(|r| SearchHit {
85 tool_id: r.document.id,
86 score: r.score,
87 })
88 .collect()
89 };
90 let took_ms = started.elapsed().as_millis() as u64;
91 let top_score = hits.first().map(|h| h.score as f64);
92 self.sink.record(TraceEvent::Search {
93 query: query.to_string(),
94 origin,
95 top_k: top_k as u32,
96 hits: hits
97 .iter()
98 .map(|h| SearchHitTrace {
99 tool_id: h.tool_id.clone(),
100 score: h.score as f64,
101 })
102 .collect(),
103 stages: vec![SearchStage {
104 name: "bm25".into(),
105 took_ms,
106 top_score,
107 }],
108 took_ms,
109 });
110 hits
111 }
112}