Skip to main content

mockforge_http/
proxy_server.rs

1//! Browser/Mobile Proxy Server
2//!
3//! Provides an intercepting proxy for frontend/mobile clients with HTTPS support,
4//! certificate injection, and comprehensive request/response logging.
5
6use axum::{
7    extract::Request, http::StatusCode, middleware::Next, response::Response, routing::get, Router,
8};
9use mockforge_core::proxy::{body_transform::BodyTransformationMiddleware, config::ProxyConfig};
10use serde::Serialize;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, warn};
15
16/// Proxy server state
17#[derive(Debug)]
18pub struct ProxyServer {
19    /// Proxy configuration
20    config: Arc<RwLock<ProxyConfig>>,
21    /// Request logging enabled
22    log_requests: bool,
23    /// Response logging enabled
24    log_responses: bool,
25    /// Request counter for logging
26    request_counter: Arc<RwLock<u64>>,
27    /// Server start time for uptime and rate calculations
28    start_time: std::time::Instant,
29    /// Total response time in milliseconds for average calculation
30    total_response_time_ms: Arc<RwLock<u64>>,
31    /// Error counter for error rate calculation
32    error_counter: Arc<RwLock<u64>>,
33}
34
35impl ProxyServer {
36    /// Create a new proxy server
37    pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
38        Self {
39            config: Arc::new(RwLock::new(config)),
40            log_requests,
41            log_responses,
42            request_counter: Arc::new(RwLock::new(0)),
43            start_time: std::time::Instant::now(),
44            total_response_time_ms: Arc::new(RwLock::new(0)),
45            error_counter: Arc::new(RwLock::new(0)),
46        }
47    }
48
49    /// Get the Axum router for the proxy server
50    pub fn router(self) -> Router {
51        let state = Arc::new(self);
52        let state_for_middleware = state.clone();
53
54        Router::new()
55            // Health check endpoint
56            .route("/proxy/health", get(health_check))
57            // Catch-all proxy handler - use fallback for all methods
58            .fallback(proxy_handler)
59            .with_state(state)
60            .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
61    }
62}
63
64/// Health check endpoint for the proxy
65async fn health_check() -> Result<Response<String>, StatusCode> {
66    // Response builder should never fail with known-good values, but handle errors gracefully
67    Response::builder()
68        .status(StatusCode::OK)
69        .header("Content-Type", "application/json")
70        .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
71        .map_err(|e| {
72            tracing::error!("Failed to build health check response: {}", e);
73            StatusCode::INTERNAL_SERVER_ERROR
74        })
75}
76
77/// Main proxy handler that intercepts and forwards requests
78async fn proxy_handler(
79    axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
80    request: http::Request<axum::body::Body>,
81) -> Result<Response<String>, StatusCode> {
82    // Extract client address from request extensions (set by ConnectInfo middleware)
83    let client_addr = request
84        .extensions()
85        .get::<SocketAddr>()
86        .copied()
87        .unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
88
89    let method = request.method().clone();
90    let uri = request.uri().clone();
91    let headers = request.headers().clone();
92
93    // Read request body early for conditional evaluation (consume the body)
94    let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
95        Ok(bytes) => Some(bytes.to_vec()),
96        Err(e) => {
97            error!("Failed to read request body: {}", e);
98            None
99        }
100    };
101
102    let config = state.config.read().await;
103
104    // Check if proxy is enabled
105    if !config.enabled {
106        return Err(StatusCode::SERVICE_UNAVAILABLE);
107    }
108
109    // Determine if this request should be proxied (with conditional evaluation)
110    if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
111        return Err(StatusCode::NOT_FOUND);
112    }
113
114    // Get the stripped path (without proxy prefix)
115    let stripped_path = config.strip_prefix(uri.path());
116
117    // Get the base upstream URL and construct the full URL
118    let base_upstream_url = config.get_upstream_url(uri.path());
119    let full_upstream_url =
120        if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
121            stripped_path.clone()
122        } else {
123            let base = base_upstream_url.trim_end_matches('/');
124            let path = stripped_path.trim_start_matches('/');
125            let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
126            if path.is_empty() || path == "/" {
127                format!("{}{}", base, query)
128            } else {
129                format!("{}/{}", base, path) + &query
130            }
131        };
132
133    // Create a new URI with the full upstream URL for the proxy handler
134    let _modified_uri = full_upstream_url.parse::<http::Uri>().unwrap_or_else(|_| uri.clone());
135
136    // Log the request if enabled
137    if state.log_requests {
138        let mut counter = state.request_counter.write().await;
139        *counter += 1;
140        let request_id = *counter;
141
142        info!(
143            request_id = request_id,
144            method = %method,
145            path = %uri.path(),
146            upstream = %full_upstream_url,
147            client_ip = %client_addr.ip(),
148            "Proxy request intercepted"
149        );
150    }
151
152    // Convert headers to HashMap for the proxy handler
153    let mut header_map = std::collections::HashMap::new();
154    for (key, value) in &headers {
155        if let Ok(value_str) = value.to_str() {
156            header_map.insert(key.to_string(), value_str.to_string());
157        }
158    }
159
160    // Use ProxyClient directly with the full upstream URL to bypass ProxyHandler's URL construction
161    use mockforge_core::proxy::client::ProxyClient;
162    let proxy_client = ProxyClient::new();
163
164    // Convert method to reqwest method
165    let reqwest_method = match method.as_str() {
166        "GET" => reqwest::Method::GET,
167        "POST" => reqwest::Method::POST,
168        "PUT" => reqwest::Method::PUT,
169        "DELETE" => reqwest::Method::DELETE,
170        "HEAD" => reqwest::Method::HEAD,
171        "OPTIONS" => reqwest::Method::OPTIONS,
172        "PATCH" => reqwest::Method::PATCH,
173        _ => {
174            error!("Unsupported HTTP method: {}", method);
175            return Err(StatusCode::METHOD_NOT_ALLOWED);
176        }
177    };
178
179    // Add any configured headers
180    for (key, value) in &config.headers {
181        header_map.insert(key.clone(), value.clone());
182    }
183
184    // Apply request body transformations if configured
185    let mut transformed_request_body = body_bytes.clone();
186    if !config.request_replacements.is_empty() {
187        let transform_middleware = BodyTransformationMiddleware::new(
188            config.request_replacements.clone(),
189            Vec::new(), // No response rules needed here
190        );
191        if let Err(e) =
192            transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
193        {
194            warn!("Failed to transform request body: {}", e);
195            // Continue with original body if transformation fails
196        }
197    }
198
199    match proxy_client
200        .send_request(
201            reqwest_method,
202            &full_upstream_url,
203            &header_map,
204            transformed_request_body.as_deref(),
205        )
206        .await
207    {
208        Ok(response) => {
209            let status = StatusCode::from_u16(response.status().as_u16())
210                .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
211
212            // Log the response if enabled
213            if state.log_responses {
214                info!(
215                    method = %method,
216                    path = %uri.path(),
217                    status = status.as_u16(),
218                    "Proxy response sent"
219                );
220            }
221
222            // Convert response headers
223            let mut response_headers = http::HeaderMap::new();
224            for (name, value) in response.headers() {
225                if let (Ok(header_name), Ok(header_value)) = (
226                    http::HeaderName::try_from(name.as_str()),
227                    http::HeaderValue::try_from(value.as_bytes()),
228                ) {
229                    response_headers.insert(header_name, header_value);
230                }
231            }
232
233            // Read response body
234            let response_body_bytes = response.bytes().await.map_err(|e| {
235                error!("Failed to read proxy response body: {}", e);
236                StatusCode::BAD_GATEWAY
237            })?;
238
239            // Apply response body transformations if configured
240            let mut final_body_bytes = response_body_bytes.to_vec();
241            {
242                let config_for_response = state.config.read().await;
243                if !config_for_response.response_replacements.is_empty() {
244                    let transform_middleware = BodyTransformationMiddleware::new(
245                        Vec::new(), // No request rules needed here
246                        config_for_response.response_replacements.clone(),
247                    );
248                    let mut body_option = Some(final_body_bytes.clone());
249                    if let Err(e) = transform_middleware.transform_response_body(
250                        uri.path(),
251                        status.as_u16(),
252                        &mut body_option,
253                    ) {
254                        warn!("Failed to transform response body: {}", e);
255                        // Continue with original body if transformation fails
256                    } else if let Some(transformed_body) = body_option {
257                        final_body_bytes = transformed_body;
258                    }
259                }
260            }
261
262            let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
263
264            // Build Axum response
265            let mut response_builder = Response::builder().status(status);
266            for (name, value) in response_headers.iter() {
267                response_builder = response_builder.header(name, value);
268            }
269
270            response_builder
271                .body(body_string)
272                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
273        }
274        Err(e) => {
275            error!("Proxy request failed: {}", e);
276            Err(StatusCode::BAD_GATEWAY)
277        }
278    }
279}
280
281/// Middleware for logging requests and responses
282async fn logging_middleware(
283    axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
284    request: Request,
285    next: Next,
286) -> Response {
287    let start = std::time::Instant::now();
288    let method = request.method().clone();
289    let uri = request.uri().clone();
290
291    // Extract client address from request extensions
292    let client_addr = request
293        .extensions()
294        .get::<SocketAddr>()
295        .copied()
296        .unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
297
298    debug!(
299        method = %method,
300        uri = %uri,
301        client_ip = %client_addr.ip(),
302        "Request received"
303    );
304
305    let response = next.run(request).await;
306    let duration = start.elapsed();
307
308    // Track response time
309    {
310        let mut total_time = state.total_response_time_ms.write().await;
311        *total_time += duration.as_millis() as u64;
312    }
313
314    // Track server errors
315    if response.status().is_server_error() {
316        let mut errors = state.error_counter.write().await;
317        *errors += 1;
318    }
319
320    debug!(
321        method = %method,
322        uri = %uri,
323        status = %response.status(),
324        duration_ms = duration.as_millis(),
325        "Response sent"
326    );
327
328    response
329}
330
331/// Proxy statistics for monitoring
332#[derive(Debug, Serialize)]
333pub struct ProxyStats {
334    /// Total requests processed
335    pub total_requests: u64,
336    /// Requests per second
337    pub requests_per_second: f64,
338    /// Average response time in milliseconds
339    pub avg_response_time_ms: f64,
340    /// Error rate percentage
341    pub error_rate_percent: f64,
342}
343
344/// Get proxy statistics
345pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
346    let total_requests = *state.request_counter.read().await;
347    let total_response_time_ms = *state.total_response_time_ms.read().await;
348    let error_count = *state.error_counter.read().await;
349
350    let elapsed_secs = state.start_time.elapsed().as_secs_f64();
351    let requests_per_second = if elapsed_secs > 0.0 {
352        total_requests as f64 / elapsed_secs
353    } else {
354        0.0
355    };
356
357    let avg_response_time_ms = if total_requests > 0 {
358        total_response_time_ms as f64 / total_requests as f64
359    } else {
360        0.0
361    };
362
363    let error_rate_percent = if total_requests > 0 {
364        (error_count as f64 / total_requests as f64) * 100.0
365    } else {
366        0.0
367    };
368
369    ProxyStats {
370        total_requests,
371        requests_per_second,
372        avg_response_time_ms,
373        error_rate_percent,
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use axum::http::StatusCode;
381    use mockforge_core::proxy::config::ProxyConfig;
382
383    #[tokio::test]
384    async fn test_proxy_server_creation() {
385        let config = ProxyConfig::default();
386        let server = ProxyServer::new(config, true, true);
387
388        // Test that the server can be created
389        assert!(server.log_requests);
390        assert!(server.log_responses);
391    }
392
393    #[tokio::test]
394    async fn test_health_check() {
395        let response = health_check().await.unwrap();
396        assert_eq!(response.status(), StatusCode::OK);
397
398        // Response body is already a String
399        let body = response.into_body();
400
401        assert!(body.contains("healthy"));
402        assert!(body.contains("mockforge-proxy"));
403    }
404
405    #[tokio::test]
406    async fn test_proxy_stats() {
407        let config = ProxyConfig::default();
408        let server = ProxyServer::new(config, false, false);
409
410        // With zero requests, all derived stats should be zero
411        let stats = get_proxy_stats(&server).await;
412        assert_eq!(stats.total_requests, 0);
413        assert_eq!(stats.avg_response_time_ms, 0.0);
414        assert_eq!(stats.error_rate_percent, 0.0);
415    }
416}