Skip to main content

lean_ctx/core/
semantic_cache.rs

1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6/// Recompute global IDF after this many mutations; searches refresh IDF when `idf_dirty`.
7const IDF_REBUILD_BATCH: u32 = 100;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SemanticCacheEntry {
11    pub path: String,
12    pub tfidf_vector: Vec<(String, f64)>,
13    pub token_count: usize,
14    pub access_count: u32,
15    pub last_session: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct SemanticCacheIndex {
20    pub entries: Vec<SemanticCacheEntry>,
21    pub idf: HashMap<String, f64>,
22    pub total_docs: usize,
23    /// Documents containing each term (unique terms per entry).
24    #[serde(default)]
25    pub term_document_freq: HashMap<String, usize>,
26    #[serde(default)]
27    idf_dirty: bool,
28    #[serde(default)]
29    mutations_since_idf_rebuild: u32,
30}
31
32impl SemanticCacheIndex {
33    pub fn add_file(&mut self, path: &str, content: &str, session_id: &str) {
34        let tf = compute_tf(content);
35        let token_count = content.split_whitespace().count();
36
37        if let Some(existing) = self.entries.iter_mut().find(|e| e.path == path) {
38            remove_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
39            existing.tfidf_vector = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
40            existing.token_count = token_count;
41            existing.access_count += 1;
42            existing.last_session = session_id.to_string();
43            add_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
44        } else {
45            let tf_vec: Vec<(String, f64)> = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
46            add_doc_terms(&mut self.term_document_freq, &tf_vec);
47            self.entries.push(SemanticCacheEntry {
48                path: path.to_string(),
49                tfidf_vector: tf_vec,
50                token_count,
51                access_count: 1,
52                last_session: session_id.to_string(),
53            });
54        }
55
56        self.total_docs = self.entries.len();
57        self.note_idf_mutation();
58    }
59
60    fn note_idf_mutation(&mut self) {
61        self.idf_dirty = true;
62        self.mutations_since_idf_rebuild = self.mutations_since_idf_rebuild.saturating_add(1);
63        if self.mutations_since_idf_rebuild >= IDF_REBUILD_BATCH {
64            self.recompute_idf_from_df();
65            self.idf_dirty = false;
66            self.mutations_since_idf_rebuild = 0;
67        }
68    }
69
70    fn recompute_idf_from_df(&mut self) {
71        self.idf.clear();
72        let n = self.total_docs as f64;
73        if n <= 0.0 {
74            return;
75        }
76        for (term, count) in &self.term_document_freq {
77            let idf = (n / (*count as f64 + 1.0)).ln() + 1.0;
78            self.idf.insert(term.clone(), idf);
79        }
80    }
81
82    fn rebuild_df_from_entries(&mut self) {
83        self.term_document_freq.clear();
84        for entry in &self.entries {
85            add_doc_terms(&mut self.term_document_freq, &entry.tfidf_vector);
86        }
87    }
88
89    fn repair_after_deserialize(&mut self) {
90        self.total_docs = self.entries.len();
91        if self.term_document_freq.is_empty() && !self.entries.is_empty() {
92            self.rebuild_df_from_entries();
93            self.idf_dirty = true;
94        }
95    }
96
97    fn ensure_idf_for_search(&mut self) {
98        if self.idf_dirty {
99            self.recompute_idf_from_df();
100            self.idf_dirty = false;
101            self.mutations_since_idf_rebuild = 0;
102        }
103    }
104
105    pub fn find_similar(&mut self, content: &str, threshold: f64) -> Vec<(String, f64)> {
106        self.ensure_idf_for_search();
107
108        let query_tf = compute_tf(content);
109        let query_vec = self.tfidf_vector(&query_tf);
110
111        let mut results: Vec<(String, f64)> = self
112            .entries
113            .iter()
114            .filter_map(|entry| {
115                let entry_vec = self.tfidf_vector_from_stored(&entry.tfidf_vector);
116                let sim = cosine_similarity(&query_vec, &entry_vec);
117                if sim >= threshold {
118                    Some((entry.path.clone(), sim))
119                } else {
120                    None
121                }
122            })
123            .collect();
124
125        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
126        results
127    }
128
129    pub fn suggest_warmup(&self, top_n: usize) -> Vec<String> {
130        let mut ranked: Vec<(&SemanticCacheEntry, f64)> = self
131            .entries
132            .iter()
133            .map(|e| {
134                let score = e.access_count as f64 * 0.6 + e.token_count as f64 * 0.0001;
135                (e, score)
136            })
137            .collect();
138
139        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
140
141        ranked
142            .into_iter()
143            .take(top_n)
144            .map(|(e, _)| e.path.clone())
145            .collect()
146    }
147
148    fn tfidf_vector(&self, tf: &HashMap<String, f64>) -> HashMap<String, f64> {
149        tf.iter()
150            .map(|(term, freq)| {
151                let idf = self.idf.get(term).copied().unwrap_or(1.0);
152                (term.clone(), freq * idf)
153            })
154            .collect()
155    }
156
157    fn tfidf_vector_from_stored(&self, stored: &[(String, f64)]) -> HashMap<String, f64> {
158        stored
159            .iter()
160            .map(|(term, freq)| {
161                let idf = self.idf.get(term).copied().unwrap_or(1.0);
162                (term.clone(), freq * idf)
163            })
164            .collect()
165    }
166
167    pub fn save(&self, project_root: &str) -> Result<(), String> {
168        let path = index_path(project_root);
169        if let Some(dir) = path.parent() {
170            std::fs::create_dir_all(dir).map_err(|e| e.to_string())?;
171        }
172        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
173        std::fs::write(&path, json).map_err(|e| e.to_string())
174    }
175
176    pub fn load(project_root: &str) -> Option<Self> {
177        let path = index_path(project_root);
178        let content = std::fs::read_to_string(&path)
179            .or_else(|_| {
180                let legacy = legacy_index_path(project_root);
181                if legacy == path {
182                    return Err(std::io::Error::new(
183                        std::io::ErrorKind::NotFound,
184                        "same path",
185                    ));
186                }
187                let data = std::fs::read_to_string(&legacy)?;
188                let _ = std::fs::copy(&legacy, &path);
189                Ok(data)
190            })
191            .ok()?;
192        let mut index: SemanticCacheIndex = serde_json::from_str(&content).ok()?;
193        index.repair_after_deserialize();
194        Some(index)
195    }
196
197    pub fn load_or_create(project_root: &str) -> Self {
198        Self::load(project_root).unwrap_or_default()
199    }
200}
201
202fn remove_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
203    let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
204    for term in unique {
205        if let Some(c) = df.get_mut(term) {
206            *c = c.saturating_sub(1);
207            if *c == 0 {
208                df.remove(term);
209            }
210        }
211    }
212}
213
214fn add_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
215    let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
216    for term in unique {
217        *df.entry(term.to_string()).or_default() += 1;
218    }
219}
220
221fn compute_tf(content: &str) -> HashMap<String, f64> {
222    let mut counts: HashMap<String, usize> = HashMap::new();
223    let mut total = 0usize;
224
225    for word in content.split(|c: char| !c.is_alphanumeric() && c != '_') {
226        let w = word.to_lowercase();
227        if w.len() >= 2 {
228            *counts.entry(w).or_default() += 1;
229            total += 1;
230        }
231    }
232
233    if total == 0 {
234        return HashMap::new();
235    }
236
237    counts
238        .into_iter()
239        .map(|(term, count)| (term, count as f64 / total as f64))
240        .collect()
241}
242
243fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
244    let mut dot = 0.0f64;
245    let mut norm_a = 0.0f64;
246    let mut norm_b = 0.0f64;
247
248    for (term, val) in a {
249        norm_a += val * val;
250        if let Some(bval) = b.get(term) {
251            dot += val * bval;
252        }
253    }
254    for val in b.values() {
255        norm_b += val * val;
256    }
257
258    let denom = norm_a.sqrt() * norm_b.sqrt();
259    if denom < 1e-10 {
260        return 0.0;
261    }
262    dot / denom
263}
264
265fn index_path(project_root: &str) -> PathBuf {
266    let hash = crate::core::project_hash::hash_project_root(project_root);
267    crate::core::data_dir::lean_ctx_data_dir()
268        .unwrap_or_default()
269        .join("semantic_cache")
270        .join(format!("{hash}.json"))
271}
272
273fn legacy_index_path(project_root: &str) -> PathBuf {
274    use md5::{Digest, Md5};
275    let mut hasher = Md5::new();
276    hasher.update(project_root.as_bytes());
277    let hash = format!("{:x}", hasher.finalize());
278    crate::core::data_dir::lean_ctx_data_dir()
279        .unwrap_or_default()
280        .join("semantic_cache")
281        .join(format!("{hash}.json"))
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn compute_tf_basic() {
290        let tf = compute_tf("fn handle_request request response handle");
291        assert!(tf.contains_key("handle"));
292        assert!(tf.contains_key("request"));
293        assert!(tf["handle"] > 0.0);
294    }
295
296    #[test]
297    fn cosine_identical() {
298        let mut a = HashMap::new();
299        a.insert("hello".to_string(), 1.0);
300        a.insert("world".to_string(), 0.5);
301        let sim = cosine_similarity(&a, &a);
302        assert!((sim - 1.0).abs() < 0.001);
303    }
304
305    #[test]
306    fn cosine_orthogonal() {
307        let mut a = HashMap::new();
308        a.insert("hello".to_string(), 1.0);
309        let mut b = HashMap::new();
310        b.insert("world".to_string(), 1.0);
311        let sim = cosine_similarity(&a, &b);
312        assert!(sim.abs() < 0.001);
313    }
314
315    #[test]
316    fn add_and_find_similar() {
317        let mut index = SemanticCacheIndex::default();
318        index.add_file(
319            "auth.rs",
320            "fn validate_token check jwt expiry auth login",
321            "s1",
322        );
323        index.add_file(
324            "db.rs",
325            "fn connect_database pool query insert delete",
326            "s1",
327        );
328
329        let results = index.find_similar("validate auth token jwt", 0.1);
330        assert!(!results.is_empty());
331        assert_eq!(results[0].0, "auth.rs");
332    }
333
334    #[test]
335    fn warmup_suggestions() {
336        let mut index = SemanticCacheIndex::default();
337        index.add_file("hot.rs", "frequently accessed file", "s1");
338        index.entries[0].access_count = 50;
339        index.add_file("cold.rs", "rarely used", "s1");
340
341        let warmup = index.suggest_warmup(1);
342        assert_eq!(warmup.len(), 1);
343        assert_eq!(warmup[0], "hot.rs");
344    }
345}