Skip to main content

mockforge_http/
reality_proxy.rs

1//! Reality-slider-driven mock/proxy switching middleware (#222).
2//!
3//! When `MOCKFORGE_PROXY_UPSTREAM` is set on the process, this middleware
4//! probabilistically forwards a fraction of incoming requests to that URL
5//! based on the active workspace's `reality_continuum_ratio`. The fraction
6//! is per-request: ratio 0.0 = always-mock, 1.0 = always-proxy, 0.5 =
7//! coin-flip per request.
8//!
9//! ## Design
10//!
11//! The middleware is a no-op when:
12//!   - `MOCKFORGE_PROXY_UPSTREAM` is unset (e.g., local dev)
13//!   - the request has no associated `UnifiedState` extension (set by the
14//!     consistency middleware upstream of this one)
15//!   - the resolved ratio is exactly 0.0
16//!
17//! When proxying, the middleware reconstructs the request against the
18//! upstream base URL preserving method, path, query, headers, and body,
19//! then streams the upstream response back to the caller. Any failure
20//! falls through to the mock chain — the mock is the durable path; the
21//! upstream is best-effort.
22//!
23//! Wiring: insert the layer between `consistency_middleware` (which
24//! injects `UnifiedState`) and the route handlers. The dependency on
25//! `UnifiedState` is read-only, so ordering relative to recording or
26//! tracing layers doesn't matter.
27
28use axum::{
29    body::{to_bytes, Body},
30    extract::Request,
31    http::{
32        header::{CONTENT_TYPE, HOST},
33        HeaderName, HeaderValue, Method, StatusCode, Uri,
34    },
35    middleware::Next,
36    response::Response,
37};
38use mockforge_core::consistency::UnifiedState;
39use reqwest::Method as ReqwestMethod;
40use std::sync::Arc;
41use std::time::Duration;
42use tracing::warn;
43
44/// Cheap-to-clone handle holding the upstream base URL and a shared
45/// reqwest client. Constructed once at server startup; the layer closure
46/// holds an `Arc<RealityProxyConfig>` so per-request work is just an
47/// arc-clone.
48#[derive(Clone)]
49pub struct RealityProxyConfig {
50    /// Base URL — protocol + host + (optional) port, no trailing slash.
51    /// Path/query are taken from the incoming request.
52    pub upstream_base: String,
53    /// Shared HTTP client used for all upstream requests.
54    pub client: reqwest::Client,
55}
56
57impl RealityProxyConfig {
58    /// Construct from `MOCKFORGE_PROXY_UPSTREAM`. Returns None when the
59    /// env var is missing or empty (no-op middleware) or when the HTTP
60    /// client can't be built (logged as a warning).
61    pub fn from_env() -> Option<Arc<Self>> {
62        let base = std::env::var("MOCKFORGE_PROXY_UPSTREAM").ok()?;
63        let trimmed = base.trim().trim_end_matches('/');
64        if trimmed.is_empty() {
65            return None;
66        }
67        let client = match reqwest::Client::builder().timeout(Duration::from_secs(30)).build() {
68            Ok(c) => c,
69            Err(e) => {
70                warn!(error = %e, "RealityProxy HTTP client init failed; middleware will no-op");
71                return None;
72            }
73        };
74        Some(Arc::new(Self {
75            upstream_base: trimmed.to_string(),
76            client,
77        }))
78    }
79}
80
81/// The middleware function. Reads `reality_continuum_ratio` from the
82/// `UnifiedState` request extension, rolls a per-request RNG, and either
83/// forwards to upstream or hands off to the next layer (mock chain).
84pub async fn reality_proxy_middleware(
85    config: Arc<RealityProxyConfig>,
86    req: Request,
87    next: Next,
88) -> Response {
89    let ratio = req
90        .extensions()
91        .get::<UnifiedState>()
92        .map(|s| s.reality_continuum_ratio)
93        .unwrap_or(0.0);
94
95    // Fast path: no upstream desired for this request.
96    if ratio <= 0.0 {
97        return next.run(req).await;
98    }
99
100    let should_proxy = if ratio >= 1.0 {
101        true
102    } else {
103        rand::random::<f64>() < ratio
104    };
105
106    if !should_proxy {
107        return next.run(req).await;
108    }
109
110    match forward_to_upstream(&config, req).await {
111        Ok(resp) => resp,
112        Err(err) => {
113            // We've already consumed the request body to forward it, so
114            // we can't fall back to the mock chain. Surface 502 — the
115            // alternative (silent retry / synthetic mock) would hide
116            // real upstream incidents from operators.
117            warn!(error = %err, "Reality proxy upstream request failed");
118            let body = serde_json::to_vec(&serde_json::json!({
119                "error": "reality_proxy_upstream_failed",
120                "message": err.to_string(),
121            }))
122            .unwrap_or_default();
123            let mut resp = Response::new(Body::from(body));
124            *resp.status_mut() = StatusCode::BAD_GATEWAY;
125            resp.headers_mut()
126                .insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
127            resp
128        }
129    }
130}
131
132async fn forward_to_upstream(
133    config: &RealityProxyConfig,
134    req: Request,
135) -> Result<Response, ProxyError> {
136    let (parts, body) = req.into_parts();
137    // Cap at 16 MiB — same as Axum's default request size limit.
138    // Anything larger and we'd be holding too much in memory for a
139    // simple proxy hop; better to fail loudly than swap-thrash.
140    const MAX_BODY: usize = 16 * 1024 * 1024;
141    let body_bytes = to_bytes(body, MAX_BODY)
142        .await
143        .map_err(|e| ProxyError::ReadBody(e.to_string()))?;
144
145    let upstream_uri = build_upstream_uri(&config.upstream_base, &parts.uri)?;
146    let method = reqwest_method(&parts.method);
147    let mut req_builder = config.client.request(method, &upstream_uri);
148
149    // Copy headers, dropping hop-by-hop / Host so reqwest sets correct ones.
150    for (name, value) in parts.headers.iter() {
151        if is_hop_by_hop(name) {
152            continue;
153        }
154        if name == HOST {
155            continue;
156        }
157        req_builder = req_builder.header(name.as_str(), value);
158    }
159
160    if !body_bytes.is_empty() {
161        req_builder = req_builder.body(body_bytes);
162    }
163
164    let upstream_resp = req_builder.send().await.map_err(ProxyError::Send)?;
165    let status = upstream_resp.status();
166    let headers = upstream_resp.headers().clone();
167    let resp_bytes = upstream_resp.bytes().await.map_err(ProxyError::ReadResponse)?;
168
169    let mut response = Response::builder().status(status.as_u16());
170    {
171        let response_headers = response.headers_mut().expect("Response builder must have headers");
172        for (name, value) in headers.iter() {
173            if is_hop_by_hop_str(name.as_str()) {
174                continue;
175            }
176            if let Ok(hname) = HeaderName::from_bytes(name.as_str().as_bytes()) {
177                if let Ok(hval) = HeaderValue::from_bytes(value.as_bytes()) {
178                    response_headers.insert(hname, hval);
179                }
180            }
181        }
182        response_headers.insert(
183            HeaderName::from_static("x-mockforge-source"),
184            HeaderValue::from_static("upstream"),
185        );
186    }
187    response
188        .body(Body::from(resp_bytes))
189        .map_err(|e| ProxyError::BuildResponse(e.to_string()))
190}
191
192fn build_upstream_uri(base: &str, original: &Uri) -> Result<String, ProxyError> {
193    let path = original.path();
194    let query = original.query().map(|q| format!("?{}", q)).unwrap_or_default();
195    Ok(format!("{}{}{}", base, path, query))
196}
197
198fn reqwest_method(m: &Method) -> ReqwestMethod {
199    ReqwestMethod::from_bytes(m.as_str().as_bytes()).unwrap_or(ReqwestMethod::GET)
200}
201
202fn is_hop_by_hop(name: &HeaderName) -> bool {
203    is_hop_by_hop_str(name.as_str())
204}
205
206fn is_hop_by_hop_str(name: &str) -> bool {
207    matches!(
208        name.to_ascii_lowercase().as_str(),
209        "connection"
210            | "keep-alive"
211            | "proxy-authenticate"
212            | "proxy-authorization"
213            | "te"
214            | "trailers"
215            | "transfer-encoding"
216            | "upgrade"
217            | "content-length"
218    )
219}
220
221#[derive(Debug, thiserror::Error)]
222enum ProxyError {
223    #[error("failed to read request body: {0}")]
224    ReadBody(String),
225    #[error("upstream request send failed: {0}")]
226    Send(reqwest::Error),
227    #[error("upstream response read failed: {0}")]
228    ReadResponse(reqwest::Error),
229    #[error("response build failed: {0}")]
230    BuildResponse(String),
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn from_env_disabled_when_unset() {
239        std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
240        assert!(RealityProxyConfig::from_env().is_none());
241    }
242
243    #[test]
244    fn from_env_disabled_when_blank() {
245        std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "   ");
246        assert!(RealityProxyConfig::from_env().is_none());
247        std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
248    }
249
250    #[test]
251    fn from_env_strips_trailing_slash() {
252        std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "https://api.example.com/");
253        let cfg = RealityProxyConfig::from_env().expect("config");
254        assert_eq!(cfg.upstream_base, "https://api.example.com");
255        std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
256    }
257
258    #[test]
259    fn build_upstream_uri_preserves_path_and_query() {
260        let base = "https://api.example.com";
261        let uri: Uri = "/users/42?role=admin".parse().unwrap();
262        let result = build_upstream_uri(base, &uri).unwrap();
263        assert_eq!(result, "https://api.example.com/users/42?role=admin");
264    }
265
266    #[test]
267    fn build_upstream_uri_no_query() {
268        let base = "https://api.example.com";
269        let uri: Uri = "/health".parse().unwrap();
270        let result = build_upstream_uri(base, &uri).unwrap();
271        assert_eq!(result, "https://api.example.com/health");
272    }
273
274    #[test]
275    fn hop_by_hop_headers_are_filtered() {
276        assert!(is_hop_by_hop_str("Connection"));
277        assert!(is_hop_by_hop_str("transfer-encoding"));
278        assert!(is_hop_by_hop_str("UPGRADE"));
279        assert!(!is_hop_by_hop_str("authorization"));
280        assert!(!is_hop_by_hop_str("x-custom-header"));
281    }
282}