Skip to main content

batuta/agent/tool/
rag.rs

1//! RAG tool — wraps `oracle::rag::RagOracle` for document retrieval.
2//!
3//! Provides agent access to the indexed Sovereign AI Stack documentation
4//! via the existing RAG pipeline (BM25 + dense hybrid retrieval, RRF).
5//!
6//! Feature-gated: requires `rag` feature for `RagOracle` access.
7
8use async_trait::async_trait;
9use std::sync::Arc;
10
11use super::{Tool, ToolResult};
12use crate::agent::capability::Capability;
13use crate::agent::driver::ToolDefinition;
14use crate::oracle::rag::RagOracle;
15
16/// Tool that wraps `RagOracle` for agent document retrieval.
17pub struct RagTool {
18    oracle: Arc<RagOracle>,
19    max_results: usize,
20}
21
22impl RagTool {
23    /// Create a new RAG tool wrapping an existing oracle instance.
24    pub fn new(oracle: Arc<RagOracle>, max_results: usize) -> Self {
25        Self { oracle, max_results }
26    }
27}
28
29#[async_trait]
30impl Tool for RagTool {
31    fn name(&self) -> &'static str {
32        "rag"
33    }
34
35    fn definition(&self) -> ToolDefinition {
36        ToolDefinition {
37            name: "rag".into(),
38            description: "Search indexed Sovereign AI Stack documentation".into(),
39            input_schema: serde_json::json!({
40                "type": "object",
41                "properties": {
42                    "query": {
43                        "type": "string",
44                        "description": "Search query for documentation"
45                    }
46                },
47                "required": ["query"]
48            }),
49        }
50    }
51
52    async fn execute(&self, input: serde_json::Value) -> ToolResult {
53        let Some(query) = input.get("query").and_then(|q| q.as_str()) else {
54            return ToolResult::error("missing required field: query");
55        };
56
57        let results = self.oracle.query(query);
58        let truncated: Vec<_> = results.into_iter().take(self.max_results).collect();
59
60        if truncated.is_empty() {
61            return ToolResult::success("No results found for the given query.");
62        }
63
64        let formatted = format_results(&truncated);
65        ToolResult::success(formatted)
66    }
67
68    fn required_capability(&self) -> Capability {
69        Capability::Rag
70    }
71
72    fn timeout(&self) -> std::time::Duration {
73        std::time::Duration::from_secs(120)
74    }
75}
76
77/// Format retrieval results as markdown for LLM consumption.
78fn format_results(results: &[crate::oracle::rag::RetrievalResult]) -> String {
79    use std::fmt::Write;
80    let mut out = String::with_capacity(results.len() * 256);
81    for (i, r) in results.iter().enumerate() {
82        let _ = writeln!(out, "### Result {} (score: {:.3})", i + 1, r.score);
83        let _ = write!(
84            out,
85            "**Source:** {} ({}:{}–{})\n\n",
86            r.source, r.component, r.start_line, r.end_line
87        );
88        out.push_str(&r.content);
89        out.push_str("\n\n---\n\n");
90    }
91    out
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_format_results_empty() {
100        let results = vec![];
101        assert_eq!(format_results(&results), "");
102    }
103
104    #[test]
105    fn test_format_results_single() {
106        use crate::oracle::rag::ScoreBreakdown;
107        let results = vec![crate::oracle::rag::RetrievalResult {
108            id: "doc-1".into(),
109            component: "trueno".into(),
110            source: "src/lib.rs".into(),
111            content: "SIMD compute primitives".into(),
112            score: 0.95,
113            start_line: 1,
114            end_line: 10,
115            score_breakdown: ScoreBreakdown {
116                bm25_score: 0.5,
117                dense_score: 0.45,
118                rrf_score: 0.95,
119                rerank_score: None,
120            },
121        }];
122        let formatted = format_results(&results);
123        assert!(formatted.contains("Result 1"));
124        assert!(formatted.contains("0.950"));
125        assert!(formatted.contains("trueno"));
126        assert!(formatted.contains("SIMD compute primitives"));
127    }
128
129    #[test]
130    fn test_rag_tool_metadata() {
131        // Cannot construct RagOracle without a full index,
132        // so test metadata only via trait bounds
133        assert_eq!(Capability::Rag, Capability::Rag, "Rag capability match");
134    }
135
136    #[test]
137    fn test_tool_definition_schema() {
138        // Validate the schema structure statically
139        let schema = serde_json::json!({
140            "type": "object",
141            "properties": {
142                "query": {
143                    "type": "string",
144                    "description": "Search query for documentation"
145                }
146            },
147            "required": ["query"]
148        });
149        assert!(schema.get("properties").is_some());
150        assert!(schema.get("required").is_some());
151    }
152}