1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(Arc<ModePredictor>, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13 pub mode: String,
14 pub tokens_in: usize,
15 pub tokens_out: usize,
16 pub density: f64,
17}
18
19impl ModeOutcome {
20 pub fn efficiency(&self) -> f64 {
22 if self.tokens_out == 0 {
23 return 0.0;
24 }
25 self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26 }
27}
28
29#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32 pub ext: String,
33 pub size_bucket: u8,
34}
35
36impl FileSignature {
37 pub fn from_path(path: &str, token_count: usize) -> Self {
39 let ext = std::path::Path::new(path)
40 .extension()
41 .and_then(|e| e.to_str())
42 .unwrap_or("")
43 .to_string();
44 let size_bucket = match token_count {
45 0..=500 => 0,
46 501..=2000 => 1,
47 2001..=5000 => 2,
48 5001..=20000 => 3,
49 _ => 4,
50 };
51 Self { ext, size_bucket }
52 }
53}
54
55#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58 history: HashMap<FileSignature, Vec<ModeOutcome>>,
59 project_root: Option<String>,
60}
61
62impl ModePredictor {
63 pub fn new() -> Self {
65 let mut guard = PREDICTOR_BUFFER
66 .lock()
67 .unwrap_or_else(std::sync::PoisonError::into_inner);
68 if let Some((ref predictor, _)) = *guard {
69 return Self {
70 history: predictor.history.clone(),
71 project_root: predictor.project_root.clone(),
72 };
73 }
74 let mut loaded = Self::load_from_disk().unwrap_or_default();
75 if loaded.project_root.is_none() {
76 loaded.project_root = std::env::current_dir()
77 .ok()
78 .map(|p| p.to_string_lossy().to_string());
79 }
80 *guard = Some((Arc::new(loaded.clone()), Instant::now()));
81 loaded
82 }
83
84 pub fn with_project_root(mut self, root: &str) -> Self {
85 self.project_root = Some(root.to_string());
86 self
87 }
88
89 pub fn set_project_root(&mut self, root: &str) {
90 self.project_root = Some(root.to_string());
91 }
92
93 pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
95 let entries = self.history.entry(sig).or_default();
96 entries.push(outcome);
97 if entries.len() > 100 {
98 entries.drain(0..50);
99 }
100 }
101
102 pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
105 let default_mode = Self::predict_from_defaults(sig);
106
107 let allow_override = |candidate: &str| -> bool {
108 let Some(def) = default_mode.as_deref() else {
109 return true;
110 };
111 if candidate == "full" {
112 return false;
113 }
114 if (def == "map" || def == "signatures")
116 && (candidate == "aggressive" || candidate == "entropy")
117 {
118 return false;
119 }
120 true
121 };
122
123 if let Some(local) = self.predict_from_local(sig) {
124 if allow_override(&local) {
125 return Some(local);
126 }
127 }
128 if let Some(bandit) = self.predict_from_bandit(sig) {
129 if allow_override(&bandit) {
130 return Some(bandit);
131 }
132 }
133 if let Some(cloud) = self.predict_from_cloud(sig) {
134 if allow_override(&cloud) {
135 return Some(cloud);
136 }
137 }
138 default_mode
139 }
140
141 fn predict_from_bandit(&self, sig: &FileSignature) -> Option<String> {
142 let key = format!("{}_feedback", sig.ext);
143 let store =
144 crate::core::bandit::BanditStore::load(self.project_root.as_deref().unwrap_or("."));
145 let bandit = store.bandits.get(&key)?;
146 if bandit.total_pulls < 5 {
147 return None;
148 }
149 let best_arm = bandit.arms.iter().max_by(|a, b| {
150 a.mean()
151 .partial_cmp(&b.mean())
152 .unwrap_or(std::cmp::Ordering::Equal)
153 })?;
154 let mode = match best_arm.name.as_str() {
155 "conservative" => "full",
156 "balanced" => "signatures",
157 "aggressive" => "aggressive",
158 _ => return None,
159 };
160 Some(mode.to_string())
161 }
162
163 fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
164 let entries = self.history.get(sig)?;
165 if entries.len() < 3 {
166 return None;
167 }
168
169 let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
170 for entry in entries {
171 let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
172 *sum += entry.efficiency();
173 *count += 1;
174 }
175
176 mode_scores
177 .into_iter()
178 .max_by(|a, b| {
179 let avg_a = a.1 .0 / a.1 .1 as f64;
180 let avg_b = b.1 .0 / b.1 .1 as f64;
181 avg_a
182 .partial_cmp(&avg_b)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 })
185 .map(|(mode, _)| mode.to_string())
186 }
187
188 #[allow(clippy::unused_self)]
191 fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
192 let data = crate::cloud_client::load_cloud_models()?;
193 let models = data["models"].as_array()?;
194
195 let ext_with_dot = format!(".{}", sig.ext);
196 let bucket_name = match sig.size_bucket {
197 0 => "0-500",
198 1 => "500-2k",
199 2 => "2k-10k",
200 _ => "10k+",
201 };
202
203 let mut best: Option<(&str, f64)> = None;
204
205 for model in models {
206 let m_ext = model["file_ext"].as_str().unwrap_or("");
207 let m_bucket = model["size_bucket"].as_str().unwrap_or("");
208 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
209
210 if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
211 if let Some(mode) = model["recommended_mode"].as_str() {
212 if best.is_none_or(|(_, c)| confidence > c) {
213 best = Some((mode, confidence));
214 }
215 }
216 }
217 }
218
219 if let Some((mode, _)) = best {
220 return Some(mode.to_string());
221 }
222
223 for model in models {
224 let m_ext = model["file_ext"].as_str().unwrap_or("");
225 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
226 if m_ext == ext_with_dot && confidence > 0.5 {
227 return model["recommended_mode"]
228 .as_str()
229 .map(std::string::ToString::to_string);
230 }
231 }
232
233 None
234 }
235
236 fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
240 if sig.size_bucket == 0 {
241 return None;
242 }
243 if matches!(sig.ext.as_str(), "md" | "mdx" | "txt" | "rst") {
244 return None;
245 }
246
247 let mode = match (sig.ext.as_str(), sig.size_bucket) {
248 ("lock", _)
250 | (
251 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
252 | "swift" | "kt" | "cs" | "vue" | "svelte",
253 4..,
254 ) => "signatures",
255
256 (
258 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
259 | "swift" | "kt" | "cs" | "vue" | "svelte",
260 2 | 3,
261 )
262 | ("sql", 2..) => "map",
263
264 ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _)
266 | ("css" | "scss" | "less" | "sass", 2..)
267 | (_, 3..) => "aggressive",
268
269 _ => return None,
270 };
271 Some(mode.to_string())
272 }
273
274 pub fn save(&self) {
276 let mut guard = PREDICTOR_BUFFER
277 .lock()
278 .unwrap_or_else(std::sync::PoisonError::into_inner);
279 let should_flush = match *guard {
280 Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
281 None => true,
282 };
283 *guard = Some((Arc::new(self.clone()), Instant::now()));
284 if should_flush {
285 self.save_to_disk();
286 }
287 }
288
289 fn save_to_disk(&self) {
290 let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
291 return;
292 };
293 let _ = std::fs::create_dir_all(&dir);
294 let path = dir.join(STATS_FILE);
295 if let Ok(json) = serde_json::to_string_pretty(self) {
296 let tmp = dir.join(".mode_stats.tmp");
297 if std::fs::write(&tmp, &json).is_ok() {
298 let _ = std::fs::rename(&tmp, &path);
299 }
300 }
301 }
302
303 pub fn flush() {
305 let guard = PREDICTOR_BUFFER
306 .lock()
307 .unwrap_or_else(std::sync::PoisonError::into_inner);
308 if let Some((ref predictor, _)) = *guard {
309 predictor.save_to_disk();
310 }
311 }
312
313 fn load_from_disk() -> Option<Self> {
314 let path = crate::core::data_dir::lean_ctx_data_dir()
315 .ok()?
316 .join(STATS_FILE);
317 let data = std::fs::read_to_string(path).ok()?;
318 serde_json::from_str(&data).ok()
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn file_signature_buckets() {
328 assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
329 assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
330 assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
331 assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
332 assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
333 }
334
335 #[test]
336 fn predict_returns_none_without_history() {
337 let predictor = ModePredictor::default();
338 let sig = FileSignature::from_path("test.zzz", 500);
339 assert!(predictor.predict_from_local(&sig).is_none());
340 }
341
342 #[test]
343 fn predict_returns_none_with_too_few_entries() {
344 let mut predictor = ModePredictor::default();
345 let sig = FileSignature::from_path("test.zzz", 500);
346 predictor.record(
347 sig.clone(),
348 ModeOutcome {
349 mode: "full".to_string(),
350 tokens_in: 100,
351 tokens_out: 100,
352 density: 0.5,
353 },
354 );
355 assert!(predictor.predict_from_local(&sig).is_none());
356 }
357
358 #[test]
359 fn predict_learns_best_mode() {
360 let mut predictor = ModePredictor::default();
361 let sig = FileSignature::from_path("big.rs", 5000);
362 for _ in 0..5 {
363 predictor.record(
364 sig.clone(),
365 ModeOutcome {
366 mode: "full".to_string(),
367 tokens_in: 5000,
368 tokens_out: 5000,
369 density: 0.3,
370 },
371 );
372 predictor.record(
373 sig.clone(),
374 ModeOutcome {
375 mode: "map".to_string(),
376 tokens_in: 5000,
377 tokens_out: 800,
378 density: 0.6,
379 },
380 );
381 }
382 let best = predictor.predict_best_mode(&sig);
383 assert_eq!(best, Some("map".to_string()));
384 }
385
386 #[test]
387 fn history_caps_at_100() {
388 let mut predictor = ModePredictor::default();
389 let sig = FileSignature::from_path("test.rs", 100);
390 for _ in 0..120 {
391 predictor.record(
392 sig.clone(),
393 ModeOutcome {
394 mode: "full".to_string(),
395 tokens_in: 100,
396 tokens_out: 100,
397 density: 0.5,
398 },
399 );
400 }
401 assert!(predictor.history.get(&sig).unwrap().len() <= 100);
402 }
403
404 #[test]
405 fn defaults_return_none_for_small_files() {
406 let sig = FileSignature::from_path("small.rs", 200);
407 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
408 }
409
410 #[test]
411 fn defaults_recommend_map_for_medium_code() {
412 let sig = FileSignature::from_path("medium.rs", 3000);
413 assert_eq!(
414 ModePredictor::predict_from_defaults(&sig),
415 Some("map".to_string())
416 );
417 }
418
419 #[test]
420 fn defaults_recommend_aggressive_for_json() {
421 let sig = FileSignature::from_path("config.json", 1000);
422 assert_eq!(
423 ModePredictor::predict_from_defaults(&sig),
424 Some("aggressive".to_string())
425 );
426 }
427
428 #[test]
429 fn defaults_recommend_signatures_for_huge_code() {
430 let sig = FileSignature::from_path("huge.ts", 25000);
431 assert_eq!(
432 ModePredictor::predict_from_defaults(&sig),
433 Some("signatures".to_string())
434 );
435 }
436
437 #[test]
438 fn defaults_recommend_aggressive_for_large_unknown() {
439 let sig = FileSignature::from_path("data.xyz", 8000);
440 assert_eq!(
441 ModePredictor::predict_from_defaults(&sig),
442 Some("aggressive".to_string())
443 );
444 }
445
446 #[test]
447 fn defaults_never_compress_markdown() {
448 for tokens in [600, 3000, 8000, 25000] {
449 let sig = FileSignature::from_path("SKILL.md", tokens);
450 assert!(
451 ModePredictor::predict_from_defaults(&sig).is_none(),
452 "SKILL.md at {tokens} tokens should get full (None), not compressed"
453 );
454 }
455 let sig = FileSignature::from_path("AGENTS.md", 5000);
456 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
457 let sig = FileSignature::from_path("README.md", 12000);
458 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
459 }
460
461 #[test]
462 fn mode_outcome_efficiency() {
463 let o = ModeOutcome {
464 mode: "map".to_string(),
465 tokens_in: 1000,
466 tokens_out: 200,
467 density: 0.6,
468 };
469 assert!(o.efficiency() > 0.0);
470 }
471}