Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12use crate::core::memory_policy::MemoryPolicy;
13
14#[cfg(feature = "embeddings")]
15use super::embeddings::{cosine_similarity, EmbeddingEngine};
16
17const ALPHA_SEMANTIC: f32 = 0.6;
18const BETA_CONFIDENCE: f32 = 0.25;
19const GAMMA_RECENCY: f32 = 0.15;
20const MAX_RECENCY_DAYS: f32 = 90.0;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FactEmbedding {
24    pub category: String,
25    pub key: String,
26    pub embedding: Vec<f32>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct KnowledgeEmbeddingIndex {
31    pub project_hash: String,
32    pub entries: Vec<FactEmbedding>,
33}
34
35impl KnowledgeEmbeddingIndex {
36    pub fn new(project_hash: &str) -> Self {
37        Self {
38            project_hash: project_hash.to_string(),
39            entries: Vec::new(),
40        }
41    }
42
43    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
44        if let Some(existing) = self
45            .entries
46            .iter_mut()
47            .find(|e| e.category == category && e.key == key)
48        {
49            existing.embedding = embedding;
50        } else {
51            self.entries.push(FactEmbedding {
52                category: category.to_string(),
53                key: key.to_string(),
54                embedding,
55            });
56        }
57    }
58
59    pub fn remove(&mut self, category: &str, key: &str) {
60        self.entries
61            .retain(|e| !(e.category == category && e.key == key));
62    }
63
64    #[cfg(feature = "embeddings")]
65    pub fn semantic_search(
66        &self,
67        query_embedding: &[f32],
68        top_k: usize,
69    ) -> Vec<(&FactEmbedding, f32)> {
70        let mut scored: Vec<(&FactEmbedding, f32)> = self
71            .entries
72            .iter()
73            .map(|e| {
74                let sim = cosine_similarity(query_embedding, &e.embedding);
75                (e, sim)
76            })
77            .collect();
78
79        scored.sort_by(|a, b| {
80            b.1.partial_cmp(&a.1)
81                .unwrap_or(std::cmp::Ordering::Equal)
82                .then_with(|| a.0.category.cmp(&b.0.category))
83                .then_with(|| a.0.key.cmp(&b.0.key))
84        });
85        scored.truncate(top_k);
86        scored
87    }
88
89    fn index_path(project_hash: &str) -> Option<PathBuf> {
90        let dir = crate::core::data_dir::lean_ctx_data_dir()
91            .ok()?
92            .join("knowledge")
93            .join(project_hash);
94        Some(dir.join("embeddings.json"))
95    }
96
97    pub fn load(project_hash: &str) -> Option<Self> {
98        let path = Self::index_path(project_hash)?;
99        let data = std::fs::read_to_string(path).ok()?;
100        serde_json::from_str(&data).ok()
101    }
102
103    pub fn save(&self) -> Result<(), String> {
104        let path = Self::index_path(&self.project_hash)
105            .ok_or_else(|| "Cannot determine data directory".to_string())?;
106        if let Some(dir) = path.parent() {
107            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
108        }
109        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
110        std::fs::write(path, json).map_err(|e| format!("{e}"))
111    }
112}
113
114pub fn reset(project_hash: &str) -> Result<(), String> {
115    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
116        .ok_or_else(|| "Cannot determine data directory".to_string())?;
117    if path.exists() {
118        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
119    }
120    Ok(())
121}
122
123#[derive(Debug)]
124pub struct ScoredFact<'a> {
125    pub fact: &'a KnowledgeFact,
126    pub score: f32,
127    pub semantic_score: f32,
128    pub confidence_score: f32,
129    pub recency_score: f32,
130}
131
132#[cfg(feature = "embeddings")]
133pub fn semantic_recall<'a>(
134    knowledge: &'a ProjectKnowledge,
135    index: &KnowledgeEmbeddingIndex,
136    engine: &EmbeddingEngine,
137    query: &str,
138    top_k: usize,
139) -> Vec<ScoredFact<'a>> {
140    let Ok(query_embedding) = engine.embed(query) else {
141        return lexical_fallback(knowledge, query, top_k);
142    };
143
144    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
145
146    let mut results: Vec<ScoredFact<'a>> = Vec::new();
147
148    for (entry, sim) in &semantic_hits {
149        if let Some(fact) = knowledge
150            .facts
151            .iter()
152            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
153        {
154            let confidence_score = fact.quality_score();
155            let recency_score = recency_decay(fact);
156            let score = ALPHA_SEMANTIC * sim
157                + BETA_CONFIDENCE * confidence_score
158                + GAMMA_RECENCY * recency_score;
159
160            results.push(ScoredFact {
161                fact,
162                score,
163                semantic_score: *sim,
164                confidence_score,
165                recency_score,
166            });
167        }
168    }
169
170    let exact_matches = knowledge.recall(query);
171    for fact in exact_matches {
172        let already_included = results
173            .iter()
174            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
175        if !already_included {
176            results.push(ScoredFact {
177                fact,
178                score: 1.0,
179                semantic_score: 1.0,
180                confidence_score: fact.quality_score(),
181                recency_score: recency_decay(fact),
182            });
183        }
184    }
185
186    results.sort_by(|a, b| {
187        b.score
188            .partial_cmp(&a.score)
189            .unwrap_or(std::cmp::Ordering::Equal)
190            .then_with(|| {
191                b.confidence_score
192                    .partial_cmp(&a.confidence_score)
193                    .unwrap_or(std::cmp::Ordering::Equal)
194            })
195            .then_with(|| {
196                b.recency_score
197                    .partial_cmp(&a.recency_score)
198                    .unwrap_or(std::cmp::Ordering::Equal)
199            })
200            .then_with(|| a.fact.category.cmp(&b.fact.category))
201            .then_with(|| a.fact.key.cmp(&b.fact.key))
202            .then_with(|| a.fact.value.cmp(&b.fact.value))
203    });
204    results.truncate(top_k);
205    results
206}
207
208#[cfg(feature = "embeddings")]
209pub fn semantic_recall_semantic_only<'a>(
210    knowledge: &'a ProjectKnowledge,
211    index: &KnowledgeEmbeddingIndex,
212    engine: &EmbeddingEngine,
213    query: &str,
214    top_k: usize,
215) -> Vec<ScoredFact<'a>> {
216    let Ok(query_embedding) = engine.embed(query) else {
217        return Vec::new();
218    };
219
220    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
221    let mut results: Vec<ScoredFact<'a>> = Vec::new();
222
223    for (entry, sim) in &semantic_hits {
224        if let Some(fact) = knowledge
225            .facts
226            .iter()
227            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
228        {
229            let confidence_score = fact.quality_score();
230            let recency_score = recency_decay(fact);
231            let score = ALPHA_SEMANTIC * sim
232                + BETA_CONFIDENCE * confidence_score
233                + GAMMA_RECENCY * recency_score;
234
235            results.push(ScoredFact {
236                fact,
237                score,
238                semantic_score: *sim,
239                confidence_score,
240                recency_score,
241            });
242        }
243    }
244
245    results.sort_by(|a, b| {
246        b.score
247            .partial_cmp(&a.score)
248            .unwrap_or(std::cmp::Ordering::Equal)
249            .then_with(|| {
250                b.confidence_score
251                    .partial_cmp(&a.confidence_score)
252                    .unwrap_or(std::cmp::Ordering::Equal)
253            })
254            .then_with(|| {
255                b.recency_score
256                    .partial_cmp(&a.recency_score)
257                    .unwrap_or(std::cmp::Ordering::Equal)
258            })
259            .then_with(|| a.fact.category.cmp(&b.fact.category))
260            .then_with(|| a.fact.key.cmp(&b.fact.key))
261            .then_with(|| a.fact.value.cmp(&b.fact.value))
262    });
263    results.truncate(top_k);
264    results
265}
266
267pub fn compact_against_knowledge(
268    index: &mut KnowledgeEmbeddingIndex,
269    knowledge: &ProjectKnowledge,
270    policy: &MemoryPolicy,
271) {
272    use std::collections::HashMap;
273
274    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
275    for f in &knowledge.facts {
276        if f.is_current() {
277            current.insert((f.category.as_str(), f.key.as_str()), f);
278        }
279    }
280
281    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
282        .entries
283        .iter()
284        .filter_map(|e| {
285            current
286                .get(&(e.category.as_str(), e.key.as_str()))
287                .map(|f| (e.clone(), *f))
288        })
289        .collect();
290
291    kept.sort_by(|(ea, fa), (eb, fb)| {
292        fb.confidence
293            .partial_cmp(&fa.confidence)
294            .unwrap_or(std::cmp::Ordering::Equal)
295            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
296            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
297            .then_with(|| ea.category.cmp(&eb.category))
298            .then_with(|| ea.key.cmp(&eb.key))
299    });
300
301    let max = policy.embeddings.max_facts;
302    if kept.len() > max {
303        kept.truncate(max);
304    }
305
306    index.entries = kept.into_iter().map(|(e, _)| e).collect();
307}
308
309fn lexical_fallback<'a>(
310    knowledge: &'a ProjectKnowledge,
311    query: &str,
312    top_k: usize,
313) -> Vec<ScoredFact<'a>> {
314    knowledge
315        .recall(query)
316        .into_iter()
317        .take(top_k)
318        .map(|fact| ScoredFact {
319            fact,
320            score: fact.confidence,
321            semantic_score: 0.0,
322            confidence_score: fact.confidence,
323            recency_score: recency_decay(fact),
324        })
325        .collect()
326}
327
328fn recency_decay(fact: &KnowledgeFact) -> f32 {
329    let days_old = chrono::Utc::now()
330        .signed_duration_since(fact.last_confirmed)
331        .num_days() as f32;
332    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
333}
334
335#[cfg(feature = "embeddings")]
336pub fn embed_and_store(
337    index: &mut KnowledgeEmbeddingIndex,
338    engine: &EmbeddingEngine,
339    category: &str,
340    key: &str,
341    value: &str,
342) -> Result<(), String> {
343    let text = format!("{category} {key}: {value}");
344    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
345    index.upsert(category, key, embedding);
346    Ok(())
347}
348
349pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
350    if results.is_empty() {
351        return "No matching facts found.".to_string();
352    }
353
354    let mut output = String::new();
355    for (i, scored) in results.iter().enumerate() {
356        let f = scored.fact;
357        let stars = if f.confidence >= 0.9 {
358            "★★★★"
359        } else if f.confidence >= 0.7 {
360            "★★★"
361        } else if f.confidence >= 0.5 {
362            "★★"
363        } else {
364            "★"
365        };
366
367        if i > 0 {
368            output.push('|');
369        }
370        output.push_str(&format!(
371            "{}:{}={}{} [s:{:.0}%]",
372            f.category,
373            f.key,
374            f.value,
375            stars,
376            scored.score * 100.0
377        ));
378    }
379    output
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn reset_removes_index_file() {
388        let _lock = crate::core::data_dir::test_env_lock();
389        let tmp = tempfile::tempdir().expect("tempdir");
390        std::env::set_var(
391            "LEAN_CTX_DATA_DIR",
392            tmp.path().to_string_lossy().to_string(),
393        );
394
395        let idx = KnowledgeEmbeddingIndex {
396            project_hash: "projhash".to_string(),
397            entries: vec![FactEmbedding {
398                category: "arch".to_string(),
399                key: "db".to_string(),
400                embedding: vec![1.0, 0.0, 0.0],
401            }],
402        };
403        idx.save().expect("save");
404        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
405
406        reset("projhash").expect("reset");
407        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
408
409        std::env::remove_var("LEAN_CTX_DATA_DIR");
410    }
411
412    #[test]
413    fn compact_drops_missing_or_archived_facts() {
414        let mut knowledge = ProjectKnowledge::new("/tmp/project");
415        let now = chrono::Utc::now();
416        knowledge.facts.push(KnowledgeFact {
417            category: "arch".to_string(),
418            key: "db".to_string(),
419            value: "Postgres".to_string(),
420            source_session: "s".to_string(),
421            confidence: 0.9,
422            created_at: now,
423            last_confirmed: now,
424            retrieval_count: 5,
425            last_retrieved: None,
426            valid_from: None,
427            valid_until: None,
428            supersedes: None,
429            confirmation_count: 1,
430            feedback_up: 0,
431            feedback_down: 0,
432            last_feedback: None,
433        });
434        knowledge.facts.push(KnowledgeFact {
435            category: "arch".to_string(),
436            key: "old".to_string(),
437            value: "Old".to_string(),
438            source_session: "s".to_string(),
439            confidence: 0.9,
440            created_at: now,
441            last_confirmed: now,
442            retrieval_count: 0,
443            last_retrieved: None,
444            valid_from: None,
445            valid_until: Some(now),
446            supersedes: None,
447            confirmation_count: 1,
448            feedback_up: 0,
449            feedback_down: 0,
450            last_feedback: None,
451        });
452
453        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
454        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
455        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
456        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
457
458        compact_against_knowledge(&mut idx, &knowledge, &MemoryPolicy::default());
459        assert_eq!(idx.entries.len(), 1);
460        assert_eq!(idx.entries[0].category, "arch");
461        assert_eq!(idx.entries[0].key, "db");
462    }
463
464    #[test]
465    fn index_upsert_and_remove() {
466        let mut idx = KnowledgeEmbeddingIndex::new("test");
467        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
468        assert_eq!(idx.entries.len(), 1);
469
470        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
471        assert_eq!(idx.entries.len(), 1);
472        assert_eq!(idx.entries[0].embedding[1], 1.0);
473
474        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
475        assert_eq!(idx.entries.len(), 2);
476
477        idx.remove("arch", "db");
478        assert_eq!(idx.entries.len(), 1);
479        assert_eq!(idx.entries[0].key, "cache");
480    }
481
482    #[test]
483    fn recency_decay_recent() {
484        let fact = KnowledgeFact {
485            category: "test".to_string(),
486            key: "k".to_string(),
487            value: "v".to_string(),
488            source_session: "s".to_string(),
489            confidence: 0.9,
490            created_at: chrono::Utc::now(),
491            last_confirmed: chrono::Utc::now(),
492            retrieval_count: 0,
493            last_retrieved: None,
494            valid_from: None,
495            valid_until: None,
496            supersedes: None,
497            confirmation_count: 1,
498            feedback_up: 0,
499            feedback_down: 0,
500            last_feedback: None,
501        };
502        let decay = recency_decay(&fact);
503        assert!(
504            decay > 0.95,
505            "Recent fact should have high recency: {decay}"
506        );
507    }
508
509    #[test]
510    fn recency_decay_old() {
511        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
512        let fact = KnowledgeFact {
513            category: "test".to_string(),
514            key: "k".to_string(),
515            value: "v".to_string(),
516            source_session: "s".to_string(),
517            confidence: 0.5,
518            created_at: old_date,
519            last_confirmed: old_date,
520            retrieval_count: 0,
521            last_retrieved: None,
522            valid_from: None,
523            valid_until: None,
524            supersedes: None,
525            confirmation_count: 1,
526            feedback_up: 0,
527            feedback_down: 0,
528            last_feedback: None,
529        };
530        let decay = recency_decay(&fact);
531        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
532    }
533
534    #[cfg(feature = "embeddings")]
535    #[test]
536    fn semantic_search_ranking() {
537        let mut idx = KnowledgeEmbeddingIndex::new("test");
538        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
539        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
540        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
541
542        let query = vec![1.0, 0.0, 0.0];
543        let results = idx.semantic_search(&query, 2);
544        assert_eq!(results.len(), 2);
545        assert_eq!(results[0].0.key, "db");
546    }
547
548    #[test]
549    fn format_scored_empty() {
550        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
551    }
552
553    #[test]
554    fn format_scored_output() {
555        let fact = KnowledgeFact {
556            category: "arch".to_string(),
557            key: "db".to_string(),
558            value: "PostgreSQL".to_string(),
559            source_session: "s1".to_string(),
560            confidence: 0.95,
561            created_at: chrono::Utc::now(),
562            last_confirmed: chrono::Utc::now(),
563            retrieval_count: 0,
564            last_retrieved: None,
565            valid_from: None,
566            valid_until: None,
567            supersedes: None,
568            confirmation_count: 3,
569            feedback_up: 0,
570            feedback_down: 0,
571            last_feedback: None,
572        };
573        let scored = vec![ScoredFact {
574            fact: &fact,
575            score: 0.85,
576            semantic_score: 0.9,
577            confidence_score: 0.95,
578            recency_score: 1.0,
579        }];
580        let output = format_scored_facts(&scored);
581        assert!(output.contains("arch:db=PostgreSQL"));
582        assert!(output.contains("★★★★"));
583        assert!(output.contains("[s:85%]"));
584    }
585}