Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10/// Observed outcome of a read mode: tokens in/out and information density.
11#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13    pub mode: String,
14    pub tokens_in: usize,
15    pub tokens_out: usize,
16    pub density: f64,
17}
18
19impl ModeOutcome {
20    /// Computes an efficiency score: density / compression ratio.
21    pub fn efficiency(&self) -> f64 {
22        if self.tokens_out == 0 {
23            return 0.0;
24        }
25        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26    }
27}
28
29/// File identity for mode prediction: extension + token-count size bucket.
30#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32    pub ext: String,
33    pub size_bucket: u8,
34}
35
36impl FileSignature {
37    /// Creates a file signature from its path and token count.
38    pub fn from_path(path: &str, token_count: usize) -> Self {
39        let ext = std::path::Path::new(path)
40            .extension()
41            .and_then(|e| e.to_str())
42            .unwrap_or("")
43            .to_string();
44        let size_bucket = match token_count {
45            0..=500 => 0,
46            501..=2000 => 1,
47            2001..=5000 => 2,
48            5001..=20000 => 3,
49            _ => 4,
50        };
51        Self { ext, size_bucket }
52    }
53}
54
55/// Learns the best read mode per file signature from historical outcomes.
56#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58    history: HashMap<FileSignature, Vec<ModeOutcome>>,
59}
60
61impl ModePredictor {
62    /// Loads or creates the predictor, using an in-memory buffer for caching.
63    pub fn new() -> Self {
64        let mut guard = PREDICTOR_BUFFER
65            .lock()
66            .unwrap_or_else(std::sync::PoisonError::into_inner);
67        if let Some((ref predictor, _)) = *guard {
68            return predictor.clone();
69        }
70        let loaded = Self::load_from_disk().unwrap_or_default();
71        *guard = Some((loaded.clone(), Instant::now()));
72        loaded
73    }
74
75    /// Records a mode outcome for a file signature (capped at 100 entries).
76    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
77        let entries = self.history.entry(sig).or_default();
78        entries.push(outcome);
79        if entries.len() > 100 {
80            entries.drain(0..50);
81        }
82    }
83
84    /// Returns the best mode based on historical efficiency.
85    /// Chain: local history -> cloud adaptive models -> built-in defaults.
86    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
87        if let Some(local) = self.predict_from_local(sig) {
88            return Some(local);
89        }
90        if let Some(cloud) = self.predict_from_cloud(sig) {
91            return Some(cloud);
92        }
93        Self::predict_from_defaults(sig)
94    }
95
96    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
97        let entries = self.history.get(sig)?;
98        if entries.len() < 3 {
99            return None;
100        }
101
102        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
103        for entry in entries {
104            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
105            *sum += entry.efficiency();
106            *count += 1;
107        }
108
109        mode_scores
110            .into_iter()
111            .max_by(|a, b| {
112                let avg_a = a.1 .0 / a.1 .1 as f64;
113                let avg_b = b.1 .0 / b.1 .1 as f64;
114                avg_a
115                    .partial_cmp(&avg_b)
116                    .unwrap_or(std::cmp::Ordering::Equal)
117            })
118            .map(|(mode, _)| mode.to_string())
119    }
120
121    /// Loads cloud adaptive models (synced from LeanCTX Cloud).
122    /// Models are cached locally and auto-updated for cloud users.
123    #[allow(clippy::unused_self)]
124    fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
125        let data = crate::cloud_client::load_cloud_models()?;
126        let models = data["models"].as_array()?;
127
128        let ext_with_dot = format!(".{}", sig.ext);
129        let bucket_name = match sig.size_bucket {
130            0 => "0-500",
131            1 => "500-2k",
132            2 => "2k-10k",
133            _ => "10k+",
134        };
135
136        let mut best: Option<(&str, f64)> = None;
137
138        for model in models {
139            let m_ext = model["file_ext"].as_str().unwrap_or("");
140            let m_bucket = model["size_bucket"].as_str().unwrap_or("");
141            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
142
143            if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
144                if let Some(mode) = model["recommended_mode"].as_str() {
145                    if best.is_none_or(|(_, c)| confidence > c) {
146                        best = Some((mode, confidence));
147                    }
148                }
149            }
150        }
151
152        if let Some((mode, _)) = best {
153            return Some(mode.to_string());
154        }
155
156        for model in models {
157            let m_ext = model["file_ext"].as_str().unwrap_or("");
158            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
159            if m_ext == ext_with_dot && confidence > 0.5 {
160                return model["recommended_mode"]
161                    .as_str()
162                    .map(std::string::ToString::to_string);
163            }
164        }
165
166        None
167    }
168
169    /// Built-in defaults for common file types and sizes.
170    /// Ensures reasonable compression even without local history or cloud models.
171    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
172    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
173        let mode = match (sig.ext.as_str(), sig.size_bucket) {
174            // Tiny files (0-500 tokens): always full — compression overhead not worth it
175            (_, 0) => return None,
176
177            // Lock / large code files: signatures only
178            ("lock", _)
179            | (
180                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
181                | "swift" | "kt" | "cs" | "vue" | "svelte",
182                4..,
183            ) => "signatures",
184
185            // Code files 2k-10k / SQL: map gives structure without bloat
186            (
187                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
188                | "swift" | "kt" | "cs" | "vue" | "svelte",
189                2 | 3,
190            )
191            | ("sql", 2..) => "map",
192
193            // Config/data, CSS, and large unknown files: aggressive
194            ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _)
195            | ("css" | "scss" | "less" | "sass", 2..)
196            | (_, 3..) => "aggressive",
197
198            _ => return None,
199        };
200        Some(mode.to_string())
201    }
202
203    /// Saves to the in-memory buffer and flushes to disk if the interval elapsed.
204    pub fn save(&self) {
205        let mut guard = PREDICTOR_BUFFER
206            .lock()
207            .unwrap_or_else(std::sync::PoisonError::into_inner);
208        let should_flush = match *guard {
209            Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
210            None => true,
211        };
212        *guard = Some((self.clone(), Instant::now()));
213        if should_flush {
214            self.save_to_disk();
215        }
216    }
217
218    fn save_to_disk(&self) {
219        let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
220            return;
221        };
222        let _ = std::fs::create_dir_all(&dir);
223        let path = dir.join(STATS_FILE);
224        if let Ok(json) = serde_json::to_string_pretty(self) {
225            let tmp = dir.join(".mode_stats.tmp");
226            if std::fs::write(&tmp, &json).is_ok() {
227                let _ = std::fs::rename(&tmp, &path);
228            }
229        }
230    }
231
232    /// Forces an immediate write of the buffered predictor state to disk.
233    pub fn flush() {
234        let guard = PREDICTOR_BUFFER
235            .lock()
236            .unwrap_or_else(std::sync::PoisonError::into_inner);
237        if let Some((ref predictor, _)) = *guard {
238            predictor.save_to_disk();
239        }
240    }
241
242    fn load_from_disk() -> Option<Self> {
243        let path = crate::core::data_dir::lean_ctx_data_dir()
244            .ok()?
245            .join(STATS_FILE);
246        let data = std::fs::read_to_string(path).ok()?;
247        serde_json::from_str(&data).ok()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn file_signature_buckets() {
257        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
258        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
259        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
260        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
261        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
262    }
263
264    #[test]
265    fn predict_returns_none_without_history() {
266        let predictor = ModePredictor::default();
267        let sig = FileSignature::from_path("test.zzz", 500);
268        assert!(predictor.predict_from_local(&sig).is_none());
269    }
270
271    #[test]
272    fn predict_returns_none_with_too_few_entries() {
273        let mut predictor = ModePredictor::default();
274        let sig = FileSignature::from_path("test.zzz", 500);
275        predictor.record(
276            sig.clone(),
277            ModeOutcome {
278                mode: "full".to_string(),
279                tokens_in: 100,
280                tokens_out: 100,
281                density: 0.5,
282            },
283        );
284        assert!(predictor.predict_from_local(&sig).is_none());
285    }
286
287    #[test]
288    fn predict_learns_best_mode() {
289        let mut predictor = ModePredictor::default();
290        let sig = FileSignature::from_path("big.rs", 5000);
291        for _ in 0..5 {
292            predictor.record(
293                sig.clone(),
294                ModeOutcome {
295                    mode: "full".to_string(),
296                    tokens_in: 5000,
297                    tokens_out: 5000,
298                    density: 0.3,
299                },
300            );
301            predictor.record(
302                sig.clone(),
303                ModeOutcome {
304                    mode: "map".to_string(),
305                    tokens_in: 5000,
306                    tokens_out: 800,
307                    density: 0.6,
308                },
309            );
310        }
311        let best = predictor.predict_best_mode(&sig);
312        assert_eq!(best, Some("map".to_string()));
313    }
314
315    #[test]
316    fn history_caps_at_100() {
317        let mut predictor = ModePredictor::default();
318        let sig = FileSignature::from_path("test.rs", 100);
319        for _ in 0..120 {
320            predictor.record(
321                sig.clone(),
322                ModeOutcome {
323                    mode: "full".to_string(),
324                    tokens_in: 100,
325                    tokens_out: 100,
326                    density: 0.5,
327                },
328            );
329        }
330        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
331    }
332
333    #[test]
334    fn defaults_return_none_for_small_files() {
335        let sig = FileSignature::from_path("small.rs", 200);
336        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
337    }
338
339    #[test]
340    fn defaults_recommend_map_for_medium_code() {
341        let sig = FileSignature::from_path("medium.rs", 3000);
342        assert_eq!(
343            ModePredictor::predict_from_defaults(&sig),
344            Some("map".to_string())
345        );
346    }
347
348    #[test]
349    fn defaults_recommend_aggressive_for_json() {
350        let sig = FileSignature::from_path("config.json", 1000);
351        assert_eq!(
352            ModePredictor::predict_from_defaults(&sig),
353            Some("aggressive".to_string())
354        );
355    }
356
357    #[test]
358    fn defaults_recommend_signatures_for_huge_code() {
359        let sig = FileSignature::from_path("huge.ts", 25000);
360        assert_eq!(
361            ModePredictor::predict_from_defaults(&sig),
362            Some("signatures".to_string())
363        );
364    }
365
366    #[test]
367    fn defaults_recommend_aggressive_for_large_unknown() {
368        let sig = FileSignature::from_path("data.xyz", 8000);
369        assert_eq!(
370            ModePredictor::predict_from_defaults(&sig),
371            Some("aggressive".to_string())
372        );
373    }
374
375    #[test]
376    fn mode_outcome_efficiency() {
377        let o = ModeOutcome {
378            mode: "map".to_string(),
379            tokens_in: 1000,
380            tokens_out: 200,
381            density: 0.6,
382        };
383        assert!(o.efficiency() > 0.0);
384    }
385}