1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ModeOutcome {
12 pub mode: String,
13 pub tokens_in: usize,
14 pub tokens_out: usize,
15 pub density: f64,
16}
17
18impl ModeOutcome {
19 pub fn efficiency(&self) -> f64 {
20 if self.tokens_out == 0 {
21 return 0.0;
22 }
23 self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
24 }
25}
26
27#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
28pub struct FileSignature {
29 pub ext: String,
30 pub size_bucket: u8,
31}
32
33impl FileSignature {
34 pub fn from_path(path: &str, token_count: usize) -> Self {
35 let ext = std::path::Path::new(path)
36 .extension()
37 .and_then(|e| e.to_str())
38 .unwrap_or("")
39 .to_string();
40 let size_bucket = match token_count {
41 0..=500 => 0,
42 501..=2000 => 1,
43 2001..=5000 => 2,
44 5001..=20000 => 3,
45 _ => 4,
46 };
47 Self { ext, size_bucket }
48 }
49}
50
51#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
52pub struct ModePredictor {
53 history: HashMap<FileSignature, Vec<ModeOutcome>>,
54}
55
56impl ModePredictor {
57 pub fn new() -> Self {
58 let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
59 if let Some((ref predictor, _)) = *guard {
60 return predictor.clone();
61 }
62 let loaded = Self::load_from_disk().unwrap_or_default();
63 *guard = Some((loaded.clone(), Instant::now()));
64 loaded
65 }
66
67 pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
68 let entries = self.history.entry(sig).or_default();
69 entries.push(outcome);
70 if entries.len() > 100 {
71 entries.drain(0..50);
72 }
73 }
74
75 pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
78 if let Some(local) = self.predict_from_local(sig) {
79 return Some(local);
80 }
81 Self::predict_from_defaults(sig)
82 }
83
84 fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
85 let entries = self.history.get(sig)?;
86 if entries.len() < 3 {
87 return None;
88 }
89
90 let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
91 for entry in entries {
92 let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
93 *sum += entry.efficiency();
94 *count += 1;
95 }
96
97 mode_scores
98 .into_iter()
99 .max_by(|a, b| {
100 let avg_a = a.1 .0 / a.1 .1 as f64;
101 let avg_b = b.1 .0 / b.1 .1 as f64;
102 avg_a
103 .partial_cmp(&avg_b)
104 .unwrap_or(std::cmp::Ordering::Equal)
105 })
106 .map(|(mode, _)| mode.to_string())
107 }
108
109 fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
113 let mode = match (sig.ext.as_str(), sig.size_bucket) {
114 (_, 0) => return None,
116
117 ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _) => "aggressive",
119
120 ("lock", _) => "signatures",
122
123 (
126 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
127 | "swift" | "kt" | "cs" | "vue" | "svelte",
128 1,
129 ) => return None,
130
131 (
133 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
134 | "swift" | "kt" | "cs" | "vue" | "svelte",
135 2,
136 ) => "map",
137
138 (
140 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
141 | "swift" | "kt" | "cs" | "vue" | "svelte",
142 3,
143 ) => "map",
144
145 (
147 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
148 | "swift" | "kt" | "cs" | "vue" | "svelte",
149 4..,
150 ) => "signatures",
151
152 ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 1..=2) => return None,
154 ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 3..) => "aggressive",
155
156 ("css" | "scss" | "less" | "sass", 2..) => "aggressive",
158
159 ("sql", 2..) => "map",
161
162 (_, 3..) => "aggressive",
164
165 _ => return None,
166 };
167 Some(mode.to_string())
168 }
169
170 pub fn save(&self) {
171 let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
172 let should_flush = match *guard {
173 Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
174 None => true,
175 };
176 *guard = Some((self.clone(), Instant::now()));
177 if should_flush {
178 self.save_to_disk();
179 }
180 }
181
182 fn save_to_disk(&self) {
183 let dir = match crate::core::data_dir::nebu_ctx_data_dir() {
184 Ok(d) => d,
185 Err(_) => return,
186 };
187 let _ = std::fs::create_dir_all(&dir);
188 let path = dir.join(STATS_FILE);
189 if let Ok(json) = serde_json::to_string_pretty(self) {
190 let tmp = dir.join(".mode_stats.tmp");
191 if std::fs::write(&tmp, &json).is_ok() {
192 let _ = std::fs::rename(&tmp, &path);
193 }
194 }
195 }
196
197 pub fn flush() {
198 let guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
199 if let Some((ref predictor, _)) = *guard {
200 predictor.save_to_disk();
201 }
202 }
203
204 fn load_from_disk() -> Option<Self> {
205 let path = crate::core::data_dir::nebu_ctx_data_dir()
206 .ok()?
207 .join(STATS_FILE);
208 let data = std::fs::read_to_string(path).ok()?;
209 serde_json::from_str(&data).ok()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn file_signature_buckets() {
219 assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
220 assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
221 assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
222 assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
223 assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
224 }
225
226 #[test]
227 fn predict_returns_none_without_history() {
228 let predictor = ModePredictor::default();
229 let sig = FileSignature::from_path("test.zzz", 500);
230 assert!(predictor.predict_from_local(&sig).is_none());
231 }
232
233 #[test]
234 fn predict_returns_none_with_too_few_entries() {
235 let mut predictor = ModePredictor::default();
236 let sig = FileSignature::from_path("test.zzz", 500);
237 predictor.record(
238 sig.clone(),
239 ModeOutcome {
240 mode: "full".to_string(),
241 tokens_in: 100,
242 tokens_out: 100,
243 density: 0.5,
244 },
245 );
246 assert!(predictor.predict_from_local(&sig).is_none());
247 }
248
249 #[test]
250 fn predict_learns_best_mode() {
251 let mut predictor = ModePredictor::default();
252 let sig = FileSignature::from_path("big.rs", 5000);
253 for _ in 0..5 {
254 predictor.record(
255 sig.clone(),
256 ModeOutcome {
257 mode: "full".to_string(),
258 tokens_in: 5000,
259 tokens_out: 5000,
260 density: 0.3,
261 },
262 );
263 predictor.record(
264 sig.clone(),
265 ModeOutcome {
266 mode: "map".to_string(),
267 tokens_in: 5000,
268 tokens_out: 800,
269 density: 0.6,
270 },
271 );
272 }
273 let best = predictor.predict_best_mode(&sig);
274 assert_eq!(best, Some("map".to_string()));
275 }
276
277 #[test]
278 fn history_caps_at_100() {
279 let mut predictor = ModePredictor::default();
280 let sig = FileSignature::from_path("test.rs", 100);
281 for _ in 0..120 {
282 predictor.record(
283 sig.clone(),
284 ModeOutcome {
285 mode: "full".to_string(),
286 tokens_in: 100,
287 tokens_out: 100,
288 density: 0.5,
289 },
290 );
291 }
292 assert!(predictor.history.get(&sig).unwrap().len() <= 100);
293 }
294
295 #[test]
296 fn defaults_return_none_for_small_files() {
297 let sig = FileSignature::from_path("small.rs", 200);
298 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
299 }
300
301 #[test]
302 fn defaults_recommend_map_for_medium_code() {
303 let sig = FileSignature::from_path("medium.rs", 3000);
304 assert_eq!(
305 ModePredictor::predict_from_defaults(&sig),
306 Some("map".to_string())
307 );
308 }
309
310 #[test]
311 fn defaults_recommend_aggressive_for_json() {
312 let sig = FileSignature::from_path("config.json", 1000);
313 assert_eq!(
314 ModePredictor::predict_from_defaults(&sig),
315 Some("aggressive".to_string())
316 );
317 }
318
319 #[test]
320 fn defaults_recommend_signatures_for_huge_code() {
321 let sig = FileSignature::from_path("huge.ts", 25000);
322 assert_eq!(
323 ModePredictor::predict_from_defaults(&sig),
324 Some("signatures".to_string())
325 );
326 }
327
328 #[test]
329 fn defaults_recommend_aggressive_for_large_unknown() {
330 let sig = FileSignature::from_path("data.xyz", 8000);
331 assert_eq!(
332 ModePredictor::predict_from_defaults(&sig),
333 Some("aggressive".to_string())
334 );
335 }
336
337 #[test]
338 fn mode_outcome_efficiency() {
339 let o = ModeOutcome {
340 mode: "map".to_string(),
341 tokens_in: 1000,
342 tokens_out: 200,
343 density: 0.6,
344 };
345 assert!(o.efficiency() > 0.0);
346 }
347}