Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12use crate::core::memory_policy::MemoryPolicy;
13
14#[cfg(feature = "embeddings")]
15use super::embeddings::{cosine_similarity, EmbeddingEngine};
16
17const ALPHA_SEMANTIC: f32 = 0.6;
18const BETA_CONFIDENCE: f32 = 0.25;
19const GAMMA_RECENCY: f32 = 0.15;
20const MAX_RECENCY_DAYS: f32 = 90.0;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FactEmbedding {
24    pub category: String,
25    pub key: String,
26    pub embedding: Vec<f32>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct KnowledgeEmbeddingIndex {
31    pub project_hash: String,
32    pub entries: Vec<FactEmbedding>,
33}
34
35impl KnowledgeEmbeddingIndex {
36    pub fn new(project_hash: &str) -> Self {
37        Self {
38            project_hash: project_hash.to_string(),
39            entries: Vec::new(),
40        }
41    }
42
43    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
44        if let Some(existing) = self
45            .entries
46            .iter_mut()
47            .find(|e| e.category == category && e.key == key)
48        {
49            existing.embedding = embedding;
50        } else {
51            self.entries.push(FactEmbedding {
52                category: category.to_string(),
53                key: key.to_string(),
54                embedding,
55            });
56        }
57    }
58
59    pub fn remove(&mut self, category: &str, key: &str) {
60        self.entries
61            .retain(|e| !(e.category == category && e.key == key));
62    }
63
64    #[cfg(feature = "embeddings")]
65    pub fn semantic_search(
66        &self,
67        query_embedding: &[f32],
68        top_k: usize,
69    ) -> Vec<(&FactEmbedding, f32)> {
70        let mut scored: Vec<(&FactEmbedding, f32)> = self
71            .entries
72            .iter()
73            .map(|e| {
74                let sim = cosine_similarity(query_embedding, &e.embedding);
75                (e, sim)
76            })
77            .collect();
78
79        scored.sort_by(|a, b| {
80            b.1.partial_cmp(&a.1)
81                .unwrap_or(std::cmp::Ordering::Equal)
82                .then_with(|| a.0.category.cmp(&b.0.category))
83                .then_with(|| a.0.key.cmp(&b.0.key))
84        });
85        scored.truncate(top_k);
86        scored
87    }
88
89    fn index_path(project_hash: &str) -> Option<PathBuf> {
90        let dir = crate::core::data_dir::lean_ctx_data_dir()
91            .ok()?
92            .join("knowledge")
93            .join(project_hash);
94        Some(dir.join("embeddings.json"))
95    }
96
97    pub fn load(project_hash: &str) -> Option<Self> {
98        let path = Self::index_path(project_hash)?;
99        let data = std::fs::read_to_string(path).ok()?;
100        serde_json::from_str(&data).ok()
101    }
102
103    pub fn save(&self) -> Result<(), String> {
104        let path = Self::index_path(&self.project_hash)
105            .ok_or_else(|| "Cannot determine data directory".to_string())?;
106        if let Some(dir) = path.parent() {
107            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
108        }
109        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
110        std::fs::write(path, json).map_err(|e| format!("{e}"))
111    }
112}
113
114pub fn reset(project_hash: &str) -> Result<(), String> {
115    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
116        .ok_or_else(|| "Cannot determine data directory".to_string())?;
117    if path.exists() {
118        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
119    }
120    Ok(())
121}
122
123#[derive(Debug)]
124pub struct ScoredFact<'a> {
125    pub fact: &'a KnowledgeFact,
126    pub score: f32,
127    pub semantic_score: f32,
128    pub confidence_score: f32,
129    pub recency_score: f32,
130}
131
132#[cfg(feature = "embeddings")]
133pub fn semantic_recall<'a>(
134    knowledge: &'a ProjectKnowledge,
135    index: &KnowledgeEmbeddingIndex,
136    engine: &EmbeddingEngine,
137    query: &str,
138    top_k: usize,
139) -> Vec<ScoredFact<'a>> {
140    let Ok(query_embedding) = engine.embed_query(query) else {
141        return lexical_fallback(knowledge, query, top_k);
142    };
143
144    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
145
146    let mut results: Vec<ScoredFact<'a>> = Vec::new();
147
148    for (entry, sim) in &semantic_hits {
149        if let Some(fact) = knowledge
150            .facts
151            .iter()
152            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
153        {
154            let confidence_score = fact.quality_score();
155            let recency_score = recency_decay(fact);
156            let score = ALPHA_SEMANTIC * sim
157                + BETA_CONFIDENCE * confidence_score
158                + GAMMA_RECENCY * recency_score;
159
160            results.push(ScoredFact {
161                fact,
162                score,
163                semantic_score: *sim,
164                confidence_score,
165                recency_score,
166            });
167        }
168    }
169
170    let exact_matches = knowledge.recall(query);
171    for fact in exact_matches {
172        let already_included = results
173            .iter()
174            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
175        if !already_included {
176            results.push(ScoredFact {
177                fact,
178                score: 1.0,
179                semantic_score: 1.0,
180                confidence_score: fact.quality_score(),
181                recency_score: recency_decay(fact),
182            });
183        }
184    }
185
186    results.sort_by(|a, b| {
187        b.score
188            .partial_cmp(&a.score)
189            .unwrap_or(std::cmp::Ordering::Equal)
190            .then_with(|| {
191                b.confidence_score
192                    .partial_cmp(&a.confidence_score)
193                    .unwrap_or(std::cmp::Ordering::Equal)
194            })
195            .then_with(|| {
196                b.recency_score
197                    .partial_cmp(&a.recency_score)
198                    .unwrap_or(std::cmp::Ordering::Equal)
199            })
200            .then_with(|| a.fact.category.cmp(&b.fact.category))
201            .then_with(|| a.fact.key.cmp(&b.fact.key))
202            .then_with(|| a.fact.value.cmp(&b.fact.value))
203    });
204    results.truncate(top_k);
205    results
206}
207
208#[cfg(feature = "embeddings")]
209pub fn semantic_recall_semantic_only<'a>(
210    knowledge: &'a ProjectKnowledge,
211    index: &KnowledgeEmbeddingIndex,
212    engine: &EmbeddingEngine,
213    query: &str,
214    top_k: usize,
215) -> Vec<ScoredFact<'a>> {
216    let Ok(query_embedding) = engine.embed_query(query) else {
217        return Vec::new();
218    };
219
220    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
221    let mut results: Vec<ScoredFact<'a>> = Vec::new();
222
223    for (entry, sim) in &semantic_hits {
224        if let Some(fact) = knowledge
225            .facts
226            .iter()
227            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
228        {
229            let confidence_score = fact.quality_score();
230            let recency_score = recency_decay(fact);
231            let score = ALPHA_SEMANTIC * sim
232                + BETA_CONFIDENCE * confidence_score
233                + GAMMA_RECENCY * recency_score;
234
235            results.push(ScoredFact {
236                fact,
237                score,
238                semantic_score: *sim,
239                confidence_score,
240                recency_score,
241            });
242        }
243    }
244
245    results.sort_by(|a, b| {
246        b.score
247            .partial_cmp(&a.score)
248            .unwrap_or(std::cmp::Ordering::Equal)
249            .then_with(|| {
250                b.confidence_score
251                    .partial_cmp(&a.confidence_score)
252                    .unwrap_or(std::cmp::Ordering::Equal)
253            })
254            .then_with(|| {
255                b.recency_score
256                    .partial_cmp(&a.recency_score)
257                    .unwrap_or(std::cmp::Ordering::Equal)
258            })
259            .then_with(|| a.fact.category.cmp(&b.fact.category))
260            .then_with(|| a.fact.key.cmp(&b.fact.key))
261            .then_with(|| a.fact.value.cmp(&b.fact.value))
262    });
263    results.truncate(top_k);
264    results
265}
266
267pub fn compact_against_knowledge(
268    index: &mut KnowledgeEmbeddingIndex,
269    knowledge: &ProjectKnowledge,
270    policy: &MemoryPolicy,
271) {
272    use std::collections::HashMap;
273
274    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
275    for f in &knowledge.facts {
276        if f.is_current() {
277            current.insert((f.category.as_str(), f.key.as_str()), f);
278        }
279    }
280
281    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
282        .entries
283        .iter()
284        .filter_map(|e| {
285            current
286                .get(&(e.category.as_str(), e.key.as_str()))
287                .map(|f| (e.clone(), *f))
288        })
289        .collect();
290
291    kept.sort_by(|(ea, fa), (eb, fb)| {
292        fb.confidence
293            .partial_cmp(&fa.confidence)
294            .unwrap_or(std::cmp::Ordering::Equal)
295            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
296            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
297            .then_with(|| ea.category.cmp(&eb.category))
298            .then_with(|| ea.key.cmp(&eb.key))
299    });
300
301    let max = policy.embeddings.max_facts;
302    if kept.len() > max {
303        kept.truncate(max);
304    }
305
306    index.entries = kept.into_iter().map(|(e, _)| e).collect();
307}
308
309fn lexical_fallback<'a>(
310    knowledge: &'a ProjectKnowledge,
311    query: &str,
312    top_k: usize,
313) -> Vec<ScoredFact<'a>> {
314    knowledge
315        .recall(query)
316        .into_iter()
317        .take(top_k)
318        .map(|fact| ScoredFact {
319            fact,
320            score: fact.confidence,
321            semantic_score: 0.0,
322            confidence_score: fact.confidence,
323            recency_score: recency_decay(fact),
324        })
325        .collect()
326}
327
328fn recency_decay(fact: &KnowledgeFact) -> f32 {
329    let days_old = chrono::Utc::now()
330        .signed_duration_since(fact.last_confirmed)
331        .num_days() as f32;
332    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
333}
334
335#[cfg(feature = "embeddings")]
336pub fn embed_and_store(
337    index: &mut KnowledgeEmbeddingIndex,
338    engine: &EmbeddingEngine,
339    category: &str,
340    key: &str,
341    value: &str,
342) -> Result<(), String> {
343    let text = format!("{category} {key}: {value}");
344    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
345    index.upsert(category, key, embedding);
346    Ok(())
347}
348
349pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
350    if results.is_empty() {
351        return "No matching facts found.".to_string();
352    }
353
354    let mut output = String::new();
355    for (i, scored) in results.iter().enumerate() {
356        let f = scored.fact;
357        let stars = if f.confidence >= 0.9 {
358            "★★★★"
359        } else if f.confidence >= 0.7 {
360            "★★★"
361        } else if f.confidence >= 0.5 {
362            "★★"
363        } else {
364            "★"
365        };
366
367        if i > 0 {
368            output.push('|');
369        }
370        output.push_str(&format!(
371            "{}:{}={}{} [s:{:.0}%]",
372            f.category,
373            f.key,
374            f.value,
375            stars,
376            scored.score * 100.0
377        ));
378    }
379    output
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::core::knowledge::KnowledgeArchetype;
386
387    #[test]
388    fn reset_removes_index_file() {
389        let _lock = crate::core::data_dir::test_env_lock();
390        let tmp = tempfile::tempdir().expect("tempdir");
391        std::env::set_var(
392            "LEAN_CTX_DATA_DIR",
393            tmp.path().to_string_lossy().to_string(),
394        );
395
396        let idx = KnowledgeEmbeddingIndex {
397            project_hash: "projhash".to_string(),
398            entries: vec![FactEmbedding {
399                category: "arch".to_string(),
400                key: "db".to_string(),
401                embedding: vec![1.0, 0.0, 0.0],
402            }],
403        };
404        idx.save().expect("save");
405        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
406
407        reset("projhash").expect("reset");
408        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
409
410        std::env::remove_var("LEAN_CTX_DATA_DIR");
411    }
412
413    #[test]
414    fn compact_drops_missing_or_archived_facts() {
415        let mut knowledge = ProjectKnowledge::new("/tmp/project");
416        let now = chrono::Utc::now();
417        knowledge.facts.push(KnowledgeFact {
418            category: "arch".to_string(),
419            key: "db".to_string(),
420            value: "Postgres".to_string(),
421            source_session: "s".to_string(),
422            confidence: 0.9,
423            created_at: now,
424            last_confirmed: now,
425            retrieval_count: 5,
426            last_retrieved: None,
427            valid_from: None,
428            valid_until: None,
429            supersedes: None,
430            confirmation_count: 1,
431            feedback_up: 0,
432            feedback_down: 0,
433            last_feedback: None,
434            privacy: crate::core::memory_boundary::FactPrivacy::default(),
435            imported_from: None,
436            archetype: KnowledgeArchetype::default(),
437            fidelity: None,
438            revision_count: 0,
439        });
440        knowledge.facts.push(KnowledgeFact {
441            category: "arch".to_string(),
442            key: "old".to_string(),
443            value: "Old".to_string(),
444            source_session: "s".to_string(),
445            confidence: 0.9,
446            created_at: now,
447            last_confirmed: now,
448            retrieval_count: 0,
449            last_retrieved: None,
450            valid_from: None,
451            valid_until: Some(now),
452            supersedes: None,
453            confirmation_count: 1,
454            feedback_up: 0,
455            feedback_down: 0,
456            last_feedback: None,
457            privacy: crate::core::memory_boundary::FactPrivacy::default(),
458            imported_from: None,
459            archetype: KnowledgeArchetype::default(),
460            fidelity: None,
461            revision_count: 0,
462        });
463
464        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
465        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
466        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
467        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
468
469        compact_against_knowledge(&mut idx, &knowledge, &MemoryPolicy::default());
470        assert_eq!(idx.entries.len(), 1);
471        assert_eq!(idx.entries[0].category, "arch");
472        assert_eq!(idx.entries[0].key, "db");
473    }
474
475    #[test]
476    fn index_upsert_and_remove() {
477        let mut idx = KnowledgeEmbeddingIndex::new("test");
478        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
479        assert_eq!(idx.entries.len(), 1);
480
481        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
482        assert_eq!(idx.entries.len(), 1);
483        assert_eq!(idx.entries[0].embedding[1], 1.0);
484
485        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
486        assert_eq!(idx.entries.len(), 2);
487
488        idx.remove("arch", "db");
489        assert_eq!(idx.entries.len(), 1);
490        assert_eq!(idx.entries[0].key, "cache");
491    }
492
493    #[test]
494    fn recency_decay_recent() {
495        let fact = KnowledgeFact {
496            category: "test".to_string(),
497            key: "k".to_string(),
498            value: "v".to_string(),
499            source_session: "s".to_string(),
500            confidence: 0.9,
501            created_at: chrono::Utc::now(),
502            last_confirmed: chrono::Utc::now(),
503            retrieval_count: 0,
504            last_retrieved: None,
505            valid_from: None,
506            valid_until: None,
507            supersedes: None,
508            confirmation_count: 1,
509            feedback_up: 0,
510            feedback_down: 0,
511            last_feedback: None,
512            privacy: crate::core::memory_boundary::FactPrivacy::default(),
513            imported_from: None,
514            archetype: KnowledgeArchetype::default(),
515            fidelity: None,
516            revision_count: 0,
517        };
518        let decay = recency_decay(&fact);
519        assert!(
520            decay > 0.95,
521            "Recent fact should have high recency: {decay}"
522        );
523    }
524
525    #[test]
526    fn recency_decay_old() {
527        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
528        let fact = KnowledgeFact {
529            category: "test".to_string(),
530            key: "k".to_string(),
531            value: "v".to_string(),
532            source_session: "s".to_string(),
533            confidence: 0.5,
534            created_at: old_date,
535            last_confirmed: old_date,
536            retrieval_count: 0,
537            last_retrieved: None,
538            valid_from: None,
539            valid_until: None,
540            supersedes: None,
541            confirmation_count: 1,
542            feedback_up: 0,
543            feedback_down: 0,
544            last_feedback: None,
545            privacy: crate::core::memory_boundary::FactPrivacy::default(),
546            imported_from: None,
547            archetype: KnowledgeArchetype::default(),
548            fidelity: None,
549            revision_count: 0,
550        };
551        let decay = recency_decay(&fact);
552        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
553    }
554
555    #[cfg(feature = "embeddings")]
556    #[test]
557    fn semantic_search_ranking() {
558        let mut idx = KnowledgeEmbeddingIndex::new("test");
559        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
560        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
561        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
562
563        let query = vec![1.0, 0.0, 0.0];
564        let results = idx.semantic_search(&query, 2);
565        assert_eq!(results.len(), 2);
566        assert_eq!(results[0].0.key, "db");
567    }
568
569    #[test]
570    fn format_scored_empty() {
571        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
572    }
573
574    #[test]
575    fn format_scored_output() {
576        let fact = KnowledgeFact {
577            category: "arch".to_string(),
578            key: "db".to_string(),
579            value: "PostgreSQL".to_string(),
580            source_session: "s1".to_string(),
581            confidence: 0.95,
582            created_at: chrono::Utc::now(),
583            last_confirmed: chrono::Utc::now(),
584            retrieval_count: 0,
585            last_retrieved: None,
586            valid_from: None,
587            valid_until: None,
588            supersedes: None,
589            confirmation_count: 3,
590            feedback_up: 0,
591            feedback_down: 0,
592            last_feedback: None,
593            privacy: crate::core::memory_boundary::FactPrivacy::default(),
594            imported_from: None,
595            archetype: KnowledgeArchetype::default(),
596            fidelity: None,
597            revision_count: 0,
598        };
599        let scored = vec![ScoredFact {
600            fact: &fact,
601            score: 0.85,
602            semantic_score: 0.9,
603            confidence_score: 0.95,
604            recency_score: 1.0,
605        }];
606        let output = format_scored_facts(&scored);
607        assert!(output.contains("arch:db=PostgreSQL"));
608        assert!(output.contains("★★★★"));
609        assert!(output.contains("[s:85%]"));
610    }
611}