Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2
3const STATS_FILE: &str = "mode_stats.json";
4
5#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
6pub struct ModeOutcome {
7    pub mode: String,
8    pub tokens_in: usize,
9    pub tokens_out: usize,
10    pub density: f64,
11}
12
13impl ModeOutcome {
14    pub fn efficiency(&self) -> f64 {
15        if self.tokens_out == 0 {
16            return 0.0;
17        }
18        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
19    }
20}
21
22#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
23pub struct FileSignature {
24    pub ext: String,
25    pub size_bucket: u8,
26}
27
28impl FileSignature {
29    pub fn from_path(path: &str, token_count: usize) -> Self {
30        let ext = std::path::Path::new(path)
31            .extension()
32            .and_then(|e| e.to_str())
33            .unwrap_or("")
34            .to_string();
35        let size_bucket = match token_count {
36            0..=500 => 0,
37            501..=2000 => 1,
38            2001..=5000 => 2,
39            5001..=20000 => 3,
40            _ => 4,
41        };
42        Self { ext, size_bucket }
43    }
44}
45
46#[derive(Debug, Default, serde::Serialize, serde::Deserialize)]
47pub struct ModePredictor {
48    history: HashMap<FileSignature, Vec<ModeOutcome>>,
49}
50
51impl ModePredictor {
52    pub fn new() -> Self {
53        Self::load().unwrap_or_default()
54    }
55
56    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
57        let entries = self.history.entry(sig).or_default();
58        entries.push(outcome);
59        if entries.len() > 100 {
60            entries.drain(0..50);
61        }
62    }
63
64    /// Returns the best mode based on historical efficiency.
65    /// Chain: local history -> Pro adaptive models -> built-in defaults.
66    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
67        if let Some(local) = self.predict_from_local(sig) {
68            return Some(local);
69        }
70        if let Some(pro) = self.predict_from_pro(sig) {
71            return Some(pro);
72        }
73        Self::predict_from_defaults(sig)
74    }
75
76    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
77        let entries = self.history.get(sig)?;
78        if entries.len() < 3 {
79            return None;
80        }
81
82        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
83        for entry in entries {
84            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
85            *sum += entry.efficiency();
86            *count += 1;
87        }
88
89        mode_scores
90            .into_iter()
91            .max_by(|a, b| {
92                let avg_a = a.1 .0 / a.1 .1 as f64;
93                let avg_b = b.1 .0 / b.1 .1 as f64;
94                avg_a
95                    .partial_cmp(&avg_b)
96                    .unwrap_or(std::cmp::Ordering::Equal)
97            })
98            .map(|(mode, _)| mode.to_string())
99    }
100
101    /// Loads Pro adaptive models (requires Pro subscription).
102    /// Pro models are cached locally and auto-updated for Pro users.
103    fn predict_from_pro(&self, sig: &FileSignature) -> Option<String> {
104        let data = crate::cloud_client::load_pro_models()?;
105        let models = data["models"].as_array()?;
106
107        let ext_with_dot = format!(".{}", sig.ext);
108        let bucket_name = match sig.size_bucket {
109            0 => "0-500",
110            1 => "500-2k",
111            2 => "2k-10k",
112            3 => "10k+",
113            _ => "10k+",
114        };
115
116        let mut best: Option<(&str, f64)> = None;
117
118        for model in models {
119            let m_ext = model["file_ext"].as_str().unwrap_or("");
120            let m_bucket = model["size_bucket"].as_str().unwrap_or("");
121            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
122
123            if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
124                if let Some(mode) = model["recommended_mode"].as_str() {
125                    if best.is_none() || confidence > best.unwrap().1 {
126                        best = Some((mode, confidence));
127                    }
128                }
129            }
130        }
131
132        if let Some((mode, _)) = best {
133            return Some(mode.to_string());
134        }
135
136        for model in models {
137            let m_ext = model["file_ext"].as_str().unwrap_or("");
138            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
139            if m_ext == ext_with_dot && confidence > 0.5 {
140                return model["recommended_mode"].as_str().map(|s| s.to_string());
141            }
142        }
143
144        None
145    }
146
147    /// Built-in defaults for common file types and sizes.
148    /// Ensures reasonable compression even without local history or Pro models.
149    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
150    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
151        let mode = match (sig.ext.as_str(), sig.size_bucket) {
152            // Tiny files (0-500 tokens): always full — compression overhead not worth it
153            (_, 0) => return None,
154
155            // Config / data files: aggressive strips comments and whitespace
156            ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _) => "aggressive",
157
158            // Lock files: signatures only (just versions matter)
159            ("lock", _) => "signatures",
160
161            // Code files by size bucket
162            // 500-2k tokens: full is fine
163            (
164                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
165                | "swift" | "kt" | "cs" | "vue" | "svelte",
166                1,
167            ) => return None,
168
169            // 2k-5k tokens: map gives structure without bloat
170            (
171                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
172                | "swift" | "kt" | "cs" | "vue" | "svelte",
173                2,
174            ) => "map",
175
176            // 5k-20k tokens: map is strongly preferred
177            (
178                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
179                | "swift" | "kt" | "cs" | "vue" | "svelte",
180                3,
181            ) => "map",
182
183            // 20k+ tokens: signatures only — too large for full context
184            (
185                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
186                | "swift" | "kt" | "cs" | "vue" | "svelte",
187                4..,
188            ) => "signatures",
189
190            // Markup / docs: aggressive for large, map for medium
191            ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 1..=2) => return None,
192            ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 3..) => "aggressive",
193
194            // CSS / styles: aggressive strips whitespace well
195            ("css" | "scss" | "less" | "sass", 2..) => "aggressive",
196
197            // SQL: map for medium+
198            ("sql", 2..) => "map",
199
200            // Unknown large files: aggressive as safe fallback
201            (_, 3..) => "aggressive",
202
203            _ => return None,
204        };
205        Some(mode.to_string())
206    }
207
208    pub fn save(&self) {
209        let dir = match dirs::home_dir() {
210            Some(d) => d.join(".lean-ctx"),
211            None => return,
212        };
213        let _ = std::fs::create_dir_all(&dir);
214        let path = dir.join(STATS_FILE);
215        if let Ok(json) = serde_json::to_string_pretty(self) {
216            let _ = std::fs::write(path, json);
217        }
218    }
219
220    fn load() -> Option<Self> {
221        let path = dirs::home_dir()?.join(".lean-ctx").join(STATS_FILE);
222        let data = std::fs::read_to_string(path).ok()?;
223        serde_json::from_str(&data).ok()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn file_signature_buckets() {
233        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
234        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
235        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
236        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
237        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
238    }
239
240    #[test]
241    fn predict_returns_none_without_history() {
242        let predictor = ModePredictor::default();
243        let sig = FileSignature::from_path("test.zzz", 500);
244        assert!(predictor.predict_from_local(&sig).is_none());
245    }
246
247    #[test]
248    fn predict_returns_none_with_too_few_entries() {
249        let mut predictor = ModePredictor::default();
250        let sig = FileSignature::from_path("test.zzz", 500);
251        predictor.record(
252            sig.clone(),
253            ModeOutcome {
254                mode: "full".to_string(),
255                tokens_in: 100,
256                tokens_out: 100,
257                density: 0.5,
258            },
259        );
260        assert!(predictor.predict_from_local(&sig).is_none());
261    }
262
263    #[test]
264    fn predict_learns_best_mode() {
265        let mut predictor = ModePredictor::default();
266        let sig = FileSignature::from_path("big.rs", 5000);
267        for _ in 0..5 {
268            predictor.record(
269                sig.clone(),
270                ModeOutcome {
271                    mode: "full".to_string(),
272                    tokens_in: 5000,
273                    tokens_out: 5000,
274                    density: 0.3,
275                },
276            );
277            predictor.record(
278                sig.clone(),
279                ModeOutcome {
280                    mode: "map".to_string(),
281                    tokens_in: 5000,
282                    tokens_out: 800,
283                    density: 0.6,
284                },
285            );
286        }
287        let best = predictor.predict_best_mode(&sig);
288        assert_eq!(best, Some("map".to_string()));
289    }
290
291    #[test]
292    fn history_caps_at_100() {
293        let mut predictor = ModePredictor::default();
294        let sig = FileSignature::from_path("test.rs", 100);
295        for _ in 0..120 {
296            predictor.record(
297                sig.clone(),
298                ModeOutcome {
299                    mode: "full".to_string(),
300                    tokens_in: 100,
301                    tokens_out: 100,
302                    density: 0.5,
303                },
304            );
305        }
306        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
307    }
308
309    #[test]
310    fn defaults_return_none_for_small_files() {
311        let sig = FileSignature::from_path("small.rs", 200);
312        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
313    }
314
315    #[test]
316    fn defaults_recommend_map_for_medium_code() {
317        let sig = FileSignature::from_path("medium.rs", 3000);
318        assert_eq!(
319            ModePredictor::predict_from_defaults(&sig),
320            Some("map".to_string())
321        );
322    }
323
324    #[test]
325    fn defaults_recommend_aggressive_for_json() {
326        let sig = FileSignature::from_path("config.json", 1000);
327        assert_eq!(
328            ModePredictor::predict_from_defaults(&sig),
329            Some("aggressive".to_string())
330        );
331    }
332
333    #[test]
334    fn defaults_recommend_signatures_for_huge_code() {
335        let sig = FileSignature::from_path("huge.ts", 25000);
336        assert_eq!(
337            ModePredictor::predict_from_defaults(&sig),
338            Some("signatures".to_string())
339        );
340    }
341
342    #[test]
343    fn defaults_recommend_aggressive_for_large_unknown() {
344        let sig = FileSignature::from_path("data.xyz", 8000);
345        assert_eq!(
346            ModePredictor::predict_from_defaults(&sig),
347            Some("aggressive".to_string())
348        );
349    }
350
351    #[test]
352    fn mode_outcome_efficiency() {
353        let o = ModeOutcome {
354            mode: "map".to_string(),
355            tokens_in: 1000,
356            tokens_out: 200,
357            density: 0.6,
358        };
359        assert!(o.efficiency() > 0.0);
360    }
361}