Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12use crate::core::memory_policy::MemoryPolicy;
13
14#[cfg(feature = "embeddings")]
15use super::embeddings::{cosine_similarity, EmbeddingEngine};
16
17const ALPHA_SEMANTIC: f32 = 0.6;
18const BETA_CONFIDENCE: f32 = 0.25;
19const GAMMA_RECENCY: f32 = 0.15;
20const MAX_RECENCY_DAYS: f32 = 90.0;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FactEmbedding {
24    pub category: String,
25    pub key: String,
26    pub embedding: Vec<f32>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct KnowledgeEmbeddingIndex {
31    pub project_hash: String,
32    pub entries: Vec<FactEmbedding>,
33}
34
35impl KnowledgeEmbeddingIndex {
36    pub fn new(project_hash: &str) -> Self {
37        Self {
38            project_hash: project_hash.to_string(),
39            entries: Vec::new(),
40        }
41    }
42
43    pub fn upsert(&mut self, category: &str, key: &str, embedding: Vec<f32>) {
44        if let Some(existing) = self
45            .entries
46            .iter_mut()
47            .find(|e| e.category == category && e.key == key)
48        {
49            existing.embedding = embedding;
50        } else {
51            self.entries.push(FactEmbedding {
52                category: category.to_string(),
53                key: key.to_string(),
54                embedding,
55            });
56        }
57    }
58
59    pub fn remove(&mut self, category: &str, key: &str) {
60        self.entries
61            .retain(|e| !(e.category == category && e.key == key));
62    }
63
64    #[cfg(feature = "embeddings")]
65    pub fn semantic_search(
66        &self,
67        query_embedding: &[f32],
68        top_k: usize,
69    ) -> Vec<(&FactEmbedding, f32)> {
70        let mut scored: Vec<(&FactEmbedding, f32)> = self
71            .entries
72            .iter()
73            .map(|e| {
74                let sim = cosine_similarity(query_embedding, &e.embedding);
75                (e, sim)
76            })
77            .collect();
78
79        scored.sort_by(|a, b| {
80            b.1.partial_cmp(&a.1)
81                .unwrap_or(std::cmp::Ordering::Equal)
82                .then_with(|| a.0.category.cmp(&b.0.category))
83                .then_with(|| a.0.key.cmp(&b.0.key))
84        });
85        scored.truncate(top_k);
86        scored
87    }
88
89    fn index_path(project_hash: &str) -> Option<PathBuf> {
90        let dir = crate::core::data_dir::lean_ctx_data_dir()
91            .ok()?
92            .join("knowledge")
93            .join(project_hash);
94        Some(dir.join("embeddings.json"))
95    }
96
97    pub fn load(project_hash: &str) -> Option<Self> {
98        let path = Self::index_path(project_hash)?;
99        let data = std::fs::read_to_string(path).ok()?;
100        serde_json::from_str(&data).ok()
101    }
102
103    pub fn save(&self) -> Result<(), String> {
104        let path = Self::index_path(&self.project_hash)
105            .ok_or_else(|| "Cannot determine data directory".to_string())?;
106        if let Some(dir) = path.parent() {
107            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
108        }
109        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
110        std::fs::write(path, json).map_err(|e| format!("{e}"))
111    }
112}
113
114pub fn reset(project_hash: &str) -> Result<(), String> {
115    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
116        .ok_or_else(|| "Cannot determine data directory".to_string())?;
117    if path.exists() {
118        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
119    }
120    Ok(())
121}
122
123#[derive(Debug)]
124pub struct ScoredFact<'a> {
125    pub fact: &'a KnowledgeFact,
126    pub score: f32,
127    pub semantic_score: f32,
128    pub confidence_score: f32,
129    pub recency_score: f32,
130}
131
132#[cfg(feature = "embeddings")]
133pub fn semantic_recall<'a>(
134    knowledge: &'a ProjectKnowledge,
135    index: &KnowledgeEmbeddingIndex,
136    engine: &EmbeddingEngine,
137    query: &str,
138    top_k: usize,
139) -> Vec<ScoredFact<'a>> {
140    let Ok(query_embedding) = engine.embed(query) else {
141        return lexical_fallback(knowledge, query, top_k);
142    };
143
144    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
145
146    let mut results: Vec<ScoredFact<'a>> = Vec::new();
147
148    for (entry, sim) in &semantic_hits {
149        if let Some(fact) = knowledge
150            .facts
151            .iter()
152            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
153        {
154            let confidence_score = fact.quality_score();
155            let recency_score = recency_decay(fact);
156            let score = ALPHA_SEMANTIC * sim
157                + BETA_CONFIDENCE * confidence_score
158                + GAMMA_RECENCY * recency_score;
159
160            results.push(ScoredFact {
161                fact,
162                score,
163                semantic_score: *sim,
164                confidence_score,
165                recency_score,
166            });
167        }
168    }
169
170    let exact_matches = knowledge.recall(query);
171    for fact in exact_matches {
172        let already_included = results
173            .iter()
174            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
175        if !already_included {
176            results.push(ScoredFact {
177                fact,
178                score: 1.0,
179                semantic_score: 1.0,
180                confidence_score: fact.quality_score(),
181                recency_score: recency_decay(fact),
182            });
183        }
184    }
185
186    results.sort_by(|a, b| {
187        b.score
188            .partial_cmp(&a.score)
189            .unwrap_or(std::cmp::Ordering::Equal)
190            .then_with(|| {
191                b.confidence_score
192                    .partial_cmp(&a.confidence_score)
193                    .unwrap_or(std::cmp::Ordering::Equal)
194            })
195            .then_with(|| {
196                b.recency_score
197                    .partial_cmp(&a.recency_score)
198                    .unwrap_or(std::cmp::Ordering::Equal)
199            })
200            .then_with(|| a.fact.category.cmp(&b.fact.category))
201            .then_with(|| a.fact.key.cmp(&b.fact.key))
202            .then_with(|| a.fact.value.cmp(&b.fact.value))
203    });
204    results.truncate(top_k);
205    results
206}
207
208#[cfg(feature = "embeddings")]
209pub fn semantic_recall_semantic_only<'a>(
210    knowledge: &'a ProjectKnowledge,
211    index: &KnowledgeEmbeddingIndex,
212    engine: &EmbeddingEngine,
213    query: &str,
214    top_k: usize,
215) -> Vec<ScoredFact<'a>> {
216    let Ok(query_embedding) = engine.embed(query) else {
217        return Vec::new();
218    };
219
220    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
221    let mut results: Vec<ScoredFact<'a>> = Vec::new();
222
223    for (entry, sim) in &semantic_hits {
224        if let Some(fact) = knowledge
225            .facts
226            .iter()
227            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
228        {
229            let confidence_score = fact.quality_score();
230            let recency_score = recency_decay(fact);
231            let score = ALPHA_SEMANTIC * sim
232                + BETA_CONFIDENCE * confidence_score
233                + GAMMA_RECENCY * recency_score;
234
235            results.push(ScoredFact {
236                fact,
237                score,
238                semantic_score: *sim,
239                confidence_score,
240                recency_score,
241            });
242        }
243    }
244
245    results.sort_by(|a, b| {
246        b.score
247            .partial_cmp(&a.score)
248            .unwrap_or(std::cmp::Ordering::Equal)
249            .then_with(|| {
250                b.confidence_score
251                    .partial_cmp(&a.confidence_score)
252                    .unwrap_or(std::cmp::Ordering::Equal)
253            })
254            .then_with(|| {
255                b.recency_score
256                    .partial_cmp(&a.recency_score)
257                    .unwrap_or(std::cmp::Ordering::Equal)
258            })
259            .then_with(|| a.fact.category.cmp(&b.fact.category))
260            .then_with(|| a.fact.key.cmp(&b.fact.key))
261            .then_with(|| a.fact.value.cmp(&b.fact.value))
262    });
263    results.truncate(top_k);
264    results
265}
266
267pub fn compact_against_knowledge(
268    index: &mut KnowledgeEmbeddingIndex,
269    knowledge: &ProjectKnowledge,
270    policy: &MemoryPolicy,
271) {
272    use std::collections::HashMap;
273
274    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
275    for f in &knowledge.facts {
276        if f.is_current() {
277            current.insert((f.category.as_str(), f.key.as_str()), f);
278        }
279    }
280
281    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
282        .entries
283        .iter()
284        .filter_map(|e| {
285            current
286                .get(&(e.category.as_str(), e.key.as_str()))
287                .map(|f| (e.clone(), *f))
288        })
289        .collect();
290
291    kept.sort_by(|(ea, fa), (eb, fb)| {
292        fb.confidence
293            .partial_cmp(&fa.confidence)
294            .unwrap_or(std::cmp::Ordering::Equal)
295            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
296            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
297            .then_with(|| ea.category.cmp(&eb.category))
298            .then_with(|| ea.key.cmp(&eb.key))
299    });
300
301    let max = policy.embeddings.max_facts;
302    if kept.len() > max {
303        kept.truncate(max);
304    }
305
306    index.entries = kept.into_iter().map(|(e, _)| e).collect();
307}
308
309fn lexical_fallback<'a>(
310    knowledge: &'a ProjectKnowledge,
311    query: &str,
312    top_k: usize,
313) -> Vec<ScoredFact<'a>> {
314    knowledge
315        .recall(query)
316        .into_iter()
317        .take(top_k)
318        .map(|fact| ScoredFact {
319            fact,
320            score: fact.confidence,
321            semantic_score: 0.0,
322            confidence_score: fact.confidence,
323            recency_score: recency_decay(fact),
324        })
325        .collect()
326}
327
328fn recency_decay(fact: &KnowledgeFact) -> f32 {
329    let days_old = chrono::Utc::now()
330        .signed_duration_since(fact.last_confirmed)
331        .num_days() as f32;
332    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
333}
334
335#[cfg(feature = "embeddings")]
336pub fn embed_and_store(
337    index: &mut KnowledgeEmbeddingIndex,
338    engine: &EmbeddingEngine,
339    category: &str,
340    key: &str,
341    value: &str,
342) -> Result<(), String> {
343    let text = format!("{category} {key}: {value}");
344    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
345    index.upsert(category, key, embedding);
346    Ok(())
347}
348
349pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
350    if results.is_empty() {
351        return "No matching facts found.".to_string();
352    }
353
354    let mut output = String::new();
355    for (i, scored) in results.iter().enumerate() {
356        let f = scored.fact;
357        let stars = if f.confidence >= 0.9 {
358            "★★★★"
359        } else if f.confidence >= 0.7 {
360            "★★★"
361        } else if f.confidence >= 0.5 {
362            "★★"
363        } else {
364            "★"
365        };
366
367        if i > 0 {
368            output.push('|');
369        }
370        output.push_str(&format!(
371            "{}:{}={}{} [s:{:.0}%]",
372            f.category,
373            f.key,
374            f.value,
375            stars,
376            scored.score * 100.0
377        ));
378    }
379    output
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn reset_removes_index_file() {
388        let _lock = crate::core::data_dir::test_env_lock();
389        let tmp = tempfile::tempdir().expect("tempdir");
390        std::env::set_var(
391            "LEAN_CTX_DATA_DIR",
392            tmp.path().to_string_lossy().to_string(),
393        );
394
395        let idx = KnowledgeEmbeddingIndex {
396            project_hash: "projhash".to_string(),
397            entries: vec![FactEmbedding {
398                category: "arch".to_string(),
399                key: "db".to_string(),
400                embedding: vec![1.0, 0.0, 0.0],
401            }],
402        };
403        idx.save().expect("save");
404        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
405
406        reset("projhash").expect("reset");
407        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
408
409        std::env::remove_var("LEAN_CTX_DATA_DIR");
410    }
411
412    #[test]
413    fn compact_drops_missing_or_archived_facts() {
414        let mut knowledge = ProjectKnowledge::new("/tmp/project");
415        let now = chrono::Utc::now();
416        knowledge.facts.push(KnowledgeFact {
417            category: "arch".to_string(),
418            key: "db".to_string(),
419            value: "Postgres".to_string(),
420            source_session: "s".to_string(),
421            confidence: 0.9,
422            created_at: now,
423            last_confirmed: now,
424            retrieval_count: 5,
425            last_retrieved: None,
426            valid_from: None,
427            valid_until: None,
428            supersedes: None,
429            confirmation_count: 1,
430            feedback_up: 0,
431            feedback_down: 0,
432            last_feedback: None,
433            privacy: crate::core::memory_boundary::FactPrivacy::default(),
434            imported_from: None,
435        });
436        knowledge.facts.push(KnowledgeFact {
437            category: "arch".to_string(),
438            key: "old".to_string(),
439            value: "Old".to_string(),
440            source_session: "s".to_string(),
441            confidence: 0.9,
442            created_at: now,
443            last_confirmed: now,
444            retrieval_count: 0,
445            last_retrieved: None,
446            valid_from: None,
447            valid_until: Some(now),
448            supersedes: None,
449            confirmation_count: 1,
450            feedback_up: 0,
451            feedback_down: 0,
452            last_feedback: None,
453            privacy: crate::core::memory_boundary::FactPrivacy::default(),
454            imported_from: None,
455        });
456
457        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
458        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
459        idx.upsert("arch", "old", vec![0.0, 1.0, 0.0]);
460        idx.upsert("ops", "deploy", vec![0.0, 0.0, 1.0]);
461
462        compact_against_knowledge(&mut idx, &knowledge, &MemoryPolicy::default());
463        assert_eq!(idx.entries.len(), 1);
464        assert_eq!(idx.entries[0].category, "arch");
465        assert_eq!(idx.entries[0].key, "db");
466    }
467
468    #[test]
469    fn index_upsert_and_remove() {
470        let mut idx = KnowledgeEmbeddingIndex::new("test");
471        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
472        assert_eq!(idx.entries.len(), 1);
473
474        idx.upsert("arch", "db", vec![0.0, 1.0, 0.0]);
475        assert_eq!(idx.entries.len(), 1);
476        assert_eq!(idx.entries[0].embedding[1], 1.0);
477
478        idx.upsert("arch", "cache", vec![0.0, 0.0, 1.0]);
479        assert_eq!(idx.entries.len(), 2);
480
481        idx.remove("arch", "db");
482        assert_eq!(idx.entries.len(), 1);
483        assert_eq!(idx.entries[0].key, "cache");
484    }
485
486    #[test]
487    fn recency_decay_recent() {
488        let fact = KnowledgeFact {
489            category: "test".to_string(),
490            key: "k".to_string(),
491            value: "v".to_string(),
492            source_session: "s".to_string(),
493            confidence: 0.9,
494            created_at: chrono::Utc::now(),
495            last_confirmed: chrono::Utc::now(),
496            retrieval_count: 0,
497            last_retrieved: None,
498            valid_from: None,
499            valid_until: None,
500            supersedes: None,
501            confirmation_count: 1,
502            feedback_up: 0,
503            feedback_down: 0,
504            last_feedback: None,
505            privacy: crate::core::memory_boundary::FactPrivacy::default(),
506            imported_from: None,
507        };
508        let decay = recency_decay(&fact);
509        assert!(
510            decay > 0.95,
511            "Recent fact should have high recency: {decay}"
512        );
513    }
514
515    #[test]
516    fn recency_decay_old() {
517        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
518        let fact = KnowledgeFact {
519            category: "test".to_string(),
520            key: "k".to_string(),
521            value: "v".to_string(),
522            source_session: "s".to_string(),
523            confidence: 0.5,
524            created_at: old_date,
525            last_confirmed: old_date,
526            retrieval_count: 0,
527            last_retrieved: None,
528            valid_from: None,
529            valid_until: None,
530            supersedes: None,
531            confirmation_count: 1,
532            feedback_up: 0,
533            feedback_down: 0,
534            last_feedback: None,
535            privacy: crate::core::memory_boundary::FactPrivacy::default(),
536            imported_from: None,
537        };
538        let decay = recency_decay(&fact);
539        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
540    }
541
542    #[cfg(feature = "embeddings")]
543    #[test]
544    fn semantic_search_ranking() {
545        let mut idx = KnowledgeEmbeddingIndex::new("test");
546        idx.upsert("arch", "db", vec![1.0, 0.0, 0.0]);
547        idx.upsert("arch", "cache", vec![0.0, 1.0, 0.0]);
548        idx.upsert("ops", "deploy", vec![0.5, 0.5, 0.0]);
549
550        let query = vec![1.0, 0.0, 0.0];
551        let results = idx.semantic_search(&query, 2);
552        assert_eq!(results.len(), 2);
553        assert_eq!(results[0].0.key, "db");
554    }
555
556    #[test]
557    fn format_scored_empty() {
558        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
559    }
560
561    #[test]
562    fn format_scored_output() {
563        let fact = KnowledgeFact {
564            category: "arch".to_string(),
565            key: "db".to_string(),
566            value: "PostgreSQL".to_string(),
567            source_session: "s1".to_string(),
568            confidence: 0.95,
569            created_at: chrono::Utc::now(),
570            last_confirmed: chrono::Utc::now(),
571            retrieval_count: 0,
572            last_retrieved: None,
573            valid_from: None,
574            valid_until: None,
575            supersedes: None,
576            confirmation_count: 3,
577            feedback_up: 0,
578            feedback_down: 0,
579            last_feedback: None,
580            privacy: crate::core::memory_boundary::FactPrivacy::default(),
581            imported_from: None,
582        };
583        let scored = vec![ScoredFact {
584            fact: &fact,
585            score: 0.85,
586            semantic_score: 0.9,
587            confidence_score: 0.95,
588            recency_score: 1.0,
589        }];
590        let output = format_scored_facts(&scored);
591        assert!(output.contains("arch:db=PostgreSQL"));
592        assert!(output.contains("★★★★"));
593        assert!(output.contains("[s:85%]"));
594    }
595}