mockforge_core/ab_testing/
manager.rs1use crate::ab_testing::types::{ABTestConfig, MockVariant, VariantAnalytics};
6use crate::error::{Error, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone)]
14pub struct VariantManager {
15 tests: Arc<RwLock<HashMap<String, ABTestConfig>>>,
17 analytics: Arc<RwLock<HashMap<String, HashMap<String, VariantAnalytics>>>>,
19 round_robin_counters: Arc<RwLock<HashMap<String, usize>>>,
21}
22
23impl VariantManager {
24 pub fn new() -> Self {
26 Self {
27 tests: Arc::new(RwLock::new(HashMap::new())),
28 analytics: Arc::new(RwLock::new(HashMap::new())),
29 round_robin_counters: Arc::new(RwLock::new(HashMap::new())),
30 }
31 }
32
33 pub async fn register_test(&self, config: ABTestConfig) -> Result<()> {
35 config.validate_allocations().map_err(|e| Error::validation(e))?;
37
38 let key = Self::endpoint_key(&config.method, &config.endpoint_path);
39 let mut tests = self.tests.write().await;
40 tests.insert(key.clone(), config.clone());
41
42 let mut analytics = self.analytics.write().await;
44 let variant_analytics = analytics.entry(key).or_insert_with(HashMap::new);
45 for variant in &config.variants {
46 variant_analytics.insert(
47 variant.variant_id.clone(),
48 VariantAnalytics::new(variant.variant_id.clone()),
49 );
50 }
51
52 info!(
53 "Registered A/B test '{}' for {} {} with {} variants",
54 config.test_name,
55 config.method,
56 config.endpoint_path,
57 config.variants.len()
58 );
59
60 Ok(())
61 }
62
63 pub async fn get_test(&self, method: &str, path: &str) -> Option<ABTestConfig> {
65 let key = Self::endpoint_key(method, path);
66 let tests = self.tests.read().await;
67 tests.get(&key).cloned()
68 }
69
70 pub async fn list_tests(&self) -> Vec<ABTestConfig> {
72 let tests = self.tests.read().await;
73 tests.values().cloned().collect()
74 }
75
76 pub async fn remove_test(&self, method: &str, path: &str) -> Result<()> {
78 let key = Self::endpoint_key(method, path);
79 let mut tests = self.tests.write().await;
80 tests.remove(&key);
81
82 info!("Removed A/B test for {} {}", method, path);
86 Ok(())
87 }
88
89 pub async fn get_variant(
91 &self,
92 method: &str,
93 path: &str,
94 variant_id: &str,
95 ) -> Option<MockVariant> {
96 if let Some(config) = self.get_test(method, path).await {
97 config.variants.iter().find(|v| v.variant_id == variant_id).cloned()
98 } else {
99 None
100 }
101 }
102
103 pub async fn record_request(
105 &self,
106 method: &str,
107 path: &str,
108 variant_id: &str,
109 status_code: u16,
110 response_time_ms: f64,
111 ) {
112 let key = Self::endpoint_key(method, path);
113 let mut analytics = self.analytics.write().await;
114 if let Some(variant_analytics) = analytics.get_mut(&key) {
115 if let Some(analytics_data) = variant_analytics.get_mut(variant_id) {
116 analytics_data.record_request(status_code, response_time_ms);
117 } else {
118 let mut new_analytics = VariantAnalytics::new(variant_id.to_string());
120 new_analytics.record_request(status_code, response_time_ms);
121 variant_analytics.insert(variant_id.to_string(), new_analytics);
122 }
123 }
124 }
125
126 pub async fn get_variant_analytics(
128 &self,
129 method: &str,
130 path: &str,
131 variant_id: &str,
132 ) -> Option<VariantAnalytics> {
133 let key = Self::endpoint_key(method, path);
134 let analytics = self.analytics.read().await;
135 analytics.get(&key)?.get(variant_id).cloned()
136 }
137
138 pub async fn get_endpoint_analytics(
140 &self,
141 method: &str,
142 path: &str,
143 ) -> HashMap<String, VariantAnalytics> {
144 let key = Self::endpoint_key(method, path);
145 let analytics = self.analytics.read().await;
146 analytics.get(&key).cloned().unwrap_or_default()
147 }
148
149 pub async fn get_round_robin_index(&self, method: &str, path: &str) -> usize {
151 let key = Self::endpoint_key(method, path);
152 let mut counters = self.round_robin_counters.write().await;
153 let counter = counters.entry(key).or_insert(0);
154 *counter
155 }
156
157 pub async fn increment_round_robin(&self, method: &str, path: &str, max: usize) -> usize {
159 let key = Self::endpoint_key(method, path);
160 let mut counters = self.round_robin_counters.write().await;
161 let counter = counters.entry(key).or_insert(0);
162 let current = *counter;
163 *counter = (*counter + 1) % max;
164 current
165 }
166
167 pub fn consistent_hash(attribute: &str, num_variants: usize) -> usize {
172 use std::hash::{Hash, Hasher};
173 let mut hasher = std::collections::hash_map::DefaultHasher::new();
174 attribute.hash(&mut hasher);
175 (hasher.finish() as usize) % num_variants
176 }
177
178 fn endpoint_key(method: &str, path: &str) -> String {
180 format!("{} {}", method.to_uppercase(), path)
181 }
182}
183
184impl Default for VariantManager {
185 fn default() -> Self {
186 Self::new()
187 }
188}