phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5    body::Body,
6    extract::Extension,
7    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10use hyper_util::rt::TokioIo;
11
12#[derive(Clone)]
13pub struct ProxyState {
14    cache: CacheStore,
15    config: CreateProxyConfig,
16}
17
18impl ProxyState {
19    pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
20        Self { cache, config }
21    }
22}
23
24/// Check if the request is a WebSocket or other upgrade request
25/// 
26/// WebSocket and other protocol upgrades are detected by checking for:
27/// - `Connection: Upgrade` header (case-insensitive)
28/// - Presence of `Upgrade` header
29/// 
30/// These requests will bypass caching and use direct TCP tunneling instead.
31fn is_upgrade_request(headers: &HeaderMap) -> bool {
32    headers
33        .get(axum::http::header::CONNECTION)
34        .and_then(|v| v.to_str().ok())
35        .map(|v| v.to_lowercase().contains("upgrade"))
36        .unwrap_or(false)
37        || headers.contains_key(axum::http::header::UPGRADE)
38}
39
40/// Main proxy handler that serves prerendered content from cache
41/// or fetches from backend if not cached
42pub async fn proxy_handler(
43    Extension(state): Extension<Arc<ProxyState>>,
44    req: Request<Body>,
45) -> Result<Response<Body>, StatusCode> {
46    // Check for upgrade requests FIRST (before consuming anything from the request)
47    // This is critical for WebSocket to work properly
48    let is_upgrade = is_upgrade_request(req.headers());
49    
50    if is_upgrade {
51        let method_str = req.method().as_str();
52        let path = req.uri().path();
53        
54        if state.config.enable_websocket {
55            tracing::debug!("Upgrade request detected for {} {}, establishing direct proxy tunnel", method_str, path);
56            return handle_upgrade_request(state, req).await;
57        } else {
58            tracing::warn!("Upgrade request detected for {} {} but WebSocket support is disabled", method_str, path);
59            return Err(StatusCode::NOT_IMPLEMENTED);
60        }
61    }
62    
63    // Extract request details (only after we know it's not an upgrade request)
64    let method = req.method().clone();
65    let method_str = method.as_str();
66    let uri = req.uri().clone();
67    let path = uri.path();
68    let query = uri.query().unwrap_or("");
69    let headers = req.headers().clone();
70    
71    // Check if only GET requests are allowed
72    if state.config.forward_get_only && method != axum::http::Method::GET {
73        tracing::warn!("Non-GET request {} {} rejected (forward_get_only is enabled)", method_str, path);
74        return Err(StatusCode::METHOD_NOT_ALLOWED);
75    }
76    
77    // Check if this path should be cached based on include/exclude patterns
78    let should_cache = should_cache_path(
79        method_str,
80        path,
81        &state.config.include_paths,
82        &state.config.exclude_paths,
83    );
84    
85    // Generate cache key using the configured function
86    let req_info = crate::RequestInfo {
87        method: method_str,
88        path,
89        query,
90        headers: &headers,
91    };
92    let cache_key = (state.config.cache_key_fn)(&req_info);
93
94    // Try to get 404 cache first (available even if should_cache is false)
95    if state.config.cache_404_capacity > 0 {
96        if let Some(cached) = state.cache.get_404(&cache_key).await {
97            tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
98            return Ok(build_response_from_cache(cached));
99        }
100    }
101
102    // Try to get from cache first (only if caching is enabled for this path)
103    if should_cache {
104        if let Some(cached) = state.cache.get(&cache_key).await {
105            tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
106            return Ok(build_response_from_cache(cached));
107        }
108        tracing::debug!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
109    } else {
110        tracing::debug!("{} {} not cacheable (filtered), proxying directly", method_str, path);
111    }
112    
113    // Convert body to bytes to forward it
114    let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
115        Ok(bytes) => bytes,
116        Err(e) => {
117            tracing::error!("Failed to read request body: {}", e);
118            return Err(StatusCode::BAD_REQUEST);
119        }
120    };
121
122    // Fetch from backend (proxy_url)
123    let target_url = format!("{}{}", state.config.proxy_url, uri);
124    let client = reqwest::Client::new();
125
126    let response = match client
127        .request(method.clone(), &target_url)
128        .headers(convert_headers(&headers))
129        .body(body_bytes.to_vec())
130        .send()
131        .await
132    {
133        Ok(resp) => resp,
134        Err(e) => {
135            tracing::error!("Failed to fetch from backend: {}", e);
136            return Err(StatusCode::BAD_GATEWAY);
137        }
138    };
139
140    // Cache the response (only if caching is enabled for this path)
141    let status = response.status().as_u16();
142    let response_headers = response.headers().clone();
143    let body_bytes = match response.bytes().await {
144        Ok(bytes) => bytes.to_vec(),
145        Err(e) => {
146            tracing::error!("Failed to read response body: {}", e);
147            return Err(StatusCode::BAD_GATEWAY);
148        }
149    };
150
151    let cached_response = CachedResponse {
152        body: body_bytes.clone(),
153        headers: convert_headers_to_map(&response_headers),
154        status,
155    };
156
157    // Determine if this should be cached as a 404 (either by status or by meta tag if enabled)
158    let mut is_404 = status == 404;
159    if !is_404 && state.config.use_404_meta {
160        // check if body contains special meta tag
161        if let Ok(body_str) = std::str::from_utf8(&body_bytes) {
162            // look for name="phantom-404" or name='phantom-404' AND content="true" or content='true'
163            let name_dbl = "name=\"phantom-404\"";
164            let name_sgl = "name='phantom-404'";
165            let content_dbl = "content=\"true\"";
166            let content_sgl = "content='true'";
167
168            if (body_str.contains(name_dbl) || body_str.contains(name_sgl))
169                && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
170            {
171                is_404 = true;
172            }
173        }
174    }
175
176    if is_404 && state.config.cache_404_capacity > 0 {
177        // store in 404 cache
178        state
179            .cache
180            .set_404(cache_key.clone(), cached_response.clone())
181            .await;
182        tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
183    } else if should_cache {
184        state
185            .cache
186            .set(cache_key.clone(), cached_response.clone())
187            .await;
188        tracing::debug!("Cached response for: {} {}", method_str, cache_key);
189    }
190
191    Ok(build_response_from_cache(cached_response))
192}
193
194/// Handle WebSocket and other upgrade requests by establishing a direct TCP tunnel
195/// 
196/// This function handles long-lived connections like WebSocket by:
197/// 1. Connecting to the backend server
198/// 2. Forwarding the upgrade request
199/// 3. Capturing both client and backend upgrade connections
200/// 4. Creating a bidirectional TCP tunnel between them
201/// 
202/// The tunnel remains open for the lifetime of the connection, allowing
203/// full-duplex communication. Data flows directly between client and backend
204/// without any caching or inspection.
205async fn handle_upgrade_request(
206    state: Arc<ProxyState>,
207    mut req: Request<Body>,
208) -> Result<Response<Body>, StatusCode> {
209    let target_url = format!("{}{}", state.config.proxy_url, req.uri());
210    
211    // Parse the backend URL to extract host and port
212    let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
213        tracing::error!("Failed to parse backend URL: {}", e);
214        StatusCode::BAD_GATEWAY
215    })?;
216    
217    let host = backend_uri.host().ok_or_else(|| {
218        tracing::error!("No host in backend URL");
219        StatusCode::BAD_GATEWAY
220    })?;
221    
222    let port = backend_uri.port_u16().unwrap_or_else(|| {
223        if backend_uri.scheme_str() == Some("https") {
224            443
225        } else {
226            80
227        }
228    });
229    
230    // IMPORTANT: Set up client upgrade BEFORE processing the request
231    // This captures the client's connection for later upgrade
232    let client_upgrade = hyper::upgrade::on(&mut req);
233    
234    // Connect to backend
235    let backend_stream = tokio::net::TcpStream::connect((host, port))
236        .await
237        .map_err(|e| {
238            tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
239            StatusCode::BAD_GATEWAY
240        })?;
241    
242    let backend_io = TokioIo::new(backend_stream);
243    
244    // Build the backend request with upgrade support
245    let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
246        .await
247        .map_err(|e| {
248            tracing::error!("Failed to handshake with backend: {}", e);
249            StatusCode::BAD_GATEWAY
250        })?;
251    
252    // Spawn a task to poll the connection - this will handle the upgrade
253    let conn_task = tokio::spawn(async move {
254        match conn.with_upgrades().await {
255            Ok(parts) => {
256                tracing::debug!("Backend connection upgraded successfully");
257                Ok(parts)
258            }
259            Err(e) => {
260                tracing::error!("Backend connection failed: {}", e);
261                Err(e)
262            }
263        }
264    });
265    
266    // Forward the request to the backend
267    let backend_response = sender.send_request(req).await.map_err(|e| {
268        tracing::error!("Failed to send request to backend: {}", e);
269        StatusCode::BAD_GATEWAY
270    })?;
271    
272    // Check if backend accepted the upgrade
273    let status = backend_response.status();
274    if status != StatusCode::SWITCHING_PROTOCOLS {
275        tracing::warn!("Backend did not accept upgrade request, status: {}", status);
276        // Convert the backend response to our response type
277        let (parts, body) = backend_response.into_parts();
278        let body = Body::new(body);
279        return Ok(Response::from_parts(parts, body));
280    }
281    
282    // Extract headers before moving backend_response
283    let backend_headers = backend_response.headers().clone();
284    
285    // Get the upgraded backend connection
286    let backend_upgrade = hyper::upgrade::on(backend_response);
287    
288    // Spawn a task to handle bidirectional streaming between client and backend
289    tokio::spawn(async move {
290        tracing::debug!("Starting upgrade tunnel establishment");
291        
292        // Wait for both upgrades to complete
293        let (client_result, backend_result) = tokio::join!(
294            client_upgrade,
295            backend_upgrade
296        );
297        
298        // Drop the connection task as we now have the upgraded streams
299        drop(conn_task);
300        
301        match (client_result, backend_result) {
302            (Ok(client_upgraded), Ok(backend_upgraded)) => {
303                tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
304                
305                // Wrap both in TokioIo for AsyncRead + AsyncWrite
306                let mut client_stream = TokioIo::new(client_upgraded);
307                let mut backend_stream = TokioIo::new(backend_upgraded);
308                
309                // Create bidirectional tunnel
310                match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
311                    Ok((client_to_backend, backend_to_client)) => {
312                        tracing::debug!(
313                            "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
314                            client_to_backend,
315                            backend_to_client
316                        );
317                    }
318                    Err(e) => {
319                        tracing::error!("Tunnel error: {}", e);
320                    }
321                }
322            }
323            (Err(e), _) => {
324                tracing::error!("Client upgrade failed: {}", e);
325            }
326            (_, Err(e)) => {
327                tracing::error!("Backend upgrade failed: {}", e);
328            }
329        }
330    });
331    
332    // Build the response to send back to the client with upgrade support
333    let mut response = Response::builder()
334        .status(StatusCode::SWITCHING_PROTOCOLS)
335        .body(Body::empty())
336        .unwrap();
337    
338    // Copy necessary headers from backend response
339    // These headers are essential for WebSocket handshake
340    if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
341        response.headers_mut().insert(
342            axum::http::header::UPGRADE,
343            upgrade_header.clone(),
344        );
345    }
346    if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
347        response.headers_mut().insert(
348            axum::http::header::CONNECTION,
349            connection_header.clone(),
350        );
351    }
352    if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
353        response.headers_mut().insert(
354            HeaderName::from_static("sec-websocket-accept"),
355            sec_websocket_accept.clone(),
356        );
357    }
358    
359    tracing::debug!("Upgrade response sent to client, tunnel task spawned");
360    
361    Ok(response)
362}
363
364fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
365    let mut response = Response::builder().status(cached.status);
366
367    // Add headers
368    let headers = response.headers_mut().unwrap();
369    for (key, value) in cached.headers {
370        if let Ok(header_name) = key.parse::<HeaderName>() {
371            if let Ok(header_value) = HeaderValue::from_str(&value) {
372                headers.insert(header_name, header_value);
373            } else {
374                tracing::warn!("Failed to parse header value for key '{}': {:?}", key, value);
375            }
376        } else {
377            tracing::warn!("Failed to parse header name: {}", key);
378        }
379    }
380
381    response.body(Body::from(cached.body)).unwrap()
382}
383
384fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
385    let mut req_headers = reqwest::header::HeaderMap::new();
386    for (key, value) in headers {
387        // Skip host header as reqwest will set it
388        if key == axum::http::header::HOST {
389            continue;
390        }
391        if let Ok(val) = value.to_str() {
392            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
393                req_headers.insert(key.clone(), header_value);
394            }
395        }
396    }
397    req_headers
398}
399
400fn convert_headers_to_map(
401    headers: &reqwest::header::HeaderMap,
402) -> std::collections::HashMap<String, String> {
403    let mut map = std::collections::HashMap::new();
404    for (key, value) in headers {
405        if let Ok(val) = value.to_str() {
406            map.insert(key.to_string(), val.to_string());
407        } else {
408            // Log when we can't convert a header (might be binary)
409            tracing::debug!("Could not convert header '{}' to string", key);
410        }
411    }
412    map
413}