mockforge_core/ab_testing/
middleware.rs1use crate::ab_testing::manager::VariantManager;
7use crate::ab_testing::types::{ABTestConfig, MockVariant, VariantSelectionStrategy};
8use crate::error::Result;
9use axum::body::Body;
10use axum::http::{HeaderMap, StatusCode};
11use axum::response::Response;
12use rand::Rng;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tracing::{debug, trace, warn};
16
17#[derive(Clone)]
19pub struct ABTestingMiddlewareState {
20 pub variant_manager: Arc<VariantManager>,
22}
23
24impl ABTestingMiddlewareState {
25 pub fn new(variant_manager: Arc<VariantManager>) -> Self {
27 Self { variant_manager }
28 }
29}
30
31pub async fn select_variant(
35 config: &ABTestConfig,
36 request_headers: &HeaderMap,
37 request_uri: &str,
38 variant_manager: &VariantManager,
39) -> Result<Option<MockVariant>> {
40 if !config.enabled {
42 return Ok(None);
43 }
44
45 let now = chrono::Utc::now();
46 if let Some(start_time) = config.start_time {
47 if now < start_time {
48 return Ok(None);
49 }
50 }
51 if let Some(end_time) = config.end_time {
52 if now > end_time {
53 return Ok(None);
54 }
55 }
56
57 let variant_id = match config.strategy {
59 VariantSelectionStrategy::Random => select_variant_random(&config.allocations)?,
60 VariantSelectionStrategy::ConsistentHash => {
61 select_variant_consistent_hash(config, request_headers, request_uri)?
62 }
63 VariantSelectionStrategy::RoundRobin => {
64 select_variant_round_robin(config, variant_manager).await?
65 }
66 VariantSelectionStrategy::StickySession => {
67 select_variant_sticky_session(config, request_headers)?
68 }
69 };
70
71 let variant = config.variants.iter().find(|v| v.variant_id == variant_id).cloned();
73
74 if variant.is_none() {
75 warn!("Selected variant '{}' not found in test '{}'", variant_id, config.test_name);
76 }
77
78 Ok(variant)
79}
80
81fn select_variant_random(
83 allocations: &[crate::ab_testing::types::VariantAllocation],
84) -> Result<String> {
85 let mut rng = rand::thread_rng();
86 let random_value = rng.gen_range(0.0..100.0);
87 let mut cumulative = 0.0;
88
89 for allocation in allocations {
90 cumulative += allocation.percentage;
91 if random_value <= cumulative {
92 return Ok(allocation.variant_id.clone());
93 }
94 }
95
96 allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
98 Err(crate::error::Error::validation("No allocations defined".to_string()))
99 })
100}
101
102fn select_variant_consistent_hash(
104 config: &ABTestConfig,
105 request_headers: &HeaderMap,
106 request_uri: &str,
107) -> Result<String> {
108 let attribute = extract_hash_attribute(request_headers, request_uri);
110
111 let hash_value = VariantManager::consistent_hash(&attribute, 100) as f64;
113
114 let mut cumulative = 0.0;
116 for allocation in &config.allocations {
117 cumulative += allocation.percentage;
118 if hash_value <= cumulative {
119 return Ok(allocation.variant_id.clone());
120 }
121 }
122
123 config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
125 Err(crate::error::Error::validation("No allocations defined".to_string()))
126 })
127}
128
129async fn select_variant_round_robin(
131 config: &ABTestConfig,
132 variant_manager: &VariantManager,
133) -> Result<String> {
134 let index = variant_manager
135 .increment_round_robin(&config.method, &config.endpoint_path, config.allocations.len())
136 .await;
137
138 config
139 .allocations
140 .get(index)
141 .map(|a| Ok(a.variant_id.clone()))
142 .unwrap_or_else(|| {
143 Err(crate::error::Error::validation("Invalid allocation index".to_string()))
144 })
145}
146
147fn select_variant_sticky_session(
149 config: &ABTestConfig,
150 request_headers: &HeaderMap,
151) -> Result<String> {
152 let session_id = extract_session_id(request_headers);
154
155 let hash_value = VariantManager::consistent_hash(&session_id, 100) as f64;
157
158 let mut cumulative = 0.0;
159 for allocation in &config.allocations {
160 cumulative += allocation.percentage;
161 if hash_value <= cumulative {
162 return Ok(allocation.variant_id.clone());
163 }
164 }
165
166 config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
168 Err(crate::error::Error::validation("No allocations defined".to_string()))
169 })
170}
171
172fn extract_hash_attribute(request_headers: &HeaderMap, request_uri: &str) -> String {
174 if let Some(user_id) = request_headers.get("X-User-ID") {
176 if let Ok(user_id_str) = user_id.to_str() {
177 return format!("user:{}", user_id_str);
178 }
179 }
180
181 if let Some(query_start) = request_uri.find('?') {
183 let query = &request_uri[query_start + 1..];
184 for param in query.split('&') {
185 if let Some((key, value)) = param.split_once('=') {
186 if key == "user_id" || key == "userId" {
187 return format!("user:{}", value);
188 }
189 }
190 }
191 }
192
193 if let Some(ip) = request_headers.get("X-Forwarded-For") {
195 if let Ok(ip_str) = ip.to_str() {
196 return format!("ip:{}", ip_str.split(',').next().unwrap_or("unknown"));
197 }
198 }
199
200 format!("random:{}", uuid::Uuid::new_v4())
202}
203
204fn extract_session_id(request_headers: &HeaderMap) -> String {
206 if let Some(cookie_header) = request_headers.get("Cookie") {
208 if let Ok(cookie_str) = cookie_header.to_str() {
209 for cookie in cookie_str.split(';') {
210 let cookie = cookie.trim();
211 if let Some((key, value)) = cookie.split_once('=') {
212 if key == "session_id" || key == "sessionId" || key == "JSESSIONID" {
213 return value.to_string();
214 }
215 }
216 }
217 }
218 }
219
220 if let Some(session_id) = request_headers.get("X-Session-ID") {
222 if let Ok(session_id_str) = session_id.to_str() {
223 return session_id_str.to_string();
224 }
225 }
226
227 extract_hash_attribute(request_headers, "")
230}
231
232pub fn apply_variant_to_response(
234 variant: &MockVariant,
235 _response: Response<Body>,
236) -> Response<Body> {
237 let mut response_builder = Response::builder().status(
239 StatusCode::from_u16(variant.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
240 );
241
242 for (key, value) in &variant.headers {
244 if let (Ok(key), Ok(value)) = (
245 axum::http::HeaderName::try_from(key.as_str()),
246 axum::http::HeaderValue::try_from(value.as_str()),
247 ) {
248 response_builder = response_builder.header(key, value);
249 }
250 }
251
252 if let Ok(header_name) = axum::http::HeaderName::try_from("X-MockForge-Variant") {
254 if let Ok(header_value) = axum::http::HeaderValue::try_from(variant.variant_id.as_str()) {
255 response_builder = response_builder.header(header_name, header_value);
256 }
257 }
258
259 let body = match serde_json::to_string(&variant.body) {
261 Ok(json_str) => Body::from(json_str),
262 Err(_) => Body::from("{}"), };
264
265 response_builder.body(body).unwrap_or_else(|_| {
266 Response::builder()
268 .status(StatusCode::INTERNAL_SERVER_ERROR)
269 .body(Body::from("{}"))
270 .unwrap()
271 })
272}