use super::{ABTestingError, Result};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExperimentState {
Draft,
Active,
Paused,
Completed,
Archived,
}
#[derive(Debug, Clone)]
pub struct Experiment {
pub id: Uuid,
pub name: String,
pub description: Option<String>,
pub branches: Vec<String>,
pub allocation: HashMap<String, u32>,
pub state: ExperimentState,
pub created_at: chrono::DateTime<chrono::Utc>,
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
pub winner: Option<String>,
pub config: ExperimentConfig,
pub tags: Vec<String>,
pub owner: Option<String>,
}
impl Experiment {
pub fn new(name: impl Into<String>, branches: Vec<String>) -> Self {
let name = name.into();
let branch_count = branches.len();
let base_allocation = 100 / branch_count as u32;
let remainder = 100 % branch_count as u32;
let mut allocation = HashMap::new();
for (i, branch) in branches.iter().enumerate() {
let extra = if i == 0 { remainder } else { 0 };
allocation.insert(branch.clone(), base_allocation + extra);
}
Self {
id: Uuid::new_v4(),
name,
description: None,
branches,
allocation,
state: ExperimentState::Draft,
created_at: chrono::Utc::now(),
started_at: None,
completed_at: None,
winner: None,
config: ExperimentConfig::default(),
tags: Vec::new(),
owner: None,
}
}
pub fn with_allocation(mut self, allocation: HashMap<String, u32>) -> Result<Self> {
let total: u32 = allocation.values().sum();
if total != 100 {
return Err(ABTestingError::Configuration(format!(
"Allocation must sum to 100, got {}",
total
)));
}
for branch in allocation.keys() {
if !self.branches.contains(branch) {
return Err(ABTestingError::BranchNotFound(branch.clone()));
}
}
self.allocation = allocation;
Ok(self)
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub fn with_owner(mut self, owner: impl Into<String>) -> Self {
self.owner = Some(owner.into());
self
}
pub fn with_config(mut self, config: ExperimentConfig) -> Self {
self.config = config;
self
}
pub fn control_branch(&self) -> Option<&String> {
self.branches.first()
}
pub fn get_allocation(&self, branch: &str) -> u32 {
*self.allocation.get(branch).unwrap_or(&0)
}
pub fn is_active(&self) -> bool {
self.state == ExperimentState::Active
}
pub fn accepts_traffic(&self) -> bool {
matches!(self.state, ExperimentState::Active | ExperimentState::Paused)
}
pub fn start(&mut self) -> Result<()> {
if self.state != ExperimentState::Draft && self.state != ExperimentState::Paused {
return Err(ABTestingError::Configuration(format!(
"Cannot start experiment in {:?} state",
self.state
)));
}
self.state = ExperimentState::Active;
self.started_at = Some(chrono::Utc::now());
Ok(())
}
pub fn pause(&mut self) -> Result<()> {
if self.state != ExperimentState::Active {
return Err(ABTestingError::Configuration(format!(
"Cannot pause experiment in {:?} state",
self.state
)));
}
self.state = ExperimentState::Paused;
Ok(())
}
pub fn complete(&mut self, winner: Option<String>) -> Result<()> {
if let Some(ref w) = winner {
if !self.branches.contains(w) {
return Err(ABTestingError::BranchNotFound(w.clone()));
}
}
self.state = ExperimentState::Completed;
self.completed_at = Some(chrono::Utc::now());
self.winner = winner;
Ok(())
}
pub fn archive(&mut self) {
self.state = ExperimentState::Archived;
}
pub fn duration(&self) -> Option<chrono::Duration> {
self.started_at.map(|start| {
let end = self.completed_at.unwrap_or_else(chrono::Utc::now);
end.signed_duration_since(start)
})
}
}
#[derive(Debug, Clone)]
pub struct ExperimentConfig {
pub min_sample_size: u64,
pub confidence_level: f64,
pub allow_reassignment: bool,
pub sticky_sessions: bool,
pub excluded_groups: Vec<String>,
pub included_groups: Vec<String>,
pub auto_complete: bool,
pub max_duration_hours: Option<u32>,
}
impl Default for ExperimentConfig {
fn default() -> Self {
Self {
min_sample_size: 100,
confidence_level: 0.95,
allow_reassignment: false,
sticky_sessions: true,
excluded_groups: Vec::new(),
included_groups: Vec::new(),
auto_complete: false,
max_duration_hours: None,
}
}
}
impl ExperimentConfig {
pub fn with_min_sample_size(mut self, size: u64) -> Self {
self.min_sample_size = size;
self
}
pub fn with_confidence_level(mut self, level: f64) -> Self {
self.confidence_level = level;
self
}
pub fn with_sticky_sessions(mut self, sticky: bool) -> Self {
self.sticky_sessions = sticky;
self
}
pub fn with_excluded_groups(mut self, groups: Vec<String>) -> Self {
self.excluded_groups = groups;
self
}
pub fn with_included_groups(mut self, groups: Vec<String>) -> Self {
self.included_groups = groups;
self
}
pub fn with_auto_complete(mut self, auto: bool) -> Self {
self.auto_complete = auto;
self
}
pub fn with_max_duration(mut self, hours: u32) -> Self {
self.max_duration_hours = Some(hours);
self
}
}
#[derive(Debug, Clone)]
pub struct Variant {
pub name: String,
pub allocation: u32,
pub description: Option<String>,
pub is_control: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_experiment_new() {
let exp = Experiment::new(
"test",
vec!["control".to_string(), "treatment".to_string()],
);
assert_eq!(exp.name, "test");
assert_eq!(exp.branches.len(), 2);
assert_eq!(exp.state, ExperimentState::Draft);
assert_eq!(exp.get_allocation("control"), 50);
assert_eq!(exp.get_allocation("treatment"), 50);
}
#[test]
fn test_experiment_uneven_branches() {
let exp = Experiment::new(
"test",
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
let total: u32 = exp.allocation.values().sum();
assert_eq!(total, 100);
}
#[test]
fn test_custom_allocation() {
let exp = Experiment::new(
"test",
vec!["control".to_string(), "treatment".to_string()],
);
let mut alloc = HashMap::new();
alloc.insert("control".to_string(), 80);
alloc.insert("treatment".to_string(), 20);
let exp = exp.with_allocation(alloc).unwrap();
assert_eq!(exp.get_allocation("control"), 80);
assert_eq!(exp.get_allocation("treatment"), 20);
}
#[test]
fn test_invalid_allocation() {
let exp = Experiment::new(
"test",
vec!["control".to_string(), "treatment".to_string()],
);
let mut alloc = HashMap::new();
alloc.insert("control".to_string(), 60);
alloc.insert("treatment".to_string(), 60);
let result = exp.with_allocation(alloc);
assert!(result.is_err());
}
#[test]
fn test_experiment_lifecycle() {
let mut exp = Experiment::new(
"test",
vec!["control".to_string(), "treatment".to_string()],
);
exp.start().unwrap();
assert_eq!(exp.state, ExperimentState::Active);
assert!(exp.started_at.is_some());
exp.pause().unwrap();
assert_eq!(exp.state, ExperimentState::Paused);
exp.start().unwrap();
assert_eq!(exp.state, ExperimentState::Active);
exp.complete(Some("treatment".to_string())).unwrap();
assert_eq!(exp.state, ExperimentState::Completed);
assert_eq!(exp.winner, Some("treatment".to_string()));
}
#[test]
fn test_control_branch() {
let exp = Experiment::new(
"test",
vec!["control".to_string(), "treatment".to_string()],
);
assert_eq!(exp.control_branch(), Some(&"control".to_string()));
}
#[test]
fn test_experiment_config() {
let config = ExperimentConfig::default()
.with_min_sample_size(1000)
.with_confidence_level(0.99)
.with_sticky_sessions(true)
.with_auto_complete(true);
assert_eq!(config.min_sample_size, 1000);
assert_eq!(config.confidence_level, 0.99);
assert!(config.sticky_sessions);
assert!(config.auto_complete);
}
}