lean_ctx/core/
provider_bandit.rs1use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProviderArm {
20 pub name: String,
21 pub alpha: f64,
22 pub beta: f64,
23 pub pulls: u64,
24}
25
26impl ProviderArm {
27 pub fn sample(&self) -> f64 {
28 beta_sample(self.alpha, self.beta)
29 }
30
31 pub fn update_success(&mut self) {
32 self.alpha += 1.0;
33 self.pulls += 1;
34 }
35
36 pub fn update_failure(&mut self) {
37 self.beta += 1.0;
38 self.pulls += 1;
39 }
40
41 pub fn mean(&self) -> f64 {
42 self.alpha / (self.alpha + self.beta)
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ProviderBandit {
49 pub arms: HashMap<String, ProviderArm>,
50}
51
52impl Default for ProviderBandit {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl ProviderBandit {
59 pub fn new() -> Self {
60 Self {
61 arms: HashMap::new(),
62 }
63 }
64
65 pub fn select_provider(
68 &mut self,
69 task_type: &str,
70 available_providers: &[String],
71 ) -> Option<String> {
72 if available_providers.is_empty() {
73 return None;
74 }
75
76 if available_providers.len() == 1 {
77 return Some(available_providers[0].clone());
78 }
79
80 let mut best_sample = f64::NEG_INFINITY;
81 let mut best_provider = &available_providers[0];
82
83 for provider_id in available_providers {
84 let key = arm_key(task_type, provider_id);
85 let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
86 name: provider_id.clone(),
87 alpha: 1.0,
88 beta: 1.0,
89 pulls: 0,
90 });
91
92 let sample = arm.sample();
93 if sample > best_sample {
94 best_sample = sample;
95 best_provider = provider_id;
96 }
97 }
98
99 Some(best_provider.clone())
100 }
101
102 pub fn update(&mut self, task_type: &str, provider_id: &str, was_useful: bool) {
104 let key = arm_key(task_type, provider_id);
105 let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
106 name: provider_id.to_string(),
107 alpha: 1.0,
108 beta: 1.0,
109 pulls: 0,
110 });
111
112 if was_useful {
113 arm.update_success();
114 } else {
115 arm.update_failure();
116 }
117 }
118
119 pub fn estimated_probability(&self, task_type: &str, provider_id: &str) -> f64 {
121 let key = arm_key(task_type, provider_id);
122 self.arms.get(&key).map_or(0.5, ProviderArm::mean)
123 }
124
125 pub fn format_report(&self) -> String {
127 let mut out = String::from("Provider Bandit Arms:\n");
128 let mut keys: Vec<_> = self.arms.keys().collect();
129 keys.sort();
130
131 for key in keys {
132 let arm = &self.arms[key];
133 out.push_str(&format!(
134 " {} — alpha={:.1} beta={:.1} mean={:.3} pulls={}\n",
135 key,
136 arm.alpha,
137 arm.beta,
138 arm.mean(),
139 arm.pulls,
140 ));
141 }
142 out
143 }
144}
145
146fn arm_key(task_type: &str, provider_id: &str) -> String {
147 format!("{task_type}:{provider_id}")
148}
149
150fn beta_sample(alpha: f64, beta: f64) -> f64 {
152 let x = gamma_sample(alpha);
153 let y = gamma_sample(beta);
154 if x + y == 0.0 {
155 return 0.5;
156 }
157 (x / (x + y)).clamp(0.0, 1.0)
158}
159
160#[allow(clippy::many_single_char_names)]
162fn gamma_sample(shape: f64) -> f64 {
163 if shape < 1.0 {
164 return gamma_sample(shape + 1.0) * rng_f64().powf(1.0 / shape);
165 }
166 let d = shape - 1.0 / 3.0;
167 let c = 1.0 / (9.0 * d).sqrt();
168 loop {
169 let x = standard_normal();
170 let v_base = 1.0 + c * x;
171 if v_base <= 0.0 {
172 continue;
173 }
174 let v = v_base * v_base * v_base;
175 let u = rng_f64();
176 if u < 1.0 - 0.0331 * (x * x) * (x * x) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
177 return d * v;
178 }
179 }
180}
181
182fn standard_normal() -> f64 {
183 let u1: f64 = rng_f64().max(1e-10);
184 let u2: f64 = rng_f64();
185 (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
186}
187
188fn rng_f64() -> f64 {
189 use std::hash::{Hash, Hasher};
190 let mut hasher = std::collections::hash_map::DefaultHasher::new();
191 std::time::Instant::now().hash(&mut hasher);
192 std::thread::current().id().hash(&mut hasher);
193 (hasher.finish() as f64) / (u64::MAX as f64)
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn select_from_single_provider() {
202 let mut bandit = ProviderBandit::new();
203 let providers = vec!["github".into()];
204
205 let selected = bandit.select_provider("bugfix", &providers);
206 assert_eq!(selected.as_deref(), Some("github"));
207 }
208
209 #[test]
210 fn select_from_empty_returns_none() {
211 let mut bandit = ProviderBandit::new();
212 let selected = bandit.select_provider("bugfix", &[]);
213 assert!(selected.is_none());
214 }
215
216 #[test]
217 fn update_shifts_distribution() {
218 let mut bandit = ProviderBandit::new();
219 let providers = vec!["github".into(), "jira".into()];
220
221 for _ in 0..20 {
223 bandit.update("bugfix", "github", true);
224 bandit.update("bugfix", "jira", false);
225 }
226
227 let gh_prob = bandit.estimated_probability("bugfix", "github");
228 let jira_prob = bandit.estimated_probability("bugfix", "jira");
229 assert!(gh_prob > 0.8);
230 assert!(jira_prob < 0.2);
231
232 let mut github_selected = 0;
234 for _ in 0..100 {
235 let selected = bandit.select_provider("bugfix", &providers).unwrap();
236 if selected == "github" {
237 github_selected += 1;
238 }
239 }
240 assert!(github_selected > 80);
241 }
242
243 #[test]
244 fn different_task_types_have_independent_arms() {
245 let mut bandit = ProviderBandit::new();
246
247 bandit.update("bugfix", "github", true);
248 bandit.update("feature", "jira", true);
249
250 assert!(bandit.estimated_probability("bugfix", "github") > 0.5);
251 assert!(bandit.estimated_probability("feature", "jira") > 0.5);
252 assert!((bandit.estimated_probability("bugfix", "jira") - 0.5).abs() < f64::EPSILON);
253 }
254
255 #[test]
256 fn format_report_shows_all_arms() {
257 let mut bandit = ProviderBandit::new();
258 bandit.update("bugfix", "github", true);
259 bandit.update("bugfix", "jira", false);
260
261 let report = bandit.format_report();
262 assert!(report.contains("bugfix:github"));
263 assert!(report.contains("bugfix:jira"));
264 }
265}