1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13 pub mode: String,
14 pub tokens_in: usize,
15 pub tokens_out: usize,
16 pub density: f64,
17}
18
19impl ModeOutcome {
20 pub fn efficiency(&self) -> f64 {
22 if self.tokens_out == 0 {
23 return 0.0;
24 }
25 self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26 }
27}
28
29#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32 pub ext: String,
33 pub size_bucket: u8,
34}
35
36impl FileSignature {
37 pub fn from_path(path: &str, token_count: usize) -> Self {
39 let ext = std::path::Path::new(path)
40 .extension()
41 .and_then(|e| e.to_str())
42 .unwrap_or("")
43 .to_string();
44 let size_bucket = match token_count {
45 0..=500 => 0,
46 501..=2000 => 1,
47 2001..=5000 => 2,
48 5001..=20000 => 3,
49 _ => 4,
50 };
51 Self { ext, size_bucket }
52 }
53}
54
55#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58 history: HashMap<FileSignature, Vec<ModeOutcome>>,
59 project_root: Option<String>,
60}
61
62impl ModePredictor {
63 pub fn new() -> Self {
65 let mut guard = PREDICTOR_BUFFER
66 .lock()
67 .unwrap_or_else(std::sync::PoisonError::into_inner);
68 if let Some((ref predictor, _)) = *guard {
69 return predictor.clone();
70 }
71 let mut loaded = Self::load_from_disk().unwrap_or_default();
72 if loaded.project_root.is_none() {
73 loaded.project_root = std::env::current_dir()
74 .ok()
75 .map(|p| p.to_string_lossy().to_string());
76 }
77 *guard = Some((loaded.clone(), Instant::now()));
78 loaded
79 }
80
81 pub fn with_project_root(mut self, root: &str) -> Self {
82 self.project_root = Some(root.to_string());
83 self
84 }
85
86 pub fn set_project_root(&mut self, root: &str) {
87 self.project_root = Some(root.to_string());
88 }
89
90 pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
92 let entries = self.history.entry(sig).or_default();
93 entries.push(outcome);
94 if entries.len() > 100 {
95 entries.drain(0..50);
96 }
97 }
98
99 pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
102 let default_mode = Self::predict_from_defaults(sig);
103
104 let allow_override = |candidate: &str| -> bool {
105 let Some(def) = default_mode.as_deref() else {
106 return true;
107 };
108 if candidate == "full" {
109 return false;
110 }
111 if (def == "map" || def == "signatures")
113 && (candidate == "aggressive" || candidate == "entropy")
114 {
115 return false;
116 }
117 true
118 };
119
120 if let Some(local) = self.predict_from_local(sig) {
121 if allow_override(&local) {
122 return Some(local);
123 }
124 }
125 if let Some(bandit) = self.predict_from_bandit(sig) {
126 if allow_override(&bandit) {
127 return Some(bandit);
128 }
129 }
130 if let Some(cloud) = self.predict_from_cloud(sig) {
131 if allow_override(&cloud) {
132 return Some(cloud);
133 }
134 }
135 default_mode
136 }
137
138 fn predict_from_bandit(&self, sig: &FileSignature) -> Option<String> {
139 let key = format!("{}_feedback", sig.ext);
140 let store =
141 crate::core::bandit::BanditStore::load(self.project_root.as_deref().unwrap_or("."));
142 let bandit = store.bandits.get(&key)?;
143 if bandit.total_pulls < 5 {
144 return None;
145 }
146 let best_arm = bandit.arms.iter().max_by(|a, b| {
147 a.mean()
148 .partial_cmp(&b.mean())
149 .unwrap_or(std::cmp::Ordering::Equal)
150 })?;
151 let mode = match best_arm.name.as_str() {
152 "conservative" => "full",
153 "balanced" => "signatures",
154 "aggressive" => "aggressive",
155 _ => return None,
156 };
157 Some(mode.to_string())
158 }
159
160 fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
161 let entries = self.history.get(sig)?;
162 if entries.len() < 3 {
163 return None;
164 }
165
166 let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
167 for entry in entries {
168 let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
169 *sum += entry.efficiency();
170 *count += 1;
171 }
172
173 mode_scores
174 .into_iter()
175 .max_by(|a, b| {
176 let avg_a = a.1 .0 / a.1 .1 as f64;
177 let avg_b = b.1 .0 / b.1 .1 as f64;
178 avg_a
179 .partial_cmp(&avg_b)
180 .unwrap_or(std::cmp::Ordering::Equal)
181 })
182 .map(|(mode, _)| mode.to_string())
183 }
184
185 #[allow(clippy::unused_self)]
188 fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
189 let data = crate::cloud_client::load_cloud_models()?;
190 let models = data["models"].as_array()?;
191
192 let ext_with_dot = format!(".{}", sig.ext);
193 let bucket_name = match sig.size_bucket {
194 0 => "0-500",
195 1 => "500-2k",
196 2 => "2k-10k",
197 _ => "10k+",
198 };
199
200 let mut best: Option<(&str, f64)> = None;
201
202 for model in models {
203 let m_ext = model["file_ext"].as_str().unwrap_or("");
204 let m_bucket = model["size_bucket"].as_str().unwrap_or("");
205 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
206
207 if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
208 if let Some(mode) = model["recommended_mode"].as_str() {
209 if best.is_none_or(|(_, c)| confidence > c) {
210 best = Some((mode, confidence));
211 }
212 }
213 }
214 }
215
216 if let Some((mode, _)) = best {
217 return Some(mode.to_string());
218 }
219
220 for model in models {
221 let m_ext = model["file_ext"].as_str().unwrap_or("");
222 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
223 if m_ext == ext_with_dot && confidence > 0.5 {
224 return model["recommended_mode"]
225 .as_str()
226 .map(std::string::ToString::to_string);
227 }
228 }
229
230 None
231 }
232
233 fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
237 if sig.size_bucket == 0 {
238 return None;
239 }
240 if matches!(sig.ext.as_str(), "md" | "mdx" | "txt" | "rst") {
241 return None;
242 }
243
244 let mode = match (sig.ext.as_str(), sig.size_bucket) {
245 ("lock", _)
247 | (
248 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
249 | "swift" | "kt" | "cs" | "vue" | "svelte",
250 4..,
251 ) => "signatures",
252
253 (
255 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
256 | "swift" | "kt" | "cs" | "vue" | "svelte",
257 2 | 3,
258 )
259 | ("sql", 2..) => "map",
260
261 ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _)
263 | ("css" | "scss" | "less" | "sass", 2..)
264 | (_, 3..) => "aggressive",
265
266 _ => return None,
267 };
268 Some(mode.to_string())
269 }
270
271 pub fn save(&self) {
273 let mut guard = PREDICTOR_BUFFER
274 .lock()
275 .unwrap_or_else(std::sync::PoisonError::into_inner);
276 let should_flush = match *guard {
277 Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
278 None => true,
279 };
280 *guard = Some((self.clone(), Instant::now()));
281 if should_flush {
282 self.save_to_disk();
283 }
284 }
285
286 fn save_to_disk(&self) {
287 let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
288 return;
289 };
290 let _ = std::fs::create_dir_all(&dir);
291 let path = dir.join(STATS_FILE);
292 if let Ok(json) = serde_json::to_string_pretty(self) {
293 let tmp = dir.join(".mode_stats.tmp");
294 if std::fs::write(&tmp, &json).is_ok() {
295 let _ = std::fs::rename(&tmp, &path);
296 }
297 }
298 }
299
300 pub fn flush() {
302 let guard = PREDICTOR_BUFFER
303 .lock()
304 .unwrap_or_else(std::sync::PoisonError::into_inner);
305 if let Some((ref predictor, _)) = *guard {
306 predictor.save_to_disk();
307 }
308 }
309
310 fn load_from_disk() -> Option<Self> {
311 let path = crate::core::data_dir::lean_ctx_data_dir()
312 .ok()?
313 .join(STATS_FILE);
314 let data = std::fs::read_to_string(path).ok()?;
315 serde_json::from_str(&data).ok()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn file_signature_buckets() {
325 assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
326 assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
327 assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
328 assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
329 assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
330 }
331
332 #[test]
333 fn predict_returns_none_without_history() {
334 let predictor = ModePredictor::default();
335 let sig = FileSignature::from_path("test.zzz", 500);
336 assert!(predictor.predict_from_local(&sig).is_none());
337 }
338
339 #[test]
340 fn predict_returns_none_with_too_few_entries() {
341 let mut predictor = ModePredictor::default();
342 let sig = FileSignature::from_path("test.zzz", 500);
343 predictor.record(
344 sig.clone(),
345 ModeOutcome {
346 mode: "full".to_string(),
347 tokens_in: 100,
348 tokens_out: 100,
349 density: 0.5,
350 },
351 );
352 assert!(predictor.predict_from_local(&sig).is_none());
353 }
354
355 #[test]
356 fn predict_learns_best_mode() {
357 let mut predictor = ModePredictor::default();
358 let sig = FileSignature::from_path("big.rs", 5000);
359 for _ in 0..5 {
360 predictor.record(
361 sig.clone(),
362 ModeOutcome {
363 mode: "full".to_string(),
364 tokens_in: 5000,
365 tokens_out: 5000,
366 density: 0.3,
367 },
368 );
369 predictor.record(
370 sig.clone(),
371 ModeOutcome {
372 mode: "map".to_string(),
373 tokens_in: 5000,
374 tokens_out: 800,
375 density: 0.6,
376 },
377 );
378 }
379 let best = predictor.predict_best_mode(&sig);
380 assert_eq!(best, Some("map".to_string()));
381 }
382
383 #[test]
384 fn history_caps_at_100() {
385 let mut predictor = ModePredictor::default();
386 let sig = FileSignature::from_path("test.rs", 100);
387 for _ in 0..120 {
388 predictor.record(
389 sig.clone(),
390 ModeOutcome {
391 mode: "full".to_string(),
392 tokens_in: 100,
393 tokens_out: 100,
394 density: 0.5,
395 },
396 );
397 }
398 assert!(predictor.history.get(&sig).unwrap().len() <= 100);
399 }
400
401 #[test]
402 fn defaults_return_none_for_small_files() {
403 let sig = FileSignature::from_path("small.rs", 200);
404 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
405 }
406
407 #[test]
408 fn defaults_recommend_map_for_medium_code() {
409 let sig = FileSignature::from_path("medium.rs", 3000);
410 assert_eq!(
411 ModePredictor::predict_from_defaults(&sig),
412 Some("map".to_string())
413 );
414 }
415
416 #[test]
417 fn defaults_recommend_aggressive_for_json() {
418 let sig = FileSignature::from_path("config.json", 1000);
419 assert_eq!(
420 ModePredictor::predict_from_defaults(&sig),
421 Some("aggressive".to_string())
422 );
423 }
424
425 #[test]
426 fn defaults_recommend_signatures_for_huge_code() {
427 let sig = FileSignature::from_path("huge.ts", 25000);
428 assert_eq!(
429 ModePredictor::predict_from_defaults(&sig),
430 Some("signatures".to_string())
431 );
432 }
433
434 #[test]
435 fn defaults_recommend_aggressive_for_large_unknown() {
436 let sig = FileSignature::from_path("data.xyz", 8000);
437 assert_eq!(
438 ModePredictor::predict_from_defaults(&sig),
439 Some("aggressive".to_string())
440 );
441 }
442
443 #[test]
444 fn defaults_never_compress_markdown() {
445 for tokens in [600, 3000, 8000, 25000] {
446 let sig = FileSignature::from_path("SKILL.md", tokens);
447 assert!(
448 ModePredictor::predict_from_defaults(&sig).is_none(),
449 "SKILL.md at {tokens} tokens should get full (None), not compressed"
450 );
451 }
452 let sig = FileSignature::from_path("AGENTS.md", 5000);
453 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
454 let sig = FileSignature::from_path("README.md", 12000);
455 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
456 }
457
458 #[test]
459 fn mode_outcome_efficiency() {
460 let o = ModeOutcome {
461 mode: "map".to_string(),
462 tokens_in: 1000,
463 tokens_out: 200,
464 density: 0.6,
465 };
466 assert!(o.efficiency() > 0.0);
467 }
468}