mockforge_http/
reality_proxy.rs1use axum::{
29 body::{to_bytes, Body},
30 extract::Request,
31 http::{HeaderName, HeaderValue, Method, StatusCode, Uri},
32 middleware::Next,
33 response::Response,
34};
35use mockforge_core::consistency::UnifiedState;
36use std::sync::Arc;
37use std::time::Duration;
38use tracing::warn;
39
40#[derive(Clone)]
45pub struct RealityProxyConfig {
46 pub upstream_base: String,
49 pub client: reqwest::Client,
51}
52
53impl RealityProxyConfig {
54 pub fn from_env() -> Option<Arc<Self>> {
58 let base = std::env::var("MOCKFORGE_PROXY_UPSTREAM").ok()?;
59 let trimmed = base.trim().trim_end_matches('/');
60 if trimmed.is_empty() {
61 return None;
62 }
63 let client = match reqwest::Client::builder().timeout(Duration::from_secs(30)).build() {
64 Ok(c) => c,
65 Err(e) => {
66 warn!(error = %e, "RealityProxy HTTP client init failed; middleware will no-op");
67 return None;
68 }
69 };
70 Some(Arc::new(Self {
71 upstream_base: trimmed.to_string(),
72 client,
73 }))
74 }
75}
76
77pub async fn reality_proxy_middleware(
81 config: Arc<RealityProxyConfig>,
82 req: Request,
83 next: Next,
84) -> Response {
85 let ratio = req
86 .extensions()
87 .get::<UnifiedState>()
88 .map(|s| s.reality_continuum_ratio)
89 .unwrap_or(0.0);
90
91 if ratio <= 0.0 {
93 return next.run(req).await;
94 }
95
96 let should_proxy = if ratio >= 1.0 {
97 true
98 } else {
99 rand::random::<f64>() < ratio
100 };
101
102 if !should_proxy {
103 return next.run(req).await;
104 }
105
106 match forward_to_upstream(&config, req).await {
107 Ok(resp) => resp,
108 Err(err) => {
109 warn!(error = %err, "Reality proxy upstream request failed");
114 let body = serde_json::to_vec(&serde_json::json!({
115 "error": "reality_proxy_upstream_failed",
116 "message": err.to_string(),
117 }))
118 .unwrap_or_default();
119 let mut resp = Response::new(Body::from(body));
120 *resp.status_mut() = StatusCode::BAD_GATEWAY;
121 resp.headers_mut().insert(
122 axum::http::header::CONTENT_TYPE,
123 HeaderValue::from_static("application/json"),
124 );
125 resp
126 }
127 }
128}
129
130async fn forward_to_upstream(
131 config: &RealityProxyConfig,
132 req: Request,
133) -> Result<Response, ProxyError> {
134 let (parts, body) = req.into_parts();
135 const MAX_BODY: usize = 16 * 1024 * 1024;
139 let body_bytes = to_bytes(body, MAX_BODY)
140 .await
141 .map_err(|e| ProxyError::ReadBody(e.to_string()))?;
142
143 let upstream_uri = build_upstream_uri(&config.upstream_base, &parts.uri)?;
144 let method = reqwest_method(&parts.method);
145 let mut req_builder = config.client.request(method, &upstream_uri);
146
147 for (name, value) in parts.headers.iter() {
149 if is_hop_by_hop(name) {
150 continue;
151 }
152 if name == axum::http::header::HOST {
153 continue;
154 }
155 req_builder = req_builder.header(name.as_str(), value);
156 }
157
158 if !body_bytes.is_empty() {
159 req_builder = req_builder.body(body_bytes);
160 }
161
162 let upstream_resp = req_builder.send().await.map_err(ProxyError::Send)?;
163 let status = upstream_resp.status();
164 let headers = upstream_resp.headers().clone();
165 let resp_bytes = upstream_resp.bytes().await.map_err(ProxyError::ReadResponse)?;
166
167 let mut response = Response::builder().status(status.as_u16());
168 {
169 let response_headers = response.headers_mut().expect("Response builder must have headers");
170 for (name, value) in headers.iter() {
171 if is_hop_by_hop_str(name.as_str()) {
172 continue;
173 }
174 if let Ok(hname) = HeaderName::from_bytes(name.as_str().as_bytes()) {
175 if let Ok(hval) = HeaderValue::from_bytes(value.as_bytes()) {
176 response_headers.insert(hname, hval);
177 }
178 }
179 }
180 response_headers.insert(
181 HeaderName::from_static("x-mockforge-source"),
182 HeaderValue::from_static("upstream"),
183 );
184 }
185 response
186 .body(Body::from(resp_bytes))
187 .map_err(|e| ProxyError::BuildResponse(e.to_string()))
188}
189
190fn build_upstream_uri(base: &str, original: &Uri) -> Result<String, ProxyError> {
191 let path = original.path();
192 let query = original.query().map(|q| format!("?{}", q)).unwrap_or_default();
193 Ok(format!("{}{}{}", base, path, query))
194}
195
196fn reqwest_method(m: &Method) -> reqwest::Method {
197 reqwest::Method::from_bytes(m.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
198}
199
200fn is_hop_by_hop(name: &HeaderName) -> bool {
201 is_hop_by_hop_str(name.as_str())
202}
203
204fn is_hop_by_hop_str(name: &str) -> bool {
205 matches!(
206 name.to_ascii_lowercase().as_str(),
207 "connection"
208 | "keep-alive"
209 | "proxy-authenticate"
210 | "proxy-authorization"
211 | "te"
212 | "trailers"
213 | "transfer-encoding"
214 | "upgrade"
215 | "content-length"
216 )
217}
218
219#[derive(Debug, thiserror::Error)]
220enum ProxyError {
221 #[error("failed to read request body: {0}")]
222 ReadBody(String),
223 #[error("upstream request send failed: {0}")]
224 Send(reqwest::Error),
225 #[error("upstream response read failed: {0}")]
226 ReadResponse(reqwest::Error),
227 #[error("response build failed: {0}")]
228 BuildResponse(String),
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn from_env_disabled_when_unset() {
237 std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
238 assert!(RealityProxyConfig::from_env().is_none());
239 }
240
241 #[test]
242 fn from_env_disabled_when_blank() {
243 std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", " ");
244 assert!(RealityProxyConfig::from_env().is_none());
245 std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
246 }
247
248 #[test]
249 fn from_env_strips_trailing_slash() {
250 std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "https://api.example.com/");
251 let cfg = RealityProxyConfig::from_env().expect("config");
252 assert_eq!(cfg.upstream_base, "https://api.example.com");
253 std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
254 }
255
256 #[test]
257 fn build_upstream_uri_preserves_path_and_query() {
258 let base = "https://api.example.com";
259 let uri: Uri = "/users/42?role=admin".parse().unwrap();
260 let result = build_upstream_uri(base, &uri).unwrap();
261 assert_eq!(result, "https://api.example.com/users/42?role=admin");
262 }
263
264 #[test]
265 fn build_upstream_uri_no_query() {
266 let base = "https://api.example.com";
267 let uri: Uri = "/health".parse().unwrap();
268 let result = build_upstream_uri(base, &uri).unwrap();
269 assert_eq!(result, "https://api.example.com/health");
270 }
271
272 #[test]
273 fn hop_by_hop_headers_are_filtered() {
274 assert!(is_hop_by_hop_str("Connection"));
275 assert!(is_hop_by_hop_str("transfer-encoding"));
276 assert!(is_hop_by_hop_str("UPGRADE"));
277 assert!(!is_hop_by_hop_str("authorization"));
278 assert!(!is_hop_by_hop_str("x-custom-header"));
279 }
280}