use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExperimentConfig {
pub name: String,
pub description: String,
pub control_variant: Variant,
pub treatment_variants: Vec<Variant>,
pub traffic_percentage: f64,
pub min_sample_size: usize,
pub max_duration_hours: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Variant {
name: String,
model_id: String,
config_overrides: Option<serde_json::Value>,
}
impl Variant {
pub fn new(name: &str, model_id: &str) -> Self {
Self {
name: name.to_string(),
model_id: model_id.to_string(),
config_overrides: None,
}
}
pub fn with_config(name: &str, model_id: &str, config: serde_json::Value) -> Self {
Self {
name: name.to_string(),
model_id: model_id.to_string(),
config_overrides: Some(config),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn model_id(&self) -> &str {
&self.model_id
}
pub fn config_overrides(&self) -> Option<&serde_json::Value> {
self.config_overrides.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct Experiment {
id: Uuid,
config: ExperimentConfig,
status: ExperimentStatus,
start_time: Option<DateTime<Utc>>,
end_time: Option<DateTime<Utc>>,
metadata: ExperimentMetadata,
}
#[derive(Debug, Clone, Default)]
pub struct ExperimentMetadata {
pub request_counts: std::collections::HashMap<String, usize>,
pub last_updated: Option<DateTime<Utc>>,
#[allow(dead_code)]
pub tags: Vec<String>,
#[allow(dead_code)]
pub owner: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ExperimentStatus {
Draft,
Running,
Paused,
Concluded,
Cancelled,
}
impl Experiment {
pub fn new(config: ExperimentConfig) -> Result<Self> {
if config.traffic_percentage <= 0.0 || config.traffic_percentage > 100.0 {
anyhow::bail!("Traffic percentage must be between 0 and 100");
}
if config.treatment_variants.is_empty() {
anyhow::bail!("At least one treatment variant is required");
}
if config.min_sample_size == 0 {
anyhow::bail!("Minimum sample size must be greater than 0");
}
Ok(Self {
id: Uuid::new_v4(),
config,
status: ExperimentStatus::Draft,
start_time: None,
end_time: None,
metadata: ExperimentMetadata::default(),
})
}
pub fn id(&self) -> &Uuid {
&self.id
}
pub fn config(&self) -> &ExperimentConfig {
&self.config
}
pub fn status(&self) -> ExperimentStatus {
self.status.clone()
}
pub fn start(&mut self) -> Result<()> {
if self.status != ExperimentStatus::Draft {
anyhow::bail!("Can only start experiments in Draft status");
}
self.status = ExperimentStatus::Running;
self.start_time = Some(Utc::now());
self.metadata.last_updated = Some(Utc::now());
Ok(())
}
pub fn pause(&mut self) -> Result<()> {
if self.status != ExperimentStatus::Running {
anyhow::bail!("Can only pause running experiments");
}
self.status = ExperimentStatus::Paused;
self.metadata.last_updated = Some(Utc::now());
Ok(())
}
pub fn resume(&mut self) -> Result<()> {
if self.status != ExperimentStatus::Paused {
anyhow::bail!("Can only resume paused experiments");
}
self.status = ExperimentStatus::Running;
self.metadata.last_updated = Some(Utc::now());
Ok(())
}
pub fn conclude(&mut self) -> Result<()> {
if self.status != ExperimentStatus::Running && self.status != ExperimentStatus::Paused {
anyhow::bail!("Can only conclude running or paused experiments");
}
self.status = ExperimentStatus::Concluded;
self.end_time = Some(Utc::now());
self.metadata.last_updated = Some(Utc::now());
Ok(())
}
pub fn cancel(&mut self) -> Result<()> {
if self.status == ExperimentStatus::Concluded || self.status == ExperimentStatus::Cancelled
{
anyhow::bail!("Cannot cancel concluded or already cancelled experiments");
}
self.status = ExperimentStatus::Cancelled;
self.end_time = Some(Utc::now());
self.metadata.last_updated = Some(Utc::now());
Ok(())
}
pub fn should_auto_conclude(&self) -> bool {
if self.status != ExperimentStatus::Running {
return false;
}
if let Some(start_time) = self.start_time {
let elapsed = Utc::now() - start_time;
if elapsed > Duration::hours(self.config.max_duration_hours as i64) {
return true;
}
}
let min_count = self.metadata.request_counts.values().min().copied().unwrap_or(0);
min_count >= self.config.min_sample_size
}
pub fn all_variants(&self) -> Vec<&Variant> {
let mut variants = vec![&self.config.control_variant];
variants.extend(self.config.treatment_variants.iter());
variants
}
pub fn increment_request_count(&mut self, variant_name: &str) {
*self.metadata.request_counts.entry(variant_name.to_string()).or_insert(0) += 1;
self.metadata.last_updated = Some(Utc::now());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_experiment_lifecycle() {
let config = ExperimentConfig {
name: "Test Experiment".to_string(),
description: "Testing lifecycle".to_string(),
control_variant: Variant::new("control", "model-v1"),
treatment_variants: vec![Variant::new("treatment", "model-v2")],
traffic_percentage: 50.0,
min_sample_size: 100,
max_duration_hours: 24,
};
let mut experiment = Experiment::new(config).expect("operation failed in test");
assert_eq!(experiment.status(), ExperimentStatus::Draft);
experiment.start().expect("operation failed in test");
assert_eq!(experiment.status(), ExperimentStatus::Running);
assert!(experiment.start_time.is_some());
experiment.pause().expect("operation failed in test");
assert_eq!(experiment.status(), ExperimentStatus::Paused);
experiment.resume().expect("operation failed in test");
assert_eq!(experiment.status(), ExperimentStatus::Running);
experiment.conclude().expect("operation failed in test");
assert_eq!(experiment.status(), ExperimentStatus::Concluded);
assert!(experiment.end_time.is_some());
}
#[test]
fn test_variant_creation() {
let variant = Variant::new("test", "model-123");
assert_eq!(variant.name(), "test");
assert_eq!(variant.model_id(), "model-123");
assert!(variant.config_overrides().is_none());
let config = serde_json::json!({
"batch_size": 32,
"temperature": 0.7
});
let variant_with_config = Variant::with_config("test2", "model-456", config.clone());
assert_eq!(variant_with_config.config_overrides(), Some(&config));
}
#[test]
fn test_auto_conclude() {
let config = ExperimentConfig {
name: "Auto Conclude Test".to_string(),
description: "Testing auto conclusion".to_string(),
control_variant: Variant::new("control", "model-v1"),
treatment_variants: vec![Variant::new("treatment", "model-v2")],
traffic_percentage: 50.0,
min_sample_size: 2,
max_duration_hours: 24,
};
let mut experiment = Experiment::new(config).expect("operation failed in test");
experiment.start().expect("operation failed in test");
assert!(!experiment.should_auto_conclude());
experiment.increment_request_count("control");
experiment.increment_request_count("control");
experiment.increment_request_count("treatment");
experiment.increment_request_count("treatment");
assert!(experiment.should_auto_conclude());
}
}