phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5    body::Body,
6    extract::Extension,
7    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10use hyper_util::rt::TokioIo;
11
12#[derive(Clone)]
13pub struct ProxyState {
14    cache: CacheStore,
15    config: CreateProxyConfig,
16}
17
18impl ProxyState {
19    pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
20        Self { cache, config }
21    }
22}
23
24/// Check if the request is a WebSocket or other upgrade request
25/// 
26/// WebSocket and other protocol upgrades are detected by checking for:
27/// - `Connection: Upgrade` header (case-insensitive)
28/// - Presence of `Upgrade` header
29/// 
30/// These requests will bypass caching and use direct TCP tunneling instead.
31fn is_upgrade_request(headers: &HeaderMap) -> bool {
32    headers
33        .get(axum::http::header::CONNECTION)
34        .and_then(|v| v.to_str().ok())
35        .map(|v| v.to_lowercase().contains("upgrade"))
36        .unwrap_or(false)
37        || headers.contains_key(axum::http::header::UPGRADE)
38}
39
40/// Main proxy handler that serves prerendered content from cache
41/// or fetches from backend if not cached
42pub async fn proxy_handler(
43    Extension(state): Extension<Arc<ProxyState>>,
44    req: Request<Body>,
45) -> Result<Response<Body>, StatusCode> {
46    // Check for upgrade requests FIRST (before consuming anything from the request)
47    // This is critical for WebSocket to work properly
48    let is_upgrade = is_upgrade_request(req.headers());
49    
50    if is_upgrade {
51        let method_str = req.method().as_str();
52        let path = req.uri().path();
53        
54        if state.config.enable_websocket {
55            tracing::debug!("Upgrade request detected for {} {}, establishing direct proxy tunnel", method_str, path);
56            return handle_upgrade_request(state, req).await;
57        } else {
58            tracing::warn!("Upgrade request detected for {} {} but WebSocket support is disabled", method_str, path);
59            return Err(StatusCode::NOT_IMPLEMENTED);
60        }
61    }
62    
63    // Extract request details (only after we know it's not an upgrade request)
64    let method = req.method().clone();
65    let method_str = method.as_str();
66    let uri = req.uri().clone();
67    let path = uri.path();
68    let query = uri.query().unwrap_or("");
69    let headers = req.headers().clone();
70    
71    // Check if only GET requests are allowed
72    if state.config.forward_get_only && method != axum::http::Method::GET {
73        tracing::warn!("Non-GET request {} {} rejected (forward_get_only is enabled)", method_str, path);
74        return Err(StatusCode::METHOD_NOT_ALLOWED);
75    }
76    
77    // Check if this path should be cached based on include/exclude patterns
78    let should_cache = should_cache_path(
79        method_str,
80        path,
81        &state.config.include_paths,
82        &state.config.exclude_paths,
83    );
84    
85    // Generate cache key using the configured function
86    let req_info = crate::RequestInfo {
87        method: method_str,
88        path,
89        query,
90        headers: &headers,
91    };
92    let cache_key = (state.config.cache_key_fn)(&req_info);
93
94    // Try to get from cache first (only if caching is enabled for this path)
95    if should_cache {
96        if let Some(cached) = state.cache.get(&cache_key).await {
97            tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
98            return Ok(build_response_from_cache(cached));
99        }
100        tracing::debug!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
101    } else {
102        tracing::debug!("{} {} not cacheable (filtered), proxying directly", method_str, path);
103    }
104    
105    // Convert body to bytes to forward it
106    let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
107        Ok(bytes) => bytes,
108        Err(e) => {
109            tracing::error!("Failed to read request body: {}", e);
110            return Err(StatusCode::BAD_REQUEST);
111        }
112    };
113
114    // Fetch from backend (proxy_url)
115    let target_url = format!("{}{}", state.config.proxy_url, uri);
116    let client = reqwest::Client::new();
117
118    let response = match client
119        .request(method.clone(), &target_url)
120        .headers(convert_headers(&headers))
121        .body(body_bytes.to_vec())
122        .send()
123        .await
124    {
125        Ok(resp) => resp,
126        Err(e) => {
127            tracing::error!("Failed to fetch from backend: {}", e);
128            return Err(StatusCode::BAD_GATEWAY);
129        }
130    };
131
132    // Cache the response (only if caching is enabled for this path)
133    let status = response.status().as_u16();
134    let response_headers = response.headers().clone();
135    let body_bytes = match response.bytes().await {
136        Ok(bytes) => bytes.to_vec(),
137        Err(e) => {
138            tracing::error!("Failed to read response body: {}", e);
139            return Err(StatusCode::BAD_GATEWAY);
140        }
141    };
142
143    let cached_response = CachedResponse {
144        body: body_bytes.clone(),
145        headers: convert_headers_to_map(&response_headers),
146        status,
147    };
148
149    if should_cache {
150        state
151            .cache
152            .set(cache_key.clone(), cached_response.clone())
153            .await;
154        tracing::debug!("Cached response for: {} {}", method_str, cache_key);
155    }
156
157    Ok(build_response_from_cache(cached_response))
158}
159
160/// Handle WebSocket and other upgrade requests by establishing a direct TCP tunnel
161/// 
162/// This function handles long-lived connections like WebSocket by:
163/// 1. Connecting to the backend server
164/// 2. Forwarding the upgrade request
165/// 3. Capturing both client and backend upgrade connections
166/// 4. Creating a bidirectional TCP tunnel between them
167/// 
168/// The tunnel remains open for the lifetime of the connection, allowing
169/// full-duplex communication. Data flows directly between client and backend
170/// without any caching or inspection.
171async fn handle_upgrade_request(
172    state: Arc<ProxyState>,
173    mut req: Request<Body>,
174) -> Result<Response<Body>, StatusCode> {
175    let target_url = format!("{}{}", state.config.proxy_url, req.uri());
176    
177    // Parse the backend URL to extract host and port
178    let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
179        tracing::error!("Failed to parse backend URL: {}", e);
180        StatusCode::BAD_GATEWAY
181    })?;
182    
183    let host = backend_uri.host().ok_or_else(|| {
184        tracing::error!("No host in backend URL");
185        StatusCode::BAD_GATEWAY
186    })?;
187    
188    let port = backend_uri.port_u16().unwrap_or_else(|| {
189        if backend_uri.scheme_str() == Some("https") {
190            443
191        } else {
192            80
193        }
194    });
195    
196    // IMPORTANT: Set up client upgrade BEFORE processing the request
197    // This captures the client's connection for later upgrade
198    let client_upgrade = hyper::upgrade::on(&mut req);
199    
200    // Connect to backend
201    let backend_stream = tokio::net::TcpStream::connect((host, port))
202        .await
203        .map_err(|e| {
204            tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
205            StatusCode::BAD_GATEWAY
206        })?;
207    
208    let backend_io = TokioIo::new(backend_stream);
209    
210    // Build the backend request with upgrade support
211    let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
212        .await
213        .map_err(|e| {
214            tracing::error!("Failed to handshake with backend: {}", e);
215            StatusCode::BAD_GATEWAY
216        })?;
217    
218    // Spawn a task to poll the connection - this will handle the upgrade
219    let conn_task = tokio::spawn(async move {
220        match conn.with_upgrades().await {
221            Ok(parts) => {
222                tracing::debug!("Backend connection upgraded successfully");
223                Ok(parts)
224            }
225            Err(e) => {
226                tracing::error!("Backend connection failed: {}", e);
227                Err(e)
228            }
229        }
230    });
231    
232    // Forward the request to the backend
233    let backend_response = sender.send_request(req).await.map_err(|e| {
234        tracing::error!("Failed to send request to backend: {}", e);
235        StatusCode::BAD_GATEWAY
236    })?;
237    
238    // Check if backend accepted the upgrade
239    let status = backend_response.status();
240    if status != StatusCode::SWITCHING_PROTOCOLS {
241        tracing::warn!("Backend did not accept upgrade request, status: {}", status);
242        // Convert the backend response to our response type
243        let (parts, body) = backend_response.into_parts();
244        let body = Body::new(body);
245        return Ok(Response::from_parts(parts, body));
246    }
247    
248    // Extract headers before moving backend_response
249    let backend_headers = backend_response.headers().clone();
250    
251    // Get the upgraded backend connection
252    let backend_upgrade = hyper::upgrade::on(backend_response);
253    
254    // Spawn a task to handle bidirectional streaming between client and backend
255    tokio::spawn(async move {
256        tracing::debug!("Starting upgrade tunnel establishment");
257        
258        // Wait for both upgrades to complete
259        let (client_result, backend_result) = tokio::join!(
260            client_upgrade,
261            backend_upgrade
262        );
263        
264        // Drop the connection task as we now have the upgraded streams
265        drop(conn_task);
266        
267        match (client_result, backend_result) {
268            (Ok(client_upgraded), Ok(backend_upgraded)) => {
269                tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
270                
271                // Wrap both in TokioIo for AsyncRead + AsyncWrite
272                let mut client_stream = TokioIo::new(client_upgraded);
273                let mut backend_stream = TokioIo::new(backend_upgraded);
274                
275                // Create bidirectional tunnel
276                match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
277                    Ok((client_to_backend, backend_to_client)) => {
278                        tracing::debug!(
279                            "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
280                            client_to_backend,
281                            backend_to_client
282                        );
283                    }
284                    Err(e) => {
285                        tracing::error!("Tunnel error: {}", e);
286                    }
287                }
288            }
289            (Err(e), _) => {
290                tracing::error!("Client upgrade failed: {}", e);
291            }
292            (_, Err(e)) => {
293                tracing::error!("Backend upgrade failed: {}", e);
294            }
295        }
296    });
297    
298    // Build the response to send back to the client with upgrade support
299    let mut response = Response::builder()
300        .status(StatusCode::SWITCHING_PROTOCOLS)
301        .body(Body::empty())
302        .unwrap();
303    
304    // Copy necessary headers from backend response
305    // These headers are essential for WebSocket handshake
306    if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
307        response.headers_mut().insert(
308            axum::http::header::UPGRADE,
309            upgrade_header.clone(),
310        );
311    }
312    if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
313        response.headers_mut().insert(
314            axum::http::header::CONNECTION,
315            connection_header.clone(),
316        );
317    }
318    if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
319        response.headers_mut().insert(
320            HeaderName::from_static("sec-websocket-accept"),
321            sec_websocket_accept.clone(),
322        );
323    }
324    
325    tracing::debug!("Upgrade response sent to client, tunnel task spawned");
326    
327    Ok(response)
328}
329
330fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
331    let mut response = Response::builder().status(cached.status);
332
333    // Add headers
334    let headers = response.headers_mut().unwrap();
335    for (key, value) in cached.headers {
336        if let Ok(header_name) = key.parse::<HeaderName>() {
337            if let Ok(header_value) = HeaderValue::from_str(&value) {
338                headers.insert(header_name, header_value);
339            } else {
340                tracing::warn!("Failed to parse header value for key '{}': {:?}", key, value);
341            }
342        } else {
343            tracing::warn!("Failed to parse header name: {}", key);
344        }
345    }
346
347    response.body(Body::from(cached.body)).unwrap()
348}
349
350fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
351    let mut req_headers = reqwest::header::HeaderMap::new();
352    for (key, value) in headers {
353        // Skip host header as reqwest will set it
354        if key == axum::http::header::HOST {
355            continue;
356        }
357        if let Ok(val) = value.to_str() {
358            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
359                req_headers.insert(key.clone(), header_value);
360            }
361        }
362    }
363    req_headers
364}
365
366fn convert_headers_to_map(
367    headers: &reqwest::header::HeaderMap,
368) -> std::collections::HashMap<String, String> {
369    let mut map = std::collections::HashMap::new();
370    for (key, value) in headers {
371        if let Ok(val) = value.to_str() {
372            map.insert(key.to_string(), val.to_string());
373        } else {
374            // Log when we can't convert a header (might be binary)
375            tracing::debug!("Could not convert header '{}' to string", key);
376        }
377    }
378    map
379}