mockforge_proxy/
reality.rs1use axum::{
35 body::{to_bytes, Body},
36 extract::Request,
37 http::{
38 header::{CONTENT_TYPE, HOST},
39 HeaderName, HeaderValue, Method, StatusCode, Uri,
40 },
41 middleware::Next,
42 response::Response,
43};
44use mockforge_core::consistency::UnifiedState;
45use reqwest::Method as ReqwestMethod;
46use std::sync::Arc;
47use std::time::Duration;
48use tracing::warn;
49
50#[derive(Clone)]
55pub struct RealityProxyConfig {
56 pub upstream_base: String,
59 pub client: reqwest::Client,
61}
62
63impl RealityProxyConfig {
64 pub fn from_env() -> Option<Arc<Self>> {
68 let base = std::env::var("MOCKFORGE_PROXY_UPSTREAM").ok()?;
69 let trimmed = base.trim().trim_end_matches('/');
70 if trimmed.is_empty() {
71 return None;
72 }
73 let client = match reqwest::Client::builder().timeout(Duration::from_secs(30)).build() {
74 Ok(c) => c,
75 Err(e) => {
76 warn!(error = %e, "RealityProxy HTTP client init failed; middleware will no-op");
77 return None;
78 }
79 };
80 Some(Arc::new(Self {
81 upstream_base: trimmed.to_string(),
82 client,
83 }))
84 }
85}
86
87pub async fn reality_proxy_middleware(
91 config: Arc<RealityProxyConfig>,
92 req: Request,
93 next: Next,
94) -> Response {
95 let ratio = req
96 .extensions()
97 .get::<UnifiedState>()
98 .map(|s| s.reality_continuum_ratio)
99 .unwrap_or(0.0);
100
101 if ratio <= 0.0 {
103 return next.run(req).await;
104 }
105
106 let should_proxy = if ratio >= 1.0 {
107 true
108 } else {
109 rand::random::<f64>() < ratio
110 };
111
112 if !should_proxy {
113 return next.run(req).await;
114 }
115
116 match forward_to_upstream(&config, req).await {
117 Ok(resp) => resp,
118 Err(err) => {
119 warn!(error = %err, "Reality proxy upstream request failed");
124 let body = serde_json::to_vec(&serde_json::json!({
125 "error": "reality_proxy_upstream_failed",
126 "message": err.to_string(),
127 }))
128 .unwrap_or_default();
129 let mut resp = Response::new(Body::from(body));
130 *resp.status_mut() = StatusCode::BAD_GATEWAY;
131 resp.headers_mut()
132 .insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
133 resp
134 }
135 }
136}
137
138async fn forward_to_upstream(
139 config: &RealityProxyConfig,
140 req: Request,
141) -> Result<Response, ProxyError> {
142 let (parts, body) = req.into_parts();
143 const MAX_BODY: usize = 16 * 1024 * 1024;
147 let body_bytes = to_bytes(body, MAX_BODY)
148 .await
149 .map_err(|e| ProxyError::ReadBody(e.to_string()))?;
150
151 let upstream_uri = build_upstream_uri(&config.upstream_base, &parts.uri)?;
152 let method = reqwest_method(&parts.method);
153 let mut req_builder = config.client.request(method, &upstream_uri);
154
155 for (name, value) in parts.headers.iter() {
157 if is_hop_by_hop(name) {
158 continue;
159 }
160 if name == HOST {
161 continue;
162 }
163 req_builder = req_builder.header(name.as_str(), value);
164 }
165
166 if !body_bytes.is_empty() {
167 req_builder = req_builder.body(body_bytes);
168 }
169
170 let upstream_resp = req_builder.send().await.map_err(ProxyError::Send)?;
171 let status = upstream_resp.status();
172 let headers = upstream_resp.headers().clone();
173 let resp_bytes = upstream_resp.bytes().await.map_err(ProxyError::ReadResponse)?;
174
175 let mut response = Response::builder().status(status.as_u16());
176 {
177 let response_headers = response.headers_mut().expect("Response builder must have headers");
178 for (name, value) in headers.iter() {
179 if is_hop_by_hop_str(name.as_str()) {
180 continue;
181 }
182 if let Ok(hname) = HeaderName::from_bytes(name.as_str().as_bytes()) {
183 if let Ok(hval) = HeaderValue::from_bytes(value.as_bytes()) {
184 response_headers.insert(hname, hval);
185 }
186 }
187 }
188 response_headers.insert(
189 HeaderName::from_static("x-mockforge-source"),
190 HeaderValue::from_static("upstream"),
191 );
192 }
193 response
194 .body(Body::from(resp_bytes))
195 .map_err(|e| ProxyError::BuildResponse(e.to_string()))
196}
197
198fn build_upstream_uri(base: &str, original: &Uri) -> Result<String, ProxyError> {
199 let path = original.path();
200 let query = original.query().map(|q| format!("?{}", q)).unwrap_or_default();
201 Ok(format!("{}{}{}", base, path, query))
202}
203
204fn reqwest_method(m: &Method) -> ReqwestMethod {
205 ReqwestMethod::from_bytes(m.as_str().as_bytes()).unwrap_or(ReqwestMethod::GET)
206}
207
208fn is_hop_by_hop(name: &HeaderName) -> bool {
209 is_hop_by_hop_str(name.as_str())
210}
211
212fn is_hop_by_hop_str(name: &str) -> bool {
213 matches!(
214 name.to_ascii_lowercase().as_str(),
215 "connection"
216 | "keep-alive"
217 | "proxy-authenticate"
218 | "proxy-authorization"
219 | "te"
220 | "trailers"
221 | "transfer-encoding"
222 | "upgrade"
223 | "content-length"
224 )
225}
226
227#[derive(Debug, thiserror::Error)]
228enum ProxyError {
229 #[error("failed to read request body: {0}")]
230 ReadBody(String),
231 #[error("upstream request send failed: {0}")]
232 Send(reqwest::Error),
233 #[error("upstream response read failed: {0}")]
234 ReadResponse(reqwest::Error),
235 #[error("response build failed: {0}")]
236 BuildResponse(String),
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn from_env_disabled_when_unset() {
245 std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
246 assert!(RealityProxyConfig::from_env().is_none());
247 }
248
249 #[test]
250 fn from_env_disabled_when_blank() {
251 std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", " ");
252 assert!(RealityProxyConfig::from_env().is_none());
253 std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
254 }
255
256 #[test]
257 fn from_env_strips_trailing_slash() {
258 std::env::set_var("MOCKFORGE_PROXY_UPSTREAM", "https://api.example.com/");
259 let cfg = RealityProxyConfig::from_env().expect("config");
260 assert_eq!(cfg.upstream_base, "https://api.example.com");
261 std::env::remove_var("MOCKFORGE_PROXY_UPSTREAM");
262 }
263
264 #[test]
265 fn build_upstream_uri_preserves_path_and_query() {
266 let base = "https://api.example.com";
267 let uri: Uri = "/users/42?role=admin".parse().unwrap();
268 let result = build_upstream_uri(base, &uri).unwrap();
269 assert_eq!(result, "https://api.example.com/users/42?role=admin");
270 }
271
272 #[test]
273 fn build_upstream_uri_no_query() {
274 let base = "https://api.example.com";
275 let uri: Uri = "/health".parse().unwrap();
276 let result = build_upstream_uri(base, &uri).unwrap();
277 assert_eq!(result, "https://api.example.com/health");
278 }
279
280 #[test]
281 fn hop_by_hop_headers_are_filtered() {
282 assert!(is_hop_by_hop_str("Connection"));
283 assert!(is_hop_by_hop_str("transfer-encoding"));
284 assert!(is_hop_by_hop_str("UPGRADE"));
285 assert!(!is_hop_by_hop_str("authorization"));
286 assert!(!is_hop_by_hop_str("x-custom-header"));
287 }
288}