Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12
13#[cfg(feature = "embeddings")]
14use super::embeddings::{cosine_similarity, EmbeddingEngine};
15
16const ALPHA_SEMANTIC: f32 = 0.6;
17const BETA_CONFIDENCE: f32 = 0.25;
18const GAMMA_RECENCY: f32 = 0.15;
19const MAX_RECENCY_DAYS: f32 = 90.0;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FactEmbedding {
23    pub category: String,
24    pub key: String,
25    pub embedding: Vec<f32>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct KnowledgeEmbeddingIndex {
30    pub project_hash: String,
31    pub entries: Vec<FactEmbedding>,
32}
33
34impl KnowledgeEmbeddingIndex {
35    pub fn new(project_hash: &str) -> Self {
36        Self {
37            project_hash: project_hash.to_string(),
38            entries: Vec::new(),
39        }
40    }
41
42    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
43        if let Some(existing) = self
44            .entries
45            .iter_mut()
46            .find(|e| e.category == category && e.key == key)
47        {
48            existing.embedding = embedding;
49        } else {
50            self.entries.push(FactEmbedding {
51                category: category.to_string(),
52                key: key.to_string(),
53                embedding,
54            });
55        }
56    }
57
58    pub fn remove(&mut self, category: &str, key: &str) {
59        self.entries
60            .retain(|e| !(e.category == category && e.key == key));
61    }
62
63    #[cfg(feature = "embeddings")]
64    pub fn semantic_search(
65        &self,
66        query_embedding: &[f32],
67        top_k: usize,
68    ) -> Vec<(&FactEmbedding, f32)> {
69        let mut scored: Vec<(&FactEmbedding, f32)> = self
70            .entries
71            .iter()
72            .map(|e| {
73                let sim = cosine_similarity(query_embedding, &e.embedding);
74                (e, sim)
75            })
76            .collect();
77
78        scored.sort_by(|a, b| {
79            b.1.partial_cmp(&a.1)
80                .unwrap_or(std::cmp::Ordering::Equal)
81                .then_with(|| a.0.category.cmp(&b.0.category))
82                .then_with(|| a.0.key.cmp(&b.0.key))
83        });
84        scored.truncate(top_k);
85        scored
86    }
87
88    fn index_path(project_hash: &str) -> Option<PathBuf> {
89        let dir = crate::core::data_dir::lean_ctx_data_dir()
90            .ok()?
91            .join("knowledge")
92            .join(project_hash);
93        Some(dir.join("embeddings.json"))
94    }
95
96    pub fn load(project_hash: &str) -> Option<Self> {
97        let path = Self::index_path(project_hash)?;
98        let data = std::fs::read_to_string(path).ok()?;
99        serde_json::from_str(&data).ok()
100    }
101
102    pub fn save(&self) -> Result<(), String> {
103        let path = Self::index_path(&self.project_hash)
104            .ok_or_else(|| "Cannot determine data directory".to_string())?;
105        if let Some(dir) = path.parent() {
106            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
107        }
108        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
109        std::fs::write(path, json).map_err(|e| format!("{e}"))
110    }
111}
112
113pub fn reset(project_hash: &str) -> Result<(), String> {
114    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
115        .ok_or_else(|| "Cannot determine data directory".to_string())?;
116    if path.exists() {
117        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
118    }
119    Ok(())
120}
121
122#[derive(Debug)]
123pub struct ScoredFact<'a> {
124    pub fact: &'a KnowledgeFact,
125    pub score: f32,
126    pub semantic_score: f32,
127    pub confidence_score: f32,
128    pub recency_score: f32,
129}
130
131#[cfg(feature = "embeddings")]
132pub fn semantic_recall<'a>(
133    knowledge: &'a ProjectKnowledge,
134    index: &KnowledgeEmbeddingIndex,
135    engine: &EmbeddingEngine,
136    query: &str,
137    top_k: usize,
138) -> Vec<ScoredFact<'a>> {
139    let query_embedding = match engine.embed(query) {
140        Ok(e) => e,
141        Err(_) => return lexical_fallback(knowledge, query, top_k),
142    };
143
144    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
145
146    let mut results: Vec<ScoredFact<'a>> = Vec::new();
147
148    for (entry, sim) in &semantic_hits {
149        if let Some(fact) = knowledge
150            .facts
151            .iter()
152            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
153        {
154            let confidence_score = fact.confidence;
155            let recency_score = recency_decay(fact);
156            let score = ALPHA_SEMANTIC * sim
157                + BETA_CONFIDENCE * confidence_score
158                + GAMMA_RECENCY * recency_score;
159
160            results.push(ScoredFact {
161                fact,
162                score,
163                semantic_score: *sim,
164                confidence_score,
165                recency_score,
166            });
167        }
168    }
169
170    let exact_matches = knowledge.recall(query);
171    for fact in exact_matches {
172        let already_included = results
173            .iter()
174            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
175        if !already_included {
176            results.push(ScoredFact {
177                fact,
178                score: 1.0,
179                semantic_score: 1.0,
180                confidence_score: fact.confidence,
181                recency_score: recency_decay(fact),
182            });
183        }
184    }
185
186    results.sort_by(|a, b| {
187        b.score
188            .partial_cmp(&a.score)
189            .unwrap_or(std::cmp::Ordering::Equal)
190            .then_with(|| {
191                b.confidence_score
192                    .partial_cmp(&a.confidence_score)
193                    .unwrap_or(std::cmp::Ordering::Equal)
194            })
195            .then_with(|| {
196                b.recency_score
197                    .partial_cmp(&a.recency_score)
198                    .unwrap_or(std::cmp::Ordering::Equal)
199            })
200            .then_with(|| a.fact.category.cmp(&b.fact.category))
201            .then_with(|| a.fact.key.cmp(&b.fact.key))
202            .then_with(|| a.fact.value.cmp(&b.fact.value))
203    });
204    results.truncate(top_k);
205    results
206}
207
208pub fn compact_against_knowledge(
209    index: &mut KnowledgeEmbeddingIndex,
210    knowledge: &ProjectKnowledge,
211) {
212    use std::collections::HashMap;
213
214    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
215    for f in &knowledge.facts {
216        if f.is_current() {
217            current.insert((f.category.as_str(), f.key.as_str()), f);
218        }
219    }
220
221    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
222        .entries
223        .iter()
224        .filter_map(|e| {
225            current
226                .get(&(e.category.as_str(), e.key.as_str()))
227                .map(|f| (e.clone(), *f))
228        })
229        .collect();
230
231    kept.sort_by(|(ea, fa), (eb, fb)| {
232        fb.confidence
233            .partial_cmp(&fa.confidence)
234            .unwrap_or(std::cmp::Ordering::Equal)
235            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
236            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
237            .then_with(|| ea.category.cmp(&eb.category))
238            .then_with(|| ea.key.cmp(&eb.key))
239    });
240
241    let max = crate::core::budgets::KNOWLEDGE_EMBEDDINGS_MAX_FACTS;
242    if kept.len() > max {
243        kept.truncate(max);
244    }
245
246    index.entries = kept.into_iter().map(|(e, _)| e).collect();
247}
248
249fn lexical_fallback<'a>(
250    knowledge: &'a ProjectKnowledge,
251    query: &str,
252    top_k: usize,
253) -> Vec<ScoredFact<'a>> {
254    knowledge
255        .recall(query)
256        .into_iter()
257        .take(top_k)
258        .map(|fact| ScoredFact {
259            fact,
260            score: fact.confidence,
261            semantic_score: 0.0,
262            confidence_score: fact.confidence,
263            recency_score: recency_decay(fact),
264        })
265        .collect()
266}
267
268fn recency_decay(fact: &KnowledgeFact) -> f32 {
269    let days_old = chrono::Utc::now()
270        .signed_duration_since(fact.last_confirmed)
271        .num_days() as f32;
272    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
273}
274
275#[cfg(feature = "embeddings")]
276pub fn embed_and_store(
277    index: &mut KnowledgeEmbeddingIndex,
278    engine: &EmbeddingEngine,
279    category: &str,
280    key: &str,
281    value: &str,
282) -> Result<(), String> {
283    let text = format!("{category} {key}: {value}");
284    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
285    index.upsert(category, key, embedding);
286    Ok(())
287}
288
289pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
290    if results.is_empty() {
291        return "No matching facts found.".to_string();
292    }
293
294    let mut output = String::new();
295    for (i, scored) in results.iter().enumerate() {
296        let f = scored.fact;
297        let stars = if f.confidence >= 0.9 {
298            "★★★★"
299        } else if f.confidence >= 0.7 {
300            "★★★"
301        } else if f.confidence >= 0.5 {
302            "★★"
303        } else {
304            "★"
305        };
306
307        if i > 0 {
308            output.push('|');
309        }
310        output.push_str(&format!(
311            "{}:{}={}{} [s:{:.0}%]",
312            f.category,
313            f.key,
314            f.value,
315            stars,
316            scored.score * 100.0
317        ));
318    }
319    output
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn reset_removes_index_file() {
328        let _lock = crate::core::data_dir::test_env_lock();
329        let tmp = tempfile::tempdir().expect("tempdir");
330        std::env::set_var(
331            "LEAN_CTX_DATA_DIR",
332            tmp.path().to_string_lossy().to_string(),
333        );
334
335        let idx = KnowledgeEmbeddingIndex {
336            project_hash: "projhash".to_string(),
337            entries: vec![FactEmbedding {
338                category: "arch".to_string(),
339                key: "db".to_string(),
340                embedding: vec![1.0, 0.0, 0.0],
341            }],
342        };
343        idx.save().expect("save");
344        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
345
346        reset("projhash").expect("reset");
347        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
348
349        std::env::remove_var("LEAN_CTX_DATA_DIR");
350    }
351
352    #[test]
353    fn compact_drops_missing_or_archived_facts() {
354        let mut knowledge = ProjectKnowledge::new("/tmp/project");
355        let now = chrono::Utc::now();
356        knowledge.facts.push(KnowledgeFact {
357            category: "arch".to_string(),
358            key: "db".to_string(),
359            value: "Postgres".to_string(),
360            source_session: "s".to_string(),
361            confidence: 0.9,
362            created_at: now,
363            last_confirmed: now,
364            retrieval_count: 5,
365            last_retrieved: None,
366            valid_from: None,
367            valid_until: None,
368            supersedes: None,
369            confirmation_count: 1,
370        });
371        knowledge.facts.push(KnowledgeFact {
372            category: "arch".to_string(),
373            key: "old".to_string(),
374            value: "Old".to_string(),
375            source_session: "s".to_string(),
376            confidence: 0.9,
377            created_at: now,
378            last_confirmed: now,
379            retrieval_count: 0,
380            last_retrieved: None,
381            valid_from: None,
382            valid_until: Some(now),
383            supersedes: None,
384            confirmation_count: 1,
385        });
386
387        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
388        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
389        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
390        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
391
392        compact_against_knowledge(&mut idx, &knowledge);
393        assert_eq!(idx.entries.len(), 1);
394        assert_eq!(idx.entries[0].category, "arch");
395        assert_eq!(idx.entries[0].key, "db");
396    }
397
398    #[test]
399    fn index_upsert_and_remove() {
400        let mut idx = KnowledgeEmbeddingIndex::new("test");
401        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
402        assert_eq!(idx.entries.len(), 1);
403
404        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
405        assert_eq!(idx.entries.len(), 1);
406        assert_eq!(idx.entries[0].embedding[1], 1.0);
407
408        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
409        assert_eq!(idx.entries.len(), 2);
410
411        idx.remove("arch", "db");
412        assert_eq!(idx.entries.len(), 1);
413        assert_eq!(idx.entries[0].key, "cache");
414    }
415
416    #[test]
417    fn recency_decay_recent() {
418        let fact = KnowledgeFact {
419            category: "test".to_string(),
420            key: "k".to_string(),
421            value: "v".to_string(),
422            source_session: "s".to_string(),
423            confidence: 0.9,
424            created_at: chrono::Utc::now(),
425            last_confirmed: chrono::Utc::now(),
426            retrieval_count: 0,
427            last_retrieved: None,
428            valid_from: None,
429            valid_until: None,
430            supersedes: None,
431            confirmation_count: 1,
432        };
433        let decay = recency_decay(&fact);
434        assert!(
435            decay > 0.95,
436            "Recent fact should have high recency: {decay}"
437        );
438    }
439
440    #[test]
441    fn recency_decay_old() {
442        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
443        let fact = KnowledgeFact {
444            category: "test".to_string(),
445            key: "k".to_string(),
446            value: "v".to_string(),
447            source_session: "s".to_string(),
448            confidence: 0.5,
449            created_at: old_date,
450            last_confirmed: old_date,
451            retrieval_count: 0,
452            last_retrieved: None,
453            valid_from: None,
454            valid_until: None,
455            supersedes: None,
456            confirmation_count: 1,
457        };
458        let decay = recency_decay(&fact);
459        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
460    }
461
462    #[cfg(feature = "embeddings")]
463    #[test]
464    fn semantic_search_ranking() {
465        let mut idx = KnowledgeEmbeddingIndex::new("test");
466        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
467        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
468        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
469
470        let query = vec![1.0, 0.0, 0.0];
471        let results = idx.semantic_search(&query, 2);
472        assert_eq!(results.len(), 2);
473        assert_eq!(results[0].0.key, "db");
474    }
475
476    #[test]
477    fn format_scored_empty() {
478        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
479    }
480
481    #[test]
482    fn format_scored_output() {
483        let fact = KnowledgeFact {
484            category: "arch".to_string(),
485            key: "db".to_string(),
486            value: "PostgreSQL".to_string(),
487            source_session: "s1".to_string(),
488            confidence: 0.95,
489            created_at: chrono::Utc::now(),
490            last_confirmed: chrono::Utc::now(),
491            retrieval_count: 0,
492            last_retrieved: None,
493            valid_from: None,
494            valid_until: None,
495            supersedes: None,
496            confirmation_count: 3,
497        };
498        let scored = vec![ScoredFact {
499            fact: &fact,
500            score: 0.85,
501            semantic_score: 0.9,
502            confidence_score: 0.95,
503            recency_score: 1.0,
504        }];
505        let output = format_scored_facts(&scored);
506        assert!(output.contains("arch:db=PostgreSQL"));
507        assert!(output.contains("★★★★"));
508        assert!(output.contains("[s:85%]"));
509    }
510}