1use axum::{
7 extract::Request, http::StatusCode, middleware::Next, response::Response, routing::get, Router,
8};
9use mockforge_core::proxy::{body_transform::BodyTransformationMiddleware, config::ProxyConfig};
10use serde::Serialize;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug)]
18pub struct ProxyServer {
19 config: Arc<RwLock<ProxyConfig>>,
21 log_requests: bool,
23 log_responses: bool,
25 request_counter: Arc<RwLock<u64>>,
27 start_time: std::time::Instant,
29 total_response_time_ms: Arc<RwLock<u64>>,
31 error_counter: Arc<RwLock<u64>>,
33}
34
35impl ProxyServer {
36 pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
38 Self {
39 config: Arc::new(RwLock::new(config)),
40 log_requests,
41 log_responses,
42 request_counter: Arc::new(RwLock::new(0)),
43 start_time: std::time::Instant::now(),
44 total_response_time_ms: Arc::new(RwLock::new(0)),
45 error_counter: Arc::new(RwLock::new(0)),
46 }
47 }
48
49 pub fn router(self) -> Router {
51 let state = Arc::new(self);
52 let state_for_middleware = state.clone();
53
54 Router::new()
55 .route("/proxy/health", get(health_check))
57 .fallback(proxy_handler)
59 .with_state(state)
60 .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
61 }
62}
63
64async fn health_check() -> Result<Response<String>, StatusCode> {
66 Response::builder()
68 .status(StatusCode::OK)
69 .header("Content-Type", "application/json")
70 .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
71 .map_err(|e| {
72 tracing::error!("Failed to build health check response: {}", e);
73 StatusCode::INTERNAL_SERVER_ERROR
74 })
75}
76
77async fn proxy_handler(
79 axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
80 request: http::Request<axum::body::Body>,
81) -> Result<Response<String>, StatusCode> {
82 let client_addr = request
84 .extensions()
85 .get::<SocketAddr>()
86 .copied()
87 .unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
88
89 let method = request.method().clone();
90 let uri = request.uri().clone();
91 let headers = request.headers().clone();
92
93 let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
95 Ok(bytes) => Some(bytes.to_vec()),
96 Err(e) => {
97 error!("Failed to read request body: {}", e);
98 None
99 }
100 };
101
102 let config = state.config.read().await;
103
104 if !config.enabled {
106 return Err(StatusCode::SERVICE_UNAVAILABLE);
107 }
108
109 if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
111 return Err(StatusCode::NOT_FOUND);
112 }
113
114 let stripped_path = config.strip_prefix(uri.path());
116
117 let base_upstream_url = config.get_upstream_url(uri.path());
119 let full_upstream_url =
120 if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
121 stripped_path.clone()
122 } else {
123 let base = base_upstream_url.trim_end_matches('/');
124 let path = stripped_path.trim_start_matches('/');
125 let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
126 if path.is_empty() || path == "/" {
127 format!("{}{}", base, query)
128 } else {
129 format!("{}/{}", base, path) + &query
130 }
131 };
132
133 let _modified_uri = full_upstream_url.parse::<http::Uri>().unwrap_or_else(|_| uri.clone());
135
136 if state.log_requests {
138 let mut counter = state.request_counter.write().await;
139 *counter += 1;
140 let request_id = *counter;
141
142 info!(
143 request_id = request_id,
144 method = %method,
145 path = %uri.path(),
146 upstream = %full_upstream_url,
147 client_ip = %client_addr.ip(),
148 "Proxy request intercepted"
149 );
150 }
151
152 let mut header_map = std::collections::HashMap::new();
154 for (key, value) in &headers {
155 if let Ok(value_str) = value.to_str() {
156 header_map.insert(key.to_string(), value_str.to_string());
157 }
158 }
159
160 use mockforge_core::proxy::client::ProxyClient;
162 let proxy_client = ProxyClient::new();
163
164 let reqwest_method = match method.as_str() {
166 "GET" => reqwest::Method::GET,
167 "POST" => reqwest::Method::POST,
168 "PUT" => reqwest::Method::PUT,
169 "DELETE" => reqwest::Method::DELETE,
170 "HEAD" => reqwest::Method::HEAD,
171 "OPTIONS" => reqwest::Method::OPTIONS,
172 "PATCH" => reqwest::Method::PATCH,
173 _ => {
174 error!("Unsupported HTTP method: {}", method);
175 return Err(StatusCode::METHOD_NOT_ALLOWED);
176 }
177 };
178
179 for (key, value) in &config.headers {
181 header_map.insert(key.clone(), value.clone());
182 }
183
184 let mut transformed_request_body = body_bytes.clone();
186 if !config.request_replacements.is_empty() {
187 let transform_middleware = BodyTransformationMiddleware::new(
188 config.request_replacements.clone(),
189 Vec::new(), );
191 if let Err(e) =
192 transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
193 {
194 warn!("Failed to transform request body: {}", e);
195 }
197 }
198
199 match proxy_client
200 .send_request(
201 reqwest_method,
202 &full_upstream_url,
203 &header_map,
204 transformed_request_body.as_deref(),
205 )
206 .await
207 {
208 Ok(response) => {
209 let status = StatusCode::from_u16(response.status().as_u16())
210 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
211
212 if state.log_responses {
214 info!(
215 method = %method,
216 path = %uri.path(),
217 status = status.as_u16(),
218 "Proxy response sent"
219 );
220 }
221
222 let mut response_headers = http::HeaderMap::new();
224 for (name, value) in response.headers() {
225 if let (Ok(header_name), Ok(header_value)) = (
226 http::HeaderName::try_from(name.as_str()),
227 http::HeaderValue::try_from(value.as_bytes()),
228 ) {
229 response_headers.insert(header_name, header_value);
230 }
231 }
232
233 let response_body_bytes = response.bytes().await.map_err(|e| {
235 error!("Failed to read proxy response body: {}", e);
236 StatusCode::BAD_GATEWAY
237 })?;
238
239 let mut final_body_bytes = response_body_bytes.to_vec();
241 {
242 let config_for_response = state.config.read().await;
243 if !config_for_response.response_replacements.is_empty() {
244 let transform_middleware = BodyTransformationMiddleware::new(
245 Vec::new(), config_for_response.response_replacements.clone(),
247 );
248 let mut body_option = Some(final_body_bytes.clone());
249 if let Err(e) = transform_middleware.transform_response_body(
250 uri.path(),
251 status.as_u16(),
252 &mut body_option,
253 ) {
254 warn!("Failed to transform response body: {}", e);
255 } else if let Some(transformed_body) = body_option {
257 final_body_bytes = transformed_body;
258 }
259 }
260 }
261
262 let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
263
264 let mut response_builder = Response::builder().status(status);
266 for (name, value) in response_headers.iter() {
267 response_builder = response_builder.header(name, value);
268 }
269
270 response_builder
271 .body(body_string)
272 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
273 }
274 Err(e) => {
275 error!("Proxy request failed: {}", e);
276 Err(StatusCode::BAD_GATEWAY)
277 }
278 }
279}
280
281async fn logging_middleware(
283 axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
284 request: Request,
285 next: Next,
286) -> Response {
287 let start = std::time::Instant::now();
288 let method = request.method().clone();
289 let uri = request.uri().clone();
290
291 let client_addr = request
293 .extensions()
294 .get::<SocketAddr>()
295 .copied()
296 .unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
297
298 debug!(
299 method = %method,
300 uri = %uri,
301 client_ip = %client_addr.ip(),
302 "Request received"
303 );
304
305 let response = next.run(request).await;
306 let duration = start.elapsed();
307
308 {
310 let mut total_time = state.total_response_time_ms.write().await;
311 *total_time += duration.as_millis() as u64;
312 }
313
314 if response.status().is_server_error() {
316 let mut errors = state.error_counter.write().await;
317 *errors += 1;
318 }
319
320 debug!(
321 method = %method,
322 uri = %uri,
323 status = %response.status(),
324 duration_ms = duration.as_millis(),
325 "Response sent"
326 );
327
328 response
329}
330
331#[derive(Debug, Serialize)]
333pub struct ProxyStats {
334 pub total_requests: u64,
336 pub requests_per_second: f64,
338 pub avg_response_time_ms: f64,
340 pub error_rate_percent: f64,
342}
343
344pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
346 let total_requests = *state.request_counter.read().await;
347 let total_response_time_ms = *state.total_response_time_ms.read().await;
348 let error_count = *state.error_counter.read().await;
349
350 let elapsed_secs = state.start_time.elapsed().as_secs_f64();
351 let requests_per_second = if elapsed_secs > 0.0 {
352 total_requests as f64 / elapsed_secs
353 } else {
354 0.0
355 };
356
357 let avg_response_time_ms = if total_requests > 0 {
358 total_response_time_ms as f64 / total_requests as f64
359 } else {
360 0.0
361 };
362
363 let error_rate_percent = if total_requests > 0 {
364 (error_count as f64 / total_requests as f64) * 100.0
365 } else {
366 0.0
367 };
368
369 ProxyStats {
370 total_requests,
371 requests_per_second,
372 avg_response_time_ms,
373 error_rate_percent,
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use axum::http::StatusCode;
381 use mockforge_core::proxy::config::ProxyConfig;
382
383 #[tokio::test]
384 async fn test_proxy_server_creation() {
385 let config = ProxyConfig::default();
386 let server = ProxyServer::new(config, true, true);
387
388 assert!(server.log_requests);
390 assert!(server.log_responses);
391 }
392
393 #[tokio::test]
394 async fn test_health_check() {
395 let response = health_check().await.unwrap();
396 assert_eq!(response.status(), StatusCode::OK);
397
398 let body = response.into_body();
400
401 assert!(body.contains("healthy"));
402 assert!(body.contains("mockforge-proxy"));
403 }
404
405 #[tokio::test]
406 async fn test_proxy_stats() {
407 let config = ProxyConfig::default();
408 let server = ProxyServer::new(config, false, false);
409
410 let stats = get_proxy_stats(&server).await;
412 assert_eq!(stats.total_requests, 0);
413 assert_eq!(stats.avg_response_time_ms, 0.0);
414 assert_eq!(stats.error_rate_percent, 0.0);
415 }
416}