Skip to main content

agent_knowledge_base/
lib.rs

1/*!
2agent-knowledge-base: keyword-indexed knowledge base for LLM agents.
3
4```rust
5use agent_knowledge_base::KnowledgeBase;
6
7let mut kb = KnowledgeBase::new();
8kb.add("rust_overview", "Rust is a systems language focused on safety and performance.");
9kb.add("python_overview", "Python is a high-level interpreted language.");
10let results = kb.search("rust safety");
11assert!(!results.is_empty());
12assert!(results[0].1.contains("Rust"));
13```
14*/
15
16use std::collections::HashMap;
17
18/// A knowledge base entry.
19#[derive(Debug, Clone)]
20pub struct Entry {
21    pub id: String,
22    pub text: String,
23    pub tags: Vec<String>,
24}
25
26/// Simple keyword-scored knowledge base.
27pub struct KnowledgeBase {
28    entries: HashMap<String, Entry>,
29}
30
31impl KnowledgeBase {
32    pub fn new() -> Self { Self { entries: HashMap::new() } }
33
34    /// Add or replace an entry.
35    pub fn add(&mut self, id: &str, text: &str) {
36        self.entries.insert(id.to_string(), Entry { id: id.to_string(), text: text.to_string(), tags: Vec::new() });
37    }
38
39    /// Add with tags.
40    pub fn add_tagged(&mut self, id: &str, text: &str, tags: &[&str]) {
41        self.entries.insert(id.to_string(), Entry {
42            id: id.to_string(),
43            text: text.to_string(),
44            tags: tags.iter().map(|t| t.to_string()).collect(),
45        });
46    }
47
48    pub fn get(&self, id: &str) -> Option<&Entry> { self.entries.get(id) }
49
50    pub fn remove(&mut self, id: &str) -> Option<Entry> { self.entries.remove(id) }
51
52    pub fn len(&self) -> usize { self.entries.len() }
53    pub fn is_empty(&self) -> bool { self.entries.is_empty() }
54
55    /// Search by keywords (space-separated). Returns (id, text, score) tuples sorted by score desc.
56    pub fn search(&self, query: &str) -> Vec<(&str, &str, f64)> {
57        let keywords: Vec<String> = query.split_whitespace().map(|w| w.to_lowercase()).collect();
58        if keywords.is_empty() { return Vec::new(); }
59
60        let mut scored: Vec<(&str, &str, f64)> = self.entries.values()
61            .filter_map(|e| {
62                let text_lower = e.text.to_lowercase();
63                let matched = keywords.iter().filter(|k| text_lower.contains(k.as_str())).count();
64                if matched > 0 {
65                    Some((e.id.as_str(), e.text.as_str(), matched as f64 / keywords.len() as f64))
66                } else {
67                    None
68                }
69            })
70            .collect();
71
72        scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap().then(a.0.cmp(b.0)));
73        scored
74    }
75
76    /// Find entries by tag.
77    pub fn by_tag(&self, tag: &str) -> Vec<&Entry> {
78        let mut v: Vec<&Entry> = self.entries.values().filter(|e| e.tags.iter().any(|t| t == tag)).collect();
79        v.sort_by_key(|e| &e.id);
80        v
81    }
82
83    /// All entry ids (sorted).
84    pub fn ids(&self) -> Vec<&str> {
85        let mut v: Vec<&str> = self.entries.keys().map(|s| s.as_str()).collect();
86        v.sort();
87        v
88    }
89}
90
91impl Default for KnowledgeBase {
92    fn default() -> Self { Self::new() }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn add_and_get() {
101        let mut kb = KnowledgeBase::new();
102        kb.add("doc1", "Rust is safe");
103        assert_eq!(kb.get("doc1").unwrap().text, "Rust is safe");
104    }
105
106    #[test]
107    fn search_finds_match() {
108        let mut kb = KnowledgeBase::new();
109        kb.add("doc1", "Rust is a systems programming language");
110        let r = kb.search("rust");
111        assert!(!r.is_empty());
112    }
113
114    #[test]
115    fn search_returns_empty_for_no_match() {
116        let mut kb = KnowledgeBase::new();
117        kb.add("doc1", "Rust programming");
118        let r = kb.search("python");
119        assert!(r.is_empty());
120    }
121
122    #[test]
123    fn search_scores_multi_keyword() {
124        let mut kb = KnowledgeBase::new();
125        kb.add("doc1", "rust safety performance");
126        kb.add("doc2", "rust language");
127        let r = kb.search("rust safety");
128        // doc1 matches both keywords, doc2 matches one
129        assert_eq!(r[0].0, "doc1");
130    }
131
132    #[test]
133    fn case_insensitive_search() {
134        let mut kb = KnowledgeBase::new();
135        kb.add("doc1", "Rust Is Great");
136        let r = kb.search("rust");
137        assert!(!r.is_empty());
138    }
139
140    #[test]
141    fn add_tagged_and_by_tag() {
142        let mut kb = KnowledgeBase::new();
143        kb.add_tagged("doc1", "content", &["lang", "systems"]);
144        kb.add_tagged("doc2", "content", &["scripting"]);
145        assert_eq!(kb.by_tag("lang").len(), 1);
146        assert_eq!(kb.by_tag("scripting").len(), 1);
147    }
148
149    #[test]
150    fn remove() {
151        let mut kb = KnowledgeBase::new();
152        kb.add("x", "text");
153        assert!(kb.remove("x").is_some());
154        assert!(kb.get("x").is_none());
155    }
156
157    #[test]
158    fn len_and_empty() {
159        let mut kb = KnowledgeBase::new();
160        assert!(kb.is_empty());
161        kb.add("x", "y");
162        assert_eq!(kb.len(), 1);
163    }
164
165    #[test]
166    fn ids_sorted() {
167        let mut kb = KnowledgeBase::new();
168        kb.add("c", "c");
169        kb.add("a", "a");
170        kb.add("b", "b");
171        assert_eq!(kb.ids(), vec!["a", "b", "c"]);
172    }
173
174    #[test]
175    fn search_empty_query() {
176        let mut kb = KnowledgeBase::new();
177        kb.add("x", "text");
178        assert!(kb.search("").is_empty());
179    }
180
181    #[test]
182    fn add_replaces_existing() {
183        let mut kb = KnowledgeBase::new();
184        kb.add("x", "old");
185        kb.add("x", "new");
186        assert_eq!(kb.get("x").unwrap().text, "new");
187        assert_eq!(kb.len(), 1);
188    }
189
190    #[test]
191    fn search_score_is_fraction() {
192        let mut kb = KnowledgeBase::new();
193        kb.add("doc", "rust safety memory");
194        let r = kb.search("rust memory");
195        assert!((r[0].2 - 1.0).abs() < 1e-9); // both keywords match
196    }
197}