Skip to main content

lean_ctx/core/
semantic_cache.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SemanticCacheEntry {
8    pub path: String,
9    pub tfidf_vector: Vec<(String, f64)>,
10    pub token_count: usize,
11    pub access_count: u32,
12    pub last_session: String,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct SemanticCacheIndex {
17    pub entries: Vec<SemanticCacheEntry>,
18    pub idf: HashMap<String, f64>,
19    pub total_docs: usize,
20}
21
22impl SemanticCacheIndex {
23    pub fn add_file(&mut self, path: &str, content: &str, session_id: &str) {
24        let tf = compute_tf(content);
25        let token_count = content.split_whitespace().count();
26
27        if let Some(existing) = self.entries.iter_mut().find(|e| e.path == path) {
28            existing.tfidf_vector = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
29            existing.token_count = token_count;
30            existing.access_count += 1;
31            existing.last_session = session_id.to_string();
32        } else {
33            self.entries.push(SemanticCacheEntry {
34                path: path.to_string(),
35                tfidf_vector: tf.iter().map(|(k, v)| (k.clone(), *v)).collect(),
36                token_count,
37                access_count: 1,
38                last_session: session_id.to_string(),
39            });
40        }
41
42        self.total_docs = self.entries.len();
43        self.rebuild_idf();
44    }
45
46    fn rebuild_idf(&mut self) {
47        let mut df: HashMap<String, usize> = HashMap::new();
48        for entry in &self.entries {
49            let unique_terms: std::collections::HashSet<&str> =
50                entry.tfidf_vector.iter().map(|(k, _)| k.as_str()).collect();
51            for term in unique_terms {
52                *df.entry(term.to_string()).or_default() += 1;
53            }
54        }
55
56        self.idf.clear();
57        let n = self.total_docs as f64;
58        for (term, count) in &df {
59            let idf = (n / (*count as f64 + 1.0)).ln() + 1.0;
60            self.idf.insert(term.clone(), idf);
61        }
62    }
63
64    pub fn find_similar(&self, content: &str, threshold: f64) -> Vec<(String, f64)> {
65        let query_tf = compute_tf(content);
66        let query_vec = self.tfidf_vector(&query_tf);
67
68        let mut results: Vec<(String, f64)> = self
69            .entries
70            .iter()
71            .filter_map(|entry| {
72                let entry_vec = self.tfidf_vector_from_stored(&entry.tfidf_vector);
73                let sim = cosine_similarity(&query_vec, &entry_vec);
74                if sim >= threshold {
75                    Some((entry.path.clone(), sim))
76                } else {
77                    None
78                }
79            })
80            .collect();
81
82        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
83        results
84    }
85
86    pub fn suggest_warmup(&self, top_n: usize) -> Vec<String> {
87        let mut ranked: Vec<(&SemanticCacheEntry, f64)> = self
88            .entries
89            .iter()
90            .map(|e| {
91                let score = e.access_count as f64 * 0.6 + e.token_count as f64 * 0.0001;
92                (e, score)
93            })
94            .collect();
95
96        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
97
98        ranked
99            .into_iter()
100            .take(top_n)
101            .map(|(e, _)| e.path.clone())
102            .collect()
103    }
104
105    fn tfidf_vector(&self, tf: &HashMap<String, f64>) -> HashMap<String, f64> {
106        tf.iter()
107            .map(|(term, freq)| {
108                let idf = self.idf.get(term).copied().unwrap_or(1.0);
109                (term.clone(), freq * idf)
110            })
111            .collect()
112    }
113
114    fn tfidf_vector_from_stored(&self, stored: &[(String, f64)]) -> HashMap<String, f64> {
115        stored
116            .iter()
117            .map(|(term, freq)| {
118                let idf = self.idf.get(term).copied().unwrap_or(1.0);
119                (term.clone(), freq * idf)
120            })
121            .collect()
122    }
123
124    pub fn save(&self, project_root: &str) -> Result<(), String> {
125        let path = index_path(project_root);
126        if let Some(dir) = path.parent() {
127            std::fs::create_dir_all(dir).map_err(|e| e.to_string())?;
128        }
129        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
130        std::fs::write(&path, json).map_err(|e| e.to_string())
131    }
132
133    pub fn load(project_root: &str) -> Option<Self> {
134        let path = index_path(project_root);
135        let content = std::fs::read_to_string(path).ok()?;
136        serde_json::from_str(&content).ok()
137    }
138
139    pub fn load_or_create(project_root: &str) -> Self {
140        Self::load(project_root).unwrap_or_default()
141    }
142}
143
144fn compute_tf(content: &str) -> HashMap<String, f64> {
145    let mut counts: HashMap<String, usize> = HashMap::new();
146    let mut total = 0usize;
147
148    for word in content.split(|c: char| !c.is_alphanumeric() && c != '_') {
149        let w = word.to_lowercase();
150        if w.len() >= 2 {
151            *counts.entry(w).or_default() += 1;
152            total += 1;
153        }
154    }
155
156    if total == 0 {
157        return HashMap::new();
158    }
159
160    counts
161        .into_iter()
162        .map(|(term, count)| (term, count as f64 / total as f64))
163        .collect()
164}
165
166fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
167    let mut dot = 0.0f64;
168    let mut norm_a = 0.0f64;
169    let mut norm_b = 0.0f64;
170
171    for (term, val) in a {
172        norm_a += val * val;
173        if let Some(bval) = b.get(term) {
174            dot += val * bval;
175        }
176    }
177    for val in b.values() {
178        norm_b += val * val;
179    }
180
181    let denom = norm_a.sqrt() * norm_b.sqrt();
182    if denom < 1e-10 {
183        return 0.0;
184    }
185    dot / denom
186}
187
188fn index_path(project_root: &str) -> PathBuf {
189    use md5::{Digest, Md5};
190    let mut hasher = Md5::new();
191    hasher.update(project_root.as_bytes());
192    let hash = format!("{:x}", hasher.finalize());
193    dirs::home_dir()
194        .unwrap_or_default()
195        .join(".lean-ctx")
196        .join("semantic_cache")
197        .join(format!("{hash}.json"))
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn compute_tf_basic() {
206        let tf = compute_tf("fn handle_request request response handle");
207        assert!(tf.contains_key("handle"));
208        assert!(tf.contains_key("request"));
209        assert!(tf["handle"] > 0.0);
210    }
211
212    #[test]
213    fn cosine_identical() {
214        let mut a = HashMap::new();
215        a.insert("hello".to_string(), 1.0);
216        a.insert("world".to_string(), 0.5);
217        let sim = cosine_similarity(&a, &a);
218        assert!((sim - 1.0).abs() < 0.001);
219    }
220
221    #[test]
222    fn cosine_orthogonal() {
223        let mut a = HashMap::new();
224        a.insert("hello".to_string(), 1.0);
225        let mut b = HashMap::new();
226        b.insert("world".to_string(), 1.0);
227        let sim = cosine_similarity(&a, &b);
228        assert!(sim.abs() < 0.001);
229    }
230
231    #[test]
232    fn add_and_find_similar() {
233        let mut index = SemanticCacheIndex::default();
234        index.add_file(
235            "auth.rs",
236            "fn validate_token check jwt expiry auth login",
237            "s1",
238        );
239        index.add_file(
240            "db.rs",
241            "fn connect_database pool query insert delete",
242            "s1",
243        );
244
245        let results = index.find_similar("validate auth token jwt", 0.1);
246        assert!(!results.is_empty());
247        assert_eq!(results[0].0, "auth.rs");
248    }
249
250    #[test]
251    fn warmup_suggestions() {
252        let mut index = SemanticCacheIndex::default();
253        index.add_file("hot.rs", "frequently accessed file", "s1");
254        index.entries[0].access_count = 50;
255        index.add_file("cold.rs", "rarely used", "s1");
256
257        let warmup = index.suggest_warmup(1);
258        assert_eq!(warmup.len(), 1);
259        assert_eq!(warmup[0], "hot.rs");
260    }
261}