1use axum::{
7 extract::Request,
8 http::{HeaderMap, Method, StatusCode, Uri},
9 middleware::Next,
10 response::Response,
11 routing::{any, get},
12 Router,
13};
14use mockforge_core::proxy::{config::ProxyConfig, handler::ProxyHandler};
15use serde::{Deserialize, Serialize};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug)]
23pub struct ProxyServer {
24 config: Arc<RwLock<ProxyConfig>>,
26 log_requests: bool,
28 log_responses: bool,
30 request_counter: Arc<RwLock<u64>>,
32}
33
34impl ProxyServer {
35 pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
37 Self {
38 config: Arc::new(RwLock::new(config)),
39 log_requests,
40 log_responses,
41 request_counter: Arc::new(RwLock::new(0)),
42 }
43 }
44
45 pub fn router(self) -> Router {
47 let state = Arc::new(self);
48 let state_for_middleware = state.clone();
49
50 Router::new()
51 .route("/proxy/health", get(health_check))
53 .fallback(proxy_handler)
55 .with_state(state)
56 .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
57 }
58}
59
60async fn health_check() -> Result<Response<String>, StatusCode> {
62 Response::builder()
64 .status(StatusCode::OK)
65 .header("Content-Type", "application/json")
66 .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
67 .map_err(|e| {
68 tracing::error!("Failed to build health check response: {}", e);
69 StatusCode::INTERNAL_SERVER_ERROR
70 })
71}
72
73async fn proxy_handler(
75 axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
76 request: axum::http::Request<axum::body::Body>,
77) -> Result<Response<String>, StatusCode> {
78 let client_addr = request
80 .extensions()
81 .get::<SocketAddr>()
82 .copied()
83 .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
84
85 let method = request.method().clone();
86 let uri = request.uri().clone();
87 let headers = request.headers().clone();
88
89 let config = state.config.read().await;
90
91 if !config.enabled {
93 return Err(StatusCode::SERVICE_UNAVAILABLE);
94 }
95
96 if !config.should_proxy(&method, uri.path()) {
98 return Err(StatusCode::NOT_FOUND);
99 }
100
101 let stripped_path = config.strip_prefix(uri.path());
103
104 let base_upstream_url = config.get_upstream_url(uri.path());
106 let full_upstream_url =
107 if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
108 stripped_path.clone()
109 } else {
110 let base = base_upstream_url.trim_end_matches('/');
111 let path = stripped_path.trim_start_matches('/');
112 let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
113 if path.is_empty() || path == "/" {
114 format!("{}{}", base, query)
115 } else {
116 format!("{}/{}", base, path) + &query
117 }
118 };
119
120 let modified_uri = full_upstream_url.parse::<axum::http::Uri>().unwrap_or_else(|_| uri.clone());
122
123 if state.log_requests {
125 let mut counter = state.request_counter.write().await;
126 *counter += 1;
127 let request_id = *counter;
128
129 info!(
130 request_id = request_id,
131 method = %method,
132 path = %uri.path(),
133 upstream = %full_upstream_url,
134 client_ip = %client_addr.ip(),
135 "Proxy request intercepted"
136 );
137 }
138
139 let mut header_map = std::collections::HashMap::new();
141 for (key, value) in &headers {
142 if let Ok(value_str) = value.to_str() {
143 header_map.insert(key.to_string(), value_str.to_string());
144 }
145 }
146
147 let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
149 Ok(bytes) => Some(bytes.to_vec()),
150 Err(e) => {
151 error!("Failed to read request body: {}", e);
152 None
153 }
154 };
155
156 use mockforge_core::proxy::client::ProxyClient;
158 let proxy_client = ProxyClient::new();
159
160 let mut header_map = std::collections::HashMap::new();
162 for (key, value) in &headers {
163 if let Ok(value_str) = value.to_str() {
164 header_map.insert(key.to_string(), value_str.to_string());
165 }
166 }
167
168 let reqwest_method = match method.as_str() {
170 "GET" => reqwest::Method::GET,
171 "POST" => reqwest::Method::POST,
172 "PUT" => reqwest::Method::PUT,
173 "DELETE" => reqwest::Method::DELETE,
174 "HEAD" => reqwest::Method::HEAD,
175 "OPTIONS" => reqwest::Method::OPTIONS,
176 "PATCH" => reqwest::Method::PATCH,
177 _ => {
178 error!("Unsupported HTTP method: {}", method);
179 return Err(StatusCode::METHOD_NOT_ALLOWED);
180 }
181 };
182
183 for (key, value) in &config.headers {
185 header_map.insert(key.clone(), value.clone());
186 }
187
188 match proxy_client
189 .send_request(reqwest_method, &full_upstream_url, &header_map, body_bytes.as_deref())
190 .await
191 {
192 Ok(response) => {
193 let status = StatusCode::from_u16(response.status().as_u16())
194 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
195
196 if state.log_responses {
198 info!(
199 method = %method,
200 path = %uri.path(),
201 status = status.as_u16(),
202 "Proxy response sent"
203 );
204 }
205
206 let mut response_headers = axum::http::HeaderMap::new();
208 for (name, value) in response.headers() {
209 if let (Ok(header_name), Ok(header_value)) = (
210 axum::http::HeaderName::try_from(name.as_str()),
211 axum::http::HeaderValue::try_from(value.as_bytes()),
212 ) {
213 response_headers.insert(header_name, header_value);
214 }
215 }
216
217 let body_bytes = response.bytes().await.map_err(|e| {
219 error!("Failed to read proxy response body: {}", e);
220 StatusCode::BAD_GATEWAY
221 })?;
222
223 let body_string = String::from_utf8_lossy(&body_bytes).to_string();
224
225 let mut response_builder = Response::builder().status(status);
227 for (name, value) in response_headers.iter() {
228 response_builder = response_builder.header(name, value);
229 }
230
231 response_builder
232 .body(body_string)
233 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
234 }
235 Err(e) => {
236 error!("Proxy request failed: {}", e);
237 Err(StatusCode::BAD_GATEWAY)
238 }
239 }
240}
241
242async fn logging_middleware(
244 axum::extract::State(_state): axum::extract::State<Arc<ProxyServer>>,
245 request: Request,
246 next: Next,
247) -> Response {
248 let start = std::time::Instant::now();
249 let method = request.method().clone();
250 let uri = request.uri().clone();
251
252 let client_addr = request
254 .extensions()
255 .get::<SocketAddr>()
256 .copied()
257 .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
258
259 debug!(
260 method = %method,
261 uri = %uri,
262 client_ip = %client_addr.ip(),
263 "Request received"
264 );
265
266 let response = next.run(request).await;
267 let duration = start.elapsed();
268
269 debug!(
270 method = %method,
271 uri = %uri,
272 status = %response.status(),
273 duration_ms = duration.as_millis(),
274 "Response sent"
275 );
276
277 response
278}
279
280#[derive(Debug, Serialize)]
282pub struct ProxyStats {
283 pub total_requests: u64,
285 pub requests_per_second: f64,
287 pub avg_response_time_ms: f64,
289 pub error_rate_percent: f64,
291}
292
293pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
295 let total_requests = *state.request_counter.read().await;
296
297 ProxyStats {
300 total_requests,
301 requests_per_second: 0.0, avg_response_time_ms: 0.0, error_rate_percent: 0.0, }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use axum::http::StatusCode;
311 use mockforge_core::proxy::config::ProxyConfig;
312 use std::net::SocketAddr;
313
314 #[tokio::test]
315 async fn test_proxy_server_creation() {
316 let config = ProxyConfig::default();
317 let server = ProxyServer::new(config, true, true);
318
319 assert!(server.log_requests);
321 assert!(server.log_responses);
322 }
323
324 #[tokio::test]
325 async fn test_health_check() {
326 let response = health_check().await.unwrap();
327 assert_eq!(response.status(), StatusCode::OK);
328
329 let body = response.into_body();
331
332 assert!(body.contains("healthy"));
333 assert!(body.contains("mockforge-proxy"));
334 }
335
336 #[tokio::test]
337 async fn test_proxy_stats() {
338 let config = ProxyConfig::default();
339 let server = ProxyServer::new(config, false, false);
340
341 let stats = get_proxy_stats(&server).await;
342 assert_eq!(stats.total_requests, 0);
343 assert_eq!(stats.requests_per_second, 0.0);
344 assert_eq!(stats.avg_response_time_ms, 0.0);
345 assert_eq!(stats.error_rate_percent, 0.0);
346 }
347}