use super::experiment::ExperimentStatus;
use super::{Experiment, Variant};
use anyhow::Result;
use scirs2_core::random::*;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone)]
pub enum RoutingStrategy {
HashBased,
RoundRobin,
SegmentBased(Vec<UserSegment>),
WeightedRandom(HashMap<String, f64>),
Sticky,
}
#[derive(Debug, Clone)]
pub struct UserSegment {
pub name: String,
pub condition: SegmentCondition,
pub variant_name: String,
}
#[derive(Debug, Clone)]
pub enum SegmentCondition {
UserIdPattern(String),
HasAttribute(String, String),
GeoRegion(String),
Platform(Platform),
Custom(String), }
#[derive(Debug, Clone, PartialEq)]
pub enum Platform {
Ios,
Android,
Web,
Desktop,
}
pub struct TrafficSplitter {
default_strategy: RoutingStrategy,
sticky_cache: parking_lot::RwLock<HashMap<String, String>>,
round_robin_counters: parking_lot::RwLock<HashMap<String, usize>>,
}
impl Default for TrafficSplitter {
fn default() -> Self {
Self::new()
}
}
impl TrafficSplitter {
pub fn new() -> Self {
Self {
default_strategy: RoutingStrategy::HashBased,
sticky_cache: parking_lot::RwLock::new(HashMap::new()),
round_robin_counters: parking_lot::RwLock::new(HashMap::new()),
}
}
pub fn with_strategy(strategy: RoutingStrategy) -> Self {
Self {
default_strategy: strategy,
sticky_cache: parking_lot::RwLock::new(HashMap::new()),
round_robin_counters: parking_lot::RwLock::new(HashMap::new()),
}
}
pub fn route(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
if experiment.status() != ExperimentStatus::Running {
return Ok(experiment.config().control_variant.clone());
}
match &self.default_strategy {
RoutingStrategy::HashBased => self.route_hash_based(experiment, user_id),
RoutingStrategy::RoundRobin => self.route_round_robin(experiment),
RoutingStrategy::SegmentBased(segments) => {
self.route_segment_based(experiment, user_id, segments)
},
RoutingStrategy::WeightedRandom(weights) => {
self.route_weighted_random(experiment, weights)
},
RoutingStrategy::Sticky => self.route_sticky(experiment, user_id),
}
}
#[allow(dead_code)]
fn should_include_in_experiment(&self, experiment: &Experiment, user_id: &str) -> bool {
let hash = self.hash_user_id(user_id, &experiment.id().to_string());
let threshold = (experiment.config().traffic_percentage / 100.0 * u64::MAX as f64) as u64;
hash < threshold
}
fn route_hash_based(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
let hash = self.hash_user_id(user_id, &experiment.id().to_string());
let control_variant = &experiment.config().control_variant;
let treatment_variants = &experiment.config().treatment_variants;
if treatment_variants.is_empty() {
return Ok(control_variant.clone());
}
let traffic_percentage = experiment.config().traffic_percentage;
let threshold = (traffic_percentage / 100.0 * u64::MAX as f64) as u64;
if hash < threshold {
let treatment_index = (hash as usize) % treatment_variants.len();
Ok(treatment_variants[treatment_index].clone())
} else {
Ok(control_variant.clone())
}
}
fn route_round_robin(&self, experiment: &Experiment) -> Result<Variant> {
let variants = experiment.all_variants();
let mut counters = self.round_robin_counters.write();
let counter = counters.entry(experiment.id().to_string()).or_insert(0);
let index = *counter % variants.len();
*counter += 1;
Ok(variants[index].clone())
}
fn route_segment_based(
&self,
experiment: &Experiment,
user_id: &str,
segments: &[UserSegment],
) -> Result<Variant> {
for segment in segments {
if !self.matches_segment(user_id, &segment.condition) {
continue;
}
for variant in experiment.all_variants() {
if variant.name() == segment.variant_name {
return Ok(variant.clone());
}
}
}
self.route_hash_based(experiment, user_id)
}
fn route_weighted_random(
&self,
experiment: &Experiment,
weights: &HashMap<String, f64>,
) -> Result<Variant> {
let variants = experiment.all_variants();
let total_weight: f64 = weights.values().sum();
if total_weight == 0.0 {
return self.route_hash_based(experiment, &uuid::Uuid::new_v4().to_string());
}
let mut rng = thread_rng();
let random_value: f64 = rng.random_range(0.0..total_weight);
let mut cumulative_weight = 0.0;
for variant in variants {
let weight = weights.get(variant.name()).unwrap_or(&1.0);
cumulative_weight += weight;
if random_value < cumulative_weight {
return Ok(variant.clone());
}
}
Ok(experiment.config().control_variant.clone())
}
fn route_sticky(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
let cache_key = format!("{}:{}", experiment.id(), user_id);
if let Some(cached_variant) = self.get_cached_variant(experiment, &cache_key) {
return Ok(cached_variant);
}
let variant = self.route_hash_based(experiment, user_id)?;
{
let mut cache = self.sticky_cache.write();
cache.insert(cache_key, variant.name().to_string());
}
Ok(variant)
}
fn get_cached_variant(&self, experiment: &Experiment, cache_key: &str) -> Option<Variant> {
let cache = self.sticky_cache.read();
let variant_name = cache.get(cache_key)?;
for variant in experiment.all_variants() {
if variant.name() == variant_name {
return Some(variant.clone());
}
}
None
}
fn matches_segment(&self, user_id: &str, condition: &SegmentCondition) -> bool {
match condition {
SegmentCondition::UserIdPattern(pattern) => {
user_id.contains(pattern)
},
SegmentCondition::HasAttribute(_, _) => {
false
},
SegmentCondition::GeoRegion(_) => {
false
},
SegmentCondition::Platform(_) => {
false
},
SegmentCondition::Custom(_) => {
false
},
}
}
fn hash_user_id(&self, user_id: &str, experiment_id: &str) -> u64 {
let mut hasher = DefaultHasher::new();
user_id.hash(&mut hasher);
experiment_id.hash(&mut hasher);
hasher.finish()
}
pub fn clear_sticky_cache(&self, experiment_id: &str) {
let mut cache = self.sticky_cache.write();
cache.retain(|k, _| !k.starts_with(&format!("{}:", experiment_id)));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ab_testing::ExperimentConfig;
fn create_test_experiment() -> Experiment {
let config = ExperimentConfig {
name: "Test".to_string(),
description: "Test".to_string(),
control_variant: Variant::new("control", "v1"),
treatment_variants: vec![Variant::new("treatment", "v2")],
traffic_percentage: 100.0,
min_sample_size: 100,
max_duration_hours: 24,
};
let mut exp = Experiment::new(config).expect("operation failed in test");
exp.start().expect("operation failed in test");
exp
}
#[test]
fn test_hash_based_routing() {
let splitter = TrafficSplitter::new();
let experiment = create_test_experiment();
let user_id = "test-user-123";
let variant1 = splitter.route(&experiment, user_id).expect("operation failed in test");
let variant2 = splitter.route(&experiment, user_id).expect("operation failed in test");
assert_eq!(variant1, variant2);
}
#[test]
fn test_round_robin_routing() {
let splitter = TrafficSplitter::with_strategy(RoutingStrategy::RoundRobin);
let experiment = create_test_experiment();
let mut control_count = 0;
let mut treatment_count = 0;
for _ in 0..10 {
let variant =
splitter.route(&experiment, "any-user").expect("operation failed in test");
match variant.name() {
"control" => control_count += 1,
"treatment" => treatment_count += 1,
name => panic!("Unexpected variant name: {}", name),
}
}
assert_eq!(control_count, 5);
assert_eq!(treatment_count, 5);
}
#[test]
fn test_sticky_routing() {
let splitter = TrafficSplitter::with_strategy(RoutingStrategy::Sticky);
let experiment = create_test_experiment();
let user_id = "sticky-user";
let first_variant = splitter.route(&experiment, user_id).expect("operation failed in test");
for _ in 0..10 {
let variant = splitter.route(&experiment, user_id).expect("operation failed in test");
assert_eq!(variant, first_variant);
}
}
#[test]
fn test_traffic_percentage() {
let config = ExperimentConfig {
name: "Test".to_string(),
description: "Test".to_string(),
control_variant: Variant::new("control", "v1"),
treatment_variants: vec![Variant::new("treatment", "v2")],
traffic_percentage: 10.0, min_sample_size: 100,
max_duration_hours: 24,
};
let mut exp = Experiment::new(config).expect("operation failed in test");
exp.start().expect("operation failed in test");
let splitter = TrafficSplitter::new();
let mut included_count = 0;
for i in 0..1000 {
let user_id = format!("user-{}", i);
let variant = splitter.route(&exp, &user_id).expect("operation failed in test");
if variant.name() != "control" || splitter.should_include_in_experiment(&exp, &user_id)
{
included_count += 1;
}
}
let inclusion_rate = included_count as f64 / 1000.0;
assert!((inclusion_rate - 0.1).abs() < 0.05);
}
}