1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6const IDF_REBUILD_BATCH: u32 = 100;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SemanticCacheEntry {
11 pub path: String,
12 pub tfidf_vector: Vec<(String, f64)>,
13 pub token_count: usize,
14 pub access_count: u32,
15 pub last_session: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct SemanticCacheIndex {
20 pub entries: Vec<SemanticCacheEntry>,
21 pub idf: HashMap<String, f64>,
22 pub total_docs: usize,
23 #[serde(default)]
25 pub term_document_freq: HashMap<String, usize>,
26 #[serde(default)]
27 idf_dirty: bool,
28 #[serde(default)]
29 mutations_since_idf_rebuild: u32,
30}
31
32impl SemanticCacheIndex {
33 pub fn add_file(&mut self, path: &str, content: &str, session_id: &str) {
34 let tf = compute_tf(content);
35 let token_count = content.split_whitespace().count();
36
37 if let Some(existing) = self.entries.iter_mut().find(|e| e.path == path) {
38 remove_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
39 existing.tfidf_vector = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
40 existing.token_count = token_count;
41 existing.access_count += 1;
42 existing.last_session = session_id.to_string();
43 add_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
44 } else {
45 let tf_vec: Vec<(String, f64)> = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
46 add_doc_terms(&mut self.term_document_freq, &tf_vec);
47 self.entries.push(SemanticCacheEntry {
48 path: path.to_string(),
49 tfidf_vector: tf_vec,
50 token_count,
51 access_count: 1,
52 last_session: session_id.to_string(),
53 });
54 }
55
56 self.total_docs = self.entries.len();
57 self.note_idf_mutation();
58 }
59
60 fn note_idf_mutation(&mut self) {
61 self.idf_dirty = true;
62 self.mutations_since_idf_rebuild = self.mutations_since_idf_rebuild.saturating_add(1);
63 if self.mutations_since_idf_rebuild >= IDF_REBUILD_BATCH {
64 self.recompute_idf_from_df();
65 self.idf_dirty = false;
66 self.mutations_since_idf_rebuild = 0;
67 }
68 }
69
70 fn recompute_idf_from_df(&mut self) {
71 self.idf.clear();
72 let n = self.total_docs as f64;
73 if n <= 0.0 {
74 return;
75 }
76 for (term, count) in &self.term_document_freq {
77 let idf = (n / (*count as f64 + 1.0)).ln() + 1.0;
78 self.idf.insert(term.clone(), idf);
79 }
80 }
81
82 fn rebuild_df_from_entries(&mut self) {
83 self.term_document_freq.clear();
84 for entry in &self.entries {
85 add_doc_terms(&mut self.term_document_freq, &entry.tfidf_vector);
86 }
87 }
88
89 fn repair_after_deserialize(&mut self) {
90 self.total_docs = self.entries.len();
91 if self.term_document_freq.is_empty() && !self.entries.is_empty() {
92 self.rebuild_df_from_entries();
93 self.idf_dirty = true;
94 }
95 }
96
97 fn ensure_idf_for_search(&mut self) {
98 if self.idf_dirty {
99 self.recompute_idf_from_df();
100 self.idf_dirty = false;
101 self.mutations_since_idf_rebuild = 0;
102 }
103 }
104
105 pub fn find_similar(&mut self, content: &str, threshold: f64) -> Vec<(String, f64)> {
106 const MAX_ENTRIES_FOR_SEARCH: usize = 200;
107
108 if self.entries.len() > MAX_ENTRIES_FOR_SEARCH {
109 return Vec::new();
110 }
111
112 self.ensure_idf_for_search();
113
114 let query_tf = compute_tf(content);
115 let query_vec = self.tfidf_vector(&query_tf);
116
117 let mut results: Vec<(String, f64)> = self
118 .entries
119 .iter()
120 .filter_map(|entry| {
121 let entry_vec = self.tfidf_vector_from_stored(&entry.tfidf_vector);
122 let sim = cosine_similarity(&query_vec, &entry_vec);
123 if sim >= threshold {
124 Some((entry.path.clone(), sim))
125 } else {
126 None
127 }
128 })
129 .collect();
130
131 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
132 results
133 }
134
135 pub fn suggest_warmup(&self, top_n: usize) -> Vec<String> {
136 let mut ranked: Vec<(&SemanticCacheEntry, f64)> = self
137 .entries
138 .iter()
139 .map(|e| {
140 let score = e.access_count as f64 * 0.6 + e.token_count as f64 * 0.0001;
141 (e, score)
142 })
143 .collect();
144
145 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
146
147 ranked
148 .into_iter()
149 .take(top_n)
150 .map(|(e, _)| e.path.clone())
151 .collect()
152 }
153
154 fn tfidf_vector(&self, tf: &HashMap<String, f64>) -> HashMap<String, f64> {
155 tf.iter()
156 .map(|(term, freq)| {
157 let idf = self.idf.get(term).copied().unwrap_or(1.0);
158 (term.clone(), freq * idf)
159 })
160 .collect()
161 }
162
163 fn tfidf_vector_from_stored(&self, stored: &[(String, f64)]) -> HashMap<String, f64> {
164 stored
165 .iter()
166 .map(|(term, freq)| {
167 let idf = self.idf.get(term).copied().unwrap_or(1.0);
168 (term.clone(), freq * idf)
169 })
170 .collect()
171 }
172
173 pub fn save(&self, project_root: &str) -> Result<(), String> {
174 let path = index_path(project_root);
175 if let Some(dir) = path.parent() {
176 std::fs::create_dir_all(dir).map_err(|e| e.to_string())?;
177 }
178 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
179 std::fs::write(&path, json).map_err(|e| e.to_string())
180 }
181
182 pub fn load(project_root: &str) -> Option<Self> {
183 let path = index_path(project_root);
184 let content = std::fs::read_to_string(&path)
185 .or_else(|_| {
186 let legacy = legacy_index_path(project_root);
187 if legacy == path {
188 return Err(std::io::Error::new(
189 std::io::ErrorKind::NotFound,
190 "same path",
191 ));
192 }
193 let data = std::fs::read_to_string(&legacy)?;
194 let _ = std::fs::copy(&legacy, &path);
195 Ok(data)
196 })
197 .ok()?;
198 let mut index: SemanticCacheIndex = serde_json::from_str(&content).ok()?;
199 index.repair_after_deserialize();
200 Some(index)
201 }
202
203 pub fn load_or_create(project_root: &str) -> Self {
204 Self::load(project_root).unwrap_or_default()
205 }
206}
207
208fn remove_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
209 let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
210 for term in unique {
211 if let Some(c) = df.get_mut(term) {
212 *c = c.saturating_sub(1);
213 if *c == 0 {
214 df.remove(term);
215 }
216 }
217 }
218}
219
220fn add_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
221 let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
222 for term in unique {
223 *df.entry(term.to_string()).or_default() += 1;
224 }
225}
226
227fn compute_tf(content: &str) -> HashMap<String, f64> {
228 let mut counts: HashMap<String, usize> = HashMap::new();
229 let mut total = 0usize;
230
231 for word in content.split(|c: char| !c.is_alphanumeric() && c != '_') {
232 let w = word.to_lowercase();
233 if w.len() >= 2 {
234 *counts.entry(w).or_default() += 1;
235 total += 1;
236 }
237 }
238
239 if total == 0 {
240 return HashMap::new();
241 }
242
243 counts
244 .into_iter()
245 .map(|(term, count)| (term, count as f64 / total as f64))
246 .collect()
247}
248
249fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
250 let mut dot = 0.0f64;
251 let mut norm_a = 0.0f64;
252 let mut norm_b = 0.0f64;
253
254 for (term, val) in a {
255 norm_a += val * val;
256 if let Some(bval) = b.get(term) {
257 dot += val * bval;
258 }
259 }
260 for val in b.values() {
261 norm_b += val * val;
262 }
263
264 let denom = norm_a.sqrt() * norm_b.sqrt();
265 if denom < 1e-10 {
266 return 0.0;
267 }
268 dot / denom
269}
270
271fn index_path(project_root: &str) -> PathBuf {
272 let hash = crate::core::project_hash::hash_project_root(project_root);
273 crate::core::data_dir::lean_ctx_data_dir()
274 .unwrap_or_default()
275 .join("semantic_cache")
276 .join(format!("{hash}.json"))
277}
278
279fn legacy_index_path(project_root: &str) -> PathBuf {
280 use md5::{Digest, Md5};
281 let mut hasher = Md5::new();
282 hasher.update(project_root.as_bytes());
283 let hash = format!("{:x}", hasher.finalize());
284 crate::core::data_dir::lean_ctx_data_dir()
285 .unwrap_or_default()
286 .join("semantic_cache")
287 .join(format!("{hash}.json"))
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn compute_tf_basic() {
296 let tf = compute_tf("fn handle_request request response handle");
297 assert!(tf.contains_key("handle"));
298 assert!(tf.contains_key("request"));
299 assert!(tf["handle"] > 0.0);
300 }
301
302 #[test]
303 fn cosine_identical() {
304 let mut a = HashMap::new();
305 a.insert("hello".to_string(), 1.0);
306 a.insert("world".to_string(), 0.5);
307 let sim = cosine_similarity(&a, &a);
308 assert!((sim - 1.0).abs() < 0.001);
309 }
310
311 #[test]
312 fn cosine_orthogonal() {
313 let mut a = HashMap::new();
314 a.insert("hello".to_string(), 1.0);
315 let mut b = HashMap::new();
316 b.insert("world".to_string(), 1.0);
317 let sim = cosine_similarity(&a, &b);
318 assert!(sim.abs() < 0.001);
319 }
320
321 #[test]
322 fn add_and_find_similar() {
323 let mut index = SemanticCacheIndex::default();
324 index.add_file(
325 "auth.rs",
326 "fn validate_token check jwt expiry auth login",
327 "s1",
328 );
329 index.add_file(
330 "db.rs",
331 "fn connect_database pool query insert delete",
332 "s1",
333 );
334
335 let results = index.find_similar("validate auth token jwt", 0.1);
336 assert!(!results.is_empty());
337 assert_eq!(results[0].0, "auth.rs");
338 }
339
340 #[test]
341 fn warmup_suggestions() {
342 let mut index = SemanticCacheIndex::default();
343 index.add_file("hot.rs", "frequently accessed file", "s1");
344 index.entries[0].access_count = 50;
345 index.add_file("cold.rs", "rarely used", "s1");
346
347 let warmup = index.suggest_warmup(1);
348 assert_eq!(warmup.len(), 1);
349 assert_eq!(warmup[0], "hot.rs");
350 }
351}