mockforge_http/
proxy_server.rs

1//! Browser/Mobile Proxy Server
2//!
3//! Provides an intercepting proxy for frontend/mobile clients with HTTPS support,
4//! certificate injection, and comprehensive request/response logging.
5
6use axum::{
7    extract::Request, http::StatusCode, middleware::Next, response::Response, routing::get, Router,
8};
9use mockforge_core::proxy::{body_transform::BodyTransformationMiddleware, config::ProxyConfig};
10use serde::Serialize;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, warn};
15
16/// Proxy server state
17#[derive(Debug)]
18pub struct ProxyServer {
19    /// Proxy configuration
20    config: Arc<RwLock<ProxyConfig>>,
21    /// Request logging enabled
22    log_requests: bool,
23    /// Response logging enabled
24    log_responses: bool,
25    /// Request counter for logging
26    request_counter: Arc<RwLock<u64>>,
27}
28
29impl ProxyServer {
30    /// Create a new proxy server
31    pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
32        Self {
33            config: Arc::new(RwLock::new(config)),
34            log_requests,
35            log_responses,
36            request_counter: Arc::new(RwLock::new(0)),
37        }
38    }
39
40    /// Get the Axum router for the proxy server
41    pub fn router(self) -> Router {
42        let state = Arc::new(self);
43        let state_for_middleware = state.clone();
44
45        Router::new()
46            // Health check endpoint
47            .route("/proxy/health", get(health_check))
48            // Catch-all proxy handler - use fallback for all methods
49            .fallback(proxy_handler)
50            .with_state(state)
51            .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
52    }
53}
54
55/// Health check endpoint for the proxy
56async fn health_check() -> Result<Response<String>, StatusCode> {
57    // Response builder should never fail with known-good values, but handle errors gracefully
58    Response::builder()
59        .status(StatusCode::OK)
60        .header("Content-Type", "application/json")
61        .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
62        .map_err(|e| {
63            tracing::error!("Failed to build health check response: {}", e);
64            StatusCode::INTERNAL_SERVER_ERROR
65        })
66}
67
68/// Main proxy handler that intercepts and forwards requests
69async fn proxy_handler(
70    axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
71    request: axum::http::Request<axum::body::Body>,
72) -> Result<Response<String>, StatusCode> {
73    // Extract client address from request extensions (set by ConnectInfo middleware)
74    let client_addr = request
75        .extensions()
76        .get::<SocketAddr>()
77        .copied()
78        .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
79
80    let method = request.method().clone();
81    let uri = request.uri().clone();
82    let headers = request.headers().clone();
83
84    // Read request body early for conditional evaluation (consume the body)
85    let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
86        Ok(bytes) => Some(bytes.to_vec()),
87        Err(e) => {
88            error!("Failed to read request body: {}", e);
89            None
90        }
91    };
92
93    let config = state.config.read().await;
94
95    // Check if proxy is enabled
96    if !config.enabled {
97        return Err(StatusCode::SERVICE_UNAVAILABLE);
98    }
99
100    // Determine if this request should be proxied (with conditional evaluation)
101    if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
102        return Err(StatusCode::NOT_FOUND);
103    }
104
105    // Get the stripped path (without proxy prefix)
106    let stripped_path = config.strip_prefix(uri.path());
107
108    // Get the base upstream URL and construct the full URL
109    let base_upstream_url = config.get_upstream_url(uri.path());
110    let full_upstream_url =
111        if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
112            stripped_path.clone()
113        } else {
114            let base = base_upstream_url.trim_end_matches('/');
115            let path = stripped_path.trim_start_matches('/');
116            let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
117            if path.is_empty() || path == "/" {
118                format!("{}{}", base, query)
119            } else {
120                format!("{}/{}", base, path) + &query
121            }
122        };
123
124    // Create a new URI with the full upstream URL for the proxy handler
125    let modified_uri = full_upstream_url.parse::<axum::http::Uri>().unwrap_or_else(|_| uri.clone());
126
127    // Log the request if enabled
128    if state.log_requests {
129        let mut counter = state.request_counter.write().await;
130        *counter += 1;
131        let request_id = *counter;
132
133        info!(
134            request_id = request_id,
135            method = %method,
136            path = %uri.path(),
137            upstream = %full_upstream_url,
138            client_ip = %client_addr.ip(),
139            "Proxy request intercepted"
140        );
141    }
142
143    // Convert headers to HashMap for the proxy handler
144    let mut header_map = std::collections::HashMap::new();
145    for (key, value) in &headers {
146        if let Ok(value_str) = value.to_str() {
147            header_map.insert(key.to_string(), value_str.to_string());
148        }
149    }
150
151    // Use ProxyClient directly with the full upstream URL to bypass ProxyHandler's URL construction
152    use mockforge_core::proxy::client::ProxyClient;
153    let proxy_client = ProxyClient::new();
154
155    // Convert method to reqwest method
156    let reqwest_method = match method.as_str() {
157        "GET" => reqwest::Method::GET,
158        "POST" => reqwest::Method::POST,
159        "PUT" => reqwest::Method::PUT,
160        "DELETE" => reqwest::Method::DELETE,
161        "HEAD" => reqwest::Method::HEAD,
162        "OPTIONS" => reqwest::Method::OPTIONS,
163        "PATCH" => reqwest::Method::PATCH,
164        _ => {
165            error!("Unsupported HTTP method: {}", method);
166            return Err(StatusCode::METHOD_NOT_ALLOWED);
167        }
168    };
169
170    // Add any configured headers
171    for (key, value) in &config.headers {
172        header_map.insert(key.clone(), value.clone());
173    }
174
175    // Apply request body transformations if configured
176    let mut transformed_request_body = body_bytes.clone();
177    if !config.request_replacements.is_empty() {
178        let transform_middleware = BodyTransformationMiddleware::new(
179            config.request_replacements.clone(),
180            Vec::new(), // No response rules needed here
181        );
182        if let Err(e) =
183            transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
184        {
185            warn!("Failed to transform request body: {}", e);
186            // Continue with original body if transformation fails
187        }
188    }
189
190    match proxy_client
191        .send_request(
192            reqwest_method,
193            &full_upstream_url,
194            &header_map,
195            transformed_request_body.as_deref(),
196        )
197        .await
198    {
199        Ok(response) => {
200            let status = StatusCode::from_u16(response.status().as_u16())
201                .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
202
203            // Log the response if enabled
204            if state.log_responses {
205                info!(
206                    method = %method,
207                    path = %uri.path(),
208                    status = status.as_u16(),
209                    "Proxy response sent"
210                );
211            }
212
213            // Convert response headers
214            let mut response_headers = axum::http::HeaderMap::new();
215            for (name, value) in response.headers() {
216                if let (Ok(header_name), Ok(header_value)) = (
217                    axum::http::HeaderName::try_from(name.as_str()),
218                    axum::http::HeaderValue::try_from(value.as_bytes()),
219                ) {
220                    response_headers.insert(header_name, header_value);
221                }
222            }
223
224            // Read response body
225            let response_body_bytes = response.bytes().await.map_err(|e| {
226                error!("Failed to read proxy response body: {}", e);
227                StatusCode::BAD_GATEWAY
228            })?;
229
230            // Apply response body transformations if configured
231            let mut final_body_bytes = response_body_bytes.to_vec();
232            {
233                let config_for_response = state.config.read().await;
234                if !config_for_response.response_replacements.is_empty() {
235                    let transform_middleware = BodyTransformationMiddleware::new(
236                        Vec::new(), // No request rules needed here
237                        config_for_response.response_replacements.clone(),
238                    );
239                    let mut body_option = Some(final_body_bytes.clone());
240                    if let Err(e) = transform_middleware.transform_response_body(
241                        uri.path(),
242                        status.as_u16(),
243                        &mut body_option,
244                    ) {
245                        warn!("Failed to transform response body: {}", e);
246                        // Continue with original body if transformation fails
247                    } else if let Some(transformed_body) = body_option {
248                        final_body_bytes = transformed_body;
249                    }
250                }
251            }
252
253            let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
254
255            // Build Axum response
256            let mut response_builder = Response::builder().status(status);
257            for (name, value) in response_headers.iter() {
258                response_builder = response_builder.header(name, value);
259            }
260
261            response_builder
262                .body(body_string)
263                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
264        }
265        Err(e) => {
266            error!("Proxy request failed: {}", e);
267            Err(StatusCode::BAD_GATEWAY)
268        }
269    }
270}
271
272/// Middleware for logging requests and responses
273async fn logging_middleware(
274    axum::extract::State(_state): axum::extract::State<Arc<ProxyServer>>,
275    request: Request,
276    next: Next,
277) -> Response {
278    let start = std::time::Instant::now();
279    let method = request.method().clone();
280    let uri = request.uri().clone();
281
282    // Extract client address from request extensions
283    let client_addr = request
284        .extensions()
285        .get::<SocketAddr>()
286        .copied()
287        .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
288
289    debug!(
290        method = %method,
291        uri = %uri,
292        client_ip = %client_addr.ip(),
293        "Request received"
294    );
295
296    let response = next.run(request).await;
297    let duration = start.elapsed();
298
299    debug!(
300        method = %method,
301        uri = %uri,
302        status = %response.status(),
303        duration_ms = duration.as_millis(),
304        "Response sent"
305    );
306
307    response
308}
309
310/// Proxy statistics for monitoring
311#[derive(Debug, Serialize)]
312pub struct ProxyStats {
313    /// Total requests processed
314    pub total_requests: u64,
315    /// Requests per second
316    pub requests_per_second: f64,
317    /// Average response time in milliseconds
318    pub avg_response_time_ms: f64,
319    /// Error rate percentage
320    pub error_rate_percent: f64,
321}
322
323/// Get proxy statistics
324pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
325    let total_requests = *state.request_counter.read().await;
326
327    // For now, return basic stats. In a real implementation,
328    // you'd track more detailed metrics over time.
329    ProxyStats {
330        total_requests,
331        requests_per_second: 0.0,  // Would need time-based tracking
332        avg_response_time_ms: 0.0, // Would need timing data
333        error_rate_percent: 0.0,   // Would need error tracking
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use axum::http::StatusCode;
341    use mockforge_core::proxy::config::ProxyConfig;
342    use std::net::SocketAddr;
343
344    #[tokio::test]
345    async fn test_proxy_server_creation() {
346        let config = ProxyConfig::default();
347        let server = ProxyServer::new(config, true, true);
348
349        // Test that the server can be created
350        assert!(server.log_requests);
351        assert!(server.log_responses);
352    }
353
354    #[tokio::test]
355    async fn test_health_check() {
356        let response = health_check().await.unwrap();
357        assert_eq!(response.status(), StatusCode::OK);
358
359        // Response body is already a String
360        let body = response.into_body();
361
362        assert!(body.contains("healthy"));
363        assert!(body.contains("mockforge-proxy"));
364    }
365
366    #[tokio::test]
367    async fn test_proxy_stats() {
368        let config = ProxyConfig::default();
369        let server = ProxyServer::new(config, false, false);
370
371        let stats = get_proxy_stats(&server).await;
372        assert_eq!(stats.total_requests, 0);
373        assert_eq!(stats.requests_per_second, 0.0);
374        assert_eq!(stats.avg_response_time_ms, 0.0);
375        assert_eq!(stats.error_rate_percent, 0.0);
376    }
377}