1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::json;
9
10use crate::auth::TenantScope;
11use crate::error::Error;
12use crate::llm::types::ToolDefinition;
13use crate::tool::{Tool, ToolOutput};
14
15use super::KnowledgeBase;
16
17pub fn knowledge_tools(kb: Arc<dyn KnowledgeBase>, scope: TenantScope) -> Vec<Arc<dyn Tool>> {
26 vec![Arc::new(KnowledgeSearchTool { kb, scope })]
27}
28
29fn default_limit() -> usize {
30 5
31}
32
33struct KnowledgeSearchTool {
34 kb: Arc<dyn KnowledgeBase>,
35 scope: TenantScope,
36}
37
38#[derive(Deserialize)]
39struct SearchInput {
40 query: String,
41 source_filter: Option<String>,
42 #[serde(default = "default_limit")]
43 limit: usize,
44}
45
46impl Tool for KnowledgeSearchTool {
47 fn definition(&self) -> ToolDefinition {
48 ToolDefinition {
49 name: "knowledge_search".into(),
50 description: "Search the knowledge base for relevant documentation, code examples, \
51 and reference material. Use this when you need to find specific \
52 information from project docs, API references, or other indexed sources."
53 .into(),
54 input_schema: json!({
55 "type": "object",
56 "properties": {
57 "query": {
58 "type": "string",
59 "description": "Free-text search query describing what you're looking for"
60 },
61 "source_filter": {
62 "type": "string",
63 "description": "Optional URI prefix to restrict results to specific sources (e.g. 'docs/' or 'https://api.example.com')"
64 },
65 "limit": {
66 "type": "integer",
67 "minimum": 1,
68 "maximum": 20,
69 "default": 5,
70 "description": "Maximum number of results to return"
71 }
72 },
73 "required": ["query"]
74 }),
75 }
76 }
77
78 fn execute(
79 &self,
80 _ctx: &crate::ExecutionContext,
81 input: serde_json::Value,
82 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
83 Box::pin(async move {
84 let input: SearchInput =
85 serde_json::from_value(input).map_err(|e| Error::Agent(e.to_string()))?;
86
87 let limit = input.limit.clamp(1, 20);
88
89 let results = self
90 .kb
91 .search(
92 &self.scope,
93 super::KnowledgeQuery {
94 text: input.query,
95 source_filter: input.source_filter,
96 limit,
97 },
98 )
99 .await?;
100
101 if results.is_empty() {
102 return Ok(ToolOutput::success(
103 "No matching documents found in the knowledge base.",
104 ));
105 }
106
107 let formatted = results
108 .iter()
109 .enumerate()
110 .map(|(i, r)| {
111 format!(
112 "--- Result {} (source: {}, matches: {}) ---\n{}",
113 i + 1,
114 r.chunk.source.uri,
115 r.match_count,
116 r.chunk.content,
117 )
118 })
119 .collect::<Vec<_>>()
120 .join("\n\n");
121
122 Ok(ToolOutput::success(format!(
123 "Found {} result(s):\n\n{}",
124 results.len(),
125 formatted,
126 )))
127 })
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::knowledge::in_memory::InMemoryKnowledgeBase;
135 use crate::knowledge::{Chunk, DocumentSource};
136
137 fn s() -> TenantScope {
138 TenantScope::default()
139 }
140
141 fn setup() -> (Arc<dyn KnowledgeBase>, Vec<Arc<dyn Tool>>) {
142 let kb: Arc<dyn KnowledgeBase> = Arc::new(InMemoryKnowledgeBase::new());
143 let tools = knowledge_tools(kb.clone(), s());
144 (kb, tools)
145 }
146
147 fn find_tool<'a>(tools: &'a [Arc<dyn Tool>], name: &str) -> &'a Arc<dyn Tool> {
148 tools
149 .iter()
150 .find(|t| t.definition().name == name)
151 .unwrap_or_else(|| panic!("tool {name} not found"))
152 }
153
154 #[test]
155 fn creates_one_tool() {
156 let (_kb, tools) = setup();
157 assert_eq!(tools.len(), 1);
158 assert_eq!(tools[0].definition().name, "knowledge_search");
159 }
160
161 #[test]
162 fn tool_definition_has_valid_schema() {
163 let (_kb, tools) = setup();
164 let def = tools[0].definition();
165 assert!(!def.name.is_empty());
166 assert!(!def.description.is_empty());
167 assert!(def.input_schema.is_object());
168 assert_eq!(def.input_schema["type"], "object");
169 assert!(def.input_schema["properties"]["query"].is_object());
170 let required = def.input_schema["required"].as_array().unwrap();
171 assert!(required.contains(&json!("query")));
172 }
173
174 #[tokio::test]
175 async fn search_returns_formatted_results() {
176 let (kb, tools) = setup();
177 kb.index(
178 &s(),
179 Chunk {
180 id: "c1".into(),
181 content: "Rust provides memory safety without garbage collection.".into(),
182 source: DocumentSource {
183 uri: "docs/rust.md".into(),
184 title: "Rust Guide".into(),
185 },
186 chunk_index: 0,
187 tenant_id: None,
188 },
189 )
190 .await
191 .unwrap();
192
193 let search = find_tool(&tools, "knowledge_search");
194 let result = search
195 .execute(
196 &crate::ExecutionContext::default(),
197 json!({"query": "rust memory"}),
198 )
199 .await
200 .unwrap();
201
202 assert!(!result.is_error);
203 assert!(result.content.contains("Found 1 result"));
204 assert!(result.content.contains("docs/rust.md"));
205 assert!(result.content.contains("memory safety"));
206 }
207
208 #[tokio::test]
209 async fn search_empty_results_returns_message() {
210 let (_kb, tools) = setup();
211 let search = find_tool(&tools, "knowledge_search");
212 let result = search
213 .execute(
214 &crate::ExecutionContext::default(),
215 json!({"query": "nonexistent topic xyz"}),
216 )
217 .await
218 .unwrap();
219
220 assert!(!result.is_error);
221 assert!(result.content.contains("No matching documents"));
222 }
223
224 #[tokio::test]
225 async fn search_with_source_filter() {
226 let (kb, tools) = setup();
227 kb.index(
228 &s(),
229 Chunk {
230 id: "c1".into(),
231 content: "Rust API reference".into(),
232 source: DocumentSource {
233 uri: "api/rust.md".into(),
234 title: "API".into(),
235 },
236 chunk_index: 0,
237 tenant_id: None,
238 },
239 )
240 .await
241 .unwrap();
242 kb.index(
243 &s(),
244 Chunk {
245 id: "c2".into(),
246 content: "Rust tutorial docs".into(),
247 source: DocumentSource {
248 uri: "docs/tutorial.md".into(),
249 title: "Tutorial".into(),
250 },
251 chunk_index: 0,
252 tenant_id: None,
253 },
254 )
255 .await
256 .unwrap();
257
258 let search = find_tool(&tools, "knowledge_search");
259 let result = search
260 .execute(
261 &crate::ExecutionContext::default(),
262 json!({"query": "rust", "source_filter": "api/"}),
263 )
264 .await
265 .unwrap();
266
267 assert!(!result.is_error);
268 assert!(result.content.contains("api/rust.md"));
269 assert!(!result.content.contains("docs/tutorial.md"));
270 }
271
272 #[tokio::test]
273 async fn search_with_limit() {
274 let (kb, tools) = setup();
275 for i in 0..10 {
276 kb.index(
277 &s(),
278 Chunk {
279 id: format!("c{i}"),
280 content: format!("Rust document {i}"),
281 source: DocumentSource {
282 uri: "docs/rust.md".into(),
283 title: "Rust".into(),
284 },
285 chunk_index: i,
286 tenant_id: None,
287 },
288 )
289 .await
290 .unwrap();
291 }
292
293 let search = find_tool(&tools, "knowledge_search");
294 let result = search
295 .execute(
296 &crate::ExecutionContext::default(),
297 json!({"query": "rust", "limit": 3}),
298 )
299 .await
300 .unwrap();
301
302 assert!(!result.is_error);
303 assert!(result.content.contains("Found 3 result"));
304 }
305
306 #[tokio::test]
307 async fn search_rejects_missing_query() {
308 let (_kb, tools) = setup();
309 let search = find_tool(&tools, "knowledge_search");
310 let result = search
311 .execute(&crate::ExecutionContext::default(), json!({}))
312 .await;
313 assert!(result.is_err(), "should fail on missing required 'query'");
314 }
315
316 #[tokio::test]
317 async fn search_default_limit_is_five() {
318 let (kb, tools) = setup();
319 for i in 0..10 {
320 kb.index(
321 &s(),
322 Chunk {
323 id: format!("c{i}"),
324 content: format!("Rust item {i}"),
325 source: DocumentSource {
326 uri: "f.md".into(),
327 title: "F".into(),
328 },
329 chunk_index: i,
330 tenant_id: None,
331 },
332 )
333 .await
334 .unwrap();
335 }
336
337 let search = find_tool(&tools, "knowledge_search");
338 let result = search
339 .execute(
340 &crate::ExecutionContext::default(),
341 json!({"query": "rust"}),
342 )
343 .await
344 .unwrap();
345
346 assert!(!result.is_error);
347 assert!(result.content.contains("Found 5 result"));
348 }
349}