Skip to main content

lean_ctx/core/
provider_bandit.rs

1//! Provider Bandit — Thompson Sampling for provider selection.
2//!
3//! Extends the existing bandit system to learn which data providers are
4//! most informative for different task types. When multiple providers
5//! are available, the bandit samples from Beta distributions to select
6//! the provider most likely to yield useful context.
7//!
8//! Scientific basis: Dopaminergic prediction errors (Schultz 1997;
9//! Nature Neurosci 2025). Positive prediction errors (provider was more
10//! useful than expected) increase the Beta alpha parameter. Negative
11//! errors decrease it.
12
13use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16
17/// A provider-specific bandit arm (simplified Beta-Bernoulli).
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProviderArm {
20    pub name: String,
21    pub alpha: f64,
22    pub beta: f64,
23    pub pulls: u64,
24}
25
26impl ProviderArm {
27    pub fn sample(&self) -> f64 {
28        beta_sample(self.alpha, self.beta)
29    }
30
31    pub fn update_success(&mut self) {
32        self.alpha += 1.0;
33        self.pulls += 1;
34    }
35
36    pub fn update_failure(&mut self) {
37        self.beta += 1.0;
38        self.pulls += 1;
39    }
40
41    pub fn mean(&self) -> f64 {
42        self.alpha / (self.alpha + self.beta)
43    }
44}
45
46/// Per-provider arms, keyed by task type (e.g., "bugfix", "feature", "refactor").
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ProviderBandit {
49    pub arms: HashMap<String, ProviderArm>,
50}
51
52impl Default for ProviderBandit {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl ProviderBandit {
59    pub fn new() -> Self {
60        Self {
61            arms: HashMap::new(),
62        }
63    }
64
65    /// Select the best provider for a given task type using Thompson Sampling.
66    /// Returns the provider_id with the highest sampled value.
67    pub fn select_provider(
68        &mut self,
69        task_type: &str,
70        available_providers: &[String],
71    ) -> Option<String> {
72        if available_providers.is_empty() {
73            return None;
74        }
75
76        if available_providers.len() == 1 {
77            return Some(available_providers[0].clone());
78        }
79
80        let mut best_sample = f64::NEG_INFINITY;
81        let mut best_provider = &available_providers[0];
82
83        for provider_id in available_providers {
84            let key = arm_key(task_type, provider_id);
85            let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
86                name: provider_id.clone(),
87                alpha: 1.0,
88                beta: 1.0,
89                pulls: 0,
90            });
91
92            let sample = arm.sample();
93            if sample > best_sample {
94                best_sample = sample;
95                best_provider = provider_id;
96            }
97        }
98
99        Some(best_provider.clone())
100    }
101
102    /// Update the bandit after observing the outcome of a provider query.
103    pub fn update(&mut self, task_type: &str, provider_id: &str, was_useful: bool) {
104        let key = arm_key(task_type, provider_id);
105        let arm = self.arms.entry(key).or_insert_with(|| ProviderArm {
106            name: provider_id.to_string(),
107            alpha: 1.0,
108            beta: 1.0,
109            pulls: 0,
110        });
111
112        if was_useful {
113            arm.update_success();
114        } else {
115            arm.update_failure();
116        }
117    }
118
119    /// Get the estimated success probability for a provider on a task type.
120    pub fn estimated_probability(&self, task_type: &str, provider_id: &str) -> f64 {
121        let key = arm_key(task_type, provider_id);
122        self.arms.get(&key).map_or(0.5, ProviderArm::mean)
123    }
124
125    /// Load the persisted bandit for a project, or a fresh one if none exists.
126    /// Persistence is what turns the preloader from a per-call heuristic into a
127    /// model that genuinely learns which providers pay off for which task types.
128    pub fn load(project_root: &str) -> Self {
129        let path = provider_bandit_path(project_root);
130        if let Ok(content) = std::fs::read_to_string(&path) {
131            if let Ok(bandit) = serde_json::from_str::<ProviderBandit>(&content) {
132                return bandit;
133            }
134        }
135        Self::new()
136    }
137
138    /// Persist the bandit's learned arms for this project.
139    pub fn save(&self, project_root: &str) -> Result<(), String> {
140        let path = provider_bandit_path(project_root);
141        if let Some(parent) = path.parent() {
142            std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
143        }
144        let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
145        std::fs::write(path, json).map_err(|e| e.to_string())
146    }
147
148    /// Format a summary of all arms for debugging/logging.
149    pub fn format_report(&self) -> String {
150        let mut out = String::from("Provider Bandit Arms:\n");
151        let mut keys: Vec<_> = self.arms.keys().collect();
152        keys.sort();
153
154        for key in keys {
155            let arm = &self.arms[key];
156            out.push_str(&format!(
157                "  {} — alpha={:.1} beta={:.1} mean={:.3} pulls={}\n",
158                key,
159                arm.alpha,
160                arm.beta,
161                arm.mean(),
162                arm.pulls,
163            ));
164        }
165        out
166    }
167}
168
169fn arm_key(task_type: &str, provider_id: &str) -> String {
170    format!("{task_type}:{provider_id}")
171}
172
173fn provider_bandit_path(project_root: &str) -> std::path::PathBuf {
174    let hash = crate::core::project_hash::hash_project_root(project_root);
175    crate::core::data_dir::lean_ctx_data_dir()
176        .unwrap_or_else(|_| std::path::PathBuf::from("."))
177        .join("projects")
178        .join(hash)
179        .join("provider_bandit.json")
180}
181
182/// Simple Beta distribution sample using the ratio of two Gamma samples.
183fn beta_sample(alpha: f64, beta: f64) -> f64 {
184    let x = gamma_sample(alpha);
185    let y = gamma_sample(beta);
186    if x + y == 0.0 {
187        return 0.5;
188    }
189    (x / (x + y)).clamp(0.0, 1.0)
190}
191
192/// Gamma(shape, 1) sample using Marsaglia & Tsang's method.
193#[allow(clippy::many_single_char_names)]
194fn gamma_sample(shape: f64) -> f64 {
195    if shape < 1.0 {
196        return gamma_sample(shape + 1.0) * rng_f64().powf(1.0 / shape);
197    }
198    let d = shape - 1.0 / 3.0;
199    let c = 1.0 / (9.0 * d).sqrt();
200    loop {
201        let x = standard_normal();
202        let v_base = 1.0 + c * x;
203        if v_base <= 0.0 {
204            continue;
205        }
206        let v = v_base * v_base * v_base;
207        let u = rng_f64();
208        if u < 1.0 - 0.0331 * (x * x) * (x * x) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
209            return d * v;
210        }
211    }
212}
213
214fn standard_normal() -> f64 {
215    let u1: f64 = rng_f64().max(1e-10);
216    let u2: f64 = rng_f64();
217    (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos()
218}
219
220fn rng_f64() -> f64 {
221    use std::hash::{Hash, Hasher};
222    let mut hasher = std::collections::hash_map::DefaultHasher::new();
223    std::time::Instant::now().hash(&mut hasher);
224    std::thread::current().id().hash(&mut hasher);
225    (hasher.finish() as f64) / (u64::MAX as f64)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn select_from_single_provider() {
234        let mut bandit = ProviderBandit::new();
235        let providers = vec!["github".into()];
236
237        let selected = bandit.select_provider("bugfix", &providers);
238        assert_eq!(selected.as_deref(), Some("github"));
239    }
240
241    #[test]
242    fn select_from_empty_returns_none() {
243        let mut bandit = ProviderBandit::new();
244        let selected = bandit.select_provider("bugfix", &[]);
245        assert!(selected.is_none());
246    }
247
248    #[test]
249    fn update_shifts_distribution() {
250        let mut bandit = ProviderBandit::new();
251        let providers = vec!["github".into(), "jira".into()];
252
253        // Train: github is always useful for bugfix
254        for _ in 0..20 {
255            bandit.update("bugfix", "github", true);
256            bandit.update("bugfix", "jira", false);
257        }
258
259        let gh_prob = bandit.estimated_probability("bugfix", "github");
260        let jira_prob = bandit.estimated_probability("bugfix", "jira");
261        assert!(gh_prob > 0.8);
262        assert!(jira_prob < 0.2);
263
264        // Should strongly prefer github for bugfix tasks.
265        let mut github_selected = 0;
266        for _ in 0..100 {
267            let selected = bandit.select_provider("bugfix", &providers).unwrap();
268            if selected == "github" {
269                github_selected += 1;
270            }
271        }
272        assert!(github_selected > 80);
273    }
274
275    #[test]
276    fn different_task_types_have_independent_arms() {
277        let mut bandit = ProviderBandit::new();
278
279        bandit.update("bugfix", "github", true);
280        bandit.update("feature", "jira", true);
281
282        assert!(bandit.estimated_probability("bugfix", "github") > 0.5);
283        assert!(bandit.estimated_probability("feature", "jira") > 0.5);
284        assert!((bandit.estimated_probability("bugfix", "jira") - 0.5).abs() < f64::EPSILON);
285    }
286
287    #[test]
288    fn format_report_shows_all_arms() {
289        let mut bandit = ProviderBandit::new();
290        bandit.update("bugfix", "github", true);
291        bandit.update("bugfix", "jira", false);
292
293        let report = bandit.format_report();
294        assert!(report.contains("bugfix:github"));
295        assert!(report.contains("bugfix:jira"));
296    }
297
298    #[test]
299    fn persistence_roundtrip_preserves_learning() {
300        let _env = crate::core::data_dir::test_env_lock();
301        let data_dir = tempfile::tempdir().unwrap();
302        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
303        let project = "/tmp/provider-bandit-roundtrip";
304
305        let mut bandit = ProviderBandit::new();
306        for _ in 0..10 {
307            bandit.update("bugfix", "github", true);
308        }
309        bandit.save(project).expect("save");
310
311        let reloaded = ProviderBandit::load(project);
312        assert!(
313            reloaded.estimated_probability("bugfix", "github") > 0.8,
314            "reloaded bandit must retain the learned preference"
315        );
316        // A fresh project starts unbiased (no cross-project leakage).
317        let fresh = ProviderBandit::load("/tmp/provider-bandit-unseen");
318        assert!((fresh.estimated_probability("bugfix", "github") - 0.5).abs() < f64::EPSILON);
319
320        std::env::remove_var("LEAN_CTX_DATA_DIR");
321    }
322}