use axum::{
body::{to_bytes, Body},
extract::Request,
http::{HeaderName, HeaderValue, Method, StatusCode, Uri},
middleware::Next,
response::Response,
};
use mockforge_core::consistency::UnifiedState;
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
#[derive(Clone)]
pub struct RealityProxyConfig {
pub upstream_base: String,
pub client: reqwest::Client,
}
impl RealityProxyConfig {
pub fn from_env() -> Option<Arc<Self>> {
let base = std::env::var("MOCKFORGE_PROXY_UPSTREAM").ok()?;
let trimmed = base.trim().trim_end_matches('/');
if trimmed.is_empty() {
return None;
}
let client = match reqwest::Client::builder().timeout(Duration::from_secs(30)).build() {
Ok(c) => c,
Err(e) => {
warn!(error = %e, "RealityProxy HTTP client init failed; middleware will no-op");
return None;
}
};
Some(Arc::new(Self {
upstream_base: trimmed.to_string(),
client,
}))
}
}
pub async fn reality_proxy_middleware(
config: Arc<RealityProxyConfig>,
req: Request,
next: Next,
) -> Response {
let ratio = req
.extensions()
.get::<UnifiedState>()
.map(|s| s.reality_continuum_ratio)
.unwrap_or(0.0);
if ratio <= 0.0 {
return next.run(req).await;
}
let should_proxy = if ratio >= 1.0 {
true
} else {
rand::random::<f64>() < ratio
};
if !should_proxy {
return next.run(req).await;
}
match forward_to_upstream(&config, req).await {
Ok(resp) => resp,
Err(err) => {
warn!(error = %err, "Reality proxy upstream request failed");
let body = serde_json::to_vec(&serde_json::json!({
"error": "reality_proxy_upstream_failed",
"message": err.to_string(),
}))
.unwrap_or_default();
let mut resp = Response::new(Body::from(body));
*resp.status_mut() = StatusCode::BAD_GATEWAY;
resp.headers_mut().insert(
axum::http::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
resp
}
}
}
async fn forward_to_upstream(
config: &RealityProxyConfig,
req: Request,
) -> Result<Response, ProxyError> {
let (parts, body) = req.into_parts();
const MAX_BODY: usize = 16 * 1024 * 1024;
let body_bytes = to_bytes(body, MAX_BODY)
.await
.map_err(|e| ProxyError::ReadBody(e.to_string()))?;
let upstream_uri = build_upstream_uri(&config.upstream_base, &parts.uri)?;
let method = reqwest_method(&parts.method);
let mut req_builder = config.client.request(method, &upstream_uri);
for (name, value) in parts.headers.iter() {
if is_hop_by_hop(name) {
continue;
}
if name == axum::http::header::HOST {
continue;
}
req_builder = req_builder.header(name.as_str(), value);
}
if !body_bytes.is_empty() {
req_builder = req_builder.body(body_bytes);
}
let upstream_resp = req_builder.send().await.map_err(ProxyError::Send)?;
let status = upstream_resp.status();
let headers = upstream_resp.headers().clone();
let resp_bytes = upstream_resp.bytes().await.map_err(ProxyError::ReadResponse)?;
let mut response = Response::builder().status(status.as_u16());
{
let response_headers = response.headers_mut().expect("Response builder must have headers");
for (name, value) in headers.iter() {
if is_hop_by_hop_str(name.as_str()) {
continue;
}
if let Ok(hname) = HeaderName::from_bytes(name.as_str().as_bytes()) {
if let Ok(hval) = HeaderValue::from_bytes(value.as_bytes()) {
response_headers.insert(hname, hval);
}
}
}
response_headers.insert(
HeaderName::from_static("x-mockforge-source"),
HeaderValue::from_static("upstream"),
);
}
response
.body(Body::from(resp_bytes))
.map_err(|e| ProxyError::BuildResponse(e.to_string()))
}
fn build_upstream_uri(base: &str, original: &Uri) -> Result<String, ProxyError> {
let path = original.path();
let query = original.query().map(|q| format!("?{}", q)).unwrap_or_default();
Ok(format!("{}{}{}", base, path, query))
}
fn reqwest_method(m: &Method) -> reqwest::Method {
reqwest::Method::from_bytes(m.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
}
fn is_hop_by_hop(name: &HeaderName) -> bool {
is_hop_by_hop_str(name.as_str())
}
fn is_hop_by_hop_str(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"connection"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailers"
| "transfer-encoding"
| "upgrade"
| "content-length"
)
}
#[derive(Debug, thiserror::Error)]
enum ProxyError {
#[error("failed to read request body: {0}")]
ReadBody(String),
#[error("upstream request send failed: {0}")]
Send(reqwest::Error),
#[error("upstream response read failed: {0}")]
ReadResponse(reqwest::Error),
#[error("response build failed: {0}")]
BuildResponse(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_env_disabled_when_unset() {
std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
assert!(RealityProxyConfig::from_env().is_none());
}
#[test]
fn from_env_disabled_when_blank() {
std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", " ");
assert!(RealityProxyConfig::from_env().is_none());
std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
}
#[test]
fn from_env_strips_trailing_slash() {
std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "https://api.example.com/");
let cfg = RealityProxyConfig::from_env().expect("config");
assert_eq!(cfg.upstream_base, "https://api.example.com");
std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
}
#[test]
fn build_upstream_uri_preserves_path_and_query() {
let base = "https://api.example.com";
let uri: Uri = "/users/42?role=admin".parse().unwrap();
let result = build_upstream_uri(base, &uri).unwrap();
assert_eq!(result, "https://api.example.com/users/42?role=admin");
}
#[test]
fn build_upstream_uri_no_query() {
let base = "https://api.example.com";
let uri: Uri = "/health".parse().unwrap();
let result = build_upstream_uri(base, &uri).unwrap();
assert_eq!(result, "https://api.example.com/health");
}
#[test]
fn hop_by_hop_headers_are_filtered() {
assert!(is_hop_by_hop_str("Connection"));
assert!(is_hop_by_hop_str("transfer-encoding"));
assert!(is_hop_by_hop_str("UPGRADE"));
assert!(!is_hop_by_hop_str("authorization"));
assert!(!is_hop_by_hop_str("x-custom-header"));
}
}