1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Arm {
9 pub id: String,
10 pub name: String,
11 pub description: String,
12 pub config: serde_json::Value,
13 pub pulls: u64,
14 pub total_reward: f64,
15 pub mean_reward: f64,
16}
17
18impl Arm {
19 pub fn new(id: String, name: String, config: serde_json::Value) -> Self {
20 Self {
21 id,
22 name,
23 description: String::new(),
24 config,
25 pulls: 0,
26 total_reward: 0.0,
27 mean_reward: 0.0,
28 }
29 }
30
31 pub fn update(&mut self, reward: f64) {
32 self.pulls += 1;
33 self.total_reward += reward;
34 self.mean_reward = self.total_reward / self.pulls as f64;
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ThompsonSampling {
41 pub arms: HashMap<String, BetaDistribution>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BetaDistribution {
46 pub alpha: f64, pub beta: f64, }
49
50impl ThompsonSampling {
51 pub fn new(arm_ids: &[String]) -> Self {
52 let mut arms = HashMap::new();
53 for id in arm_ids {
54 arms.insert(
55 id.clone(),
56 BetaDistribution {
57 alpha: 1.0,
58 beta: 1.0,
59 },
60 );
61 }
62 Self { arms }
63 }
64
65 pub fn select_arm(&self) -> String {
66 let mut best_arm = String::new();
67 let mut best_sample = f64::NEG_INFINITY;
68
69 for (arm_id, dist) in &self.arms {
70 let sample = self.sample_beta(dist.alpha, dist.beta);
71 if sample > best_sample {
72 best_sample = sample;
73 best_arm = arm_id.clone();
74 }
75 }
76
77 best_arm
78 }
79
80 pub fn update(&mut self, arm_id: &str, reward: f64) {
81 if let Some(dist) = self.arms.get_mut(arm_id) {
82 if reward > 0.5 {
83 dist.alpha += 1.0;
84 } else {
85 dist.beta += 1.0;
86 }
87 }
88 }
89
90 fn sample_beta(&self, alpha: f64, beta: f64) -> f64 {
92 let x = self.sample_gamma(alpha, 1.0);
93 let y = self.sample_gamma(beta, 1.0);
94 x / (x + y)
95 }
96
97 fn sample_gamma(&self, shape: f64, scale: f64) -> f64 {
99 if shape < 1.0 {
100 return self.sample_gamma(shape + 1.0, scale) * rand::random::<f64>().powf(1.0 / shape);
101 }
102
103 let d = shape - 1.0 / 3.0;
104 let c = 1.0 / (9.0 * d).sqrt();
105
106 loop {
107 let x = self.sample_normal();
108 let v = (1.0 + c * x).powi(3);
109
110 if v > 0.0 {
111 let u = rand::random::<f64>();
112 if u < 1.0 - 0.0331 * x.powi(4) {
113 return d * v * scale;
114 }
115 if u.ln() < 0.5 * x.powi(2) + d * (1.0 - v + v.ln()) {
116 return d * v * scale;
117 }
118 }
119 }
120 }
121
122 fn sample_normal(&self) -> f64 {
123 let u1 = rand::random::<f64>();
125 let u2 = rand::random::<f64>();
126 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct UCB1 {
133 pub arms: HashMap<String, ArmStats>,
134 pub total_pulls: u64,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ArmStats {
139 pub pulls: u64,
140 pub total_reward: f64,
141 pub mean_reward: f64,
142}
143
144impl UCB1 {
145 pub fn new(arm_ids: &[String]) -> Self {
146 let mut arms = HashMap::new();
147 for id in arm_ids {
148 arms.insert(
149 id.clone(),
150 ArmStats {
151 pulls: 0,
152 total_reward: 0.0,
153 mean_reward: 0.0,
154 },
155 );
156 }
157 Self {
158 arms,
159 total_pulls: 0,
160 }
161 }
162
163 pub fn select_arm(&self) -> String {
164 for (arm_id, stats) in &self.arms {
166 if stats.pulls == 0 {
167 return arm_id.clone();
168 }
169 }
170
171 let mut best_arm = String::new();
173 let mut best_ucb = f64::NEG_INFINITY;
174
175 for (arm_id, stats) in &self.arms {
176 let ucb = stats.mean_reward
177 + (2.0 * (self.total_pulls as f64).ln() / stats.pulls as f64).sqrt();
178
179 if ucb > best_ucb {
180 best_ucb = ucb;
181 best_arm = arm_id.clone();
182 }
183 }
184
185 best_arm
186 }
187
188 pub fn update(&mut self, arm_id: &str, reward: f64) {
189 self.total_pulls += 1;
190
191 if let Some(stats) = self.arms.get_mut(arm_id) {
192 stats.pulls += 1;
193 stats.total_reward += reward;
194 stats.mean_reward = stats.total_reward / stats.pulls as f64;
195 }
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub enum BanditStrategy {
202 ThompsonSampling,
203 UCB1,
204 EpsilonGreedy { epsilon: f64 },
205}
206
207pub struct MultiArmedBandit {
209 arms: Arc<RwLock<HashMap<String, Arm>>>,
210 strategy: BanditStrategy,
211 thompson_sampling: Arc<RwLock<Option<ThompsonSampling>>>,
212 ucb1: Arc<RwLock<Option<UCB1>>>,
213 epsilon: f64,
214}
215
216impl MultiArmedBandit {
217 pub fn new(arms: Vec<Arm>, strategy: BanditStrategy) -> Self {
218 let arm_ids: Vec<String> = arms.iter().map(|a| a.id.clone()).collect();
219
220 let (thompson_sampling, ucb1, epsilon) = match &strategy {
221 BanditStrategy::ThompsonSampling => (Some(ThompsonSampling::new(&arm_ids)), None, 0.0),
222 BanditStrategy::UCB1 => (None, Some(UCB1::new(&arm_ids)), 0.0),
223 BanditStrategy::EpsilonGreedy { epsilon } => (None, None, *epsilon),
224 };
225
226 let arms_map: HashMap<String, Arm> = arms.into_iter().map(|a| (a.id.clone(), a)).collect();
227
228 Self {
229 arms: Arc::new(RwLock::new(arms_map)),
230 strategy,
231 thompson_sampling: Arc::new(RwLock::new(thompson_sampling)),
232 ucb1: Arc::new(RwLock::new(ucb1)),
233 epsilon,
234 }
235 }
236
237 pub async fn select_arm(&self) -> String {
239 match &self.strategy {
240 BanditStrategy::ThompsonSampling => {
241 let ts = self.thompson_sampling.read().await;
242 ts.as_ref().unwrap().select_arm()
243 }
244 BanditStrategy::UCB1 => {
245 let ucb = self.ucb1.read().await;
246 ucb.as_ref().unwrap().select_arm()
247 }
248 BanditStrategy::EpsilonGreedy { .. } => {
249 if rand::random::<f64>() < self.epsilon {
250 self.random_arm().await
252 } else {
253 self.best_arm().await
255 }
256 }
257 }
258 }
259
260 async fn random_arm(&self) -> String {
261 let arms = self.arms.read().await;
262 let keys: Vec<_> = arms.keys().collect();
263 if keys.is_empty() {
264 return String::new();
265 }
266 use rand::Rng;
267 let mut rng = rand::rng();
268 let idx = rng.random_range(0..keys.len());
269 keys[idx].clone()
270 }
271
272 async fn best_arm(&self) -> String {
273 let arms = self.arms.read().await;
274 let mut best_arm = String::new();
275 let mut best_reward = f64::NEG_INFINITY;
276
277 for (id, arm) in arms.iter() {
278 if arm.mean_reward > best_reward {
279 best_reward = arm.mean_reward;
280 best_arm = id.clone();
281 }
282 }
283
284 best_arm
285 }
286
287 pub async fn update(&self, arm_id: &str, reward: f64) {
289 {
291 let mut arms = self.arms.write().await;
292 if let Some(arm) = arms.get_mut(arm_id) {
293 arm.update(reward);
294 }
295 }
296
297 match &self.strategy {
299 BanditStrategy::ThompsonSampling => {
300 let mut ts = self.thompson_sampling.write().await;
301 if let Some(ts) = ts.as_mut() {
302 ts.update(arm_id, reward);
303 }
304 }
305 BanditStrategy::UCB1 => {
306 let mut ucb = self.ucb1.write().await;
307 if let Some(ucb) = ucb.as_mut() {
308 ucb.update(arm_id, reward);
309 }
310 }
311 BanditStrategy::EpsilonGreedy { .. } => {
312 }
314 }
315 }
316
317 pub async fn get_arm(&self, arm_id: &str) -> Option<Arm> {
319 let arms = self.arms.read().await;
320 arms.get(arm_id).cloned()
321 }
322
323 pub async fn get_all_arms(&self) -> Vec<Arm> {
325 let arms = self.arms.read().await;
326 arms.values().cloned().collect()
327 }
328
329 pub async fn get_report(&self) -> BanditReport {
331 let arms = self.arms.read().await;
332
333 let mut arm_reports: Vec<_> = arms
334 .values()
335 .map(|arm| ArmReport {
336 id: arm.id.clone(),
337 name: arm.name.clone(),
338 pulls: arm.pulls,
339 mean_reward: arm.mean_reward,
340 total_reward: arm.total_reward,
341 confidence_interval: self.calculate_confidence_interval(arm),
342 })
343 .collect();
344
345 arm_reports.sort_by(|a, b| b.mean_reward.partial_cmp(&a.mean_reward).unwrap());
346
347 let total_pulls: u64 = arms.values().map(|a| a.pulls).sum();
348 let best_arm = arm_reports.first().map(|r| r.id.clone());
349
350 BanditReport {
351 total_pulls,
352 arms: arm_reports,
353 best_arm,
354 strategy: format!("{:?}", self.strategy),
355 }
356 }
357
358 fn calculate_confidence_interval(&self, arm: &Arm) -> (f64, f64) {
359 if arm.pulls < 2 {
360 return (0.0, 1.0);
361 }
362
363 let z = 1.96; let std_error = (arm.mean_reward * (1.0 - arm.mean_reward) / arm.pulls as f64).sqrt();
366 let margin = z * std_error;
367
368 ((arm.mean_reward - margin).max(0.0), (arm.mean_reward + margin).min(1.0))
369 }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct BanditReport {
374 pub total_pulls: u64,
375 pub arms: Vec<ArmReport>,
376 pub best_arm: Option<String>,
377 pub strategy: String,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ArmReport {
382 pub id: String,
383 pub name: String,
384 pub pulls: u64,
385 pub mean_reward: f64,
386 pub total_reward: f64,
387 pub confidence_interval: (f64, f64),
388}
389
390pub struct TrafficAllocator {
392 bandit: Arc<MultiArmedBandit>,
393 update_interval: std::time::Duration,
394 min_samples: u64,
395}
396
397impl TrafficAllocator {
398 pub fn new(bandit: Arc<MultiArmedBandit>, update_interval: std::time::Duration) -> Self {
399 Self {
400 bandit,
401 update_interval,
402 min_samples: 100,
403 }
404 }
405
406 pub async fn get_allocation(&self) -> HashMap<String, f64> {
408 let arms = self.bandit.get_all_arms().await;
409 let total_pulls: u64 = arms.iter().map(|a| a.pulls).sum();
410
411 if total_pulls < self.min_samples {
412 let equal_share = 1.0 / arms.len() as f64;
414 return arms.iter().map(|a| (a.id.clone(), equal_share)).collect();
415 }
416
417 let total_reward: f64 = arms.iter().map(|a| a.mean_reward).sum();
419
420 if total_reward == 0.0 {
421 let equal_share = 1.0 / arms.len() as f64;
422 return arms.iter().map(|a| (a.id.clone(), equal_share)).collect();
423 }
424
425 arms.iter()
426 .map(|a| {
427 let allocation = a.mean_reward / total_reward;
428 (a.id.clone(), allocation)
429 })
430 .collect()
431 }
432
433 pub async fn start_auto_allocation(&self) {
435 let _bandit = self.bandit.clone();
436 let interval = self.update_interval;
437
438 tokio::spawn(async move {
439 let mut ticker = tokio::time::interval(interval);
440 loop {
441 ticker.tick().await;
442 }
445 });
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[tokio::test]
454 async fn test_thompson_sampling() {
455 let arms = vec![
456 Arm::new("v1".to_string(), "Variant 1".to_string(), serde_json::json!({})),
457 Arm::new("v2".to_string(), "Variant 2".to_string(), serde_json::json!({})),
458 Arm::new("v3".to_string(), "Variant 3".to_string(), serde_json::json!({})),
459 ];
460
461 let bandit = MultiArmedBandit::new(arms, BanditStrategy::ThompsonSampling);
462
463 for _ in 0..100 {
465 let arm_id = bandit.select_arm().await;
466 let reward = if arm_id == "v2" { 0.8 } else { 0.3 };
467 bandit.update(&arm_id, reward).await;
468 }
469
470 let report = bandit.get_report().await;
471 assert_eq!(report.best_arm, Some("v2".to_string()));
472 }
473
474 #[tokio::test]
475 async fn test_ucb1() {
476 let arms = vec![
477 Arm::new("a".to_string(), "Arm A".to_string(), serde_json::json!({})),
478 Arm::new("b".to_string(), "Arm B".to_string(), serde_json::json!({})),
479 ];
480
481 let bandit = MultiArmedBandit::new(arms, BanditStrategy::UCB1);
482
483 for _ in 0..50 {
484 let arm_id = bandit.select_arm().await;
485 let reward = if arm_id == "a" { 0.9 } else { 0.1 };
486 bandit.update(&arm_id, reward).await;
487 }
488
489 let report = bandit.get_report().await;
490 assert!(report.total_pulls > 0);
491 }
492}