Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12use crate::core::memory_policy::MemoryPolicy;
13
14#[cfg(feature = "embeddings")]
15use super::embeddings::{cosine_similarity, EmbeddingEngine};
16
17const ALPHA_SEMANTIC: f32 = 0.6;
18const BETA_CONFIDENCE: f32 = 0.25;
19const GAMMA_RECENCY: f32 = 0.15;
20const MAX_RECENCY_DAYS: f32 = 90.0;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FactEmbedding {
24    pub category: String,
25    pub key: String,
26    pub embedding: Vec<f32>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct KnowledgeEmbeddingIndex {
31    pub project_hash: String,
32    pub entries: Vec<FactEmbedding>,
33}
34
35impl KnowledgeEmbeddingIndex {
36    pub fn new(project_hash: &str) -> Self {
37        Self {
38            project_hash: project_hash.to_string(),
39            entries: Vec::new(),
40        }
41    }
42
43    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
44        if let Some(existing) = self
45            .entries
46            .iter_mut()
47            .find(|e| e.category == category && e.key == key)
48        {
49            existing.embedding = embedding;
50        } else {
51            self.entries.push(FactEmbedding {
52                category: category.to_string(),
53                key: key.to_string(),
54                embedding,
55            });
56        }
57    }
58
59    pub fn remove(&mut self, category: &str, key: &str) {
60        self.entries
61            .retain(|e| !(e.category == category && e.key == key));
62    }
63
64    #[cfg(feature = "embeddings")]
65    pub fn semantic_search(
66        &self,
67        query_embedding: &[f32],
68        top_k: usize,
69    ) -> Vec<(&FactEmbedding, f32)> {
70        let mut scored: Vec<(&FactEmbedding, f32)> = self
71            .entries
72            .iter()
73            .map(|e| {
74                let sim = cosine_similarity(query_embedding, &e.embedding);
75                (e, sim)
76            })
77            .collect();
78
79        scored.sort_by(|a, b| {
80            b.1.partial_cmp(&a.1)
81                .unwrap_or(std::cmp::Ordering::Equal)
82                .then_with(|| a.0.category.cmp(&b.0.category))
83                .then_with(|| a.0.key.cmp(&b.0.key))
84        });
85        scored.truncate(top_k);
86        scored
87    }
88
89    fn index_path(project_hash: &str) -> Option<PathBuf> {
90        let dir = crate::core::data_dir::lean_ctx_data_dir()
91            .ok()?
92            .join("knowledge")
93            .join(project_hash);
94        Some(dir.join("embeddings.json"))
95    }
96
97    pub fn load(project_hash: &str) -> Option<Self> {
98        let path = Self::index_path(project_hash)?;
99        let data = std::fs::read_to_string(path).ok()?;
100        serde_json::from_str(&data).ok()
101    }
102
103    pub fn save(&self) -> Result<(), String> {
104        let path = Self::index_path(&self.project_hash)
105            .ok_or_else(|| "Cannot determine data directory".to_string())?;
106        if let Some(dir) = path.parent() {
107            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
108        }
109        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
110        std::fs::write(path, json).map_err(|e| format!("{e}"))
111    }
112}
113
114pub fn reset(project_hash: &str) -> Result<(), String> {
115    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
116        .ok_or_else(|| "Cannot determine data directory".to_string())?;
117    if path.exists() {
118        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
119    }
120    Ok(())
121}
122
123#[derive(Debug)]
124pub struct ScoredFact<'a> {
125    pub fact: &'a KnowledgeFact,
126    pub score: f32,
127    pub semantic_score: f32,
128    pub confidence_score: f32,
129    pub recency_score: f32,
130}
131
132#[cfg(feature = "embeddings")]
133pub fn semantic_recall<'a>(
134    knowledge: &'a ProjectKnowledge,
135    index: &KnowledgeEmbeddingIndex,
136    engine: &EmbeddingEngine,
137    query: &str,
138    top_k: usize,
139) -> Vec<ScoredFact<'a>> {
140    let Ok(query_embedding) = engine.embed(query) else {
141        return lexical_fallback(knowledge, query, top_k);
142    };
143
144    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
145
146    let mut results: Vec<ScoredFact<'a>> = Vec::new();
147
148    for (entry, sim) in &semantic_hits {
149        if let Some(fact) = knowledge
150            .facts
151            .iter()
152            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
153        {
154            let confidence_score = fact.quality_score();
155            let recency_score = recency_decay(fact);
156            let score = ALPHA_SEMANTIC * sim
157                + BETA_CONFIDENCE * confidence_score
158                + GAMMA_RECENCY * recency_score;
159
160            results.push(ScoredFact {
161                fact,
162                score,
163                semantic_score: *sim,
164                confidence_score,
165                recency_score,
166            });
167        }
168    }
169
170    let exact_matches = knowledge.recall(query);
171    for fact in exact_matches {
172        let already_included = results
173            .iter()
174            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
175        if !already_included {
176            results.push(ScoredFact {
177                fact,
178                score: 1.0,
179                semantic_score: 1.0,
180                confidence_score: fact.quality_score(),
181                recency_score: recency_decay(fact),
182            });
183        }
184    }
185
186    results.sort_by(|a, b| {
187        b.score
188            .partial_cmp(&a.score)
189            .unwrap_or(std::cmp::Ordering::Equal)
190            .then_with(|| {
191                b.confidence_score
192                    .partial_cmp(&a.confidence_score)
193                    .unwrap_or(std::cmp::Ordering::Equal)
194            })
195            .then_with(|| {
196                b.recency_score
197                    .partial_cmp(&a.recency_score)
198                    .unwrap_or(std::cmp::Ordering::Equal)
199            })
200            .then_with(|| a.fact.category.cmp(&b.fact.category))
201            .then_with(|| a.fact.key.cmp(&b.fact.key))
202            .then_with(|| a.fact.value.cmp(&b.fact.value))
203    });
204    results.truncate(top_k);
205    results
206}
207
208#[cfg(feature = "embeddings")]
209pub fn semantic_recall_semantic_only<'a>(
210    knowledge: &'a ProjectKnowledge,
211    index: &KnowledgeEmbeddingIndex,
212    engine: &EmbeddingEngine,
213    query: &str,
214    top_k: usize,
215) -> Vec<ScoredFact<'a>> {
216    let Ok(query_embedding) = engine.embed(query) else {
217        return Vec::new();
218    };
219
220    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
221    let mut results: Vec<ScoredFact<'a>> = Vec::new();
222
223    for (entry, sim) in &semantic_hits {
224        if let Some(fact) = knowledge
225            .facts
226            .iter()
227            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
228        {
229            let confidence_score = fact.quality_score();
230            let recency_score = recency_decay(fact);
231            let score = ALPHA_SEMANTIC * sim
232                + BETA_CONFIDENCE * confidence_score
233                + GAMMA_RECENCY * recency_score;
234
235            results.push(ScoredFact {
236                fact,
237                score,
238                semantic_score: *sim,
239                confidence_score,
240                recency_score,
241            });
242        }
243    }
244
245    results.sort_by(|a, b| {
246        b.score
247            .partial_cmp(&a.score)
248            .unwrap_or(std::cmp::Ordering::Equal)
249            .then_with(|| {
250                b.confidence_score
251                    .partial_cmp(&a.confidence_score)
252                    .unwrap_or(std::cmp::Ordering::Equal)
253            })
254            .then_with(|| {
255                b.recency_score
256                    .partial_cmp(&a.recency_score)
257                    .unwrap_or(std::cmp::Ordering::Equal)
258            })
259            .then_with(|| a.fact.category.cmp(&b.fact.category))
260            .then_with(|| a.fact.key.cmp(&b.fact.key))
261            .then_with(|| a.fact.value.cmp(&b.fact.value))
262    });
263    results.truncate(top_k);
264    results
265}
266
267pub fn compact_against_knowledge(
268    index: &mut KnowledgeEmbeddingIndex,
269    knowledge: &ProjectKnowledge,
270    policy: &MemoryPolicy,
271) {
272    use std::collections::HashMap;
273
274    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
275    for f in &knowledge.facts {
276        if f.is_current() {
277            current.insert((f.category.as_str(), f.key.as_str()), f);
278        }
279    }
280
281    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
282        .entries
283        .iter()
284        .filter_map(|e| {
285            current
286                .get(&(e.category.as_str(), e.key.as_str()))
287                .map(|f| (e.clone(), *f))
288        })
289        .collect();
290
291    kept.sort_by(|(ea, fa), (eb, fb)| {
292        fb.confidence
293            .partial_cmp(&fa.confidence)
294            .unwrap_or(std::cmp::Ordering::Equal)
295            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
296            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
297            .then_with(|| ea.category.cmp(&eb.category))
298            .then_with(|| ea.key.cmp(&eb.key))
299    });
300
301    let max = policy.embeddings.max_facts;
302    if kept.len() > max {
303        kept.truncate(max);
304    }
305
306    index.entries = kept.into_iter().map(|(e, _)| e).collect();
307}
308
309fn lexical_fallback<'a>(
310    knowledge: &'a ProjectKnowledge,
311    query: &str,
312    top_k: usize,
313) -> Vec<ScoredFact<'a>> {
314    knowledge
315        .recall(query)
316        .into_iter()
317        .take(top_k)
318        .map(|fact| ScoredFact {
319            fact,
320            score: fact.confidence,
321            semantic_score: 0.0,
322            confidence_score: fact.confidence,
323            recency_score: recency_decay(fact),
324        })
325        .collect()
326}
327
328fn recency_decay(fact: &KnowledgeFact) -> f32 {
329    let days_old = chrono::Utc::now()
330        .signed_duration_since(fact.last_confirmed)
331        .num_days() as f32;
332    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
333}
334
335#[cfg(feature = "embeddings")]
336pub fn embed_and_store(
337    index: &mut KnowledgeEmbeddingIndex,
338    engine: &EmbeddingEngine,
339    category: &str,
340    key: &str,
341    value: &str,
342) -> Result<(), String> {
343    let text = format!("{category} {key}: {value}");
344    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
345    index.upsert(category, key, embedding);
346    Ok(())
347}
348
349pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
350    if results.is_empty() {
351        return "No matching facts found.".to_string();
352    }
353
354    let mut output = String::new();
355    for (i, scored) in results.iter().enumerate() {
356        let f = scored.fact;
357        let stars = if f.confidence >= 0.9 {
358            "★★★★"
359        } else if f.confidence >= 0.7 {
360            "★★★"
361        } else if f.confidence >= 0.5 {
362            "★★"
363        } else {
364            "★"
365        };
366
367        if i > 0 {
368            output.push('|');
369        }
370        output.push_str(&format!(
371            "{}:{}={}{} [s:{:.0}%]",
372            f.category,
373            f.key,
374            f.value,
375            stars,
376            scored.score * 100.0
377        ));
378    }
379    output
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::core::knowledge::KnowledgeArchetype;
386
387    #[test]
388    fn reset_removes_index_file() {
389        let _lock = crate::core::data_dir::test_env_lock();
390        let tmp = tempfile::tempdir().expect("tempdir");
391        std::env::set_var(
392            "LEAN_CTX_DATA_DIR",
393            tmp.path().to_string_lossy().to_string(),
394        );
395
396        let idx = KnowledgeEmbeddingIndex {
397            project_hash: "projhash".to_string(),
398            entries: vec![FactEmbedding {
399                category: "arch".to_string(),
400                key: "db".to_string(),
401                embedding: vec![1.0, 0.0, 0.0],
402            }],
403        };
404        idx.save().expect("save");
405        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
406
407        reset("projhash").expect("reset");
408        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
409
410        std::env::remove_var("LEAN_CTX_DATA_DIR");
411    }
412
413    #[test]
414    fn compact_drops_missing_or_archived_facts() {
415        let mut knowledge = ProjectKnowledge::new("/tmp/project");
416        let now = chrono::Utc::now();
417        knowledge.facts.push(KnowledgeFact {
418            category: "arch".to_string(),
419            key: "db".to_string(),
420            value: "Postgres".to_string(),
421            source_session: "s".to_string(),
422            confidence: 0.9,
423            created_at: now,
424            last_confirmed: now,
425            retrieval_count: 5,
426            last_retrieved: None,
427            valid_from: None,
428            valid_until: None,
429            supersedes: None,
430            confirmation_count: 1,
431            feedback_up: 0,
432            feedback_down: 0,
433            last_feedback: None,
434            privacy: crate::core::memory_boundary::FactPrivacy::default(),
435            imported_from: None,
436            archetype: KnowledgeArchetype::default(),
437            fidelity: None,
438        });
439        knowledge.facts.push(KnowledgeFact {
440            category: "arch".to_string(),
441            key: "old".to_string(),
442            value: "Old".to_string(),
443            source_session: "s".to_string(),
444            confidence: 0.9,
445            created_at: now,
446            last_confirmed: now,
447            retrieval_count: 0,
448            last_retrieved: None,
449            valid_from: None,
450            valid_until: Some(now),
451            supersedes: None,
452            confirmation_count: 1,
453            feedback_up: 0,
454            feedback_down: 0,
455            last_feedback: None,
456            privacy: crate::core::memory_boundary::FactPrivacy::default(),
457            imported_from: None,
458            archetype: KnowledgeArchetype::default(),
459            fidelity: None,
460        });
461
462        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
463        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
464        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
465        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
466
467        compact_against_knowledge(&mut idx, &knowledge, &MemoryPolicy::default());
468        assert_eq!(idx.entries.len(), 1);
469        assert_eq!(idx.entries[0].category, "arch");
470        assert_eq!(idx.entries[0].key, "db");
471    }
472
473    #[test]
474    fn index_upsert_and_remove() {
475        let mut idx = KnowledgeEmbeddingIndex::new("test");
476        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
477        assert_eq!(idx.entries.len(), 1);
478
479        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
480        assert_eq!(idx.entries.len(), 1);
481        assert_eq!(idx.entries[0].embedding[1], 1.0);
482
483        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
484        assert_eq!(idx.entries.len(), 2);
485
486        idx.remove("arch", "db");
487        assert_eq!(idx.entries.len(), 1);
488        assert_eq!(idx.entries[0].key, "cache");
489    }
490
491    #[test]
492    fn recency_decay_recent() {
493        let fact = KnowledgeFact {
494            category: "test".to_string(),
495            key: "k".to_string(),
496            value: "v".to_string(),
497            source_session: "s".to_string(),
498            confidence: 0.9,
499            created_at: chrono::Utc::now(),
500            last_confirmed: chrono::Utc::now(),
501            retrieval_count: 0,
502            last_retrieved: None,
503            valid_from: None,
504            valid_until: None,
505            supersedes: None,
506            confirmation_count: 1,
507            feedback_up: 0,
508            feedback_down: 0,
509            last_feedback: None,
510            privacy: crate::core::memory_boundary::FactPrivacy::default(),
511            imported_from: None,
512            archetype: KnowledgeArchetype::default(),
513            fidelity: None,
514        };
515        let decay = recency_decay(&fact);
516        assert!(
517            decay > 0.95,
518            "Recent fact should have high recency: {decay}"
519        );
520    }
521
522    #[test]
523    fn recency_decay_old() {
524        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
525        let fact = KnowledgeFact {
526            category: "test".to_string(),
527            key: "k".to_string(),
528            value: "v".to_string(),
529            source_session: "s".to_string(),
530            confidence: 0.5,
531            created_at: old_date,
532            last_confirmed: old_date,
533            retrieval_count: 0,
534            last_retrieved: None,
535            valid_from: None,
536            valid_until: None,
537            supersedes: None,
538            confirmation_count: 1,
539            feedback_up: 0,
540            feedback_down: 0,
541            last_feedback: None,
542            privacy: crate::core::memory_boundary::FactPrivacy::default(),
543            imported_from: None,
544            archetype: KnowledgeArchetype::default(),
545            fidelity: None,
546        };
547        let decay = recency_decay(&fact);
548        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
549    }
550
551    #[cfg(feature = "embeddings")]
552    #[test]
553    fn semantic_search_ranking() {
554        let mut idx = KnowledgeEmbeddingIndex::new("test");
555        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
556        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
557        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
558
559        let query = vec![1.0, 0.0, 0.0];
560        let results = idx.semantic_search(&query, 2);
561        assert_eq!(results.len(), 2);
562        assert_eq!(results[0].0.key, "db");
563    }
564
565    #[test]
566    fn format_scored_empty() {
567        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
568    }
569
570    #[test]
571    fn format_scored_output() {
572        let fact = KnowledgeFact {
573            category: "arch".to_string(),
574            key: "db".to_string(),
575            value: "PostgreSQL".to_string(),
576            source_session: "s1".to_string(),
577            confidence: 0.95,
578            created_at: chrono::Utc::now(),
579            last_confirmed: chrono::Utc::now(),
580            retrieval_count: 0,
581            last_retrieved: None,
582            valid_from: None,
583            valid_until: None,
584            supersedes: None,
585            confirmation_count: 3,
586            feedback_up: 0,
587            feedback_down: 0,
588            last_feedback: None,
589            privacy: crate::core::memory_boundary::FactPrivacy::default(),
590            imported_from: None,
591            archetype: KnowledgeArchetype::default(),
592            fidelity: None,
593        };
594        let scored = vec![ScoredFact {
595            fact: &fact,
596            score: 0.85,
597            semantic_score: 0.9,
598            confidence_score: 0.95,
599            recency_score: 1.0,
600        }];
601        let output = format_scored_facts(&scored);
602        assert!(output.contains("arch:db=PostgreSQL"));
603        assert!(output.contains("★★★★"));
604        assert!(output.contains("[s:85%]"));
605    }
606}