mockforge_core/ab_testing/
middleware.rs

1//! A/B testing middleware for variant selection
2//!
3//! This module provides middleware functionality for selecting and applying
4//! mock variants based on A/B test configuration.
5
6use crate::ab_testing::manager::VariantManager;
7use crate::ab_testing::types::{ABTestConfig, MockVariant, VariantSelectionStrategy};
8use crate::error::Result;
9use axum::body::Body;
10use axum::http::{HeaderMap, StatusCode};
11use axum::response::Response;
12use rand::Rng;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tracing::{debug, trace, warn};
16
17/// State for A/B testing middleware
18#[derive(Clone)]
19pub struct ABTestingMiddlewareState {
20    /// Variant manager
21    pub variant_manager: Arc<VariantManager>,
22}
23
24impl ABTestingMiddlewareState {
25    /// Create new middleware state
26    pub fn new(variant_manager: Arc<VariantManager>) -> Self {
27        Self { variant_manager }
28    }
29}
30
31/// Select a variant for a request based on A/B test configuration
32///
33/// This function extracts all needed data from the request before selection
34pub async fn select_variant(
35    config: &ABTestConfig,
36    request_headers: &HeaderMap,
37    request_uri: &str,
38    variant_manager: &VariantManager,
39) -> Result<Option<MockVariant>> {
40    // Check if test is enabled and within time window
41    if !config.enabled {
42        return Ok(None);
43    }
44
45    let now = chrono::Utc::now();
46    if let Some(start_time) = config.start_time {
47        if now < start_time {
48            return Ok(None);
49        }
50    }
51    if let Some(end_time) = config.end_time {
52        if now > end_time {
53            return Ok(None);
54        }
55    }
56
57    // Select variant based on strategy
58    let variant_id = match config.strategy {
59        VariantSelectionStrategy::Random => select_variant_random(&config.allocations)?,
60        VariantSelectionStrategy::ConsistentHash => {
61            select_variant_consistent_hash(config, request_headers, request_uri)?
62        }
63        VariantSelectionStrategy::RoundRobin => {
64            select_variant_round_robin(config, variant_manager).await?
65        }
66        VariantSelectionStrategy::StickySession => {
67            select_variant_sticky_session(config, request_headers)?
68        }
69    };
70
71    // Find the selected variant
72    let variant = config.variants.iter().find(|v| v.variant_id == variant_id).cloned();
73
74    if variant.is_none() {
75        warn!("Selected variant '{}' not found in test '{}'", variant_id, config.test_name);
76    }
77
78    Ok(variant)
79}
80
81/// Select variant using random allocation
82fn select_variant_random(
83    allocations: &[crate::ab_testing::types::VariantAllocation],
84) -> Result<String> {
85    let mut rng = rand::thread_rng();
86    let random_value = rng.gen_range(0.0..100.0);
87    let mut cumulative = 0.0;
88
89    for allocation in allocations {
90        cumulative += allocation.percentage;
91        if random_value <= cumulative {
92            return Ok(allocation.variant_id.clone());
93        }
94    }
95
96    // Fallback to last variant if rounding errors
97    allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
98        Err(crate::error::Error::validation("No allocations defined".to_string()))
99    })
100}
101
102/// Select variant using consistent hashing
103fn select_variant_consistent_hash(
104    config: &ABTestConfig,
105    request_headers: &HeaderMap,
106    request_uri: &str,
107) -> Result<String> {
108    // Try to extract a consistent attribute (e.g., user ID, IP address)
109    let attribute = extract_hash_attribute(request_headers, request_uri);
110
111    // Hash the attribute to get a value between 0-100
112    let hash_value = VariantManager::consistent_hash(&attribute, 100) as f64;
113
114    // Find which allocation bucket this hash falls into
115    let mut cumulative = 0.0;
116    for allocation in &config.allocations {
117        cumulative += allocation.percentage;
118        if hash_value <= cumulative {
119            return Ok(allocation.variant_id.clone());
120        }
121    }
122
123    // Fallback
124    config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
125        Err(crate::error::Error::validation("No allocations defined".to_string()))
126    })
127}
128
129/// Select variant using round-robin
130async fn select_variant_round_robin(
131    config: &ABTestConfig,
132    variant_manager: &VariantManager,
133) -> Result<String> {
134    let index = variant_manager
135        .increment_round_robin(&config.method, &config.endpoint_path, config.allocations.len())
136        .await;
137
138    config
139        .allocations
140        .get(index)
141        .map(|a| Ok(a.variant_id.clone()))
142        .unwrap_or_else(|| {
143            Err(crate::error::Error::validation("Invalid allocation index".to_string()))
144        })
145}
146
147/// Select variant using sticky session
148fn select_variant_sticky_session(
149    config: &ABTestConfig,
150    request_headers: &HeaderMap,
151) -> Result<String> {
152    // Try to get session ID from cookie or header
153    let session_id = extract_session_id(request_headers);
154
155    // Use consistent hashing on session ID
156    let hash_value = VariantManager::consistent_hash(&session_id, 100) as f64;
157
158    let mut cumulative = 0.0;
159    for allocation in &config.allocations {
160        cumulative += allocation.percentage;
161        if hash_value <= cumulative {
162            return Ok(allocation.variant_id.clone());
163        }
164    }
165
166    // Fallback
167    config.allocations.last().map(|a| Ok(a.variant_id.clone())).unwrap_or_else(|| {
168        Err(crate::error::Error::validation("No allocations defined".to_string()))
169    })
170}
171
172/// Extract a consistent attribute for hashing from request
173fn extract_hash_attribute(request_headers: &HeaderMap, request_uri: &str) -> String {
174    // Try to get user ID from headers
175    if let Some(user_id) = request_headers.get("X-User-ID") {
176        if let Ok(user_id_str) = user_id.to_str() {
177            return format!("user:{}", user_id_str);
178        }
179    }
180
181    // Try to get user ID from query parameters
182    if let Some(query_start) = request_uri.find('?') {
183        let query = &request_uri[query_start + 1..];
184        for param in query.split('&') {
185            if let Some((key, value)) = param.split_once('=') {
186                if key == "user_id" || key == "userId" {
187                    return format!("user:{}", value);
188                }
189            }
190        }
191    }
192
193    // Fallback to IP address
194    if let Some(ip) = request_headers.get("X-Forwarded-For") {
195        if let Ok(ip_str) = ip.to_str() {
196            return format!("ip:{}", ip_str.split(',').next().unwrap_or("unknown"));
197        }
198    }
199
200    // Final fallback: use a random value (not ideal for consistent hashing)
201    format!("random:{}", uuid::Uuid::new_v4())
202}
203
204/// Extract session ID from request
205fn extract_session_id(request_headers: &HeaderMap) -> String {
206    // Try to get session ID from cookie
207    if let Some(cookie_header) = request_headers.get("Cookie") {
208        if let Ok(cookie_str) = cookie_header.to_str() {
209            for cookie in cookie_str.split(';') {
210                let cookie = cookie.trim();
211                if let Some((key, value)) = cookie.split_once('=') {
212                    if key == "session_id" || key == "sessionId" || key == "JSESSIONID" {
213                        return value.to_string();
214                    }
215                }
216            }
217        }
218    }
219
220    // Try to get session ID from header
221    if let Some(session_id) = request_headers.get("X-Session-ID") {
222        if let Ok(session_id_str) = session_id.to_str() {
223            return session_id_str.to_string();
224        }
225    }
226
227    // Fallback: generate a session ID based on IP
228    // We need to pass a dummy URI since extract_hash_attribute needs it
229    extract_hash_attribute(request_headers, "")
230}
231
232/// Apply variant to response
233pub fn apply_variant_to_response(
234    variant: &MockVariant,
235    _response: Response<Body>,
236) -> Response<Body> {
237    // Create new response with variant body
238    let mut response_builder = Response::builder().status(
239        StatusCode::from_u16(variant.status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
240    );
241
242    // Add variant headers
243    for (key, value) in &variant.headers {
244        if let (Ok(key), Ok(value)) = (
245            axum::http::HeaderName::try_from(key.as_str()),
246            axum::http::HeaderValue::try_from(value.as_str()),
247        ) {
248            response_builder = response_builder.header(key, value);
249        }
250    }
251
252    // Add variant ID header for tracking
253    if let Ok(header_name) = axum::http::HeaderName::try_from("X-MockForge-Variant") {
254        if let Ok(header_value) = axum::http::HeaderValue::try_from(variant.variant_id.as_str()) {
255            response_builder = response_builder.header(header_name, header_value);
256        }
257    }
258
259    // Convert variant body to response body
260    let body = match serde_json::to_string(&variant.body) {
261        Ok(json_str) => Body::from(json_str),
262        Err(_) => Body::from("{}"), // Fallback to empty JSON
263    };
264
265    response_builder.body(body).unwrap_or_else(|_| {
266        // Fallback response if building fails
267        Response::builder()
268            .status(StatusCode::INTERNAL_SERVER_ERROR)
269            .body(Body::from("{}"))
270            .unwrap()
271    })
272}