Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10/// Observed outcome of a read mode: tokens in/out and information density.
11#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13    pub mode: String,
14    pub tokens_in: usize,
15    pub tokens_out: usize,
16    pub density: f64,
17}
18
19impl ModeOutcome {
20    /// Computes an efficiency score: density / compression ratio.
21    pub fn efficiency(&self) -> f64 {
22        if self.tokens_out == 0 {
23            return 0.0;
24        }
25        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26    }
27}
28
29/// File identity for mode prediction: extension + token-count size bucket.
30#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32    pub ext: String,
33    pub size_bucket: u8,
34}
35
36impl FileSignature {
37    /// Creates a file signature from its path and token count.
38    pub fn from_path(path: &str, token_count: usize) -> Self {
39        let ext = std::path::Path::new(path)
40            .extension()
41            .and_then(|e| e.to_str())
42            .unwrap_or("")
43            .to_string();
44        let size_bucket = match token_count {
45            0..=500 => 0,
46            501..=2000 => 1,
47            2001..=5000 => 2,
48            5001..=20000 => 3,
49            _ => 4,
50        };
51        Self { ext, size_bucket }
52    }
53}
54
55/// Learns the best read mode per file signature from historical outcomes.
56#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58    history: HashMap<FileSignature, Vec<ModeOutcome>>,
59    project_root: Option<String>,
60}
61
62impl ModePredictor {
63    /// Loads or creates the predictor, using an in-memory buffer for caching.
64    pub fn new() -> Self {
65        let mut guard = PREDICTOR_BUFFER
66            .lock()
67            .unwrap_or_else(std::sync::PoisonError::into_inner);
68        if let Some((ref predictor, _)) = *guard {
69            return predictor.clone();
70        }
71        let mut loaded = Self::load_from_disk().unwrap_or_default();
72        if loaded.project_root.is_none() {
73            loaded.project_root = std::env::current_dir()
74                .ok()
75                .map(|p| p.to_string_lossy().to_string());
76        }
77        *guard = Some((loaded.clone(), Instant::now()));
78        loaded
79    }
80
81    pub fn with_project_root(mut self, root: &str) -> Self {
82        self.project_root = Some(root.to_string());
83        self
84    }
85
86    pub fn set_project_root(&mut self, root: &str) {
87        self.project_root = Some(root.to_string());
88    }
89
90    /// Records a mode outcome for a file signature (capped at 100 entries).
91    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
92        let entries = self.history.entry(sig).or_default();
93        entries.push(outcome);
94        if entries.len() > 100 {
95            entries.drain(0..50);
96        }
97    }
98
99    /// Returns the best mode based on historical efficiency.
100    /// Chain: local history -> cloud adaptive models -> built-in defaults.
101    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
102        if let Some(local) = self.predict_from_local(sig) {
103            return Some(local);
104        }
105        if let Some(bandit) = self.predict_from_bandit(sig) {
106            return Some(bandit);
107        }
108        if let Some(cloud) = self.predict_from_cloud(sig) {
109            return Some(cloud);
110        }
111        Self::predict_from_defaults(sig)
112    }
113
114    fn predict_from_bandit(&self, sig: &FileSignature) -> Option<String> {
115        let key = format!("{}_feedback", sig.ext);
116        let store =
117            crate::core::bandit::BanditStore::load(self.project_root.as_deref().unwrap_or("."));
118        let bandit = store.bandits.get(&key)?;
119        if bandit.total_pulls < 5 {
120            return None;
121        }
122        let best_arm = bandit.arms.iter().max_by(|a, b| {
123            a.mean()
124                .partial_cmp(&b.mean())
125                .unwrap_or(std::cmp::Ordering::Equal)
126        })?;
127        let mode = match best_arm.name.as_str() {
128            "conservative" => "full",
129            "balanced" => "signatures",
130            "aggressive" => "aggressive",
131            _ => return None,
132        };
133        Some(mode.to_string())
134    }
135
136    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
137        let entries = self.history.get(sig)?;
138        if entries.len() < 3 {
139            return None;
140        }
141
142        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
143        for entry in entries {
144            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
145            *sum += entry.efficiency();
146            *count += 1;
147        }
148
149        mode_scores
150            .into_iter()
151            .max_by(|a, b| {
152                let avg_a = a.1 .0 / a.1 .1 as f64;
153                let avg_b = b.1 .0 / b.1 .1 as f64;
154                avg_a
155                    .partial_cmp(&avg_b)
156                    .unwrap_or(std::cmp::Ordering::Equal)
157            })
158            .map(|(mode, _)| mode.to_string())
159    }
160
161    /// Loads cloud adaptive models (synced from LeanCTX Cloud).
162    /// Models are cached locally and auto-updated for cloud users.
163    #[allow(clippy::unused_self)]
164    fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
165        let data = crate::cloud_client::load_cloud_models()?;
166        let models = data["models"].as_array()?;
167
168        let ext_with_dot = format!(".{}", sig.ext);
169        let bucket_name = match sig.size_bucket {
170            0 => "0-500",
171            1 => "500-2k",
172            2 => "2k-10k",
173            _ => "10k+",
174        };
175
176        let mut best: Option<(&str, f64)> = None;
177
178        for model in models {
179            let m_ext = model["file_ext"].as_str().unwrap_or("");
180            let m_bucket = model["size_bucket"].as_str().unwrap_or("");
181            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
182
183            if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
184                if let Some(mode) = model["recommended_mode"].as_str() {
185                    if best.is_none_or(|(_, c)| confidence > c) {
186                        best = Some((mode, confidence));
187                    }
188                }
189            }
190        }
191
192        if let Some((mode, _)) = best {
193            return Some(mode.to_string());
194        }
195
196        for model in models {
197            let m_ext = model["file_ext"].as_str().unwrap_or("");
198            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
199            if m_ext == ext_with_dot && confidence > 0.5 {
200                return model["recommended_mode"]
201                    .as_str()
202                    .map(std::string::ToString::to_string);
203            }
204        }
205
206        None
207    }
208
209    /// Built-in defaults for common file types and sizes.
210    /// Ensures reasonable compression even without local history or cloud models.
211    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
212    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
213        let mode = match (sig.ext.as_str(), sig.size_bucket) {
214            // Tiny files (0-500 tokens): always full — compression overhead not worth it
215            (_, 0) => return None,
216
217            // Lock / large code files: signatures only
218            ("lock", _)
219            | (
220                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
221                | "swift" | "kt" | "cs" | "vue" | "svelte",
222                4..,
223            ) => "signatures",
224
225            // Code files 2k-10k / SQL: map gives structure without bloat
226            (
227                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
228                | "swift" | "kt" | "cs" | "vue" | "svelte",
229                2 | 3,
230            )
231            | ("sql", 2..) => "map",
232
233            // Config/data, CSS, and large unknown files: aggressive
234            ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _)
235            | ("css" | "scss" | "less" | "sass", 2..)
236            | (_, 3..) => "aggressive",
237
238            _ => return None,
239        };
240        Some(mode.to_string())
241    }
242
243    /// Saves to the in-memory buffer and flushes to disk if the interval elapsed.
244    pub fn save(&self) {
245        let mut guard = PREDICTOR_BUFFER
246            .lock()
247            .unwrap_or_else(std::sync::PoisonError::into_inner);
248        let should_flush = match *guard {
249            Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
250            None => true,
251        };
252        *guard = Some((self.clone(), Instant::now()));
253        if should_flush {
254            self.save_to_disk();
255        }
256    }
257
258    fn save_to_disk(&self) {
259        let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
260            return;
261        };
262        let _ = std::fs::create_dir_all(&dir);
263        let path = dir.join(STATS_FILE);
264        if let Ok(json) = serde_json::to_string_pretty(self) {
265            let tmp = dir.join(".mode_stats.tmp");
266            if std::fs::write(&tmp, &json).is_ok() {
267                let _ = std::fs::rename(&tmp, &path);
268            }
269        }
270    }
271
272    /// Forces an immediate write of the buffered predictor state to disk.
273    pub fn flush() {
274        let guard = PREDICTOR_BUFFER
275            .lock()
276            .unwrap_or_else(std::sync::PoisonError::into_inner);
277        if let Some((ref predictor, _)) = *guard {
278            predictor.save_to_disk();
279        }
280    }
281
282    fn load_from_disk() -> Option<Self> {
283        let path = crate::core::data_dir::lean_ctx_data_dir()
284            .ok()?
285            .join(STATS_FILE);
286        let data = std::fs::read_to_string(path).ok()?;
287        serde_json::from_str(&data).ok()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn file_signature_buckets() {
297        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
298        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
299        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
300        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
301        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
302    }
303
304    #[test]
305    fn predict_returns_none_without_history() {
306        let predictor = ModePredictor::default();
307        let sig = FileSignature::from_path("test.zzz", 500);
308        assert!(predictor.predict_from_local(&sig).is_none());
309    }
310
311    #[test]
312    fn predict_returns_none_with_too_few_entries() {
313        let mut predictor = ModePredictor::default();
314        let sig = FileSignature::from_path("test.zzz", 500);
315        predictor.record(
316            sig.clone(),
317            ModeOutcome {
318                mode: "full".to_string(),
319                tokens_in: 100,
320                tokens_out: 100,
321                density: 0.5,
322            },
323        );
324        assert!(predictor.predict_from_local(&sig).is_none());
325    }
326
327    #[test]
328    fn predict_learns_best_mode() {
329        let mut predictor = ModePredictor::default();
330        let sig = FileSignature::from_path("big.rs", 5000);
331        for _ in 0..5 {
332            predictor.record(
333                sig.clone(),
334                ModeOutcome {
335                    mode: "full".to_string(),
336                    tokens_in: 5000,
337                    tokens_out: 5000,
338                    density: 0.3,
339                },
340            );
341            predictor.record(
342                sig.clone(),
343                ModeOutcome {
344                    mode: "map".to_string(),
345                    tokens_in: 5000,
346                    tokens_out: 800,
347                    density: 0.6,
348                },
349            );
350        }
351        let best = predictor.predict_best_mode(&sig);
352        assert_eq!(best, Some("map".to_string()));
353    }
354
355    #[test]
356    fn history_caps_at_100() {
357        let mut predictor = ModePredictor::default();
358        let sig = FileSignature::from_path("test.rs", 100);
359        for _ in 0..120 {
360            predictor.record(
361                sig.clone(),
362                ModeOutcome {
363                    mode: "full".to_string(),
364                    tokens_in: 100,
365                    tokens_out: 100,
366                    density: 0.5,
367                },
368            );
369        }
370        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
371    }
372
373    #[test]
374    fn defaults_return_none_for_small_files() {
375        let sig = FileSignature::from_path("small.rs", 200);
376        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
377    }
378
379    #[test]
380    fn defaults_recommend_map_for_medium_code() {
381        let sig = FileSignature::from_path("medium.rs", 3000);
382        assert_eq!(
383            ModePredictor::predict_from_defaults(&sig),
384            Some("map".to_string())
385        );
386    }
387
388    #[test]
389    fn defaults_recommend_aggressive_for_json() {
390        let sig = FileSignature::from_path("config.json", 1000);
391        assert_eq!(
392            ModePredictor::predict_from_defaults(&sig),
393            Some("aggressive".to_string())
394        );
395    }
396
397    #[test]
398    fn defaults_recommend_signatures_for_huge_code() {
399        let sig = FileSignature::from_path("huge.ts", 25000);
400        assert_eq!(
401            ModePredictor::predict_from_defaults(&sig),
402            Some("signatures".to_string())
403        );
404    }
405
406    #[test]
407    fn defaults_recommend_aggressive_for_large_unknown() {
408        let sig = FileSignature::from_path("data.xyz", 8000);
409        assert_eq!(
410            ModePredictor::predict_from_defaults(&sig),
411            Some("aggressive".to_string())
412        );
413    }
414
415    #[test]
416    fn mode_outcome_efficiency() {
417        let o = ModeOutcome {
418            mode: "map".to_string(),
419            tokens_in: 1000,
420            tokens_out: 200,
421            density: 0.6,
422        };
423        assert!(o.efficiency() > 0.0);
424    }
425}