mockforge_core/ab_testing/
middleware.rs

1//! A/B testing middleware for variant selection
2//!
3//! This module provides middleware functionality for selecting and applying
4//! mock variants based on A/B test configuration.
5
6use crate::ab_testing::manager::VariantManager;
7use crate::ab_testing::types::{ABTestConfig, MockVariant, VariantSelectionStrategy};
8use crate::error::Result;
9use axum::body::Body;
10use axum::http::{HeaderMap, StatusCode};
11use axum::response::Response;
12use rand::Rng;
13use std::sync::Arc;
14use tracing::warn;
15
16/// State for A/B testing middleware
17#[derive(Clone)]
18pub struct ABTestingMiddlewareState {
19    /// Variant manager
20    pub variant_manager: Arc<VariantManager>,
21}
22
23impl ABTestingMiddlewareState {
24    /// Create new middleware state
25    pub fn new(variant_manager: Arc<VariantManager>) -> Self {
26        Self { variant_manager }
27    }
28}
29
30/// Select a variant for a request based on A/B test configuration
31///
32/// This function extracts all needed data from the request before selection
33pub async fn select_variant(
34    config: &ABTestConfig,
35    request_headers: &HeaderMap,
36    request_uri: &str,
37    variant_manager: &VariantManager,
38) -> Result<Option<MockVariant>> {
39    // Check if test is enabled and within time window
40    if !config.enabled {
41        return Ok(None);
42    }
43
44    let now = chrono::Utc::now();
45    if let Some(start_time) = config.start_time {
46        if now < start_time {
47            return Ok(None);
48        }
49    }
50    if let Some(end_time) = config.end_time {
51        if now > end_time {
52            return Ok(None);
53        }
54    }
55
56    // Select variant based on strategy
57    let variant_id = match config.strategy {
58        VariantSelectionStrategy::Random => select_variant_random(&config.allocations)?,
59        VariantSelectionStrategy::ConsistentHash => {
60            select_variant_consistent_hash(config, request_headers, request_uri)?
61        }
62        VariantSelectionStrategy::RoundRobin => {
63            select_variant_round_robin(config, variant_manager).await?
64        }
65        VariantSelectionStrategy::StickySession => {
66            select_variant_sticky_session(config, request_headers)?
67        }
68    };
69
70    // Find the selected variant
71    let variant = config.variants.iter().find(|v| v.variant_id == variant_id).cloned();
72
73    if variant.is_none() {
74        warn!("Selected variant '{}' not found in test '{}'", variant_id, config.test_name);
75    }
76
77    Ok(variant)
78}
79
80/// Select variant using random allocation
81fn select_variant_random(
82    allocations: &[crate::ab_testing::types::VariantAllocation],
83) -> Result<String> {
84    let mut rng = rand::thread_rng();
85    let random_value = rng.gen_range(0.0..100.0);
86    let mut cumulative = 0.0;
87
88    for allocation in allocations {
89        cumulative += allocation.percentage;
90        if random_value <= cumulative {
91            return Ok(allocation.variant_id.clone());
92        }
93    }
94
95    // Fallback to last variant if rounding errors
96    allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
97        Err(crate::error::Error::validation("No allocations defined".to_string()))
98    })
99}
100
101/// Select variant using consistent hashing
102fn select_variant_consistent_hash(
103    config: &ABTestConfig,
104    request_headers: &HeaderMap,
105    request_uri: &str,
106) -> Result<String> {
107    // Try to extract a consistent attribute (e.g., user ID, IP address)
108    let attribute = extract_hash_attribute(request_headers, request_uri);
109
110    // Hash the attribute to get a value between 0-100
111    let hash_value = VariantManager::consistent_hash(&attribute, 100) as f64;
112
113    // Find which allocation bucket this hash falls into
114    let mut cumulative = 0.0;
115    for allocation in &config.allocations {
116        cumulative += allocation.percentage;
117        if hash_value <= cumulative {
118            return Ok(allocation.variant_id.clone());
119        }
120    }
121
122    // Fallback
123    config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
124        Err(crate::error::Error::validation("No allocations defined".to_string()))
125    })
126}
127
128/// Select variant using round-robin
129async fn select_variant_round_robin(
130    config: &ABTestConfig,
131    variant_manager: &VariantManager,
132) -> Result<String> {
133    let index = variant_manager
134        .increment_round_robin(&config.method, &config.endpoint_path, config.allocations.len())
135        .await;
136
137    config
138        .allocations
139        .get(index)
140        .map(|a| Ok(a.variant_id.clone()))
141        .unwrap_or_else(|| {
142            Err(crate::error::Error::validation("Invalid allocation index".to_string()))
143        })
144}
145
146/// Select variant using sticky session
147fn select_variant_sticky_session(
148    config: &ABTestConfig,
149    request_headers: &HeaderMap,
150) -> Result<String> {
151    // Try to get session ID from cookie or header
152    let session_id = extract_session_id(request_headers);
153
154    // Use consistent hashing on session ID
155    let hash_value = VariantManager::consistent_hash(&session_id, 100) as f64;
156
157    let mut cumulative = 0.0;
158    for allocation in &config.allocations {
159        cumulative += allocation.percentage;
160        if hash_value <= cumulative {
161            return Ok(allocation.variant_id.clone());
162        }
163    }
164
165    // Fallback
166    config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
167        Err(crate::error::Error::validation("No allocations defined".to_string()))
168    })
169}
170
171/// Extract a consistent attribute for hashing from request
172fn extract_hash_attribute(request_headers: &HeaderMap, request_uri: &str) -> String {
173    // Try to get user ID from headers
174    if let Some(user_id) = request_headers.get("X-User-ID") {
175        if let Ok(user_id_str) = user_id.to_str() {
176            return format!("user:{}", user_id_str);
177        }
178    }
179
180    // Try to get user ID from query parameters
181    if let Some(query_start) = request_uri.find('?') {
182        let query = &request_uri[query_start + 1..];
183        for param in query.split('&') {
184            if let Some((key, value)) = param.split_once('=') {
185                if key == "user_id" || key == "userId" {
186                    return format!("user:{}", value);
187                }
188            }
189        }
190    }
191
192    // Fallback to IP address
193    if let Some(ip) = request_headers.get("X-Forwarded-For") {
194        if let Ok(ip_str) = ip.to_str() {
195            return format!("ip:{}", ip_str.split(',').next().unwrap_or("unknown"));
196        }
197    }
198
199    // Final fallback: use a random value (not ideal for consistent hashing)
200    format!("random:{}", uuid::Uuid::new_v4())
201}
202
203/// Extract session ID from request
204fn extract_session_id(request_headers: &HeaderMap) -> String {
205    // Try to get session ID from cookie
206    if let Some(cookie_header) = request_headers.get("Cookie") {
207        if let Ok(cookie_str) = cookie_header.to_str() {
208            for cookie in cookie_str.split(';') {
209                let cookie = cookie.trim();
210                if let Some((key, value)) = cookie.split_once('=') {
211                    if key == "session_id" || key == "sessionId" || key == "JSESSIONID" {
212                        return value.to_string();
213                    }
214                }
215            }
216        }
217    }
218
219    // Try to get session ID from header
220    if let Some(session_id) = request_headers.get("X-Session-ID") {
221        if let Ok(session_id_str) = session_id.to_str() {
222            return session_id_str.to_string();
223        }
224    }
225
226    // Fallback: generate a session ID based on IP
227    // We need to pass a dummy URI since extract_hash_attribute needs it
228    extract_hash_attribute(request_headers, "")
229}
230
231/// Apply variant to response
232pub fn apply_variant_to_response(
233    variant: &MockVariant,
234    _response: Response<Body>,
235) -> Response<Body> {
236    // Create new response with variant body
237    let mut response_builder = Response::builder().status(
238        StatusCode::from_u16(variant.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
239    );
240
241    // Add variant headers
242    for (key, value) in &variant.headers {
243        if let (Ok(key), Ok(value)) = (
244            axum::http::HeaderName::try_from(key.as_str()),
245            axum::http::HeaderValue::try_from(value.as_str()),
246        ) {
247            response_builder = response_builder.header(key, value);
248        }
249    }
250
251    // Add variant ID header for tracking
252    if let Ok(header_name) = axum::http::HeaderName::try_from("X-MockForge-Variant") {
253        if let Ok(header_value) = axum::http::HeaderValue::try_from(variant.variant_id.as_str()) {
254            response_builder = response_builder.header(header_name, header_value);
255        }
256    }
257
258    // Convert variant body to response body
259    let body = match serde_json::to_string(&variant.body) {
260        Ok(json_str) => Body::from(json_str),
261        Err(_) => Body::from("{}"), // Fallback to empty JSON
262    };
263
264    response_builder.body(body).unwrap_or_else(|_| {
265        // Fallback response if building fails
266        Response::builder()
267            .status(StatusCode::INTERNAL_SERVER_ERROR)
268            .body(Body::from("{}"))
269            .unwrap()
270    })
271}