Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ModeOutcome {
12    pub mode: String,
13    pub tokens_in: usize,
14    pub tokens_out: usize,
15    pub density: f64,
16}
17
18impl ModeOutcome {
19    pub fn efficiency(&self) -> f64 {
20        if self.tokens_out == 0 {
21            return 0.0;
22        }
23        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
24    }
25}
26
27#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
28pub struct FileSignature {
29    pub ext: String,
30    pub size_bucket: u8,
31}
32
33impl FileSignature {
34    pub fn from_path(path: &str, token_count: usize) -> Self {
35        let ext = std::path::Path::new(path)
36            .extension()
37            .and_then(|e| e.to_str())
38            .unwrap_or("")
39            .to_string();
40        let size_bucket = match token_count {
41            0..=500 => 0,
42            501..=2000 => 1,
43            2001..=5000 => 2,
44            5001..=20000 => 3,
45            _ => 4,
46        };
47        Self { ext, size_bucket }
48    }
49}
50
51#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
52pub struct ModePredictor {
53    history: HashMap<FileSignature, Vec<ModeOutcome>>,
54}
55
56impl ModePredictor {
57    pub fn new() -> Self {
58        let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
59        if let Some((ref predictor, _)) = *guard {
60            return predictor.clone();
61        }
62        let loaded = Self::load_from_disk().unwrap_or_default();
63        *guard = Some((loaded.clone(), Instant::now()));
64        loaded
65    }
66
67    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
68        let entries = self.history.entry(sig).or_default();
69        entries.push(outcome);
70        if entries.len() > 100 {
71            entries.drain(0..50);
72        }
73    }
74
75    /// Returns the best mode based on historical efficiency.
76    /// Chain: local history -> built-in defaults.
77    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
78        if let Some(local) = self.predict_from_local(sig) {
79            return Some(local);
80        }
81        Self::predict_from_defaults(sig)
82    }
83
84    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
85        let entries = self.history.get(sig)?;
86        if entries.len() < 3 {
87            return None;
88        }
89
90        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
91        for entry in entries {
92            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
93            *sum += entry.efficiency();
94            *count += 1;
95        }
96
97        mode_scores
98            .into_iter()
99            .max_by(|a, b| {
100                let avg_a = a.1 .0 / a.1 .1 as f64;
101                let avg_b = b.1 .0 / b.1 .1 as f64;
102                avg_a
103                    .partial_cmp(&avg_b)
104                    .unwrap_or(std::cmp::Ordering::Equal)
105            })
106            .map(|(mode, _)| mode.to_string())
107    }
108
109    /// Built-in defaults for common file types and sizes.
110    /// Ensures reasonable compression even without local history or cloud models.
111    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
112    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
113        let mode = match (sig.ext.as_str(), sig.size_bucket) {
114            // Tiny files (0-500 tokens): always full — compression overhead not worth it
115            (_, 0) => return None,
116
117            // Config / data files: aggressive strips comments and whitespace
118            ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _) => "aggressive",
119
120            // Lock files: signatures only (just versions matter)
121            ("lock", _) => "signatures",
122
123            // Code files by size bucket
124            // 500-2k tokens: full is fine
125            (
126                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
127                | "swift" | "kt" | "cs" | "vue" | "svelte",
128                1,
129            ) => return None,
130
131            // 2k-5k tokens: map gives structure without bloat
132            (
133                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
134                | "swift" | "kt" | "cs" | "vue" | "svelte",
135                2,
136            ) => "map",
137
138            // 5k-20k tokens: map is strongly preferred
139            (
140                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
141                | "swift" | "kt" | "cs" | "vue" | "svelte",
142                3,
143            ) => "map",
144
145            // 20k+ tokens: signatures only — too large for full context
146            (
147                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
148                | "swift" | "kt" | "cs" | "vue" | "svelte",
149                4..,
150            ) => "signatures",
151
152            // Markup / docs: aggressive for large, map for medium
153            ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 1..=2) => return None,
154            ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 3..) => "aggressive",
155
156            // CSS / styles: aggressive strips whitespace well
157            ("css" | "scss" | "less" | "sass", 2..) => "aggressive",
158
159            // SQL: map for medium+
160            ("sql", 2..) => "map",
161
162            // Unknown large files: aggressive as safe fallback
163            (_, 3..) => "aggressive",
164
165            _ => return None,
166        };
167        Some(mode.to_string())
168    }
169
170    pub fn save(&self) {
171        let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
172        let should_flush = match *guard {
173            Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
174            None => true,
175        };
176        *guard = Some((self.clone(), Instant::now()));
177        if should_flush {
178            self.save_to_disk();
179        }
180    }
181
182    fn save_to_disk(&self) {
183        let dir = match crate::core::data_dir::nebu_ctx_data_dir() {
184            Ok(d) => d,
185            Err(_) => return,
186        };
187        let _ = std::fs::create_dir_all(&dir);
188        let path = dir.join(STATS_FILE);
189        if let Ok(json) = serde_json::to_string_pretty(self) {
190            let tmp = dir.join(".mode_stats.tmp");
191            if std::fs::write(&tmp, &json).is_ok() {
192                let _ = std::fs::rename(&tmp, &path);
193            }
194        }
195    }
196
197    pub fn flush() {
198        let guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
199        if let Some((ref predictor, _)) = *guard {
200            predictor.save_to_disk();
201        }
202    }
203
204    fn load_from_disk() -> Option<Self> {
205        let path = crate::core::data_dir::nebu_ctx_data_dir()
206            .ok()?
207            .join(STATS_FILE);
208        let data = std::fs::read_to_string(path).ok()?;
209        serde_json::from_str(&data).ok()
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn file_signature_buckets() {
219        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
220        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
221        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
222        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
223        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
224    }
225
226    #[test]
227    fn predict_returns_none_without_history() {
228        let predictor = ModePredictor::default();
229        let sig = FileSignature::from_path("test.zzz", 500);
230        assert!(predictor.predict_from_local(&sig).is_none());
231    }
232
233    #[test]
234    fn predict_returns_none_with_too_few_entries() {
235        let mut predictor = ModePredictor::default();
236        let sig = FileSignature::from_path("test.zzz", 500);
237        predictor.record(
238            sig.clone(),
239            ModeOutcome {
240                mode: "full".to_string(),
241                tokens_in: 100,
242                tokens_out: 100,
243                density: 0.5,
244            },
245        );
246        assert!(predictor.predict_from_local(&sig).is_none());
247    }
248
249    #[test]
250    fn predict_learns_best_mode() {
251        let mut predictor = ModePredictor::default();
252        let sig = FileSignature::from_path("big.rs", 5000);
253        for _ in 0..5 {
254            predictor.record(
255                sig.clone(),
256                ModeOutcome {
257                    mode: "full".to_string(),
258                    tokens_in: 5000,
259                    tokens_out: 5000,
260                    density: 0.3,
261                },
262            );
263            predictor.record(
264                sig.clone(),
265                ModeOutcome {
266                    mode: "map".to_string(),
267                    tokens_in: 5000,
268                    tokens_out: 800,
269                    density: 0.6,
270                },
271            );
272        }
273        let best = predictor.predict_best_mode(&sig);
274        assert_eq!(best, Some("map".to_string()));
275    }
276
277    #[test]
278    fn history_caps_at_100() {
279        let mut predictor = ModePredictor::default();
280        let sig = FileSignature::from_path("test.rs", 100);
281        for _ in 0..120 {
282            predictor.record(
283                sig.clone(),
284                ModeOutcome {
285                    mode: "full".to_string(),
286                    tokens_in: 100,
287                    tokens_out: 100,
288                    density: 0.5,
289                },
290            );
291        }
292        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
293    }
294
295    #[test]
296    fn defaults_return_none_for_small_files() {
297        let sig = FileSignature::from_path("small.rs", 200);
298        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
299    }
300
301    #[test]
302    fn defaults_recommend_map_for_medium_code() {
303        let sig = FileSignature::from_path("medium.rs", 3000);
304        assert_eq!(
305            ModePredictor::predict_from_defaults(&sig),
306            Some("map".to_string())
307        );
308    }
309
310    #[test]
311    fn defaults_recommend_aggressive_for_json() {
312        let sig = FileSignature::from_path("config.json", 1000);
313        assert_eq!(
314            ModePredictor::predict_from_defaults(&sig),
315            Some("aggressive".to_string())
316        );
317    }
318
319    #[test]
320    fn defaults_recommend_signatures_for_huge_code() {
321        let sig = FileSignature::from_path("huge.ts", 25000);
322        assert_eq!(
323            ModePredictor::predict_from_defaults(&sig),
324            Some("signatures".to_string())
325        );
326    }
327
328    #[test]
329    fn defaults_recommend_aggressive_for_large_unknown() {
330        let sig = FileSignature::from_path("data.xyz", 8000);
331        assert_eq!(
332            ModePredictor::predict_from_defaults(&sig),
333            Some("aggressive".to_string())
334        );
335    }
336
337    #[test]
338    fn mode_outcome_efficiency() {
339        let o = ModeOutcome {
340            mode: "map".to_string(),
341            tokens_in: 1000,
342            tokens_out: 200,
343            density: 0.6,
344        };
345        assert!(o.efficiency() > 0.0);
346    }
347}