Skip to main content

lean_ctx/core/
mode_predictor.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Instant;
4
5const STATS_FILE: &str = "mode_stats.json";
6const PREDICTOR_FLUSH_SECS: u64 = 10;
7
8static PREDICTOR_BUFFER: Mutex<Option<(Arc<ModePredictor>, Instant)>> = Mutex::new(None);
9
10/// Observed outcome of a read mode: tokens in/out and information density.
11#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct ModeOutcome {
13    pub mode: String,
14    pub tokens_in: usize,
15    pub tokens_out: usize,
16    pub density: f64,
17}
18
19impl ModeOutcome {
20    /// Computes an efficiency score: density / compression ratio.
21    pub fn efficiency(&self) -> f64 {
22        if self.tokens_out == 0 {
23            return 0.0;
24        }
25        self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
26    }
27}
28
29/// File identity for mode prediction: extension + token-count size bucket.
30#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
31pub struct FileSignature {
32    pub ext: String,
33    pub size_bucket: u8,
34}
35
36impl FileSignature {
37    /// Creates a file signature from its path and token count.
38    pub fn from_path(path: &str, token_count: usize) -> Self {
39        let ext = std::path::Path::new(path)
40            .extension()
41            .and_then(|e| e.to_str())
42            .unwrap_or("")
43            .to_string();
44        let size_bucket = match token_count {
45            0..=500 => 0,
46            501..=2000 => 1,
47            2001..=5000 => 2,
48            5001..=20000 => 3,
49            _ => 4,
50        };
51        Self { ext, size_bucket }
52    }
53}
54
55/// Learns the best read mode per file signature from historical outcomes.
56#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
57pub struct ModePredictor {
58    history: HashMap<FileSignature, Vec<ModeOutcome>>,
59    project_root: Option<String>,
60}
61
62impl ModePredictor {
63    /// Loads or creates the predictor, using an in-memory buffer for caching.
64    pub fn new() -> Self {
65        let mut guard = PREDICTOR_BUFFER
66            .lock()
67            .unwrap_or_else(std::sync::PoisonError::into_inner);
68        if let Some((ref predictor, _)) = *guard {
69            return Self {
70                history: predictor.history.clone(),
71                project_root: predictor.project_root.clone(),
72            };
73        }
74        let mut loaded = Self::load_from_disk().unwrap_or_default();
75        if loaded.project_root.is_none() {
76            loaded.project_root = std::env::current_dir()
77                .ok()
78                .map(|p| p.to_string_lossy().to_string());
79        }
80        *guard = Some((Arc::new(loaded.clone()), Instant::now()));
81        loaded
82    }
83
84    pub fn with_project_root(mut self, root: &str) -> Self {
85        self.project_root = Some(root.to_string());
86        self
87    }
88
89    pub fn set_project_root(&mut self, root: &str) {
90        self.project_root = Some(root.to_string());
91    }
92
93    /// Records a mode outcome for a file signature (capped at 100 entries).
94    pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
95        let entries = self.history.entry(sig).or_default();
96        entries.push(outcome);
97        if entries.len() > 100 {
98            entries.drain(0..50);
99        }
100    }
101
102    /// Returns the best mode based on historical efficiency.
103    /// Chain: local history -> cloud adaptive models -> built-in defaults.
104    pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
105        let default_mode = Self::predict_from_defaults(sig);
106
107        let allow_override = |candidate: &str| -> bool {
108            let Some(def) = default_mode.as_deref() else {
109                return true;
110            };
111            if candidate == "full" {
112                return false;
113            }
114            // For code-structured defaults, never override to lossy modes.
115            if (def == "map" || def == "signatures")
116                && (candidate == "aggressive" || candidate == "entropy")
117            {
118                return false;
119            }
120            true
121        };
122
123        if let Some(local) = self.predict_from_local(sig) {
124            if allow_override(&local) {
125                return Some(local);
126            }
127        }
128        if let Some(bandit) = self.predict_from_bandit(sig) {
129            if allow_override(&bandit) {
130                return Some(bandit);
131            }
132        }
133        if let Some(cloud) = self.predict_from_cloud(sig) {
134            if allow_override(&cloud) {
135                return Some(cloud);
136            }
137        }
138        default_mode
139    }
140
141    fn predict_from_bandit(&self, sig: &FileSignature) -> Option<String> {
142        let key = format!("{}_feedback", sig.ext);
143        let store =
144            crate::core::bandit::BanditStore::load(self.project_root.as_deref().unwrap_or("."));
145        let bandit = store.bandits.get(&key)?;
146        if bandit.total_pulls < 5 {
147            return None;
148        }
149        let best_arm = bandit.arms.iter().max_by(|a, b| {
150            a.mean()
151                .partial_cmp(&b.mean())
152                .unwrap_or(std::cmp::Ordering::Equal)
153        })?;
154        let mode = match best_arm.name.as_str() {
155            "conservative" => "full",
156            "balanced" => "signatures",
157            "aggressive" => "aggressive",
158            _ => return None,
159        };
160        Some(mode.to_string())
161    }
162
163    fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
164        let entries = self.history.get(sig)?;
165        if entries.len() < 3 {
166            return None;
167        }
168
169        let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
170        for entry in entries {
171            let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
172            *sum += entry.efficiency();
173            *count += 1;
174        }
175
176        mode_scores
177            .into_iter()
178            .max_by(|a, b| {
179                let avg_a = a.1 .0 / a.1 .1 as f64;
180                let avg_b = b.1 .0 / b.1 .1 as f64;
181                avg_a
182                    .partial_cmp(&avg_b)
183                    .unwrap_or(std::cmp::Ordering::Equal)
184            })
185            .map(|(mode, _)| mode.to_string())
186    }
187
188    /// Loads cloud adaptive models (synced from LeanCTX Cloud).
189    /// Models are cached locally and auto-updated for cloud users.
190    #[allow(clippy::unused_self)]
191    fn predict_from_cloud(&self, sig: &FileSignature) -> Option<String> {
192        let data = crate::cloud_client::load_cloud_models()?;
193        let models = data["models"].as_array()?;
194
195        let ext_with_dot = format!(".{}", sig.ext);
196        let bucket_name = match sig.size_bucket {
197            0 => "0-500",
198            1 => "500-2k",
199            2 => "2k-10k",
200            _ => "10k+",
201        };
202
203        let mut best: Option<(&str, f64)> = None;
204
205        for model in models {
206            let m_ext = model["file_ext"].as_str().unwrap_or("");
207            let m_bucket = model["size_bucket"].as_str().unwrap_or("");
208            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
209
210            if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
211                if let Some(mode) = model["recommended_mode"].as_str() {
212                    if best.is_none_or(|(_, c)| confidence > c) {
213                        best = Some((mode, confidence));
214                    }
215                }
216            }
217        }
218
219        if let Some((mode, _)) = best {
220            return Some(mode.to_string());
221        }
222
223        for model in models {
224            let m_ext = model["file_ext"].as_str().unwrap_or("");
225            let confidence = model["confidence"].as_f64().unwrap_or(0.0);
226            if m_ext == ext_with_dot && confidence > 0.5 {
227                return model["recommended_mode"]
228                    .as_str()
229                    .map(std::string::ToString::to_string);
230            }
231        }
232
233        None
234    }
235
236    /// Built-in defaults for common file types and sizes.
237    /// Ensures reasonable compression even without local history or cloud models.
238    /// Respects Kolmogorov-Gate: files with K>0.7 skip aggressive modes.
239    fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
240        if sig.size_bucket == 0 {
241            return None;
242        }
243        if matches!(sig.ext.as_str(), "md" | "mdx" | "txt" | "rst") {
244            return None;
245        }
246
247        let mode = match (sig.ext.as_str(), sig.size_bucket) {
248            // Large code files: signatures only
249            (
250                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
251                | "swift" | "kt" | "cs" | "vue" | "svelte",
252                4..,
253            ) => "signatures",
254
255            // Code 2k-10k, SQL, lock, config/data: structured map
256            ("lock" | "json" | "yaml" | "yml" | "toml", _)
257            | (
258                "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
259                | "swift" | "kt" | "cs" | "vue" | "svelte",
260                2 | 3,
261            )
262            | ("sql", 2..) => "map",
263
264            // CSS, XML/CSV, and large unknown files: aggressive
265            ("xml" | "csv", _) | ("css" | "scss" | "less" | "sass", 2..) | (_, 3..) => "aggressive",
266
267            _ => return None,
268        };
269        Some(mode.to_string())
270    }
271
272    /// Saves to the in-memory buffer and flushes to disk if the interval elapsed.
273    pub fn save(&self) {
274        let mut guard = PREDICTOR_BUFFER
275            .lock()
276            .unwrap_or_else(std::sync::PoisonError::into_inner);
277        let should_flush = match *guard {
278            Some((_, ref last_flush)) => last_flush.elapsed().as_secs() >= PREDICTOR_FLUSH_SECS,
279            None => true,
280        };
281        *guard = Some((Arc::new(self.clone()), Instant::now()));
282        if should_flush {
283            self.save_to_disk();
284        }
285    }
286
287    fn save_to_disk(&self) {
288        let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() else {
289            return;
290        };
291        let _ = std::fs::create_dir_all(&dir);
292        let path = dir.join(STATS_FILE);
293        if let Ok(json) = serde_json::to_string_pretty(self) {
294            let tmp = dir.join(".mode_stats.tmp");
295            if std::fs::write(&tmp, &json).is_ok() {
296                let _ = std::fs::rename(&tmp, &path);
297            }
298        }
299    }
300
301    /// Forces an immediate write of the buffered predictor state to disk.
302    pub fn flush() {
303        let guard = PREDICTOR_BUFFER
304            .lock()
305            .unwrap_or_else(std::sync::PoisonError::into_inner);
306        if let Some((ref predictor, _)) = *guard {
307            predictor.save_to_disk();
308        }
309    }
310
311    fn load_from_disk() -> Option<Self> {
312        let path = crate::core::data_dir::lean_ctx_data_dir()
313            .ok()?
314            .join(STATS_FILE);
315        let data = std::fs::read_to_string(path).ok()?;
316        serde_json::from_str(&data).ok()
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn file_signature_buckets() {
326        assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
327        assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
328        assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
329        assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
330        assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
331    }
332
333    #[test]
334    fn predict_returns_none_without_history() {
335        let predictor = ModePredictor::default();
336        let sig = FileSignature::from_path("test.zzz", 500);
337        assert!(predictor.predict_from_local(&sig).is_none());
338    }
339
340    #[test]
341    fn predict_returns_none_with_too_few_entries() {
342        let mut predictor = ModePredictor::default();
343        let sig = FileSignature::from_path("test.zzz", 500);
344        predictor.record(
345            sig.clone(),
346            ModeOutcome {
347                mode: "full".to_string(),
348                tokens_in: 100,
349                tokens_out: 100,
350                density: 0.5,
351            },
352        );
353        assert!(predictor.predict_from_local(&sig).is_none());
354    }
355
356    #[test]
357    fn predict_learns_best_mode() {
358        let mut predictor = ModePredictor::default();
359        let sig = FileSignature::from_path("big.rs", 5000);
360        for _ in 0..5 {
361            predictor.record(
362                sig.clone(),
363                ModeOutcome {
364                    mode: "full".to_string(),
365                    tokens_in: 5000,
366                    tokens_out: 5000,
367                    density: 0.3,
368                },
369            );
370            predictor.record(
371                sig.clone(),
372                ModeOutcome {
373                    mode: "map".to_string(),
374                    tokens_in: 5000,
375                    tokens_out: 800,
376                    density: 0.6,
377                },
378            );
379        }
380        let best = predictor.predict_best_mode(&sig);
381        assert_eq!(best, Some("map".to_string()));
382    }
383
384    #[test]
385    fn history_caps_at_100() {
386        let mut predictor = ModePredictor::default();
387        let sig = FileSignature::from_path("test.rs", 100);
388        for _ in 0..120 {
389            predictor.record(
390                sig.clone(),
391                ModeOutcome {
392                    mode: "full".to_string(),
393                    tokens_in: 100,
394                    tokens_out: 100,
395                    density: 0.5,
396                },
397            );
398        }
399        assert!(predictor.history.get(&sig).unwrap().len() <= 100);
400    }
401
402    #[test]
403    fn defaults_return_none_for_small_files() {
404        let sig = FileSignature::from_path("small.rs", 200);
405        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
406    }
407
408    #[test]
409    fn defaults_recommend_map_for_medium_code() {
410        let sig = FileSignature::from_path("medium.rs", 3000);
411        assert_eq!(
412            ModePredictor::predict_from_defaults(&sig),
413            Some("map".to_string())
414        );
415    }
416
417    #[test]
418    fn defaults_recommend_map_for_json() {
419        let sig = FileSignature::from_path("config.json", 1000);
420        assert_eq!(
421            ModePredictor::predict_from_defaults(&sig),
422            Some("map".to_string())
423        );
424    }
425
426    #[test]
427    fn defaults_recommend_signatures_for_huge_code() {
428        let sig = FileSignature::from_path("huge.ts", 25000);
429        assert_eq!(
430            ModePredictor::predict_from_defaults(&sig),
431            Some("signatures".to_string())
432        );
433    }
434
435    #[test]
436    fn defaults_recommend_aggressive_for_large_unknown() {
437        let sig = FileSignature::from_path("data.xyz", 8000);
438        assert_eq!(
439            ModePredictor::predict_from_defaults(&sig),
440            Some("aggressive".to_string())
441        );
442    }
443
444    #[test]
445    fn defaults_never_compress_markdown() {
446        for tokens in [600, 3000, 8000, 25000] {
447            let sig = FileSignature::from_path("SKILL.md", tokens);
448            assert!(
449                ModePredictor::predict_from_defaults(&sig).is_none(),
450                "SKILL.md at {tokens} tokens should get full (None), not compressed"
451            );
452        }
453        let sig = FileSignature::from_path("AGENTS.md", 5000);
454        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
455        let sig = FileSignature::from_path("README.md", 12000);
456        assert!(ModePredictor::predict_from_defaults(&sig).is_none());
457    }
458
459    #[test]
460    fn mode_outcome_efficiency() {
461        let o = ModeOutcome {
462            mode: "map".to_string(),
463            tokens_in: 1000,
464            tokens_out: 200,
465            density: 0.6,
466        };
467        assert!(o.efficiency() > 0.0);
468    }
469}