Skip to main content

llm_optimizer_decision/
ab_testing.rs

1//! A/B testing engine with Thompson Sampling
2//!
3//! This module provides the main A/B testing engine that combines
4//! Thompson Sampling, statistical significance testing, and experiment management.
5
6use llm_optimizer_config::OptimizerConfig;
7use llm_optimizer_types::{
8    experiments::*,
9    models::ModelConfig,
10};
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13use uuid::Uuid;
14
15use crate::{
16    errors::{DecisionError, Result},
17    experiment_manager::{ExperimentManager, ExperimentStatistics},
18    statistical::SampleSizeCalculator,
19    variant_generator::{VariantGenerator, VariantStrategy},
20};
21
22/// A/B testing engine
23pub struct ABTestEngine {
24    /// Experiment manager
25    manager: Arc<ExperimentManager>,
26    
27    /// Configuration
28    min_sample_size: usize,
29    significance_level: f64,
30    max_duration_seconds: u64,
31}
32
33impl ABTestEngine {
34    /// Create a new A/B testing engine
35    pub fn new(config: &OptimizerConfig) -> Self {
36        let ab_config = &config.strategies.ab_testing;
37        
38        Self {
39            manager: Arc::new(ExperimentManager::new()),
40            min_sample_size: ab_config.min_sample_size,
41            significance_level: ab_config.significance_level,
42            max_duration_seconds: ab_config.max_duration_seconds,
43        }
44    }
45
46    /// Create a new experiment with generated variants
47    pub fn create_experiment_from_strategy(
48        &self,
49        name: impl Into<String>,
50        base_config: &ModelConfig,
51        strategy: &VariantStrategy,
52    ) -> Result<Uuid> {
53        info!("Creating experiment with strategy: {:?}", strategy);
54        
55        // Generate variant configurations
56        let configs = VariantGenerator::generate(base_config, strategy)?;
57        
58        if configs.len() < 2 {
59            return Err(DecisionError::InvalidConfig(
60                "Must have at least 2 variants".to_string()
61            ));
62        }
63        
64        // Validate all configurations
65        for config in &configs {
66            VariantGenerator::validate_config(config)?;
67        }
68        
69        // Create variants with equal traffic allocation
70        let allocation = 1.0 / configs.len() as f64;
71        let variants: Vec<Variant> = configs.into_iter()
72            .enumerate()
73            .map(|(i, config)| {
74                let name = if i == 0 {
75                    "control".to_string()
76                } else {
77                    format!("variant_{}", i)
78                };
79                Variant::new(name, config, allocation)
80            })
81            .collect();
82        
83        // Create experiment
84        let exp_id = self.manager.create_experiment(name, variants, vec![])?;
85        
86        info!("Created experiment {}", exp_id);
87        
88        Ok(exp_id)
89    }
90
91    /// Create a custom experiment with specific variants
92    pub fn create_experiment(
93        &self,
94        name: impl Into<String>,
95        variants: Vec<Variant>,
96    ) -> Result<Uuid> {
97        // Validate variants
98        for variant in &variants {
99            VariantGenerator::validate_config(&variant.config)?;
100        }
101        
102        self.manager.create_experiment(name, variants, vec![])
103    }
104
105    /// Start an experiment
106    pub fn start(&self, experiment_id: &Uuid) -> Result<()> {
107        info!("Starting experiment {}", experiment_id);
108        self.manager.start_experiment(experiment_id)
109    }
110
111    /// Pause an experiment
112    pub fn pause(&self, experiment_id: &Uuid) -> Result<()> {
113        info!("Pausing experiment {}", experiment_id);
114        self.manager.pause_experiment(experiment_id)
115    }
116
117    /// Resume an experiment
118    pub fn resume(&self, experiment_id: &Uuid) -> Result<()> {
119        info!("Resuming experiment {}", experiment_id);
120        self.manager.resume_experiment(experiment_id)
121    }
122
123    /// Assign a variant to a request
124    pub fn assign_variant(&self, experiment_id: &Uuid) -> Result<(Uuid, ModelConfig)> {
125        let variant_id = self.manager.select_variant(experiment_id)?;
126        
127        let experiment = self.manager.get_experiment(experiment_id)
128            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
129        
130        let variant = experiment.variants.iter()
131            .find(|v| v.id == variant_id)
132            .ok_or_else(|| DecisionError::VariantNotFound(variant_id.to_string()))?;
133        
134        debug!("Assigned variant {} for experiment {}", variant.name, experiment_id);
135        
136        Ok((variant_id, variant.config.clone()))
137    }
138
139    /// Record the outcome of a request
140    pub fn record_outcome(
141        &self,
142        experiment_id: &Uuid,
143        variant_id: &Uuid,
144        success: bool,
145        quality: f64,
146        cost: f64,
147        latency_ms: f64,
148    ) -> Result<()> {
149        self.manager.record_result(
150            experiment_id,
151            variant_id,
152            success,
153            quality,
154            cost,
155            latency_ms,
156        )?;
157        
158        debug!(
159            "Recorded result for variant {} in experiment {}: success={}, quality={:.2}, cost={:.4}",
160            variant_id, experiment_id, success, quality, cost
161        );
162        
163        // Check if experiment should conclude
164        self.check_experiment_conclusion(experiment_id)?;
165        
166        Ok(())
167    }
168
169    /// Check if an experiment should conclude
170    fn check_experiment_conclusion(&self, experiment_id: &Uuid) -> Result<()> {
171        let should_conclude = self.manager.should_conclude(
172            experiment_id,
173            self.min_sample_size,
174            self.significance_level,
175        )?;
176        
177        if should_conclude {
178            info!("Experiment {} has reached statistical significance", experiment_id);
179            self.conclude(experiment_id)?;
180        } else {
181            // Check if max duration exceeded
182            let stats = self.manager.get_statistics(experiment_id)?;
183            if stats.duration_seconds >= self.max_duration_seconds {
184                warn!(
185                    "Experiment {} exceeded max duration ({} seconds), concluding without significance",
186                    experiment_id, self.max_duration_seconds
187                );
188                self.conclude(experiment_id)?;
189            }
190        }
191        
192        Ok(())
193    }
194
195    /// Manually conclude an experiment
196    pub fn conclude(&self, experiment_id: &Uuid) -> Result<Experiment> {
197        info!("Concluding experiment {}", experiment_id);
198        
199        self.manager.conclude_experiment(experiment_id, self.significance_level)?;
200        
201        let experiment = self.manager.get_experiment(experiment_id)
202            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
203        
204        if let Some(results) = &experiment.results {
205            if let Some(winner_id) = results.statistical_analysis.winner_variant_id {
206                let winner = experiment.variants.iter()
207                    .find(|v| v.id == winner_id)
208                    .map(|v| &v.name);
209                
210                info!(
211                    "Experiment {} concluded. Winner: {:?}, p-value: {:.4}, effect size: {:.4}",
212                    experiment_id,
213                    winner,
214                    results.statistical_analysis.p_value,
215                    results.statistical_analysis.effect_size
216                );
217            } else {
218                info!(
219                    "Experiment {} concluded with no significant winner (p-value: {:.4})",
220                    experiment_id,
221                    results.statistical_analysis.p_value
222                );
223            }
224        }
225        
226        Ok(experiment)
227    }
228
229    /// Get experiment statistics
230    pub fn get_statistics(&self, experiment_id: &Uuid) -> Result<ExperimentStatistics> {
231        self.manager.get_statistics(experiment_id)
232    }
233
234    /// Get experiment details
235    pub fn get_experiment(&self, experiment_id: &Uuid) -> Option<Experiment> {
236        self.manager.get_experiment(experiment_id)
237    }
238
239    /// List all experiments
240    pub fn list_experiments(&self) -> Vec<Experiment> {
241        self.manager.list_experiments()
242    }
243
244    /// List active experiments
245    pub fn list_active_experiments(&self) -> Vec<Experiment> {
246        self.manager.list_active_experiments()
247    }
248
249    /// Calculate required sample size for an experiment
250    pub fn calculate_sample_size(
251        &self,
252        baseline_rate: f64,
253        min_effect: f64,
254        power: f64,
255    ) -> Result<usize> {
256        let calculator = SampleSizeCalculator::new(
257            baseline_rate,
258            min_effect,
259            power,
260            self.significance_level,
261        )?;
262        
263        calculator.calculate()
264    }
265
266    /// Get winning variant configuration (if experiment concluded)
267    pub fn get_winner_config(&self, experiment_id: &Uuid) -> Result<Option<ModelConfig>> {
268        let experiment = self.manager.get_experiment(experiment_id)
269            .ok_or_else(|| DecisionError::ExperimentNotFound(experiment_id.to_string()))?;
270        
271        if experiment.status != ExperimentStatus::Completed {
272            return Ok(None);
273        }
274        
275        let winner = experiment.get_winner();
276        Ok(winner.map(|v| v.config.clone()))
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use llm_optimizer_config::{ServiceConfig, DatabaseConfig, IntegrationConfig, ObservabilityConfig};
284    use llm_optimizer_types::strategies::StrategyConfig;
285
286    fn test_config() -> OptimizerConfig {
287        OptimizerConfig {
288            service: ServiceConfig::default(),
289            database: DatabaseConfig::default(),
290            integrations: IntegrationConfig::default(),
291            strategies: StrategyConfig::default(),
292            observability: ObservabilityConfig::default(),
293        }
294    }
295
296    fn base_model_config() -> ModelConfig {
297        ModelConfig::default()
298    }
299
300    #[test]
301    fn test_create_engine() {
302        let config = test_config();
303        let engine = ABTestEngine::new(&config);
304        
305        assert_eq!(engine.min_sample_size, 1000);
306        assert_eq!(engine.significance_level, 0.05);
307    }
308
309    #[test]
310    fn test_create_experiment_from_strategy() {
311        let config = test_config();
312        let engine = ABTestEngine::new(&config);
313        
314        let base = base_model_config();
315        let strategy = VariantStrategy::Temperature(vec![0.0, 0.7, 1.0]);
316        
317        let exp_id = engine.create_experiment_from_strategy(
318            "Temperature Test",
319            &base,
320            &strategy,
321        ).unwrap();
322        
323        let experiment = engine.get_experiment(&exp_id).unwrap();
324        assert_eq!(experiment.variants.len(), 3);
325        assert_eq!(experiment.name, "Temperature Test");
326    }
327
328    #[test]
329    fn test_full_experiment_lifecycle() {
330        let config = test_config();
331        let engine = ABTestEngine::new(&config);
332        
333        // Create experiment
334        let base = base_model_config();
335        let strategy = VariantStrategy::Temperature(vec![0.3, 0.7]);
336        
337        let exp_id = engine.create_experiment_from_strategy(
338            "Test",
339            &base,
340            &strategy,
341        ).unwrap();
342        
343        // Start experiment
344        engine.start(&exp_id).unwrap();
345        
346        // Assign variants and record results
347        for i in 0..100 {
348            let (variant_id, _config) = engine.assign_variant(&exp_id).unwrap();
349            
350            // Simulate: first variant has 80% success, second has 60%
351            let variant_idx = engine.get_experiment(&exp_id).unwrap()
352                .variants.iter()
353                .position(|v| v.id == variant_id)
354                .unwrap();
355            
356            let success = if variant_idx == 0 {
357                i % 10 < 8
358            } else {
359                i % 10 < 6
360            };
361            
362            engine.record_outcome(&exp_id, &variant_id, success, 0.9, 0.05, 1000.0).unwrap();
363        }
364        
365        // Get statistics
366        let stats = engine.get_statistics(&exp_id).unwrap();
367        assert!(stats.total_requests > 0);
368    }
369
370    #[test]
371    fn test_pause_resume() {
372        let config = test_config();
373        let engine = ABTestEngine::new(&config);
374        
375        let base = base_model_config();
376        let strategy = VariantStrategy::Temperature(vec![0.3, 0.7]);
377        
378        let exp_id = engine.create_experiment_from_strategy("Test", &base, &strategy).unwrap();
379        
380        engine.start(&exp_id).unwrap();
381        engine.pause(&exp_id).unwrap();
382        
383        let exp = engine.get_experiment(&exp_id).unwrap();
384        assert_eq!(exp.status, ExperimentStatus::Paused);
385        
386        // Can't assign variant when paused
387        assert!(engine.assign_variant(&exp_id).is_err());
388        
389        engine.resume(&exp_id).unwrap();
390        
391        // Can assign after resume
392        assert!(engine.assign_variant(&exp_id).is_ok());
393    }
394
395    #[test]
396    fn test_sample_size_calculation() {
397        let config = test_config();
398        let engine = ABTestEngine::new(&config);
399        
400        let sample_size = engine.calculate_sample_size(
401            0.1,   // 10% baseline
402            0.2,   // 20% relative improvement
403            0.8,   // 80% power
404        ).unwrap();
405        
406        assert!(sample_size > 100);
407        assert!(sample_size < 100000);
408    }
409
410    #[test]
411    fn test_list_experiments() {
412        let config = test_config();
413        let engine = ABTestEngine::new(&config);
414        
415        let base = base_model_config();
416        let strategy = VariantStrategy::Temperature(vec![0.3, 0.7]);
417        
418        engine.create_experiment_from_strategy("Test 1", &base, &strategy).unwrap();
419        engine.create_experiment_from_strategy("Test 2", &base, &strategy).unwrap();
420        
421        let experiments = engine.list_experiments();
422        assert_eq!(experiments.len(), 2);
423    }
424
425    #[test]
426    fn test_list_active_experiments() {
427        let config = test_config();
428        let engine = ABTestEngine::new(&config);
429        
430        let base = base_model_config();
431        let strategy = VariantStrategy::Temperature(vec![0.3, 0.7]);
432        
433        let exp_id = engine.create_experiment_from_strategy("Test", &base, &strategy).unwrap();
434        engine.start(&exp_id).unwrap();
435        
436        let active = engine.list_active_experiments();
437        assert_eq!(active.len(), 1);
438    }
439}