kaccy_ai/
knowledge_base.rs

1//! Knowledge base for domain-specific information
2//!
3//! This module provides a system for storing and retrieving domain-specific knowledge
4//! that can enhance AI evaluations and verifications.
5//!
6//! # Examples
7//!
8//! ```
9//! use kaccy_ai::knowledge_base::{KnowledgeBase, KnowledgeEntry, KnowledgeDomain};
10//!
11//! let mut kb = KnowledgeBase::new();
12//!
13//! // Add knowledge about Rust programming
14//! let entry = KnowledgeEntry::new(
15//!     KnowledgeDomain::Programming,
16//!     "rust_best_practices",
17//!     "Rust best practices include using the type system, avoiding unwrap() in production, \
18//!      and using Result for error handling.",
19//! );
20//! kb.add_entry(entry);
21//!
22//! // Query knowledge
23//! let results = kb.search("rust error handling");
24//! assert!(results.len() > 0);
25//! ```
26
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30use crate::error::{AiError, Result};
31
32/// Knowledge domain categories
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub enum KnowledgeDomain {
35    /// Programming and software development
36    Programming,
37    /// Blockchain and crypto
38    Blockchain,
39    /// Security and fraud detection
40    Security,
41    /// Content quality and writing
42    Content,
43    /// Social media and marketing
44    SocialMedia,
45    /// General domain knowledge
46    General,
47}
48
49impl KnowledgeDomain {
50    /// Get domain name
51    #[must_use]
52    pub fn name(&self) -> &'static str {
53        match self {
54            KnowledgeDomain::Programming => "Programming",
55            KnowledgeDomain::Blockchain => "Blockchain",
56            KnowledgeDomain::Security => "Security",
57            KnowledgeDomain::Content => "Content",
58            KnowledgeDomain::SocialMedia => "Social Media",
59            KnowledgeDomain::General => "General",
60        }
61    }
62}
63
64/// Knowledge entry
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct KnowledgeEntry {
67    /// Unique identifier
68    pub id: String,
69    /// Domain category
70    pub domain: KnowledgeDomain,
71    /// Entry title/key
72    pub title: String,
73    /// Knowledge content
74    pub content: String,
75    /// Tags for searching
76    pub tags: Vec<String>,
77    /// Source/reference
78    pub source: Option<String>,
79    /// Confidence score (0.0-1.0)
80    pub confidence: f64,
81    /// Creation timestamp
82    pub created_at: chrono::DateTime<chrono::Utc>,
83    /// Last updated timestamp
84    pub updated_at: chrono::DateTime<chrono::Utc>,
85}
86
87impl KnowledgeEntry {
88    /// Create a new knowledge entry
89    pub fn new(
90        domain: KnowledgeDomain,
91        title: impl Into<String>,
92        content: impl Into<String>,
93    ) -> Self {
94        let now = chrono::Utc::now();
95        let title = title.into();
96        Self {
97            id: uuid::Uuid::new_v4().to_string(),
98            domain,
99            title: title.clone(),
100            content: content.into(),
101            tags: Self::extract_tags(&title),
102            source: None,
103            confidence: 1.0,
104            created_at: now,
105            updated_at: now,
106        }
107    }
108
109    /// Create with explicit ID
110    pub fn with_id(
111        id: impl Into<String>,
112        domain: KnowledgeDomain,
113        title: impl Into<String>,
114        content: impl Into<String>,
115    ) -> Self {
116        let now = chrono::Utc::now();
117        let title = title.into();
118        Self {
119            id: id.into(),
120            domain,
121            title: title.clone(),
122            content: content.into(),
123            tags: Self::extract_tags(&title),
124            source: None,
125            confidence: 1.0,
126            created_at: now,
127            updated_at: now,
128        }
129    }
130
131    /// Add tags
132    #[must_use]
133    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
134        self.tags = tags;
135        self
136    }
137
138    /// Add source
139    #[must_use]
140    pub fn with_source(mut self, source: impl Into<String>) -> Self {
141        self.source = Some(source.into());
142        self
143    }
144
145    /// Set confidence
146    #[must_use]
147    pub fn with_confidence(mut self, confidence: f64) -> Self {
148        self.confidence = confidence.clamp(0.0, 1.0);
149        self
150    }
151
152    /// Extract tags from text (simple word tokenization)
153    fn extract_tags(text: &str) -> Vec<String> {
154        text.to_lowercase()
155            .split_whitespace()
156            .filter(|w| w.len() > 3)
157            .take(10)
158            .map(std::string::ToString::to_string)
159            .collect()
160    }
161
162    /// Update content
163    pub fn update_content(&mut self, content: impl Into<String>) {
164        self.content = content.into();
165        self.updated_at = chrono::Utc::now();
166    }
167}
168
169/// Knowledge base storage and retrieval
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct KnowledgeBase {
172    entries: HashMap<String, KnowledgeEntry>,
173    domain_index: HashMap<KnowledgeDomain, Vec<String>>,
174}
175
176impl Default for KnowledgeBase {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl KnowledgeBase {
183    /// Create a new knowledge base
184    #[must_use]
185    pub fn new() -> Self {
186        Self {
187            entries: HashMap::new(),
188            domain_index: HashMap::new(),
189        }
190    }
191
192    /// Add a knowledge entry
193    pub fn add_entry(&mut self, entry: KnowledgeEntry) -> Result<()> {
194        let id = entry.id.clone();
195        let domain = entry.domain;
196
197        // Add to main storage
198        self.entries.insert(id.clone(), entry);
199
200        // Add to domain index
201        self.domain_index.entry(domain).or_default().push(id);
202
203        Ok(())
204    }
205
206    /// Get entry by ID
207    #[must_use]
208    pub fn get_entry(&self, id: &str) -> Option<&KnowledgeEntry> {
209        self.entries.get(id)
210    }
211
212    /// Get mutable entry by ID
213    pub fn get_entry_mut(&mut self, id: &str) -> Option<&mut KnowledgeEntry> {
214        self.entries.get_mut(id)
215    }
216
217    /// Remove entry by ID
218    pub fn remove_entry(&mut self, id: &str) -> Option<KnowledgeEntry> {
219        if let Some(entry) = self.entries.remove(id) {
220            // Remove from domain index
221            if let Some(ids) = self.domain_index.get_mut(&entry.domain) {
222                ids.retain(|i| i != id);
223            }
224            Some(entry)
225        } else {
226            None
227        }
228    }
229
230    /// Search for entries matching query
231    #[must_use]
232    pub fn search(&self, query: &str) -> Vec<&KnowledgeEntry> {
233        let query_lower = query.to_lowercase();
234        let query_words: Vec<&str> = query_lower.split_whitespace().collect();
235
236        let mut results: Vec<(f64, &KnowledgeEntry)> = self
237            .entries
238            .values()
239            .filter_map(|entry| {
240                let score = self.calculate_relevance(entry, &query_words, &query_lower);
241                if score > 0.0 {
242                    Some((score, entry))
243                } else {
244                    None
245                }
246            })
247            .collect();
248
249        // Sort by relevance score
250        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
251
252        results.into_iter().map(|(_, entry)| entry).collect()
253    }
254
255    /// Calculate relevance score for an entry
256    fn calculate_relevance(
257        &self,
258        entry: &KnowledgeEntry,
259        query_words: &[&str],
260        query_full: &str,
261    ) -> f64 {
262        let mut score = 0.0;
263
264        let title_lower = entry.title.to_lowercase();
265        let content_lower = entry.content.to_lowercase();
266
267        // Exact title match
268        if title_lower.contains(query_full) {
269            score += 10.0;
270        }
271
272        // Exact content match
273        if content_lower.contains(query_full) {
274            score += 5.0;
275        }
276
277        // Word matches in title
278        for word in query_words {
279            if title_lower.contains(word) {
280                score += 2.0;
281            }
282        }
283
284        // Word matches in content
285        for word in query_words {
286            if content_lower.contains(word) {
287                score += 0.5;
288            }
289        }
290
291        // Tag matches
292        for tag in &entry.tags {
293            for word in query_words {
294                if tag.contains(word) {
295                    score += 1.0;
296                }
297            }
298        }
299
300        // Apply confidence multiplier
301        score * entry.confidence
302    }
303
304    /// Get entries by domain
305    #[must_use]
306    pub fn get_by_domain(&self, domain: KnowledgeDomain) -> Vec<&KnowledgeEntry> {
307        self.domain_index
308            .get(&domain)
309            .map(|ids| ids.iter().filter_map(|id| self.entries.get(id)).collect())
310            .unwrap_or_default()
311    }
312
313    /// Get all entries
314    #[must_use]
315    pub fn all_entries(&self) -> Vec<&KnowledgeEntry> {
316        self.entries.values().collect()
317    }
318
319    /// Get total entry count
320    #[must_use]
321    pub fn len(&self) -> usize {
322        self.entries.len()
323    }
324
325    /// Check if empty
326    #[must_use]
327    pub fn is_empty(&self) -> bool {
328        self.entries.is_empty()
329    }
330
331    /// Save knowledge base to file
332    pub fn save_to_file(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
333        let json = serde_json::to_string_pretty(self)
334            .map_err(|e| AiError::Internal(format!("Failed to serialize knowledge base: {e}")))?;
335
336        std::fs::write(path, json)
337            .map_err(|e| AiError::Internal(format!("Failed to write knowledge base: {e}")))?;
338
339        Ok(())
340    }
341
342    /// Load knowledge base from file
343    pub fn load_from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
344        let json = std::fs::read_to_string(path)
345            .map_err(|e| AiError::Internal(format!("Failed to read knowledge base: {e}")))?;
346
347        let kb: KnowledgeBase = serde_json::from_str(&json)
348            .map_err(|e| AiError::Internal(format!("Failed to deserialize knowledge base: {e}")))?;
349
350        Ok(kb)
351    }
352
353    /// Clear all entries
354    pub fn clear(&mut self) {
355        self.entries.clear();
356        self.domain_index.clear();
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_knowledge_entry_creation() {
366        let entry = KnowledgeEntry::new(
367            KnowledgeDomain::Programming,
368            "Rust Best Practices",
369            "Always use Result for error handling",
370        );
371
372        assert_eq!(entry.domain, KnowledgeDomain::Programming);
373        assert_eq!(entry.title, "Rust Best Practices");
374        assert!(!entry.tags.is_empty());
375    }
376
377    #[test]
378    fn test_knowledge_base_add_and_get() {
379        let mut kb = KnowledgeBase::new();
380
381        let entry = KnowledgeEntry::new(
382            KnowledgeDomain::Programming,
383            "Rust Ownership",
384            "Ownership is a key concept in Rust",
385        );
386
387        let id = entry.id.clone();
388        kb.add_entry(entry).unwrap();
389
390        assert_eq!(kb.len(), 1);
391        assert!(kb.get_entry(&id).is_some());
392    }
393
394    #[test]
395    fn test_knowledge_base_search() {
396        let mut kb = KnowledgeBase::new();
397
398        kb.add_entry(KnowledgeEntry::new(
399            KnowledgeDomain::Programming,
400            "Rust Ownership",
401            "Ownership prevents memory issues",
402        ))
403        .unwrap();
404
405        kb.add_entry(KnowledgeEntry::new(
406            KnowledgeDomain::Programming,
407            "Python GIL",
408            "Global Interpreter Lock in Python",
409        ))
410        .unwrap();
411
412        let results = kb.search("rust ownership");
413        assert_eq!(results.len(), 1);
414        assert_eq!(results[0].title, "Rust Ownership");
415    }
416
417    #[test]
418    fn test_knowledge_base_by_domain() {
419        let mut kb = KnowledgeBase::new();
420
421        kb.add_entry(KnowledgeEntry::new(
422            KnowledgeDomain::Programming,
423            "Test1",
424            "Content1",
425        ))
426        .unwrap();
427
428        kb.add_entry(KnowledgeEntry::new(
429            KnowledgeDomain::Blockchain,
430            "Test2",
431            "Content2",
432        ))
433        .unwrap();
434
435        let prog_entries = kb.get_by_domain(KnowledgeDomain::Programming);
436        assert_eq!(prog_entries.len(), 1);
437
438        let blockchain_entries = kb.get_by_domain(KnowledgeDomain::Blockchain);
439        assert_eq!(blockchain_entries.len(), 1);
440    }
441
442    #[test]
443    fn test_knowledge_base_persistence() {
444        let mut kb = KnowledgeBase::new();
445
446        kb.add_entry(KnowledgeEntry::new(
447            KnowledgeDomain::Security,
448            "Security Best Practices",
449            "Always validate input",
450        ))
451        .unwrap();
452
453        let temp_path = "/tmp/kb_test.json";
454        kb.save_to_file(temp_path).unwrap();
455
456        let kb2 = KnowledgeBase::load_from_file(temp_path).unwrap();
457        assert_eq!(kb2.len(), 1);
458
459        // Cleanup
460        let _ = std::fs::remove_file(temp_path);
461    }
462
463    #[test]
464    fn test_entry_update() {
465        let mut kb = KnowledgeBase::new();
466
467        let entry = KnowledgeEntry::new(KnowledgeDomain::Programming, "Test", "Original content");
468
469        let id = entry.id.clone();
470        kb.add_entry(entry).unwrap();
471
472        if let Some(entry) = kb.get_entry_mut(&id) {
473            entry.update_content("Updated content");
474        }
475
476        let updated = kb.get_entry(&id).unwrap();
477        assert_eq!(updated.content, "Updated content");
478    }
479
480    #[test]
481    fn test_entry_removal() {
482        let mut kb = KnowledgeBase::new();
483
484        let entry = KnowledgeEntry::new(KnowledgeDomain::Programming, "Test", "Content");
485
486        let id = entry.id.clone();
487        kb.add_entry(entry).unwrap();
488
489        assert_eq!(kb.len(), 1);
490
491        let removed = kb.remove_entry(&id);
492        assert!(removed.is_some());
493        assert_eq!(kb.len(), 0);
494    }
495}