Skip to main content

lean_ctx/core/
knowledge_embedding.rs

1//! Embedding-based Knowledge Retrieval for `ctx_knowledge`.
2//!
3//! Wraps `ProjectKnowledge` with a vector index for semantic recall.
4//! Facts are automatically embedded on `remember` and searched via
5//! cosine similarity on `recall`, with hybrid exact + semantic ranking.
6
7use std::path::PathBuf;
8
9use serde::{Deserialize, Serialize};
10
11use super::knowledge::{KnowledgeFact, ProjectKnowledge};
12use crate::core::embedding_quant::{self, QuantizedVector};
13use crate::core::memory_policy::MemoryPolicy;
14
15#[cfg(feature = "embeddings")]
16use super::embeddings::EmbeddingEngine;
17
18const ALPHA_SEMANTIC: f32 = 0.6;
19const BETA_CONFIDENCE: f32 = 0.25;
20const GAMMA_RECENCY: f32 = 0.15;
21const MAX_RECENCY_DAYS: f32 = 90.0;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FactEmbedding {
25    pub category: String,
26    pub key: String,
27    /// Legacy full-precision vector (indices written before int8 quantization).
28    /// Migrated to `quant` transparently on load and then emptied, so it only
29    /// appears in files written by older binaries.
30    #[serde(default, skip_serializing_if = "Vec::is_empty")]
31    pub embedding: Vec<f32>,
32    /// int8-quantized representation (turbovec-derived) — 4× smaller on disk and
33    /// the canonical storage for every entry written by current binaries.
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub quant: Option<QuantizedVector>,
36}
37
38impl FactEmbedding {
39    /// Similarity against a full-precision (L2-normalized) query. Scores directly
40    /// against the int8 codes when available; falls back to the legacy f32 vector
41    /// for not-yet-migrated entries.
42    fn similarity(&self, query: &[f32]) -> f32 {
43        match &self.quant {
44            Some(q) => embedding_quant::dot_quant(query, q),
45            None => embedding_quant::dot_f32(query, &self.embedding),
46        }
47    }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct KnowledgeEmbeddingIndex {
52    pub project_hash: String,
53    pub entries: Vec<FactEmbedding>,
54}
55
56impl KnowledgeEmbeddingIndex {
57    pub fn new(project_hash: &str) -> Self {
58        Self {
59            project_hash: project_hash.to_string(),
60            entries: Vec::new(),
61        }
62    }
63
64    pub fn upsert(&mut self, category: &str, key: &str, embedding: &[f32]) {
65        let quant = Some(embedding_quant::quantize(embedding));
66        if let Some(existing) = self
67            .entries
68            .iter_mut()
69            .find(|e| e.category == category && e.key == key)
70        {
71            existing.quant = quant;
72            existing.embedding = Vec::new();
73        } else {
74            self.entries.push(FactEmbedding {
75                category: category.to_string(),
76                key: key.to_string(),
77                embedding: Vec::new(),
78                quant,
79            });
80        }
81    }
82
83    /// Upgrades any legacy full-precision entries to int8 in place. Returns true
84    /// if anything changed (so the caller can persist the smaller form once).
85    fn migrate_legacy_entries(&mut self) -> bool {
86        let mut changed = false;
87        for e in &mut self.entries {
88            if e.quant.is_none() && !e.embedding.is_empty() {
89                e.quant = Some(embedding_quant::quantize(&e.embedding));
90                e.embedding = Vec::new();
91                changed = true;
92            }
93        }
94        changed
95    }
96
97    pub fn remove(&mut self, category: &str, key: &str) {
98        self.entries
99            .retain(|e| !(e.category == category && e.key == key));
100    }
101
102    #[cfg(feature = "embeddings")]
103    pub fn semantic_search(
104        &self,
105        query_embedding: &[f32],
106        top_k: usize,
107    ) -> Vec<(&FactEmbedding, f32)> {
108        let mut scored: Vec<(&FactEmbedding, f32)> = self
109            .entries
110            .iter()
111            .map(|e| {
112                let sim = e.similarity(query_embedding);
113                (e, sim)
114            })
115            .collect();
116
117        scored.sort_by(|a, b| {
118            b.1.partial_cmp(&a.1)
119                .unwrap_or(std::cmp::Ordering::Equal)
120                .then_with(|| a.0.category.cmp(&b.0.category))
121                .then_with(|| a.0.key.cmp(&b.0.key))
122        });
123        scored.truncate(top_k);
124        scored
125    }
126
127    fn index_path(project_hash: &str) -> Option<PathBuf> {
128        let dir = crate::core::data_dir::lean_ctx_data_dir()
129            .ok()?
130            .join("knowledge")
131            .join(project_hash);
132        Some(dir.join("embeddings.json"))
133    }
134
135    pub fn load(project_hash: &str) -> Option<Self> {
136        let path = Self::index_path(project_hash)?;
137        let data = std::fs::read_to_string(path).ok()?;
138        let mut index: Self = serde_json::from_str(&data).ok()?;
139        // Pay the one-time int8 migration cost on first load by an upgraded binary,
140        // then persist so subsequent loads read the 4×-smaller form.
141        if index.migrate_legacy_entries() {
142            let _ = index.save();
143        }
144        Some(index)
145    }
146
147    pub fn save(&self) -> Result<(), String> {
148        let path = Self::index_path(&self.project_hash)
149            .ok_or_else(|| "Cannot determine data directory".to_string())?;
150        if let Some(dir) = path.parent() {
151            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
152        }
153        let json = serde_json::to_string(self).map_err(|e| format!("{e}"))?;
154        std::fs::write(path, json).map_err(|e| format!("{e}"))
155    }
156}
157
158pub fn reset(project_hash: &str) -> Result<(), String> {
159    let path = KnowledgeEmbeddingIndex::index_path(project_hash)
160        .ok_or_else(|| "Cannot determine data directory".to_string())?;
161    if path.exists() {
162        std::fs::remove_file(&path).map_err(|e| format!("{e}"))?;
163    }
164    Ok(())
165}
166
167#[derive(Debug)]
168pub struct ScoredFact<'a> {
169    pub fact: &'a KnowledgeFact,
170    pub score: f32,
171    pub semantic_score: f32,
172    pub confidence_score: f32,
173    pub recency_score: f32,
174}
175
176#[cfg(feature = "embeddings")]
177pub fn semantic_recall<'a>(
178    knowledge: &'a ProjectKnowledge,
179    index: &KnowledgeEmbeddingIndex,
180    engine: &EmbeddingEngine,
181    query: &str,
182    top_k: usize,
183) -> Vec<ScoredFact<'a>> {
184    let Ok(query_embedding) = engine.embed_query(query) else {
185        return lexical_fallback(knowledge, query, top_k);
186    };
187
188    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
189
190    let mut results: Vec<ScoredFact<'a>> = Vec::new();
191
192    for (entry, sim) in &semantic_hits {
193        if let Some(fact) = knowledge
194            .facts
195            .iter()
196            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
197        {
198            let confidence_score = fact.quality_score();
199            let recency_score = recency_decay(fact);
200            let score = ALPHA_SEMANTIC * sim
201                + BETA_CONFIDENCE * confidence_score
202                + GAMMA_RECENCY * recency_score;
203
204            results.push(ScoredFact {
205                fact,
206                score,
207                semantic_score: *sim,
208                confidence_score,
209                recency_score,
210            });
211        }
212    }
213
214    let exact_matches = knowledge.recall(query);
215    for fact in exact_matches {
216        let already_included = results
217            .iter()
218            .any(|r| r.fact.category == fact.category && r.fact.key == fact.key);
219        if !already_included {
220            results.push(ScoredFact {
221                fact,
222                score: 1.0,
223                semantic_score: 1.0,
224                confidence_score: fact.quality_score(),
225                recency_score: recency_decay(fact),
226            });
227        }
228    }
229
230    results.sort_by(|a, b| {
231        b.score
232            .partial_cmp(&a.score)
233            .unwrap_or(std::cmp::Ordering::Equal)
234            .then_with(|| {
235                b.confidence_score
236                    .partial_cmp(&a.confidence_score)
237                    .unwrap_or(std::cmp::Ordering::Equal)
238            })
239            .then_with(|| {
240                b.recency_score
241                    .partial_cmp(&a.recency_score)
242                    .unwrap_or(std::cmp::Ordering::Equal)
243            })
244            .then_with(|| a.fact.category.cmp(&b.fact.category))
245            .then_with(|| a.fact.key.cmp(&b.fact.key))
246            .then_with(|| a.fact.value.cmp(&b.fact.value))
247    });
248    results.truncate(top_k);
249    results
250}
251
252#[cfg(feature = "embeddings")]
253pub fn semantic_recall_semantic_only<'a>(
254    knowledge: &'a ProjectKnowledge,
255    index: &KnowledgeEmbeddingIndex,
256    engine: &EmbeddingEngine,
257    query: &str,
258    top_k: usize,
259) -> Vec<ScoredFact<'a>> {
260    let Ok(query_embedding) = engine.embed_query(query) else {
261        return Vec::new();
262    };
263
264    let semantic_hits = index.semantic_search(&query_embedding, top_k * 2);
265    let mut results: Vec<ScoredFact<'a>> = Vec::new();
266
267    for (entry, sim) in &semantic_hits {
268        if let Some(fact) = knowledge
269            .facts
270            .iter()
271            .find(|f| f.category == entry.category && f.key == entry.key && f.is_current())
272        {
273            let confidence_score = fact.quality_score();
274            let recency_score = recency_decay(fact);
275            let score = ALPHA_SEMANTIC * sim
276                + BETA_CONFIDENCE * confidence_score
277                + GAMMA_RECENCY * recency_score;
278
279            results.push(ScoredFact {
280                fact,
281                score,
282                semantic_score: *sim,
283                confidence_score,
284                recency_score,
285            });
286        }
287    }
288
289    results.sort_by(|a, b| {
290        b.score
291            .partial_cmp(&a.score)
292            .unwrap_or(std::cmp::Ordering::Equal)
293            .then_with(|| {
294                b.confidence_score
295                    .partial_cmp(&a.confidence_score)
296                    .unwrap_or(std::cmp::Ordering::Equal)
297            })
298            .then_with(|| {
299                b.recency_score
300                    .partial_cmp(&a.recency_score)
301                    .unwrap_or(std::cmp::Ordering::Equal)
302            })
303            .then_with(|| a.fact.category.cmp(&b.fact.category))
304            .then_with(|| a.fact.key.cmp(&b.fact.key))
305            .then_with(|| a.fact.value.cmp(&b.fact.value))
306    });
307    results.truncate(top_k);
308    results
309}
310
311pub fn compact_against_knowledge(
312    index: &mut KnowledgeEmbeddingIndex,
313    knowledge: &ProjectKnowledge,
314    policy: &MemoryPolicy,
315) {
316    use std::collections::HashMap;
317
318    let mut current: HashMap<(&str, &str), &KnowledgeFact> = HashMap::new();
319    for f in &knowledge.facts {
320        if f.is_current() {
321            current.insert((f.category.as_str(), f.key.as_str()), f);
322        }
323    }
324
325    let mut kept: Vec<(FactEmbedding, &KnowledgeFact)> = index
326        .entries
327        .iter()
328        .filter_map(|e| {
329            current
330                .get(&(e.category.as_str(), e.key.as_str()))
331                .map(|f| (e.clone(), *f))
332        })
333        .collect();
334
335    kept.sort_by(|(ea, fa), (eb, fb)| {
336        fb.confidence
337            .partial_cmp(&fa.confidence)
338            .unwrap_or(std::cmp::Ordering::Equal)
339            .then_with(|| fb.last_confirmed.cmp(&fa.last_confirmed))
340            .then_with(|| fb.retrieval_count.cmp(&fa.retrieval_count))
341            .then_with(|| ea.category.cmp(&eb.category))
342            .then_with(|| ea.key.cmp(&eb.key))
343    });
344
345    let max = policy.embeddings.max_facts;
346    if kept.len() > max {
347        kept.truncate(max);
348    }
349
350    index.entries = kept.into_iter().map(|(e, _)| e).collect();
351}
352
353fn lexical_fallback<'a>(
354    knowledge: &'a ProjectKnowledge,
355    query: &str,
356    top_k: usize,
357) -> Vec<ScoredFact<'a>> {
358    knowledge
359        .recall(query)
360        .into_iter()
361        .take(top_k)
362        .map(|fact| ScoredFact {
363            fact,
364            score: fact.confidence,
365            semantic_score: 0.0,
366            confidence_score: fact.confidence,
367            recency_score: recency_decay(fact),
368        })
369        .collect()
370}
371
372fn recency_decay(fact: &KnowledgeFact) -> f32 {
373    let days_old = chrono::Utc::now()
374        .signed_duration_since(fact.last_confirmed)
375        .num_days() as f32;
376    (1.0 - days_old / MAX_RECENCY_DAYS).max(0.0)
377}
378
379#[cfg(feature = "embeddings")]
380pub fn embed_and_store(
381    index: &mut KnowledgeEmbeddingIndex,
382    engine: &EmbeddingEngine,
383    category: &str,
384    key: &str,
385    value: &str,
386) -> Result<(), String> {
387    let text = format!("{category} {key}: {value}");
388    let embedding = engine.embed(&text).map_err(|e| format!("{e}"))?;
389    index.upsert(category, key, &embedding);
390    Ok(())
391}
392
393pub fn format_scored_facts(results: &[ScoredFact<'_>]) -> String {
394    if results.is_empty() {
395        return "No matching facts found.".to_string();
396    }
397
398    let mut output = String::new();
399    for (i, scored) in results.iter().enumerate() {
400        let f = scored.fact;
401        let stars = if f.confidence >= 0.9 {
402            "★★★★"
403        } else if f.confidence >= 0.7 {
404            "★★★"
405        } else if f.confidence >= 0.5 {
406            "★★"
407        } else {
408            "★"
409        };
410
411        if i > 0 {
412            output.push('|');
413        }
414        output.push_str(&format!(
415            "{}:{}={}{} [s:{:.0}%]",
416            f.category,
417            f.key,
418            f.value,
419            stars,
420            scored.score * 100.0
421        ));
422    }
423    output
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::core::knowledge::KnowledgeArchetype;
430
431    #[test]
432    fn reset_removes_index_file() {
433        let _lock = crate::core::data_dir::test_env_lock();
434        let tmp = tempfile::tempdir().expect("tempdir");
435        std::env::set_var(
436            "LEAN_CTX_DATA_DIR",
437            tmp.path().to_string_lossy().to_string(),
438        );
439
440        let idx = KnowledgeEmbeddingIndex {
441            project_hash: "projhash".to_string(),
442            entries: vec![FactEmbedding {
443                category: "arch".to_string(),
444                key: "db".to_string(),
445                embedding: vec![1.0, 0.0, 0.0],
446                quant: None,
447            }],
448        };
449        idx.save().expect("save");
450        assert!(KnowledgeEmbeddingIndex::load("projhash").is_some());
451
452        reset("projhash").expect("reset");
453        assert!(KnowledgeEmbeddingIndex::load("projhash").is_none());
454
455        std::env::remove_var("LEAN_CTX_DATA_DIR");
456    }
457
458    #[test]
459    fn compact_drops_missing_or_archived_facts() {
460        let mut knowledge = ProjectKnowledge::new("/tmp/project");
461        let now = chrono::Utc::now();
462        knowledge.facts.push(KnowledgeFact {
463            category: "arch".to_string(),
464            key: "db".to_string(),
465            value: "Postgres".to_string(),
466            source_session: "s".to_string(),
467            confidence: 0.9,
468            created_at: now,
469            last_confirmed: now,
470            retrieval_count: 5,
471            last_retrieved: None,
472            valid_from: None,
473            valid_until: None,
474            supersedes: None,
475            confirmation_count: 1,
476            feedback_up: 0,
477            feedback_down: 0,
478            last_feedback: None,
479            privacy: crate::core::memory_boundary::FactPrivacy::default(),
480            imported_from: None,
481            archetype: KnowledgeArchetype::default(),
482            fidelity: None,
483            revision_count: 0,
484        });
485        knowledge.facts.push(KnowledgeFact {
486            category: "arch".to_string(),
487            key: "old".to_string(),
488            value: "Old".to_string(),
489            source_session: "s".to_string(),
490            confidence: 0.9,
491            created_at: now,
492            last_confirmed: now,
493            retrieval_count: 0,
494            last_retrieved: None,
495            valid_from: None,
496            valid_until: Some(now),
497            supersedes: None,
498            confirmation_count: 1,
499            feedback_up: 0,
500            feedback_down: 0,
501            last_feedback: None,
502            privacy: crate::core::memory_boundary::FactPrivacy::default(),
503            imported_from: None,
504            archetype: KnowledgeArchetype::default(),
505            fidelity: None,
506            revision_count: 0,
507        });
508
509        let mut idx = KnowledgeEmbeddingIndex::new(&knowledge.project_hash);
510        idx.upsert("arch", "db", &[1.0, 0.0, 0.0]);
511        idx.upsert("arch", "old", &[0.0, 1.0, 0.0]);
512        idx.upsert("ops", "deploy", &[0.0, 0.0, 1.0]);
513
514        compact_against_knowledge(&mut idx, &knowledge, &MemoryPolicy::default());
515        assert_eq!(idx.entries.len(), 1);
516        assert_eq!(idx.entries[0].category, "arch");
517        assert_eq!(idx.entries[0].key, "db");
518    }
519
520    #[test]
521    fn index_upsert_and_remove() {
522        let mut idx = KnowledgeEmbeddingIndex::new("test");
523        idx.upsert("arch", "db", &[1.0, 0.0, 0.0]);
524        assert_eq!(idx.entries.len(), 1);
525
526        idx.upsert("arch", "db", &[0.0, 1.0, 0.0]);
527        assert_eq!(idx.entries.len(), 1);
528        // Stored quantized now: the dominant axis reconstructs to ~1.0.
529        let recon = idx.entries[0]
530            .quant
531            .as_ref()
532            .expect("quantized")
533            .dequantize();
534        assert!((recon[1] - 1.0).abs() < 1e-6);
535
536        idx.upsert("arch", "cache", &[0.0, 0.0, 1.0]);
537        assert_eq!(idx.entries.len(), 2);
538
539        idx.remove("arch", "db");
540        assert_eq!(idx.entries.len(), 1);
541        assert_eq!(idx.entries[0].key, "cache");
542    }
543
544    #[test]
545    fn recency_decay_recent() {
546        let fact = KnowledgeFact {
547            category: "test".to_string(),
548            key: "k".to_string(),
549            value: "v".to_string(),
550            source_session: "s".to_string(),
551            confidence: 0.9,
552            created_at: chrono::Utc::now(),
553            last_confirmed: chrono::Utc::now(),
554            retrieval_count: 0,
555            last_retrieved: None,
556            valid_from: None,
557            valid_until: None,
558            supersedes: None,
559            confirmation_count: 1,
560            feedback_up: 0,
561            feedback_down: 0,
562            last_feedback: None,
563            privacy: crate::core::memory_boundary::FactPrivacy::default(),
564            imported_from: None,
565            archetype: KnowledgeArchetype::default(),
566            fidelity: None,
567            revision_count: 0,
568        };
569        let decay = recency_decay(&fact);
570        assert!(
571            decay > 0.95,
572            "Recent fact should have high recency: {decay}"
573        );
574    }
575
576    #[test]
577    fn recency_decay_old() {
578        let old_date = chrono::Utc::now() - chrono::Duration::days(100);
579        let fact = KnowledgeFact {
580            category: "test".to_string(),
581            key: "k".to_string(),
582            value: "v".to_string(),
583            source_session: "s".to_string(),
584            confidence: 0.5,
585            created_at: old_date,
586            last_confirmed: old_date,
587            retrieval_count: 0,
588            last_retrieved: None,
589            valid_from: None,
590            valid_until: None,
591            supersedes: None,
592            confirmation_count: 1,
593            feedback_up: 0,
594            feedback_down: 0,
595            last_feedback: None,
596            privacy: crate::core::memory_boundary::FactPrivacy::default(),
597            imported_from: None,
598            archetype: KnowledgeArchetype::default(),
599            fidelity: None,
600            revision_count: 0,
601        };
602        let decay = recency_decay(&fact);
603        assert_eq!(decay, 0.0, "100-day-old fact should have 0 recency");
604    }
605
606    #[cfg(feature = "embeddings")]
607    #[test]
608    fn semantic_search_ranking() {
609        let mut idx = KnowledgeEmbeddingIndex::new("test");
610        idx.upsert("arch", "db", &[1.0, 0.0, 0.0]);
611        idx.upsert("arch", "cache", &[0.0, 1.0, 0.0]);
612        idx.upsert("ops", "deploy", &[0.5, 0.5, 0.0]);
613
614        let query = vec![1.0, 0.0, 0.0];
615        let results = idx.semantic_search(&query, 2);
616        assert_eq!(results.len(), 2);
617        assert_eq!(results[0].0.key, "db");
618    }
619
620    #[test]
621    fn format_scored_empty() {
622        assert_eq!(format_scored_facts(&[]), "No matching facts found.");
623    }
624
625    #[test]
626    fn format_scored_output() {
627        let fact = KnowledgeFact {
628            category: "arch".to_string(),
629            key: "db".to_string(),
630            value: "PostgreSQL".to_string(),
631            source_session: "s1".to_string(),
632            confidence: 0.95,
633            created_at: chrono::Utc::now(),
634            last_confirmed: chrono::Utc::now(),
635            retrieval_count: 0,
636            last_retrieved: None,
637            valid_from: None,
638            valid_until: None,
639            supersedes: None,
640            confirmation_count: 3,
641            feedback_up: 0,
642            feedback_down: 0,
643            last_feedback: None,
644            privacy: crate::core::memory_boundary::FactPrivacy::default(),
645            imported_from: None,
646            archetype: KnowledgeArchetype::default(),
647            fidelity: None,
648            revision_count: 0,
649        };
650        let scored = vec![ScoredFact {
651            fact: &fact,
652            score: 0.85,
653            semantic_score: 0.9,
654            confidence_score: 0.95,
655            recency_score: 1.0,
656        }];
657        let output = format_scored_facts(&scored);
658        assert!(output.contains("arch:db=PostgreSQL"));
659        assert!(output.contains("★★★★"));
660        assert!(output.contains("[s:85%]"));
661    }
662}