1use super::{KnowledgeEntry, SearchOptions, SearchResult};
4use crate::embedding::EmbeddingEngine;
5use crate::error::{Error, Result};
6use crate::learning::LearningEngine;
7use crate::storage::StorageBackend;
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use std::sync::Arc;
14use tracing::{debug, info, instrument};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct KnowledgeBaseConfig {
20 pub dimensions: usize,
22
23 pub storage_path: String,
25
26 pub learning_enabled: bool,
28
29 pub learning_rate: f32,
31
32 pub hnsw_m: usize,
34
35 pub hnsw_ef_construction: usize,
37
38 pub hnsw_ef_search: usize,
40
41 pub batch_size: usize,
43}
44
45impl Default for KnowledgeBaseConfig {
46 fn default() -> Self {
47 Self {
48 dimensions: 384,
49 storage_path: "./knowledge.db".to_string(),
50 learning_enabled: true,
51 learning_rate: 0.01,
52 hnsw_m: 16,
53 hnsw_ef_construction: 200,
54 hnsw_ef_search: 100,
55 batch_size: 1000,
56 }
57 }
58}
59
60impl KnowledgeBaseConfig {
61 pub fn with_path(mut self, path: impl Into<String>) -> Self {
63 self.storage_path = path.into();
64 self
65 }
66
67 pub fn with_dimensions(mut self, dims: usize) -> Self {
69 self.dimensions = dims;
70 self
71 }
72
73 pub fn without_learning(mut self) -> Self {
75 self.learning_enabled = false;
76 self
77 }
78}
79
80pub struct KnowledgeBase {
82 config: KnowledgeBaseConfig,
84
85 storage: Arc<StorageBackend>,
87
88 embeddings: Arc<EmbeddingEngine>,
90
91 learning: Option<Arc<RwLock<LearningEngine>>>,
93
94 entries: DashMap<Uuid, KnowledgeEntry>,
96
97 vectors: DashMap<Uuid, Vec<f32>>,
99
100 count: Arc<RwLock<usize>>,
102}
103
104impl KnowledgeBase {
105 #[instrument(skip_all)]
107 pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
108 let config = KnowledgeBaseConfig::default().with_path(path.as_ref().to_string_lossy());
109 Self::with_config(config).await
110 }
111
112 #[instrument(skip_all, fields(path = %config.storage_path))]
114 pub async fn with_config(config: KnowledgeBaseConfig) -> Result<Self> {
115 info!("Initializing knowledge base at {}", config.storage_path);
116
117 let storage = Arc::new(StorageBackend::open(&config.storage_path).await?);
118 let embeddings = Arc::new(EmbeddingEngine::new(config.dimensions));
119
120 let learning = if config.learning_enabled {
121 Some(Arc::new(RwLock::new(LearningEngine::new(
122 config.dimensions,
123 config.learning_rate,
124 ))))
125 } else {
126 None
127 };
128
129 let kb = Self {
130 config,
131 storage,
132 embeddings,
133 learning,
134 entries: DashMap::new(),
135 vectors: DashMap::new(),
136 count: Arc::new(RwLock::new(0)),
137 };
138
139 kb.load_entries().await?;
141
142 info!("Knowledge base initialized with {} entries", kb.len());
143 Ok(kb)
144 }
145
146 async fn load_entries(&self) -> Result<()> {
148 let stored = self.storage.load_all().await?;
149
150 for (entry, embedding) in stored {
151 self.entries.insert(entry.id, entry.clone());
152 self.vectors.insert(entry.id, embedding);
153 }
154
155 *self.count.write() = self.entries.len();
156 Ok(())
157 }
158
159 pub fn len(&self) -> usize {
161 *self.count.read()
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.len() == 0
167 }
168
169 pub fn config(&self) -> &KnowledgeBaseConfig {
171 &self.config
172 }
173
174 #[instrument(skip(self, entry), fields(title = %entry.title))]
176 pub async fn add_entry(&self, entry: KnowledgeEntry) -> Result<Uuid> {
177 let id = entry.id;
178
179 let text = entry.embedding_text();
181 let embedding = self.embeddings.embed(&text).await?;
182
183 self.entries.insert(id, entry.clone());
185 self.vectors.insert(id, embedding.clone());
186
187 self.storage.save_entry(&entry, &embedding).await?;
189
190 *self.count.write() += 1;
191 debug!("Added entry {}", id);
192
193 Ok(id)
194 }
195
196 #[instrument(skip(self, entries), fields(count = entries.len()))]
198 pub async fn add_entries(&self, entries: Vec<KnowledgeEntry>) -> Result<Vec<Uuid>> {
199 let mut ids = Vec::with_capacity(entries.len());
200
201 for chunk in entries.chunks(self.config.batch_size) {
202 let mut batch = Vec::with_capacity(chunk.len());
203 for entry in chunk {
204 let text = entry.embedding_text();
205 let embedding = self.embeddings.embed(&text).await?;
206 batch.push((entry.clone(), embedding));
207 }
208
209 for (entry, embedding) in &batch {
210 self.entries.insert(entry.id, entry.clone());
211 self.vectors.insert(entry.id, embedding.clone());
212 ids.push(entry.id);
213 }
214
215 self.storage.save_batch(&batch).await?;
216 }
217
218 *self.count.write() += ids.len();
219 info!("Added {} entries in batch", ids.len());
220
221 Ok(ids)
222 }
223
224 pub fn get(&self, id: Uuid) -> Option<KnowledgeEntry> {
226 self.entries.get(&id).map(|e| e.clone())
227 }
228
229 #[instrument(skip(self, entry), fields(id = %entry.id))]
231 pub async fn update_entry(&self, entry: KnowledgeEntry) -> Result<()> {
232 let id = entry.id;
233
234 if !self.entries.contains_key(&id) {
235 return Err(Error::not_found(id.to_string()));
236 }
237
238 let text = entry.embedding_text();
240 let embedding = self.embeddings.embed(&text).await?;
241
242 self.entries.insert(id, entry.clone());
244 self.vectors.insert(id, embedding.clone());
245
246 self.storage.save_entry(&entry, &embedding).await?;
248
249 debug!("Updated entry {}", id);
250 Ok(())
251 }
252
253 #[instrument(skip(self), fields(id = %id))]
255 pub async fn delete_entry(&self, id: Uuid) -> Result<()> {
256 if self.entries.remove(&id).is_none() {
257 return Err(Error::not_found(id.to_string()));
258 }
259
260 self.vectors.remove(&id);
261 self.storage.delete_entry(id).await?;
262
263 *self.count.write() -= 1;
264 debug!("Deleted entry {}", id);
265
266 Ok(())
267 }
268
269 #[instrument(skip(self), fields(k = options.limit))]
271 pub async fn search(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
272 let query_embedding = self.embeddings.embed(query).await?;
274
275 let mut candidates: Vec<(Uuid, f32)> = self
278 .vectors
279 .iter()
280 .map(|entry| {
281 let id = *entry.key();
282 let distance = cosine_distance(&query_embedding, entry.value());
283 (id, distance)
284 })
285 .collect();
286
287 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
289
290 if options.use_learning
292 && let Some(learning) = &self.learning
293 {
294 let learning = learning.read();
295 candidates = learning.rerank(&query_embedding, candidates, &self.vectors);
296 }
297
298 let mut results = Vec::new();
300
301 for (id, distance) in candidates.into_iter().take(options.limit * 2) {
302 if let Some(entry) = self.entries.get(&id) {
303 let entry = entry.clone();
304
305 if let Some(ref cat) = options.category
307 && entry.category.as_ref() != Some(cat)
308 {
309 continue;
310 }
311
312 if !options.tags.is_empty()
313 && !options
314 .tags
315 .iter()
316 .any(|t| entry.tags.iter().any(|et| et == t))
317 {
318 continue;
319 }
320
321 let similarity = 1.0 - distance;
322 if similarity < options.min_similarity {
323 continue;
324 }
325
326 results.push(SearchResult::new(entry, similarity, distance));
327
328 if results.len() >= options.limit {
329 break;
330 }
331 }
332 }
333
334 if options.diversity > 0.0 {
336 results = apply_mmr(results, options.diversity);
337 }
338
339 if let Some(learning) = &self.learning {
341 let mut learning = learning.write();
342 learning.record_query(&query_embedding, &results);
343 }
344
345 debug!("Search returned {} results", results.len());
346 Ok(results)
347 }
348
349 pub async fn search_simple(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
351 self.search(query, SearchOptions::new(limit)).await
352 }
353
354 #[instrument(skip(self))]
356 pub async fn record_feedback(&self, entry_id: Uuid, positive: bool) -> Result<()> {
357 if let Some(mut entry) = self.entries.get_mut(&entry_id) {
358 let boost = if positive { 0.1 } else { -0.05 };
359 entry.record_access(1.0 + boost);
360
361 if let Some(learning) = &self.learning {
363 let mut learning = learning.write();
364 if let Some(embedding) = self.vectors.get(&entry_id) {
365 learning.record_feedback(&embedding, positive);
366 }
367 }
368
369 let entry = entry.clone();
371 if let Some(embedding) = self.vectors.get(&entry_id) {
372 self.storage.save_entry(&entry, &embedding).await?;
373 }
374 }
375
376 Ok(())
377 }
378
379 pub fn get_related(&self, id: Uuid, limit: usize) -> Vec<KnowledgeEntry> {
381 if let Some(entry) = self.entries.get(&id) {
382 entry
383 .related_entries
384 .iter()
385 .take(limit)
386 .filter_map(|rel_id| self.entries.get(rel_id).map(|e| e.clone()))
387 .collect()
388 } else {
389 Vec::new()
390 }
391 }
392
393 #[allow(clippy::unused_async)]
395 pub async fn link_entries(&self, id1: Uuid, id2: Uuid) -> Result<()> {
396 if let Some(mut entry1) = self.entries.get_mut(&id1) {
397 if !entry1.related_entries.contains(&id2) {
398 entry1.related_entries.push(id2);
399 }
400 } else {
401 return Err(Error::not_found(id1.to_string()));
402 }
403
404 if let Some(mut entry2) = self.entries.get_mut(&id2)
405 && !entry2.related_entries.contains(&id1)
406 {
407 entry2.related_entries.push(id1);
408 }
409
410 Ok(())
411 }
412
413 pub fn all_entries(&self) -> Vec<KnowledgeEntry> {
415 self.entries.iter().map(|e| e.value().clone()).collect()
416 }
417
418 pub fn stats(&self) -> KnowledgeBaseStats {
420 let total = self.len();
421 let categories: std::collections::HashSet<_> = self
422 .entries
423 .iter()
424 .filter_map(|e| e.category.clone())
425 .collect();
426
427 let tags: std::collections::HashSet<_> =
428 self.entries.iter().flat_map(|e| e.tags.clone()).collect();
429
430 let total_access: u64 = self.entries.iter().map(|e| e.access_count).sum();
431
432 KnowledgeBaseStats {
433 total_entries: total,
434 unique_categories: categories.len(),
435 unique_tags: tags.len(),
436 total_access_count: total_access,
437 dimensions: self.config.dimensions,
438 learning_enabled: self.config.learning_enabled,
439 }
440 }
441
442 pub async fn flush(&self) -> Result<()> {
444 self.storage.flush().await
445 }
446}
447
448#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct KnowledgeBaseStats {
451 pub total_entries: usize,
452 pub unique_categories: usize,
453 pub unique_tags: usize,
454 pub total_access_count: u64,
455 pub dimensions: usize,
456 pub learning_enabled: bool,
457}
458
459fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
461 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
462 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
463 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
464
465 if norm_a == 0.0 || norm_b == 0.0 {
466 1.0
467 } else {
468 1.0 - (dot / (norm_a * norm_b))
469 }
470}
471
472fn apply_mmr(mut results: Vec<SearchResult>, lambda: f32) -> Vec<SearchResult> {
474 if results.len() <= 1 {
475 return results;
476 }
477
478 let mut selected = vec![results.remove(0)];
479
480 while !results.is_empty() && selected.len() < results.len() + selected.len() {
481 let mut best_idx = 0;
482 let mut best_score = f32::NEG_INFINITY;
483
484 for (i, candidate) in results.iter().enumerate() {
485 let relevance = candidate.similarity;
487
488 let max_sim = selected
490 .iter()
491 .map(|s| {
492 1.0 - (s.score - candidate.score).abs()
494 })
495 .max_by(|a, b| a.partial_cmp(b).unwrap())
496 .unwrap_or(0.0);
497
498 let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
500
501 if mmr > best_score {
502 best_score = mmr;
503 best_idx = i;
504 }
505 }
506
507 selected.push(results.remove(best_idx));
508 }
509
510 selected
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::core::KnowledgeEntry;
517 use tempfile::tempdir;
518
519 fn small_config(path: &Path) -> KnowledgeBaseConfig {
520 KnowledgeBaseConfig::default()
521 .with_path(path.to_string_lossy())
522 .with_dimensions(32)
523 }
524
525 #[test]
526 fn test_cosine_distance() {
527 let a = vec![1.0, 0.0, 0.0];
528 let b = vec![1.0, 0.0, 0.0];
529 assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
530
531 let c = vec![0.0, 1.0, 0.0];
532 assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
533
534 let z = vec![0.0, 0.0, 0.0];
536 assert!((cosine_distance(&a, &z) - 1.0).abs() < 1e-6);
537 }
538
539 #[test]
540 fn config_builder_sets_fields() {
541 let cfg = KnowledgeBaseConfig::default()
542 .with_path("/tmp/x.db")
543 .with_dimensions(64)
544 .without_learning();
545 assert_eq!(cfg.storage_path, "/tmp/x.db");
546 assert_eq!(cfg.dimensions, 64);
547 assert!(!cfg.learning_enabled);
548 }
549
550 #[tokio::test]
551 async fn open_creates_empty_kb() {
552 let dir = tempdir().unwrap();
553 let kb = KnowledgeBase::open(dir.path().join("kb.db")).await.unwrap();
554 assert_eq!(kb.len(), 0);
555 assert!(kb.is_empty());
556 assert_eq!(kb.config().dimensions, 384);
557 }
558
559 #[tokio::test]
560 async fn add_get_update_delete_roundtrip() {
561 let dir = tempdir().unwrap();
562 let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
563 .await
564 .unwrap();
565
566 let entry = KnowledgeEntry::new("Title", "body text").with_category("docs");
567 let id = kb.add_entry(entry.clone()).await.unwrap();
568 assert_eq!(kb.len(), 1);
569 assert!(!kb.is_empty());
570
571 let fetched = kb.get(id).expect("entry should exist");
572 assert_eq!(fetched.title, "Title");
573
574 let mut updated = fetched;
575 updated.content = "new body".into();
576 kb.update_entry(updated.clone()).await.unwrap();
577 assert_eq!(kb.get(id).unwrap().content, "new body");
578
579 kb.delete_entry(id).await.unwrap();
580 assert_eq!(kb.len(), 0);
581 assert!(kb.get(id).is_none());
582 }
583
584 #[tokio::test]
585 async fn update_missing_entry_errors() {
586 let dir = tempdir().unwrap();
587 let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
588 .await
589 .unwrap();
590 let stranger = KnowledgeEntry::new("ghost", "body");
591 let err = kb.update_entry(stranger).await.unwrap_err();
592 assert!(matches!(err, Error::NotFound(_)));
593 }
594
595 #[tokio::test]
596 async fn delete_missing_entry_errors() {
597 let dir = tempdir().unwrap();
598 let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
599 .await
600 .unwrap();
601 let err = kb.delete_entry(Uuid::new_v4()).await.unwrap_err();
602 assert!(matches!(err, Error::NotFound(_)));
603 }
604
605 #[tokio::test]
606 async fn add_entries_batch_persists() {
607 let dir = tempdir().unwrap();
608 let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
609 .await
610 .unwrap();
611 let batch: Vec<_> = (0..5)
612 .map(|i| KnowledgeEntry::new(format!("t{i}"), format!("body {i}")))
613 .collect();
614 let ids = kb.add_entries(batch).await.unwrap();
615 assert_eq!(ids.len(), 5);
616 assert_eq!(kb.len(), 5);
617 kb.flush().await.unwrap();
618 }
619
620 #[tokio::test]
621 async fn search_filters_and_results() {
622 let dir = tempdir().unwrap();
623 let cfg = KnowledgeBaseConfig::default()
626 .with_path(dir.path().join("kb.db").to_string_lossy())
627 .with_dimensions(128);
628 let kb = KnowledgeBase::with_config(cfg).await.unwrap();
629 kb.add_entry(
630 KnowledgeEntry::new("rust ownership", "borrow checker introduction")
631 .with_category("rust")
632 .with_tags(["ownership"]),
633 )
634 .await
635 .unwrap();
636 kb.add_entry(
637 KnowledgeEntry::new("python decorators", "functions wrapping functions")
638 .with_category("python")
639 .with_tags(["meta"]),
640 )
641 .await
642 .unwrap();
643
644 let _ = kb.search_simple("borrow", 10).await.unwrap();
647
648 let only_rust = kb
650 .search(
651 "wrapping",
652 SearchOptions::new(10)
653 .with_category("rust")
654 .without_learning(),
655 )
656 .await
657 .unwrap();
658 for r in &only_rust {
659 assert_eq!(r.entry.category.as_deref(), Some("rust"));
660 }
661
662 let by_tag = kb
664 .search("anything", SearchOptions::new(10).with_tags(["ownership"]))
665 .await
666 .unwrap();
667 for r in &by_tag {
668 assert!(r.entry.tags.iter().any(|t| t == "ownership"));
669 }
670
671 let _ = kb
673 .search("functions", SearchOptions::new(5).with_diversity(0.5))
674 .await
675 .unwrap();
676
677 let none = kb
679 .search("borrow", SearchOptions::new(10).with_min_similarity(1.0))
680 .await
681 .unwrap();
682 assert!(none.is_empty());
683 }
684
685 #[tokio::test]
686 async fn record_feedback_and_stats() {
687 let dir = tempdir().unwrap();
688 let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
689 .await
690 .unwrap();
691 let id = kb
692 .add_entry(
693 KnowledgeEntry::new("a", "alpha")
694 .with_category("c")
695 .with_tags(["t"]),
696 )
697 .await
698 .unwrap();
699 kb.record_feedback(id, true).await.unwrap();
700 kb.record_feedback(id, false).await.unwrap();
701 kb.record_feedback(Uuid::new_v4(), true).await.unwrap(); let stats = kb.stats();
704 assert_eq!(stats.total_entries, 1);
705 assert_eq!(stats.unique_categories, 1);
706 assert_eq!(stats.unique_tags, 1);
707 assert!(stats.learning_enabled);
708 assert_eq!(stats.dimensions, 32);
709 assert!(stats.total_access_count >= 2);
710 }
711
712 #[tokio::test]
713 async fn linking_and_related() {
714 let dir = tempdir().unwrap();
715 let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
716 .await
717 .unwrap();
718 let a = kb.add_entry(KnowledgeEntry::new("a", "x")).await.unwrap();
719 let b = kb.add_entry(KnowledgeEntry::new("b", "y")).await.unwrap();
720
721 kb.link_entries(a, b).await.unwrap();
722 kb.link_entries(a, b).await.unwrap();
724
725 let related = kb.get_related(a, 5);
726 assert_eq!(related.len(), 1);
727 assert_eq!(related[0].id, b);
728
729 let err = kb.link_entries(Uuid::new_v4(), b).await.unwrap_err();
731 assert!(matches!(err, Error::NotFound(_)));
732
733 assert!(kb.get_related(Uuid::new_v4(), 5).is_empty());
735
736 assert_eq!(kb.all_entries().len(), 2);
738 }
739
740 #[tokio::test]
741 async fn reopens_with_existing_entries() {
742 let dir = tempdir().unwrap();
743 let path = dir.path().join("kb.db");
744 let kb = KnowledgeBase::with_config(small_config(&path))
745 .await
746 .unwrap();
747 kb.add_entry(KnowledgeEntry::new("persist", "me"))
748 .await
749 .unwrap();
750 kb.flush().await.unwrap();
751 drop(kb);
752
753 let kb2 = KnowledgeBase::with_config(small_config(&path))
754 .await
755 .unwrap();
756 assert_eq!(kb2.len(), 1);
757 assert_eq!(kb2.all_entries()[0].title, "persist");
758 }
759
760 #[tokio::test]
761 async fn learning_disabled_skips_engine() {
762 let dir = tempdir().unwrap();
763 let cfg = small_config(&dir.path().join("kb.db")).without_learning();
764 let kb = KnowledgeBase::with_config(cfg).await.unwrap();
765 let id = kb.add_entry(KnowledgeEntry::new("t", "c")).await.unwrap();
766 let _ = kb.search_simple("t", 5).await.unwrap();
768 kb.record_feedback(id, true).await.unwrap();
769 assert!(!kb.stats().learning_enabled);
770 }
771
772 #[test]
773 fn mmr_short_circuits_short_lists() {
774 let entry = KnowledgeEntry::new("t", "c");
775 let r = SearchResult::new(entry, 0.5, 0.5);
776 let one = apply_mmr(vec![r.clone()], 0.5);
777 assert_eq!(one.len(), 1);
778 let empty: Vec<SearchResult> = apply_mmr(Vec::new(), 0.5);
779 assert!(empty.is_empty());
780
781 let mut many = Vec::new();
783 for i in 0..3 {
784 let e = KnowledgeEntry::new(format!("t{i}"), "c");
785 many.push(SearchResult::new(e, 0.9 - i as f32 * 0.1, 0.1 * i as f32));
786 }
787 let picked = apply_mmr(many, 0.7);
788 assert!(!picked.is_empty());
789 }
790}