1use std::collections::HashMap;
2
3const STATS_FILE: &str = "mode_stats.json";
4
5#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
6pub struct ModeOutcome {
7 pub mode: String,
8 pub tokens_in: usize,
9 pub tokens_out: usize,
10 pub density: f64,
11}
12
13impl ModeOutcome {
14 pub fn efficiency(&self) -> f64 {
15 if self.tokens_out == 0 {
16 return 0.0;
17 }
18 self.density / (self.tokens_out as f64 / self.tokens_in.max(1) as f64)
19 }
20}
21
22#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
23pub struct FileSignature {
24 pub ext: String,
25 pub size_bucket: u8,
26}
27
28impl FileSignature {
29 pub fn from_path(path: &str, token_count: usize) -> Self {
30 let ext = std::path::Path::new(path)
31 .extension()
32 .and_then(|e| e.to_str())
33 .unwrap_or("")
34 .to_string();
35 let size_bucket = match token_count {
36 0..=500 => 0,
37 501..=2000 => 1,
38 2001..=5000 => 2,
39 5001..=20000 => 3,
40 _ => 4,
41 };
42 Self { ext, size_bucket }
43 }
44}
45
46#[derive(Debug, Default, serde::Serialize, serde::Deserialize)]
47pub struct ModePredictor {
48 history: HashMap<FileSignature, Vec<ModeOutcome>>,
49}
50
51impl ModePredictor {
52 pub fn new() -> Self {
53 Self::load().unwrap_or_default()
54 }
55
56 pub fn record(&mut self, sig: FileSignature, outcome: ModeOutcome) {
57 let entries = self.history.entry(sig).or_default();
58 entries.push(outcome);
59 if entries.len() > 100 {
60 entries.drain(0..50);
61 }
62 }
63
64 pub fn predict_best_mode(&self, sig: &FileSignature) -> Option<String> {
67 if let Some(local) = self.predict_from_local(sig) {
68 return Some(local);
69 }
70 if let Some(pro) = self.predict_from_pro(sig) {
71 return Some(pro);
72 }
73 Self::predict_from_defaults(sig)
74 }
75
76 fn predict_from_local(&self, sig: &FileSignature) -> Option<String> {
77 let entries = self.history.get(sig)?;
78 if entries.len() < 3 {
79 return None;
80 }
81
82 let mut mode_scores: HashMap<&str, (f64, usize)> = HashMap::new();
83 for entry in entries {
84 let (sum, count) = mode_scores.entry(&entry.mode).or_insert((0.0, 0));
85 *sum += entry.efficiency();
86 *count += 1;
87 }
88
89 mode_scores
90 .into_iter()
91 .max_by(|a, b| {
92 let avg_a = a.1 .0 / a.1 .1 as f64;
93 let avg_b = b.1 .0 / b.1 .1 as f64;
94 avg_a
95 .partial_cmp(&avg_b)
96 .unwrap_or(std::cmp::Ordering::Equal)
97 })
98 .map(|(mode, _)| mode.to_string())
99 }
100
101 fn predict_from_pro(&self, sig: &FileSignature) -> Option<String> {
104 let data = crate::cloud_client::load_pro_models()?;
105 let models = data["models"].as_array()?;
106
107 let ext_with_dot = format!(".{}", sig.ext);
108 let bucket_name = match sig.size_bucket {
109 0 => "0-500",
110 1 => "500-2k",
111 2 => "2k-10k",
112 3 => "10k+",
113 _ => "10k+",
114 };
115
116 let mut best: Option<(&str, f64)> = None;
117
118 for model in models {
119 let m_ext = model["file_ext"].as_str().unwrap_or("");
120 let m_bucket = model["size_bucket"].as_str().unwrap_or("");
121 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
122
123 if m_ext == ext_with_dot && m_bucket == bucket_name && confidence > 0.5 {
124 if let Some(mode) = model["recommended_mode"].as_str() {
125 if best.is_none() || confidence > best.unwrap().1 {
126 best = Some((mode, confidence));
127 }
128 }
129 }
130 }
131
132 if let Some((mode, _)) = best {
133 return Some(mode.to_string());
134 }
135
136 for model in models {
137 let m_ext = model["file_ext"].as_str().unwrap_or("");
138 let confidence = model["confidence"].as_f64().unwrap_or(0.0);
139 if m_ext == ext_with_dot && confidence > 0.5 {
140 return model["recommended_mode"].as_str().map(|s| s.to_string());
141 }
142 }
143
144 None
145 }
146
147 fn predict_from_defaults(sig: &FileSignature) -> Option<String> {
151 let mode = match (sig.ext.as_str(), sig.size_bucket) {
152 (_, 0) => return None,
154
155 ("json" | "yaml" | "yml" | "toml" | "xml" | "csv", _) => "aggressive",
157
158 ("lock", _) => "signatures",
160
161 (
164 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
165 | "swift" | "kt" | "cs" | "vue" | "svelte",
166 1,
167 ) => return None,
168
169 (
171 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
172 | "swift" | "kt" | "cs" | "vue" | "svelte",
173 2,
174 ) => "map",
175
176 (
178 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
179 | "swift" | "kt" | "cs" | "vue" | "svelte",
180 3,
181 ) => "map",
182
183 (
185 "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "java" | "c" | "cpp" | "rb"
186 | "swift" | "kt" | "cs" | "vue" | "svelte",
187 4..,
188 ) => "signatures",
189
190 ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 1..=2) => return None,
192 ("md" | "mdx" | "rst" | "txt" | "html" | "astro", 3..) => "aggressive",
193
194 ("css" | "scss" | "less" | "sass", 2..) => "aggressive",
196
197 ("sql", 2..) => "map",
199
200 (_, 3..) => "aggressive",
202
203 _ => return None,
204 };
205 Some(mode.to_string())
206 }
207
208 pub fn save(&self) {
209 let dir = match dirs::home_dir() {
210 Some(d) => d.join(".lean-ctx"),
211 None => return,
212 };
213 let _ = std::fs::create_dir_all(&dir);
214 let path = dir.join(STATS_FILE);
215 if let Ok(json) = serde_json::to_string_pretty(self) {
216 let _ = std::fs::write(path, json);
217 }
218 }
219
220 fn load() -> Option<Self> {
221 let path = dirs::home_dir()?.join(".lean-ctx").join(STATS_FILE);
222 let data = std::fs::read_to_string(path).ok()?;
223 serde_json::from_str(&data).ok()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn file_signature_buckets() {
233 assert_eq!(FileSignature::from_path("main.rs", 100).size_bucket, 0);
234 assert_eq!(FileSignature::from_path("main.rs", 1000).size_bucket, 1);
235 assert_eq!(FileSignature::from_path("main.rs", 3000).size_bucket, 2);
236 assert_eq!(FileSignature::from_path("main.rs", 10000).size_bucket, 3);
237 assert_eq!(FileSignature::from_path("main.rs", 50000).size_bucket, 4);
238 }
239
240 #[test]
241 fn predict_returns_none_without_history() {
242 let predictor = ModePredictor::default();
243 let sig = FileSignature::from_path("test.zzz", 500);
244 assert!(predictor.predict_from_local(&sig).is_none());
245 }
246
247 #[test]
248 fn predict_returns_none_with_too_few_entries() {
249 let mut predictor = ModePredictor::default();
250 let sig = FileSignature::from_path("test.zzz", 500);
251 predictor.record(
252 sig.clone(),
253 ModeOutcome {
254 mode: "full".to_string(),
255 tokens_in: 100,
256 tokens_out: 100,
257 density: 0.5,
258 },
259 );
260 assert!(predictor.predict_from_local(&sig).is_none());
261 }
262
263 #[test]
264 fn predict_learns_best_mode() {
265 let mut predictor = ModePredictor::default();
266 let sig = FileSignature::from_path("big.rs", 5000);
267 for _ in 0..5 {
268 predictor.record(
269 sig.clone(),
270 ModeOutcome {
271 mode: "full".to_string(),
272 tokens_in: 5000,
273 tokens_out: 5000,
274 density: 0.3,
275 },
276 );
277 predictor.record(
278 sig.clone(),
279 ModeOutcome {
280 mode: "map".to_string(),
281 tokens_in: 5000,
282 tokens_out: 800,
283 density: 0.6,
284 },
285 );
286 }
287 let best = predictor.predict_best_mode(&sig);
288 assert_eq!(best, Some("map".to_string()));
289 }
290
291 #[test]
292 fn history_caps_at_100() {
293 let mut predictor = ModePredictor::default();
294 let sig = FileSignature::from_path("test.rs", 100);
295 for _ in 0..120 {
296 predictor.record(
297 sig.clone(),
298 ModeOutcome {
299 mode: "full".to_string(),
300 tokens_in: 100,
301 tokens_out: 100,
302 density: 0.5,
303 },
304 );
305 }
306 assert!(predictor.history.get(&sig).unwrap().len() <= 100);
307 }
308
309 #[test]
310 fn defaults_return_none_for_small_files() {
311 let sig = FileSignature::from_path("small.rs", 200);
312 assert!(ModePredictor::predict_from_defaults(&sig).is_none());
313 }
314
315 #[test]
316 fn defaults_recommend_map_for_medium_code() {
317 let sig = FileSignature::from_path("medium.rs", 3000);
318 assert_eq!(
319 ModePredictor::predict_from_defaults(&sig),
320 Some("map".to_string())
321 );
322 }
323
324 #[test]
325 fn defaults_recommend_aggressive_for_json() {
326 let sig = FileSignature::from_path("config.json", 1000);
327 assert_eq!(
328 ModePredictor::predict_from_defaults(&sig),
329 Some("aggressive".to_string())
330 );
331 }
332
333 #[test]
334 fn defaults_recommend_signatures_for_huge_code() {
335 let sig = FileSignature::from_path("huge.ts", 25000);
336 assert_eq!(
337 ModePredictor::predict_from_defaults(&sig),
338 Some("signatures".to_string())
339 );
340 }
341
342 #[test]
343 fn defaults_recommend_aggressive_for_large_unknown() {
344 let sig = FileSignature::from_path("data.xyz", 8000);
345 assert_eq!(
346 ModePredictor::predict_from_defaults(&sig),
347 Some("aggressive".to_string())
348 );
349 }
350
351 #[test]
352 fn mode_outcome_efficiency() {
353 let o = ModeOutcome {
354 mode: "map".to_string(),
355 tokens_in: 1000,
356 tokens_out: 200,
357 density: 0.6,
358 };
359 assert!(o.efficiency() > 0.0);
360 }
361}