1use chrono::Utc;
7use dashmap::DashMap;
8use llm_optimizer_types::experiments::*;
9use std::sync::Arc;
10use uuid::Uuid;
11
12use crate::{
13 errors::{DecisionError, Result},
14 statistical::{StatisticalTest, ZTest},
15 thompson_sampling::ThompsonSampling,
16};
17
18pub struct ExperimentManager {
20 experiments: Arc<DashMap<Uuid, Experiment>>,
22 bandits: Arc<DashMap<Uuid, ThompsonSampling>>,
24}
25
26impl ExperimentManager {
27 pub fn new() -> Self {
29 Self {
30 experiments: Arc::new(DashMap::new()),
31 bandits: Arc::new(DashMap::new()),
32 }
33 }
34
35 pub fn create_experiment(
37 &self,
38 name: impl Into<String>,
39 variants: Vec<Variant>,
40 _metrics: Vec<MetricDefinition>,
41 ) -> Result<Uuid> {
42 if variants.len() < 2 {
43 return Err(DecisionError::InvalidConfig(
44 "Experiment must have at least 2 variants".to_string()
45 ));
46 }
47
48 let total_allocation: f64 = variants.iter()
50 .map(|v| v.traffic_allocation)
51 .sum();
52
53 if (total_allocation - 1.0).abs() > 0.01 {
54 return Err(DecisionError::InvalidConfig(
55 format!("Traffic allocation must sum to 1.0, got {}", total_allocation)
56 ));
57 }
58
59 let experiment = Experiment::new(name, variants.clone());
60 let experiment_id = experiment.id;
61
62 let mut bandit = ThompsonSampling::new();
64 for variant in &variants {
65 bandit.add_variant(variant.id);
66 }
67
68 self.experiments.insert(experiment_id, experiment);
69 self.bandits.insert(experiment_id, bandit);
70
71 Ok(experiment_id)
72 }
73
74 pub fn start_experiment(&self, experiment_id: &Uuid) -> Result<()> {
76 let mut entry = self.experiments.get_mut(experiment_id)
77 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
78
79 entry.start();
80 Ok(())
81 }
82
83 pub fn pause_experiment(&self, experiment_id: &Uuid) -> Result<()> {
85 let mut entry = self.experiments.get_mut(experiment_id)
86 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
87
88 if entry.status == ExperimentStatus::Running {
89 entry.status = ExperimentStatus::Paused;
90 Ok(())
91 } else {
92 Err(DecisionError::InvalidState(
93 format!("Cannot pause experiment in state {:?}", entry.status)
94 ))
95 }
96 }
97
98 pub fn resume_experiment(&self, experiment_id: &Uuid) -> Result<()> {
100 let mut entry = self.experiments.get_mut(experiment_id)
101 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
102
103 if entry.status == ExperimentStatus::Paused {
104 entry.status = ExperimentStatus::Running;
105 Ok(())
106 } else {
107 Err(DecisionError::InvalidState(
108 format!("Cannot resume experiment in state {:?}", entry.status)
109 ))
110 }
111 }
112
113 pub fn select_variant(&self, experiment_id: &Uuid) -> Result<Uuid> {
115 let experiment = self.experiments.get(experiment_id)
117 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
118
119 if experiment.status != ExperimentStatus::Running {
120 return Err(DecisionError::InvalidState(
121 format!("Experiment is not running: {:?}", experiment.status)
122 ));
123 }
124
125 let bandit = self.bandits.get(experiment_id)
127 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
128
129 bandit.select_variant()
130 }
131
132 pub fn record_result(
134 &self,
135 experiment_id: &Uuid,
136 variant_id: &Uuid,
137 success: bool,
138 quality: f64,
139 cost: f64,
140 latency_ms: f64,
141 ) -> Result<()> {
142 if let Some(mut bandit) = self.bandits.get_mut(experiment_id) {
144 bandit.update(variant_id, success)?;
145 }
146
147 let mut experiment = self.experiments.get_mut(experiment_id)
149 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
150
151 let variant = experiment.variants.iter_mut()
152 .find(|v| v.id == *variant_id)
153 .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?;
154
155 if variant.results.is_none() {
157 variant.results = Some(VariantResults {
158 total_requests: 0,
159 conversions: 0,
160 avg_quality: 0.0,
161 avg_cost: 0.0,
162 avg_latency_ms: 0.0,
163 metrics: Default::default(),
164 });
165 }
166
167 if let Some(results) = &mut variant.results {
169 let n = results.total_requests as f64;
170
171 results.avg_quality = (results.avg_quality * n + quality) / (n + 1.0);
173 results.avg_cost = (results.avg_cost * n + cost) / (n + 1.0);
174 results.avg_latency_ms = (results.avg_latency_ms * n + latency_ms) / (n + 1.0);
175
176 results.total_requests += 1;
177 if success {
178 results.conversions += 1;
179 }
180 }
181
182 Ok(())
183 }
184
185 pub fn should_conclude(&self, experiment_id: &Uuid, min_sample_size: usize, significance_level: f64) -> Result<bool> {
187 let experiment = self.experiments.get(experiment_id)
188 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
189
190 if experiment.variants.len() != 2 {
191 return Ok(false);
193 }
194
195 let variant1 = &experiment.variants[0];
196 let variant2 = &experiment.variants[1];
197
198 let results1 = variant1.results.as_ref();
200 let results2 = variant2.results.as_ref();
201
202 if results1.is_none() || results2.is_none() {
203 return Ok(false);
204 }
205
206 let r1 = results1.unwrap();
207 let r2 = results2.unwrap();
208
209 if r1.total_requests < min_sample_size as u64 || r2.total_requests < min_sample_size as u64 {
210 return Ok(false);
211 }
212
213 let z_test = ZTest::new(
215 r1.conversions,
216 r1.total_requests,
217 r2.conversions,
218 r2.total_requests,
219 );
220
221 let is_significant = z_test.is_significant(significance_level)?;
222
223 Ok(is_significant)
224 }
225
226 pub fn conclude_experiment(&self, experiment_id: &Uuid, significance_level: f64) -> Result<()> {
228 let mut experiment = self.experiments.get_mut(experiment_id)
229 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
230
231 if experiment.variants.len() != 2 {
232 return Err(DecisionError::InvalidConfig(
233 "Can only conclude 2-variant experiments".to_string()
234 ));
235 }
236
237 let variant1 = &experiment.variants[0];
238 let variant2 = &experiment.variants[1];
239
240 let results1 = variant1.results.as_ref()
241 .ok_or_else(|| DecisionError::InsufficientData("Variant 1 has no results".to_string()))?;
242 let results2 = variant2.results.as_ref()
243 .ok_or_else(|| DecisionError::InsufficientData("Variant 2 has no results".to_string()))?;
244
245 let z_test = ZTest::new(
247 results1.conversions,
248 results1.total_requests,
249 results2.conversions,
250 results2.total_requests,
251 );
252
253 let p_value = z_test.test()?;
254 let is_significant = p_value < significance_level;
255 let effect_size = z_test.effect_size();
256
257 let winner_variant_id = if is_significant {
259 if results1.conversions as f64 / results1.total_requests as f64 >
260 results2.conversions as f64 / results2.total_requests as f64 {
261 Some(variant1.id)
262 } else {
263 Some(variant2.id)
264 }
265 } else {
266 None
267 };
268
269 let analysis = StatisticalAnalysis {
271 winner_variant_id,
272 p_value,
273 confidence_level: 1.0 - significance_level,
274 effect_size,
275 is_significant,
276 method: "Two-proportion z-test".to_string(),
277 };
278
279 let mut variant_details = std::collections::HashMap::new();
281 variant_details.insert(variant1.id, results1.clone());
282 variant_details.insert(variant2.id, results2.clone());
283
284 let duration_seconds = (Utc::now() - experiment.start_time).num_seconds() as u64;
285 let total_sample_size = results1.total_requests + results2.total_requests;
286
287 let results = ExperimentResults {
288 statistical_analysis: analysis,
289 variant_details,
290 total_sample_size,
291 duration_seconds,
292 };
293
294 experiment.complete(results);
295
296 Ok(())
297 }
298
299 pub fn get_experiment(&self, experiment_id: &Uuid) -> Option<Experiment> {
301 self.experiments.get(experiment_id).map(|e| e.clone())
302 }
303
304 pub fn list_experiments(&self) -> Vec<Experiment> {
306 self.experiments.iter().map(|e| e.value().clone()).collect()
307 }
308
309 pub fn list_active_experiments(&self) -> Vec<Experiment> {
311 self.experiments.iter()
312 .filter(|e| e.status == ExperimentStatus::Running)
313 .map(|e| e.value().clone())
314 .collect()
315 }
316
317 pub fn get_statistics(&self, experiment_id: &Uuid) -> Result<ExperimentStatistics> {
319 let experiment = self.experiments.get(experiment_id)
320 .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
321
322 let bandit = self.bandits.get(experiment_id);
323
324 let total_requests: u64 = experiment.variants.iter()
325 .filter_map(|v| v.results.as_ref())
326 .map(|r| r.total_requests)
327 .sum();
328
329 let conversion_rates = experiment.variants.iter()
330 .map(|v| {
331 let rate = v.conversion_rate().unwrap_or(0.0);
332 (v.id, rate)
333 })
334 .collect();
335
336 let bandit_regret = bandit.as_ref().map(|b| b.calculate_regret());
337
338 Ok(ExperimentStatistics {
339 experiment_id: *experiment_id,
340 status: experiment.status,
341 total_requests,
342 conversion_rates,
343 bandit_regret,
344 duration_seconds: (Utc::now() - experiment.start_time).num_seconds() as u64,
345 })
346 }
347}
348
349impl Default for ExperimentManager {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355#[derive(Debug, Clone)]
357pub struct ExperimentStatistics {
358 pub experiment_id: Uuid,
359 pub status: ExperimentStatus,
360 pub total_requests: u64,
361 pub conversion_rates: std::collections::HashMap<Uuid, f64>,
362 pub bandit_regret: Option<f64>,
363 pub duration_seconds: u64,
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use llm_optimizer_types::models::ModelConfig;
370
371 fn create_test_variants() -> Vec<Variant> {
372 vec![
373 Variant::new("control", ModelConfig::default(), 0.5),
374 Variant::new("variant_a", ModelConfig::default(), 0.5),
375 ]
376 }
377
378 #[test]
379 fn test_create_experiment() {
380 let manager = ExperimentManager::new();
381 let variants = create_test_variants();
382
383 let exp_id = manager.create_experiment("Test Experiment", variants, vec![]).unwrap();
384
385 let experiment = manager.get_experiment(&exp_id).unwrap();
386 assert_eq!(experiment.name, "Test Experiment");
387 assert_eq!(experiment.variants.len(), 2);
388 assert_eq!(experiment.status, ExperimentStatus::Draft);
389 }
390
391 #[test]
392 fn test_invalid_variant_count() {
393 let manager = ExperimentManager::new();
394 let variants = vec![Variant::new("control", ModelConfig::default(), 1.0)];
395
396 assert!(manager.create_experiment("Test", variants, vec![]).is_err());
397 }
398
399 #[test]
400 fn test_invalid_traffic_allocation() {
401 let manager = ExperimentManager::new();
402 let variants = vec![
403 Variant::new("control", ModelConfig::default(), 0.3),
404 Variant::new("variant_a", ModelConfig::default(), 0.5),
405 ];
406
407 assert!(manager.create_experiment("Test", variants, vec![]).is_err());
409 }
410
411 #[test]
412 fn test_start_experiment() {
413 let manager = ExperimentManager::new();
414 let variants = create_test_variants();
415
416 let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
417 manager.start_experiment(&exp_id).unwrap();
418
419 let experiment = manager.get_experiment(&exp_id).unwrap();
420 assert_eq!(experiment.status, ExperimentStatus::Running);
421 }
422
423 #[test]
424 fn test_select_variant() {
425 let manager = ExperimentManager::new();
426 let variants = create_test_variants();
427
428 let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
429 manager.start_experiment(&exp_id).unwrap();
430
431 let variant_id = manager.select_variant(&exp_id).unwrap();
432 assert!(variant_id != Uuid::nil());
433 }
434
435 #[test]
436 fn test_record_result() {
437 let manager = ExperimentManager::new();
438 let variants = create_test_variants();
439
440 let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
441 manager.start_experiment(&exp_id).unwrap();
442
443 let variant_id = manager.select_variant(&exp_id).unwrap();
444
445 manager.record_result(&exp_id, &variant_id, true, 0.9, 0.05, 1200.0).unwrap();
446
447 let experiment = manager.get_experiment(&exp_id).unwrap();
448 let variant = experiment.variants.iter().find(|v| v.id == variant_id).unwrap();
449
450 assert!(variant.results.is_some());
451 let results = variant.results.as_ref().unwrap();
452 assert_eq!(results.total_requests, 1);
453 assert_eq!(results.conversions, 1);
454 }
455
456 #[test]
457 fn test_pause_resume() {
458 let manager = ExperimentManager::new();
459 let variants = create_test_variants();
460
461 let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
462 manager.start_experiment(&exp_id).unwrap();
463 manager.pause_experiment(&exp_id).unwrap();
464
465 let experiment = manager.get_experiment(&exp_id).unwrap();
466 assert_eq!(experiment.status, ExperimentStatus::Paused);
467
468 manager.resume_experiment(&exp_id).unwrap();
469 let experiment = manager.get_experiment(&exp_id).unwrap();
470 assert_eq!(experiment.status, ExperimentStatus::Running);
471 }
472
473 #[test]
474 fn test_should_conclude() {
475 let manager = ExperimentManager::new();
476 let variants = create_test_variants();
477
478 let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
479 manager.start_experiment(&exp_id).unwrap();
480
481 assert!(!manager.should_conclude(&exp_id, 100, 0.05).unwrap());
483
484 let var1_id = manager.get_experiment(&exp_id).unwrap().variants[0].id;
486 for _ in 0..100 {
487 manager.record_result(&exp_id, &var1_id, true, 0.9, 0.05, 1000.0).unwrap();
488 }
489
490 let var2_id = manager.get_experiment(&exp_id).unwrap().variants[1].id;
492 for _ in 0..30 {
493 manager.record_result(&exp_id, &var2_id, true, 0.7, 0.05, 1000.0).unwrap();
494 }
495 for _ in 0..70 {
496 manager.record_result(&exp_id, &var2_id, false, 0.5, 0.05, 1000.0).unwrap();
497 }
498
499 assert!(manager.should_conclude(&exp_id, 100, 0.05).unwrap());
501 }
502
503 #[test]
504 fn test_get_statistics() {
505 let manager = ExperimentManager::new();
506 let variants = create_test_variants();
507
508 let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
509 manager.start_experiment(&exp_id).unwrap();
510
511 let stats = manager.get_statistics(&exp_id).unwrap();
512 assert_eq!(stats.experiment_id, exp_id);
513 assert_eq!(stats.status, ExperimentStatus::Running);
514 assert_eq!(stats.total_requests, 0);
515 }
516}