Skip to main content

lean_ctx/core/
mdl_selector.rs

1//! Minimum Description Length–style selection among abstract read modes (proxy compressed lengths).
2
3use super::compressor::aggressive_compress;
4use super::entropy::entropy_compress;
5use super::signatures::{extract_file_map, extract_signatures, Signature};
6use super::tokens::count_tokens;
7
8#[derive(Clone, Copy)]
9struct ModeSpec {
10    name: &'static str,
11    /// Prior cost added to compressed token estimate (mode complexity).
12    model_cost: usize,
13}
14
15const MODES: [ModeSpec; 5] = [
16    ModeSpec {
17        name: "full",
18        model_cost: 0,
19    },
20    ModeSpec {
21        name: "map",
22        model_cost: 50,
23    },
24    ModeSpec {
25        name: "signatures",
26        model_cost: 80,
27    },
28    ModeSpec {
29        name: "aggressive",
30        model_cost: 120,
31    },
32    ModeSpec {
33        name: "entropy",
34        model_cost: 140,
35    },
36];
37
38fn synthetic_path_for(content: &str) -> &'static str {
39    if content.contains("def ")
40        && content
41            .lines()
42            .next()
43            .is_some_and(|l| l.trim_start().starts_with("def "))
44    {
45        "snippet.py"
46    } else if content.contains("package ")
47        || content.contains("func ")
48        || content.lines().any(|l| l.starts_with("func "))
49    {
50        "snippet.go"
51    } else {
52        "snippet.rs"
53    }
54}
55
56fn ext_from_path(path: &str) -> &str {
57    path.rsplit_once('.').map_or("rs", |(_, e)| e)
58}
59
60fn render_signatures(compact: &[String]) -> String {
61    compact.join("\n")
62}
63
64fn compressed_tokens_for(mode: &str, content: &str, path: &str) -> usize {
65    let ext = ext_from_path(path);
66    match mode {
67        "map" => count_tokens(&extract_file_map(path, content)),
68        "signatures" => {
69            let sigs = extract_signatures(content, ext);
70            let lines: Vec<String> = sigs.iter().map(Signature::to_compact).collect();
71            count_tokens(&render_signatures(&lines))
72        }
73        "aggressive" => count_tokens(&aggressive_compress(content, Some(ext))),
74        "entropy" => count_tokens(&entropy_compress(content).output),
75        _ => count_tokens(content),
76    }
77}
78
79/// Pick read mode minimizing MDL proxy `compressed_tokens + model_cost` among modes whose compressed size fits `budget_tokens`.
80pub fn select_mode(content: &str, budget_tokens: usize) -> &'static str {
81    if content.is_empty() {
82        return "full";
83    }
84    let path = synthetic_path_for(content);
85
86    let mut best_feasible: Option<(&'static str, usize)> = None;
87    let mut best_fallback: Option<(&'static str, usize)> = None;
88
89    for m in &MODES {
90        let ct = compressed_tokens_for(m.name, content, path);
91        let dl = ct.saturating_add(m.model_cost);
92
93        let cand_opt = best_fallback.map_or(Some((m.name, dl)), |(bn, bd)| {
94            Some(if dl < bd || (dl == bd && m.name < bn) {
95                (m.name, dl)
96            } else {
97                (bn, bd)
98            })
99        });
100        best_fallback = cand_opt;
101
102        if ct <= budget_tokens {
103            best_feasible = Some(match best_feasible {
104                None => (m.name, dl),
105                Some((bn, bd)) => {
106                    if dl < bd || (dl == bd && m.name < bn) {
107                        (m.name, dl)
108                    } else {
109                        (bn, bd)
110                    }
111                }
112            });
113        }
114    }
115
116    best_feasible.or(best_fallback).map_or("full", |(n, _)| n)
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn empty_returns_full() {
125        assert_eq!(select_mode("", 100), "full");
126    }
127
128    #[test]
129    fn large_budget_picks_some_mode() {
130        let code = "pub fn foo() -> i32 { 1 }\npub fn bar(x: u32) {}\n";
131        let ub = count_tokens(code) + 50_000;
132        let m = select_mode(code, ub);
133        assert!(matches!(
134            m,
135            "full" | "map" | "signatures" | "aggressive" | "entropy"
136        ));
137    }
138
139    #[test]
140    fn tight_budget_avoids_full() {
141        let repetitive = "a ".repeat(200);
142        let m = select_mode(&repetitive, 5);
143        assert_ne!(m, "full");
144    }
145
146    #[test]
147    fn respects_budget_when_full_fits() {
148        let py = "def foo():\n    pass\n";
149        let t_full = compressed_tokens_for("full", py, "snippet.py");
150        let mode = select_mode(py, t_full);
151        let ct = compressed_tokens_for(mode, py, "snippet.py");
152        assert!(ct <= t_full);
153    }
154}