Skip to main content

mockforge_http/
reality_proxy.rs

1//! Reality-slider-driven mock/proxy switching middleware (#222).
2//!
3//! When `MOCKFORGE_PROXY_UPSTREAM` is set on the process, this middleware
4//! probabilistically forwards a fraction of incoming requests to that URL
5//! based on the active workspace's `reality_continuum_ratio`. The fraction
6//! is per-request: ratio 0.0 = always-mock, 1.0 = always-proxy, 0.5 =
7//! coin-flip per request.
8//!
9//! ## Design
10//!
11//! The middleware is a no-op when:
12//!   - `MOCKFORGE_PROXY_UPSTREAM` is unset (e.g., local dev)
13//!   - the request has no associated `UnifiedState` extension (set by the
14//!     consistency middleware upstream of this one)
15//!   - the resolved ratio is exactly 0.0
16//!
17//! When proxying, the middleware reconstructs the request against the
18//! upstream base URL preserving method, path, query, headers, and body,
19//! then streams the upstream response back to the caller. Any failure
20//! falls through to the mock chain — the mock is the durable path; the
21//! upstream is best-effort.
22//!
23//! Wiring: insert the layer between `consistency_middleware` (which
24//! injects `UnifiedState`) and the route handlers. The dependency on
25//! `UnifiedState` is read-only, so ordering relative to recording or
26//! tracing layers doesn't matter.
27
28use axum::{
29    body::{to_bytes, Body},
30    extract::Request,
31    http::{HeaderName, HeaderValue, Method, StatusCode, Uri},
32    middleware::Next,
33    response::Response,
34};
35use mockforge_core::consistency::UnifiedState;
36use std::sync::Arc;
37use std::time::Duration;
38use tracing::warn;
39
40/// Cheap-to-clone handle holding the upstream base URL and a shared
41/// reqwest client. Constructed once at server startup; the layer closure
42/// holds an `Arc<RealityProxyConfig>` so per-request work is just an
43/// arc-clone.
44#[derive(Clone)]
45pub struct RealityProxyConfig {
46    /// Base URL — protocol + host + (optional) port, no trailing slash.
47    /// Path/query are taken from the incoming request.
48    pub upstream_base: String,
49    /// Shared HTTP client used for all upstream requests.
50    pub client: reqwest::Client,
51}
52
53impl RealityProxyConfig {
54    /// Construct from `MOCKFORGE_PROXY_UPSTREAM`. Returns None when the
55    /// env var is missing or empty (no-op middleware) or when the HTTP
56    /// client can't be built (logged as a warning).
57    pub fn from_env() -> Option<Arc<Self>> {
58        let base = std::env::var("MOCKFORGE_PROXY_UPSTREAM").ok()?;
59        let trimmed = base.trim().trim_end_matches('/');
60        if trimmed.is_empty() {
61            return None;
62        }
63        let client = match reqwest::Client::builder().timeout(Duration::from_secs(30)).build() {
64            Ok(c) => c,
65            Err(e) => {
66                warn!(error = %e, "RealityProxy HTTP client init failed; middleware will no-op");
67                return None;
68            }
69        };
70        Some(Arc::new(Self {
71            upstream_base: trimmed.to_string(),
72            client,
73        }))
74    }
75}
76
77/// The middleware function. Reads `reality_continuum_ratio` from the
78/// `UnifiedState` request extension, rolls a per-request RNG, and either
79/// forwards to upstream or hands off to the next layer (mock chain).
80pub async fn reality_proxy_middleware(
81    config: Arc<RealityProxyConfig>,
82    req: Request,
83    next: Next,
84) -> Response {
85    let ratio = req
86        .extensions()
87        .get::<UnifiedState>()
88        .map(|s| s.reality_continuum_ratio)
89        .unwrap_or(0.0);
90
91    // Fast path: no upstream desired for this request.
92    if ratio <= 0.0 {
93        return next.run(req).await;
94    }
95
96    let should_proxy = if ratio >= 1.0 {
97        true
98    } else {
99        rand::random::<f64>() < ratio
100    };
101
102    if !should_proxy {
103        return next.run(req).await;
104    }
105
106    match forward_to_upstream(&config, req).await {
107        Ok(resp) => resp,
108        Err(err) => {
109            // We've already consumed the request body to forward it, so
110            // we can't fall back to the mock chain. Surface 502 — the
111            // alternative (silent retry / synthetic mock) would hide
112            // real upstream incidents from operators.
113            warn!(error = %err, "Reality proxy upstream request failed");
114            let body = serde_json::to_vec(&serde_json::json!({
115                "error": "reality_proxy_upstream_failed",
116                "message": err.to_string(),
117            }))
118            .unwrap_or_default();
119            let mut resp = Response::new(Body::from(body));
120            *resp.status_mut() = StatusCode::BAD_GATEWAY;
121            resp.headers_mut().insert(
122                axum::http::header::CONTENT_TYPE,
123                HeaderValue::from_static("application/json"),
124            );
125            resp
126        }
127    }
128}
129
130async fn forward_to_upstream(
131    config: &RealityProxyConfig,
132    req: Request,
133) -> Result<Response, ProxyError> {
134    let (parts, body) = req.into_parts();
135    // Cap at 16 MiB — same as Axum's default request size limit.
136    // Anything larger and we'd be holding too much in memory for a
137    // simple proxy hop; better to fail loudly than swap-thrash.
138    const MAX_BODY: usize = 16 * 1024 * 1024;
139    let body_bytes = to_bytes(body, MAX_BODY)
140        .await
141        .map_err(|e| ProxyError::ReadBody(e.to_string()))?;
142
143    let upstream_uri = build_upstream_uri(&config.upstream_base, &parts.uri)?;
144    let method = reqwest_method(&parts.method);
145    let mut req_builder = config.client.request(method, &upstream_uri);
146
147    // Copy headers, dropping hop-by-hop / Host so reqwest sets correct ones.
148    for (name, value) in parts.headers.iter() {
149        if is_hop_by_hop(name) {
150            continue;
151        }
152        if name == axum::http::header::HOST {
153            continue;
154        }
155        req_builder = req_builder.header(name.as_str(), value);
156    }
157
158    if !body_bytes.is_empty() {
159        req_builder = req_builder.body(body_bytes);
160    }
161
162    let upstream_resp = req_builder.send().await.map_err(ProxyError::Send)?;
163    let status = upstream_resp.status();
164    let headers = upstream_resp.headers().clone();
165    let resp_bytes = upstream_resp.bytes().await.map_err(ProxyError::ReadResponse)?;
166
167    let mut response = Response::builder().status(status.as_u16());
168    {
169        let response_headers = response.headers_mut().expect("Response builder must have headers");
170        for (name, value) in headers.iter() {
171            if is_hop_by_hop_str(name.as_str()) {
172                continue;
173            }
174            if let Ok(hname) = HeaderName::from_bytes(name.as_str().as_bytes()) {
175                if let Ok(hval) = HeaderValue::from_bytes(value.as_bytes()) {
176                    response_headers.insert(hname, hval);
177                }
178            }
179        }
180        response_headers.insert(
181            HeaderName::from_static("x-mockforge-source"),
182            HeaderValue::from_static("upstream"),
183        );
184    }
185    response
186        .body(Body::from(resp_bytes))
187        .map_err(|e| ProxyError::BuildResponse(e.to_string()))
188}
189
190fn build_upstream_uri(base: &str, original: &Uri) -> Result<String, ProxyError> {
191    let path = original.path();
192    let query = original.query().map(|q| format!("?{}", q)).unwrap_or_default();
193    Ok(format!("{}{}{}", base, path, query))
194}
195
196fn reqwest_method(m: &Method) -> reqwest::Method {
197    reqwest::Method::from_bytes(m.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
198}
199
200fn is_hop_by_hop(name: &HeaderName) -> bool {
201    is_hop_by_hop_str(name.as_str())
202}
203
204fn is_hop_by_hop_str(name: &str) -> bool {
205    matches!(
206        name.to_ascii_lowercase().as_str(),
207        "connection"
208            | "keep-alive"
209            | "proxy-authenticate"
210            | "proxy-authorization"
211            | "te"
212            | "trailers"
213            | "transfer-encoding"
214            | "upgrade"
215            | "content-length"
216    )
217}
218
219#[derive(Debug, thiserror::Error)]
220enum ProxyError {
221    #[error("failed to read request body: {0}")]
222    ReadBody(String),
223    #[error("upstream request send failed: {0}")]
224    Send(reqwest::Error),
225    #[error("upstream response read failed: {0}")]
226    ReadResponse(reqwest::Error),
227    #[error("response build failed: {0}")]
228    BuildResponse(String),
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn from_env_disabled_when_unset() {
237        std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
238        assert!(RealityProxyConfig::from_env().is_none());
239    }
240
241    #[test]
242    fn from_env_disabled_when_blank() {
243        std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "   ");
244        assert!(RealityProxyConfig::from_env().is_none());
245        std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
246    }
247
248    #[test]
249    fn from_env_strips_trailing_slash() {
250        std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "https://api.example.com/");
251        let cfg = RealityProxyConfig::from_env().expect("config");
252        assert_eq!(cfg.upstream_base, "https://api.example.com");
253        std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
254    }
255
256    #[test]
257    fn build_upstream_uri_preserves_path_and_query() {
258        let base = "https://api.example.com";
259        let uri: Uri = "/users/42?role=admin".parse().unwrap();
260        let result = build_upstream_uri(base, &uri).unwrap();
261        assert_eq!(result, "https://api.example.com/users/42?role=admin");
262    }
263
264    #[test]
265    fn build_upstream_uri_no_query() {
266        let base = "https://api.example.com";
267        let uri: Uri = "/health".parse().unwrap();
268        let result = build_upstream_uri(base, &uri).unwrap();
269        assert_eq!(result, "https://api.example.com/health");
270    }
271
272    #[test]
273    fn hop_by_hop_headers_are_filtered() {
274        assert!(is_hop_by_hop_str("Connection"));
275        assert!(is_hop_by_hop_str("transfer-encoding"));
276        assert!(is_hop_by_hop_str("UPGRADE"));
277        assert!(!is_hop_by_hop_str("authorization"));
278        assert!(!is_hop_by_hop_str("x-custom-header"));
279    }
280}