llm_optimizer_decision/
thompson_sampling.rs1use rand_distr::{Beta, Distribution};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12use crate::errors::{DecisionError, Result};
13
14#[derive(Debug, Clone)]
16pub struct ThompsonSampling {
17 arms: HashMap<Uuid, BanditArm>,
19 total_samples: u64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BanditArm {
26 pub variant_id: Uuid,
28 pub successes: f64,
30 pub failures: f64,
32 pub trials: u64,
34}
35
36impl BanditArm {
37 pub fn new(variant_id: Uuid) -> Self {
39 Self {
40 variant_id,
41 successes: 1.0, failures: 1.0,
43 trials: 0,
44 }
45 }
46
47 pub fn update(&mut self, success: bool) {
49 if success {
50 self.successes += 1.0;
51 } else {
52 self.failures += 1.0;
53 }
54 self.trials += 1;
55 }
56
57 pub fn conversion_rate(&self) -> f64 {
59 self.successes / (self.successes + self.failures)
60 }
61
62 pub fn credible_interval(&self, confidence: f64) -> (f64, f64) {
64 use statrs::distribution::Beta as BetaDist;
65
66 let _beta = BetaDist::new(self.successes, self.failures).unwrap();
67 let lower = (1.0 - confidence) / 2.0;
68 let _upper = 1.0 - lower;
69
70 let mean = self.conversion_rate();
72 let std = (self.successes * self.failures /
73 ((self.successes + self.failures).powi(2) *
74 (self.successes + self.failures + 1.0))).sqrt();
75
76 (
77 (mean - 1.96 * std).max(0.0),
78 (mean + 1.96 * std).min(1.0),
79 )
80 }
81
82 pub fn sample(&self) -> Result<f64> {
84 let beta = Beta::new(self.successes, self.failures)
85 .map_err(|e| DecisionError::StatisticalError(
86 format!("Failed to create Beta distribution: {}", e)
87 ))?;
88
89 let mut rng = rand::thread_rng();
90 Ok(beta.sample(&mut rng))
91 }
92}
93
94impl ThompsonSampling {
95 pub fn new() -> Self {
97 Self {
98 arms: HashMap::new(),
99 total_samples: 0,
100 }
101 }
102
103 pub fn add_variant(&mut self, variant_id: Uuid) {
105 self.arms.insert(variant_id, BanditArm::new(variant_id));
106 }
107
108 pub fn remove_variant(&mut self, variant_id: &Uuid) {
110 self.arms.remove(variant_id);
111 }
112
113 pub fn select_variant(&self) -> Result<Uuid> {
118 if self.arms.is_empty() {
119 return Err(DecisionError::InvalidState(
120 "No variants available for selection".to_string()
121 ));
122 }
123
124 let mut best_variant = None;
125 let mut best_sample = f64::MIN;
126
127 for (variant_id, arm) in &self.arms {
128 let sample = arm.sample()?;
129 if sample > best_sample {
130 best_sample = sample;
131 best_variant = Some(*variant_id);
132 }
133 }
134
135 best_variant.ok_or_else(||
136 DecisionError::AllocationError("Failed to select variant".to_string())
137 )
138 }
139
140 pub fn update(&mut self, variant_id: &Uuid, success: bool) -> Result<()> {
142 let arm = self.arms.get_mut(variant_id)
143 .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?;
144
145 arm.update(success);
146 self.total_samples += 1;
147 Ok(())
148 }
149
150 pub fn get_conversion_rates(&self) -> HashMap<Uuid, f64> {
152 self.arms.iter()
153 .map(|(id, arm)| (*id, arm.conversion_rate()))
154 .collect()
155 }
156
157 pub fn get_arm(&self, variant_id: &Uuid) -> Option<&BanditArm> {
159 self.arms.get(variant_id)
160 }
161
162 pub fn get_arms(&self) -> &HashMap<Uuid, BanditArm> {
164 &self.arms
165 }
166
167 pub fn calculate_regret(&self) -> f64 {
169 if self.arms.is_empty() || self.total_samples == 0 {
170 return 0.0;
171 }
172
173 let best_rate = self.arms.values()
175 .map(|arm| arm.conversion_rate())
176 .max_by(|a, b| a.partial_cmp(b).unwrap())
177 .unwrap_or(0.0);
178
179 let actual_conversions: f64 = self.arms.values()
181 .map(|arm| (arm.successes - 1.0).max(0.0))
182 .sum();
183
184 let expected_conversions = best_rate * self.total_samples as f64;
186
187 (expected_conversions - actual_conversions).max(0.0)
189 }
190
191 pub fn total_samples(&self) -> u64 {
193 self.total_samples
194 }
195
196 pub fn has_converged(&self, threshold: f64) -> bool {
198 if self.total_samples < 100 {
199 return false;
200 }
201
202 let regret = self.calculate_regret();
203 let regret_rate = regret / self.total_samples as f64;
204
205 regret_rate < threshold
206 }
207}
208
209impl Default for ThompsonSampling {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use rand::Rng;
219
220 #[test]
221 fn test_bandit_arm_creation() {
222 let arm = BanditArm::new(Uuid::new_v4());
223 assert_eq!(arm.successes, 1.0);
224 assert_eq!(arm.failures, 1.0);
225 assert_eq!(arm.trials, 0);
226 assert_eq!(arm.conversion_rate(), 0.5); }
228
229 #[test]
230 fn test_bandit_arm_update() {
231 let mut arm = BanditArm::new(Uuid::new_v4());
232
233 arm.update(true);
234 assert_eq!(arm.successes, 2.0);
235 assert_eq!(arm.trials, 1);
236
237 arm.update(false);
238 assert_eq!(arm.failures, 2.0);
239 assert_eq!(arm.trials, 2);
240 }
241
242 #[test]
243 fn test_bandit_arm_conversion_rate() {
244 let mut arm = BanditArm::new(Uuid::new_v4());
245
246 for _ in 0..7 {
248 arm.update(true);
249 }
250 for _ in 0..3 {
251 arm.update(false);
252 }
253
254 let rate = arm.conversion_rate();
256 assert!((rate - 0.666).abs() < 0.01);
257 }
258
259 #[test]
260 fn test_thompson_sampling_creation() {
261 let ts = ThompsonSampling::new();
262 assert_eq!(ts.total_samples(), 0);
263 assert!(ts.get_arms().is_empty());
264 }
265
266 #[test]
267 fn test_add_remove_variant() {
268 let mut ts = ThompsonSampling::new();
269 let id = Uuid::new_v4();
270
271 ts.add_variant(id);
272 assert_eq!(ts.get_arms().len(), 1);
273 assert!(ts.get_arm(&id).is_some());
274
275 ts.remove_variant(&id);
276 assert_eq!(ts.get_arms().len(), 0);
277 }
278
279 #[test]
280 fn test_select_variant() {
281 let mut ts = ThompsonSampling::new();
282
283 assert!(ts.select_variant().is_err());
285
286 let id1 = Uuid::new_v4();
288 let id2 = Uuid::new_v4();
289 ts.add_variant(id1);
290 ts.add_variant(id2);
291
292 let selected = ts.select_variant().unwrap();
294 assert!(selected == id1 || selected == id2);
295 }
296
297 #[test]
298 fn test_update_variant() {
299 let mut ts = ThompsonSampling::new();
300 let id = Uuid::new_v4();
301 ts.add_variant(id);
302
303 ts.update(&id, true).unwrap();
305 assert_eq!(ts.total_samples(), 1);
306
307 let arm = ts.get_arm(&id).unwrap();
308 assert_eq!(arm.successes, 2.0);
309 assert_eq!(arm.trials, 1);
310 }
311
312 #[test]
313 fn test_conversion_rates() {
314 let mut ts = ThompsonSampling::new();
315 let id1 = Uuid::new_v4();
316 let id2 = Uuid::new_v4();
317
318 ts.add_variant(id1);
319 ts.add_variant(id2);
320
321 for _ in 0..8 {
323 ts.update(&id1, true).unwrap();
324 }
325 for _ in 0..2 {
326 ts.update(&id1, false).unwrap();
327 }
328
329 for _ in 0..3 {
331 ts.update(&id2, true).unwrap();
332 }
333 for _ in 0..7 {
334 ts.update(&id2, false).unwrap();
335 }
336
337 let rates = ts.get_conversion_rates();
338
339 assert!(rates[&id1] > rates[&id2]);
341
342 assert!((rates[&id1] - 0.75).abs() < 0.01);
346 assert!((rates[&id2] - 0.333).abs() < 0.01);
347 }
348
349 #[test]
350 fn test_regret_calculation() {
351 let mut ts = ThompsonSampling::new();
352 let id1 = Uuid::new_v4();
353 let id2 = Uuid::new_v4();
354
355 ts.add_variant(id1);
356 ts.add_variant(id2);
357
358 let initial_regret = ts.calculate_regret();
360 assert!(initial_regret >= -0.01, "Initial regret should be >= 0, got: {}", initial_regret);
361
362 for _ in 0..10 {
364 ts.update(&id1, true).unwrap();
365 }
366
367 let regret = ts.calculate_regret();
369 assert!(regret >= -0.01, "Regret should be >= 0, got: {}", regret);
370 }
371
372 #[test]
373 fn test_thompson_sampling_convergence() {
374 let mut ts = ThompsonSampling::new();
375 let good_variant = Uuid::new_v4();
376 let bad_variant = Uuid::new_v4();
377
378 ts.add_variant(good_variant);
379 ts.add_variant(bad_variant);
380
381 let mut rng = rand::thread_rng();
383
384 for _ in 0..1000 {
385 let selected = ts.select_variant().unwrap();
386
387 let success = if selected == good_variant {
388 rng.gen::<f64>() < 0.8
389 } else {
390 rng.gen::<f64>() < 0.2
391 };
392
393 ts.update(&selected, success).unwrap();
394 }
395
396 let rates = ts.get_conversion_rates();
398 assert!(rates[&good_variant] > rates[&bad_variant]);
399
400 assert!((rates[&good_variant] - 0.8).abs() < 0.1);
402 }
403
404 #[test]
405 fn test_credible_interval() {
406 let mut arm = BanditArm::new(Uuid::new_v4());
407
408 for _ in 0..70 {
410 arm.update(true);
411 }
412 for _ in 0..30 {
413 arm.update(false);
414 }
415
416 let (lower, upper) = arm.credible_interval(0.95);
417
418 let mean = arm.conversion_rate();
420 assert!(lower < mean && mean < upper);
421
422 assert!(lower > 0.0 && upper < 1.0);
424 assert!(upper - lower < 0.2); }
426}