Skip to main content

mnemos/core/
knowledge_base.rs

1//! Knowledge base implementation using ruvector.
2
3use super::{KnowledgeEntry, SearchOptions, SearchResult};
4use crate::embedding::EmbeddingEngine;
5use crate::error::{Error, Result};
6use crate::learning::LearningEngine;
7use crate::storage::StorageBackend;
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use std::sync::Arc;
14use tracing::{debug, info, instrument};
15use uuid::Uuid;
16
17/// Configuration for the knowledge base.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct KnowledgeBaseConfig {
20    /// Embedding dimension size.
21    pub dimensions: usize,
22
23    /// Path to storage file.
24    pub storage_path: String,
25
26    /// Enable self-learning features.
27    pub learning_enabled: bool,
28
29    /// Learning rate for GNN updates.
30    pub learning_rate: f32,
31
32    /// Number of HNSW neighbors (M parameter).
33    pub hnsw_m: usize,
34
35    /// HNSW ef_construction parameter.
36    pub hnsw_ef_construction: usize,
37
38    /// HNSW ef_search parameter.
39    pub hnsw_ef_search: usize,
40
41    /// Batch size for bulk operations.
42    pub batch_size: usize,
43}
44
45impl Default for KnowledgeBaseConfig {
46    fn default() -> Self {
47        Self {
48            dimensions: 384,
49            storage_path: "./knowledge.db".to_string(),
50            learning_enabled: true,
51            learning_rate: 0.01,
52            hnsw_m: 16,
53            hnsw_ef_construction: 200,
54            hnsw_ef_search: 100,
55            batch_size: 1000,
56        }
57    }
58}
59
60impl KnowledgeBaseConfig {
61    /// Create config with custom storage path.
62    pub fn with_path(mut self, path: impl Into<String>) -> Self {
63        self.storage_path = path.into();
64        self
65    }
66
67    /// Set embedding dimensions.
68    pub fn with_dimensions(mut self, dims: usize) -> Self {
69        self.dimensions = dims;
70        self
71    }
72
73    /// Disable learning features.
74    pub fn without_learning(mut self) -> Self {
75        self.learning_enabled = false;
76        self
77    }
78}
79
80/// A self-learning knowledge base powered by ruvector.
81pub struct KnowledgeBase {
82    /// Configuration.
83    config: KnowledgeBaseConfig,
84
85    /// Storage backend for persistence.
86    storage: Arc<StorageBackend>,
87
88    /// Embedding engine for text vectorization.
89    embeddings: Arc<EmbeddingEngine>,
90
91    /// Learning engine for self-improvement.
92    learning: Option<Arc<RwLock<LearningEngine>>>,
93
94    /// In-memory entry cache (id -> entry).
95    entries: DashMap<Uuid, KnowledgeEntry>,
96
97    /// Vector index (id -> embedding).
98    vectors: DashMap<Uuid, Vec<f32>>,
99
100    /// Entry count.
101    count: Arc<RwLock<usize>>,
102}
103
104impl KnowledgeBase {
105    /// Open or create a knowledge base at the given path.
106    #[instrument(skip_all)]
107    pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
108        let config = KnowledgeBaseConfig::default().with_path(path.as_ref().to_string_lossy());
109        Self::with_config(config).await
110    }
111
112    /// Create a knowledge base with custom configuration.
113    #[instrument(skip_all, fields(path = %config.storage_path))]
114    pub async fn with_config(config: KnowledgeBaseConfig) -> Result<Self> {
115        info!("Initializing knowledge base at {}", config.storage_path);
116
117        let storage = Arc::new(StorageBackend::open(&config.storage_path).await?);
118        let embeddings = Arc::new(EmbeddingEngine::new(config.dimensions));
119
120        let learning = if config.learning_enabled {
121            Some(Arc::new(RwLock::new(LearningEngine::new(
122                config.dimensions,
123                config.learning_rate,
124            ))))
125        } else {
126            None
127        };
128
129        let kb = Self {
130            config,
131            storage,
132            embeddings,
133            learning,
134            entries: DashMap::new(),
135            vectors: DashMap::new(),
136            count: Arc::new(RwLock::new(0)),
137        };
138
139        // Load existing entries from storage
140        kb.load_entries().await?;
141
142        info!("Knowledge base initialized with {} entries", kb.len());
143        Ok(kb)
144    }
145
146    /// Load entries from storage.
147    async fn load_entries(&self) -> Result<()> {
148        let stored = self.storage.load_all().await?;
149
150        for (entry, embedding) in stored {
151            self.entries.insert(entry.id, entry.clone());
152            self.vectors.insert(entry.id, embedding);
153        }
154
155        *self.count.write() = self.entries.len();
156        Ok(())
157    }
158
159    /// Get the number of entries.
160    pub fn len(&self) -> usize {
161        *self.count.read()
162    }
163
164    /// Check if the knowledge base is empty.
165    pub fn is_empty(&self) -> bool {
166        self.len() == 0
167    }
168
169    /// Get configuration.
170    pub fn config(&self) -> &KnowledgeBaseConfig {
171        &self.config
172    }
173
174    /// Add a new knowledge entry.
175    #[instrument(skip(self, entry), fields(title = %entry.title))]
176    pub async fn add_entry(&self, entry: KnowledgeEntry) -> Result<Uuid> {
177        let id = entry.id;
178
179        // Generate embedding from content
180        let text = entry.embedding_text();
181        let embedding = self.embeddings.embed(&text).await?;
182
183        // Store in memory
184        self.entries.insert(id, entry.clone());
185        self.vectors.insert(id, embedding.clone());
186
187        // Persist to storage
188        self.storage.save_entry(&entry, &embedding).await?;
189
190        *self.count.write() += 1;
191        debug!("Added entry {}", id);
192
193        Ok(id)
194    }
195
196    /// Add multiple entries in batch.
197    #[instrument(skip(self, entries), fields(count = entries.len()))]
198    pub async fn add_entries(&self, entries: Vec<KnowledgeEntry>) -> Result<Vec<Uuid>> {
199        let mut ids = Vec::with_capacity(entries.len());
200
201        for chunk in entries.chunks(self.config.batch_size) {
202            let mut batch = Vec::with_capacity(chunk.len());
203            for entry in chunk {
204                let text = entry.embedding_text();
205                let embedding = self.embeddings.embed(&text).await?;
206                batch.push((entry.clone(), embedding));
207            }
208
209            for (entry, embedding) in &batch {
210                self.entries.insert(entry.id, entry.clone());
211                self.vectors.insert(entry.id, embedding.clone());
212                ids.push(entry.id);
213            }
214
215            self.storage.save_batch(&batch).await?;
216        }
217
218        *self.count.write() += ids.len();
219        info!("Added {} entries in batch", ids.len());
220
221        Ok(ids)
222    }
223
224    /// Get an entry by ID.
225    pub fn get(&self, id: Uuid) -> Option<KnowledgeEntry> {
226        self.entries.get(&id).map(|e| e.clone())
227    }
228
229    /// Update an existing entry.
230    #[instrument(skip(self, entry), fields(id = %entry.id))]
231    pub async fn update_entry(&self, entry: KnowledgeEntry) -> Result<()> {
232        let id = entry.id;
233
234        if !self.entries.contains_key(&id) {
235            return Err(Error::not_found(id.to_string()));
236        }
237
238        // Regenerate embedding
239        let text = entry.embedding_text();
240        let embedding = self.embeddings.embed(&text).await?;
241
242        // Update in memory
243        self.entries.insert(id, entry.clone());
244        self.vectors.insert(id, embedding.clone());
245
246        // Persist
247        self.storage.save_entry(&entry, &embedding).await?;
248
249        debug!("Updated entry {}", id);
250        Ok(())
251    }
252
253    /// Delete an entry.
254    #[instrument(skip(self), fields(id = %id))]
255    pub async fn delete_entry(&self, id: Uuid) -> Result<()> {
256        if self.entries.remove(&id).is_none() {
257            return Err(Error::not_found(id.to_string()));
258        }
259
260        self.vectors.remove(&id);
261        self.storage.delete_entry(id).await?;
262
263        *self.count.write() -= 1;
264        debug!("Deleted entry {}", id);
265
266        Ok(())
267    }
268
269    /// Search the knowledge base.
270    #[instrument(skip(self), fields(k = options.limit))]
271    pub async fn search(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
272        // Generate query embedding
273        let query_embedding = self.embeddings.embed(query).await?;
274
275        // Find similar vectors using brute force for now
276        // (ruvector HNSW would be used in production)
277        let mut candidates: Vec<(Uuid, f32)> = self
278            .vectors
279            .iter()
280            .map(|entry| {
281                let id = *entry.key();
282                let distance = cosine_distance(&query_embedding, entry.value());
283                (id, distance)
284            })
285            .collect();
286
287        // Sort by distance (ascending)
288        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
289
290        // Apply learning-based re-ranking if enabled
291        if options.use_learning
292            && let Some(learning) = &self.learning
293        {
294            let learning = learning.read();
295            candidates = learning.rerank(&query_embedding, candidates, &self.vectors);
296        }
297
298        // Build results
299        let mut results = Vec::new();
300
301        for (id, distance) in candidates.into_iter().take(options.limit * 2) {
302            if let Some(entry) = self.entries.get(&id) {
303                let entry = entry.clone();
304
305                // Apply filters
306                if let Some(ref cat) = options.category
307                    && entry.category.as_ref() != Some(cat)
308                {
309                    continue;
310                }
311
312                if !options.tags.is_empty()
313                    && !options
314                        .tags
315                        .iter()
316                        .any(|t| entry.tags.iter().any(|et| et == t))
317                {
318                    continue;
319                }
320
321                let similarity = 1.0 - distance;
322                if similarity < options.min_similarity {
323                    continue;
324                }
325
326                results.push(SearchResult::new(entry, similarity, distance));
327
328                if results.len() >= options.limit {
329                    break;
330                }
331            }
332        }
333
334        // Apply MMR diversity if requested
335        if options.diversity > 0.0 {
336            results = apply_mmr(results, options.diversity);
337        }
338
339        // Record query for learning
340        if let Some(learning) = &self.learning {
341            let mut learning = learning.write();
342            learning.record_query(&query_embedding, &results);
343        }
344
345        debug!("Search returned {} results", results.len());
346        Ok(results)
347    }
348
349    /// Simple search with default options.
350    pub async fn search_simple(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
351        self.search(query, SearchOptions::new(limit)).await
352    }
353
354    /// Record user feedback on a search result.
355    #[instrument(skip(self))]
356    pub async fn record_feedback(&self, entry_id: Uuid, positive: bool) -> Result<()> {
357        if let Some(mut entry) = self.entries.get_mut(&entry_id) {
358            let boost = if positive { 0.1 } else { -0.05 };
359            entry.record_access(1.0 + boost);
360
361            // Update learning engine
362            if let Some(learning) = &self.learning {
363                let mut learning = learning.write();
364                if let Some(embedding) = self.vectors.get(&entry_id) {
365                    learning.record_feedback(&embedding, positive);
366                }
367            }
368
369            // Persist updated entry
370            let entry = entry.clone();
371            if let Some(embedding) = self.vectors.get(&entry_id) {
372                self.storage.save_entry(&entry, &embedding).await?;
373            }
374        }
375
376        Ok(())
377    }
378
379    /// Get entries related to a given entry.
380    pub fn get_related(&self, id: Uuid, limit: usize) -> Vec<KnowledgeEntry> {
381        if let Some(entry) = self.entries.get(&id) {
382            entry
383                .related_entries
384                .iter()
385                .take(limit)
386                .filter_map(|rel_id| self.entries.get(rel_id).map(|e| e.clone()))
387                .collect()
388        } else {
389            Vec::new()
390        }
391    }
392
393    /// Link two entries as related.
394    #[allow(clippy::unused_async)]
395    pub async fn link_entries(&self, id1: Uuid, id2: Uuid) -> Result<()> {
396        if let Some(mut entry1) = self.entries.get_mut(&id1) {
397            if !entry1.related_entries.contains(&id2) {
398                entry1.related_entries.push(id2);
399            }
400        } else {
401            return Err(Error::not_found(id1.to_string()));
402        }
403
404        if let Some(mut entry2) = self.entries.get_mut(&id2)
405            && !entry2.related_entries.contains(&id1)
406        {
407            entry2.related_entries.push(id1);
408        }
409
410        Ok(())
411    }
412
413    /// Get all entries (for export/backup).
414    pub fn all_entries(&self) -> Vec<KnowledgeEntry> {
415        self.entries.iter().map(|e| e.value().clone()).collect()
416    }
417
418    /// Get statistics about the knowledge base.
419    pub fn stats(&self) -> KnowledgeBaseStats {
420        let total = self.len();
421        let categories: std::collections::HashSet<_> = self
422            .entries
423            .iter()
424            .filter_map(|e| e.category.clone())
425            .collect();
426
427        let tags: std::collections::HashSet<_> =
428            self.entries.iter().flat_map(|e| e.tags.clone()).collect();
429
430        let total_access: u64 = self.entries.iter().map(|e| e.access_count).sum();
431
432        KnowledgeBaseStats {
433            total_entries: total,
434            unique_categories: categories.len(),
435            unique_tags: tags.len(),
436            total_access_count: total_access,
437            dimensions: self.config.dimensions,
438            learning_enabled: self.config.learning_enabled,
439        }
440    }
441
442    /// Flush all pending writes to storage.
443    pub async fn flush(&self) -> Result<()> {
444        self.storage.flush().await
445    }
446}
447
448/// Statistics about the knowledge base.
449#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct KnowledgeBaseStats {
451    pub total_entries: usize,
452    pub unique_categories: usize,
453    pub unique_tags: usize,
454    pub total_access_count: u64,
455    pub dimensions: usize,
456    pub learning_enabled: bool,
457}
458
459/// Calculate cosine distance between two vectors.
460fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
461    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
462    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
463    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
464
465    if norm_a == 0.0 || norm_b == 0.0 {
466        1.0
467    } else {
468        1.0 - (dot / (norm_a * norm_b))
469    }
470}
471
472/// Apply Maximal Marginal Relevance for diversity.
473fn apply_mmr(mut results: Vec<SearchResult>, lambda: f32) -> Vec<SearchResult> {
474    if results.len() <= 1 {
475        return results;
476    }
477
478    let mut selected = vec![results.remove(0)];
479
480    while !results.is_empty() && selected.len() < results.len() + selected.len() {
481        let mut best_idx = 0;
482        let mut best_score = f32::NEG_INFINITY;
483
484        for (i, candidate) in results.iter().enumerate() {
485            // Relevance term
486            let relevance = candidate.similarity;
487
488            // Diversity term: max similarity to already selected
489            let max_sim = selected
490                .iter()
491                .map(|s| {
492                    // Simplified: use score similarity
493                    1.0 - (s.score - candidate.score).abs()
494                })
495                .max_by(|a, b| a.partial_cmp(b).unwrap())
496                .unwrap_or(0.0);
497
498            // MMR score
499            let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
500
501            if mmr > best_score {
502                best_score = mmr;
503                best_idx = i;
504            }
505        }
506
507        selected.push(results.remove(best_idx));
508    }
509
510    selected
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::core::KnowledgeEntry;
517    use tempfile::tempdir;
518
519    fn small_config(path: &Path) -> KnowledgeBaseConfig {
520        KnowledgeBaseConfig::default()
521            .with_path(path.to_string_lossy())
522            .with_dimensions(32)
523    }
524
525    #[test]
526    fn test_cosine_distance() {
527        let a = vec![1.0, 0.0, 0.0];
528        let b = vec![1.0, 0.0, 0.0];
529        assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
530
531        let c = vec![0.0, 1.0, 0.0];
532        assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
533
534        // Zero-norm path returns 1.0 (max distance).
535        let z = vec![0.0, 0.0, 0.0];
536        assert!((cosine_distance(&a, &z) - 1.0).abs() < 1e-6);
537    }
538
539    #[test]
540    fn config_builder_sets_fields() {
541        let cfg = KnowledgeBaseConfig::default()
542            .with_path("/tmp/x.db")
543            .with_dimensions(64)
544            .without_learning();
545        assert_eq!(cfg.storage_path, "/tmp/x.db");
546        assert_eq!(cfg.dimensions, 64);
547        assert!(!cfg.learning_enabled);
548    }
549
550    #[tokio::test]
551    async fn open_creates_empty_kb() {
552        let dir = tempdir().unwrap();
553        let kb = KnowledgeBase::open(dir.path().join("kb.db")).await.unwrap();
554        assert_eq!(kb.len(), 0);
555        assert!(kb.is_empty());
556        assert_eq!(kb.config().dimensions, 384);
557    }
558
559    #[tokio::test]
560    async fn add_get_update_delete_roundtrip() {
561        let dir = tempdir().unwrap();
562        let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
563            .await
564            .unwrap();
565
566        let entry = KnowledgeEntry::new("Title", "body text").with_category("docs");
567        let id = kb.add_entry(entry.clone()).await.unwrap();
568        assert_eq!(kb.len(), 1);
569        assert!(!kb.is_empty());
570
571        let fetched = kb.get(id).expect("entry should exist");
572        assert_eq!(fetched.title, "Title");
573
574        let mut updated = fetched;
575        updated.content = "new body".into();
576        kb.update_entry(updated.clone()).await.unwrap();
577        assert_eq!(kb.get(id).unwrap().content, "new body");
578
579        kb.delete_entry(id).await.unwrap();
580        assert_eq!(kb.len(), 0);
581        assert!(kb.get(id).is_none());
582    }
583
584    #[tokio::test]
585    async fn update_missing_entry_errors() {
586        let dir = tempdir().unwrap();
587        let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
588            .await
589            .unwrap();
590        let stranger = KnowledgeEntry::new("ghost", "body");
591        let err = kb.update_entry(stranger).await.unwrap_err();
592        assert!(matches!(err, Error::NotFound(_)));
593    }
594
595    #[tokio::test]
596    async fn delete_missing_entry_errors() {
597        let dir = tempdir().unwrap();
598        let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
599            .await
600            .unwrap();
601        let err = kb.delete_entry(Uuid::new_v4()).await.unwrap_err();
602        assert!(matches!(err, Error::NotFound(_)));
603    }
604
605    #[tokio::test]
606    async fn add_entries_batch_persists() {
607        let dir = tempdir().unwrap();
608        let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
609            .await
610            .unwrap();
611        let batch: Vec<_> = (0..5)
612            .map(|i| KnowledgeEntry::new(format!("t{i}"), format!("body {i}")))
613            .collect();
614        let ids = kb.add_entries(batch).await.unwrap();
615        assert_eq!(ids.len(), 5);
616        assert_eq!(kb.len(), 5);
617        kb.flush().await.unwrap();
618    }
619
620    #[tokio::test]
621    async fn search_filters_and_results() {
622        let dir = tempdir().unwrap();
623        // Larger dims so hash-embedder collisions don't make small-corpus
624        // searches flaky.
625        let cfg = KnowledgeBaseConfig::default()
626            .with_path(dir.path().join("kb.db").to_string_lossy())
627            .with_dimensions(128);
628        let kb = KnowledgeBase::with_config(cfg).await.unwrap();
629        kb.add_entry(
630            KnowledgeEntry::new("rust ownership", "borrow checker introduction")
631                .with_category("rust")
632                .with_tags(["ownership"]),
633        )
634        .await
635        .unwrap();
636        kb.add_entry(
637            KnowledgeEntry::new("python decorators", "functions wrapping functions")
638                .with_category("python")
639                .with_tags(["meta"]),
640        )
641        .await
642        .unwrap();
643
644        // search_simple returns Ok (results may be empty if hash embedding
645        // has no positive overlap; we only assert the call path succeeds).
646        let _ = kb.search_simple("borrow", 10).await.unwrap();
647
648        // Category filter — only rust-categorised entries (or none).
649        let only_rust = kb
650            .search(
651                "wrapping",
652                SearchOptions::new(10)
653                    .with_category("rust")
654                    .without_learning(),
655            )
656            .await
657            .unwrap();
658        for r in &only_rust {
659            assert_eq!(r.entry.category.as_deref(), Some("rust"));
660        }
661
662        // Tag filter — every result must carry the requested tag.
663        let by_tag = kb
664            .search("anything", SearchOptions::new(10).with_tags(["ownership"]))
665            .await
666            .unwrap();
667        for r in &by_tag {
668            assert!(r.entry.tags.iter().any(|t| t == "ownership"));
669        }
670
671        // Diversity branch — exercises apply_mmr.
672        let _ = kb
673            .search("functions", SearchOptions::new(5).with_diversity(0.5))
674            .await
675            .unwrap();
676
677        // min_similarity above the achievable maximum filters everything out.
678        let none = kb
679            .search("borrow", SearchOptions::new(10).with_min_similarity(1.0))
680            .await
681            .unwrap();
682        assert!(none.is_empty());
683    }
684
685    #[tokio::test]
686    async fn record_feedback_and_stats() {
687        let dir = tempdir().unwrap();
688        let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
689            .await
690            .unwrap();
691        let id = kb
692            .add_entry(
693                KnowledgeEntry::new("a", "alpha")
694                    .with_category("c")
695                    .with_tags(["t"]),
696            )
697            .await
698            .unwrap();
699        kb.record_feedback(id, true).await.unwrap();
700        kb.record_feedback(id, false).await.unwrap();
701        kb.record_feedback(Uuid::new_v4(), true).await.unwrap(); // unknown id is ok
702
703        let stats = kb.stats();
704        assert_eq!(stats.total_entries, 1);
705        assert_eq!(stats.unique_categories, 1);
706        assert_eq!(stats.unique_tags, 1);
707        assert!(stats.learning_enabled);
708        assert_eq!(stats.dimensions, 32);
709        assert!(stats.total_access_count >= 2);
710    }
711
712    #[tokio::test]
713    async fn linking_and_related() {
714        let dir = tempdir().unwrap();
715        let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
716            .await
717            .unwrap();
718        let a = kb.add_entry(KnowledgeEntry::new("a", "x")).await.unwrap();
719        let b = kb.add_entry(KnowledgeEntry::new("b", "y")).await.unwrap();
720
721        kb.link_entries(a, b).await.unwrap();
722        // idempotent
723        kb.link_entries(a, b).await.unwrap();
724
725        let related = kb.get_related(a, 5);
726        assert_eq!(related.len(), 1);
727        assert_eq!(related[0].id, b);
728
729        // Unknown source id errors.
730        let err = kb.link_entries(Uuid::new_v4(), b).await.unwrap_err();
731        assert!(matches!(err, Error::NotFound(_)));
732
733        // get_related on unknown id returns empty.
734        assert!(kb.get_related(Uuid::new_v4(), 5).is_empty());
735
736        // all_entries surfaces every entry.
737        assert_eq!(kb.all_entries().len(), 2);
738    }
739
740    #[tokio::test]
741    async fn reopens_with_existing_entries() {
742        let dir = tempdir().unwrap();
743        let path = dir.path().join("kb.db");
744        let kb = KnowledgeBase::with_config(small_config(&path))
745            .await
746            .unwrap();
747        kb.add_entry(KnowledgeEntry::new("persist", "me"))
748            .await
749            .unwrap();
750        kb.flush().await.unwrap();
751        drop(kb);
752
753        let kb2 = KnowledgeBase::with_config(small_config(&path))
754            .await
755            .unwrap();
756        assert_eq!(kb2.len(), 1);
757        assert_eq!(kb2.all_entries()[0].title, "persist");
758    }
759
760    #[tokio::test]
761    async fn learning_disabled_skips_engine() {
762        let dir = tempdir().unwrap();
763        let cfg = small_config(&dir.path().join("kb.db")).without_learning();
764        let kb = KnowledgeBase::with_config(cfg).await.unwrap();
765        let id = kb.add_entry(KnowledgeEntry::new("t", "c")).await.unwrap();
766        // Search and feedback both no-op the learning branch.
767        let _ = kb.search_simple("t", 5).await.unwrap();
768        kb.record_feedback(id, true).await.unwrap();
769        assert!(!kb.stats().learning_enabled);
770    }
771
772    #[test]
773    fn mmr_short_circuits_short_lists() {
774        let entry = KnowledgeEntry::new("t", "c");
775        let r = SearchResult::new(entry, 0.5, 0.5);
776        let one = apply_mmr(vec![r.clone()], 0.5);
777        assert_eq!(one.len(), 1);
778        let empty: Vec<SearchResult> = apply_mmr(Vec::new(), 0.5);
779        assert!(empty.is_empty());
780
781        // Multiple results pass through MMR selection loop.
782        let mut many = Vec::new();
783        for i in 0..3 {
784            let e = KnowledgeEntry::new(format!("t{i}"), "c");
785            many.push(SearchResult::new(e, 0.9 - i as f32 * 0.1, 0.1 * i as f32));
786        }
787        let picked = apply_mmr(many, 0.7);
788        assert!(!picked.is_empty());
789    }
790}