Skip to main content

lean_ctx/core/
provider_bandit.rs

1//! Provider Bandit — Thompson Sampling for provider selection.
2//!
3//! Extends the existing bandit system to learn which data providers are
4//! most informative for different task types. When multiple providers
5//! are available, the bandit samples from Beta distributions to select
6//! the provider most likely to yield useful context.
7//!
8//! Scientific basis: Dopaminergic prediction errors (Schultz 1997;
9//! Nature Neurosci 2025). Positive prediction errors (provider was more
10//! useful than expected) increase the Beta alpha parameter. Negative
11//! errors decrease it.
12
13use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16
17/// A provider-specific bandit arm (simplified Beta-Bernoulli).
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProviderArm {
20    pub name: String,
21    pub alpha: f64,
22    pub beta: f64,
23    pub pulls: u64,
24}
25
26impl ProviderArm {
27    pub fn sample(&self) -> f64 {
28        beta_sample(self.alpha, self.beta)
29    }
30
31    pub fn update_success(&mut self) {
32        self.alpha += 1.0;
33        self.pulls += 1;
34    }
35
36    pub fn update_failure(&mut self) {
37        self.beta += 1.0;
38        self.pulls += 1;
39    }
40
41    pub fn mean(&self) -> f64 {
42        self.alpha / (self.alpha + self.beta)
43    }
44}
45
46/// Per-provider arms, keyed by task type (e.g., "bugfix", "feature", "refactor").
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ProviderBandit {
49    pub arms: HashMap<String, ProviderArm>,
50}
51
52impl Default for ProviderBandit {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl ProviderBandit {
59    pub fn new() -> Self {
60        Self {
61            arms: HashMap::new(),
62        }
63    }
64
65    /// Select the best provider for a given task type using Thompson Sampling.
66    /// Returns the provider_id with the highest sampled value.
67    pub fn select_provider(
68        &mut self,
69        task_type: &str,
70        available_providers: &[String],
71    ) -> Option<String> {
72        if available_providers.is_empty() {
73            return None;
74        }
75
76        if available_providers.len() == 1 {
77            return Some(available_providers[0].clone());
78        }
79
80        let mut best_sample = f64::NEG_INFINITY;
81        let mut best_provider = &available_providers[0];
82
83        for provider_id in available_providers {
84            let key = arm_key(task_type, provider_id);
85            let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
86                name: provider_id.clone(),
87                alpha: 1.0,
88                beta: 1.0,
89                pulls: 0,
90            });
91
92            let sample = arm.sample();
93            if sample > best_sample {
94                best_sample = sample;
95                best_provider = provider_id;
96            }
97        }
98
99        Some(best_provider.clone())
100    }
101
102    /// Update the bandit after observing the outcome of a provider query.
103    pub fn update(&mut self, task_type: &str, provider_id: &str, was_useful: bool) {
104        let key = arm_key(task_type, provider_id);
105        let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
106            name: provider_id.to_string(),
107            alpha: 1.0,
108            beta: 1.0,
109            pulls: 0,
110        });
111
112        if was_useful {
113            arm.update_success();
114        } else {
115            arm.update_failure();
116        }
117    }
118
119    /// Get the estimated success probability for a provider on a task type.
120    pub fn estimated_probability(&self, task_type: &str, provider_id: &str) -> f64 {
121        let key = arm_key(task_type, provider_id);
122        self.arms.get(&key).map_or(0.5, ProviderArm::mean)
123    }
124
125    /// Format a summary of all arms for debugging/logging.
126    pub fn format_report(&self) -> String {
127        let mut out = String::from("Provider Bandit Arms:\n");
128        let mut keys: Vec<_> = self.arms.keys().collect();
129        keys.sort();
130
131        for key in keys {
132            let arm = &self.arms[key];
133            out.push_str(&format!(
134                "  {} — alpha={:.1} beta={:.1} mean={:.3} pulls={}\n",
135                key,
136                arm.alpha,
137                arm.beta,
138                arm.mean(),
139                arm.pulls,
140            ));
141        }
142        out
143    }
144}
145
146fn arm_key(task_type: &str, provider_id: &str) -> String {
147    format!("{task_type}:{provider_id}")
148}
149
150/// Simple Beta distribution sample using the ratio of two Gamma samples.
151fn beta_sample(alpha: f64, beta: f64) -> f64 {
152    let x = gamma_sample(alpha);
153    let y = gamma_sample(beta);
154    if x + y == 0.0 {
155        return 0.5;
156    }
157    (x / (x + y)).clamp(0.0, 1.0)
158}
159
160/// Gamma(shape, 1) sample using Marsaglia & Tsang's method.
161#[allow(clippy::many_single_char_names)]
162fn gamma_sample(shape: f64) -> f64 {
163    if shape < 1.0 {
164        return gamma_sample(shape + 1.0) * rng_f64().powf(1.0 / shape);
165    }
166    let d = shape - 1.0 / 3.0;
167    let c = 1.0 / (9.0 * d).sqrt();
168    loop {
169        let x = standard_normal();
170        let v_base = 1.0 + c * x;
171        if v_base <= 0.0 {
172            continue;
173        }
174        let v = v_base * v_base * v_base;
175        let u = rng_f64();
176        if u < 1.0 - 0.0331 * (x * x) * (x * x) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
177            return d * v;
178        }
179    }
180}
181
182fn standard_normal() -> f64 {
183    let u1: f64 = rng_f64().max(1e-10);
184    let u2: f64 = rng_f64();
185    (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
186}
187
188fn rng_f64() -> f64 {
189    use std::hash::{Hash, Hasher};
190    let mut hasher = std::collections::hash_map::DefaultHasher::new();
191    std::time::Instant::now().hash(&mut hasher);
192    std::thread::current().id().hash(&mut hasher);
193    (hasher.finish() as f64) / (u64::MAX as f64)
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn select_from_single_provider() {
202        let mut bandit = ProviderBandit::new();
203        let providers = vec!["github".into()];
204
205        let selected = bandit.select_provider("bugfix", &providers);
206        assert_eq!(selected.as_deref(), Some("github"));
207    }
208
209    #[test]
210    fn select_from_empty_returns_none() {
211        let mut bandit = ProviderBandit::new();
212        let selected = bandit.select_provider("bugfix", &[]);
213        assert!(selected.is_none());
214    }
215
216    #[test]
217    fn update_shifts_distribution() {
218        let mut bandit = ProviderBandit::new();
219        let providers = vec!["github".into(), "jira".into()];
220
221        // Train: github is always useful for bugfix
222        for _ in 0..20 {
223            bandit.update("bugfix", "github", true);
224            bandit.update("bugfix", "jira", false);
225        }
226
227        let gh_prob = bandit.estimated_probability("bugfix", "github");
228        let jira_prob = bandit.estimated_probability("bugfix", "jira");
229        assert!(gh_prob > 0.8);
230        assert!(jira_prob < 0.2);
231
232        // Should strongly prefer github for bugfix tasks.
233        let mut github_selected = 0;
234        for _ in 0..100 {
235            let selected = bandit.select_provider("bugfix", &providers).unwrap();
236            if selected == "github" {
237                github_selected += 1;
238            }
239        }
240        assert!(github_selected > 80);
241    }
242
243    #[test]
244    fn different_task_types_have_independent_arms() {
245        let mut bandit = ProviderBandit::new();
246
247        bandit.update("bugfix", "github", true);
248        bandit.update("feature", "jira", true);
249
250        assert!(bandit.estimated_probability("bugfix", "github") > 0.5);
251        assert!(bandit.estimated_probability("feature", "jira") > 0.5);
252        assert!((bandit.estimated_probability("bugfix", "jira") - 0.5).abs() < f64::EPSILON);
253    }
254
255    #[test]
256    fn format_report_shows_all_arms() {
257        let mut bandit = ProviderBandit::new();
258        bandit.update("bugfix", "github", true);
259        bandit.update("bugfix", "jira", false);
260
261        let report = bandit.format_report();
262        assert!(report.contains("bugfix:github"));
263        assert!(report.contains("bugfix:jira"));
264    }
265}