1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6const IDF_REBUILD_BATCH: u32 = 100;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SemanticCacheEntry {
11 pub path: String,
12 pub tfidf_vector: Vec<(String, f64)>,
13 pub token_count: usize,
14 pub access_count: u32,
15 pub last_session: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct SemanticCacheIndex {
20 pub entries: Vec<SemanticCacheEntry>,
21 pub idf: HashMap<String, f64>,
22 pub total_docs: usize,
23 #[serde(default)]
25 pub term_document_freq: HashMap<String, usize>,
26 #[serde(default)]
27 idf_dirty: bool,
28 #[serde(default)]
29 mutations_since_idf_rebuild: u32,
30}
31
32impl SemanticCacheIndex {
33 pub fn add_file(&mut self, path: &str, content: &str, session_id: &str) {
34 let tf = compute_tf(content);
35 let token_count = content.split_whitespace().count();
36
37 if let Some(existing) = self.entries.iter_mut().find(|e| e.path == path) {
38 remove_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
39 existing.tfidf_vector = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
40 existing.token_count = token_count;
41 existing.access_count += 1;
42 existing.last_session = session_id.to_string();
43 add_doc_terms(&mut self.term_document_freq, &existing.tfidf_vector);
44 } else {
45 let tf_vec: Vec<(String, f64)> = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
46 add_doc_terms(&mut self.term_document_freq, &tf_vec);
47 self.entries.push(SemanticCacheEntry {
48 path: path.to_string(),
49 tfidf_vector: tf_vec,
50 token_count,
51 access_count: 1,
52 last_session: session_id.to_string(),
53 });
54 }
55
56 self.total_docs = self.entries.len();
57 self.note_idf_mutation();
58 }
59
60 fn note_idf_mutation(&mut self) {
61 self.idf_dirty = true;
62 self.mutations_since_idf_rebuild = self.mutations_since_idf_rebuild.saturating_add(1);
63 if self.mutations_since_idf_rebuild >= IDF_REBUILD_BATCH {
64 self.recompute_idf_from_df();
65 self.idf_dirty = false;
66 self.mutations_since_idf_rebuild = 0;
67 }
68 }
69
70 fn recompute_idf_from_df(&mut self) {
71 self.idf.clear();
72 let n = self.total_docs as f64;
73 if n <= 0.0 {
74 return;
75 }
76 for (term, count) in &self.term_document_freq {
77 let idf = (n / (*count as f64 + 1.0)).ln() + 1.0;
78 self.idf.insert(term.clone(), idf);
79 }
80 }
81
82 fn rebuild_df_from_entries(&mut self) {
83 self.term_document_freq.clear();
84 for entry in &self.entries {
85 add_doc_terms(&mut self.term_document_freq, &entry.tfidf_vector);
86 }
87 }
88
89 fn repair_after_deserialize(&mut self) {
90 self.total_docs = self.entries.len();
91 if self.term_document_freq.is_empty() && !self.entries.is_empty() {
92 self.rebuild_df_from_entries();
93 self.idf_dirty = true;
94 }
95 }
96
97 fn ensure_idf_for_search(&mut self) {
98 if self.idf_dirty {
99 self.recompute_idf_from_df();
100 self.idf_dirty = false;
101 self.mutations_since_idf_rebuild = 0;
102 }
103 }
104
105 pub fn find_similar(&mut self, content: &str, threshold: f64) -> Vec<(String, f64)> {
106 self.ensure_idf_for_search();
107
108 let query_tf = compute_tf(content);
109 let query_vec = self.tfidf_vector(&query_tf);
110
111 let mut results: Vec<(String, f64)> = self
112 .entries
113 .iter()
114 .filter_map(|entry| {
115 let entry_vec = self.tfidf_vector_from_stored(&entry.tfidf_vector);
116 let sim = cosine_similarity(&query_vec, &entry_vec);
117 if sim >= threshold {
118 Some((entry.path.clone(), sim))
119 } else {
120 None
121 }
122 })
123 .collect();
124
125 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
126 results
127 }
128
129 pub fn suggest_warmup(&self, top_n: usize) -> Vec<String> {
130 let mut ranked: Vec<(&SemanticCacheEntry, f64)> = self
131 .entries
132 .iter()
133 .map(|e| {
134 let score = e.access_count as f64 * 0.6 + e.token_count as f64 * 0.0001;
135 (e, score)
136 })
137 .collect();
138
139 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
140
141 ranked
142 .into_iter()
143 .take(top_n)
144 .map(|(e, _)| e.path.clone())
145 .collect()
146 }
147
148 fn tfidf_vector(&self, tf: &HashMap<String, f64>) -> HashMap<String, f64> {
149 tf.iter()
150 .map(|(term, freq)| {
151 let idf = self.idf.get(term).copied().unwrap_or(1.0);
152 (term.clone(), freq * idf)
153 })
154 .collect()
155 }
156
157 fn tfidf_vector_from_stored(&self, stored: &[(String, f64)]) -> HashMap<String, f64> {
158 stored
159 .iter()
160 .map(|(term, freq)| {
161 let idf = self.idf.get(term).copied().unwrap_or(1.0);
162 (term.clone(), freq * idf)
163 })
164 .collect()
165 }
166
167 pub fn save(&self, project_root: &str) -> Result<(), String> {
168 let path = index_path(project_root);
169 if let Some(dir) = path.parent() {
170 std::fs::create_dir_all(dir).map_err(|e| e.to_string())?;
171 }
172 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
173 std::fs::write(&path, json).map_err(|e| e.to_string())
174 }
175
176 pub fn load(project_root: &str) -> Option<Self> {
177 let path = index_path(project_root);
178 let content = std::fs::read_to_string(&path)
179 .or_else(|_| {
180 let legacy = legacy_index_path(project_root);
181 if legacy == path {
182 return Err(std::io::Error::new(
183 std::io::ErrorKind::NotFound,
184 "same path",
185 ));
186 }
187 let data = std::fs::read_to_string(&legacy)?;
188 let _ = std::fs::copy(&legacy, &path);
189 Ok(data)
190 })
191 .ok()?;
192 let mut index: SemanticCacheIndex = serde_json::from_str(&content).ok()?;
193 index.repair_after_deserialize();
194 Some(index)
195 }
196
197 pub fn load_or_create(project_root: &str) -> Self {
198 Self::load(project_root).unwrap_or_default()
199 }
200}
201
202fn remove_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
203 let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
204 for term in unique {
205 if let Some(c) = df.get_mut(term) {
206 *c = c.saturating_sub(1);
207 if *c == 0 {
208 df.remove(term);
209 }
210 }
211 }
212}
213
214fn add_doc_terms(df: &mut HashMap<String, usize>, tf_vec: &[(String, f64)]) {
215 let unique: HashSet<&str> = tf_vec.iter().map(|(k, _)| k.as_str()).collect();
216 for term in unique {
217 *df.entry(term.to_string()).or_default() += 1;
218 }
219}
220
221fn compute_tf(content: &str) -> HashMap<String, f64> {
222 let mut counts: HashMap<String, usize> = HashMap::new();
223 let mut total = 0usize;
224
225 for word in content.split(|c: char| !c.is_alphanumeric() && c != '_') {
226 let w = word.to_lowercase();
227 if w.len() >= 2 {
228 *counts.entry(w).or_default() += 1;
229 total += 1;
230 }
231 }
232
233 if total == 0 {
234 return HashMap::new();
235 }
236
237 counts
238 .into_iter()
239 .map(|(term, count)| (term, count as f64 / total as f64))
240 .collect()
241}
242
243fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
244 let mut dot = 0.0f64;
245 let mut norm_a = 0.0f64;
246 let mut norm_b = 0.0f64;
247
248 for (term, val) in a {
249 norm_a += val * val;
250 if let Some(bval) = b.get(term) {
251 dot += val * bval;
252 }
253 }
254 for val in b.values() {
255 norm_b += val * val;
256 }
257
258 let denom = norm_a.sqrt() * norm_b.sqrt();
259 if denom < 1e-10 {
260 return 0.0;
261 }
262 dot / denom
263}
264
265fn index_path(project_root: &str) -> PathBuf {
266 let hash = crate::core::project_hash::hash_project_root(project_root);
267 crate::core::data_dir::lean_ctx_data_dir()
268 .unwrap_or_default()
269 .join("semantic_cache")
270 .join(format!("{hash}.json"))
271}
272
273fn legacy_index_path(project_root: &str) -> PathBuf {
274 use md5::{Digest, Md5};
275 let mut hasher = Md5::new();
276 hasher.update(project_root.as_bytes());
277 let hash = format!("{:x}", hasher.finalize());
278 crate::core::data_dir::lean_ctx_data_dir()
279 .unwrap_or_default()
280 .join("semantic_cache")
281 .join(format!("{hash}.json"))
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn compute_tf_basic() {
290 let tf = compute_tf("fn handle_request request response handle");
291 assert!(tf.contains_key("handle"));
292 assert!(tf.contains_key("request"));
293 assert!(tf["handle"] > 0.0);
294 }
295
296 #[test]
297 fn cosine_identical() {
298 let mut a = HashMap::new();
299 a.insert("hello".to_string(), 1.0);
300 a.insert("world".to_string(), 0.5);
301 let sim = cosine_similarity(&a, &a);
302 assert!((sim - 1.0).abs() < 0.001);
303 }
304
305 #[test]
306 fn cosine_orthogonal() {
307 let mut a = HashMap::new();
308 a.insert("hello".to_string(), 1.0);
309 let mut b = HashMap::new();
310 b.insert("world".to_string(), 1.0);
311 let sim = cosine_similarity(&a, &b);
312 assert!(sim.abs() < 0.001);
313 }
314
315 #[test]
316 fn add_and_find_similar() {
317 let mut index = SemanticCacheIndex::default();
318 index.add_file(
319 "auth.rs",
320 "fn validate_token check jwt expiry auth login",
321 "s1",
322 );
323 index.add_file(
324 "db.rs",
325 "fn connect_database pool query insert delete",
326 "s1",
327 );
328
329 let results = index.find_similar("validate auth token jwt", 0.1);
330 assert!(!results.is_empty());
331 assert_eq!(results[0].0, "auth.rs");
332 }
333
334 #[test]
335 fn warmup_suggestions() {
336 let mut index = SemanticCacheIndex::default();
337 index.add_file("hot.rs", "frequently accessed file", "s1");
338 index.entries[0].access_count = 50;
339 index.add_file("cold.rs", "rarely used", "s1");
340
341 let warmup = index.suggest_warmup(1);
342 assert_eq!(warmup.len(), 1);
343 assert_eq!(warmup[0], "hot.rs");
344 }
345}