1use axum::{
7 extract::Request,
8 http::{HeaderMap, Method, StatusCode, Uri},
9 middleware::Next,
10 response::Response,
11 routing::{any, get},
12 Router,
13};
14use mockforge_core::proxy::{
15 body_transform::BodyTransformationMiddleware, config::ProxyConfig, handler::ProxyHandler,
16};
17use serde::{Deserialize, Serialize};
18use std::net::SocketAddr;
19use std::sync::Arc;
20use tokio::sync::RwLock;
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug)]
25pub struct ProxyServer {
26 config: Arc<RwLock<ProxyConfig>>,
28 log_requests: bool,
30 log_responses: bool,
32 request_counter: Arc<RwLock<u64>>,
34}
35
36impl ProxyServer {
37 pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
39 Self {
40 config: Arc::new(RwLock::new(config)),
41 log_requests,
42 log_responses,
43 request_counter: Arc::new(RwLock::new(0)),
44 }
45 }
46
47 pub fn router(self) -> Router {
49 let state = Arc::new(self);
50 let state_for_middleware = state.clone();
51
52 Router::new()
53 .route("/proxy/health", get(health_check))
55 .fallback(proxy_handler)
57 .with_state(state)
58 .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
59 }
60}
61
62async fn health_check() -> Result<Response<String>, StatusCode> {
64 Response::builder()
66 .status(StatusCode::OK)
67 .header("Content-Type", "application/json")
68 .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
69 .map_err(|e| {
70 tracing::error!("Failed to build health check response: {}", e);
71 StatusCode::INTERNAL_SERVER_ERROR
72 })
73}
74
75async fn proxy_handler(
77 axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
78 request: axum::http::Request<axum::body::Body>,
79) -> Result<Response<String>, StatusCode> {
80 let client_addr = request
82 .extensions()
83 .get::<SocketAddr>()
84 .copied()
85 .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
86
87 let method = request.method().clone();
88 let uri = request.uri().clone();
89 let headers = request.headers().clone();
90
91 let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
93 Ok(bytes) => Some(bytes.to_vec()),
94 Err(e) => {
95 error!("Failed to read request body: {}", e);
96 None
97 }
98 };
99
100 let config = state.config.read().await;
101
102 if !config.enabled {
104 return Err(StatusCode::SERVICE_UNAVAILABLE);
105 }
106
107 if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
109 return Err(StatusCode::NOT_FOUND);
110 }
111
112 let stripped_path = config.strip_prefix(uri.path());
114
115 let base_upstream_url = config.get_upstream_url(uri.path());
117 let full_upstream_url =
118 if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
119 stripped_path.clone()
120 } else {
121 let base = base_upstream_url.trim_end_matches('/');
122 let path = stripped_path.trim_start_matches('/');
123 let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
124 if path.is_empty() || path == "/" {
125 format!("{}{}", base, query)
126 } else {
127 format!("{}/{}", base, path) + &query
128 }
129 };
130
131 let modified_uri = full_upstream_url.parse::<axum::http::Uri>().unwrap_or_else(|_| uri.clone());
133
134 if state.log_requests {
136 let mut counter = state.request_counter.write().await;
137 *counter += 1;
138 let request_id = *counter;
139
140 info!(
141 request_id = request_id,
142 method = %method,
143 path = %uri.path(),
144 upstream = %full_upstream_url,
145 client_ip = %client_addr.ip(),
146 "Proxy request intercepted"
147 );
148 }
149
150 let mut header_map = std::collections::HashMap::new();
152 for (key, value) in &headers {
153 if let Ok(value_str) = value.to_str() {
154 header_map.insert(key.to_string(), value_str.to_string());
155 }
156 }
157
158 use mockforge_core::proxy::client::ProxyClient;
160 let proxy_client = ProxyClient::new();
161
162 let reqwest_method = match method.as_str() {
164 "GET" => reqwest::Method::GET,
165 "POST" => reqwest::Method::POST,
166 "PUT" => reqwest::Method::PUT,
167 "DELETE" => reqwest::Method::DELETE,
168 "HEAD" => reqwest::Method::HEAD,
169 "OPTIONS" => reqwest::Method::OPTIONS,
170 "PATCH" => reqwest::Method::PATCH,
171 _ => {
172 error!("Unsupported HTTP method: {}", method);
173 return Err(StatusCode::METHOD_NOT_ALLOWED);
174 }
175 };
176
177 for (key, value) in &config.headers {
179 header_map.insert(key.clone(), value.clone());
180 }
181
182 let mut transformed_request_body = body_bytes.clone();
184 if !config.request_replacements.is_empty() {
185 let transform_middleware = BodyTransformationMiddleware::new(
186 config.request_replacements.clone(),
187 Vec::new(), );
189 if let Err(e) =
190 transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
191 {
192 warn!("Failed to transform request body: {}", e);
193 }
195 }
196
197 match proxy_client
198 .send_request(
199 reqwest_method,
200 &full_upstream_url,
201 &header_map,
202 transformed_request_body.as_deref(),
203 )
204 .await
205 {
206 Ok(response) => {
207 let status = StatusCode::from_u16(response.status().as_u16())
208 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
209
210 if state.log_responses {
212 info!(
213 method = %method,
214 path = %uri.path(),
215 status = status.as_u16(),
216 "Proxy response sent"
217 );
218 }
219
220 let mut response_headers = axum::http::HeaderMap::new();
222 for (name, value) in response.headers() {
223 if let (Ok(header_name), Ok(header_value)) = (
224 axum::http::HeaderName::try_from(name.as_str()),
225 axum::http::HeaderValue::try_from(value.as_bytes()),
226 ) {
227 response_headers.insert(header_name, header_value);
228 }
229 }
230
231 let response_body_bytes = response.bytes().await.map_err(|e| {
233 error!("Failed to read proxy response body: {}", e);
234 StatusCode::BAD_GATEWAY
235 })?;
236
237 let mut final_body_bytes = response_body_bytes.to_vec();
239 {
240 let config_for_response = state.config.read().await;
241 if !config_for_response.response_replacements.is_empty() {
242 let transform_middleware = BodyTransformationMiddleware::new(
243 Vec::new(), config_for_response.response_replacements.clone(),
245 );
246 let mut body_option = Some(final_body_bytes.clone());
247 if let Err(e) = transform_middleware.transform_response_body(
248 uri.path(),
249 status.as_u16(),
250 &mut body_option,
251 ) {
252 warn!("Failed to transform response body: {}", e);
253 } else if let Some(transformed_body) = body_option {
255 final_body_bytes = transformed_body;
256 }
257 }
258 }
259
260 let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
261
262 let mut response_builder = Response::builder().status(status);
264 for (name, value) in response_headers.iter() {
265 response_builder = response_builder.header(name, value);
266 }
267
268 response_builder
269 .body(body_string)
270 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
271 }
272 Err(e) => {
273 error!("Proxy request failed: {}", e);
274 Err(StatusCode::BAD_GATEWAY)
275 }
276 }
277}
278
279async fn logging_middleware(
281 axum::extract::State(_state): axum::extract::State<Arc<ProxyServer>>,
282 request: Request,
283 next: Next,
284) -> Response {
285 let start = std::time::Instant::now();
286 let method = request.method().clone();
287 let uri = request.uri().clone();
288
289 let client_addr = request
291 .extensions()
292 .get::<SocketAddr>()
293 .copied()
294 .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
295
296 debug!(
297 method = %method,
298 uri = %uri,
299 client_ip = %client_addr.ip(),
300 "Request received"
301 );
302
303 let response = next.run(request).await;
304 let duration = start.elapsed();
305
306 debug!(
307 method = %method,
308 uri = %uri,
309 status = %response.status(),
310 duration_ms = duration.as_millis(),
311 "Response sent"
312 );
313
314 response
315}
316
317#[derive(Debug, Serialize)]
319pub struct ProxyStats {
320 pub total_requests: u64,
322 pub requests_per_second: f64,
324 pub avg_response_time_ms: f64,
326 pub error_rate_percent: f64,
328}
329
330pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
332 let total_requests = *state.request_counter.read().await;
333
334 ProxyStats {
337 total_requests,
338 requests_per_second: 0.0, avg_response_time_ms: 0.0, error_rate_percent: 0.0, }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use axum::http::StatusCode;
348 use mockforge_core::proxy::config::ProxyConfig;
349 use std::net::SocketAddr;
350
351 #[tokio::test]
352 async fn test_proxy_server_creation() {
353 let config = ProxyConfig::default();
354 let server = ProxyServer::new(config, true, true);
355
356 assert!(server.log_requests);
358 assert!(server.log_responses);
359 }
360
361 #[tokio::test]
362 async fn test_health_check() {
363 let response = health_check().await.unwrap();
364 assert_eq!(response.status(), StatusCode::OK);
365
366 let body = response.into_body();
368
369 assert!(body.contains("healthy"));
370 assert!(body.contains("mockforge-proxy"));
371 }
372
373 #[tokio::test]
374 async fn test_proxy_stats() {
375 let config = ProxyConfig::default();
376 let server = ProxyServer::new(config, false, false);
377
378 let stats = get_proxy_stats(&server).await;
379 assert_eq!(stats.total_requests, 0);
380 assert_eq!(stats.requests_per_second, 0.0);
381 assert_eq!(stats.avg_response_time_ms, 0.0);
382 assert_eq!(stats.error_rate_percent, 0.0);
383 }
384}