Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12
13#[cfg(feature = "embeddings")]
14use super::embeddings::{cosine_similarity, EmbeddingEngine};
15
16const ALPHA_SEMANTIC: f32 = 0.6;
17const BETA_CONFIDENCE: f32 = 0.25;
18const GAMMA_RECENCY: f32 = 0.15;
19const MAX_RECENCY_DAYS: f32 = 90.0;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FactEmbedding {
23    pub category: String,
24    pub key: String,
25    pub embedding: Vec<f32>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct KnowledgeEmbeddingIndex {
30    pub project_hash: String,
31    pub entries: Vec<FactEmbedding>,
32}
33
34impl KnowledgeEmbeddingIndex {
35    pub fn new(project_hash: &str) -> Self {
36        Self {
37            project_hash: project_hash.to_string(),
38            entries: Vec::new(),
39        }
40    }
41
42    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
43        if let Some(existing) = self
44            .entries
45            .iter_mut()
46            .find(|e| e.category == category && e.key == key)
47        {
48            existing.embedding = embedding;
49        } else {
50            self.entries.push(FactEmbedding {
51                category: category.to_string(),
52                key: key.to_string(),
53                embedding,
54            });
55        }
56    }
57
58    pub fn remove(&mut self, category: &str, key: &str) {
59        self.entries
60            .retain(|e| !(e.category == category && e.key == key));
61    }
62
63    #[cfg(feature = "embeddings")]
64    pub fn semantic_search(
65        &self,
66        query_embedding: &[f32],
67        top_k: usize,
68    ) -> Vec<(&FactEmbedding, f32)> {
69        let mut scored: Vec<(&FactEmbedding, f32)> = self
70            .entries
71            .iter()
72            .map(|e| {
73                let sim = cosine_similarity(query_embedding, &e.embedding);
74                (e, sim)
75            })
76            .collect();
77
78        scored.sort_by(|a, b| {
79            b.1.partial_cmp(&a.1)
80                .unwrap_or(std::cmp::Ordering::Equal)
81                .then_with(|| a.0.category.cmp(&b.0.category))
82                .then_with(|| a.0.key.cmp(&b.0.key))
83        });
84        scored.truncate(top_k);
85        scored
86    }
87
88    fn index_path(project_hash: &str) -> Option<PathBuf> {
89        let dir = crate::core::data_dir::lean_ctx_data_dir()
90            .ok()?
91            .join("knowledge")
92            .join(project_hash);
93        Some(dir.join("embeddings.json"))
94    }
95
96    pub fn load(project_hash: &str) -> Option<Self> {
97        let path = Self::index_path(project_hash)?;
98        let data = std::fs::read_to_string(path).ok()?;
99        serde_json::from_str(&data).ok()
100    }
101
102    pub fn save(&self) -> Result<(), String> {
103        let path = Self::index_path(&self.project_hash)
104            .ok_or_else(|| "Cannot determine data directory".to_string())?;
105        if let Some(dir) = path.parent() {
106            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
107        }
108        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
109        std::fs::write(path, json).map_err(|e| format!("{e}"))
110    }
111}
112
113pub fn reset(project_hash: &str) -> Result<(), String> {
114    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
115        .ok_or_else(|| "Cannot determine data directory".to_string())?;
116    if path.exists() {
117        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
118    }
119    Ok(())
120}
121
122#[derive(Debug)]
123pub struct ScoredFact<'a> {
124    pub fact: &'a KnowledgeFact,
125    pub score: f32,
126    pub semantic_score: f32,
127    pub confidence_score: f32,
128    pub recency_score: f32,
129}
130
131#[cfg(feature = "embeddings")]
132pub fn semantic_recall<'a>(
133    knowledge: &'a ProjectKnowledge,
134    index: &KnowledgeEmbeddingIndex,
135    engine: &EmbeddingEngine,
136    query: &str,
137    top_k: usize,
138) -> Vec<ScoredFact<'a>> {
139    let Ok(query_embedding) = engine.embed(query) else {
140        return lexical_fallback(knowledge, query, top_k);
141    };
142
143    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
144
145    let mut results: Vec<ScoredFact<'a>> = Vec::new();
146
147    for (entry, sim) in &semantic_hits {
148        if let Some(fact) = knowledge
149            .facts
150            .iter()
151            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
152        {
153            let confidence_score = fact.confidence;
154            let recency_score = recency_decay(fact);
155            let score = ALPHA_SEMANTIC * sim
156                + BETA_CONFIDENCE * confidence_score
157                + GAMMA_RECENCY * recency_score;
158
159            results.push(ScoredFact {
160                fact,
161                score,
162                semantic_score: *sim,
163                confidence_score,
164                recency_score,
165            });
166        }
167    }
168
169    let exact_matches = knowledge.recall(query);
170    for fact in exact_matches {
171        let already_included = results
172            .iter()
173            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
174        if !already_included {
175            results.push(ScoredFact {
176                fact,
177                score: 1.0,
178                semantic_score: 1.0,
179                confidence_score: fact.confidence,
180                recency_score: recency_decay(fact),
181            });
182        }
183    }
184
185    results.sort_by(|a, b| {
186        b.score
187            .partial_cmp(&a.score)
188            .unwrap_or(std::cmp::Ordering::Equal)
189            .then_with(|| {
190                b.confidence_score
191                    .partial_cmp(&a.confidence_score)
192                    .unwrap_or(std::cmp::Ordering::Equal)
193            })
194            .then_with(|| {
195                b.recency_score
196                    .partial_cmp(&a.recency_score)
197                    .unwrap_or(std::cmp::Ordering::Equal)
198            })
199            .then_with(|| a.fact.category.cmp(&b.fact.category))
200            .then_with(|| a.fact.key.cmp(&b.fact.key))
201            .then_with(|| a.fact.value.cmp(&b.fact.value))
202    });
203    results.truncate(top_k);
204    results
205}
206
207pub fn compact_against_knowledge(
208    index: &mut KnowledgeEmbeddingIndex,
209    knowledge: &ProjectKnowledge,
210) {
211    use std::collections::HashMap;
212
213    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
214    for f in &knowledge.facts {
215        if f.is_current() {
216            current.insert((f.category.as_str(), f.key.as_str()), f);
217        }
218    }
219
220    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
221        .entries
222        .iter()
223        .filter_map(|e| {
224            current
225                .get(&(e.category.as_str(), e.key.as_str()))
226                .map(|f| (e.clone(), *f))
227        })
228        .collect();
229
230    kept.sort_by(|(ea, fa), (eb, fb)| {
231        fb.confidence
232            .partial_cmp(&fa.confidence)
233            .unwrap_or(std::cmp::Ordering::Equal)
234            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
235            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
236            .then_with(|| ea.category.cmp(&eb.category))
237            .then_with(|| ea.key.cmp(&eb.key))
238    });
239
240    let max = crate::core::budgets::KNOWLEDGE_EMBEDDINGS_MAX_FACTS;
241    if kept.len() > max {
242        kept.truncate(max);
243    }
244
245    index.entries = kept.into_iter().map(|(e, _)| e).collect();
246}
247
248fn lexical_fallback<'a>(
249    knowledge: &'a ProjectKnowledge,
250    query: &str,
251    top_k: usize,
252) -> Vec<ScoredFact<'a>> {
253    knowledge
254        .recall(query)
255        .into_iter()
256        .take(top_k)
257        .map(|fact| ScoredFact {
258            fact,
259            score: fact.confidence,
260            semantic_score: 0.0,
261            confidence_score: fact.confidence,
262            recency_score: recency_decay(fact),
263        })
264        .collect()
265}
266
267fn recency_decay(fact: &KnowledgeFact) -> f32 {
268    let days_old = chrono::Utc::now()
269        .signed_duration_since(fact.last_confirmed)
270        .num_days() as f32;
271    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
272}
273
274#[cfg(feature = "embeddings")]
275pub fn embed_and_store(
276    index: &mut KnowledgeEmbeddingIndex,
277    engine: &EmbeddingEngine,
278    category: &str,
279    key: &str,
280    value: &str,
281) -> Result<(), String> {
282    let text = format!("{category} {key}: {value}");
283    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
284    index.upsert(category, key, embedding);
285    Ok(())
286}
287
288pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
289    if results.is_empty() {
290        return "No matching facts found.".to_string();
291    }
292
293    let mut output = String::new();
294    for (i, scored) in results.iter().enumerate() {
295        let f = scored.fact;
296        let stars = if f.confidence >= 0.9 {
297            "★★★★"
298        } else if f.confidence >= 0.7 {
299            "★★★"
300        } else if f.confidence >= 0.5 {
301            "★★"
302        } else {
303            "★"
304        };
305
306        if i > 0 {
307            output.push('|');
308        }
309        output.push_str(&format!(
310            "{}:{}={}{} [s:{:.0}%]",
311            f.category,
312            f.key,
313            f.value,
314            stars,
315            scored.score * 100.0
316        ));
317    }
318    output
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn reset_removes_index_file() {
327        let _lock = crate::core::data_dir::test_env_lock();
328        let tmp = tempfile::tempdir().expect("tempdir");
329        std::env::set_var(
330            "LEAN_CTX_DATA_DIR",
331            tmp.path().to_string_lossy().to_string(),
332        );
333
334        let idx = KnowledgeEmbeddingIndex {
335            project_hash: "projhash".to_string(),
336            entries: vec![FactEmbedding {
337                category: "arch".to_string(),
338                key: "db".to_string(),
339                embedding: vec![1.0, 0.0, 0.0],
340            }],
341        };
342        idx.save().expect("save");
343        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
344
345        reset("projhash").expect("reset");
346        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
347
348        std::env::remove_var("LEAN_CTX_DATA_DIR");
349    }
350
351    #[test]
352    fn compact_drops_missing_or_archived_facts() {
353        let mut knowledge = ProjectKnowledge::new("/tmp/project");
354        let now = chrono::Utc::now();
355        knowledge.facts.push(KnowledgeFact {
356            category: "arch".to_string(),
357            key: "db".to_string(),
358            value: "Postgres".to_string(),
359            source_session: "s".to_string(),
360            confidence: 0.9,
361            created_at: now,
362            last_confirmed: now,
363            retrieval_count: 5,
364            last_retrieved: None,
365            valid_from: None,
366            valid_until: None,
367            supersedes: None,
368            confirmation_count: 1,
369        });
370        knowledge.facts.push(KnowledgeFact {
371            category: "arch".to_string(),
372            key: "old".to_string(),
373            value: "Old".to_string(),
374            source_session: "s".to_string(),
375            confidence: 0.9,
376            created_at: now,
377            last_confirmed: now,
378            retrieval_count: 0,
379            last_retrieved: None,
380            valid_from: None,
381            valid_until: Some(now),
382            supersedes: None,
383            confirmation_count: 1,
384        });
385
386        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
387        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
388        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
389        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
390
391        compact_against_knowledge(&mut idx, &knowledge);
392        assert_eq!(idx.entries.len(), 1);
393        assert_eq!(idx.entries[0].category, "arch");
394        assert_eq!(idx.entries[0].key, "db");
395    }
396
397    #[test]
398    fn index_upsert_and_remove() {
399        let mut idx = KnowledgeEmbeddingIndex::new("test");
400        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
401        assert_eq!(idx.entries.len(), 1);
402
403        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
404        assert_eq!(idx.entries.len(), 1);
405        assert_eq!(idx.entries[0].embedding[1], 1.0);
406
407        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
408        assert_eq!(idx.entries.len(), 2);
409
410        idx.remove("arch", "db");
411        assert_eq!(idx.entries.len(), 1);
412        assert_eq!(idx.entries[0].key, "cache");
413    }
414
415    #[test]
416    fn recency_decay_recent() {
417        let fact = KnowledgeFact {
418            category: "test".to_string(),
419            key: "k".to_string(),
420            value: "v".to_string(),
421            source_session: "s".to_string(),
422            confidence: 0.9,
423            created_at: chrono::Utc::now(),
424            last_confirmed: chrono::Utc::now(),
425            retrieval_count: 0,
426            last_retrieved: None,
427            valid_from: None,
428            valid_until: None,
429            supersedes: None,
430            confirmation_count: 1,
431        };
432        let decay = recency_decay(&fact);
433        assert!(
434            decay > 0.95,
435            "Recent fact should have high recency: {decay}"
436        );
437    }
438
439    #[test]
440    fn recency_decay_old() {
441        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
442        let fact = KnowledgeFact {
443            category: "test".to_string(),
444            key: "k".to_string(),
445            value: "v".to_string(),
446            source_session: "s".to_string(),
447            confidence: 0.5,
448            created_at: old_date,
449            last_confirmed: old_date,
450            retrieval_count: 0,
451            last_retrieved: None,
452            valid_from: None,
453            valid_until: None,
454            supersedes: None,
455            confirmation_count: 1,
456        };
457        let decay = recency_decay(&fact);
458        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
459    }
460
461    #[cfg(feature = "embeddings")]
462    #[test]
463    fn semantic_search_ranking() {
464        let mut idx = KnowledgeEmbeddingIndex::new("test");
465        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
466        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
467        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
468
469        let query = vec![1.0, 0.0, 0.0];
470        let results = idx.semantic_search(&query, 2);
471        assert_eq!(results.len(), 2);
472        assert_eq!(results[0].0.key, "db");
473    }
474
475    #[test]
476    fn format_scored_empty() {
477        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
478    }
479
480    #[test]
481    fn format_scored_output() {
482        let fact = KnowledgeFact {
483            category: "arch".to_string(),
484            key: "db".to_string(),
485            value: "PostgreSQL".to_string(),
486            source_session: "s1".to_string(),
487            confidence: 0.95,
488            created_at: chrono::Utc::now(),
489            last_confirmed: chrono::Utc::now(),
490            retrieval_count: 0,
491            last_retrieved: None,
492            valid_from: None,
493            valid_until: None,
494            supersedes: None,
495            confirmation_count: 3,
496        };
497        let scored = vec![ScoredFact {
498            fact: &fact,
499            score: 0.85,
500            semantic_score: 0.9,
501            confidence_score: 0.95,
502            recency_score: 1.0,
503        }];
504        let output = format_scored_facts(&scored);
505        assert!(output.contains("arch:db=PostgreSQL"));
506        assert!(output.contains("★★★★"));
507        assert!(output.contains("[s:85%]"));
508    }
509}