1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13 pub mode: String,
14 pub tokens_in: usize,
15 pub tokens_out: usize,
16 pub density: f64,
17}
18
19impl ModeOutcome {
20 pub fn efficiency(&self) -> f64 {
22 if self.tokens_out == 0 {
23 return 0.0;
24 }
25 self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26 }
27}
28
29#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32 pub ext: String,
33 pub size_bucket: u8,
34}
35
36impl FileSignature {
37 pub fn from_path(path: &str, token_count: usize) -> Self {
39 let ext = std::path::Path::new(path)
40 .extension()
41 .and_then(|e| e.to_str())
42 .unwrap_or("")
43 .to_string();
44 let size_bucket = match token_count {
45 0..=500 => 0,
46 501..=2000 => 1,
47 2001..=5000 => 2,
48 5001..=20000 => 3,
49 _ => 4,
50 };
51 Self { ext, size_bucket }
52 }
53}
54
55#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58 history: HashMap<FileSignature, Vec<ModeOutcome>>,
59}
60
61impl ModePredictor {
62 pub fn new() -> Self {
64 let mut guard = PREDICTOR_BUFFER
65 .lock()
66 .unwrap_or_else(std::sync::PoisonError::into_inner);
67 if let Some((ref predictor, _)) = *guard {
68 return predictor.clone();
69 }
70 let loaded = Self::load_from_disk().unwrap_or_default();
71 *guard = Some((loaded.clone(), Instant::now()));
72 loaded
73 }
74
75 pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
77 let entries = self.history.entry(sig).or_default();
78 entries.push(outcome);
79 if entries.len() > 100 {
80 entries.drain(0..50);
81 }
82 }
83
84 pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
87 if let Some(local) = self.predict_from_local(sig) {
88 return Some(local);
89 }
90 if let Some(cloud) = self.predict_from_cloud(sig) {
91 return Some(cloud);
92 }
93 Self::predict_from_defaults(sig)
94 }
95
96 fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
97 let entries = self.history.get(sig)?;
98 if entries.len() < 3 {
99 return None;
100 }
101
102 let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
103 for entry in entries {
104 let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
105 *sum += entry.efficiency();
106 *count += 1;
107 }
108
109 mode_scores
110 .into_iter()
111 .max_by(|a, b| {
112 let avg_a = a.1 .0 / a.1 .1 as f64;
113 let avg_b = b.1 .0 / b.1 .1 as f64;
114 avg_a
115 .partial_cmp(&avg_b)
116 .unwrap_or(std::cmp::Ordering::Equal)
117 })
118 .map(|(mode, _)| mode.to_string())
119 }
120
121 #[allow(clippy::unused_self)]
124 fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
125 let data = crate::cloud_client::load_cloud_models()?;
126 let models = data["models"].as_array()?;
127
128 let ext_with_dot = format!(".{}", sig.ext);
129 let bucket_name = match sig.size_bucket {
130 0 => "0-500",
131 1 => "500-2k",
132 2 => "2k-10k",
133 _ => "10k+",
134 };
135
136 let mut best: Option<(&str, f64)> = None;
137
138 for model in models {
139 let m_ext = model["file_ext"].as_str().unwrap_or("");
140 let m_bucket = model["size_bucket"].as_str().unwrap_or("");
141 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
142
143 if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
144 if let Some(mode) = model["recommended_mode"].as_str() {
145 if best.is_none_or(|(_, c)| confidence > c) {
146 best = Some((mode, confidence));
147 }
148 }
149 }
150 }
151
152 if let Some((mode, _)) = best {
153 return Some(mode.to_string());
154 }
155
156 for model in models {
157 let m_ext = model["file_ext"].as_str().unwrap_or("");
158 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
159 if m_ext == ext_with_dot && confidence > 0.5 {
160 return model["recommended_mode"]
161 .as_str()
162 .map(std::string::ToString::to_string);
163 }
164 }
165
166 None
167 }
168
169 fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
173 let mode = match (sig.ext.as_str(), sig.size_bucket) {
174 (_, 0) => return None,
176
177 ("lock", _)
179 | (
180 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
181 | "swift" | "kt" | "cs" | "vue" | "svelte",
182 4..,
183 ) => "signatures",
184
185 (
187 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
188 | "swift" | "kt" | "cs" | "vue" | "svelte",
189 2 | 3,
190 )
191 | ("sql", 2..) => "map",
192
193 ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _)
195 | ("css" | "scss" | "less" | "sass", 2..)
196 | (_, 3..) => "aggressive",
197
198 _ => return None,
199 };
200 Some(mode.to_string())
201 }
202
203 pub fn save(&self) {
205 let mut guard = PREDICTOR_BUFFER
206 .lock()
207 .unwrap_or_else(std::sync::PoisonError::into_inner);
208 let should_flush = match *guard {
209 Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
210 None => true,
211 };
212 *guard = Some((self.clone(), Instant::now()));
213 if should_flush {
214 self.save_to_disk();
215 }
216 }
217
218 fn save_to_disk(&self) {
219 let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
220 return;
221 };
222 let _ = std::fs::create_dir_all(&dir);
223 let path = dir.join(STATS_FILE);
224 if let Ok(json) = serde_json::to_string_pretty(self) {
225 let tmp = dir.join(".mode_stats.tmp");
226 if std::fs::write(&tmp, &json).is_ok() {
227 let _ = std::fs::rename(&tmp, &path);
228 }
229 }
230 }
231
232 pub fn flush() {
234 let guard = PREDICTOR_BUFFER
235 .lock()
236 .unwrap_or_else(std::sync::PoisonError::into_inner);
237 if let Some((ref predictor, _)) = *guard {
238 predictor.save_to_disk();
239 }
240 }
241
242 fn load_from_disk() -> Option<Self> {
243 let path = crate::core::data_dir::lean_ctx_data_dir()
244 .ok()?
245 .join(STATS_FILE);
246 let data = std::fs::read_to_string(path).ok()?;
247 serde_json::from_str(&data).ok()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn file_signature_buckets() {
257 assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
258 assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
259 assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
260 assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
261 assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
262 }
263
264 #[test]
265 fn predict_returns_none_without_history() {
266 let predictor = ModePredictor::default();
267 let sig = FileSignature::from_path("test.zzz", 500);
268 assert!(predictor.predict_from_local(&sig).is_none());
269 }
270
271 #[test]
272 fn predict_returns_none_with_too_few_entries() {
273 let mut predictor = ModePredictor::default();
274 let sig = FileSignature::from_path("test.zzz", 500);
275 predictor.record(
276 sig.clone(),
277 ModeOutcome {
278 mode: "full".to_string(),
279 tokens_in: 100,
280 tokens_out: 100,
281 density: 0.5,
282 },
283 );
284 assert!(predictor.predict_from_local(&sig).is_none());
285 }
286
287 #[test]
288 fn predict_learns_best_mode() {
289 let mut predictor = ModePredictor::default();
290 let sig = FileSignature::from_path("big.rs", 5000);
291 for _ in 0..5 {
292 predictor.record(
293 sig.clone(),
294 ModeOutcome {
295 mode: "full".to_string(),
296 tokens_in: 5000,
297 tokens_out: 5000,
298 density: 0.3,
299 },
300 );
301 predictor.record(
302 sig.clone(),
303 ModeOutcome {
304 mode: "map".to_string(),
305 tokens_in: 5000,
306 tokens_out: 800,
307 density: 0.6,
308 },
309 );
310 }
311 let best = predictor.predict_best_mode(&sig);
312 assert_eq!(best, Some("map".to_string()));
313 }
314
315 #[test]
316 fn history_caps_at_100() {
317 let mut predictor = ModePredictor::default();
318 let sig = FileSignature::from_path("test.rs", 100);
319 for _ in 0..120 {
320 predictor.record(
321 sig.clone(),
322 ModeOutcome {
323 mode: "full".to_string(),
324 tokens_in: 100,
325 tokens_out: 100,
326 density: 0.5,
327 },
328 );
329 }
330 assert!(predictor.history.get(&sig).unwrap().len() <= 100);
331 }
332
333 #[test]
334 fn defaults_return_none_for_small_files() {
335 let sig = FileSignature::from_path("small.rs", 200);
336 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
337 }
338
339 #[test]
340 fn defaults_recommend_map_for_medium_code() {
341 let sig = FileSignature::from_path("medium.rs", 3000);
342 assert_eq!(
343 ModePredictor::predict_from_defaults(&sig),
344 Some("map".to_string())
345 );
346 }
347
348 #[test]
349 fn defaults_recommend_aggressive_for_json() {
350 let sig = FileSignature::from_path("config.json", 1000);
351 assert_eq!(
352 ModePredictor::predict_from_defaults(&sig),
353 Some("aggressive".to_string())
354 );
355 }
356
357 #[test]
358 fn defaults_recommend_signatures_for_huge_code() {
359 let sig = FileSignature::from_path("huge.ts", 25000);
360 assert_eq!(
361 ModePredictor::predict_from_defaults(&sig),
362 Some("signatures".to_string())
363 );
364 }
365
366 #[test]
367 fn defaults_recommend_aggressive_for_large_unknown() {
368 let sig = FileSignature::from_path("data.xyz", 8000);
369 assert_eq!(
370 ModePredictor::predict_from_defaults(&sig),
371 Some("aggressive".to_string())
372 );
373 }
374
375 #[test]
376 fn mode_outcome_efficiency() {
377 let o = ModeOutcome {
378 mode: "map".to_string(),
379 tokens_in: 1000,
380 tokens_out: 200,
381 density: 0.6,
382 };
383 assert!(o.efficiency() > 0.0);
384 }
385}