Skip to main content

llm_optimizer_decision/
experiment_manager.rs

1//! Experiment lifecycle management
2//!
3//! This module manages the full lifecycle of A/B experiments including
4//! creation, execution, monitoring, and conclusion.
5
6use chrono::Utc;
7use dashmap::DashMap;
8use llm_optimizer_types::experiments::*;
9use std::sync::Arc;
10use uuid::Uuid;
11
12use crate::{
13    errors::{DecisionError, Result},
14    statistical::{StatisticalTest, ZTest},
15    thompson_sampling::ThompsonSampling,
16};
17
18/// Experiment lifecycle manager
19pub struct ExperimentManager {
20    /// Active experiments
21    experiments: Arc<DashMap<Uuid, Experiment>>,
22    /// Thompson Sampling bandits per experiment
23    bandits: Arc<DashMap<Uuid, ThompsonSampling>>,
24}
25
26impl ExperimentManager {
27    /// Create a new experiment manager
28    pub fn new() -> Self {
29        Self {
30            experiments: Arc::new(DashMap::new()),
31            bandits: Arc::new(DashMap::new()),
32        }
33    }
34
35    /// Create a new experiment
36    pub fn create_experiment(
37        &self,
38        name: impl Into<String>,
39        variants: Vec<Variant>,
40        _metrics: Vec<MetricDefinition>,
41    ) -> Result<Uuid> {
42        if variants.len() < 2 {
43            return Err(DecisionError::InvalidConfig(
44                "Experiment must have at least 2 variants".to_string()
45            ));
46        }
47
48        // Validate traffic allocation sums to 1.0
49        let total_allocation: f64 = variants.iter()
50            .map(|v| v.traffic_allocation)
51            .sum();
52        
53        if (total_allocation - 1.0).abs() > 0.01 {
54            return Err(DecisionError::InvalidConfig(
55                format!("Traffic allocation must sum to 1.0, got {}", total_allocation)
56            ));
57        }
58
59        let experiment = Experiment::new(name, variants.clone());
60        let experiment_id = experiment.id;
61
62        // Create Thompson Sampling bandit
63        let mut bandit = ThompsonSampling::new();
64        for variant in &variants {
65            bandit.add_variant(variant.id);
66        }
67
68        self.experiments.insert(experiment_id, experiment);
69        self.bandits.insert(experiment_id, bandit);
70
71        Ok(experiment_id)
72    }
73
74    /// Start an experiment
75    pub fn start_experiment(&self, experiment_id: &Uuid) -> Result<()> {
76        let mut entry = self.experiments.get_mut(experiment_id)
77            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
78        
79        entry.start();
80        Ok(())
81    }
82
83    /// Pause an experiment
84    pub fn pause_experiment(&self, experiment_id: &Uuid) -> Result<()> {
85        let mut entry = self.experiments.get_mut(experiment_id)
86            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
87        
88        if entry.status == ExperimentStatus::Running {
89            entry.status = ExperimentStatus::Paused;
90            Ok(())
91        } else {
92            Err(DecisionError::InvalidState(
93                format!("Cannot pause experiment in state {:?}", entry.status)
94            ))
95        }
96    }
97
98    /// Resume an experiment
99    pub fn resume_experiment(&self, experiment_id: &Uuid) -> Result<()> {
100        let mut entry = self.experiments.get_mut(experiment_id)
101            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
102        
103        if entry.status == ExperimentStatus::Paused {
104            entry.status = ExperimentStatus::Running;
105            Ok(())
106        } else {
107            Err(DecisionError::InvalidState(
108                format!("Cannot resume experiment in state {:?}", entry.status)
109            ))
110        }
111    }
112
113    /// Select a variant for a request using Thompson Sampling
114    pub fn select_variant(&self, experiment_id: &Uuid) -> Result<Uuid> {
115        // Check experiment is running
116        let experiment = self.experiments.get(experiment_id)
117            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
118        
119        if experiment.status != ExperimentStatus::Running {
120            return Err(DecisionError::InvalidState(
121                format!("Experiment is not running: {:?}", experiment.status)
122            ));
123        }
124
125        // Use Thompson Sampling to select variant
126        let bandit = self.bandits.get(experiment_id)
127            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
128        
129        bandit.select_variant()
130    }
131
132    /// Record result for a variant
133    pub fn record_result(
134        &self,
135        experiment_id: &Uuid,
136        variant_id: &Uuid,
137        success: bool,
138        quality: f64,
139        cost: f64,
140        latency_ms: f64,
141    ) -> Result<()> {
142        // Update Thompson Sampling bandit
143        if let Some(mut bandit) = self.bandits.get_mut(experiment_id) {
144            bandit.update(variant_id, success)?;
145        }
146
147        // Update variant results
148        let mut experiment = self.experiments.get_mut(experiment_id)
149            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
150        
151        let variant = experiment.variants.iter_mut()
152            .find(|v| v.id == *variant_id)
153            .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?;
154
155        // Initialize results if needed
156        if variant.results.is_none() {
157            variant.results = Some(VariantResults {
158                total_requests: 0,
159                conversions: 0,
160                avg_quality: 0.0,
161                avg_cost: 0.0,
162                avg_latency_ms: 0.0,
163                metrics: Default::default(),
164            });
165        }
166
167        // Update results
168        if let Some(results) = &mut variant.results {
169            let n = results.total_requests as f64;
170            
171            // Update running averages
172            results.avg_quality = (results.avg_quality * n + quality) / (n + 1.0);
173            results.avg_cost = (results.avg_cost * n + cost) / (n + 1.0);
174            results.avg_latency_ms = (results.avg_latency_ms * n + latency_ms) / (n + 1.0);
175            
176            results.total_requests += 1;
177            if success {
178                results.conversions += 1;
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Check if experiment should conclude
186    pub fn should_conclude(&self, experiment_id: &Uuid, min_sample_size: usize, significance_level: f64) -> Result<bool> {
187        let experiment = self.experiments.get(experiment_id)
188            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
189        
190        if experiment.variants.len() != 2 {
191            // For now, only support 2-variant experiments
192            return Ok(false);
193        }
194
195        let variant1 = &experiment.variants[0];
196        let variant2 = &experiment.variants[1];
197
198        // Check minimum sample size
199        let results1 = variant1.results.as_ref();
200        let results2 = variant2.results.as_ref();
201
202        if results1.is_none() || results2.is_none() {
203            return Ok(false);
204        }
205
206        let r1 = results1.unwrap();
207        let r2 = results2.unwrap();
208
209        if r1.total_requests < min_sample_size as u64 || r2.total_requests < min_sample_size as u64 {
210            return Ok(false);
211        }
212
213        // Perform statistical test
214        let z_test = ZTest::new(
215            r1.conversions,
216            r1.total_requests,
217            r2.conversions,
218            r2.total_requests,
219        );
220
221        let is_significant = z_test.is_significant(significance_level)?;
222
223        Ok(is_significant)
224    }
225
226    /// Conclude experiment and determine winner
227    pub fn conclude_experiment(&self, experiment_id: &Uuid, significance_level: f64) -> Result<()> {
228        let mut experiment = self.experiments.get_mut(experiment_id)
229            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
230        
231        if experiment.variants.len() != 2 {
232            return Err(DecisionError::InvalidConfig(
233                "Can only conclude 2-variant experiments".to_string()
234            ));
235        }
236
237        let variant1 = &experiment.variants[0];
238        let variant2 = &experiment.variants[1];
239
240        let results1 = variant1.results.as_ref()
241            .ok_or_else(|| DecisionError::InsufficientData("Variant 1 has no results".to_string()))?;
242        let results2 = variant2.results.as_ref()
243            .ok_or_else(|| DecisionError::InsufficientData("Variant 2 has no results".to_string()))?;
244
245        // Perform statistical analysis
246        let z_test = ZTest::new(
247            results1.conversions,
248            results1.total_requests,
249            results2.conversions,
250            results2.total_requests,
251        );
252
253        let p_value = z_test.test()?;
254        let is_significant = p_value < significance_level;
255        let effect_size = z_test.effect_size();
256
257        // Determine winner
258        let winner_variant_id = if is_significant {
259            if results1.conversions as f64 / results1.total_requests as f64 >
260               results2.conversions as f64 / results2.total_requests as f64 {
261                Some(variant1.id)
262            } else {
263                Some(variant2.id)
264            }
265        } else {
266            None
267        };
268
269        // Create statistical analysis
270        let analysis = StatisticalAnalysis {
271            winner_variant_id,
272            p_value,
273            confidence_level: 1.0 - significance_level,
274            effect_size,
275            is_significant,
276            method: "Two-proportion z-test".to_string(),
277        };
278
279        // Create experiment results
280        let mut variant_details = std::collections::HashMap::new();
281        variant_details.insert(variant1.id, results1.clone());
282        variant_details.insert(variant2.id, results2.clone());
283
284        let duration_seconds = (Utc::now() - experiment.start_time).num_seconds() as u64;
285        let total_sample_size = results1.total_requests + results2.total_requests;
286
287        let results = ExperimentResults {
288            statistical_analysis: analysis,
289            variant_details,
290            total_sample_size,
291            duration_seconds,
292        };
293
294        experiment.complete(results);
295
296        Ok(())
297    }
298
299    /// Get experiment by ID
300    pub fn get_experiment(&self, experiment_id: &Uuid) -> Option<Experiment> {
301        self.experiments.get(experiment_id).map(|e| e.clone())
302    }
303
304    /// Get all experiments
305    pub fn list_experiments(&self) -> Vec<Experiment> {
306        self.experiments.iter().map(|e| e.value().clone()).collect()
307    }
308
309    /// Get active experiments
310    pub fn list_active_experiments(&self) -> Vec<Experiment> {
311        self.experiments.iter()
312            .filter(|e| e.status == ExperimentStatus::Running)
313            .map(|e| e.value().clone())
314            .collect()
315    }
316
317    /// Get experiment statistics
318    pub fn get_statistics(&self, experiment_id: &Uuid) -> Result<ExperimentStatistics> {
319        let experiment = self.experiments.get(experiment_id)
320            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
321        
322        let bandit = self.bandits.get(experiment_id);
323
324        let total_requests: u64 = experiment.variants.iter()
325            .filter_map(|v| v.results.as_ref())
326            .map(|r| r.total_requests)
327            .sum();
328
329        let conversion_rates = experiment.variants.iter()
330            .map(|v| {
331                let rate = v.conversion_rate().unwrap_or(0.0);
332                (v.id, rate)
333            })
334            .collect();
335
336        let bandit_regret = bandit.as_ref().map(|b| b.calculate_regret());
337
338        Ok(ExperimentStatistics {
339            experiment_id: *experiment_id,
340            status: experiment.status,
341            total_requests,
342            conversion_rates,
343            bandit_regret,
344            duration_seconds: (Utc::now() - experiment.start_time).num_seconds() as u64,
345        })
346    }
347}
348
349impl Default for ExperimentManager {
350    fn default() -> Self {
351        Self::new()
352    }
353}
354
355/// Experiment statistics
356#[derive(Debug, Clone)]
357pub struct ExperimentStatistics {
358    pub experiment_id: Uuid,
359    pub status: ExperimentStatus,
360    pub total_requests: u64,
361    pub conversion_rates: std::collections::HashMap<Uuid, f64>,
362    pub bandit_regret: Option<f64>,
363    pub duration_seconds: u64,
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use llm_optimizer_types::models::ModelConfig;
370
371    fn create_test_variants() -> Vec<Variant> {
372        vec![
373            Variant::new("control", ModelConfig::default(), 0.5),
374            Variant::new("variant_a", ModelConfig::default(), 0.5),
375        ]
376    }
377
378    #[test]
379    fn test_create_experiment() {
380        let manager = ExperimentManager::new();
381        let variants = create_test_variants();
382        
383        let exp_id = manager.create_experiment("Test Experiment", variants, vec![]).unwrap();
384        
385        let experiment = manager.get_experiment(&exp_id).unwrap();
386        assert_eq!(experiment.name, "Test Experiment");
387        assert_eq!(experiment.variants.len(), 2);
388        assert_eq!(experiment.status, ExperimentStatus::Draft);
389    }
390
391    #[test]
392    fn test_invalid_variant_count() {
393        let manager = ExperimentManager::new();
394        let variants = vec![Variant::new("control", ModelConfig::default(), 1.0)];
395        
396        assert!(manager.create_experiment("Test", variants, vec![]).is_err());
397    }
398
399    #[test]
400    fn test_invalid_traffic_allocation() {
401        let manager = ExperimentManager::new();
402        let variants = vec![
403            Variant::new("control", ModelConfig::default(), 0.3),
404            Variant::new("variant_a", ModelConfig::default(), 0.5),
405        ];
406        
407        // Total is 0.8, should error
408        assert!(manager.create_experiment("Test", variants, vec![]).is_err());
409    }
410
411    #[test]
412    fn test_start_experiment() {
413        let manager = ExperimentManager::new();
414        let variants = create_test_variants();
415        
416        let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
417        manager.start_experiment(&exp_id).unwrap();
418        
419        let experiment = manager.get_experiment(&exp_id).unwrap();
420        assert_eq!(experiment.status, ExperimentStatus::Running);
421    }
422
423    #[test]
424    fn test_select_variant() {
425        let manager = ExperimentManager::new();
426        let variants = create_test_variants();
427        
428        let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
429        manager.start_experiment(&exp_id).unwrap();
430        
431        let variant_id = manager.select_variant(&exp_id).unwrap();
432        assert!(variant_id != Uuid::nil());
433    }
434
435    #[test]
436    fn test_record_result() {
437        let manager = ExperimentManager::new();
438        let variants = create_test_variants();
439        
440        let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
441        manager.start_experiment(&exp_id).unwrap();
442        
443        let variant_id = manager.select_variant(&exp_id).unwrap();
444        
445        manager.record_result(&exp_id, &variant_id, true, 0.9, 0.05, 1200.0).unwrap();
446        
447        let experiment = manager.get_experiment(&exp_id).unwrap();
448        let variant = experiment.variants.iter().find(|v| v.id == variant_id).unwrap();
449        
450        assert!(variant.results.is_some());
451        let results = variant.results.as_ref().unwrap();
452        assert_eq!(results.total_requests, 1);
453        assert_eq!(results.conversions, 1);
454    }
455
456    #[test]
457    fn test_pause_resume() {
458        let manager = ExperimentManager::new();
459        let variants = create_test_variants();
460        
461        let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
462        manager.start_experiment(&exp_id).unwrap();
463        manager.pause_experiment(&exp_id).unwrap();
464        
465        let experiment = manager.get_experiment(&exp_id).unwrap();
466        assert_eq!(experiment.status, ExperimentStatus::Paused);
467        
468        manager.resume_experiment(&exp_id).unwrap();
469        let experiment = manager.get_experiment(&exp_id).unwrap();
470        assert_eq!(experiment.status, ExperimentStatus::Running);
471    }
472
473    #[test]
474    fn test_should_conclude() {
475        let manager = ExperimentManager::new();
476        let variants = create_test_variants();
477        
478        let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
479        manager.start_experiment(&exp_id).unwrap();
480        
481        // Not enough data yet
482        assert!(!manager.should_conclude(&exp_id, 100, 0.05).unwrap());
483        
484        // Add data to first variant (high conversion)
485        let var1_id = manager.get_experiment(&exp_id).unwrap().variants[0].id;
486        for _ in 0..100 {
487            manager.record_result(&exp_id, &var1_id, true, 0.9, 0.05, 1000.0).unwrap();
488        }
489        
490        // Add data to second variant (low conversion)
491        let var2_id = manager.get_experiment(&exp_id).unwrap().variants[1].id;
492        for _ in 0..30 {
493            manager.record_result(&exp_id, &var2_id, true, 0.7, 0.05, 1000.0).unwrap();
494        }
495        for _ in 0..70 {
496            manager.record_result(&exp_id, &var2_id, false, 0.5, 0.05, 1000.0).unwrap();
497        }
498        
499        // Should conclude (significant difference)
500        assert!(manager.should_conclude(&exp_id, 100, 0.05).unwrap());
501    }
502
503    #[test]
504    fn test_get_statistics() {
505        let manager = ExperimentManager::new();
506        let variants = create_test_variants();
507        
508        let exp_id = manager.create_experiment("Test", variants, vec![]).unwrap();
509        manager.start_experiment(&exp_id).unwrap();
510        
511        let stats = manager.get_statistics(&exp_id).unwrap();
512        assert_eq!(stats.experiment_id, exp_id);
513        assert_eq!(stats.status, ExperimentStatus::Running);
514        assert_eq!(stats.total_requests, 0);
515    }
516}