lean_ctx/core/
mdl_selector.rs1use super::compressor::aggressive_compress;
4use super::entropy::entropy_compress;
5use super::signatures::{extract_file_map, extract_signatures, Signature};
6use super::tokens::count_tokens;
7
8#[derive(Clone, Copy)]
9struct ModeSpec {
10 name: &'static str,
11 model_cost: usize,
13}
14
15const MODES: [ModeSpec; 5] = [
16 ModeSpec {
17 name: "full",
18 model_cost: 0,
19 },
20 ModeSpec {
21 name: "map",
22 model_cost: 50,
23 },
24 ModeSpec {
25 name: "signatures",
26 model_cost: 80,
27 },
28 ModeSpec {
29 name: "aggressive",
30 model_cost: 120,
31 },
32 ModeSpec {
33 name: "entropy",
34 model_cost: 140,
35 },
36];
37
38fn synthetic_path_for(content: &str) -> &'static str {
39 if content.contains("def ")
40 && content
41 .lines()
42 .next()
43 .is_some_and(|l| l.trim_start().starts_with("def "))
44 {
45 "snippet.py"
46 } else if content.contains("package ")
47 || content.contains("func ")
48 || content.lines().any(|l| l.starts_with("func "))
49 {
50 "snippet.go"
51 } else {
52 "snippet.rs"
53 }
54}
55
56fn ext_from_path(path: &str) -> &str {
57 path.rsplit_once('.').map_or("rs", |(_, e)| e)
58}
59
60fn render_signatures(compact: &[String]) -> String {
61 compact.join("\n")
62}
63
64fn compressed_tokens_for(mode: &str, content: &str, path: &str) -> usize {
65 let ext = ext_from_path(path);
66 match mode {
67 "map" => count_tokens(&extract_file_map(path, content)),
68 "signatures" => {
69 let sigs = extract_signatures(content, ext);
70 let lines: Vec<String> = sigs.iter().map(Signature::to_compact).collect();
71 count_tokens(&render_signatures(&lines))
72 }
73 "aggressive" => count_tokens(&aggressive_compress(content, Some(ext))),
74 "entropy" => count_tokens(&entropy_compress(content).output),
75 _ => count_tokens(content),
76 }
77}
78
79pub fn select_mode(content: &str, budget_tokens: usize) -> &'static str {
81 if content.is_empty() {
82 return "full";
83 }
84 let path = synthetic_path_for(content);
85
86 let mut best_feasible: Option<(&'static str, usize)> = None;
87 let mut best_fallback: Option<(&'static str, usize)> = None;
88
89 for m in &MODES {
90 let ct = compressed_tokens_for(m.name, content, path);
91 let dl = ct.saturating_add(m.model_cost);
92
93 let cand_opt = best_fallback.map_or(Some((m.name, dl)), |(bn, bd)| {
94 Some(if dl < bd || (dl == bd && m.name < bn) {
95 (m.name, dl)
96 } else {
97 (bn, bd)
98 })
99 });
100 best_fallback = cand_opt;
101
102 if ct <= budget_tokens {
103 best_feasible = Some(match best_feasible {
104 None => (m.name, dl),
105 Some((bn, bd)) => {
106 if dl < bd || (dl == bd && m.name < bn) {
107 (m.name, dl)
108 } else {
109 (bn, bd)
110 }
111 }
112 });
113 }
114 }
115
116 best_feasible.or(best_fallback).map_or("full", |(n, _)| n)
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn empty_returns_full() {
125 assert_eq!(select_mode("", 100), "full");
126 }
127
128 #[test]
129 fn large_budget_picks_some_mode() {
130 let code = "pub fn foo() -> i32 { 1 }\npub fn bar(x: u32) {}\n";
131 let ub = count_tokens(code) + 50_000;
132 let m = select_mode(code, ub);
133 assert!(matches!(
134 m,
135 "full" | "map" | "signatures" | "aggressive" | "entropy"
136 ));
137 }
138
139 #[test]
140 fn tight_budget_avoids_full() {
141 let repetitive = "a ".repeat(200);
142 let m = select_mode(&repetitive, 5);
143 assert_ne!(m, "full");
144 }
145
146 #[test]
147 fn respects_budget_when_full_fits() {
148 let py = "def foo():\n pass\n";
149 let t_full = compressed_tokens_for("full", py, "snippet.py");
150 let mode = select_mode(py, t_full);
151 let ct = compressed_tokens_for(mode, py, "snippet.py");
152 assert!(ct <= t_full);
153 }
154}