Skip to main content

lean_ctx/core/
adaptive_mode_policy.rs

1use std::collections::{BTreeMap, HashMap};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use crate::core::llm_feedback::LlmFeedbackEvent;
7
8const POLICY_FILE: &str = "adaptive_mode_policy.json";
9const EMA_ALPHA: f64 = 0.2;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct AdaptiveModePolicyStore {
13    pub global: ModePenaltyTable,
14    #[serde(default)]
15    pub by_intent: HashMap<String, ModePenaltyTable>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct ModePenaltyTable {
20    #[serde(default)]
21    pub modes: BTreeMap<String, ModePenalty>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ModePenalty {
26    pub ema_badness: f64,
27    pub samples: u64,
28    pub last_ts: Option<String>,
29}
30
31impl AdaptiveModePolicyStore {
32    pub fn load() -> Self {
33        let path = policy_path();
34        let Ok(s) = std::fs::read_to_string(&path) else {
35            return Self::default();
36        };
37        serde_json::from_str(&s).unwrap_or_default()
38    }
39
40    pub fn save(&self) -> Result<(), String> {
41        let path = policy_path();
42        if let Some(parent) = path.parent() {
43            std::fs::create_dir_all(parent)
44                .map_err(|e| format!("create_dir_all {}: {e}", parent.display()))?;
45        }
46        let json = serde_json::to_string_pretty(self).map_err(|e| format!("serialize: {e}"))?;
47        let tmp = path.with_extension("tmp");
48        std::fs::write(&tmp, json).map_err(|e| format!("write {}: {e}", tmp.display()))?;
49        std::fs::rename(&tmp, &path).map_err(|e| format!("rename {}: {e}", path.display()))?;
50        Ok(())
51    }
52
53    pub fn reset() -> Result<(), String> {
54        let path = policy_path();
55        if path.exists() {
56            std::fs::remove_file(&path).map_err(|e| format!("remove {}: {e}", path.display()))?;
57        }
58        Ok(())
59    }
60
61    pub fn update_from_feedback(&mut self, ev: &LlmFeedbackEvent) {
62        let ratio = ev.llm_output_tokens as f64 / ev.llm_input_tokens.max(1) as f64;
63        let mut badness = ((ratio - 1.2) / 1.2).clamp(0.0, 1.0);
64        if ev.llm_output_tokens >= 6000 {
65            badness = badness.max(0.8);
66        } else if ev.llm_output_tokens >= 3000 {
67            badness = badness.max(0.5);
68        }
69        if badness <= 0.0 {
70            return;
71        }
72
73        let modes = ev.ctx_read_modes.as_ref().cloned().unwrap_or_default();
74        if modes.is_empty() {
75            if let Some(m) = ev.ctx_read_last_mode.as_ref() {
76                if let Some(key) = normalize_mode_key(m) {
77                    Self::apply_update(&mut self.global, key, badness, ev.timestamp.as_str());
78                    if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
79                        let table = self.by_intent.entry(k).or_default();
80                        Self::apply_update(table, key, badness, ev.timestamp.as_str());
81                    }
82                }
83            }
84            return;
85        }
86
87        let total: u64 = modes.values().sum();
88        if total == 0 {
89            return;
90        }
91
92        for (mode, count) in modes {
93            let Some(key) = normalize_mode_key(&mode) else {
94                continue;
95            };
96            let w = count as f64 / total as f64;
97            let b = badness * w;
98            Self::apply_update(&mut self.global, key, b, ev.timestamp.as_str());
99            if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
100                let table = self.by_intent.entry(k).or_default();
101                Self::apply_update(table, key, b, ev.timestamp.as_str());
102            }
103        }
104    }
105
106    fn apply_update(table: &mut ModePenaltyTable, mode: &str, badness: f64, ts: &str) {
107        let entry = table.modes.entry(mode.to_string()).or_default();
108        entry.ema_badness = entry.ema_badness * (1.0 - EMA_ALPHA) + badness * EMA_ALPHA;
109        entry.samples = entry.samples.saturating_add(1);
110        entry.last_ts = Some(ts.to_string());
111    }
112
113    pub fn penalty(&self, intent: Option<&str>, mode: &str) -> f64 {
114        let Some(key) = normalize_mode_key(mode) else {
115            return 0.0;
116        };
117        if let Some(k) = normalized_intent_key(intent) {
118            if let Some(t) = self.by_intent.get(&k) {
119                if let Some(p) = t.modes.get(key) {
120                    return p.ema_badness.clamp(0.0, 1.0);
121                }
122            }
123        }
124        self.global
125            .modes
126            .get(key)
127            .map(|p| p.ema_badness.clamp(0.0, 1.0))
128            .unwrap_or(0.0)
129    }
130
131    pub fn choose_auto_mode(&self, intent: Option<&str>, predicted: &str) -> String {
132        let candidates = auto_candidates(predicted);
133        let mut best_mode = predicted.to_string();
134        let mut best_score = f64::NEG_INFINITY;
135        for (i, mode) in candidates.into_iter().enumerate() {
136            let base = 1.0 - (i as f64 * 0.05);
137            let score = base - self.penalty(intent, &mode);
138            if score > best_score {
139                best_score = score;
140                best_mode = mode;
141            }
142        }
143        best_mode
144    }
145}
146
147fn mode_group(mode: &str) -> &str {
148    match mode {
149        "entropy" | "aggressive" => "aggressive",
150        other => other,
151    }
152}
153
154fn normalize_mode_key(mode: &str) -> Option<&str> {
155    if mode == "diff" {
156        return None;
157    }
158    if mode.starts_with("lines:") {
159        return None;
160    }
161    Some(mode_group(mode))
162}
163
164fn auto_candidates(predicted: &str) -> Vec<String> {
165    match predicted {
166        "aggressive" | "entropy" => vec![
167            "aggressive".to_string(),
168            "entropy".to_string(),
169            "map".to_string(),
170            "signatures".to_string(),
171            "full".to_string(),
172        ],
173        "map" => vec![
174            "map".to_string(),
175            "signatures".to_string(),
176            "full".to_string(),
177            "aggressive".to_string(),
178        ],
179        "signatures" => vec![
180            "signatures".to_string(),
181            "map".to_string(),
182            "full".to_string(),
183            "aggressive".to_string(),
184        ],
185        "full" => vec![
186            "full".to_string(),
187            "map".to_string(),
188            "signatures".to_string(),
189            "aggressive".to_string(),
190        ],
191        other => vec![
192            other.to_string(),
193            "map".to_string(),
194            "signatures".to_string(),
195            "full".to_string(),
196            "aggressive".to_string(),
197        ],
198    }
199}
200
201fn normalized_intent_key(intent: Option<&str>) -> Option<String> {
202    let s = intent?.trim();
203    if s.is_empty() {
204        return None;
205    }
206    let lower = s.to_lowercase();
207    Some(lower.chars().take(80).collect())
208}
209
210fn policy_path() -> PathBuf {
211    crate::core::data_dir::lean_ctx_data_dir()
212        .unwrap_or_else(|_| PathBuf::from("."))
213        .join(POLICY_FILE)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn penalty_defaults_to_zero() {
222        let p = AdaptiveModePolicyStore::default();
223        assert_eq!(p.penalty(None, "aggressive"), 0.0);
224    }
225
226    #[test]
227    fn choose_auto_avoids_penalized_aggressive() {
228        let mut store = AdaptiveModePolicyStore::default();
229        AdaptiveModePolicyStore::apply_update(&mut store.global, "aggressive", 1.0, "t");
230        let chosen = store.choose_auto_mode(Some("fix bug"), "aggressive");
231        assert_ne!(chosen, "aggressive");
232        assert_ne!(chosen, "entropy");
233    }
234}