1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct BanditArm {
6 pub name: String,
7 pub alpha: f64,
8 pub beta: f64,
9 pub entropy_threshold: f64,
10 pub jaccard_threshold: f64,
11 pub budget_ratio: f64,
12}
13
14impl BanditArm {
15 fn sample(&self) -> f64 {
16 beta_sample(self.alpha, self.beta)
17 }
18
19 pub fn update_from_feedback(&mut self, outcome: &crate::core::feedback::CompressionOutcome) {
20 let efficiency = if outcome.tokens_original > 0 {
21 outcome.tokens_saved as f64 / outcome.tokens_original as f64
22 } else {
23 0.0
24 };
25 let success = efficiency > 0.3 && outcome.task_completed;
26 if success {
27 self.update_success();
28 } else {
29 self.update_failure();
30 }
31 }
32
33 pub fn update_success(&mut self) {
34 self.alpha += 1.0;
35 }
36
37 pub fn update_failure(&mut self) {
38 self.beta += 1.0;
39 }
40
41 pub fn decay(&mut self, factor: f64) {
42 self.alpha = (self.alpha * factor).max(1.0);
43 self.beta = (self.beta * factor).max(1.0);
44 }
45
46 pub fn mean(&self) -> f64 {
47 self.alpha / (self.alpha + self.beta)
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ThresholdBandit {
53 pub arms: Vec<BanditArm>,
54 pub total_pulls: u64,
55}
56
57impl Default for ThresholdBandit {
58 fn default() -> Self {
59 Self {
60 arms: vec![
61 BanditArm {
62 name: "conservative".to_string(),
63 alpha: 2.0,
64 beta: 1.0,
65 entropy_threshold: 1.2,
66 jaccard_threshold: 0.8,
67 budget_ratio: 0.5,
68 },
69 BanditArm {
70 name: "balanced".to_string(),
71 alpha: 2.0,
72 beta: 1.0,
73 entropy_threshold: 0.9,
74 jaccard_threshold: 0.7,
75 budget_ratio: 0.35,
76 },
77 BanditArm {
78 name: "aggressive".to_string(),
79 alpha: 2.0,
80 beta: 1.0,
81 entropy_threshold: 0.6,
82 jaccard_threshold: 0.55,
83 budget_ratio: 0.2,
84 },
85 ],
86 total_pulls: 0,
87 }
88 }
89}
90
91impl ThresholdBandit {
92 pub fn select_arm(&mut self) -> &BanditArm {
93 self.total_pulls += 1;
94
95 let epsilon = (0.1 / (1.0 + self.total_pulls as f64 / 100.0)).max(0.02);
96 if rng_f64() < epsilon {
97 let idx = rng_usize(self.arms.len());
98 return &self.arms[idx];
99 }
100
101 let samples: Vec<f64> = self.arms.iter().map(BanditArm::sample).collect();
102 let best_idx = samples
103 .iter()
104 .enumerate()
105 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
106 .map_or(0, |(i, _)| i);
107
108 &self.arms[best_idx]
109 }
110
111 pub fn update(&mut self, arm_name: &str, success: bool) {
112 if let Some(arm) = self.arms.iter_mut().find(|a| a.name == arm_name) {
113 if success {
114 arm.update_success();
115 } else {
116 arm.update_failure();
117 }
118 }
119 }
120
121 pub fn decay_all(&mut self, factor: f64) {
122 for arm in &mut self.arms {
123 arm.decay(factor);
124 }
125 }
126
127 pub fn update_from_session(&mut self, outcomes: &[crate::core::feedback::CompressionOutcome]) {
128 for outcome in outcomes {
129 let efficiency = if outcome.tokens_original > 0 {
130 outcome.tokens_saved as f64 / outcome.tokens_original as f64
131 } else {
132 0.0
133 };
134 let success = efficiency > 0.3 && outcome.task_completed;
135
136 let arm_name = if outcome.entropy_threshold >= 1.0 {
137 "conservative"
138 } else if outcome.entropy_threshold >= 0.7 {
139 "balanced"
140 } else {
141 "aggressive"
142 };
143
144 self.update(arm_name, success);
145 }
146
147 if !outcomes.is_empty() {
148 self.decay_all(0.98);
149 }
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, Default)]
154pub struct BanditStore {
155 pub bandits: HashMap<String, ThresholdBandit>,
156}
157
158impl BanditStore {
159 pub fn get_or_create(&mut self, key: &str) -> &mut ThresholdBandit {
160 self.bandits.entry(key.to_string()).or_default()
161 }
162
163 pub fn load(project_root: &str) -> Self {
164 let path = bandit_path(project_root);
165 if path.exists() {
166 if let Ok(content) = std::fs::read_to_string(&path) {
167 if let Ok(store) = serde_json::from_str::<BanditStore>(&content) {
168 return store;
169 }
170 }
171 }
172 Self::default()
173 }
174
175 pub fn save(&self, project_root: &str) -> Result<(), String> {
176 let path = bandit_path(project_root);
177 if let Some(parent) = path.parent() {
178 std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
179 }
180 let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
181 std::fs::write(path, json).map_err(|e| e.to_string())
182 }
183
184 pub fn format_report(&self) -> String {
185 if self.bandits.is_empty() {
186 return "No bandit data yet.".to_string();
187 }
188 let mut lines = vec!["Threshold Bandits (Thompson Sampling):".to_string()];
189 for (key, bandit) in &self.bandits {
190 lines.push(format!(" {key} (pulls: {}):", bandit.total_pulls));
191 for arm in &bandit.arms {
192 let mean = arm.mean();
193 lines.push(format!(
194 " {}: α={:.1} β={:.1} mean={:.0}% entropy={:.2} jaccard={:.2} budget={:.0}%",
195 arm.name,
196 arm.alpha,
197 arm.beta,
198 mean * 100.0,
199 arm.entropy_threshold,
200 arm.jaccard_threshold,
201 arm.budget_ratio * 100.0
202 ));
203 }
204 }
205 lines.join("\n")
206 }
207}
208
209fn bandit_path(project_root: &str) -> std::path::PathBuf {
210 let hash = crate::core::project_hash::hash_project_root(project_root);
211 crate::core::data_dir::lean_ctx_data_dir()
212 .unwrap_or_else(|_| std::path::PathBuf::from("."))
213 .join("projects")
214 .join(hash)
215 .join("bandits.json")
216}
217
218fn rng_f64() -> f64 {
219 let mut bytes = [0u8; 8];
220 getrandom::fill(&mut bytes).unwrap_or(());
221 let val = u64::from_le_bytes(bytes);
222 (val >> 11) as f64 / ((1u64 << 53) as f64)
223}
224
225fn rng_usize(bound: usize) -> usize {
226 if bound == 0 {
227 return 0;
228 }
229 let mut bytes = [0u8; 8];
230 getrandom::fill(&mut bytes).unwrap_or(());
231 let val = u64::from_le_bytes(bytes);
232 (val as usize) % bound
233}
234
235fn beta_sample(alpha: f64, beta: f64) -> f64 {
236 let x = gamma_sample(alpha);
237 let y = gamma_sample(beta);
238 if x + y == 0.0 {
239 return 0.5;
240 }
241 x / (x + y)
242}
243
244#[allow(clippy::many_single_char_names)] fn gamma_sample(shape: f64) -> f64 {
246 if shape < 1.0 {
247 let u = rng_f64().max(1e-10);
248 gamma_sample(shape + 1.0) * u.powf(1.0 / shape)
249 } else {
250 let d = shape - 1.0 / 3.0;
251 let c = 1.0 / (9.0_f64 * d).sqrt();
252 loop {
253 let x = standard_normal();
254 let v = (1.0 + c * x).powi(3);
255 if v <= 0.0 {
256 continue;
257 }
258 let u = rng_f64().max(1e-10);
259 if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
260 return d * v;
261 }
262 }
263 }
264}
265
266fn standard_normal() -> f64 {
267 let u1: f64 = rng_f64().max(1e-10);
268 let u2: f64 = rng_f64();
269 (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn bandit_default_has_three_arms() {
278 let b = ThresholdBandit::default();
279 assert_eq!(b.arms.len(), 3);
280 assert_eq!(b.arms[0].name, "conservative");
281 assert_eq!(b.arms[1].name, "balanced");
282 assert_eq!(b.arms[2].name, "aggressive");
283 }
284
285 #[test]
286 fn bandit_selection_works() {
287 let mut b = ThresholdBandit::default();
288 for _ in 0..10 {
289 let arm = b.select_arm();
290 let _ = arm.name.clone();
291 }
292 assert_eq!(b.total_pulls, 10);
293 }
294
295 #[test]
296 fn bandit_update_shifts_distribution() {
297 let mut b = ThresholdBandit::default();
298 for _ in 0..20 {
299 b.update("aggressive", true);
300 }
301 for _ in 0..20 {
302 b.update("conservative", false);
303 }
304 let agg = b.arms.iter().find(|a| a.name == "aggressive").unwrap();
305 let con = b.arms.iter().find(|a| a.name == "conservative").unwrap();
306 assert!(agg.mean() > con.mean());
307 }
308
309 #[test]
310 fn beta_sample_in_range() {
311 for _ in 0..100 {
312 let s = beta_sample(2.0, 2.0);
313 assert!((0.0..=1.0).contains(&s), "got {s}");
314 }
315 }
316
317 #[test]
318 fn store_save_load_roundtrip() {
319 let _env = crate::core::data_dir::test_env_lock();
320 let data_dir = tempfile::tempdir().unwrap();
321 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
322
323 let project = tempfile::tempdir().unwrap();
324 let root = project.path().to_string_lossy().to_string();
325 let mut store = BanditStore::default();
326 store.get_or_create("rs_medium");
327 store.save(&root).unwrap();
328 let loaded = BanditStore::load(&root);
329 assert!(loaded.bandits.contains_key("rs_medium"));
330 }
331}