Skip to main content

converge_knowledge/core/
knowledge_base.rs

1//! Knowledge base implementation using ruvector.
2
3use super::{KnowledgeEntry, SearchOptions, SearchResult};
4use crate::embedding::EmbeddingEngine;
5use crate::error::{Error, Result};
6use crate::learning::LearningEngine;
7use crate::storage::StorageBackend;
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use std::sync::Arc;
14use tracing::{debug, info, instrument};
15use uuid::Uuid;
16
17/// Configuration for the knowledge base.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct KnowledgeBaseConfig {
20    /// Embedding dimension size.
21    pub dimensions: usize,
22
23    /// Path to storage file.
24    pub storage_path: String,
25
26    /// Enable self-learning features.
27    pub learning_enabled: bool,
28
29    /// Learning rate for GNN updates.
30    pub learning_rate: f32,
31
32    /// Number of HNSW neighbors (M parameter).
33    pub hnsw_m: usize,
34
35    /// HNSW ef_construction parameter.
36    pub hnsw_ef_construction: usize,
37
38    /// HNSW ef_search parameter.
39    pub hnsw_ef_search: usize,
40
41    /// Batch size for bulk operations.
42    pub batch_size: usize,
43}
44
45impl Default for KnowledgeBaseConfig {
46    fn default() -> Self {
47        Self {
48            dimensions: 384,
49            storage_path: "./knowledge.db".to_string(),
50            learning_enabled: true,
51            learning_rate: 0.01,
52            hnsw_m: 16,
53            hnsw_ef_construction: 200,
54            hnsw_ef_search: 100,
55            batch_size: 1000,
56        }
57    }
58}
59
60impl KnowledgeBaseConfig {
61    /// Create config with custom storage path.
62    pub fn with_path(mut self, path: impl Into<String>) -> Self {
63        self.storage_path = path.into();
64        self
65    }
66
67    /// Set embedding dimensions.
68    pub fn with_dimensions(mut self, dims: usize) -> Self {
69        self.dimensions = dims;
70        self
71    }
72
73    /// Disable learning features.
74    pub fn without_learning(mut self) -> Self {
75        self.learning_enabled = false;
76        self
77    }
78}
79
80/// A self-learning knowledge base powered by ruvector.
81pub struct KnowledgeBase {
82    /// Configuration.
83    config: KnowledgeBaseConfig,
84
85    /// Storage backend for persistence.
86    storage: Arc<StorageBackend>,
87
88    /// Embedding engine for text vectorization.
89    embeddings: Arc<EmbeddingEngine>,
90
91    /// Learning engine for self-improvement.
92    learning: Option<Arc<RwLock<LearningEngine>>>,
93
94    /// In-memory entry cache (id -> entry).
95    entries: DashMap<Uuid, KnowledgeEntry>,
96
97    /// Vector index (id -> embedding).
98    vectors: DashMap<Uuid, Vec<f32>>,
99
100    /// Entry count.
101    count: Arc<RwLock<usize>>,
102}
103
104impl KnowledgeBase {
105    /// Open or create a knowledge base at the given path.
106    #[instrument(skip_all)]
107    pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
108        let config = KnowledgeBaseConfig::default().with_path(path.as_ref().to_string_lossy());
109        Self::with_config(config).await
110    }
111
112    /// Create a knowledge base with custom configuration.
113    #[instrument(skip_all, fields(path = %config.storage_path))]
114    pub async fn with_config(config: KnowledgeBaseConfig) -> Result<Self> {
115        info!("Initializing knowledge base at {}", config.storage_path);
116
117        let storage = Arc::new(StorageBackend::open(&config.storage_path).await?);
118        let embeddings = Arc::new(EmbeddingEngine::new(config.dimensions));
119
120        let learning = if config.learning_enabled {
121            Some(Arc::new(RwLock::new(LearningEngine::new(
122                config.dimensions,
123                config.learning_rate,
124            ))))
125        } else {
126            None
127        };
128
129        let kb = Self {
130            config,
131            storage,
132            embeddings,
133            learning,
134            entries: DashMap::new(),
135            vectors: DashMap::new(),
136            count: Arc::new(RwLock::new(0)),
137        };
138
139        // Load existing entries from storage
140        kb.load_entries().await?;
141
142        info!("Knowledge base initialized with {} entries", kb.len());
143        Ok(kb)
144    }
145
146    /// Load entries from storage.
147    async fn load_entries(&self) -> Result<()> {
148        let stored = self.storage.load_all().await?;
149
150        for (entry, embedding) in stored {
151            self.entries.insert(entry.id, entry.clone());
152            self.vectors.insert(entry.id, embedding);
153        }
154
155        *self.count.write() = self.entries.len();
156        Ok(())
157    }
158
159    /// Get the number of entries.
160    pub fn len(&self) -> usize {
161        *self.count.read()
162    }
163
164    /// Check if the knowledge base is empty.
165    pub fn is_empty(&self) -> bool {
166        self.len() == 0
167    }
168
169    /// Get configuration.
170    pub fn config(&self) -> &KnowledgeBaseConfig {
171        &self.config
172    }
173
174    /// Add a new knowledge entry.
175    #[instrument(skip(self, entry), fields(title = %entry.title))]
176    pub async fn add_entry(&self, entry: KnowledgeEntry) -> Result<Uuid> {
177        let id = entry.id;
178
179        // Generate embedding from content
180        let text = entry.embedding_text();
181        let embedding = self.embeddings.embed(&text)?;
182
183        // Store in memory
184        self.entries.insert(id, entry.clone());
185        self.vectors.insert(id, embedding.clone());
186
187        // Persist to storage
188        self.storage.save_entry(&entry, &embedding).await?;
189
190        *self.count.write() += 1;
191        debug!("Added entry {}", id);
192
193        Ok(id)
194    }
195
196    /// Add multiple entries in batch.
197    #[instrument(skip(self, entries), fields(count = entries.len()))]
198    pub async fn add_entries(&self, entries: Vec<KnowledgeEntry>) -> Result<Vec<Uuid>> {
199        let mut ids = Vec::with_capacity(entries.len());
200
201        for chunk in entries.chunks(self.config.batch_size) {
202            let batch: Vec<_> = chunk
203                .iter()
204                .map(|entry| {
205                    let text = entry.embedding_text();
206                    let embedding = self.embeddings.embed(&text)?;
207                    Ok((entry.clone(), embedding))
208                })
209                .collect::<Result<Vec<_>>>()?;
210
211            for (entry, embedding) in &batch {
212                self.entries.insert(entry.id, entry.clone());
213                self.vectors.insert(entry.id, embedding.clone());
214                ids.push(entry.id);
215            }
216
217            self.storage.save_batch(&batch).await?;
218        }
219
220        *self.count.write() += ids.len();
221        info!("Added {} entries in batch", ids.len());
222
223        Ok(ids)
224    }
225
226    /// Get an entry by ID.
227    pub fn get(&self, id: Uuid) -> Option<KnowledgeEntry> {
228        self.entries.get(&id).map(|e| e.clone())
229    }
230
231    /// Update an existing entry.
232    #[instrument(skip(self, entry), fields(id = %entry.id))]
233    pub async fn update_entry(&self, entry: KnowledgeEntry) -> Result<()> {
234        let id = entry.id;
235
236        if !self.entries.contains_key(&id) {
237            return Err(Error::not_found(id.to_string()));
238        }
239
240        // Regenerate embedding
241        let text = entry.embedding_text();
242        let embedding = self.embeddings.embed(&text)?;
243
244        // Update in memory
245        self.entries.insert(id, entry.clone());
246        self.vectors.insert(id, embedding.clone());
247
248        // Persist
249        self.storage.save_entry(&entry, &embedding).await?;
250
251        debug!("Updated entry {}", id);
252        Ok(())
253    }
254
255    /// Delete an entry.
256    #[instrument(skip(self), fields(id = %id))]
257    pub async fn delete_entry(&self, id: Uuid) -> Result<()> {
258        if self.entries.remove(&id).is_none() {
259            return Err(Error::not_found(id.to_string()));
260        }
261
262        self.vectors.remove(&id);
263        self.storage.delete_entry(id).await?;
264
265        *self.count.write() -= 1;
266        debug!("Deleted entry {}", id);
267
268        Ok(())
269    }
270
271    /// Search the knowledge base.
272    #[instrument(skip(self), fields(k = options.limit))]
273    pub async fn search(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
274        // Generate query embedding
275        let query_embedding = self.embeddings.embed(query)?;
276
277        // Find similar vectors using brute force for now
278        // (ruvector HNSW would be used in production)
279        let mut candidates: Vec<(Uuid, f32)> = self
280            .vectors
281            .iter()
282            .map(|entry| {
283                let id = *entry.key();
284                let distance = cosine_distance(&query_embedding, entry.value());
285                (id, distance)
286            })
287            .collect();
288
289        // Sort by distance (ascending)
290        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
291
292        // Apply learning-based re-ranking if enabled
293        if options.use_learning {
294            if let Some(learning) = &self.learning {
295                let learning = learning.read();
296                candidates = learning.rerank(&query_embedding, candidates, &self.vectors);
297            }
298        }
299
300        // Build results
301        let mut results = Vec::new();
302
303        for (id, distance) in candidates.into_iter().take(options.limit * 2) {
304            if let Some(entry) = self.entries.get(&id) {
305                let entry = entry.clone();
306
307                // Apply filters
308                if let Some(ref cat) = options.category {
309                    if entry.category.as_ref() != Some(cat) {
310                        continue;
311                    }
312                }
313
314                if !options.tags.is_empty()
315                    && !options
316                        .tags
317                        .iter()
318                        .any(|t| entry.tags.iter().any(|et| et == t))
319                {
320                    continue;
321                }
322
323                let similarity = 1.0 - distance;
324                if similarity < options.min_similarity {
325                    continue;
326                }
327
328                results.push(SearchResult::new(entry, similarity, distance));
329
330                if results.len() >= options.limit {
331                    break;
332                }
333            }
334        }
335
336        // Apply MMR diversity if requested
337        if options.diversity > 0.0 {
338            results = apply_mmr(results, options.diversity);
339        }
340
341        // Record query for learning
342        if let Some(learning) = &self.learning {
343            let mut learning = learning.write();
344            learning.record_query(&query_embedding, &results);
345        }
346
347        debug!("Search returned {} results", results.len());
348        Ok(results)
349    }
350
351    /// Simple search with default options.
352    pub async fn search_simple(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
353        self.search(query, SearchOptions::new(limit)).await
354    }
355
356    /// Record user feedback on a search result.
357    #[instrument(skip(self))]
358    pub async fn record_feedback(&self, entry_id: Uuid, positive: bool) -> Result<()> {
359        if let Some(mut entry) = self.entries.get_mut(&entry_id) {
360            let boost = if positive { 0.1 } else { -0.05 };
361            entry.record_access(1.0 + boost);
362
363            // Update learning engine
364            if let Some(learning) = &self.learning {
365                let mut learning = learning.write();
366                if let Some(embedding) = self.vectors.get(&entry_id) {
367                    learning.record_feedback(&embedding, positive);
368                }
369            }
370
371            // Persist updated entry
372            let entry = entry.clone();
373            if let Some(embedding) = self.vectors.get(&entry_id) {
374                self.storage.save_entry(&entry, &embedding).await?;
375            }
376        }
377
378        Ok(())
379    }
380
381    /// Get entries related to a given entry.
382    pub fn get_related(&self, id: Uuid, limit: usize) -> Vec<KnowledgeEntry> {
383        if let Some(entry) = self.entries.get(&id) {
384            entry
385                .related_entries
386                .iter()
387                .take(limit)
388                .filter_map(|rel_id| self.entries.get(rel_id).map(|e| e.clone()))
389                .collect()
390        } else {
391            Vec::new()
392        }
393    }
394
395    /// Link two entries as related.
396    pub async fn link_entries(&self, id1: Uuid, id2: Uuid) -> Result<()> {
397        if let Some(mut entry1) = self.entries.get_mut(&id1) {
398            if !entry1.related_entries.contains(&id2) {
399                entry1.related_entries.push(id2);
400            }
401        } else {
402            return Err(Error::not_found(id1.to_string()));
403        }
404
405        if let Some(mut entry2) = self.entries.get_mut(&id2) {
406            if !entry2.related_entries.contains(&id1) {
407                entry2.related_entries.push(id1);
408            }
409        }
410
411        Ok(())
412    }
413
414    /// Get all entries (for export/backup).
415    pub fn all_entries(&self) -> Vec<KnowledgeEntry> {
416        self.entries.iter().map(|e| e.value().clone()).collect()
417    }
418
419    /// Get statistics about the knowledge base.
420    pub fn stats(&self) -> KnowledgeBaseStats {
421        let total = self.len();
422        let categories: std::collections::HashSet<_> = self
423            .entries
424            .iter()
425            .filter_map(|e| e.category.clone())
426            .collect();
427
428        let tags: std::collections::HashSet<_> =
429            self.entries.iter().flat_map(|e| e.tags.clone()).collect();
430
431        let total_access: u64 = self.entries.iter().map(|e| e.access_count).sum();
432
433        KnowledgeBaseStats {
434            total_entries: total,
435            unique_categories: categories.len(),
436            unique_tags: tags.len(),
437            total_access_count: total_access,
438            dimensions: self.config.dimensions,
439            learning_enabled: self.config.learning_enabled,
440        }
441    }
442
443    /// Flush all pending writes to storage.
444    pub async fn flush(&self) -> Result<()> {
445        self.storage.flush().await
446    }
447}
448
449/// Statistics about the knowledge base.
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct KnowledgeBaseStats {
452    pub total_entries: usize,
453    pub unique_categories: usize,
454    pub unique_tags: usize,
455    pub total_access_count: u64,
456    pub dimensions: usize,
457    pub learning_enabled: bool,
458}
459
460/// Calculate cosine distance between two vectors.
461fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
462    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
463    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
464    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
465
466    if norm_a == 0.0 || norm_b == 0.0 {
467        1.0
468    } else {
469        1.0 - (dot / (norm_a * norm_b))
470    }
471}
472
473/// Apply Maximal Marginal Relevance for diversity.
474fn apply_mmr(mut results: Vec<SearchResult>, lambda: f32) -> Vec<SearchResult> {
475    if results.len() <= 1 {
476        return results;
477    }
478
479    let mut selected = vec![results.remove(0)];
480
481    while !results.is_empty() && selected.len() < results.len() + selected.len() {
482        let mut best_idx = 0;
483        let mut best_score = f32::NEG_INFINITY;
484
485        for (i, candidate) in results.iter().enumerate() {
486            // Relevance term
487            let relevance = candidate.similarity;
488
489            // Diversity term: max similarity to already selected
490            let max_sim = selected
491                .iter()
492                .map(|s| {
493                    // Simplified: use score similarity
494                    1.0 - (s.score - candidate.score).abs()
495                })
496                .max_by(|a, b| a.partial_cmp(b).unwrap())
497                .unwrap_or(0.0);
498
499            // MMR score
500            let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
501
502            if mmr > best_score {
503                best_score = mmr;
504                best_idx = i;
505            }
506        }
507
508        selected.push(results.remove(best_idx));
509    }
510
511    selected
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_cosine_distance() {
520        let a = vec![1.0, 0.0, 0.0];
521        let b = vec![1.0, 0.0, 0.0];
522        assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
523
524        let c = vec![0.0, 1.0, 0.0];
525        assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
526    }
527}