1use std::collections::{BTreeMap, HashMap};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use crate::core::llm_feedback::LlmFeedbackEvent;
7
8const POLICY_FILE: &str = "adaptive_mode_policy.json";
9const EMA_ALPHA: f64 = 0.2;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct AdaptiveModePolicyStore {
13 pub global: ModePenaltyTable,
14 #[serde(default)]
15 pub by_intent: HashMap<String, ModePenaltyTable>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct ModePenaltyTable {
20 #[serde(default)]
21 pub modes: BTreeMap<String, ModePenalty>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ModePenalty {
26 pub ema_badness: f64,
27 pub samples: u64,
28 pub last_ts: Option<String>,
29}
30
31impl AdaptiveModePolicyStore {
32 pub fn load() -> Self {
33 let path = policy_path();
34 let Ok(s) = std::fs::read_to_string(&path) else {
35 return Self::default();
36 };
37 serde_json::from_str(&s).unwrap_or_default()
38 }
39
40 pub fn save(&self) -> Result<(), String> {
41 let path = policy_path();
42 if let Some(parent) = path.parent() {
43 std::fs::create_dir_all(parent)
44 .map_err(|e| format!("create_dir_all {}: {e}", parent.display()))?;
45 }
46 let json = serde_json::to_string_pretty(self).map_err(|e| format!("serialize: {e}"))?;
47 let tmp = path.with_extension("tmp");
48 std::fs::write(&tmp, json).map_err(|e| format!("write {}: {e}", tmp.display()))?;
49 std::fs::rename(&tmp, &path).map_err(|e| format!("rename {}: {e}", path.display()))?;
50 Ok(())
51 }
52
53 pub fn reset() -> Result<(), String> {
54 let path = policy_path();
55 if path.exists() {
56 std::fs::remove_file(&path).map_err(|e| format!("remove {}: {e}", path.display()))?;
57 }
58 Ok(())
59 }
60
61 pub fn update_from_feedback(&mut self, ev: &LlmFeedbackEvent) {
62 let ratio = ev.llm_output_tokens as f64 / ev.llm_input_tokens.max(1) as f64;
63 let mut badness = ((ratio - 1.2) / 1.2).clamp(0.0, 1.0);
64 if ev.llm_output_tokens >= 6000 {
65 badness = badness.max(0.8);
66 } else if ev.llm_output_tokens >= 3000 {
67 badness = badness.max(0.5);
68 }
69 if badness <= 0.0 {
70 return;
71 }
72
73 let modes = ev.ctx_read_modes.as_ref().cloned().unwrap_or_default();
74 if modes.is_empty() {
75 if let Some(m) = ev.ctx_read_last_mode.as_ref() {
76 if let Some(key) = normalize_mode_key(m) {
77 Self::apply_update(&mut self.global, key, badness, ev.timestamp.as_str());
78 if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
79 let table = self.by_intent.entry(k).or_default();
80 Self::apply_update(table, key, badness, ev.timestamp.as_str());
81 }
82 }
83 }
84 return;
85 }
86
87 let total: u64 = modes.values().sum();
88 if total == 0 {
89 return;
90 }
91
92 for (mode, count) in modes {
93 let Some(key) = normalize_mode_key(&mode) else {
94 continue;
95 };
96 let w = count as f64 / total as f64;
97 let b = badness * w;
98 Self::apply_update(&mut self.global, key, b, ev.timestamp.as_str());
99 if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
100 let table = self.by_intent.entry(k).or_default();
101 Self::apply_update(table, key, b, ev.timestamp.as_str());
102 }
103 }
104 }
105
106 fn apply_update(table: &mut ModePenaltyTable, mode: &str, badness: f64, ts: &str) {
107 let entry = table.modes.entry(mode.to_string()).or_default();
108 entry.ema_badness = entry.ema_badness * (1.0 - EMA_ALPHA) + badness * EMA_ALPHA;
109 entry.samples = entry.samples.saturating_add(1);
110 entry.last_ts = Some(ts.to_string());
111 }
112
113 pub fn penalty(&self, intent: Option<&str>, mode: &str) -> f64 {
114 let Some(key) = normalize_mode_key(mode) else {
115 return 0.0;
116 };
117 if let Some(k) = normalized_intent_key(intent) {
118 if let Some(t) = self.by_intent.get(&k) {
119 if let Some(p) = t.modes.get(key) {
120 return p.ema_badness.clamp(0.0, 1.0);
121 }
122 }
123 }
124 self.global
125 .modes
126 .get(key)
127 .map(|p| p.ema_badness.clamp(0.0, 1.0))
128 .unwrap_or(0.0)
129 }
130
131 pub fn choose_auto_mode(&self, intent: Option<&str>, predicted: &str) -> String {
132 let candidates = auto_candidates(predicted);
133 let mut best_mode = predicted.to_string();
134 let mut best_score = f64::NEG_INFINITY;
135 for (i, mode) in candidates.into_iter().enumerate() {
136 let base = 1.0 - (i as f64 * 0.05);
137 let score = base - self.penalty(intent, &mode);
138 if score > best_score {
139 best_score = score;
140 best_mode = mode;
141 }
142 }
143 best_mode
144 }
145}
146
147fn mode_group(mode: &str) -> &str {
148 match mode {
149 "entropy" | "aggressive" => "aggressive",
150 other => other,
151 }
152}
153
154fn normalize_mode_key(mode: &str) -> Option<&str> {
155 if mode == "diff" {
156 return None;
157 }
158 if mode.starts_with("lines:") {
159 return None;
160 }
161 Some(mode_group(mode))
162}
163
164fn auto_candidates(predicted: &str) -> Vec<String> {
165 match predicted {
166 "aggressive" | "entropy" => vec![
167 "aggressive".to_string(),
168 "entropy".to_string(),
169 "map".to_string(),
170 "signatures".to_string(),
171 "full".to_string(),
172 ],
173 "map" => vec![
174 "map".to_string(),
175 "signatures".to_string(),
176 "full".to_string(),
177 "aggressive".to_string(),
178 ],
179 "signatures" => vec![
180 "signatures".to_string(),
181 "map".to_string(),
182 "full".to_string(),
183 "aggressive".to_string(),
184 ],
185 "full" => vec![
186 "full".to_string(),
187 "map".to_string(),
188 "signatures".to_string(),
189 "aggressive".to_string(),
190 ],
191 other => vec![
192 other.to_string(),
193 "map".to_string(),
194 "signatures".to_string(),
195 "full".to_string(),
196 "aggressive".to_string(),
197 ],
198 }
199}
200
201fn normalized_intent_key(intent: Option<&str>) -> Option<String> {
202 let s = intent?.trim();
203 if s.is_empty() {
204 return None;
205 }
206 let lower = s.to_lowercase();
207 Some(lower.chars().take(80).collect())
208}
209
210fn policy_path() -> PathBuf {
211 crate::core::data_dir::lean_ctx_data_dir()
212 .unwrap_or_else(|_| PathBuf::from("."))
213 .join(POLICY_FILE)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn penalty_defaults_to_zero() {
222 let p = AdaptiveModePolicyStore::default();
223 assert_eq!(p.penalty(None, "aggressive"), 0.0);
224 }
225
226 #[test]
227 fn choose_auto_avoids_penalized_aggressive() {
228 let mut store = AdaptiveModePolicyStore::default();
229 AdaptiveModePolicyStore::apply_update(&mut store.global, "aggressive", 1.0, "t");
230 let chosen = store.choose_auto_mode(Some("fix bug"), "aggressive");
231 assert_ne!(chosen, "aggressive");
232 assert_ne!(chosen, "entropy");
233 }
234}