1use async_trait::async_trait;
9use std::sync::Arc;
10
11use super::{Tool, ToolResult};
12use crate::agent::capability::Capability;
13use crate::agent::driver::ToolDefinition;
14use crate::oracle::rag::RagOracle;
15
16pub struct RagTool {
18 oracle: Arc<RagOracle>,
19 max_results: usize,
20}
21
22impl RagTool {
23 pub fn new(oracle: Arc<RagOracle>, max_results: usize) -> Self {
25 Self { oracle, max_results }
26 }
27}
28
29#[async_trait]
30impl Tool for RagTool {
31 fn name(&self) -> &'static str {
32 "rag"
33 }
34
35 fn definition(&self) -> ToolDefinition {
36 ToolDefinition {
37 name: "rag".into(),
38 description: "Search indexed Sovereign AI Stack documentation".into(),
39 input_schema: serde_json::json!({
40 "type": "object",
41 "properties": {
42 "query": {
43 "type": "string",
44 "description": "Search query for documentation"
45 }
46 },
47 "required": ["query"]
48 }),
49 }
50 }
51
52 async fn execute(&self, input: serde_json::Value) -> ToolResult {
53 let Some(query) = input.get("query").and_then(|q| q.as_str()) else {
54 return ToolResult::error("missing required field: query");
55 };
56
57 let results = self.oracle.query(query);
58 let truncated: Vec<_> = results.into_iter().take(self.max_results).collect();
59
60 if truncated.is_empty() {
61 return ToolResult::success("No results found for the given query.");
62 }
63
64 let formatted = format_results(&truncated);
65 ToolResult::success(formatted)
66 }
67
68 fn required_capability(&self) -> Capability {
69 Capability::Rag
70 }
71
72 fn timeout(&self) -> std::time::Duration {
73 std::time::Duration::from_secs(120)
74 }
75}
76
77fn format_results(results: &[crate::oracle::rag::RetrievalResult]) -> String {
79 use std::fmt::Write;
80 let mut out = String::with_capacity(results.len() * 256);
81 for (i, r) in results.iter().enumerate() {
82 let _ = writeln!(out, "### Result {} (score: {:.3})", i + 1, r.score);
83 let _ = write!(
84 out,
85 "**Source:** {} ({}:{}–{})\n\n",
86 r.source, r.component, r.start_line, r.end_line
87 );
88 out.push_str(&r.content);
89 out.push_str("\n\n---\n\n");
90 }
91 out
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_format_results_empty() {
100 let results = vec![];
101 assert_eq!(format_results(&results), "");
102 }
103
104 #[test]
105 fn test_format_results_single() {
106 use crate::oracle::rag::ScoreBreakdown;
107 let results = vec![crate::oracle::rag::RetrievalResult {
108 id: "doc-1".into(),
109 component: "trueno".into(),
110 source: "src/lib.rs".into(),
111 content: "SIMD compute primitives".into(),
112 score: 0.95,
113 start_line: 1,
114 end_line: 10,
115 score_breakdown: ScoreBreakdown {
116 bm25_score: 0.5,
117 dense_score: 0.45,
118 rrf_score: 0.95,
119 rerank_score: None,
120 },
121 }];
122 let formatted = format_results(&results);
123 assert!(formatted.contains("Result 1"));
124 assert!(formatted.contains("0.950"));
125 assert!(formatted.contains("trueno"));
126 assert!(formatted.contains("SIMD compute primitives"));
127 }
128
129 #[test]
130 fn test_rag_tool_metadata() {
131 assert_eq!(Capability::Rag, Capability::Rag, "Rag capability match");
134 }
135
136 #[test]
137 fn test_tool_definition_schema() {
138 let schema = serde_json::json!({
140 "type": "object",
141 "properties": {
142 "query": {
143 "type": "string",
144 "description": "Search query for documentation"
145 }
146 },
147 "required": ["query"]
148 });
149 assert!(schema.get("properties").is_some());
150 assert!(schema.get("required").is_some());
151 }
152}