mockforge_http/
proxy_server.rs

1//! Browser/Mobile Proxy Server
2//!
3//! Provides an intercepting proxy for frontend/mobile clients with HTTPS support,
4//! certificate injection, and comprehensive request/response logging.
5
6use axum::{
7    extract::Request,
8    http::{HeaderMap, Method, StatusCode, Uri},
9    middleware::Next,
10    response::Response,
11    routing::{any, get},
12    Router,
13};
14use mockforge_core::proxy::{config::ProxyConfig, handler::ProxyHandler};
15use serde::{Deserialize, Serialize};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, error, info, warn};
20
21/// Proxy server state
22#[derive(Debug)]
23pub struct ProxyServer {
24    /// Proxy configuration
25    config: Arc<RwLock<ProxyConfig>>,
26    /// Request logging enabled
27    log_requests: bool,
28    /// Response logging enabled
29    log_responses: bool,
30    /// Request counter for logging
31    request_counter: Arc<RwLock<u64>>,
32}
33
34impl ProxyServer {
35    /// Create a new proxy server
36    pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
37        Self {
38            config: Arc::new(RwLock::new(config)),
39            log_requests,
40            log_responses,
41            request_counter: Arc::new(RwLock::new(0)),
42        }
43    }
44
45    /// Get the Axum router for the proxy server
46    pub fn router(self) -> Router {
47        let state = Arc::new(self);
48        let state_for_middleware = state.clone();
49
50        Router::new()
51            // Health check endpoint
52            .route("/proxy/health", get(health_check))
53            // Catch-all proxy handler - use fallback for all methods
54            .fallback(proxy_handler)
55            .with_state(state)
56            .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
57    }
58}
59
60/// Health check endpoint for the proxy
61async fn health_check() -> Result<Response<String>, StatusCode> {
62    // Response builder should never fail with known-good values, but handle errors gracefully
63    Response::builder()
64        .status(StatusCode::OK)
65        .header("Content-Type", "application/json")
66        .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
67        .map_err(|e| {
68            tracing::error!("Failed to build health check response: {}", e);
69            StatusCode::INTERNAL_SERVER_ERROR
70        })
71}
72
73/// Main proxy handler that intercepts and forwards requests
74async fn proxy_handler(
75    axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
76    request: axum::http::Request<axum::body::Body>,
77) -> Result<Response<String>, StatusCode> {
78    // Extract client address from request extensions (set by ConnectInfo middleware)
79    let client_addr = request
80        .extensions()
81        .get::<SocketAddr>()
82        .copied()
83        .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
84
85    let method = request.method().clone();
86    let uri = request.uri().clone();
87    let headers = request.headers().clone();
88
89    let config = state.config.read().await;
90
91    // Check if proxy is enabled
92    if !config.enabled {
93        return Err(StatusCode::SERVICE_UNAVAILABLE);
94    }
95
96    // Determine if this request should be proxied
97    if !config.should_proxy(&method, uri.path()) {
98        return Err(StatusCode::NOT_FOUND);
99    }
100
101    // Get the stripped path (without proxy prefix)
102    let stripped_path = config.strip_prefix(uri.path());
103
104    // Get the base upstream URL and construct the full URL
105    let base_upstream_url = config.get_upstream_url(uri.path());
106    let full_upstream_url =
107        if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
108            stripped_path.clone()
109        } else {
110            let base = base_upstream_url.trim_end_matches('/');
111            let path = stripped_path.trim_start_matches('/');
112            let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
113            if path.is_empty() || path == "/" {
114                format!("{}{}", base, query)
115            } else {
116                format!("{}/{}", base, path) + &query
117            }
118        };
119
120    // Create a new URI with the full upstream URL for the proxy handler
121    let modified_uri = full_upstream_url.parse::<axum::http::Uri>().unwrap_or_else(|_| uri.clone());
122
123    // Log the request if enabled
124    if state.log_requests {
125        let mut counter = state.request_counter.write().await;
126        *counter += 1;
127        let request_id = *counter;
128
129        info!(
130            request_id = request_id,
131            method = %method,
132            path = %uri.path(),
133            upstream = %full_upstream_url,
134            client_ip = %client_addr.ip(),
135            "Proxy request intercepted"
136        );
137    }
138
139    // Convert headers to HashMap for the proxy handler
140    let mut header_map = std::collections::HashMap::new();
141    for (key, value) in &headers {
142        if let Ok(value_str) = value.to_str() {
143            header_map.insert(key.to_string(), value_str.to_string());
144        }
145    }
146
147    // Read request body (consume the request)
148    let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
149        Ok(bytes) => Some(bytes.to_vec()),
150        Err(e) => {
151            error!("Failed to read request body: {}", e);
152            None
153        }
154    };
155
156    // Use ProxyClient directly with the full upstream URL to bypass ProxyHandler's URL construction
157    use mockforge_core::proxy::client::ProxyClient;
158    let proxy_client = ProxyClient::new();
159
160    // Convert headers to HashMap
161    let mut header_map = std::collections::HashMap::new();
162    for (key, value) in &headers {
163        if let Ok(value_str) = value.to_str() {
164            header_map.insert(key.to_string(), value_str.to_string());
165        }
166    }
167
168    // Convert method to reqwest method
169    let reqwest_method = match method.as_str() {
170        "GET" => reqwest::Method::GET,
171        "POST" => reqwest::Method::POST,
172        "PUT" => reqwest::Method::PUT,
173        "DELETE" => reqwest::Method::DELETE,
174        "HEAD" => reqwest::Method::HEAD,
175        "OPTIONS" => reqwest::Method::OPTIONS,
176        "PATCH" => reqwest::Method::PATCH,
177        _ => {
178            error!("Unsupported HTTP method: {}", method);
179            return Err(StatusCode::METHOD_NOT_ALLOWED);
180        }
181    };
182
183    // Add any configured headers
184    for (key, value) in &config.headers {
185        header_map.insert(key.clone(), value.clone());
186    }
187
188    match proxy_client
189        .send_request(reqwest_method, &full_upstream_url, &header_map, body_bytes.as_deref())
190        .await
191    {
192        Ok(response) => {
193            let status = StatusCode::from_u16(response.status().as_u16())
194                .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
195
196            // Log the response if enabled
197            if state.log_responses {
198                info!(
199                    method = %method,
200                    path = %uri.path(),
201                    status = status.as_u16(),
202                    "Proxy response sent"
203                );
204            }
205
206            // Convert response headers
207            let mut response_headers = axum::http::HeaderMap::new();
208            for (name, value) in response.headers() {
209                if let (Ok(header_name), Ok(header_value)) = (
210                    axum::http::HeaderName::try_from(name.as_str()),
211                    axum::http::HeaderValue::try_from(value.as_bytes()),
212                ) {
213                    response_headers.insert(header_name, header_value);
214                }
215            }
216
217            // Read response body
218            let body_bytes = response.bytes().await.map_err(|e| {
219                error!("Failed to read proxy response body: {}", e);
220                StatusCode::BAD_GATEWAY
221            })?;
222
223            let body_string = String::from_utf8_lossy(&body_bytes).to_string();
224
225            // Build Axum response
226            let mut response_builder = Response::builder().status(status);
227            for (name, value) in response_headers.iter() {
228                response_builder = response_builder.header(name, value);
229            }
230
231            response_builder
232                .body(body_string)
233                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
234        }
235        Err(e) => {
236            error!("Proxy request failed: {}", e);
237            Err(StatusCode::BAD_GATEWAY)
238        }
239    }
240}
241
242/// Middleware for logging requests and responses
243async fn logging_middleware(
244    axum::extract::State(_state): axum::extract::State<Arc<ProxyServer>>,
245    request: Request,
246    next: Next,
247) -> Response {
248    let start = std::time::Instant::now();
249    let method = request.method().clone();
250    let uri = request.uri().clone();
251
252    // Extract client address from request extensions
253    let client_addr = request
254        .extensions()
255        .get::<SocketAddr>()
256        .copied()
257        .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
258
259    debug!(
260        method = %method,
261        uri = %uri,
262        client_ip = %client_addr.ip(),
263        "Request received"
264    );
265
266    let response = next.run(request).await;
267    let duration = start.elapsed();
268
269    debug!(
270        method = %method,
271        uri = %uri,
272        status = %response.status(),
273        duration_ms = duration.as_millis(),
274        "Response sent"
275    );
276
277    response
278}
279
280/// Proxy statistics for monitoring
281#[derive(Debug, Serialize)]
282pub struct ProxyStats {
283    /// Total requests processed
284    pub total_requests: u64,
285    /// Requests per second
286    pub requests_per_second: f64,
287    /// Average response time in milliseconds
288    pub avg_response_time_ms: f64,
289    /// Error rate percentage
290    pub error_rate_percent: f64,
291}
292
293/// Get proxy statistics
294pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
295    let total_requests = *state.request_counter.read().await;
296
297    // For now, return basic stats. In a real implementation,
298    // you'd track more detailed metrics over time.
299    ProxyStats {
300        total_requests,
301        requests_per_second: 0.0,  // Would need time-based tracking
302        avg_response_time_ms: 0.0, // Would need timing data
303        error_rate_percent: 0.0,   // Would need error tracking
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use axum::http::StatusCode;
311    use mockforge_core::proxy::config::ProxyConfig;
312    use std::net::SocketAddr;
313
314    #[tokio::test]
315    async fn test_proxy_server_creation() {
316        let config = ProxyConfig::default();
317        let server = ProxyServer::new(config, true, true);
318
319        // Test that the server can be created
320        assert!(server.log_requests);
321        assert!(server.log_responses);
322    }
323
324    #[tokio::test]
325    async fn test_health_check() {
326        let response = health_check().await.unwrap();
327        assert_eq!(response.status(), StatusCode::OK);
328
329        // Response body is already a String
330        let body = response.into_body();
331
332        assert!(body.contains("healthy"));
333        assert!(body.contains("mockforge-proxy"));
334    }
335
336    #[tokio::test]
337    async fn test_proxy_stats() {
338        let config = ProxyConfig::default();
339        let server = ProxyServer::new(config, false, false);
340
341        let stats = get_proxy_stats(&server).await;
342        assert_eq!(stats.total_requests, 0);
343        assert_eq!(stats.requests_per_second, 0.0);
344        assert_eq!(stats.avg_response_time_ms, 0.0);
345        assert_eq!(stats.error_rate_percent, 0.0);
346    }
347}