Skip to main content

lean_ctx/core/
adaptive_mode_policy.rs

1use std::collections::{BTreeMap, HashMap};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use crate::core::llm_feedback::LlmFeedbackEvent;
7
8const POLICY_FILE: &str = "adaptive_mode_policy.json";
9const EMA_ALPHA: f64 = 0.2;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct AdaptiveModePolicyStore {
13    pub global: ModePenaltyTable,
14    #[serde(default)]
15    pub by_intent: HashMap<String, ModePenaltyTable>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct ModePenaltyTable {
20    #[serde(default)]
21    pub modes: BTreeMap<String, ModePenalty>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ModePenalty {
26    pub ema_badness: f64,
27    pub samples: u64,
28    pub last_ts: Option<String>,
29}
30
31impl AdaptiveModePolicyStore {
32    pub fn load() -> Self {
33        let path = policy_path();
34        let Ok(s) = std::fs::read_to_string(&path) else {
35            return Self::default();
36        };
37        serde_json::from_str(&s).unwrap_or_default()
38    }
39
40    pub fn save(&self) -> Result<(), String> {
41        let path = policy_path();
42        if let Some(parent) = path.parent() {
43            std::fs::create_dir_all(parent)
44                .map_err(|e| format!("create_dir_all {}: {e}", parent.display()))?;
45        }
46        let json = serde_json::to_string_pretty(self).map_err(|e| format!("serialize: {e}"))?;
47        let tmp = path.with_extension("tmp");
48        std::fs::write(&tmp, json).map_err(|e| format!("write {}: {e}", tmp.display()))?;
49        std::fs::rename(&tmp, &path).map_err(|e| format!("rename {}: {e}", path.display()))?;
50        Ok(())
51    }
52
53    pub fn reset() -> Result<(), String> {
54        let path = policy_path();
55        if path.exists() {
56            std::fs::remove_file(&path).map_err(|e| format!("remove {}: {e}", path.display()))?;
57        }
58        Ok(())
59    }
60
61    pub fn update_from_feedback(&mut self, ev: &LlmFeedbackEvent) {
62        let ratio = ev.llm_output_tokens as f64 / ev.llm_input_tokens.max(1) as f64;
63        let mut badness = ((ratio - 1.2) / 1.2).clamp(0.0, 1.0);
64        if ev.llm_output_tokens >= 6000 {
65            badness = badness.max(0.8);
66        } else if ev.llm_output_tokens >= 3000 {
67            badness = badness.max(0.5);
68        }
69        if badness <= 0.0 {
70            return;
71        }
72
73        let modes = ev.ctx_read_modes.clone().unwrap_or_default();
74        if modes.is_empty() {
75            if let Some(m) = ev.ctx_read_last_mode.as_ref() {
76                if let Some(key) = normalize_mode_key(m) {
77                    Self::apply_update(&mut self.global, key, badness, ev.timestamp.as_str());
78                    if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
79                        let table = self.by_intent.entry(k).or_default();
80                        Self::apply_update(table, key, badness, ev.timestamp.as_str());
81                    }
82                }
83            }
84            return;
85        }
86
87        let total: u64 = modes.values().sum();
88        if total == 0 {
89            return;
90        }
91
92        for (mode, count) in modes {
93            let Some(key) = normalize_mode_key(&mode) else {
94                continue;
95            };
96            let w = count as f64 / total as f64;
97            let b = badness * w;
98            Self::apply_update(&mut self.global, key, b, ev.timestamp.as_str());
99            if let Some(k) = normalized_intent_key(ev.intent.as_deref()) {
100                let table = self.by_intent.entry(k).or_default();
101                Self::apply_update(table, key, b, ev.timestamp.as_str());
102            }
103        }
104    }
105
106    fn apply_update(table: &mut ModePenaltyTable, mode: &str, badness: f64, ts: &str) {
107        let entry = table.modes.entry(mode.to_string()).or_default();
108        entry.ema_badness = entry.ema_badness * (1.0 - EMA_ALPHA) + badness * EMA_ALPHA;
109        entry.samples = entry.samples.saturating_add(1);
110        entry.last_ts = Some(ts.to_string());
111    }
112
113    pub fn penalty(&self, intent: Option<&str>, mode: &str) -> f64 {
114        let Some(key) = normalize_mode_key(mode) else {
115            return 0.0;
116        };
117        if let Some(k) = normalized_intent_key(intent) {
118            if let Some(t) = self.by_intent.get(&k) {
119                if let Some(p) = t.modes.get(key) {
120                    return p.ema_badness.clamp(0.0, 1.0);
121                }
122            }
123        }
124        self.global
125            .modes
126            .get(key)
127            .map_or(0.0, |p| p.ema_badness.clamp(0.0, 1.0))
128    }
129
130    pub fn choose_auto_mode(&self, intent: Option<&str>, predicted: &str) -> String {
131        let candidates = auto_candidates(predicted);
132        let mut best_mode = predicted.to_string();
133        let mut best_score = f64::NEG_INFINITY;
134        for (i, mode) in candidates.into_iter().enumerate() {
135            let base = 1.0 - (i as f64 * 0.05);
136            let score = base - self.penalty(intent, &mode);
137            if score > best_score {
138                best_score = score;
139                best_mode = mode;
140            }
141        }
142        best_mode
143    }
144}
145
146fn mode_group(mode: &str) -> &str {
147    match mode {
148        "entropy" | "aggressive" => "aggressive",
149        other => other,
150    }
151}
152
153fn normalize_mode_key(mode: &str) -> Option<&str> {
154    if mode == "diff" {
155        return None;
156    }
157    if mode.starts_with("lines:") {
158        return None;
159    }
160    Some(mode_group(mode))
161}
162
163fn auto_candidates(predicted: &str) -> Vec<String> {
164    match predicted {
165        "aggressive" | "entropy" => vec![
166            "aggressive".to_string(),
167            "entropy".to_string(),
168            "map".to_string(),
169            "signatures".to_string(),
170            "full".to_string(),
171        ],
172        "map" => vec![
173            "map".to_string(),
174            "signatures".to_string(),
175            "full".to_string(),
176            "aggressive".to_string(),
177        ],
178        "signatures" => vec![
179            "signatures".to_string(),
180            "map".to_string(),
181            "full".to_string(),
182            "aggressive".to_string(),
183        ],
184        "full" => vec![
185            "full".to_string(),
186            "map".to_string(),
187            "signatures".to_string(),
188            "aggressive".to_string(),
189        ],
190        other => vec![
191            other.to_string(),
192            "map".to_string(),
193            "signatures".to_string(),
194            "full".to_string(),
195            "aggressive".to_string(),
196        ],
197    }
198}
199
200fn normalized_intent_key(intent: Option<&str>) -> Option<String> {
201    let s = intent?.trim();
202    if s.is_empty() {
203        return None;
204    }
205    let lower = s.to_lowercase();
206    Some(lower.chars().take(80).collect())
207}
208
209fn policy_path() -> PathBuf {
210    crate::core::data_dir::lean_ctx_data_dir()
211        .unwrap_or_else(|_| PathBuf::from("."))
212        .join(POLICY_FILE)
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn penalty_defaults_to_zero() {
221        let p = AdaptiveModePolicyStore::default();
222        assert_eq!(p.penalty(None, "aggressive"), 0.0);
223    }
224
225    #[test]
226    fn choose_auto_avoids_penalized_aggressive() {
227        let mut store = AdaptiveModePolicyStore::default();
228        AdaptiveModePolicyStore::apply_update(&mut store.global, "aggressive", 1.0, "t");
229        let chosen = store.choose_auto_mode(Some("fix bug"), "aggressive");
230        assert_ne!(chosen, "aggressive");
231        assert_ne!(chosen, "entropy");
232    }
233}