1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct BanditArm {
6 pub name: String,
7 pub alpha: f64,
8 pub beta: f64,
9 pub entropy_threshold: f64,
10 pub jaccard_threshold: f64,
11 pub budget_ratio: f64,
12}
13
14impl BanditArm {
15 fn sample(&self) -> f64 {
16 beta_sample(self.alpha, self.beta)
17 }
18
19 pub fn update_success(&mut self) {
20 self.alpha += 1.0;
21 }
22
23 pub fn update_failure(&mut self) {
24 self.beta += 1.0;
25 }
26
27 pub fn decay(&mut self, factor: f64) {
28 self.alpha = (self.alpha * factor).max(1.0);
29 self.beta = (self.beta * factor).max(1.0);
30 }
31
32 pub fn mean(&self) -> f64 {
33 self.alpha / (self.alpha + self.beta)
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ThresholdBandit {
39 pub arms: Vec<BanditArm>,
40 pub total_pulls: u64,
41}
42
43impl Default for ThresholdBandit {
44 fn default() -> Self {
45 Self {
46 arms: vec![
47 BanditArm {
48 name: "conservative".to_string(),
49 alpha: 2.0,
50 beta: 1.0,
51 entropy_threshold: 1.2,
52 jaccard_threshold: 0.8,
53 budget_ratio: 0.5,
54 },
55 BanditArm {
56 name: "balanced".to_string(),
57 alpha: 2.0,
58 beta: 1.0,
59 entropy_threshold: 0.9,
60 jaccard_threshold: 0.7,
61 budget_ratio: 0.35,
62 },
63 BanditArm {
64 name: "aggressive".to_string(),
65 alpha: 2.0,
66 beta: 1.0,
67 entropy_threshold: 0.6,
68 jaccard_threshold: 0.55,
69 budget_ratio: 0.2,
70 },
71 ],
72 total_pulls: 0,
73 }
74 }
75}
76
77impl ThresholdBandit {
78 pub fn select_arm(&mut self) -> &BanditArm {
79 self.total_pulls += 1;
80
81 let epsilon = (0.1 / (1.0 + self.total_pulls as f64 / 100.0)).max(0.02);
82 if rng_f64() < epsilon {
83 let idx = rng_usize(self.arms.len());
84 return &self.arms[idx];
85 }
86
87 let samples: Vec<f64> = self.arms.iter().map(|a| a.sample()).collect();
88 let best_idx = samples
89 .iter()
90 .enumerate()
91 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
92 .map(|(i, _)| i)
93 .unwrap_or(0);
94
95 &self.arms[best_idx]
96 }
97
98 pub fn update(&mut self, arm_name: &str, success: bool) {
99 if let Some(arm) = self.arms.iter_mut().find(|a| a.name == arm_name) {
100 if success {
101 arm.update_success();
102 } else {
103 arm.update_failure();
104 }
105 }
106 }
107
108 pub fn decay_all(&mut self, factor: f64) {
109 for arm in &mut self.arms {
110 arm.decay(factor);
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize, Default)]
116pub struct BanditStore {
117 pub bandits: HashMap<String, ThresholdBandit>,
118}
119
120impl BanditStore {
121 pub fn get_or_create(&mut self, key: &str) -> &mut ThresholdBandit {
122 self.bandits.entry(key.to_string()).or_default()
123 }
124
125 pub fn load(project_root: &str) -> Self {
126 let path = bandit_path(project_root);
127 if path.exists() {
128 if let Ok(content) = std::fs::read_to_string(&path) {
129 if let Ok(store) = serde_json::from_str::<BanditStore>(&content) {
130 return store;
131 }
132 }
133 }
134 Self::default()
135 }
136
137 pub fn save(&self, project_root: &str) -> Result<(), String> {
138 let path = bandit_path(project_root);
139 if let Some(parent) = path.parent() {
140 std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
141 }
142 let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
143 std::fs::write(path, json).map_err(|e| e.to_string())
144 }
145
146 pub fn format_report(&self) -> String {
147 if self.bandits.is_empty() {
148 return "No bandit data yet.".to_string();
149 }
150 let mut lines = vec!["Threshold Bandits (Thompson Sampling):".to_string()];
151 for (key, bandit) in &self.bandits {
152 lines.push(format!(" {key} (pulls: {}):", bandit.total_pulls));
153 for arm in &bandit.arms {
154 let mean = arm.mean();
155 lines.push(format!(
156 " {}: α={:.1} β={:.1} mean={:.0}% entropy={:.2} jaccard={:.2} budget={:.0}%",
157 arm.name,
158 arm.alpha,
159 arm.beta,
160 mean * 100.0,
161 arm.entropy_threshold,
162 arm.jaccard_threshold,
163 arm.budget_ratio * 100.0
164 ));
165 }
166 }
167 lines.join("\n")
168 }
169}
170
171fn bandit_path(project_root: &str) -> std::path::PathBuf {
172 let hash = {
173 use std::hash::{Hash, Hasher};
174 let mut hasher = std::collections::hash_map::DefaultHasher::new();
175 project_root.hash(&mut hasher);
176 format!("{:x}", hasher.finish())
177 };
178 crate::core::data_dir::lean_ctx_data_dir()
179 .unwrap_or_else(|_| std::path::PathBuf::from("."))
180 .join("projects")
181 .join(hash)
182 .join("bandits.json")
183}
184
185fn rng_f64() -> f64 {
186 let mut bytes = [0u8; 8];
187 getrandom::fill(&mut bytes).unwrap_or(());
188 let val = u64::from_le_bytes(bytes);
189 (val >> 11) as f64 / ((1u64 << 53) as f64)
190}
191
192fn rng_usize(bound: usize) -> usize {
193 if bound == 0 {
194 return 0;
195 }
196 let mut bytes = [0u8; 8];
197 getrandom::fill(&mut bytes).unwrap_or(());
198 let val = u64::from_le_bytes(bytes);
199 (val as usize) % bound
200}
201
202fn beta_sample(alpha: f64, beta: f64) -> f64 {
203 let x = gamma_sample(alpha);
204 let y = gamma_sample(beta);
205 if x + y == 0.0 {
206 return 0.5;
207 }
208 x / (x + y)
209}
210
211fn gamma_sample(shape: f64) -> f64 {
212 if shape < 1.0 {
213 let u = rng_f64().max(1e-10);
214 gamma_sample(shape + 1.0) * u.powf(1.0 / shape)
215 } else {
216 let d = shape - 1.0 / 3.0;
217 let c = 1.0 / (9.0_f64 * d).sqrt();
218 loop {
219 let x = standard_normal();
220 let v = (1.0 + c * x).powi(3);
221 if v <= 0.0 {
222 continue;
223 }
224 let u = rng_f64().max(1e-10);
225 if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
226 return d * v;
227 }
228 }
229 }
230}
231
232fn standard_normal() -> f64 {
233 let u1: f64 = rng_f64().max(1e-10);
234 let u2: f64 = rng_f64();
235 (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn bandit_default_has_three_arms() {
244 let b = ThresholdBandit::default();
245 assert_eq!(b.arms.len(), 3);
246 assert_eq!(b.arms[0].name, "conservative");
247 assert_eq!(b.arms[1].name, "balanced");
248 assert_eq!(b.arms[2].name, "aggressive");
249 }
250
251 #[test]
252 fn bandit_selection_works() {
253 let mut b = ThresholdBandit::default();
254 for _ in 0..10 {
255 let arm = b.select_arm();
256 let _ = arm.name.clone();
257 }
258 assert_eq!(b.total_pulls, 10);
259 }
260
261 #[test]
262 fn bandit_update_shifts_distribution() {
263 let mut b = ThresholdBandit::default();
264 for _ in 0..20 {
265 b.update("aggressive", true);
266 }
267 for _ in 0..20 {
268 b.update("conservative", false);
269 }
270 let agg = b.arms.iter().find(|a| a.name == "aggressive").unwrap();
271 let con = b.arms.iter().find(|a| a.name == "conservative").unwrap();
272 assert!(agg.mean() > con.mean());
273 }
274
275 #[test]
276 fn beta_sample_in_range() {
277 for _ in 0..100 {
278 let s = beta_sample(2.0, 2.0);
279 assert!((0.0..=1.0).contains(&s), "got {s}");
280 }
281 }
282
283 #[test]
284 fn store_save_load_roundtrip() {
285 let dir = std::env::temp_dir().join("bandit-test");
286 let root = dir.to_string_lossy().to_string();
287 let mut store = BanditStore::default();
288 store.get_or_create("rs_medium");
289 store.save(&root).unwrap();
290 let loaded = BanditStore::load(&root);
291 assert!(loaded.bandits.contains_key("rs_medium"));
292 let _ = std::fs::remove_dir_all(&dir);
293 }
294}