use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use mockforge_core::deceptive_canary::DeceptiveCanaryRouter;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Clone)]
pub struct DeceptiveCanaryState {
pub router: Arc<RwLock<DeceptiveCanaryRouter>>,
}
impl DeceptiveCanaryState {
pub fn new(router: DeceptiveCanaryRouter) -> Self {
Self {
router: Arc::new(RwLock::new(router)),
}
}
}
pub async fn deceptive_canary_middleware(req: Request, next: Next) -> Response {
let state = req.extensions().get::<DeceptiveCanaryState>().cloned().unwrap_or_else(|| {
DeceptiveCanaryState::new(DeceptiveCanaryRouter::default())
});
let user_agent = req
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let ip_address = req
.extensions()
.get::<std::net::SocketAddr>()
.map(|addr| addr.ip().to_string())
.or_else(|| {
req.headers()
.get("x-forwarded-for")
.or_else(|| req.headers().get("x-real-ip"))
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
});
let mut headers_map = HashMap::new();
for (key, value) in req.headers() {
let key_str = key.as_str().to_string();
if let Ok(value_str) = value.to_str() {
headers_map.insert(key_str, value_str.to_string());
}
}
let mut query_params = HashMap::new();
if let Some(query) = req.uri().query() {
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
query_params.insert(key.to_string(), value.to_string());
}
}
}
let user_id = req
.headers()
.get("x-user-id")
.or_else(|| req.headers().get("authorization"))
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let router = state.router.read().await;
let should_route = router.should_route_to_canary(
user_agent.as_deref(),
ip_address.as_deref(),
&headers_map,
&query_params,
user_id.as_deref(),
);
if should_route {
debug!("Routing request to deceptive canary: {} {}", req.method(), req.uri().path());
let canary_url = router.config().deceptive_deploy_url.clone();
drop(router);
if !canary_url.is_empty() {
return proxy_to_canary(&canary_url, req).await;
}
} else {
drop(router); }
next.run(req).await
}
async fn proxy_to_canary(canary_url: &str, req: Request) -> Response {
let client = reqwest::Client::new();
let target_url = format!(
"{}{}{}",
canary_url.trim_end_matches('/'),
req.uri().path(),
req.uri().query().map(|q| format!("?{q}")).unwrap_or_default()
);
let method: reqwest::Method = match req.method().as_str().parse() {
Ok(m) => m,
Err(_) => {
warn!("Canary proxy: unsupported HTTP method {}", req.method());
return StatusCode::BAD_GATEWAY.into_response();
}
};
let mut proxy_req = client.request(method, &target_url);
for (name, value) in req.headers() {
let n = name.as_str();
if !matches!(n, "host" | "connection" | "transfer-encoding" | "keep-alive" | "upgrade") {
if let Ok(v) = value.to_str() {
proxy_req = proxy_req.header(n, v);
}
}
}
let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
Ok(b) => b,
Err(e) => {
warn!("Canary proxy: failed to read request body: {e}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
if !body_bytes.is_empty() {
proxy_req = proxy_req.body(body_bytes.to_vec());
}
let canary_response = match proxy_req.send().await {
Ok(r) => r,
Err(e) => {
warn!("Canary proxy: request to {target_url} failed: {e}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let status =
StatusCode::from_u16(canary_response.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let mut builder = Response::builder().status(status);
for (name, value) in canary_response.headers() {
let n = name.as_str();
if !matches!(n, "transfer-encoding" | "connection") {
if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
builder = builder.header(n, v);
}
}
}
builder = builder.header("X-Deceptive-Canary", "true");
let response_body = match canary_response.bytes().await {
Ok(b) => b,
Err(e) => {
warn!("Canary proxy: failed to read response body: {e}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
builder
.body(Body::from(response_body))
.unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
}