1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ModeOutcome {
12 pub mode: String,
13 pub tokens_in: usize,
14 pub tokens_out: usize,
15 pub density: f64,
16}
17
18impl ModeOutcome {
19 pub fn efficiency(&self) -> f64 {
20 if self.tokens_out == 0 {
21 return 0.0;
22 }
23 self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
24 }
25}
26
27#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
28pub struct FileSignature {
29 pub ext: String,
30 pub size_bucket: u8,
31}
32
33impl FileSignature {
34 pub fn from_path(path: &str, token_count: usize) -> Self {
35 let ext = std::path::Path::new(path)
36 .extension()
37 .and_then(|e| e.to_str())
38 .unwrap_or("")
39 .to_string();
40 let size_bucket = match token_count {
41 0..=500 => 0,
42 501..=2000 => 1,
43 2001..=5000 => 2,
44 5001..=20000 => 3,
45 _ => 4,
46 };
47 Self { ext, size_bucket }
48 }
49}
50
51#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
52pub struct ModePredictor {
53 history: HashMap<FileSignature, Vec<ModeOutcome>>,
54}
55
56impl ModePredictor {
57 pub fn new() -> Self {
58 let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
59 if let Some((ref predictor, _)) = *guard {
60 return predictor.clone();
61 }
62 let loaded = Self::load_from_disk().unwrap_or_default();
63 *guard = Some((loaded.clone(), Instant::now()));
64 loaded
65 }
66
67 pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
68 let entries = self.history.entry(sig).or_default();
69 entries.push(outcome);
70 if entries.len() > 100 {
71 entries.drain(0..50);
72 }
73 }
74
75 pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
78 if let Some(local) = self.predict_from_local(sig) {
79 return Some(local);
80 }
81 if let Some(cloud) = self.predict_from_cloud(sig) {
82 return Some(cloud);
83 }
84 Self::predict_from_defaults(sig)
85 }
86
87 fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
88 let entries = self.history.get(sig)?;
89 if entries.len() < 3 {
90 return None;
91 }
92
93 let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
94 for entry in entries {
95 let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
96 *sum += entry.efficiency();
97 *count += 1;
98 }
99
100 mode_scores
101 .into_iter()
102 .max_by(|a, b| {
103 let avg_a = a.1 .0 / a.1 .1 as f64;
104 let avg_b = b.1 .0 / b.1 .1 as f64;
105 avg_a
106 .partial_cmp(&avg_b)
107 .unwrap_or(std::cmp::Ordering::Equal)
108 })
109 .map(|(mode, _)| mode.to_string())
110 }
111
112 fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
115 let data = crate::cloud_client::load_cloud_models()?;
116 let models = data["models"].as_array()?;
117
118 let ext_with_dot = format!(".{}", sig.ext);
119 let bucket_name = match sig.size_bucket {
120 0 => "0-500",
121 1 => "500-2k",
122 2 => "2k-10k",
123 3 => "10k+",
124 _ => "10k+",
125 };
126
127 let mut best: Option<(&str, f64)> = None;
128
129 for model in models {
130 let m_ext = model["file_ext"].as_str().unwrap_or("");
131 let m_bucket = model["size_bucket"].as_str().unwrap_or("");
132 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
133
134 if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
135 if let Some(mode) = model["recommended_mode"].as_str() {
136 if best.is_none() || confidence > best.unwrap().1 {
137 best = Some((mode, confidence));
138 }
139 }
140 }
141 }
142
143 if let Some((mode, _)) = best {
144 return Some(mode.to_string());
145 }
146
147 for model in models {
148 let m_ext = model["file_ext"].as_str().unwrap_or("");
149 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
150 if m_ext == ext_with_dot && confidence > 0.5 {
151 return model["recommended_mode"].as_str().map(|s| s.to_string());
152 }
153 }
154
155 None
156 }
157
158 fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
162 let mode = match (sig.ext.as_str(), sig.size_bucket) {
163 (_, 0) => return None,
165
166 ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _) => "aggressive",
168
169 ("lock", _) => "signatures",
171
172 (
175 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
176 | "swift" | "kt" | "cs" | "vue" | "svelte",
177 1,
178 ) => return None,
179
180 (
182 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
183 | "swift" | "kt" | "cs" | "vue" | "svelte",
184 2,
185 ) => "map",
186
187 (
189 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
190 | "swift" | "kt" | "cs" | "vue" | "svelte",
191 3,
192 ) => "map",
193
194 (
196 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
197 | "swift" | "kt" | "cs" | "vue" | "svelte",
198 4..,
199 ) => "signatures",
200
201 ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 1..=2) => return None,
203 ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 3..) => "aggressive",
204
205 ("css" | "scss" | "less" | "sass", 2..) => "aggressive",
207
208 ("sql", 2..) => "map",
210
211 (_, 3..) => "aggressive",
213
214 _ => return None,
215 };
216 Some(mode.to_string())
217 }
218
219 pub fn save(&self) {
220 let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
221 let should_flush = match *guard {
222 Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
223 None => true,
224 };
225 *guard = Some((self.clone(), Instant::now()));
226 if should_flush {
227 self.save_to_disk();
228 }
229 }
230
231 fn save_to_disk(&self) {
232 let dir = match crate::core::data_dir::lean_ctx_data_dir() {
233 Ok(d) => d,
234 Err(_) => return,
235 };
236 let _ = std::fs::create_dir_all(&dir);
237 let path = dir.join(STATS_FILE);
238 if let Ok(json) = serde_json::to_string_pretty(self) {
239 let tmp = dir.join(".mode_stats.tmp");
240 if std::fs::write(&tmp, &json).is_ok() {
241 let _ = std::fs::rename(&tmp, &path);
242 }
243 }
244 }
245
246 pub fn flush() {
247 let guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
248 if let Some((ref predictor, _)) = *guard {
249 predictor.save_to_disk();
250 }
251 }
252
253 fn load_from_disk() -> Option<Self> {
254 let path = crate::core::data_dir::lean_ctx_data_dir()
255 .ok()?
256 .join(STATS_FILE);
257 let data = std::fs::read_to_string(path).ok()?;
258 serde_json::from_str(&data).ok()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn file_signature_buckets() {
268 assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
269 assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
270 assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
271 assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
272 assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
273 }
274
275 #[test]
276 fn predict_returns_none_without_history() {
277 let predictor = ModePredictor::default();
278 let sig = FileSignature::from_path("test.zzz", 500);
279 assert!(predictor.predict_from_local(&sig).is_none());
280 }
281
282 #[test]
283 fn predict_returns_none_with_too_few_entries() {
284 let mut predictor = ModePredictor::default();
285 let sig = FileSignature::from_path("test.zzz", 500);
286 predictor.record(
287 sig.clone(),
288 ModeOutcome {
289 mode: "full".to_string(),
290 tokens_in: 100,
291 tokens_out: 100,
292 density: 0.5,
293 },
294 );
295 assert!(predictor.predict_from_local(&sig).is_none());
296 }
297
298 #[test]
299 fn predict_learns_best_mode() {
300 let mut predictor = ModePredictor::default();
301 let sig = FileSignature::from_path("big.rs", 5000);
302 for _ in 0..5 {
303 predictor.record(
304 sig.clone(),
305 ModeOutcome {
306 mode: "full".to_string(),
307 tokens_in: 5000,
308 tokens_out: 5000,
309 density: 0.3,
310 },
311 );
312 predictor.record(
313 sig.clone(),
314 ModeOutcome {
315 mode: "map".to_string(),
316 tokens_in: 5000,
317 tokens_out: 800,
318 density: 0.6,
319 },
320 );
321 }
322 let best = predictor.predict_best_mode(&sig);
323 assert_eq!(best, Some("map".to_string()));
324 }
325
326 #[test]
327 fn history_caps_at_100() {
328 let mut predictor = ModePredictor::default();
329 let sig = FileSignature::from_path("test.rs", 100);
330 for _ in 0..120 {
331 predictor.record(
332 sig.clone(),
333 ModeOutcome {
334 mode: "full".to_string(),
335 tokens_in: 100,
336 tokens_out: 100,
337 density: 0.5,
338 },
339 );
340 }
341 assert!(predictor.history.get(&sig).unwrap().len() <= 100);
342 }
343
344 #[test]
345 fn defaults_return_none_for_small_files() {
346 let sig = FileSignature::from_path("small.rs", 200);
347 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
348 }
349
350 #[test]
351 fn defaults_recommend_map_for_medium_code() {
352 let sig = FileSignature::from_path("medium.rs", 3000);
353 assert_eq!(
354 ModePredictor::predict_from_defaults(&sig),
355 Some("map".to_string())
356 );
357 }
358
359 #[test]
360 fn defaults_recommend_aggressive_for_json() {
361 let sig = FileSignature::from_path("config.json", 1000);
362 assert_eq!(
363 ModePredictor::predict_from_defaults(&sig),
364 Some("aggressive".to_string())
365 );
366 }
367
368 #[test]
369 fn defaults_recommend_signatures_for_huge_code() {
370 let sig = FileSignature::from_path("huge.ts", 25000);
371 assert_eq!(
372 ModePredictor::predict_from_defaults(&sig),
373 Some("signatures".to_string())
374 );
375 }
376
377 #[test]
378 fn defaults_recommend_aggressive_for_large_unknown() {
379 let sig = FileSignature::from_path("data.xyz", 8000);
380 assert_eq!(
381 ModePredictor::predict_from_defaults(&sig),
382 Some("aggressive".to_string())
383 );
384 }
385
386 #[test]
387 fn mode_outcome_efficiency() {
388 let o = ModeOutcome {
389 mode: "map".to_string(),
390 tokens_in: 1000,
391 tokens_out: 200,
392 density: 0.6,
393 };
394 assert!(o.efficiency() > 0.0);
395 }
396}