Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(ModePredictor, Instant)>> = Mutex::new(None);
9
10/// Observed outcome of a read mode: tokens in/out and information density.
11#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13    pub mode: String,
14    pub tokens_in: usize,
15    pub tokens_out: usize,
16    pub density: f64,
17}
18
19impl ModeOutcome {
20    /// Computes an efficiency score: density / compression ratio.
21    pub fn efficiency(&self) -> f64 {
22        if self.tokens_out == 0 {
23            return 0.0;
24        }
25        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26    }
27}
28
29/// File identity for mode prediction: extension + token-count size bucket.
30#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32    pub ext: String,
33    pub size_bucket: u8,
34}
35
36impl FileSignature {
37    /// Creates a file signature from its path and token count.
38    pub fn from_path(path: &str, token_count: usize) -> Self {
39        let ext = std::path::Path::new(path)
40            .extension()
41            .and_then(|e| e.to_str())
42            .unwrap_or("")
43            .to_string();
44        let size_bucket = match token_count {
45            0..=500 => 0,
46            501..=2000 => 1,
47            2001..=5000 => 2,
48            5001..=20000 => 3,
49            _ => 4,
50        };
51        Self { ext, size_bucket }
52    }
53}
54
55/// Learns the best read mode per file signature from historical outcomes.
56#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58    history: HashMap<FileSignature, Vec<ModeOutcome>>,
59    project_root: Option<String>,
60}
61
62impl ModePredictor {
63    /// Loads or creates the predictor, using an in-memory buffer for caching.
64    pub fn new() -> Self {
65        let mut guard = PREDICTOR_BUFFER
66            .lock()
67            .unwrap_or_else(std::sync::PoisonError::into_inner);
68        if let Some((ref predictor, _)) = *guard {
69            return predictor.clone();
70        }
71        let mut loaded = Self::load_from_disk().unwrap_or_default();
72        if loaded.project_root.is_none() {
73            loaded.project_root = std::env::current_dir()
74                .ok()
75                .map(|p| p.to_string_lossy().to_string());
76        }
77        *guard = Some((loaded.clone(), Instant::now()));
78        loaded
79    }
80
81    pub fn with_project_root(mut self, root: &str) -> Self {
82        self.project_root = Some(root.to_string());
83        self
84    }
85
86    pub fn set_project_root(&mut self, root: &str) {
87        self.project_root = Some(root.to_string());
88    }
89
90    /// Records a mode outcome for a file signature (capped at 100 entries).
91    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
92        let entries = self.history.entry(sig).or_default();
93        entries.push(outcome);
94        if entries.len() > 100 {
95            entries.drain(0..50);
96        }
97    }
98
99    /// Returns the best mode based on historical efficiency.
100    /// Chain: local history -> cloud adaptive models -> built-in defaults.
101    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
102        let default_mode = Self::predict_from_defaults(sig);
103
104        let allow_override = |candidate: &str| -> bool {
105            let Some(def) = default_mode.as_deref() else {
106                return true;
107            };
108            if candidate == "full" {
109                return false;
110            }
111            // For code-structured defaults, never override to lossy modes.
112            if (def == "map" || def == "signatures")
113                && (candidate == "aggressive" || candidate == "entropy")
114            {
115                return false;
116            }
117            true
118        };
119
120        if let Some(local) = self.predict_from_local(sig) {
121            if allow_override(&local) {
122                return Some(local);
123            }
124        }
125        if let Some(bandit) = self.predict_from_bandit(sig) {
126            if allow_override(&bandit) {
127                return Some(bandit);
128            }
129        }
130        if let Some(cloud) = self.predict_from_cloud(sig) {
131            if allow_override(&cloud) {
132                return Some(cloud);
133            }
134        }
135        default_mode
136    }
137
138    fn predict_from_bandit(&self, sig: &FileSignature) -> Option<String> {
139        let key = format!("{}_feedback", sig.ext);
140        let store =
141            crate::core::bandit::BanditStore::load(self.project_root.as_deref().unwrap_or("."));
142        let bandit = store.bandits.get(&key)?;
143        if bandit.total_pulls < 5 {
144            return None;
145        }
146        let best_arm = bandit.arms.iter().max_by(|a, b| {
147            a.mean()
148                .partial_cmp(&b.mean())
149                .unwrap_or(std::cmp::Ordering::Equal)
150        })?;
151        let mode = match best_arm.name.as_str() {
152            "conservative" => "full",
153            "balanced" => "signatures",
154            "aggressive" => "aggressive",
155            _ => return None,
156        };
157        Some(mode.to_string())
158    }
159
160    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
161        let entries = self.history.get(sig)?;
162        if entries.len() < 3 {
163            return None;
164        }
165
166        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
167        for entry in entries {
168            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
169            *sum += entry.efficiency();
170            *count += 1;
171        }
172
173        mode_scores
174            .into_iter()
175            .max_by(|a, b| {
176                let avg_a = a.1 .0 / a.1 .1 as f64;
177                let avg_b = b.1 .0 / b.1 .1 as f64;
178                avg_a
179                    .partial_cmp(&avg_b)
180                    .unwrap_or(std::cmp::Ordering::Equal)
181            })
182            .map(|(mode, _)| mode.to_string())
183    }
184
185    /// Loads cloud adaptive models (synced from LeanCTX Cloud).
186    /// Models are cached locally and auto-updated for cloud users.
187    #[allow(clippy::unused_self)]
188    fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
189        let data = crate::cloud_client::load_cloud_models()?;
190        let models = data["models"].as_array()?;
191
192        let ext_with_dot = format!(".{}", sig.ext);
193        let bucket_name = match sig.size_bucket {
194            0 => "0-500",
195            1 => "500-2k",
196            2 => "2k-10k",
197            _ => "10k+",
198        };
199
200        let mut best: Option<(&str, f64)> = None;
201
202        for model in models {
203            let m_ext = model["file_ext"].as_str().unwrap_or("");
204            let m_bucket = model["size_bucket"].as_str().unwrap_or("");
205            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
206
207            if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
208                if let Some(mode) = model["recommended_mode"].as_str() {
209                    if best.is_none_or(|(_, c)| confidence > c) {
210                        best = Some((mode, confidence));
211                    }
212                }
213            }
214        }
215
216        if let Some((mode, _)) = best {
217            return Some(mode.to_string());
218        }
219
220        for model in models {
221            let m_ext = model["file_ext"].as_str().unwrap_or("");
222            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
223            if m_ext == ext_with_dot && confidence > 0.5 {
224                return model["recommended_mode"]
225                    .as_str()
226                    .map(std::string::ToString::to_string);
227            }
228        }
229
230        None
231    }
232
233    /// Built-in defaults for common file types and sizes.
234    /// Ensures reasonable compression even without local history or cloud models.
235    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
236    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
237        let mode = match (sig.ext.as_str(), sig.size_bucket) {
238            // Tiny files (0-500 tokens): always full — compression overhead not worth it
239            (_, 0) => return None,
240
241            // Lock / large code files: signatures only
242            ("lock", _)
243            | (
244                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
245                | "swift" | "kt" | "cs" | "vue" | "svelte",
246                4..,
247            ) => "signatures",
248
249            // Code files 2k-10k / SQL: map gives structure without bloat
250            (
251                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
252                | "swift" | "kt" | "cs" | "vue" | "svelte",
253                2 | 3,
254            )
255            | ("sql", 2..) => "map",
256
257            // Config/data, CSS, and large unknown files: aggressive
258            ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _)
259            | ("css" | "scss" | "less" | "sass", 2..)
260            | (_, 3..) => "aggressive",
261
262            _ => return None,
263        };
264        Some(mode.to_string())
265    }
266
267    /// Saves to the in-memory buffer and flushes to disk if the interval elapsed.
268    pub fn save(&self) {
269        let mut guard = PREDICTOR_BUFFER
270            .lock()
271            .unwrap_or_else(std::sync::PoisonError::into_inner);
272        let should_flush = match *guard {
273            Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
274            None => true,
275        };
276        *guard = Some((self.clone(), Instant::now()));
277        if should_flush {
278            self.save_to_disk();
279        }
280    }
281
282    fn save_to_disk(&self) {
283        let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
284            return;
285        };
286        let _ = std::fs::create_dir_all(&dir);
287        let path = dir.join(STATS_FILE);
288        if let Ok(json) = serde_json::to_string_pretty(self) {
289            let tmp = dir.join(".mode_stats.tmp");
290            if std::fs::write(&tmp, &json).is_ok() {
291                let _ = std::fs::rename(&tmp, &path);
292            }
293        }
294    }
295
296    /// Forces an immediate write of the buffered predictor state to disk.
297    pub fn flush() {
298        let guard = PREDICTOR_BUFFER
299            .lock()
300            .unwrap_or_else(std::sync::PoisonError::into_inner);
301        if let Some((ref predictor, _)) = *guard {
302            predictor.save_to_disk();
303        }
304    }
305
306    fn load_from_disk() -> Option<Self> {
307        let path = crate::core::data_dir::lean_ctx_data_dir()
308            .ok()?
309            .join(STATS_FILE);
310        let data = std::fs::read_to_string(path).ok()?;
311        serde_json::from_str(&data).ok()
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn file_signature_buckets() {
321        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
322        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
323        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
324        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
325        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
326    }
327
328    #[test]
329    fn predict_returns_none_without_history() {
330        let predictor = ModePredictor::default();
331        let sig = FileSignature::from_path("test.zzz", 500);
332        assert!(predictor.predict_from_local(&sig).is_none());
333    }
334
335    #[test]
336    fn predict_returns_none_with_too_few_entries() {
337        let mut predictor = ModePredictor::default();
338        let sig = FileSignature::from_path("test.zzz", 500);
339        predictor.record(
340            sig.clone(),
341            ModeOutcome {
342                mode: "full".to_string(),
343                tokens_in: 100,
344                tokens_out: 100,
345                density: 0.5,
346            },
347        );
348        assert!(predictor.predict_from_local(&sig).is_none());
349    }
350
351    #[test]
352    fn predict_learns_best_mode() {
353        let mut predictor = ModePredictor::default();
354        let sig = FileSignature::from_path("big.rs", 5000);
355        for _ in 0..5 {
356            predictor.record(
357                sig.clone(),
358                ModeOutcome {
359                    mode: "full".to_string(),
360                    tokens_in: 5000,
361                    tokens_out: 5000,
362                    density: 0.3,
363                },
364            );
365            predictor.record(
366                sig.clone(),
367                ModeOutcome {
368                    mode: "map".to_string(),
369                    tokens_in: 5000,
370                    tokens_out: 800,
371                    density: 0.6,
372                },
373            );
374        }
375        let best = predictor.predict_best_mode(&sig);
376        assert_eq!(best, Some("map".to_string()));
377    }
378
379    #[test]
380    fn history_caps_at_100() {
381        let mut predictor = ModePredictor::default();
382        let sig = FileSignature::from_path("test.rs", 100);
383        for _ in 0..120 {
384            predictor.record(
385                sig.clone(),
386                ModeOutcome {
387                    mode: "full".to_string(),
388                    tokens_in: 100,
389                    tokens_out: 100,
390                    density: 0.5,
391                },
392            );
393        }
394        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
395    }
396
397    #[test]
398    fn defaults_return_none_for_small_files() {
399        let sig = FileSignature::from_path("small.rs", 200);
400        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
401    }
402
403    #[test]
404    fn defaults_recommend_map_for_medium_code() {
405        let sig = FileSignature::from_path("medium.rs", 3000);
406        assert_eq!(
407            ModePredictor::predict_from_defaults(&sig),
408            Some("map".to_string())
409        );
410    }
411
412    #[test]
413    fn defaults_recommend_aggressive_for_json() {
414        let sig = FileSignature::from_path("config.json", 1000);
415        assert_eq!(
416            ModePredictor::predict_from_defaults(&sig),
417            Some("aggressive".to_string())
418        );
419    }
420
421    #[test]
422    fn defaults_recommend_signatures_for_huge_code() {
423        let sig = FileSignature::from_path("huge.ts", 25000);
424        assert_eq!(
425            ModePredictor::predict_from_defaults(&sig),
426            Some("signatures".to_string())
427        );
428    }
429
430    #[test]
431    fn defaults_recommend_aggressive_for_large_unknown() {
432        let sig = FileSignature::from_path("data.xyz", 8000);
433        assert_eq!(
434            ModePredictor::predict_from_defaults(&sig),
435            Some("aggressive".to_string())
436        );
437    }
438
439    #[test]
440    fn mode_outcome_efficiency() {
441        let o = ModeOutcome {
442            mode: "map".to_string(),
443            tokens_in: 1000,
444            tokens_out: 200,
445            density: 0.6,
446        };
447        assert!(o.efficiency() > 0.0);
448    }
449}