lean_ctx/core/
provider_bandit.rs1use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProviderArm {
20 pub name: String,
21 pub alpha: f64,
22 pub beta: f64,
23 pub pulls: u64,
24}
25
26impl ProviderArm {
27 pub fn sample(&self) -> f64 {
28 beta_sample(self.alpha, self.beta)
29 }
30
31 pub fn update_success(&mut self) {
32 self.alpha += 1.0;
33 self.pulls += 1;
34 }
35
36 pub fn update_failure(&mut self) {
37 self.beta += 1.0;
38 self.pulls += 1;
39 }
40
41 pub fn mean(&self) -> f64 {
42 self.alpha / (self.alpha + self.beta)
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ProviderBandit {
49 pub arms: HashMap<String, ProviderArm>,
50}
51
52impl Default for ProviderBandit {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl ProviderBandit {
59 pub fn new() -> Self {
60 Self {
61 arms: HashMap::new(),
62 }
63 }
64
65 pub fn select_provider(
68 &mut self,
69 task_type: &str,
70 available_providers: &[String],
71 ) -> Option<String> {
72 if available_providers.is_empty() {
73 return None;
74 }
75
76 if available_providers.len() == 1 {
77 return Some(available_providers[0].clone());
78 }
79
80 let mut best_sample = f64::NEG_INFINITY;
81 let mut best_provider = &available_providers[0];
82
83 for provider_id in available_providers {
84 let key = arm_key(task_type, provider_id);
85 let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
86 name: provider_id.clone(),
87 alpha: 1.0,
88 beta: 1.0,
89 pulls: 0,
90 });
91
92 let sample = arm.sample();
93 if sample > best_sample {
94 best_sample = sample;
95 best_provider = provider_id;
96 }
97 }
98
99 Some(best_provider.clone())
100 }
101
102 pub fn update(&mut self, task_type: &str, provider_id: &str, was_useful: bool) {
104 let key = arm_key(task_type, provider_id);
105 let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
106 name: provider_id.to_string(),
107 alpha: 1.0,
108 beta: 1.0,
109 pulls: 0,
110 });
111
112 if was_useful {
113 arm.update_success();
114 } else {
115 arm.update_failure();
116 }
117 }
118
119 pub fn estimated_probability(&self, task_type: &str, provider_id: &str) -> f64 {
121 let key = arm_key(task_type, provider_id);
122 self.arms.get(&key).map_or(0.5, ProviderArm::mean)
123 }
124
125 pub fn load(project_root: &str) -> Self {
129 let path = provider_bandit_path(project_root);
130 if let Ok(content) = std::fs::read_to_string(&path) {
131 if let Ok(bandit) = serde_json::from_str::<ProviderBandit>(&content) {
132 return bandit;
133 }
134 }
135 Self::new()
136 }
137
138 pub fn save(&self, project_root: &str) -> Result<(), String> {
140 let path = provider_bandit_path(project_root);
141 if let Some(parent) = path.parent() {
142 std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
143 }
144 let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
145 std::fs::write(path, json).map_err(|e| e.to_string())
146 }
147
148 pub fn format_report(&self) -> String {
150 let mut out = String::from("Provider Bandit Arms:\n");
151 let mut keys: Vec<_> = self.arms.keys().collect();
152 keys.sort();
153
154 for key in keys {
155 let arm = &self.arms[key];
156 out.push_str(&format!(
157 " {} — alpha={:.1} beta={:.1} mean={:.3} pulls={}\n",
158 key,
159 arm.alpha,
160 arm.beta,
161 arm.mean(),
162 arm.pulls,
163 ));
164 }
165 out
166 }
167}
168
169fn arm_key(task_type: &str, provider_id: &str) -> String {
170 format!("{task_type}:{provider_id}")
171}
172
173fn provider_bandit_path(project_root: &str) -> std::path::PathBuf {
174 let hash = crate::core::project_hash::hash_project_root(project_root);
175 crate::core::data_dir::lean_ctx_data_dir()
176 .unwrap_or_else(|_| std::path::PathBuf::from("."))
177 .join("projects")
178 .join(hash)
179 .join("provider_bandit.json")
180}
181
182fn beta_sample(alpha: f64, beta: f64) -> f64 {
184 let x = gamma_sample(alpha);
185 let y = gamma_sample(beta);
186 if x + y == 0.0 {
187 return 0.5;
188 }
189 (x / (x + y)).clamp(0.0, 1.0)
190}
191
192#[allow(clippy::many_single_char_names)]
194fn gamma_sample(shape: f64) -> f64 {
195 if shape < 1.0 {
196 return gamma_sample(shape + 1.0) * rng_f64().powf(1.0 / shape);
197 }
198 let d = shape - 1.0 / 3.0;
199 let c = 1.0 / (9.0 * d).sqrt();
200 loop {
201 let x = standard_normal();
202 let v_base = 1.0 + c * x;
203 if v_base <= 0.0 {
204 continue;
205 }
206 let v = v_base * v_base * v_base;
207 let u = rng_f64();
208 if u < 1.0 - 0.0331 * (x * x) * (x * x) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
209 return d * v;
210 }
211 }
212}
213
214fn standard_normal() -> f64 {
215 let u1: f64 = rng_f64().max(1e-10);
216 let u2: f64 = rng_f64();
217 (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
218}
219
220fn rng_f64() -> f64 {
221 use std::hash::{Hash, Hasher};
222 let mut hasher = std::collections::hash_map::DefaultHasher::new();
223 std::time::Instant::now().hash(&mut hasher);
224 std::thread::current().id().hash(&mut hasher);
225 (hasher.finish() as f64) / (u64::MAX as f64)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn select_from_single_provider() {
234 let mut bandit = ProviderBandit::new();
235 let providers = vec!["github".into()];
236
237 let selected = bandit.select_provider("bugfix", &providers);
238 assert_eq!(selected.as_deref(), Some("github"));
239 }
240
241 #[test]
242 fn select_from_empty_returns_none() {
243 let mut bandit = ProviderBandit::new();
244 let selected = bandit.select_provider("bugfix", &[]);
245 assert!(selected.is_none());
246 }
247
248 #[test]
249 fn update_shifts_distribution() {
250 let mut bandit = ProviderBandit::new();
251 let providers = vec!["github".into(), "jira".into()];
252
253 for _ in 0..20 {
255 bandit.update("bugfix", "github", true);
256 bandit.update("bugfix", "jira", false);
257 }
258
259 let gh_prob = bandit.estimated_probability("bugfix", "github");
260 let jira_prob = bandit.estimated_probability("bugfix", "jira");
261 assert!(gh_prob > 0.8);
262 assert!(jira_prob < 0.2);
263
264 let mut github_selected = 0;
266 for _ in 0..100 {
267 let selected = bandit.select_provider("bugfix", &providers).unwrap();
268 if selected == "github" {
269 github_selected += 1;
270 }
271 }
272 assert!(github_selected > 80);
273 }
274
275 #[test]
276 fn different_task_types_have_independent_arms() {
277 let mut bandit = ProviderBandit::new();
278
279 bandit.update("bugfix", "github", true);
280 bandit.update("feature", "jira", true);
281
282 assert!(bandit.estimated_probability("bugfix", "github") > 0.5);
283 assert!(bandit.estimated_probability("feature", "jira") > 0.5);
284 assert!((bandit.estimated_probability("bugfix", "jira") - 0.5).abs() < f64::EPSILON);
285 }
286
287 #[test]
288 fn format_report_shows_all_arms() {
289 let mut bandit = ProviderBandit::new();
290 bandit.update("bugfix", "github", true);
291 bandit.update("bugfix", "jira", false);
292
293 let report = bandit.format_report();
294 assert!(report.contains("bugfix:github"));
295 assert!(report.contains("bugfix:jira"));
296 }
297
298 #[test]
299 fn persistence_roundtrip_preserves_learning() {
300 let _env = crate::core::data_dir::test_env_lock();
301 let data_dir = tempfile::tempdir().unwrap();
302 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
303 let project = "/tmp/provider-bandit-roundtrip";
304
305 let mut bandit = ProviderBandit::new();
306 for _ in 0..10 {
307 bandit.update("bugfix", "github", true);
308 }
309 bandit.save(project).expect("save");
310
311 let reloaded = ProviderBandit::load(project);
312 assert!(
313 reloaded.estimated_probability("bugfix", "github") > 0.8,
314 "reloaded bandit must retain the learned preference"
315 );
316 let fresh = ProviderBandit::load("/tmp/provider-bandit-unseen");
318 assert!((fresh.estimated_probability("bugfix", "github") - 0.5).abs() < f64::EPSILON);
319
320 std::env::remove_var("LEAN_CTX_DATA_DIR");
321 }
322}