Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ModeOutcome {
12    pub mode: String,
13    pub tokens_in: usize,
14    pub tokens_out: usize,
15    pub density: f64,
16}
17
18impl ModeOutcome {
19    pub fn efficiency(&self) -> f64 {
20        if self.tokens_out == 0 {
21            return 0.0;
22        }
23        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
24    }
25}
26
27#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
28pub struct FileSignature {
29    pub ext: String,
30    pub size_bucket: u8,
31}
32
33impl FileSignature {
34    pub fn from_path(path: &str, token_count: usize) -> Self {
35        let ext = std::path::Path::new(path)
36            .extension()
37            .and_then(|e| e.to_str())
38            .unwrap_or("")
39            .to_string();
40        let size_bucket = match token_count {
41            0..=500 => 0,
42            501..=2000 => 1,
43            2001..=5000 => 2,
44            5001..=20000 => 3,
45            _ => 4,
46        };
47        Self { ext, size_bucket }
48    }
49}
50
51#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
52pub struct ModePredictor {
53    history: HashMap<FileSignature, Vec<ModeOutcome>>,
54}
55
56impl ModePredictor {
57    pub fn new() -> Self {
58        let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
59        if let Some((ref predictor, _)) = *guard {
60            return predictor.clone();
61        }
62        let loaded = Self::load_from_disk().unwrap_or_default();
63        *guard = Some((loaded.clone(), Instant::now()));
64        loaded
65    }
66
67    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
68        let entries = self.history.entry(sig).or_default();
69        entries.push(outcome);
70        if entries.len() > 100 {
71            entries.drain(0..50);
72        }
73    }
74
75    /// Returns the best mode based on historical efficiency.
76    /// Chain: local history -> cloud adaptive models -> built-in defaults.
77    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
78        if let Some(local) = self.predict_from_local(sig) {
79            return Some(local);
80        }
81        if let Some(cloud) = self.predict_from_cloud(sig) {
82            return Some(cloud);
83        }
84        Self::predict_from_defaults(sig)
85    }
86
87    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
88        let entries = self.history.get(sig)?;
89        if entries.len() < 3 {
90            return None;
91        }
92
93        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
94        for entry in entries {
95            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
96            *sum += entry.efficiency();
97            *count += 1;
98        }
99
100        mode_scores
101            .into_iter()
102            .max_by(|a, b| {
103                let avg_a = a.1 .0 / a.1 .1 as f64;
104                let avg_b = b.1 .0 / b.1 .1 as f64;
105                avg_a
106                    .partial_cmp(&avg_b)
107                    .unwrap_or(std::cmp::Ordering::Equal)
108            })
109            .map(|(mode, _)| mode.to_string())
110    }
111
112    /// Loads cloud adaptive models (synced from LeanCTX Cloud).
113    /// Models are cached locally and auto-updated for cloud users.
114    fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
115        let data = crate::cloud_client::load_cloud_models()?;
116        let models = data["models"].as_array()?;
117
118        let ext_with_dot = format!(".{}", sig.ext);
119        let bucket_name = match sig.size_bucket {
120            0 => "0-500",
121            1 => "500-2k",
122            2 => "2k-10k",
123            3 => "10k+",
124            _ => "10k+",
125        };
126
127        let mut best: Option<(&str, f64)> = None;
128
129        for model in models {
130            let m_ext = model["file_ext"].as_str().unwrap_or("");
131            let m_bucket = model["size_bucket"].as_str().unwrap_or("");
132            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
133
134            if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
135                if let Some(mode) = model["recommended_mode"].as_str() {
136                    if best.is_none() || confidence > best.unwrap().1 {
137                        best = Some((mode, confidence));
138                    }
139                }
140            }
141        }
142
143        if let Some((mode, _)) = best {
144            return Some(mode.to_string());
145        }
146
147        for model in models {
148            let m_ext = model["file_ext"].as_str().unwrap_or("");
149            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
150            if m_ext == ext_with_dot && confidence > 0.5 {
151                return model["recommended_mode"].as_str().map(|s| s.to_string());
152            }
153        }
154
155        None
156    }
157
158    /// Built-in defaults for common file types and sizes.
159    /// Ensures reasonable compression even without local history or cloud models.
160    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
161    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
162        let mode = match (sig.ext.as_str(), sig.size_bucket) {
163            // Tiny files (0-500 tokens): always full — compression overhead not worth it
164            (_, 0) => return None,
165
166            // Config / data files: aggressive strips comments and whitespace
167            ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _) => "aggressive",
168
169            // Lock files: signatures only (just versions matter)
170            ("lock", _) => "signatures",
171
172            // Code files by size bucket
173            // 500-2k tokens: full is fine
174            (
175                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
176                | "swift" | "kt" | "cs" | "vue" | "svelte",
177                1,
178            ) => return None,
179
180            // 2k-5k tokens: map gives structure without bloat
181            (
182                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
183                | "swift" | "kt" | "cs" | "vue" | "svelte",
184                2,
185            ) => "map",
186
187            // 5k-20k tokens: map is strongly preferred
188            (
189                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
190                | "swift" | "kt" | "cs" | "vue" | "svelte",
191                3,
192            ) => "map",
193
194            // 20k+ tokens: signatures only — too large for full context
195            (
196                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
197                | "swift" | "kt" | "cs" | "vue" | "svelte",
198                4..,
199            ) => "signatures",
200
201            // Markup / docs: aggressive for large, map for medium
202            ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 1..=2) => return None,
203            ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 3..) => "aggressive",
204
205            // CSS / styles: aggressive strips whitespace well
206            ("css" | "scss" | "less" | "sass", 2..) => "aggressive",
207
208            // SQL: map for medium+
209            ("sql", 2..) => "map",
210
211            // Unknown large files: aggressive as safe fallback
212            (_, 3..) => "aggressive",
213
214            _ => return None,
215        };
216        Some(mode.to_string())
217    }
218
219    pub fn save(&self) {
220        let mut guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
221        let should_flush = match *guard {
222            Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
223            None => true,
224        };
225        *guard = Some((self.clone(), Instant::now()));
226        if should_flush {
227            self.save_to_disk();
228        }
229    }
230
231    fn save_to_disk(&self) {
232        let dir = match crate::core::data_dir::lean_ctx_data_dir() {
233            Ok(d) => d,
234            Err(_) => return,
235        };
236        let _ = std::fs::create_dir_all(&dir);
237        let path = dir.join(STATS_FILE);
238        if let Ok(json) = serde_json::to_string_pretty(self) {
239            let tmp = dir.join(".mode_stats.tmp");
240            if std::fs::write(&tmp, &json).is_ok() {
241                let _ = std::fs::rename(&tmp, &path);
242            }
243        }
244    }
245
246    pub fn flush() {
247        let guard = PREDICTOR_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
248        if let Some((ref predictor, _)) = *guard {
249            predictor.save_to_disk();
250        }
251    }
252
253    fn load_from_disk() -> Option<Self> {
254        let path = crate::core::data_dir::lean_ctx_data_dir()
255            .ok()?
256            .join(STATS_FILE);
257        let data = std::fs::read_to_string(path).ok()?;
258        serde_json::from_str(&data).ok()
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn file_signature_buckets() {
268        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
269        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
270        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
271        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
272        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
273    }
274
275    #[test]
276    fn predict_returns_none_without_history() {
277        let predictor = ModePredictor::default();
278        let sig = FileSignature::from_path("test.zzz", 500);
279        assert!(predictor.predict_from_local(&sig).is_none());
280    }
281
282    #[test]
283    fn predict_returns_none_with_too_few_entries() {
284        let mut predictor = ModePredictor::default();
285        let sig = FileSignature::from_path("test.zzz", 500);
286        predictor.record(
287            sig.clone(),
288            ModeOutcome {
289                mode: "full".to_string(),
290                tokens_in: 100,
291                tokens_out: 100,
292                density: 0.5,
293            },
294        );
295        assert!(predictor.predict_from_local(&sig).is_none());
296    }
297
298    #[test]
299    fn predict_learns_best_mode() {
300        let mut predictor = ModePredictor::default();
301        let sig = FileSignature::from_path("big.rs", 5000);
302        for _ in 0..5 {
303            predictor.record(
304                sig.clone(),
305                ModeOutcome {
306                    mode: "full".to_string(),
307                    tokens_in: 5000,
308                    tokens_out: 5000,
309                    density: 0.3,
310                },
311            );
312            predictor.record(
313                sig.clone(),
314                ModeOutcome {
315                    mode: "map".to_string(),
316                    tokens_in: 5000,
317                    tokens_out: 800,
318                    density: 0.6,
319                },
320            );
321        }
322        let best = predictor.predict_best_mode(&sig);
323        assert_eq!(best, Some("map".to_string()));
324    }
325
326    #[test]
327    fn history_caps_at_100() {
328        let mut predictor = ModePredictor::default();
329        let sig = FileSignature::from_path("test.rs", 100);
330        for _ in 0..120 {
331            predictor.record(
332                sig.clone(),
333                ModeOutcome {
334                    mode: "full".to_string(),
335                    tokens_in: 100,
336                    tokens_out: 100,
337                    density: 0.5,
338                },
339            );
340        }
341        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
342    }
343
344    #[test]
345    fn defaults_return_none_for_small_files() {
346        let sig = FileSignature::from_path("small.rs", 200);
347        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
348    }
349
350    #[test]
351    fn defaults_recommend_map_for_medium_code() {
352        let sig = FileSignature::from_path("medium.rs", 3000);
353        assert_eq!(
354            ModePredictor::predict_from_defaults(&sig),
355            Some("map".to_string())
356        );
357    }
358
359    #[test]
360    fn defaults_recommend_aggressive_for_json() {
361        let sig = FileSignature::from_path("config.json", 1000);
362        assert_eq!(
363            ModePredictor::predict_from_defaults(&sig),
364            Some("aggressive".to_string())
365        );
366    }
367
368    #[test]
369    fn defaults_recommend_signatures_for_huge_code() {
370        let sig = FileSignature::from_path("huge.ts", 25000);
371        assert_eq!(
372            ModePredictor::predict_from_defaults(&sig),
373            Some("signatures".to_string())
374        );
375    }
376
377    #[test]
378    fn defaults_recommend_aggressive_for_large_unknown() {
379        let sig = FileSignature::from_path("data.xyz", 8000);
380        assert_eq!(
381            ModePredictor::predict_from_defaults(&sig),
382            Some("aggressive".to_string())
383        );
384    }
385
386    #[test]
387    fn mode_outcome_efficiency() {
388        let o = ModeOutcome {
389            mode: "map".to_string(),
390            tokens_in: 1000,
391            tokens_out: 200,
392            density: 0.6,
393        };
394        assert!(o.efficiency() > 0.0);
395    }
396}