mockforge_chaos/
multi_armed_bandit.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6/// Arm (variant) in multi-armed bandit
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Arm {
9    pub id: String,
10    pub name: String,
11    pub description: String,
12    pub config: serde_json::Value,
13    pub pulls: u64,
14    pub total_reward: f64,
15    pub mean_reward: f64,
16}
17
18impl Arm {
19    pub fn new(id: String, name: String, config: serde_json::Value) -> Self {
20        Self {
21            id,
22            name,
23            description: String::new(),
24            config,
25            pulls: 0,
26            total_reward: 0.0,
27            mean_reward: 0.0,
28        }
29    }
30
31    pub fn update(&mut self, reward: f64) {
32        self.pulls += 1;
33        self.total_reward += reward;
34        self.mean_reward = self.total_reward / self.pulls as f64;
35    }
36}
37
38/// Thompson Sampling strategy
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ThompsonSampling {
41    pub arms: HashMap<String, BetaDistribution>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BetaDistribution {
46    pub alpha: f64, // Successes
47    pub beta: f64,  // Failures
48}
49
50impl ThompsonSampling {
51    pub fn new(arm_ids: &[String]) -> Self {
52        let mut arms = HashMap::new();
53        for id in arm_ids {
54            arms.insert(
55                id.clone(),
56                BetaDistribution {
57                    alpha: 1.0,
58                    beta: 1.0,
59                },
60            );
61        }
62        Self { arms }
63    }
64
65    pub fn select_arm(&self) -> String {
66        let mut best_arm = String::new();
67        let mut best_sample = f64::NEG_INFINITY;
68
69        for (arm_id, dist) in &self.arms {
70            let sample = self.sample_beta(dist.alpha, dist.beta);
71            if sample > best_sample {
72                best_sample = sample;
73                best_arm = arm_id.clone();
74            }
75        }
76
77        best_arm
78    }
79
80    pub fn update(&mut self, arm_id: &str, reward: f64) {
81        if let Some(dist) = self.arms.get_mut(arm_id) {
82            if reward > 0.5 {
83                dist.alpha += 1.0;
84            } else {
85                dist.beta += 1.0;
86            }
87        }
88    }
89
90    // Simple beta distribution sampling using gamma distributions
91    fn sample_beta(&self, alpha: f64, beta: f64) -> f64 {
92        let x = self.sample_gamma(alpha, 1.0);
93        let y = self.sample_gamma(beta, 1.0);
94        x / (x + y)
95    }
96
97    // Marsaglia and Tsang's method for gamma distribution
98    fn sample_gamma(&self, shape: f64, scale: f64) -> f64 {
99        if shape < 1.0 {
100            return self.sample_gamma(shape + 1.0, scale) * rand::random::<f64>().powf(1.0 / shape);
101        }
102
103        let d = shape - 1.0 / 3.0;
104        let c = 1.0 / (9.0 * d).sqrt();
105
106        loop {
107            let x = self.sample_normal();
108            let v = (1.0 + c * x).powi(3);
109
110            if v > 0.0 {
111                let u = rand::random::<f64>();
112                if u < 1.0 - 0.0331 * x.powi(4) {
113                    return d * v * scale;
114                }
115                if u.ln() < 0.5 * x.powi(2) + d * (1.0 - v + v.ln()) {
116                    return d * v * scale;
117                }
118            }
119        }
120    }
121
122    fn sample_normal(&self) -> f64 {
123        // Box-Muller transform
124        let u1 = rand::random::<f64>();
125        let u2 = rand::random::<f64>();
126        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
127    }
128}
129
130/// UCB1 (Upper Confidence Bound) strategy
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct UCB1 {
133    pub arms: HashMap<String, ArmStats>,
134    pub total_pulls: u64,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ArmStats {
139    pub pulls: u64,
140    pub total_reward: f64,
141    pub mean_reward: f64,
142}
143
144impl UCB1 {
145    pub fn new(arm_ids: &[String]) -> Self {
146        let mut arms = HashMap::new();
147        for id in arm_ids {
148            arms.insert(
149                id.clone(),
150                ArmStats {
151                    pulls: 0,
152                    total_reward: 0.0,
153                    mean_reward: 0.0,
154                },
155            );
156        }
157        Self {
158            arms,
159            total_pulls: 0,
160        }
161    }
162
163    pub fn select_arm(&self) -> String {
164        // First pull all arms at least once
165        for (arm_id, stats) in &self.arms {
166            if stats.pulls == 0 {
167                return arm_id.clone();
168            }
169        }
170
171        // Calculate UCB for each arm
172        let mut best_arm = String::new();
173        let mut best_ucb = f64::NEG_INFINITY;
174
175        for (arm_id, stats) in &self.arms {
176            let ucb = stats.mean_reward
177                + (2.0 * (self.total_pulls as f64).ln() / stats.pulls as f64).sqrt();
178
179            if ucb > best_ucb {
180                best_ucb = ucb;
181                best_arm = arm_id.clone();
182            }
183        }
184
185        best_arm
186    }
187
188    pub fn update(&mut self, arm_id: &str, reward: f64) {
189        self.total_pulls += 1;
190
191        if let Some(stats) = self.arms.get_mut(arm_id) {
192            stats.pulls += 1;
193            stats.total_reward += reward;
194            stats.mean_reward = stats.total_reward / stats.pulls as f64;
195        }
196    }
197}
198
199/// Strategy for selecting arms
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub enum BanditStrategy {
202    ThompsonSampling,
203    UCB1,
204    EpsilonGreedy { epsilon: f64 },
205}
206
207/// Multi-Armed Bandit for A/B/C/D/... testing
208pub struct MultiArmedBandit {
209    arms: Arc<RwLock<HashMap<String, Arm>>>,
210    strategy: BanditStrategy,
211    thompson_sampling: Arc<RwLock<Option<ThompsonSampling>>>,
212    ucb1: Arc<RwLock<Option<UCB1>>>,
213    epsilon: f64,
214}
215
216impl MultiArmedBandit {
217    pub fn new(arms: Vec<Arm>, strategy: BanditStrategy) -> Self {
218        let arm_ids: Vec<String> = arms.iter().map(|a| a.id.clone()).collect();
219
220        let (thompson_sampling, ucb1, epsilon) = match &strategy {
221            BanditStrategy::ThompsonSampling => (Some(ThompsonSampling::new(&arm_ids)), None, 0.0),
222            BanditStrategy::UCB1 => (None, Some(UCB1::new(&arm_ids)), 0.0),
223            BanditStrategy::EpsilonGreedy { epsilon } => (None, None, *epsilon),
224        };
225
226        let arms_map: HashMap<String, Arm> = arms.into_iter().map(|a| (a.id.clone(), a)).collect();
227
228        Self {
229            arms: Arc::new(RwLock::new(arms_map)),
230            strategy,
231            thompson_sampling: Arc::new(RwLock::new(thompson_sampling)),
232            ucb1: Arc::new(RwLock::new(ucb1)),
233            epsilon,
234        }
235    }
236
237    /// Select arm based on strategy
238    pub async fn select_arm(&self) -> String {
239        match &self.strategy {
240            BanditStrategy::ThompsonSampling => {
241                let ts = self.thompson_sampling.read().await;
242                ts.as_ref().unwrap().select_arm()
243            }
244            BanditStrategy::UCB1 => {
245                let ucb = self.ucb1.read().await;
246                ucb.as_ref().unwrap().select_arm()
247            }
248            BanditStrategy::EpsilonGreedy { .. } => {
249                if rand::random::<f64>() < self.epsilon {
250                    // Explore: random arm
251                    self.random_arm().await
252                } else {
253                    // Exploit: best arm
254                    self.best_arm().await
255                }
256            }
257        }
258    }
259
260    async fn random_arm(&self) -> String {
261        let arms = self.arms.read().await;
262        let keys: Vec<_> = arms.keys().collect();
263        if keys.is_empty() {
264            return String::new();
265        }
266        use rand::Rng;
267        let mut rng = rand::rng();
268        let idx = rng.random_range(0..keys.len());
269        keys[idx].clone()
270    }
271
272    async fn best_arm(&self) -> String {
273        let arms = self.arms.read().await;
274        let mut best_arm = String::new();
275        let mut best_reward = f64::NEG_INFINITY;
276
277        for (id, arm) in arms.iter() {
278            if arm.mean_reward > best_reward {
279                best_reward = arm.mean_reward;
280                best_arm = id.clone();
281            }
282        }
283
284        best_arm
285    }
286
287    /// Update arm with observed reward
288    pub async fn update(&self, arm_id: &str, reward: f64) {
289        // Update arm statistics
290        {
291            let mut arms = self.arms.write().await;
292            if let Some(arm) = arms.get_mut(arm_id) {
293                arm.update(reward);
294            }
295        }
296
297        // Update strategy-specific state
298        match &self.strategy {
299            BanditStrategy::ThompsonSampling => {
300                let mut ts = self.thompson_sampling.write().await;
301                if let Some(ts) = ts.as_mut() {
302                    ts.update(arm_id, reward);
303                }
304            }
305            BanditStrategy::UCB1 => {
306                let mut ucb = self.ucb1.write().await;
307                if let Some(ucb) = ucb.as_mut() {
308                    ucb.update(arm_id, reward);
309                }
310            }
311            BanditStrategy::EpsilonGreedy { .. } => {
312                // No additional state to update
313            }
314        }
315    }
316
317    /// Get arm by ID
318    pub async fn get_arm(&self, arm_id: &str) -> Option<Arm> {
319        let arms = self.arms.read().await;
320        arms.get(arm_id).cloned()
321    }
322
323    /// Get all arms with statistics
324    pub async fn get_all_arms(&self) -> Vec<Arm> {
325        let arms = self.arms.read().await;
326        arms.values().cloned().collect()
327    }
328
329    /// Get performance report
330    pub async fn get_report(&self) -> BanditReport {
331        let arms = self.arms.read().await;
332
333        let mut arm_reports: Vec<_> = arms
334            .values()
335            .map(|arm| ArmReport {
336                id: arm.id.clone(),
337                name: arm.name.clone(),
338                pulls: arm.pulls,
339                mean_reward: arm.mean_reward,
340                total_reward: arm.total_reward,
341                confidence_interval: self.calculate_confidence_interval(arm),
342            })
343            .collect();
344
345        arm_reports.sort_by(|a, b| b.mean_reward.partial_cmp(&a.mean_reward).unwrap());
346
347        let total_pulls: u64 = arms.values().map(|a| a.pulls).sum();
348        let best_arm = arm_reports.first().map(|r| r.id.clone());
349
350        BanditReport {
351            total_pulls,
352            arms: arm_reports,
353            best_arm,
354            strategy: format!("{:?}", self.strategy),
355        }
356    }
357
358    fn calculate_confidence_interval(&self, arm: &Arm) -> (f64, f64) {
359        if arm.pulls < 2 {
360            return (0.0, 1.0);
361        }
362
363        // 95% confidence interval using normal approximation
364        let z = 1.96; // 95% confidence
365        let std_error = (arm.mean_reward * (1.0 - arm.mean_reward) / arm.pulls as f64).sqrt();
366        let margin = z * std_error;
367
368        ((arm.mean_reward - margin).max(0.0), (arm.mean_reward + margin).min(1.0))
369    }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct BanditReport {
374    pub total_pulls: u64,
375    pub arms: Vec<ArmReport>,
376    pub best_arm: Option<String>,
377    pub strategy: String,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ArmReport {
382    pub id: String,
383    pub name: String,
384    pub pulls: u64,
385    pub mean_reward: f64,
386    pub total_reward: f64,
387    pub confidence_interval: (f64, f64),
388}
389
390/// Automatic traffic allocator
391pub struct TrafficAllocator {
392    bandit: Arc<MultiArmedBandit>,
393    update_interval: std::time::Duration,
394    min_samples: u64,
395}
396
397impl TrafficAllocator {
398    pub fn new(bandit: Arc<MultiArmedBandit>, update_interval: std::time::Duration) -> Self {
399        Self {
400            bandit,
401            update_interval,
402            min_samples: 100,
403        }
404    }
405
406    /// Get traffic allocation percentages
407    pub async fn get_allocation(&self) -> HashMap<String, f64> {
408        let arms = self.bandit.get_all_arms().await;
409        let total_pulls: u64 = arms.iter().map(|a| a.pulls).sum();
410
411        if total_pulls < self.min_samples {
412            // Equal allocation during exploration phase
413            let equal_share = 1.0 / arms.len() as f64;
414            return arms.iter().map(|a| (a.id.clone(), equal_share)).collect();
415        }
416
417        // Allocate based on performance
418        let total_reward: f64 = arms.iter().map(|a| a.mean_reward).sum();
419
420        if total_reward == 0.0 {
421            let equal_share = 1.0 / arms.len() as f64;
422            return arms.iter().map(|a| (a.id.clone(), equal_share)).collect();
423        }
424
425        arms.iter()
426            .map(|a| {
427                let allocation = a.mean_reward / total_reward;
428                (a.id.clone(), allocation)
429            })
430            .collect()
431    }
432
433    /// Start automatic reallocation
434    pub async fn start_auto_allocation(&self) {
435        let _bandit = self.bandit.clone();
436        let interval = self.update_interval;
437
438        tokio::spawn(async move {
439            let mut ticker = tokio::time::interval(interval);
440            loop {
441                ticker.tick().await;
442                // Allocation is recalculated on-demand via get_allocation()
443                // This task can trigger webhooks or notifications when allocation changes significantly
444            }
445        });
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[tokio::test]
454    async fn test_thompson_sampling() {
455        let arms = vec![
456            Arm::new("v1".to_string(), "Variant 1".to_string(), serde_json::json!({})),
457            Arm::new("v2".to_string(), "Variant 2".to_string(), serde_json::json!({})),
458            Arm::new("v3".to_string(), "Variant 3".to_string(), serde_json::json!({})),
459        ];
460
461        let bandit = MultiArmedBandit::new(arms, BanditStrategy::ThompsonSampling);
462
463        // Simulate some pulls
464        for _ in 0..100 {
465            let arm_id = bandit.select_arm().await;
466            let reward = if arm_id == "v2" { 0.8 } else { 0.3 };
467            bandit.update(&arm_id, reward).await;
468        }
469
470        let report = bandit.get_report().await;
471        assert_eq!(report.best_arm, Some("v2".to_string()));
472    }
473
474    #[tokio::test]
475    async fn test_ucb1() {
476        let arms = vec![
477            Arm::new("a".to_string(), "Arm A".to_string(), serde_json::json!({})),
478            Arm::new("b".to_string(), "Arm B".to_string(), serde_json::json!({})),
479        ];
480
481        let bandit = MultiArmedBandit::new(arms, BanditStrategy::UCB1);
482
483        for _ in 0..50 {
484            let arm_id = bandit.select_arm().await;
485            let reward = if arm_id == "a" { 0.9 } else { 0.1 };
486            bandit.update(&arm_id, reward).await;
487        }
488
489        let report = bandit.get_report().await;
490        assert!(report.total_pulls > 0);
491    }
492}