use crate::handlers::ab_testing::ABTestingState;
use axum::{body::Body, extract::Request, middleware::Next, response::Response};
use mockforge_core::ab_testing::{apply_variant_to_response, select_variant};
use std::time::Instant;
use tracing::{debug, trace};
pub async fn ab_testing_middleware(req: Request, next: Next) -> Response<Body> {
let start_time = Instant::now();
let method = req.method().to_string();
let path = req.uri().path().to_string();
let uri = req.uri().to_string();
let headers = req.headers().clone();
let state_opt = req.extensions().get::<ABTestingState>().cloned();
if let Some(state) = state_opt {
if let Some(test_config) = state.variant_manager.get_test(&method, &path).await {
trace!("A/B test found for {} {}", method, path);
match select_variant(&test_config, &headers, &uri, &state.variant_manager).await {
Ok(Some(variant)) => {
debug!("Selected variant '{}' for {} {}", variant.variant_id, method, path);
let response = next.run(req).await;
let response = apply_variant_to_response(&variant, response);
let response_time_ms = start_time.elapsed().as_millis() as f64;
let status_code = response.status().as_u16();
state
.variant_manager
.record_request(
&method,
&path,
&variant.variant_id,
status_code,
response_time_ms,
)
.await;
if let Some(latency_ms) = variant.latency_ms {
tokio::time::sleep(tokio::time::Duration::from_millis(latency_ms)).await;
}
return response;
}
Ok(None) => {
trace!("No variant selected for {} {}", method, path);
}
Err(e) => {
debug!("Error selecting variant for {} {}: {}", method, path, e);
}
}
}
}
next.run(req).await
}