Skip to main content

heartbit_core/knowledge/
tools.rs

1//! Knowledge base tool definitions for agent document retrieval.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::json;
9
10use crate::auth::TenantScope;
11use crate::error::Error;
12use crate::llm::types::ToolDefinition;
13use crate::tool::{Tool, ToolOutput};
14
15use super::KnowledgeBase;
16
17/// Create knowledge tools for agent access to the knowledge base.
18///
19/// Returns 1 tool:
20/// - `knowledge_search` — search the knowledge base for relevant documentation
21///
22/// SECURITY (F-KB-1): the `scope` is baked into the tool so every search is
23/// filtered by tenant. A shared `Arc<dyn KnowledgeBase>` across tenants would
24/// otherwise leak documents cross-tenant via `knowledge_search`.
25pub fn knowledge_tools(kb: Arc<dyn KnowledgeBase>, scope: TenantScope) -> Vec<Arc<dyn Tool>> {
26    vec![Arc::new(KnowledgeSearchTool { kb, scope })]
27}
28
29fn default_limit() -> usize {
30    5
31}
32
33struct KnowledgeSearchTool {
34    kb: Arc<dyn KnowledgeBase>,
35    scope: TenantScope,
36}
37
38#[derive(Deserialize)]
39struct SearchInput {
40    query: String,
41    source_filter: Option<String>,
42    #[serde(default = "default_limit")]
43    limit: usize,
44}
45
46impl Tool for KnowledgeSearchTool {
47    fn definition(&self) -> ToolDefinition {
48        ToolDefinition {
49            name: "knowledge_search".into(),
50            description: "Search the knowledge base for relevant documentation, code examples, \
51                          and reference material. Use this when you need to find specific \
52                          information from project docs, API references, or other indexed sources."
53                .into(),
54            input_schema: json!({
55                "type": "object",
56                "properties": {
57                    "query": {
58                        "type": "string",
59                        "description": "Free-text search query describing what you're looking for"
60                    },
61                    "source_filter": {
62                        "type": "string",
63                        "description": "Optional URI prefix to restrict results to specific sources (e.g. 'docs/' or 'https://api.example.com')"
64                    },
65                    "limit": {
66                        "type": "integer",
67                        "minimum": 1,
68                        "maximum": 20,
69                        "default": 5,
70                        "description": "Maximum number of results to return"
71                    }
72                },
73                "required": ["query"]
74            }),
75        }
76    }
77
78    fn execute(
79        &self,
80        _ctx: &crate::ExecutionContext,
81        input: serde_json::Value,
82    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
83        Box::pin(async move {
84            let input: SearchInput =
85                serde_json::from_value(input).map_err(|e| Error::Agent(e.to_string()))?;
86
87            let limit = input.limit.clamp(1, 20);
88
89            let results = self
90                .kb
91                .search(
92                    &self.scope,
93                    super::KnowledgeQuery {
94                        text: input.query,
95                        source_filter: input.source_filter,
96                        limit,
97                    },
98                )
99                .await?;
100
101            if results.is_empty() {
102                return Ok(ToolOutput::success(
103                    "No matching documents found in the knowledge base.",
104                ));
105            }
106
107            let formatted = results
108                .iter()
109                .enumerate()
110                .map(|(i, r)| {
111                    format!(
112                        "--- Result {} (source: {}, matches: {}) ---\n{}",
113                        i + 1,
114                        r.chunk.source.uri,
115                        r.match_count,
116                        r.chunk.content,
117                    )
118                })
119                .collect::<Vec<_>>()
120                .join("\n\n");
121
122            Ok(ToolOutput::success(format!(
123                "Found {} result(s):\n\n{}",
124                results.len(),
125                formatted,
126            )))
127        })
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::knowledge::in_memory::InMemoryKnowledgeBase;
135    use crate::knowledge::{Chunk, DocumentSource};
136
137    fn s() -> TenantScope {
138        TenantScope::default()
139    }
140
141    fn setup() -> (Arc<dyn KnowledgeBase>, Vec<Arc<dyn Tool>>) {
142        let kb: Arc<dyn KnowledgeBase> = Arc::new(InMemoryKnowledgeBase::new());
143        let tools = knowledge_tools(kb.clone(), s());
144        (kb, tools)
145    }
146
147    fn find_tool<'a>(tools: &'a [Arc<dyn Tool>], name: &str) -> &'a Arc<dyn Tool> {
148        tools
149            .iter()
150            .find(|t| t.definition().name == name)
151            .unwrap_or_else(|| panic!("tool {name} not found"))
152    }
153
154    #[test]
155    fn creates_one_tool() {
156        let (_kb, tools) = setup();
157        assert_eq!(tools.len(), 1);
158        assert_eq!(tools[0].definition().name, "knowledge_search");
159    }
160
161    #[test]
162    fn tool_definition_has_valid_schema() {
163        let (_kb, tools) = setup();
164        let def = tools[0].definition();
165        assert!(!def.name.is_empty());
166        assert!(!def.description.is_empty());
167        assert!(def.input_schema.is_object());
168        assert_eq!(def.input_schema["type"], "object");
169        assert!(def.input_schema["properties"]["query"].is_object());
170        let required = def.input_schema["required"].as_array().unwrap();
171        assert!(required.contains(&json!("query")));
172    }
173
174    #[tokio::test]
175    async fn search_returns_formatted_results() {
176        let (kb, tools) = setup();
177        kb.index(
178            &s(),
179            Chunk {
180                id: "c1".into(),
181                content: "Rust provides memory safety without garbage collection.".into(),
182                source: DocumentSource {
183                    uri: "docs/rust.md".into(),
184                    title: "Rust Guide".into(),
185                },
186                chunk_index: 0,
187                tenant_id: None,
188            },
189        )
190        .await
191        .unwrap();
192
193        let search = find_tool(&tools, "knowledge_search");
194        let result = search
195            .execute(
196                &crate::ExecutionContext::default(),
197                json!({"query": "rust memory"}),
198            )
199            .await
200            .unwrap();
201
202        assert!(!result.is_error);
203        assert!(result.content.contains("Found 1 result"));
204        assert!(result.content.contains("docs/rust.md"));
205        assert!(result.content.contains("memory safety"));
206    }
207
208    #[tokio::test]
209    async fn search_empty_results_returns_message() {
210        let (_kb, tools) = setup();
211        let search = find_tool(&tools, "knowledge_search");
212        let result = search
213            .execute(
214                &crate::ExecutionContext::default(),
215                json!({"query": "nonexistent topic xyz"}),
216            )
217            .await
218            .unwrap();
219
220        assert!(!result.is_error);
221        assert!(result.content.contains("No matching documents"));
222    }
223
224    #[tokio::test]
225    async fn search_with_source_filter() {
226        let (kb, tools) = setup();
227        kb.index(
228            &s(),
229            Chunk {
230                id: "c1".into(),
231                content: "Rust API reference".into(),
232                source: DocumentSource {
233                    uri: "api/rust.md".into(),
234                    title: "API".into(),
235                },
236                chunk_index: 0,
237                tenant_id: None,
238            },
239        )
240        .await
241        .unwrap();
242        kb.index(
243            &s(),
244            Chunk {
245                id: "c2".into(),
246                content: "Rust tutorial docs".into(),
247                source: DocumentSource {
248                    uri: "docs/tutorial.md".into(),
249                    title: "Tutorial".into(),
250                },
251                chunk_index: 0,
252                tenant_id: None,
253            },
254        )
255        .await
256        .unwrap();
257
258        let search = find_tool(&tools, "knowledge_search");
259        let result = search
260            .execute(
261                &crate::ExecutionContext::default(),
262                json!({"query": "rust", "source_filter": "api/"}),
263            )
264            .await
265            .unwrap();
266
267        assert!(!result.is_error);
268        assert!(result.content.contains("api/rust.md"));
269        assert!(!result.content.contains("docs/tutorial.md"));
270    }
271
272    #[tokio::test]
273    async fn search_with_limit() {
274        let (kb, tools) = setup();
275        for i in 0..10 {
276            kb.index(
277                &s(),
278                Chunk {
279                    id: format!("c{i}"),
280                    content: format!("Rust document {i}"),
281                    source: DocumentSource {
282                        uri: "docs/rust.md".into(),
283                        title: "Rust".into(),
284                    },
285                    chunk_index: i,
286                    tenant_id: None,
287                },
288            )
289            .await
290            .unwrap();
291        }
292
293        let search = find_tool(&tools, "knowledge_search");
294        let result = search
295            .execute(
296                &crate::ExecutionContext::default(),
297                json!({"query": "rust", "limit": 3}),
298            )
299            .await
300            .unwrap();
301
302        assert!(!result.is_error);
303        assert!(result.content.contains("Found 3 result"));
304    }
305
306    #[tokio::test]
307    async fn search_rejects_missing_query() {
308        let (_kb, tools) = setup();
309        let search = find_tool(&tools, "knowledge_search");
310        let result = search
311            .execute(&crate::ExecutionContext::default(), json!({}))
312            .await;
313        assert!(result.is_err(), "should fail on missing required 'query'");
314    }
315
316    #[tokio::test]
317    async fn search_default_limit_is_five() {
318        let (kb, tools) = setup();
319        for i in 0..10 {
320            kb.index(
321                &s(),
322                Chunk {
323                    id: format!("c{i}"),
324                    content: format!("Rust item {i}"),
325                    source: DocumentSource {
326                        uri: "f.md".into(),
327                        title: "F".into(),
328                    },
329                    chunk_index: i,
330                    tenant_id: None,
331                },
332            )
333            .await
334            .unwrap();
335        }
336
337        let search = find_tool(&tools, "knowledge_search");
338        let result = search
339            .execute(
340                &crate::ExecutionContext::default(),
341                json!({"query": "rust"}),
342            )
343            .await
344            .unwrap();
345
346        assert!(!result.is_error);
347        assert!(result.content.contains("Found 5 result"));
348    }
349}