Skip to main content

lean_ctx/core/
bandit.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct BanditArm {
6    pub name: String,
7    pub alpha: f64,
8    pub beta: f64,
9    pub entropy_threshold: f64,
10    pub jaccard_threshold: f64,
11    pub budget_ratio: f64,
12}
13
14impl BanditArm {
15    fn sample(&self) -> f64 {
16        beta_sample(self.alpha, self.beta)
17    }
18
19    pub fn update_from_feedback(&mut self, outcome: &crate::core::feedback::CompressionOutcome) {
20        let efficiency = if outcome.tokens_original > 0 {
21            outcome.tokens_saved as f64 / outcome.tokens_original as f64
22        } else {
23            0.0
24        };
25        let success = efficiency > 0.3 && outcome.task_completed;
26        if success {
27            self.update_success();
28        } else {
29            self.update_failure();
30        }
31    }
32
33    pub fn update_success(&mut self) {
34        self.alpha += 1.0;
35    }
36
37    pub fn update_failure(&mut self) {
38        self.beta += 1.0;
39    }
40
41    pub fn decay(&mut self, factor: f64) {
42        self.alpha = (self.alpha * factor).max(1.0);
43        self.beta = (self.beta * factor).max(1.0);
44    }
45
46    pub fn mean(&self) -> f64 {
47        self.alpha / (self.alpha + self.beta)
48    }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ThresholdBandit {
53    pub arms: Vec<BanditArm>,
54    pub total_pulls: u64,
55}
56
57impl Default for ThresholdBandit {
58    fn default() -> Self {
59        Self {
60            arms: vec![
61                BanditArm {
62                    name: "conservative".to_string(),
63                    alpha: 2.0,
64                    beta: 1.0,
65                    entropy_threshold: 1.2,
66                    jaccard_threshold: 0.8,
67                    budget_ratio: 0.5,
68                },
69                BanditArm {
70                    name: "balanced".to_string(),
71                    alpha: 2.0,
72                    beta: 1.0,
73                    entropy_threshold: 0.9,
74                    jaccard_threshold: 0.7,
75                    budget_ratio: 0.35,
76                },
77                BanditArm {
78                    name: "aggressive".to_string(),
79                    alpha: 2.0,
80                    beta: 1.0,
81                    entropy_threshold: 0.6,
82                    jaccard_threshold: 0.55,
83                    budget_ratio: 0.2,
84                },
85            ],
86            total_pulls: 0,
87        }
88    }
89}
90
91impl ThresholdBandit {
92    pub fn select_arm(&mut self) -> &BanditArm {
93        self.total_pulls += 1;
94
95        let epsilon = (0.1 / (1.0 + self.total_pulls as f64 / 100.0)).max(0.02);
96        if rng_f64() < epsilon {
97            let idx = rng_usize(self.arms.len());
98            return &self.arms[idx];
99        }
100
101        let samples: Vec<f64> = self.arms.iter().map(BanditArm::sample).collect();
102        let best_idx = samples
103            .iter()
104            .enumerate()
105            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
106            .map_or(0, |(i, _)| i);
107
108        &self.arms[best_idx]
109    }
110
111    pub fn update(&mut self, arm_name: &str, success: bool) {
112        if let Some(arm) = self.arms.iter_mut().find(|a| a.name == arm_name) {
113            if success {
114                arm.update_success();
115            } else {
116                arm.update_failure();
117            }
118        }
119    }
120
121    pub fn decay_all(&mut self, factor: f64) {
122        for arm in &mut self.arms {
123            arm.decay(factor);
124        }
125    }
126
127    pub fn update_from_session(&mut self, outcomes: &[crate::core::feedback::CompressionOutcome]) {
128        for outcome in outcomes {
129            let efficiency = if outcome.tokens_original > 0 {
130                outcome.tokens_saved as f64 / outcome.tokens_original as f64
131            } else {
132                0.0
133            };
134            let success = efficiency > 0.3 && outcome.task_completed;
135
136            let arm_name = if outcome.entropy_threshold >= 1.0 {
137                "conservative"
138            } else if outcome.entropy_threshold >= 0.7 {
139                "balanced"
140            } else {
141                "aggressive"
142            };
143
144            self.update(arm_name, success);
145        }
146
147        if !outcomes.is_empty() {
148            self.decay_all(0.98);
149        }
150    }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, Default)]
154pub struct BanditStore {
155    pub bandits: HashMap<String, ThresholdBandit>,
156}
157
158impl BanditStore {
159    pub fn get_or_create(&mut self, key: &str) -> &mut ThresholdBandit {
160        self.bandits.entry(key.to_string()).or_default()
161    }
162
163    pub fn load(project_root: &str) -> Self {
164        let path = bandit_path(project_root);
165        if path.exists() {
166            if let Ok(content) = std::fs::read_to_string(&path) {
167                if let Ok(store) = serde_json::from_str::<BanditStore>(&content) {
168                    return store;
169                }
170            }
171        }
172        Self::default()
173    }
174
175    pub fn save(&self, project_root: &str) -> Result<(), String> {
176        let path = bandit_path(project_root);
177        if let Some(parent) = path.parent() {
178            std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
179        }
180        let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
181        std::fs::write(path, json).map_err(|e| e.to_string())
182    }
183
184    pub fn format_report(&self) -> String {
185        if self.bandits.is_empty() {
186            return "No bandit data yet.".to_string();
187        }
188        let mut lines = vec!["Threshold Bandits (Thompson Sampling):".to_string()];
189        for (key, bandit) in &self.bandits {
190            lines.push(format!("  {key} (pulls: {}):", bandit.total_pulls));
191            for arm in &bandit.arms {
192                let mean = arm.mean();
193                lines.push(format!(
194                    "    {}: α={:.1} β={:.1} mean={:.0}% entropy={:.2} jaccard={:.2} budget={:.0}%",
195                    arm.name,
196                    arm.alpha,
197                    arm.beta,
198                    mean * 100.0,
199                    arm.entropy_threshold,
200                    arm.jaccard_threshold,
201                    arm.budget_ratio * 100.0
202                ));
203            }
204        }
205        lines.join("\n")
206    }
207}
208
209fn bandit_path(project_root: &str) -> std::path::PathBuf {
210    let hash = {
211        use std::hash::{Hash, Hasher};
212        let mut hasher = std::collections::hash_map::DefaultHasher::new();
213        project_root.hash(&mut hasher);
214        format!("{:x}", hasher.finish())
215    };
216    crate::core::data_dir::lean_ctx_data_dir()
217        .unwrap_or_else(|_| std::path::PathBuf::from("."))
218        .join("projects")
219        .join(hash)
220        .join("bandits.json")
221}
222
223fn rng_f64() -> f64 {
224    let mut bytes = [0u8; 8];
225    getrandom::fill(&mut bytes).unwrap_or(());
226    let val = u64::from_le_bytes(bytes);
227    (val >> 11) as f64 / ((1u64 << 53) as f64)
228}
229
230fn rng_usize(bound: usize) -> usize {
231    if bound == 0 {
232        return 0;
233    }
234    let mut bytes = [0u8; 8];
235    getrandom::fill(&mut bytes).unwrap_or(());
236    let val = u64::from_le_bytes(bytes);
237    (val as usize) % bound
238}
239
240fn beta_sample(alpha: f64, beta: f64) -> f64 {
241    let x = gamma_sample(alpha);
242    let y = gamma_sample(beta);
243    if x + y == 0.0 {
244        return 0.5;
245    }
246    x / (x + y)
247}
248
249#[allow(clippy::many_single_char_names)] // Marsaglia's algorithm uses standard math notation
250fn gamma_sample(shape: f64) -> f64 {
251    if shape < 1.0 {
252        let u = rng_f64().max(1e-10);
253        gamma_sample(shape + 1.0) * u.powf(1.0 / shape)
254    } else {
255        let d = shape - 1.0 / 3.0;
256        let c = 1.0 / (9.0_f64 * d).sqrt();
257        loop {
258            let x = standard_normal();
259            let v = (1.0 + c * x).powi(3);
260            if v <= 0.0 {
261                continue;
262            }
263            let u = rng_f64().max(1e-10);
264            if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
265                return d * v;
266            }
267        }
268    }
269}
270
271fn standard_normal() -> f64 {
272    let u1: f64 = rng_f64().max(1e-10);
273    let u2: f64 = rng_f64();
274    (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn bandit_default_has_three_arms() {
283        let b = ThresholdBandit::default();
284        assert_eq!(b.arms.len(), 3);
285        assert_eq!(b.arms[0].name, "conservative");
286        assert_eq!(b.arms[1].name, "balanced");
287        assert_eq!(b.arms[2].name, "aggressive");
288    }
289
290    #[test]
291    fn bandit_selection_works() {
292        let mut b = ThresholdBandit::default();
293        for _ in 0..10 {
294            let arm = b.select_arm();
295            let _ = arm.name.clone();
296        }
297        assert_eq!(b.total_pulls, 10);
298    }
299
300    #[test]
301    fn bandit_update_shifts_distribution() {
302        let mut b = ThresholdBandit::default();
303        for _ in 0..20 {
304            b.update("aggressive", true);
305        }
306        for _ in 0..20 {
307            b.update("conservative", false);
308        }
309        let agg = b.arms.iter().find(|a| a.name == "aggressive").unwrap();
310        let con = b.arms.iter().find(|a| a.name == "conservative").unwrap();
311        assert!(agg.mean() > con.mean());
312    }
313
314    #[test]
315    fn beta_sample_in_range() {
316        for _ in 0..100 {
317            let s = beta_sample(2.0, 2.0);
318            assert!((0.0..=1.0).contains(&s), "got {s}");
319        }
320    }
321
322    #[test]
323    fn store_save_load_roundtrip() {
324        let dir = std::env::temp_dir().join("bandit-test");
325        let root = dir.to_string_lossy().to_string();
326        let mut store = BanditStore::default();
327        store.get_or_create("rs_medium");
328        store.save(&root).unwrap();
329        let loaded = BanditStore::load(&root);
330        assert!(loaded.bandits.contains_key("rs_medium"));
331        let _ = std::fs::remove_dir_all(&dir);
332    }
333}