Skip to main content

llm_optimizer_decision/
thompson_sampling.rs

1//! Thompson Sampling implementation for multi-armed bandit optimization
2//!
3//! This module implements Thompson Sampling using Beta distributions for
4//! traffic allocation in A/B tests. It provides adaptive traffic routing
5//! that balances exploration and exploitation.
6
7use rand_distr::{Beta, Distribution};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12use crate::errors::{DecisionError, Result};
13
14/// Thompson Sampling bandit for variant selection
15#[derive(Debug, Clone)]
16pub struct ThompsonSampling {
17    /// Arms (variants) in the bandit
18    arms: HashMap<Uuid, BanditArm>,
19    /// Total samples drawn
20    total_samples: u64,
21}
22
23/// A single arm in the multi-armed bandit
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BanditArm {
26    /// Variant ID
27    pub variant_id: Uuid,
28    /// Number of successes (conversions, positive outcomes)
29    pub successes: f64,
30    /// Number of failures (non-conversions, negative outcomes)
31    pub failures: f64,
32    /// Total trials
33    pub trials: u64,
34}
35
36impl BanditArm {
37    /// Create a new bandit arm with prior
38    pub fn new(variant_id: Uuid) -> Self {
39        Self {
40            variant_id,
41            successes: 1.0, // Prior: Beta(1, 1) is uniform distribution
42            failures: 1.0,
43            trials: 0,
44        }
45    }
46
47    /// Update arm with observation
48    pub fn update(&mut self, success: bool) {
49        if success {
50            self.successes += 1.0;
51        } else {
52            self.failures += 1.0;
53        }
54        self.trials += 1;
55    }
56
57    /// Get conversion rate (mean of Beta distribution)
58    pub fn conversion_rate(&self) -> f64 {
59        self.successes / (self.successes + self.failures)
60    }
61
62    /// Get credible interval (Bayesian confidence interval)
63    pub fn credible_interval(&self, confidence: f64) -> (f64, f64) {
64        use statrs::distribution::Beta as BetaDist;
65
66        let _beta = BetaDist::new(self.successes, self.failures).unwrap();
67        let lower = (1.0 - confidence) / 2.0;
68        let _upper = 1.0 - lower;
69
70        // Approximate quantiles (for production, use proper quantile function)
71        let mean = self.conversion_rate();
72        let std = (self.successes * self.failures /
73                   ((self.successes + self.failures).powi(2) *
74                    (self.successes + self.failures + 1.0))).sqrt();
75
76        (
77            (mean - 1.96 * std).max(0.0),
78            (mean + 1.96 * std).min(1.0),
79        )
80    }
81
82    /// Sample from the Beta distribution
83    pub fn sample(&self) -> Result<f64> {
84        let beta = Beta::new(self.successes, self.failures)
85            .map_err(|e| DecisionError::StatisticalError(
86                format!("Failed to create Beta distribution: {}", e)
87            ))?;
88        
89        let mut rng = rand::thread_rng();
90        Ok(beta.sample(&mut rng))
91    }
92}
93
94impl ThompsonSampling {
95    /// Create a new Thompson Sampling instance
96    pub fn new() -> Self {
97        Self {
98            arms: HashMap::new(),
99            total_samples: 0,
100        }
101    }
102
103    /// Add a new variant
104    pub fn add_variant(&mut self, variant_id: Uuid) {
105        self.arms.insert(variant_id, BanditArm::new(variant_id));
106    }
107
108    /// Remove a variant
109    pub fn remove_variant(&mut self, variant_id: &Uuid) {
110        self.arms.remove(variant_id);
111    }
112
113    /// Select a variant using Thompson Sampling
114    ///
115    /// This samples from each arm's Beta distribution and selects
116    /// the arm with the highest sampled value.
117    pub fn select_variant(&self) -> Result<Uuid> {
118        if self.arms.is_empty() {
119            return Err(DecisionError::InvalidState(
120                "No variants available for selection".to_string()
121            ));
122        }
123
124        let mut best_variant = None;
125        let mut best_sample = f64::MIN;
126
127        for (variant_id, arm) in &self.arms {
128            let sample = arm.sample()?;
129            if sample > best_sample {
130                best_sample = sample;
131                best_variant = Some(*variant_id);
132            }
133        }
134
135        best_variant.ok_or_else(|| 
136            DecisionError::AllocationError("Failed to select variant".to_string())
137        )
138    }
139
140    /// Update a variant with observation
141    pub fn update(&mut self, variant_id: &Uuid, success: bool) -> Result<()> {
142        let arm = self.arms.get_mut(variant_id)
143            .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?;
144        
145        arm.update(success);
146        self.total_samples += 1;
147        Ok(())
148    }
149
150    /// Get current conversion rates for all variants
151    pub fn get_conversion_rates(&self) -> HashMap<Uuid, f64> {
152        self.arms.iter()
153            .map(|(id, arm)| (*id, arm.conversion_rate()))
154            .collect()
155    }
156
157    /// Get arm statistics
158    pub fn get_arm(&self, variant_id: &Uuid) -> Option<&BanditArm> {
159        self.arms.get(variant_id)
160    }
161
162    /// Get all arms
163    pub fn get_arms(&self) -> &HashMap<Uuid, BanditArm> {
164        &self.arms
165    }
166
167    /// Calculate regret (difference from optimal arm)
168    pub fn calculate_regret(&self) -> f64 {
169        if self.arms.is_empty() || self.total_samples == 0 {
170            return 0.0;
171        }
172
173        // Best possible conversion rate
174        let best_rate = self.arms.values()
175            .map(|arm| arm.conversion_rate())
176            .max_by(|a, b| a.partial_cmp(b).unwrap())
177            .unwrap_or(0.0);
178
179        // Actual conversions (excluding priors)
180        let actual_conversions: f64 = self.arms.values()
181            .map(|arm| (arm.successes - 1.0).max(0.0))
182            .sum();
183
184        // Expected conversions if we always chose best arm
185        let expected_conversions = best_rate * self.total_samples as f64;
186
187        // Regret is the difference
188        (expected_conversions - actual_conversions).max(0.0)
189    }
190
191    /// Get total number of samples
192    pub fn total_samples(&self) -> u64 {
193        self.total_samples
194    }
195
196    /// Check if bandit has converged (low regret relative to samples)
197    pub fn has_converged(&self, threshold: f64) -> bool {
198        if self.total_samples < 100 {
199            return false;
200        }
201
202        let regret = self.calculate_regret();
203        let regret_rate = regret / self.total_samples as f64;
204        
205        regret_rate < threshold
206    }
207}
208
209impl Default for ThompsonSampling {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use rand::Rng;
219
220    #[test]
221    fn test_bandit_arm_creation() {
222        let arm = BanditArm::new(Uuid::new_v4());
223        assert_eq!(arm.successes, 1.0);
224        assert_eq!(arm.failures, 1.0);
225        assert_eq!(arm.trials, 0);
226        assert_eq!(arm.conversion_rate(), 0.5); // Beta(1,1) mean
227    }
228
229    #[test]
230    fn test_bandit_arm_update() {
231        let mut arm = BanditArm::new(Uuid::new_v4());
232        
233        arm.update(true);
234        assert_eq!(arm.successes, 2.0);
235        assert_eq!(arm.trials, 1);
236        
237        arm.update(false);
238        assert_eq!(arm.failures, 2.0);
239        assert_eq!(arm.trials, 2);
240    }
241
242    #[test]
243    fn test_bandit_arm_conversion_rate() {
244        let mut arm = BanditArm::new(Uuid::new_v4());
245        
246        // 7 successes out of 10 trials
247        for _ in 0..7 {
248            arm.update(true);
249        }
250        for _ in 0..3 {
251            arm.update(false);
252        }
253        
254        // (1 + 7) / (1 + 7 + 1 + 3) = 8/12 = 0.666...
255        let rate = arm.conversion_rate();
256        assert!((rate - 0.666).abs() < 0.01);
257    }
258
259    #[test]
260    fn test_thompson_sampling_creation() {
261        let ts = ThompsonSampling::new();
262        assert_eq!(ts.total_samples(), 0);
263        assert!(ts.get_arms().is_empty());
264    }
265
266    #[test]
267    fn test_add_remove_variant() {
268        let mut ts = ThompsonSampling::new();
269        let id = Uuid::new_v4();
270        
271        ts.add_variant(id);
272        assert_eq!(ts.get_arms().len(), 1);
273        assert!(ts.get_arm(&id).is_some());
274        
275        ts.remove_variant(&id);
276        assert_eq!(ts.get_arms().len(), 0);
277    }
278
279    #[test]
280    fn test_select_variant() {
281        let mut ts = ThompsonSampling::new();
282        
283        // Should fail with no variants
284        assert!(ts.select_variant().is_err());
285        
286        // Add variants
287        let id1 = Uuid::new_v4();
288        let id2 = Uuid::new_v4();
289        ts.add_variant(id1);
290        ts.add_variant(id2);
291        
292        // Should select one
293        let selected = ts.select_variant().unwrap();
294        assert!(selected == id1 || selected == id2);
295    }
296
297    #[test]
298    fn test_update_variant() {
299        let mut ts = ThompsonSampling::new();
300        let id = Uuid::new_v4();
301        ts.add_variant(id);
302        
303        // Update with success
304        ts.update(&id, true).unwrap();
305        assert_eq!(ts.total_samples(), 1);
306        
307        let arm = ts.get_arm(&id).unwrap();
308        assert_eq!(arm.successes, 2.0);
309        assert_eq!(arm.trials, 1);
310    }
311
312    #[test]
313    fn test_conversion_rates() {
314        let mut ts = ThompsonSampling::new();
315        let id1 = Uuid::new_v4();
316        let id2 = Uuid::new_v4();
317        
318        ts.add_variant(id1);
319        ts.add_variant(id2);
320        
321        // Variant 1: 8/10 success rate
322        for _ in 0..8 {
323            ts.update(&id1, true).unwrap();
324        }
325        for _ in 0..2 {
326            ts.update(&id1, false).unwrap();
327        }
328        
329        // Variant 2: 3/10 success rate
330        for _ in 0..3 {
331            ts.update(&id2, true).unwrap();
332        }
333        for _ in 0..7 {
334            ts.update(&id2, false).unwrap();
335        }
336        
337        let rates = ts.get_conversion_rates();
338        
339        // Variant 1 should have higher rate
340        assert!(rates[&id1] > rates[&id2]);
341        
342        // Check approximate values (with priors)
343        // id1: (1+8)/(1+8+1+2) = 9/12 = 0.75
344        // id2: (1+3)/(1+3+1+7) = 4/12 = 0.333
345        assert!((rates[&id1] - 0.75).abs() < 0.01);
346        assert!((rates[&id2] - 0.333).abs() < 0.01);
347    }
348
349    #[test]
350    fn test_regret_calculation() {
351        let mut ts = ThompsonSampling::new();
352        let id1 = Uuid::new_v4();
353        let id2 = Uuid::new_v4();
354        
355        ts.add_variant(id1);
356        ts.add_variant(id2);
357        
358        // Initial regret should be 0 or positive
359        let initial_regret = ts.calculate_regret();
360        assert!(initial_regret >= -0.01, "Initial regret should be >= 0, got: {}", initial_regret);
361
362        // Add some samples
363        for _ in 0..10 {
364            ts.update(&id1, true).unwrap();
365        }
366
367        // Regret should be low if we're selecting the best arm
368        let regret = ts.calculate_regret();
369        assert!(regret >= -0.01, "Regret should be >= 0, got: {}", regret);
370    }
371
372    #[test]
373    fn test_thompson_sampling_convergence() {
374        let mut ts = ThompsonSampling::new();
375        let good_variant = Uuid::new_v4();
376        let bad_variant = Uuid::new_v4();
377        
378        ts.add_variant(good_variant);
379        ts.add_variant(bad_variant);
380        
381        // Simulate: good variant has 80% success, bad has 20%
382        let mut rng = rand::thread_rng();
383        
384        for _ in 0..1000 {
385            let selected = ts.select_variant().unwrap();
386            
387            let success = if selected == good_variant {
388                rng.gen::<f64>() < 0.8
389            } else {
390                rng.gen::<f64>() < 0.2
391            };
392            
393            ts.update(&selected, success).unwrap();
394        }
395        
396        // After many trials, good variant should be selected more
397        let rates = ts.get_conversion_rates();
398        assert!(rates[&good_variant] > rates[&bad_variant]);
399        
400        // Check that conversion rates are approximately correct
401        assert!((rates[&good_variant] - 0.8).abs() < 0.1);
402    }
403
404    #[test]
405    fn test_credible_interval() {
406        let mut arm = BanditArm::new(Uuid::new_v4());
407        
408        // Add data: 70 successes, 30 failures
409        for _ in 0..70 {
410            arm.update(true);
411        }
412        for _ in 0..30 {
413            arm.update(false);
414        }
415        
416        let (lower, upper) = arm.credible_interval(0.95);
417        
418        // Interval should contain the true mean
419        let mean = arm.conversion_rate();
420        assert!(lower < mean && mean < upper);
421        
422        // Interval should be reasonable
423        assert!(lower > 0.0 && upper < 1.0);
424        assert!(upper - lower < 0.2); // Width < 20%
425    }
426}