Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12
13#[cfg(feature = "embeddings")]
14use super::embeddings::{cosine_similarity, EmbeddingEngine};
15
16#[cfg(feature = "embeddings")]
17const ALPHA_SEMANTIC: f32 = 0.6;
18#[cfg(feature = "embeddings")]
19const BETA_CONFIDENCE: f32 = 0.25;
20#[cfg(feature = "embeddings")]
21const GAMMA_RECENCY: f32 = 0.15;
22#[cfg(feature = "embeddings")]
23const MAX_RECENCY_DAYS: f32 = 90.0;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct FactEmbedding {
27    pub category: String,
28    pub key: String,
29    pub embedding: Vec<f32>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct KnowledgeEmbeddingIndex {
34    pub project_hash: String,
35    pub entries: Vec<FactEmbedding>,
36}
37
38impl KnowledgeEmbeddingIndex {
39    pub fn new(project_hash: &str) -> Self {
40        Self {
41            project_hash: project_hash.to_string(),
42            entries: Vec::new(),
43        }
44    }
45
46    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
47        if let Some(existing) = self
48            .entries
49            .iter_mut()
50            .find(|e| e.category == category && e.key == key)
51        {
52            existing.embedding = embedding;
53        } else {
54            self.entries.push(FactEmbedding {
55                category: category.to_string(),
56                key: key.to_string(),
57                embedding,
58            });
59        }
60    }
61
62    pub fn remove(&mut self, category: &str, key: &str) {
63        self.entries
64            .retain(|e| !(e.category == category && e.key == key));
65    }
66
67    #[cfg(feature = "embeddings")]
68    pub fn semantic_search(
69        &self,
70        query_embedding: &[f32],
71        top_k: usize,
72    ) -> Vec<(&FactEmbedding, f32)> {
73        let mut scored: Vec<(&FactEmbedding, f32)> = self
74            .entries
75            .iter()
76            .map(|e| {
77                let sim = cosine_similarity(query_embedding, &e.embedding);
78                (e, sim)
79            })
80            .collect();
81
82        scored.sort_by(|a, b| {
83            b.1.partial_cmp(&a.1)
84                .unwrap_or(std::cmp::Ordering::Equal)
85                .then_with(|| a.0.category.cmp(&b.0.category))
86                .then_with(|| a.0.key.cmp(&b.0.key))
87        });
88        scored.truncate(top_k);
89        scored
90    }
91
92    fn index_path(project_hash: &str) -> Option<PathBuf> {
93        let dir = crate::core::data_dir::nebu_ctx_data_dir()
94            .ok()?
95            .join("knowledge")
96            .join(project_hash);
97        Some(dir.join("embeddings.json"))
98    }
99
100    pub fn load(project_hash: &str) -> Option<Self> {
101        let path = Self::index_path(project_hash)?;
102        let data = std::fs::read_to_string(path).ok()?;
103        serde_json::from_str(&data).ok()
104    }
105
106    pub fn save(&self) -> Result<(), String> {
107        let path = Self::index_path(&self.project_hash)
108            .ok_or_else(|| "Cannot determine data directory".to_string())?;
109        if let Some(dir) = path.parent() {
110            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
111        }
112        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
113        std::fs::write(path, json).map_err(|e| format!("{e}"))
114    }
115}
116
117pub fn reset(project_hash: &str) -> Result<(), String> {
118    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
119        .ok_or_else(|| "Cannot determine data directory".to_string())?;
120    if path.exists() {
121        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
122    }
123    Ok(())
124}
125
126#[derive(Debug)]
127pub struct ScoredFact<'a> {
128    pub fact: &'a KnowledgeFact,
129    pub score: f32,
130    pub semantic_score: f32,
131    pub confidence_score: f32,
132    pub recency_score: f32,
133}
134
135#[cfg(feature = "embeddings")]
136pub fn semantic_recall<'a>(
137    knowledge: &'a ProjectKnowledge,
138    index: &KnowledgeEmbeddingIndex,
139    engine: &EmbeddingEngine,
140    query: &str,
141    top_k: usize,
142) -> Vec<ScoredFact<'a>> {
143    let query_embedding = match engine.embed(query) {
144        Ok(e) => e,
145        Err(_) => return lexical_fallback(knowledge, query, top_k),
146    };
147
148    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
149
150    let mut results: Vec<ScoredFact<'a>> = Vec::new();
151
152    for (entry, sim) in &semantic_hits {
153        if let Some(fact) = knowledge
154            .facts
155            .iter()
156            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
157        {
158            let confidence_score = fact.confidence;
159            let recency_score = recency_decay(fact);
160            let score = ALPHA_SEMANTIC * sim
161                + BETA_CONFIDENCE * confidence_score
162                + GAMMA_RECENCY * recency_score;
163
164            results.push(ScoredFact {
165                fact,
166                score,
167                semantic_score: *sim,
168                confidence_score,
169                recency_score,
170            });
171        }
172    }
173
174    let exact_matches = knowledge.recall(query);
175    for fact in exact_matches {
176        let already_included = results
177            .iter()
178            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
179        if !already_included {
180            results.push(ScoredFact {
181                fact,
182                score: 1.0,
183                semantic_score: 1.0,
184                confidence_score: fact.confidence,
185                recency_score: recency_decay(fact),
186            });
187        }
188    }
189
190    results.sort_by(|a, b| {
191        b.score
192            .partial_cmp(&a.score)
193            .unwrap_or(std::cmp::Ordering::Equal)
194            .then_with(|| {
195                b.confidence_score
196                    .partial_cmp(&a.confidence_score)
197                    .unwrap_or(std::cmp::Ordering::Equal)
198            })
199            .then_with(|| {
200                b.recency_score
201                    .partial_cmp(&a.recency_score)
202                    .unwrap_or(std::cmp::Ordering::Equal)
203            })
204            .then_with(|| a.fact.category.cmp(&b.fact.category))
205            .then_with(|| a.fact.key.cmp(&b.fact.key))
206            .then_with(|| a.fact.value.cmp(&b.fact.value))
207    });
208    results.truncate(top_k);
209    results
210}
211
212pub fn compact_against_knowledge(
213    index: &mut KnowledgeEmbeddingIndex,
214    knowledge: &ProjectKnowledge,
215) {
216    use std::collections::HashMap;
217
218    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
219    for f in &knowledge.facts {
220        if f.is_current() {
221            current.insert((f.category.as_str(), f.key.as_str()), f);
222        }
223    }
224
225    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
226        .entries
227        .iter()
228        .filter_map(|e| {
229            current
230                .get(&(e.category.as_str(), e.key.as_str()))
231                .map(|f| (e.clone(), *f))
232        })
233        .collect();
234
235    kept.sort_by(|(ea, fa), (eb, fb)| {
236        fb.confidence
237            .partial_cmp(&fa.confidence)
238            .unwrap_or(std::cmp::Ordering::Equal)
239            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
240            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
241            .then_with(|| ea.category.cmp(&eb.category))
242            .then_with(|| ea.key.cmp(&eb.key))
243    });
244
245    let max = crate::core::budgets::KNOWLEDGE_EMBEDDINGS_MAX_FACTS;
246    if kept.len() > max {
247        kept.truncate(max);
248    }
249
250    index.entries = kept.into_iter().map(|(e, _)| e).collect();
251}
252
253#[cfg(feature = "embeddings")]
254fn lexical_fallback<'a>(
255    knowledge: &'a ProjectKnowledge,
256    query: &str,
257    top_k: usize,
258) -> Vec<ScoredFact<'a>> {
259    knowledge
260        .recall(query)
261        .into_iter()
262        .take(top_k)
263        .map(|fact| ScoredFact {
264            fact,
265            score: fact.confidence,
266            semantic_score: 0.0,
267            confidence_score: fact.confidence,
268            recency_score: recency_decay(fact),
269        })
270        .collect()
271}
272
273#[cfg(feature = "embeddings")]
274fn recency_decay(fact: &KnowledgeFact) -> f32 {
275    let days_old = chrono::Utc::now()
276        .signed_duration_since(fact.last_confirmed)
277        .num_days() as f32;
278    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
279}
280
281#[cfg(feature = "embeddings")]
282pub fn embed_and_store(
283    index: &mut KnowledgeEmbeddingIndex,
284    engine: &EmbeddingEngine,
285    category: &str,
286    key: &str,
287    value: &str,
288) -> Result<(), String> {
289    let text = format!("{category} {key}: {value}");
290    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
291    index.upsert(category, key, embedding);
292    Ok(())
293}
294
295pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
296    if results.is_empty() {
297        return "No matching facts found.".to_string();
298    }
299
300    let mut output = String::new();
301    for (i, scored) in results.iter().enumerate() {
302        let f = scored.fact;
303        let stars = if f.confidence >= 0.9 {
304            "★★★★"
305        } else if f.confidence >= 0.7 {
306            "★★★"
307        } else if f.confidence >= 0.5 {
308            "★★"
309        } else {
310            "★"
311        };
312
313        if i > 0 {
314            output.push('|');
315        }
316        output.push_str(&format!(
317            "{}:{}={}{} [s:{:.0}%]",
318            f.category,
319            f.key,
320            f.value,
321            stars,
322            scored.score * 100.0
323        ));
324    }
325    output
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn reset_removes_index_file() {
334        let _lock = crate::core::data_dir::test_env_lock();
335        let tmp = tempfile::tempdir().expect("tempdir");
336        std::env::set_var(
337            "NEBU_CTX_DATA_DIR",
338            tmp.path().to_string_lossy().to_string(),
339        );
340
341        let idx = KnowledgeEmbeddingIndex {
342            project_hash: "projhash".to_string(),
343            entries: vec![FactEmbedding {
344                category: "arch".to_string(),
345                key: "db".to_string(),
346                embedding: vec![1.0, 0.0, 0.0],
347            }],
348        };
349        idx.save().expect("save");
350        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
351
352        reset("projhash").expect("reset");
353        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
354
355        std::env::remove_var("NEBU_CTX_DATA_DIR");
356    }
357
358    #[test]
359    fn compact_drops_missing_or_archived_facts() {
360        let mut knowledge = ProjectKnowledge::new("/tmp/project");
361        let now = chrono::Utc::now();
362        knowledge.facts.push(KnowledgeFact {
363            category: "arch".to_string(),
364            key: "db".to_string(),
365            value: "Postgres".to_string(),
366            source_session: "s".to_string(),
367            confidence: 0.9,
368            created_at: now,
369            last_confirmed: now,
370            retrieval_count: 5,
371            last_retrieved: None,
372            valid_from: None,
373            valid_until: None,
374            supersedes: None,
375            confirmation_count: 1,
376        });
377        knowledge.facts.push(KnowledgeFact {
378            category: "arch".to_string(),
379            key: "old".to_string(),
380            value: "Old".to_string(),
381            source_session: "s".to_string(),
382            confidence: 0.9,
383            created_at: now,
384            last_confirmed: now,
385            retrieval_count: 0,
386            last_retrieved: None,
387            valid_from: None,
388            valid_until: Some(now),
389            supersedes: None,
390            confirmation_count: 1,
391        });
392
393        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
394        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
395        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
396        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
397
398        compact_against_knowledge(&mut idx, &knowledge);
399        assert_eq!(idx.entries.len(), 1);
400        assert_eq!(idx.entries[0].category, "arch");
401        assert_eq!(idx.entries[0].key, "db");
402    }
403
404    #[test]
405    fn index_upsert_and_remove() {
406        let mut idx = KnowledgeEmbeddingIndex::new("test");
407        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
408        assert_eq!(idx.entries.len(), 1);
409
410        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
411        assert_eq!(idx.entries.len(), 1);
412        assert_eq!(idx.entries[0].embedding[1], 1.0);
413
414        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
415        assert_eq!(idx.entries.len(), 2);
416
417        idx.remove("arch", "db");
418        assert_eq!(idx.entries.len(), 1);
419        assert_eq!(idx.entries[0].key, "cache");
420    }
421
422    #[cfg(feature = "embeddings")]
423    #[test]
424    fn recency_decay_recent() {
425        let fact = KnowledgeFact {
426            category: "test".to_string(),
427            key: "k".to_string(),
428            value: "v".to_string(),
429            source_session: "s".to_string(),
430            confidence: 0.9,
431            created_at: chrono::Utc::now(),
432            last_confirmed: chrono::Utc::now(),
433            retrieval_count: 0,
434            last_retrieved: None,
435            valid_from: None,
436            valid_until: None,
437            supersedes: None,
438            confirmation_count: 1,
439        };
440        let decay = recency_decay(&fact);
441        assert!(
442            decay > 0.95,
443            "Recent fact should have high recency: {decay}"
444        );
445    }
446
447    #[cfg(feature = "embeddings")]
448    #[test]
449    fn recency_decay_old() {
450        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
451        let fact = KnowledgeFact {
452            category: "test".to_string(),
453            key: "k".to_string(),
454            value: "v".to_string(),
455            source_session: "s".to_string(),
456            confidence: 0.5,
457            created_at: old_date,
458            last_confirmed: old_date,
459            retrieval_count: 0,
460            last_retrieved: None,
461            valid_from: None,
462            valid_until: None,
463            supersedes: None,
464            confirmation_count: 1,
465        };
466        let decay = recency_decay(&fact);
467        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
468    }
469
470    #[cfg(feature = "embeddings")]
471    #[test]
472    fn semantic_search_ranking() {
473        let mut idx = KnowledgeEmbeddingIndex::new("test");
474        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
475        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
476        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
477
478        let query = vec![1.0, 0.0, 0.0];
479        let results = idx.semantic_search(&query, 2);
480        assert_eq!(results.len(), 2);
481        assert_eq!(results[0].0.key, "db");
482    }
483
484    #[test]
485    fn format_scored_empty() {
486        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
487    }
488
489    #[test]
490    fn format_scored_output() {
491        let fact = KnowledgeFact {
492            category: "arch".to_string(),
493            key: "db".to_string(),
494            value: "PostgreSQL".to_string(),
495            source_session: "s1".to_string(),
496            confidence: 0.95,
497            created_at: chrono::Utc::now(),
498            last_confirmed: chrono::Utc::now(),
499            retrieval_count: 0,
500            last_retrieved: None,
501            valid_from: None,
502            valid_until: None,
503            supersedes: None,
504            confirmation_count: 3,
505        };
506        let scored = vec![ScoredFact {
507            fact: &fact,
508            score: 0.85,
509            semantic_score: 0.9,
510            confidence_score: 0.95,
511            recency_score: 1.0,
512        }];
513        let output = format_scored_facts(&scored);
514        assert!(output.contains("arch:db=PostgreSQL"));
515        assert!(output.contains("★★★★"));
516        assert!(output.contains("[s:85%]"));
517    }
518}