Skip to main content

lean_ctx/core/
bandit.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct BanditArm {
6    pub name: String,
7    pub alpha: f64,
8    pub beta: f64,
9    pub entropy_threshold: f64,
10    pub jaccard_threshold: f64,
11    pub budget_ratio: f64,
12}
13
14impl BanditArm {
15    fn sample(&self) -> f64 {
16        beta_sample(self.alpha, self.beta)
17    }
18
19    pub fn update_success(&mut self) {
20        self.alpha += 1.0;
21    }
22
23    pub fn update_failure(&mut self) {
24        self.beta += 1.0;
25    }
26
27    pub fn decay(&mut self, factor: f64) {
28        self.alpha = (self.alpha * factor).max(1.0);
29        self.beta = (self.beta * factor).max(1.0);
30    }
31
32    pub fn mean(&self) -> f64 {
33        self.alpha / (self.alpha + self.beta)
34    }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ThresholdBandit {
39    pub arms: Vec<BanditArm>,
40    pub total_pulls: u64,
41}
42
43impl Default for ThresholdBandit {
44    fn default() -> Self {
45        Self {
46            arms: vec![
47                BanditArm {
48                    name: "conservative".to_string(),
49                    alpha: 2.0,
50                    beta: 1.0,
51                    entropy_threshold: 1.2,
52                    jaccard_threshold: 0.8,
53                    budget_ratio: 0.5,
54                },
55                BanditArm {
56                    name: "balanced".to_string(),
57                    alpha: 2.0,
58                    beta: 1.0,
59                    entropy_threshold: 0.9,
60                    jaccard_threshold: 0.7,
61                    budget_ratio: 0.35,
62                },
63                BanditArm {
64                    name: "aggressive".to_string(),
65                    alpha: 2.0,
66                    beta: 1.0,
67                    entropy_threshold: 0.6,
68                    jaccard_threshold: 0.55,
69                    budget_ratio: 0.2,
70                },
71            ],
72            total_pulls: 0,
73        }
74    }
75}
76
77impl ThresholdBandit {
78    pub fn select_arm(&mut self) -> &BanditArm {
79        self.total_pulls += 1;
80
81        let epsilon = (0.1 / (1.0 + self.total_pulls as f64 / 100.0)).max(0.02);
82        if rng_f64() < epsilon {
83            let idx = rng_usize(self.arms.len());
84            return &self.arms[idx];
85        }
86
87        let samples: Vec<f64> = self.arms.iter().map(BanditArm::sample).collect();
88        let best_idx = samples
89            .iter()
90            .enumerate()
91            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
92            .map_or(0, |(i, _)| i);
93
94        &self.arms[best_idx]
95    }
96
97    pub fn update(&mut self, arm_name: &str, success: bool) {
98        if let Some(arm) = self.arms.iter_mut().find(|a| a.name == arm_name) {
99            if success {
100                arm.update_success();
101            } else {
102                arm.update_failure();
103            }
104        }
105    }
106
107    pub fn decay_all(&mut self, factor: f64) {
108        for arm in &mut self.arms {
109            arm.decay(factor);
110        }
111    }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, Default)]
115pub struct BanditStore {
116    pub bandits: HashMap<String, ThresholdBandit>,
117}
118
119impl BanditStore {
120    pub fn get_or_create(&mut self, key: &str) -> &mut ThresholdBandit {
121        self.bandits.entry(key.to_string()).or_default()
122    }
123
124    pub fn load(project_root: &str) -> Self {
125        let path = bandit_path(project_root);
126        if path.exists() {
127            if let Ok(content) = std::fs::read_to_string(&path) {
128                if let Ok(store) = serde_json::from_str::<BanditStore>(&content) {
129                    return store;
130                }
131            }
132        }
133        Self::default()
134    }
135
136    pub fn save(&self, project_root: &str) -> Result<(), String> {
137        let path = bandit_path(project_root);
138        if let Some(parent) = path.parent() {
139            std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
140        }
141        let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
142        std::fs::write(path, json).map_err(|e| e.to_string())
143    }
144
145    pub fn format_report(&self) -> String {
146        if self.bandits.is_empty() {
147            return "No bandit data yet.".to_string();
148        }
149        let mut lines = vec!["Threshold Bandits (Thompson Sampling):".to_string()];
150        for (key, bandit) in &self.bandits {
151            lines.push(format!("  {key} (pulls: {}):", bandit.total_pulls));
152            for arm in &bandit.arms {
153                let mean = arm.mean();
154                lines.push(format!(
155                    "    {}: α={:.1} β={:.1} mean={:.0}% entropy={:.2} jaccard={:.2} budget={:.0}%",
156                    arm.name,
157                    arm.alpha,
158                    arm.beta,
159                    mean * 100.0,
160                    arm.entropy_threshold,
161                    arm.jaccard_threshold,
162                    arm.budget_ratio * 100.0
163                ));
164            }
165        }
166        lines.join("\n")
167    }
168}
169
170fn bandit_path(project_root: &str) -> std::path::PathBuf {
171    let hash = {
172        use std::hash::{Hash, Hasher};
173        let mut hasher = std::collections::hash_map::DefaultHasher::new();
174        project_root.hash(&mut hasher);
175        format!("{:x}", hasher.finish())
176    };
177    crate::core::data_dir::lean_ctx_data_dir()
178        .unwrap_or_else(|_| std::path::PathBuf::from("."))
179        .join("projects")
180        .join(hash)
181        .join("bandits.json")
182}
183
184fn rng_f64() -> f64 {
185    let mut bytes = [0u8; 8];
186    getrandom::fill(&mut bytes).unwrap_or(());
187    let val = u64::from_le_bytes(bytes);
188    (val >> 11) as f64 / ((1u64 << 53) as f64)
189}
190
191fn rng_usize(bound: usize) -> usize {
192    if bound == 0 {
193        return 0;
194    }
195    let mut bytes = [0u8; 8];
196    getrandom::fill(&mut bytes).unwrap_or(());
197    let val = u64::from_le_bytes(bytes);
198    (val as usize) % bound
199}
200
201fn beta_sample(alpha: f64, beta: f64) -> f64 {
202    let x = gamma_sample(alpha);
203    let y = gamma_sample(beta);
204    if x + y == 0.0 {
205        return 0.5;
206    }
207    x / (x + y)
208}
209
210#[allow(clippy::many_single_char_names)] // Marsaglia's algorithm uses standard math notation
211fn gamma_sample(shape: f64) -> f64 {
212    if shape < 1.0 {
213        let u = rng_f64().max(1e-10);
214        gamma_sample(shape + 1.0) * u.powf(1.0 / shape)
215    } else {
216        let d = shape - 1.0 / 3.0;
217        let c = 1.0 / (9.0_f64 * d).sqrt();
218        loop {
219            let x = standard_normal();
220            let v = (1.0 + c * x).powi(3);
221            if v <= 0.0 {
222                continue;
223            }
224            let u = rng_f64().max(1e-10);
225            if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
226                return d * v;
227            }
228        }
229    }
230}
231
232fn standard_normal() -> f64 {
233    let u1: f64 = rng_f64().max(1e-10);
234    let u2: f64 = rng_f64();
235    (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn bandit_default_has_three_arms() {
244        let b = ThresholdBandit::default();
245        assert_eq!(b.arms.len(), 3);
246        assert_eq!(b.arms[0].name, "conservative");
247        assert_eq!(b.arms[1].name, "balanced");
248        assert_eq!(b.arms[2].name, "aggressive");
249    }
250
251    #[test]
252    fn bandit_selection_works() {
253        let mut b = ThresholdBandit::default();
254        for _ in 0..10 {
255            let arm = b.select_arm();
256            let _ = arm.name.clone();
257        }
258        assert_eq!(b.total_pulls, 10);
259    }
260
261    #[test]
262    fn bandit_update_shifts_distribution() {
263        let mut b = ThresholdBandit::default();
264        for _ in 0..20 {
265            b.update("aggressive", true);
266        }
267        for _ in 0..20 {
268            b.update("conservative", false);
269        }
270        let agg = b.arms.iter().find(|a| a.name == "aggressive").unwrap();
271        let con = b.arms.iter().find(|a| a.name == "conservative").unwrap();
272        assert!(agg.mean() > con.mean());
273    }
274
275    #[test]
276    fn beta_sample_in_range() {
277        for _ in 0..100 {
278            let s = beta_sample(2.0, 2.0);
279            assert!((0.0..=1.0).contains(&s), "got {s}");
280        }
281    }
282
283    #[test]
284    fn store_save_load_roundtrip() {
285        let dir = std::env::temp_dir().join("bandit-test");
286        let root = dir.to_string_lossy().to_string();
287        let mut store = BanditStore::default();
288        store.get_or_create("rs_medium");
289        store.save(&root).unwrap();
290        let loaded = BanditStore::load(&root);
291        assert!(loaded.bandits.contains_key("rs_medium"));
292        let _ = std::fs::remove_dir_all(&dir);
293    }
294}