1use axum::{
7 extract::Request, http::StatusCode, middleware::Next, response::Response, routing::get, Router,
8};
9use mockforge_core::proxy::{body_transform::BodyTransformationMiddleware, config::ProxyConfig};
10use serde::Serialize;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug)]
18pub struct ProxyServer {
19 config: Arc<RwLock<ProxyConfig>>,
21 log_requests: bool,
23 log_responses: bool,
25 request_counter: Arc<RwLock<u64>>,
27}
28
29impl ProxyServer {
30 pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
32 Self {
33 config: Arc::new(RwLock::new(config)),
34 log_requests,
35 log_responses,
36 request_counter: Arc::new(RwLock::new(0)),
37 }
38 }
39
40 pub fn router(self) -> Router {
42 let state = Arc::new(self);
43 let state_for_middleware = state.clone();
44
45 Router::new()
46 .route("/proxy/health", get(health_check))
48 .fallback(proxy_handler)
50 .with_state(state)
51 .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
52 }
53}
54
55async fn health_check() -> Result<Response<String>, StatusCode> {
57 Response::builder()
59 .status(StatusCode::OK)
60 .header("Content-Type", "application/json")
61 .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
62 .map_err(|e| {
63 tracing::error!("Failed to build health check response: {}", e);
64 StatusCode::INTERNAL_SERVER_ERROR
65 })
66}
67
68async fn proxy_handler(
70 axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
71 request: axum::http::Request<axum::body::Body>,
72) -> Result<Response<String>, StatusCode> {
73 let client_addr = request
75 .extensions()
76 .get::<SocketAddr>()
77 .copied()
78 .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
79
80 let method = request.method().clone();
81 let uri = request.uri().clone();
82 let headers = request.headers().clone();
83
84 let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
86 Ok(bytes) => Some(bytes.to_vec()),
87 Err(e) => {
88 error!("Failed to read request body: {}", e);
89 None
90 }
91 };
92
93 let config = state.config.read().await;
94
95 if !config.enabled {
97 return Err(StatusCode::SERVICE_UNAVAILABLE);
98 }
99
100 if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
102 return Err(StatusCode::NOT_FOUND);
103 }
104
105 let stripped_path = config.strip_prefix(uri.path());
107
108 let base_upstream_url = config.get_upstream_url(uri.path());
110 let full_upstream_url =
111 if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
112 stripped_path.clone()
113 } else {
114 let base = base_upstream_url.trim_end_matches('/');
115 let path = stripped_path.trim_start_matches('/');
116 let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
117 if path.is_empty() || path == "/" {
118 format!("{}{}", base, query)
119 } else {
120 format!("{}/{}", base, path) + &query
121 }
122 };
123
124 let modified_uri = full_upstream_url.parse::<axum::http::Uri>().unwrap_or_else(|_| uri.clone());
126
127 if state.log_requests {
129 let mut counter = state.request_counter.write().await;
130 *counter += 1;
131 let request_id = *counter;
132
133 info!(
134 request_id = request_id,
135 method = %method,
136 path = %uri.path(),
137 upstream = %full_upstream_url,
138 client_ip = %client_addr.ip(),
139 "Proxy request intercepted"
140 );
141 }
142
143 let mut header_map = std::collections::HashMap::new();
145 for (key, value) in &headers {
146 if let Ok(value_str) = value.to_str() {
147 header_map.insert(key.to_string(), value_str.to_string());
148 }
149 }
150
151 use mockforge_core::proxy::client::ProxyClient;
153 let proxy_client = ProxyClient::new();
154
155 let reqwest_method = match method.as_str() {
157 "GET" => reqwest::Method::GET,
158 "POST" => reqwest::Method::POST,
159 "PUT" => reqwest::Method::PUT,
160 "DELETE" => reqwest::Method::DELETE,
161 "HEAD" => reqwest::Method::HEAD,
162 "OPTIONS" => reqwest::Method::OPTIONS,
163 "PATCH" => reqwest::Method::PATCH,
164 _ => {
165 error!("Unsupported HTTP method: {}", method);
166 return Err(StatusCode::METHOD_NOT_ALLOWED);
167 }
168 };
169
170 for (key, value) in &config.headers {
172 header_map.insert(key.clone(), value.clone());
173 }
174
175 let mut transformed_request_body = body_bytes.clone();
177 if !config.request_replacements.is_empty() {
178 let transform_middleware = BodyTransformationMiddleware::new(
179 config.request_replacements.clone(),
180 Vec::new(), );
182 if let Err(e) =
183 transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
184 {
185 warn!("Failed to transform request body: {}", e);
186 }
188 }
189
190 match proxy_client
191 .send_request(
192 reqwest_method,
193 &full_upstream_url,
194 &header_map,
195 transformed_request_body.as_deref(),
196 )
197 .await
198 {
199 Ok(response) => {
200 let status = StatusCode::from_u16(response.status().as_u16())
201 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
202
203 if state.log_responses {
205 info!(
206 method = %method,
207 path = %uri.path(),
208 status = status.as_u16(),
209 "Proxy response sent"
210 );
211 }
212
213 let mut response_headers = axum::http::HeaderMap::new();
215 for (name, value) in response.headers() {
216 if let (Ok(header_name), Ok(header_value)) = (
217 axum::http::HeaderName::try_from(name.as_str()),
218 axum::http::HeaderValue::try_from(value.as_bytes()),
219 ) {
220 response_headers.insert(header_name, header_value);
221 }
222 }
223
224 let response_body_bytes = response.bytes().await.map_err(|e| {
226 error!("Failed to read proxy response body: {}", e);
227 StatusCode::BAD_GATEWAY
228 })?;
229
230 let mut final_body_bytes = response_body_bytes.to_vec();
232 {
233 let config_for_response = state.config.read().await;
234 if !config_for_response.response_replacements.is_empty() {
235 let transform_middleware = BodyTransformationMiddleware::new(
236 Vec::new(), config_for_response.response_replacements.clone(),
238 );
239 let mut body_option = Some(final_body_bytes.clone());
240 if let Err(e) = transform_middleware.transform_response_body(
241 uri.path(),
242 status.as_u16(),
243 &mut body_option,
244 ) {
245 warn!("Failed to transform response body: {}", e);
246 } else if let Some(transformed_body) = body_option {
248 final_body_bytes = transformed_body;
249 }
250 }
251 }
252
253 let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
254
255 let mut response_builder = Response::builder().status(status);
257 for (name, value) in response_headers.iter() {
258 response_builder = response_builder.header(name, value);
259 }
260
261 response_builder
262 .body(body_string)
263 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
264 }
265 Err(e) => {
266 error!("Proxy request failed: {}", e);
267 Err(StatusCode::BAD_GATEWAY)
268 }
269 }
270}
271
272async fn logging_middleware(
274 axum::extract::State(_state): axum::extract::State<Arc<ProxyServer>>,
275 request: Request,
276 next: Next,
277) -> Response {
278 let start = std::time::Instant::now();
279 let method = request.method().clone();
280 let uri = request.uri().clone();
281
282 let client_addr = request
284 .extensions()
285 .get::<SocketAddr>()
286 .copied()
287 .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
288
289 debug!(
290 method = %method,
291 uri = %uri,
292 client_ip = %client_addr.ip(),
293 "Request received"
294 );
295
296 let response = next.run(request).await;
297 let duration = start.elapsed();
298
299 debug!(
300 method = %method,
301 uri = %uri,
302 status = %response.status(),
303 duration_ms = duration.as_millis(),
304 "Response sent"
305 );
306
307 response
308}
309
310#[derive(Debug, Serialize)]
312pub struct ProxyStats {
313 pub total_requests: u64,
315 pub requests_per_second: f64,
317 pub avg_response_time_ms: f64,
319 pub error_rate_percent: f64,
321}
322
323pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
325 let total_requests = *state.request_counter.read().await;
326
327 ProxyStats {
330 total_requests,
331 requests_per_second: 0.0, avg_response_time_ms: 0.0, error_rate_percent: 0.0, }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use axum::http::StatusCode;
341 use mockforge_core::proxy::config::ProxyConfig;
342 use std::net::SocketAddr;
343
344 #[tokio::test]
345 async fn test_proxy_server_creation() {
346 let config = ProxyConfig::default();
347 let server = ProxyServer::new(config, true, true);
348
349 assert!(server.log_requests);
351 assert!(server.log_responses);
352 }
353
354 #[tokio::test]
355 async fn test_health_check() {
356 let response = health_check().await.unwrap();
357 assert_eq!(response.status(), StatusCode::OK);
358
359 let body = response.into_body();
361
362 assert!(body.contains("healthy"));
363 assert!(body.contains("mockforge-proxy"));
364 }
365
366 #[tokio::test]
367 async fn test_proxy_stats() {
368 let config = ProxyConfig::default();
369 let server = ProxyServer::new(config, false, false);
370
371 let stats = get_proxy_stats(&server).await;
372 assert_eq!(stats.total_requests, 0);
373 assert_eq!(stats.requests_per_second, 0.0);
374 assert_eq!(stats.avg_response_time_ms, 0.0);
375 assert_eq!(stats.error_rate_percent, 0.0);
376 }
377}