mockforge_core/ab_testing/
middleware.rs1use crate::ab_testing::manager::VariantManager;
7use crate::ab_testing::types::{ABTestConfig, MockVariant, VariantSelectionStrategy};
8use crate::error::Result;
9use axum::body::Body;
10use axum::http::{HeaderMap, StatusCode};
11use axum::response::Response;
12use rand::Rng;
13use std::sync::Arc;
14use tracing::warn;
15
16#[derive(Clone)]
18pub struct ABTestingMiddlewareState {
19 pub variant_manager: Arc<VariantManager>,
21}
22
23impl ABTestingMiddlewareState {
24 pub fn new(variant_manager: Arc<VariantManager>) -> Self {
26 Self { variant_manager }
27 }
28}
29
30pub async fn select_variant(
34 config: &ABTestConfig,
35 request_headers: &HeaderMap,
36 request_uri: &str,
37 variant_manager: &VariantManager,
38) -> Result<Option<MockVariant>> {
39 if !config.enabled {
41 return Ok(None);
42 }
43
44 let now = chrono::Utc::now();
45 if let Some(start_time) = config.start_time {
46 if now < start_time {
47 return Ok(None);
48 }
49 }
50 if let Some(end_time) = config.end_time {
51 if now > end_time {
52 return Ok(None);
53 }
54 }
55
56 let variant_id = match config.strategy {
58 VariantSelectionStrategy::Random => select_variant_random(&config.allocations)?,
59 VariantSelectionStrategy::ConsistentHash => {
60 select_variant_consistent_hash(config, request_headers, request_uri)?
61 }
62 VariantSelectionStrategy::RoundRobin => {
63 select_variant_round_robin(config, variant_manager).await?
64 }
65 VariantSelectionStrategy::StickySession => {
66 select_variant_sticky_session(config, request_headers)?
67 }
68 };
69
70 let variant = config.variants.iter().find(|v| v.variant_id == variant_id).cloned();
72
73 if variant.is_none() {
74 warn!("Selected variant '{}' not found in test '{}'", variant_id, config.test_name);
75 }
76
77 Ok(variant)
78}
79
80fn select_variant_random(
82 allocations: &[crate::ab_testing::types::VariantAllocation],
83) -> Result<String> {
84 let mut rng = rand::thread_rng();
85 let random_value = rng.gen_range(0.0..100.0);
86 let mut cumulative = 0.0;
87
88 for allocation in allocations {
89 cumulative += allocation.percentage;
90 if random_value <= cumulative {
91 return Ok(allocation.variant_id.clone());
92 }
93 }
94
95 allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
97 Err(crate::error::Error::validation("No allocations defined".to_string()))
98 })
99}
100
101fn select_variant_consistent_hash(
103 config: &ABTestConfig,
104 request_headers: &HeaderMap,
105 request_uri: &str,
106) -> Result<String> {
107 let attribute = extract_hash_attribute(request_headers, request_uri);
109
110 let hash_value = VariantManager::consistent_hash(&attribute, 100) as f64;
112
113 let mut cumulative = 0.0;
115 for allocation in &config.allocations {
116 cumulative += allocation.percentage;
117 if hash_value <= cumulative {
118 return Ok(allocation.variant_id.clone());
119 }
120 }
121
122 config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
124 Err(crate::error::Error::validation("No allocations defined".to_string()))
125 })
126}
127
128async fn select_variant_round_robin(
130 config: &ABTestConfig,
131 variant_manager: &VariantManager,
132) -> Result<String> {
133 let index = variant_manager
134 .increment_round_robin(&config.method, &config.endpoint_path, config.allocations.len())
135 .await;
136
137 config
138 .allocations
139 .get(index)
140 .map(|a| Ok(a.variant_id.clone()))
141 .unwrap_or_else(|| {
142 Err(crate::error::Error::validation("Invalid allocation index".to_string()))
143 })
144}
145
146fn select_variant_sticky_session(
148 config: &ABTestConfig,
149 request_headers: &HeaderMap,
150) -> Result<String> {
151 let session_id = extract_session_id(request_headers);
153
154 let hash_value = VariantManager::consistent_hash(&session_id, 100) as f64;
156
157 let mut cumulative = 0.0;
158 for allocation in &config.allocations {
159 cumulative += allocation.percentage;
160 if hash_value <= cumulative {
161 return Ok(allocation.variant_id.clone());
162 }
163 }
164
165 config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
167 Err(crate::error::Error::validation("No allocations defined".to_string()))
168 })
169}
170
171fn extract_hash_attribute(request_headers: &HeaderMap, request_uri: &str) -> String {
173 if let Some(user_id) = request_headers.get("X-User-ID") {
175 if let Ok(user_id_str) = user_id.to_str() {
176 return format!("user:{}", user_id_str);
177 }
178 }
179
180 if let Some(query_start) = request_uri.find('?') {
182 let query = &request_uri[query_start + 1..];
183 for param in query.split('&') {
184 if let Some((key, value)) = param.split_once('=') {
185 if key == "user_id" || key == "userId" {
186 return format!("user:{}", value);
187 }
188 }
189 }
190 }
191
192 if let Some(ip) = request_headers.get("X-Forwarded-For") {
194 if let Ok(ip_str) = ip.to_str() {
195 return format!("ip:{}", ip_str.split(',').next().unwrap_or("unknown"));
196 }
197 }
198
199 format!("random:{}", uuid::Uuid::new_v4())
201}
202
203fn extract_session_id(request_headers: &HeaderMap) -> String {
205 if let Some(cookie_header) = request_headers.get("Cookie") {
207 if let Ok(cookie_str) = cookie_header.to_str() {
208 for cookie in cookie_str.split(';') {
209 let cookie = cookie.trim();
210 if let Some((key, value)) = cookie.split_once('=') {
211 if key == "session_id" || key == "sessionId" || key == "JSESSIONID" {
212 return value.to_string();
213 }
214 }
215 }
216 }
217 }
218
219 if let Some(session_id) = request_headers.get("X-Session-ID") {
221 if let Ok(session_id_str) = session_id.to_str() {
222 return session_id_str.to_string();
223 }
224 }
225
226 extract_hash_attribute(request_headers, "")
229}
230
231pub fn apply_variant_to_response(
233 variant: &MockVariant,
234 _response: Response<Body>,
235) -> Response<Body> {
236 let mut response_builder = Response::builder().status(
238 StatusCode::from_u16(variant.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
239 );
240
241 for (key, value) in &variant.headers {
243 if let (Ok(key), Ok(value)) = (
244 axum::http::HeaderName::try_from(key.as_str()),
245 axum::http::HeaderValue::try_from(value.as_str()),
246 ) {
247 response_builder = response_builder.header(key, value);
248 }
249 }
250
251 if let Ok(header_name) = axum::http::HeaderName::try_from("X-MockForge-Variant") {
253 if let Ok(header_value) = axum::http::HeaderValue::try_from(variant.variant_id.as_str()) {
254 response_builder = response_builder.header(header_name, header_value);
255 }
256 }
257
258 let body = match serde_json::to_string(&variant.body) {
260 Ok(json_str) => Body::from(json_str),
261 Err(_) => Body::from("{}"), };
263
264 response_builder.body(body).unwrap_or_else(|_| {
265 Response::builder()
267 .status(StatusCode::INTERNAL_SERVER_ERROR)
268 .body(Body::from("{}"))
269 .unwrap()
270 })
271}