1use std::collections::{BTreeMap, HashMap};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use crate::core::llm_feedback::LlmFeedbackEvent;
7
8const POLICY_FILE: &str = "adaptive_mode_policy.json";
9const EMA_ALPHA: f64 = 0.2;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct AdaptiveModePolicyStore {
13 pub global: ModePenaltyTable,
14 #[serde(default)]
15 pub by_intent: HashMap<String, ModePenaltyTable>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct ModePenaltyTable {
20 #[serde(default)]
21 pub modes: BTreeMap<String, ModePenalty>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ModePenalty {
26 pub ema_badness: f64,
27 pub samples: u64,
28 pub last_ts: Option<String>,
29}
30
31impl AdaptiveModePolicyStore {
32 pub fn load() -> Self {
33 let path = policy_path();
34 let Ok(s) = std::fs::read_to_string(&path) else {
35 return Self::default();
36 };
37 serde_json::from_str(&s).unwrap_or_default()
38 }
39
40 pub fn save(&self) -> Result<(), String> {
41 let path = policy_path();
42 if let Some(parent) = path.parent() {
43 std::fs::create_dir_all(parent)
44 .map_err(|e| format!("create_dir_all {}: {e}", parent.display()))?;
45 }
46 let json = serde_json::to_string_pretty(self).map_err(|e| format!("serialize: {e}"))?;
47 let tmp = path.with_extension("tmp");
48 std::fs::write(&tmp, json).map_err(|e| format!("write {}: {e}", tmp.display()))?;
49 std::fs::rename(&tmp, &path).map_err(|e| format!("rename {}: {e}", path.display()))?;
50 Ok(())
51 }
52
53 pub fn reset() -> Result<(), String> {
54 let path = policy_path();
55 if path.exists() {
56 std::fs::remove_file(&path).map_err(|e| format!("remove {}: {e}", path.display()))?;
57 }
58 Ok(())
59 }
60
61 pub fn update_from_feedback(&mut self, ev: &LlmFeedbackEvent) {
62 let ratio = ev.llm_output_tokens as f64 / ev.llm_input_tokens.max(1) as f64;
63 let mut badness = ((ratio - 1.2) / 1.2).clamp(0.0, 1.0);
64 if ev.llm_output_tokens >= 6000 {
65 badness = badness.max(0.8);
66 } else if ev.llm_output_tokens >= 3000 {
67 badness = badness.max(0.5);
68 }
69 if badness <= 0.0 {
70 return;
71 }
72
73 let modes = ev.ctx_read_modes.clone().unwrap_or_default();
74 if modes.is_empty() {
75 if let Some(m) = ev.ctx_read_last_mode.as_ref() {
76 if let Some(key) = normalize_mode_key(m) {
77 Self::apply_update(&mut self.global, key, badness, ev.timestamp.as_str());
78 if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
79 let table = self.by_intent.entry(k).or_default();
80 Self::apply_update(table, key, badness, ev.timestamp.as_str());
81 }
82 }
83 }
84 return;
85 }
86
87 let total: u64 = modes.values().sum();
88 if total == 0 {
89 return;
90 }
91
92 for (mode, count) in modes {
93 let Some(key) = normalize_mode_key(&mode) else {
94 continue;
95 };
96 let w = count as f64 / total as f64;
97 let b = badness * w;
98 Self::apply_update(&mut self.global, key, b, ev.timestamp.as_str());
99 if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
100 let table = self.by_intent.entry(k).or_default();
101 Self::apply_update(table, key, b, ev.timestamp.as_str());
102 }
103 }
104 }
105
106 fn apply_update(table: &mut ModePenaltyTable, mode: &str, badness: f64, ts: &str) {
107 let entry = table.modes.entry(mode.to_string()).or_default();
108 entry.ema_badness = entry.ema_badness * (1.0 - EMA_ALPHA) + badness * EMA_ALPHA;
109 entry.samples = entry.samples.saturating_add(1);
110 entry.last_ts = Some(ts.to_string());
111 }
112
113 pub fn penalty(&self, intent: Option<&str>, mode: &str) -> f64 {
114 let Some(key) = normalize_mode_key(mode) else {
115 return 0.0;
116 };
117 if let Some(k) = normalized_intent_key(intent) {
118 if let Some(t) = self.by_intent.get(&k) {
119 if let Some(p) = t.modes.get(key) {
120 return p.ema_badness.clamp(0.0, 1.0);
121 }
122 }
123 }
124 self.global
125 .modes
126 .get(key)
127 .map_or(0.0, |p| p.ema_badness.clamp(0.0, 1.0))
128 }
129
130 pub fn choose_auto_mode(&self, intent: Option<&str>, predicted: &str) -> String {
131 let candidates = auto_candidates(predicted);
132 let mut best_mode = predicted.to_string();
133 let mut best_score = f64::NEG_INFINITY;
134 for (i, mode) in candidates.into_iter().enumerate() {
135 let base = 1.0 - (i as f64 * 0.05);
136 let score = base - self.penalty(intent, &mode);
137 if score > best_score {
138 best_score = score;
139 best_mode = mode;
140 }
141 }
142 best_mode
143 }
144}
145
146fn mode_group(mode: &str) -> &str {
147 match mode {
148 "entropy" | "aggressive" => "aggressive",
149 other => other,
150 }
151}
152
153fn normalize_mode_key(mode: &str) -> Option<&str> {
154 if mode == "diff" {
155 return None;
156 }
157 if mode.starts_with("lines:") {
158 return None;
159 }
160 Some(mode_group(mode))
161}
162
163fn auto_candidates(predicted: &str) -> Vec<String> {
164 match predicted {
165 "aggressive" | "entropy" => vec![
166 "aggressive".to_string(),
167 "entropy".to_string(),
168 "map".to_string(),
169 "signatures".to_string(),
170 "full".to_string(),
171 ],
172 "map" => vec![
173 "map".to_string(),
174 "signatures".to_string(),
175 "full".to_string(),
176 "aggressive".to_string(),
177 ],
178 "signatures" => vec![
179 "signatures".to_string(),
180 "map".to_string(),
181 "full".to_string(),
182 "aggressive".to_string(),
183 ],
184 "full" => vec![
185 "full".to_string(),
186 "map".to_string(),
187 "signatures".to_string(),
188 "aggressive".to_string(),
189 ],
190 other => vec![
191 other.to_string(),
192 "map".to_string(),
193 "signatures".to_string(),
194 "full".to_string(),
195 "aggressive".to_string(),
196 ],
197 }
198}
199
200fn normalized_intent_key(intent: Option<&str>) -> Option<String> {
201 let s = intent?.trim();
202 if s.is_empty() {
203 return None;
204 }
205 let lower = s.to_lowercase();
206 Some(lower.chars().take(80).collect())
207}
208
209fn policy_path() -> PathBuf {
210 crate::core::data_dir::lean_ctx_data_dir()
211 .unwrap_or_else(|_| PathBuf::from("."))
212 .join(POLICY_FILE)
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn penalty_defaults_to_zero() {
221 let p = AdaptiveModePolicyStore::default();
222 assert_eq!(p.penalty(None, "aggressive"), 0.0);
223 }
224
225 #[test]
226 fn choose_auto_avoids_penalized_aggressive() {
227 let mut store = AdaptiveModePolicyStore::default();
228 AdaptiveModePolicyStore::apply_update(&mut store.global, "aggressive", 1.0, "t");
229 let chosen = store.choose_auto_mode(Some("fix bug"), "aggressive");
230 assert_ne!(chosen, "aggressive");
231 assert_ne!(chosen, "entropy");
232 }
233}