Skip to main content

lean_ctx/core/
semantic_cache.rs

1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6/// Recompute global IDF after this many mutations; searches refresh IDF when `idf_dirty`.
7const IDF_REBUILD_BATCH: u32 = 100;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SemanticCacheEntry {
11    pub path: String,
12    pub tfidf_vector: Vec<(String, f64)>,
13    pub token_count: usize,
14    pub access_count: u32,
15    pub last_session: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct SemanticCacheIndex {
20    pub entries: Vec<SemanticCacheEntry>,
21    pub idf: HashMap<String, f64>,
22    pub total_docs: usize,
23    /// Documents containing each term (unique terms per entry).
24    #[serde(default)]
25    pub term_document_freq: HashMap<String, usize>,
26    #[serde(default)]
27    idf_dirty: bool,
28    #[serde(default)]
29    mutations_since_idf_rebuild: u32,
30}
31
32impl SemanticCacheIndex {
33    pub fn add_file(&mut self, path: &str, content: &str, session_id: &str) {
34        let tf = compute_tf(content);
35        let token_count = content.split_whitespace().count();
36
37        if let Some(existing) = self.entries.iter_mut().find(|e| e.path == path) {
38            remove_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
39            existing.tfidf_vector = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
40            existing.token_count = token_count;
41            existing.access_count += 1;
42            existing.last_session = session_id.to_string();
43            add_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
44        } else {
45            let tf_vec: Vec<(String, f64)> = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
46            add_doc_terms(&mut self.term_document_freq, &tf_vec);
47            self.entries.push(SemanticCacheEntry {
48                path: path.to_string(),
49                tfidf_vector: tf_vec,
50                token_count,
51                access_count: 1,
52                last_session: session_id.to_string(),
53            });
54        }
55
56        self.total_docs = self.entries.len();
57        self.note_idf_mutation();
58    }
59
60    fn note_idf_mutation(&mut self) {
61        self.idf_dirty = true;
62        self.mutations_since_idf_rebuild = self.mutations_since_idf_rebuild.saturating_add(1);
63        if self.mutations_since_idf_rebuild >= IDF_REBUILD_BATCH {
64            self.recompute_idf_from_df();
65            self.idf_dirty = false;
66            self.mutations_since_idf_rebuild = 0;
67        }
68    }
69
70    fn recompute_idf_from_df(&mut self) {
71        self.idf.clear();
72        let n = self.total_docs as f64;
73        if n <= 0.0 {
74            return;
75        }
76        for (term, count) in &self.term_document_freq {
77            let idf = (n / (*count as f64 + 1.0)).ln() + 1.0;
78            self.idf.insert(term.clone(), idf);
79        }
80    }
81
82    fn rebuild_df_from_entries(&mut self) {
83        self.term_document_freq.clear();
84        for entry in &self.entries {
85            add_doc_terms(&mut self.term_document_freq, &entry.tfidf_vector);
86        }
87    }
88
89    fn repair_after_deserialize(&mut self) {
90        self.total_docs = self.entries.len();
91        if self.term_document_freq.is_empty() && !self.entries.is_empty() {
92            self.rebuild_df_from_entries();
93            self.idf_dirty = true;
94        }
95    }
96
97    fn ensure_idf_for_search(&mut self) {
98        if self.idf_dirty {
99            self.recompute_idf_from_df();
100            self.idf_dirty = false;
101            self.mutations_since_idf_rebuild = 0;
102        }
103    }
104
105    pub fn find_similar(&mut self, content: &str, threshold: f64) -> Vec<(String, f64)> {
106        const MAX_ENTRIES_FOR_SEARCH: usize = 200;
107
108        if self.entries.len() > MAX_ENTRIES_FOR_SEARCH {
109            return Vec::new();
110        }
111
112        self.ensure_idf_for_search();
113
114        let query_tf = compute_tf(content);
115        let query_vec = self.tfidf_vector(&query_tf);
116
117        let mut results: Vec<(String, f64)> = self
118            .entries
119            .iter()
120            .filter_map(|entry| {
121                let entry_vec = self.tfidf_vector_from_stored(&entry.tfidf_vector);
122                let sim = cosine_similarity(&query_vec, &entry_vec);
123                if sim >= threshold {
124                    Some((entry.path.clone(), sim))
125                } else {
126                    None
127                }
128            })
129            .collect();
130
131        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
132        results
133    }
134
135    pub fn suggest_warmup(&self, top_n: usize) -> Vec<String> {
136        let mut ranked: Vec<(&SemanticCacheEntry, f64)> = self
137            .entries
138            .iter()
139            .map(|e| {
140                let score = e.access_count as f64 * 0.6 + e.token_count as f64 * 0.0001;
141                (e, score)
142            })
143            .collect();
144
145        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
146
147        ranked
148            .into_iter()
149            .take(top_n)
150            .map(|(e, _)| e.path.clone())
151            .collect()
152    }
153
154    fn tfidf_vector(&self, tf: &HashMap<String, f64>) -> HashMap<String, f64> {
155        tf.iter()
156            .map(|(term, freq)| {
157                let idf = self.idf.get(term).copied().unwrap_or(1.0);
158                (term.clone(), freq * idf)
159            })
160            .collect()
161    }
162
163    fn tfidf_vector_from_stored(&self, stored: &[(String, f64)]) -> HashMap<String, f64> {
164        stored
165            .iter()
166            .map(|(term, freq)| {
167                let idf = self.idf.get(term).copied().unwrap_or(1.0);
168                (term.clone(), freq * idf)
169            })
170            .collect()
171    }
172
173    pub fn save(&self, project_root: &str) -> Result<(), String> {
174        let path = index_path(project_root);
175        if let Some(dir) = path.parent() {
176            std::fs::create_dir_all(dir).map_err(|e| e.to_string())?;
177        }
178        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
179        std::fs::write(&path, json).map_err(|e| e.to_string())
180    }
181
182    pub fn load(project_root: &str) -> Option<Self> {
183        let path = index_path(project_root);
184        let content = std::fs::read_to_string(&path)
185            .or_else(|_| {
186                let legacy = legacy_index_path(project_root);
187                if legacy == path {
188                    return Err(std::io::Error::new(
189                        std::io::ErrorKind::NotFound,
190                        "same path",
191                    ));
192                }
193                let data = std::fs::read_to_string(&legacy)?;
194                let _ = std::fs::copy(&legacy, &path);
195                Ok(data)
196            })
197            .ok()?;
198        let mut index: SemanticCacheIndex = serde_json::from_str(&content).ok()?;
199        index.repair_after_deserialize();
200        Some(index)
201    }
202
203    pub fn load_or_create(project_root: &str) -> Self {
204        Self::load(project_root).unwrap_or_default()
205    }
206}
207
208fn remove_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
209    let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
210    for term in unique {
211        if let Some(c) = df.get_mut(term) {
212            *c = c.saturating_sub(1);
213            if *c == 0 {
214                df.remove(term);
215            }
216        }
217    }
218}
219
220fn add_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
221    let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
222    for term in unique {
223        *df.entry(term.to_string()).or_default() += 1;
224    }
225}
226
227fn compute_tf(content: &str) -> HashMap<String, f64> {
228    let mut counts: HashMap<String, usize> = HashMap::new();
229    let mut total = 0usize;
230
231    for word in content.split(|c: char| !c.is_alphanumeric() && c != '_') {
232        let w = word.to_lowercase();
233        if w.len() >= 2 {
234            *counts.entry(w).or_default() += 1;
235            total += 1;
236        }
237    }
238
239    if total == 0 {
240        return HashMap::new();
241    }
242
243    counts
244        .into_iter()
245        .map(|(term, count)| (term, count as f64 / total as f64))
246        .collect()
247}
248
249fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
250    let mut dot = 0.0f64;
251    let mut norm_a = 0.0f64;
252    let mut norm_b = 0.0f64;
253
254    for (term, val) in a {
255        norm_a += val * val;
256        if let Some(bval) = b.get(term) {
257            dot += val * bval;
258        }
259    }
260    for val in b.values() {
261        norm_b += val * val;
262    }
263
264    let denom = norm_a.sqrt() * norm_b.sqrt();
265    if denom < 1e-10 {
266        return 0.0;
267    }
268    dot / denom
269}
270
271fn index_path(project_root: &str) -> PathBuf {
272    let hash = crate::core::project_hash::hash_project_root(project_root);
273    crate::core::data_dir::lean_ctx_data_dir()
274        .unwrap_or_default()
275        .join("semantic_cache")
276        .join(format!("{hash}.json"))
277}
278
279fn legacy_index_path(project_root: &str) -> PathBuf {
280    use md5::{Digest, Md5};
281    let mut hasher = Md5::new();
282    hasher.update(project_root.as_bytes());
283    let hash = format!("{:x}", hasher.finalize());
284    crate::core::data_dir::lean_ctx_data_dir()
285        .unwrap_or_default()
286        .join("semantic_cache")
287        .join(format!("{hash}.json"))
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn compute_tf_basic() {
296        let tf = compute_tf("fn handle_request request response handle");
297        assert!(tf.contains_key("handle"));
298        assert!(tf.contains_key("request"));
299        assert!(tf["handle"] > 0.0);
300    }
301
302    #[test]
303    fn cosine_identical() {
304        let mut a = HashMap::new();
305        a.insert("hello".to_string(), 1.0);
306        a.insert("world".to_string(), 0.5);
307        let sim = cosine_similarity(&a, &a);
308        assert!((sim - 1.0).abs() < 0.001);
309    }
310
311    #[test]
312    fn cosine_orthogonal() {
313        let mut a = HashMap::new();
314        a.insert("hello".to_string(), 1.0);
315        let mut b = HashMap::new();
316        b.insert("world".to_string(), 1.0);
317        let sim = cosine_similarity(&a, &b);
318        assert!(sim.abs() < 0.001);
319    }
320
321    #[test]
322    fn add_and_find_similar() {
323        let mut index = SemanticCacheIndex::default();
324        index.add_file(
325            "auth.rs",
326            "fn validate_token check jwt expiry auth login",
327            "s1",
328        );
329        index.add_file(
330            "db.rs",
331            "fn connect_database pool query insert delete",
332            "s1",
333        );
334
335        let results = index.find_similar("validate auth token jwt", 0.1);
336        assert!(!results.is_empty());
337        assert_eq!(results[0].0, "auth.rs");
338    }
339
340    #[test]
341    fn warmup_suggestions() {
342        let mut index = SemanticCacheIndex::default();
343        index.add_file("hot.rs", "frequently accessed file", "s1");
344        index.entries[0].access_count = 50;
345        index.add_file("cold.rs", "rarely used", "s1");
346
347        let warmup = index.suggest_warmup(1);
348        assert_eq!(warmup.len(), 1);
349        assert_eq!(warmup[0], "hot.rs");
350    }
351}