1use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30use crate::error::{AiError, Result};
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub enum KnowledgeDomain {
35 Programming,
37 Blockchain,
39 Security,
41 Content,
43 SocialMedia,
45 General,
47}
48
49impl KnowledgeDomain {
50 #[must_use]
52 pub fn name(&self) -> &'static str {
53 match self {
54 KnowledgeDomain::Programming => "Programming",
55 KnowledgeDomain::Blockchain => "Blockchain",
56 KnowledgeDomain::Security => "Security",
57 KnowledgeDomain::Content => "Content",
58 KnowledgeDomain::SocialMedia => "Social Media",
59 KnowledgeDomain::General => "General",
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct KnowledgeEntry {
67 pub id: String,
69 pub domain: KnowledgeDomain,
71 pub title: String,
73 pub content: String,
75 pub tags: Vec<String>,
77 pub source: Option<String>,
79 pub confidence: f64,
81 pub created_at: chrono::DateTime<chrono::Utc>,
83 pub updated_at: chrono::DateTime<chrono::Utc>,
85}
86
87impl KnowledgeEntry {
88 pub fn new(
90 domain: KnowledgeDomain,
91 title: impl Into<String>,
92 content: impl Into<String>,
93 ) -> Self {
94 let now = chrono::Utc::now();
95 let title = title.into();
96 Self {
97 id: uuid::Uuid::new_v4().to_string(),
98 domain,
99 title: title.clone(),
100 content: content.into(),
101 tags: Self::extract_tags(&title),
102 source: None,
103 confidence: 1.0,
104 created_at: now,
105 updated_at: now,
106 }
107 }
108
109 pub fn with_id(
111 id: impl Into<String>,
112 domain: KnowledgeDomain,
113 title: impl Into<String>,
114 content: impl Into<String>,
115 ) -> Self {
116 let now = chrono::Utc::now();
117 let title = title.into();
118 Self {
119 id: id.into(),
120 domain,
121 title: title.clone(),
122 content: content.into(),
123 tags: Self::extract_tags(&title),
124 source: None,
125 confidence: 1.0,
126 created_at: now,
127 updated_at: now,
128 }
129 }
130
131 #[must_use]
133 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
134 self.tags = tags;
135 self
136 }
137
138 #[must_use]
140 pub fn with_source(mut self, source: impl Into<String>) -> Self {
141 self.source = Some(source.into());
142 self
143 }
144
145 #[must_use]
147 pub fn with_confidence(mut self, confidence: f64) -> Self {
148 self.confidence = confidence.clamp(0.0, 1.0);
149 self
150 }
151
152 fn extract_tags(text: &str) -> Vec<String> {
154 text.to_lowercase()
155 .split_whitespace()
156 .filter(|w| w.len() > 3)
157 .take(10)
158 .map(std::string::ToString::to_string)
159 .collect()
160 }
161
162 pub fn update_content(&mut self, content: impl Into<String>) {
164 self.content = content.into();
165 self.updated_at = chrono::Utc::now();
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct KnowledgeBase {
172 entries: HashMap<String, KnowledgeEntry>,
173 domain_index: HashMap<KnowledgeDomain, Vec<String>>,
174}
175
176impl Default for KnowledgeBase {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl KnowledgeBase {
183 #[must_use]
185 pub fn new() -> Self {
186 Self {
187 entries: HashMap::new(),
188 domain_index: HashMap::new(),
189 }
190 }
191
192 pub fn add_entry(&mut self, entry: KnowledgeEntry) -> Result<()> {
194 let id = entry.id.clone();
195 let domain = entry.domain;
196
197 self.entries.insert(id.clone(), entry);
199
200 self.domain_index.entry(domain).or_default().push(id);
202
203 Ok(())
204 }
205
206 #[must_use]
208 pub fn get_entry(&self, id: &str) -> Option<&KnowledgeEntry> {
209 self.entries.get(id)
210 }
211
212 pub fn get_entry_mut(&mut self, id: &str) -> Option<&mut KnowledgeEntry> {
214 self.entries.get_mut(id)
215 }
216
217 pub fn remove_entry(&mut self, id: &str) -> Option<KnowledgeEntry> {
219 if let Some(entry) = self.entries.remove(id) {
220 if let Some(ids) = self.domain_index.get_mut(&entry.domain) {
222 ids.retain(|i| i != id);
223 }
224 Some(entry)
225 } else {
226 None
227 }
228 }
229
230 #[must_use]
232 pub fn search(&self, query: &str) -> Vec<&KnowledgeEntry> {
233 let query_lower = query.to_lowercase();
234 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
235
236 let mut results: Vec<(f64, &KnowledgeEntry)> = self
237 .entries
238 .values()
239 .filter_map(|entry| {
240 let score = self.calculate_relevance(entry, &query_words, &query_lower);
241 if score > 0.0 {
242 Some((score, entry))
243 } else {
244 None
245 }
246 })
247 .collect();
248
249 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
251
252 results.into_iter().map(|(_, entry)| entry).collect()
253 }
254
255 fn calculate_relevance(
257 &self,
258 entry: &KnowledgeEntry,
259 query_words: &[&str],
260 query_full: &str,
261 ) -> f64 {
262 let mut score = 0.0;
263
264 let title_lower = entry.title.to_lowercase();
265 let content_lower = entry.content.to_lowercase();
266
267 if title_lower.contains(query_full) {
269 score += 10.0;
270 }
271
272 if content_lower.contains(query_full) {
274 score += 5.0;
275 }
276
277 for word in query_words {
279 if title_lower.contains(word) {
280 score += 2.0;
281 }
282 }
283
284 for word in query_words {
286 if content_lower.contains(word) {
287 score += 0.5;
288 }
289 }
290
291 for tag in &entry.tags {
293 for word in query_words {
294 if tag.contains(word) {
295 score += 1.0;
296 }
297 }
298 }
299
300 score * entry.confidence
302 }
303
304 #[must_use]
306 pub fn get_by_domain(&self, domain: KnowledgeDomain) -> Vec<&KnowledgeEntry> {
307 self.domain_index
308 .get(&domain)
309 .map(|ids| ids.iter().filter_map(|id| self.entries.get(id)).collect())
310 .unwrap_or_default()
311 }
312
313 #[must_use]
315 pub fn all_entries(&self) -> Vec<&KnowledgeEntry> {
316 self.entries.values().collect()
317 }
318
319 #[must_use]
321 pub fn len(&self) -> usize {
322 self.entries.len()
323 }
324
325 #[must_use]
327 pub fn is_empty(&self) -> bool {
328 self.entries.is_empty()
329 }
330
331 pub fn save_to_file(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
333 let json = serde_json::to_string_pretty(self)
334 .map_err(|e| AiError::Internal(format!("Failed to serialize knowledge base: {e}")))?;
335
336 std::fs::write(path, json)
337 .map_err(|e| AiError::Internal(format!("Failed to write knowledge base: {e}")))?;
338
339 Ok(())
340 }
341
342 pub fn load_from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
344 let json = std::fs::read_to_string(path)
345 .map_err(|e| AiError::Internal(format!("Failed to read knowledge base: {e}")))?;
346
347 let kb: KnowledgeBase = serde_json::from_str(&json)
348 .map_err(|e| AiError::Internal(format!("Failed to deserialize knowledge base: {e}")))?;
349
350 Ok(kb)
351 }
352
353 pub fn clear(&mut self) {
355 self.entries.clear();
356 self.domain_index.clear();
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_knowledge_entry_creation() {
366 let entry = KnowledgeEntry::new(
367 KnowledgeDomain::Programming,
368 "Rust Best Practices",
369 "Always use Result for error handling",
370 );
371
372 assert_eq!(entry.domain, KnowledgeDomain::Programming);
373 assert_eq!(entry.title, "Rust Best Practices");
374 assert!(!entry.tags.is_empty());
375 }
376
377 #[test]
378 fn test_knowledge_base_add_and_get() {
379 let mut kb = KnowledgeBase::new();
380
381 let entry = KnowledgeEntry::new(
382 KnowledgeDomain::Programming,
383 "Rust Ownership",
384 "Ownership is a key concept in Rust",
385 );
386
387 let id = entry.id.clone();
388 kb.add_entry(entry).unwrap();
389
390 assert_eq!(kb.len(), 1);
391 assert!(kb.get_entry(&id).is_some());
392 }
393
394 #[test]
395 fn test_knowledge_base_search() {
396 let mut kb = KnowledgeBase::new();
397
398 kb.add_entry(KnowledgeEntry::new(
399 KnowledgeDomain::Programming,
400 "Rust Ownership",
401 "Ownership prevents memory issues",
402 ))
403 .unwrap();
404
405 kb.add_entry(KnowledgeEntry::new(
406 KnowledgeDomain::Programming,
407 "Python GIL",
408 "Global Interpreter Lock in Python",
409 ))
410 .unwrap();
411
412 let results = kb.search("rust ownership");
413 assert_eq!(results.len(), 1);
414 assert_eq!(results[0].title, "Rust Ownership");
415 }
416
417 #[test]
418 fn test_knowledge_base_by_domain() {
419 let mut kb = KnowledgeBase::new();
420
421 kb.add_entry(KnowledgeEntry::new(
422 KnowledgeDomain::Programming,
423 "Test1",
424 "Content1",
425 ))
426 .unwrap();
427
428 kb.add_entry(KnowledgeEntry::new(
429 KnowledgeDomain::Blockchain,
430 "Test2",
431 "Content2",
432 ))
433 .unwrap();
434
435 let prog_entries = kb.get_by_domain(KnowledgeDomain::Programming);
436 assert_eq!(prog_entries.len(), 1);
437
438 let blockchain_entries = kb.get_by_domain(KnowledgeDomain::Blockchain);
439 assert_eq!(blockchain_entries.len(), 1);
440 }
441
442 #[test]
443 fn test_knowledge_base_persistence() {
444 let mut kb = KnowledgeBase::new();
445
446 kb.add_entry(KnowledgeEntry::new(
447 KnowledgeDomain::Security,
448 "Security Best Practices",
449 "Always validate input",
450 ))
451 .unwrap();
452
453 let temp_path = "/tmp/kb_test.json";
454 kb.save_to_file(temp_path).unwrap();
455
456 let kb2 = KnowledgeBase::load_from_file(temp_path).unwrap();
457 assert_eq!(kb2.len(), 1);
458
459 let _ = std::fs::remove_file(temp_path);
461 }
462
463 #[test]
464 fn test_entry_update() {
465 let mut kb = KnowledgeBase::new();
466
467 let entry = KnowledgeEntry::new(KnowledgeDomain::Programming, "Test", "Original content");
468
469 let id = entry.id.clone();
470 kb.add_entry(entry).unwrap();
471
472 if let Some(entry) = kb.get_entry_mut(&id) {
473 entry.update_content("Updated content");
474 }
475
476 let updated = kb.get_entry(&id).unwrap();
477 assert_eq!(updated.content, "Updated content");
478 }
479
480 #[test]
481 fn test_entry_removal() {
482 let mut kb = KnowledgeBase::new();
483
484 let entry = KnowledgeEntry::new(KnowledgeDomain::Programming, "Test", "Content");
485
486 let id = entry.id.clone();
487 kb.add_entry(entry).unwrap();
488
489 assert_eq!(kb.len(), 1);
490
491 let removed = kb.remove_entry(&id);
492 assert!(removed.is_some());
493 assert_eq!(kb.len(), 0);
494 }
495}