mockforge_http/
proxy_server.rs

1//! Browser/Mobile Proxy Server
2//!
3//! Provides an intercepting proxy for frontend/mobile clients with HTTPS support,
4//! certificate injection, and comprehensive request/response logging.
5
6use axum::{
7    extract::Request,
8    http::{HeaderMap, Method, StatusCode, Uri},
9    middleware::Next,
10    response::Response,
11    routing::{any, get},
12    Router,
13};
14use mockforge_core::proxy::{
15    body_transform::BodyTransformationMiddleware, config::ProxyConfig, handler::ProxyHandler,
16};
17use serde::{Deserialize, Serialize};
18use std::net::SocketAddr;
19use std::sync::Arc;
20use tokio::sync::RwLock;
21use tracing::{debug, error, info, warn};
22
23/// Proxy server state
24#[derive(Debug)]
25pub struct ProxyServer {
26    /// Proxy configuration
27    config: Arc<RwLock<ProxyConfig>>,
28    /// Request logging enabled
29    log_requests: bool,
30    /// Response logging enabled
31    log_responses: bool,
32    /// Request counter for logging
33    request_counter: Arc<RwLock<u64>>,
34}
35
36impl ProxyServer {
37    /// Create a new proxy server
38    pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
39        Self {
40            config: Arc::new(RwLock::new(config)),
41            log_requests,
42            log_responses,
43            request_counter: Arc::new(RwLock::new(0)),
44        }
45    }
46
47    /// Get the Axum router for the proxy server
48    pub fn router(self) -> Router {
49        let state = Arc::new(self);
50        let state_for_middleware = state.clone();
51
52        Router::new()
53            // Health check endpoint
54            .route("/proxy/health", get(health_check))
55            // Catch-all proxy handler - use fallback for all methods
56            .fallback(proxy_handler)
57            .with_state(state)
58            .layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
59    }
60}
61
62/// Health check endpoint for the proxy
63async fn health_check() -> Result<Response<String>, StatusCode> {
64    // Response builder should never fail with known-good values, but handle errors gracefully
65    Response::builder()
66        .status(StatusCode::OK)
67        .header("Content-Type", "application/json")
68        .body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
69        .map_err(|e| {
70            tracing::error!("Failed to build health check response: {}", e);
71            StatusCode::INTERNAL_SERVER_ERROR
72        })
73}
74
75/// Main proxy handler that intercepts and forwards requests
76async fn proxy_handler(
77    axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
78    request: axum::http::Request<axum::body::Body>,
79) -> Result<Response<String>, StatusCode> {
80    // Extract client address from request extensions (set by ConnectInfo middleware)
81    let client_addr = request
82        .extensions()
83        .get::<SocketAddr>()
84        .copied()
85        .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
86
87    let method = request.method().clone();
88    let uri = request.uri().clone();
89    let headers = request.headers().clone();
90
91    // Read request body early for conditional evaluation (consume the body)
92    let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
93        Ok(bytes) => Some(bytes.to_vec()),
94        Err(e) => {
95            error!("Failed to read request body: {}", e);
96            None
97        }
98    };
99
100    let config = state.config.read().await;
101
102    // Check if proxy is enabled
103    if !config.enabled {
104        return Err(StatusCode::SERVICE_UNAVAILABLE);
105    }
106
107    // Determine if this request should be proxied (with conditional evaluation)
108    if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
109        return Err(StatusCode::NOT_FOUND);
110    }
111
112    // Get the stripped path (without proxy prefix)
113    let stripped_path = config.strip_prefix(uri.path());
114
115    // Get the base upstream URL and construct the full URL
116    let base_upstream_url = config.get_upstream_url(uri.path());
117    let full_upstream_url =
118        if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
119            stripped_path.clone()
120        } else {
121            let base = base_upstream_url.trim_end_matches('/');
122            let path = stripped_path.trim_start_matches('/');
123            let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
124            if path.is_empty() || path == "/" {
125                format!("{}{}", base, query)
126            } else {
127                format!("{}/{}", base, path) + &query
128            }
129        };
130
131    // Create a new URI with the full upstream URL for the proxy handler
132    let modified_uri = full_upstream_url.parse::<axum::http::Uri>().unwrap_or_else(|_| uri.clone());
133
134    // Log the request if enabled
135    if state.log_requests {
136        let mut counter = state.request_counter.write().await;
137        *counter += 1;
138        let request_id = *counter;
139
140        info!(
141            request_id = request_id,
142            method = %method,
143            path = %uri.path(),
144            upstream = %full_upstream_url,
145            client_ip = %client_addr.ip(),
146            "Proxy request intercepted"
147        );
148    }
149
150    // Convert headers to HashMap for the proxy handler
151    let mut header_map = std::collections::HashMap::new();
152    for (key, value) in &headers {
153        if let Ok(value_str) = value.to_str() {
154            header_map.insert(key.to_string(), value_str.to_string());
155        }
156    }
157
158    // Use ProxyClient directly with the full upstream URL to bypass ProxyHandler's URL construction
159    use mockforge_core::proxy::client::ProxyClient;
160    let proxy_client = ProxyClient::new();
161
162    // Convert method to reqwest method
163    let reqwest_method = match method.as_str() {
164        "GET" => reqwest::Method::GET,
165        "POST" => reqwest::Method::POST,
166        "PUT" => reqwest::Method::PUT,
167        "DELETE" => reqwest::Method::DELETE,
168        "HEAD" => reqwest::Method::HEAD,
169        "OPTIONS" => reqwest::Method::OPTIONS,
170        "PATCH" => reqwest::Method::PATCH,
171        _ => {
172            error!("Unsupported HTTP method: {}", method);
173            return Err(StatusCode::METHOD_NOT_ALLOWED);
174        }
175    };
176
177    // Add any configured headers
178    for (key, value) in &config.headers {
179        header_map.insert(key.clone(), value.clone());
180    }
181
182    // Apply request body transformations if configured
183    let mut transformed_request_body = body_bytes.clone();
184    if !config.request_replacements.is_empty() {
185        let transform_middleware = BodyTransformationMiddleware::new(
186            config.request_replacements.clone(),
187            Vec::new(), // No response rules needed here
188        );
189        if let Err(e) =
190            transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
191        {
192            warn!("Failed to transform request body: {}", e);
193            // Continue with original body if transformation fails
194        }
195    }
196
197    match proxy_client
198        .send_request(
199            reqwest_method,
200            &full_upstream_url,
201            &header_map,
202            transformed_request_body.as_deref(),
203        )
204        .await
205    {
206        Ok(response) => {
207            let status = StatusCode::from_u16(response.status().as_u16())
208                .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
209
210            // Log the response if enabled
211            if state.log_responses {
212                info!(
213                    method = %method,
214                    path = %uri.path(),
215                    status = status.as_u16(),
216                    "Proxy response sent"
217                );
218            }
219
220            // Convert response headers
221            let mut response_headers = axum::http::HeaderMap::new();
222            for (name, value) in response.headers() {
223                if let (Ok(header_name), Ok(header_value)) = (
224                    axum::http::HeaderName::try_from(name.as_str()),
225                    axum::http::HeaderValue::try_from(value.as_bytes()),
226                ) {
227                    response_headers.insert(header_name, header_value);
228                }
229            }
230
231            // Read response body
232            let response_body_bytes = response.bytes().await.map_err(|e| {
233                error!("Failed to read proxy response body: {}", e);
234                StatusCode::BAD_GATEWAY
235            })?;
236
237            // Apply response body transformations if configured
238            let mut final_body_bytes = response_body_bytes.to_vec();
239            {
240                let config_for_response = state.config.read().await;
241                if !config_for_response.response_replacements.is_empty() {
242                    let transform_middleware = BodyTransformationMiddleware::new(
243                        Vec::new(), // No request rules needed here
244                        config_for_response.response_replacements.clone(),
245                    );
246                    let mut body_option = Some(final_body_bytes.clone());
247                    if let Err(e) = transform_middleware.transform_response_body(
248                        uri.path(),
249                        status.as_u16(),
250                        &mut body_option,
251                    ) {
252                        warn!("Failed to transform response body: {}", e);
253                        // Continue with original body if transformation fails
254                    } else if let Some(transformed_body) = body_option {
255                        final_body_bytes = transformed_body;
256                    }
257                }
258            }
259
260            let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
261
262            // Build Axum response
263            let mut response_builder = Response::builder().status(status);
264            for (name, value) in response_headers.iter() {
265                response_builder = response_builder.header(name, value);
266            }
267
268            response_builder
269                .body(body_string)
270                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
271        }
272        Err(e) => {
273            error!("Proxy request failed: {}", e);
274            Err(StatusCode::BAD_GATEWAY)
275        }
276    }
277}
278
279/// Middleware for logging requests and responses
280async fn logging_middleware(
281    axum::extract::State(_state): axum::extract::State<Arc<ProxyServer>>,
282    request: Request,
283    next: Next,
284) -> Response {
285    let start = std::time::Instant::now();
286    let method = request.method().clone();
287    let uri = request.uri().clone();
288
289    // Extract client address from request extensions
290    let client_addr = request
291        .extensions()
292        .get::<SocketAddr>()
293        .copied()
294        .unwrap_or_else(|| std::net::SocketAddr::from(([0, 0, 0, 0], 0)));
295
296    debug!(
297        method = %method,
298        uri = %uri,
299        client_ip = %client_addr.ip(),
300        "Request received"
301    );
302
303    let response = next.run(request).await;
304    let duration = start.elapsed();
305
306    debug!(
307        method = %method,
308        uri = %uri,
309        status = %response.status(),
310        duration_ms = duration.as_millis(),
311        "Response sent"
312    );
313
314    response
315}
316
317/// Proxy statistics for monitoring
318#[derive(Debug, Serialize)]
319pub struct ProxyStats {
320    /// Total requests processed
321    pub total_requests: u64,
322    /// Requests per second
323    pub requests_per_second: f64,
324    /// Average response time in milliseconds
325    pub avg_response_time_ms: f64,
326    /// Error rate percentage
327    pub error_rate_percent: f64,
328}
329
330/// Get proxy statistics
331pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
332    let total_requests = *state.request_counter.read().await;
333
334    // For now, return basic stats. In a real implementation,
335    // you'd track more detailed metrics over time.
336    ProxyStats {
337        total_requests,
338        requests_per_second: 0.0,  // Would need time-based tracking
339        avg_response_time_ms: 0.0, // Would need timing data
340        error_rate_percent: 0.0,   // Would need error tracking
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use axum::http::StatusCode;
348    use mockforge_core::proxy::config::ProxyConfig;
349    use std::net::SocketAddr;
350
351    #[tokio::test]
352    async fn test_proxy_server_creation() {
353        let config = ProxyConfig::default();
354        let server = ProxyServer::new(config, true, true);
355
356        // Test that the server can be created
357        assert!(server.log_requests);
358        assert!(server.log_responses);
359    }
360
361    #[tokio::test]
362    async fn test_health_check() {
363        let response = health_check().await.unwrap();
364        assert_eq!(response.status(), StatusCode::OK);
365
366        // Response body is already a String
367        let body = response.into_body();
368
369        assert!(body.contains("healthy"));
370        assert!(body.contains("mockforge-proxy"));
371    }
372
373    #[tokio::test]
374    async fn test_proxy_stats() {
375        let config = ProxyConfig::default();
376        let server = ProxyServer::new(config, false, false);
377
378        let stats = get_proxy_stats(&server).await;
379        assert_eq!(stats.total_requests, 0);
380        assert_eq!(stats.requests_per_second, 0.0);
381        assert_eq!(stats.avg_response_time_ms, 0.0);
382        assert_eq!(stats.error_rate_percent, 0.0);
383    }
384}