1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SemanticCacheEntry {
8 pub path: String,
9 pub tfidf_vector: Vec<(String, f64)>,
10 pub token_count: usize,
11 pub access_count: u32,
12 pub last_session: String,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct SemanticCacheIndex {
17 pub entries: Vec<SemanticCacheEntry>,
18 pub idf: HashMap<String, f64>,
19 pub total_docs: usize,
20}
21
22impl SemanticCacheIndex {
23 pub fn add_file(&mut self, path: &str, content: &str, session_id: &str) {
24 let tf = compute_tf(content);
25 let token_count = content.split_whitespace().count();
26
27 if let Some(existing) = self.entries.iter_mut().find(|e| e.path == path) {
28 existing.tfidf_vector = tf.iter().map(|(k, v)| (k.clone(), *v)).collect();
29 existing.token_count = token_count;
30 existing.access_count += 1;
31 existing.last_session = session_id.to_string();
32 } else {
33 self.entries.push(SemanticCacheEntry {
34 path: path.to_string(),
35 tfidf_vector: tf.iter().map(|(k, v)| (k.clone(), *v)).collect(),
36 token_count,
37 access_count: 1,
38 last_session: session_id.to_string(),
39 });
40 }
41
42 self.total_docs = self.entries.len();
43 self.rebuild_idf();
44 }
45
46 fn rebuild_idf(&mut self) {
47 let mut df: HashMap<String, usize> = HashMap::new();
48 for entry in &self.entries {
49 let unique_terms: std::collections::HashSet<&str> =
50 entry.tfidf_vector.iter().map(|(k, _)| k.as_str()).collect();
51 for term in unique_terms {
52 *df.entry(term.to_string()).or_default() += 1;
53 }
54 }
55
56 self.idf.clear();
57 let n = self.total_docs as f64;
58 for (term, count) in &df {
59 let idf = (n / (*count as f64 + 1.0)).ln() + 1.0;
60 self.idf.insert(term.clone(), idf);
61 }
62 }
63
64 pub fn find_similar(&self, content: &str, threshold: f64) -> Vec<(String, f64)> {
65 let query_tf = compute_tf(content);
66 let query_vec = self.tfidf_vector(&query_tf);
67
68 let mut results: Vec<(String, f64)> = self
69 .entries
70 .iter()
71 .filter_map(|entry| {
72 let entry_vec = self.tfidf_vector_from_stored(&entry.tfidf_vector);
73 let sim = cosine_similarity(&query_vec, &entry_vec);
74 if sim >= threshold {
75 Some((entry.path.clone(), sim))
76 } else {
77 None
78 }
79 })
80 .collect();
81
82 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
83 results
84 }
85
86 pub fn suggest_warmup(&self, top_n: usize) -> Vec<String> {
87 let mut ranked: Vec<(&SemanticCacheEntry, f64)> = self
88 .entries
89 .iter()
90 .map(|e| {
91 let score = e.access_count as f64 * 0.6 + e.token_count as f64 * 0.0001;
92 (e, score)
93 })
94 .collect();
95
96 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
97
98 ranked
99 .into_iter()
100 .take(top_n)
101 .map(|(e, _)| e.path.clone())
102 .collect()
103 }
104
105 fn tfidf_vector(&self, tf: &HashMap<String, f64>) -> HashMap<String, f64> {
106 tf.iter()
107 .map(|(term, freq)| {
108 let idf = self.idf.get(term).copied().unwrap_or(1.0);
109 (term.clone(), freq * idf)
110 })
111 .collect()
112 }
113
114 fn tfidf_vector_from_stored(&self, stored: &[(String, f64)]) -> HashMap<String, f64> {
115 stored
116 .iter()
117 .map(|(term, freq)| {
118 let idf = self.idf.get(term).copied().unwrap_or(1.0);
119 (term.clone(), freq * idf)
120 })
121 .collect()
122 }
123
124 pub fn save(&self, project_root: &str) -> Result<(), String> {
125 let path = index_path(project_root);
126 if let Some(dir) = path.parent() {
127 std::fs::create_dir_all(dir).map_err(|e| e.to_string())?;
128 }
129 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
130 std::fs::write(&path, json).map_err(|e| e.to_string())
131 }
132
133 pub fn load(project_root: &str) -> Option<Self> {
134 let path = index_path(project_root);
135 let content = std::fs::read_to_string(path).ok()?;
136 serde_json::from_str(&content).ok()
137 }
138
139 pub fn load_or_create(project_root: &str) -> Self {
140 Self::load(project_root).unwrap_or_default()
141 }
142}
143
144fn compute_tf(content: &str) -> HashMap<String, f64> {
145 let mut counts: HashMap<String, usize> = HashMap::new();
146 let mut total = 0usize;
147
148 for word in content.split(|c: char| !c.is_alphanumeric() && c != '_') {
149 let w = word.to_lowercase();
150 if w.len() >= 2 {
151 *counts.entry(w).or_default() += 1;
152 total += 1;
153 }
154 }
155
156 if total == 0 {
157 return HashMap::new();
158 }
159
160 counts
161 .into_iter()
162 .map(|(term, count)| (term, count as f64 / total as f64))
163 .collect()
164}
165
166fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
167 let mut dot = 0.0f64;
168 let mut norm_a = 0.0f64;
169 let mut norm_b = 0.0f64;
170
171 for (term, val) in a {
172 norm_a += val * val;
173 if let Some(bval) = b.get(term) {
174 dot += val * bval;
175 }
176 }
177 for val in b.values() {
178 norm_b += val * val;
179 }
180
181 let denom = norm_a.sqrt() * norm_b.sqrt();
182 if denom < 1e-10 {
183 return 0.0;
184 }
185 dot / denom
186}
187
188fn index_path(project_root: &str) -> PathBuf {
189 use md5::{Digest, Md5};
190 let mut hasher = Md5::new();
191 hasher.update(project_root.as_bytes());
192 let hash = format!("{:x}", hasher.finalize());
193 dirs::home_dir()
194 .unwrap_or_default()
195 .join(".lean-ctx")
196 .join("semantic_cache")
197 .join(format!("{hash}.json"))
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn compute_tf_basic() {
206 let tf = compute_tf("fn handle_request request response handle");
207 assert!(tf.contains_key("handle"));
208 assert!(tf.contains_key("request"));
209 assert!(tf["handle"] > 0.0);
210 }
211
212 #[test]
213 fn cosine_identical() {
214 let mut a = HashMap::new();
215 a.insert("hello".to_string(), 1.0);
216 a.insert("world".to_string(), 0.5);
217 let sim = cosine_similarity(&a, &a);
218 assert!((sim - 1.0).abs() < 0.001);
219 }
220
221 #[test]
222 fn cosine_orthogonal() {
223 let mut a = HashMap::new();
224 a.insert("hello".to_string(), 1.0);
225 let mut b = HashMap::new();
226 b.insert("world".to_string(), 1.0);
227 let sim = cosine_similarity(&a, &b);
228 assert!(sim.abs() < 0.001);
229 }
230
231 #[test]
232 fn add_and_find_similar() {
233 let mut index = SemanticCacheIndex::default();
234 index.add_file(
235 "auth.rs",
236 "fn validate_token check jwt expiry auth login",
237 "s1",
238 );
239 index.add_file(
240 "db.rs",
241 "fn connect_database pool query insert delete",
242 "s1",
243 );
244
245 let results = index.find_similar("validate auth token jwt", 0.1);
246 assert!(!results.is_empty());
247 assert_eq!(results[0].0, "auth.rs");
248 }
249
250 #[test]
251 fn warmup_suggestions() {
252 let mut index = SemanticCacheIndex::default();
253 index.add_file("hot.rs", "frequently accessed file", "s1");
254 index.entries[0].access_count = 50;
255 index.add_file("cold.rs", "rarely used", "s1");
256
257 let warmup = index.suggest_warmup(1);
258 assert_eq!(warmup.len(), 1);
259 assert_eq!(warmup[0], "hot.rs");
260 }
261}