phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5    body::Body,
6    extract::Extension,
7    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10use hyper_util::rt::TokioIo;
11
12#[derive(Clone)]
13pub struct ProxyState {
14    cache: CacheStore,
15    config: CreateProxyConfig,
16}
17
18impl ProxyState {
19    pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
20        Self { cache, config }
21    }
22}
23
24/// Check if the request is a WebSocket or other upgrade request
25/// 
26/// WebSocket and other protocol upgrades are detected by checking for:
27/// - `Connection: Upgrade` header (case-insensitive)
28/// - Presence of `Upgrade` header
29/// 
30/// These requests will bypass caching and use direct TCP tunneling instead.
31fn is_upgrade_request(headers: &HeaderMap) -> bool {
32    headers
33        .get(axum::http::header::CONNECTION)
34        .and_then(|v| v.to_str().ok())
35        .map(|v| v.to_lowercase().contains("upgrade"))
36        .unwrap_or(false)
37        || headers.contains_key(axum::http::header::UPGRADE)
38}
39
40/// Main proxy handler that serves prerendered content from cache
41/// or fetches from backend if not cached
42pub async fn proxy_handler(
43    Extension(state): Extension<Arc<ProxyState>>,
44    req: Request<Body>,
45) -> Result<Response<Body>, StatusCode> {
46    let method_str = req.method().as_str();
47    let path = req.uri().path();
48    let query = req.uri().query().unwrap_or("");
49    let headers = req.headers();
50    
51    // Check if this is an upgrade request (WebSocket, etc.)
52    // If so, handle it via direct TCP proxying (if enabled)
53    if state.config.enable_websocket && is_upgrade_request(headers) {
54        tracing::info!("Upgrade request detected for {} {}, establishing direct proxy tunnel", method_str, path);
55        return handle_upgrade_request(state, req).await;
56    } else if !state.config.enable_websocket && is_upgrade_request(headers) {
57        tracing::warn!("Upgrade request detected for {} {} but WebSocket support is disabled", method_str, path);
58        return Err(StatusCode::NOT_IMPLEMENTED);
59    }
60    
61    // Check if this path should be cached based on include/exclude patterns
62    let should_cache = should_cache_path(
63        method_str,
64        path,
65        &state.config.include_paths,
66        &state.config.exclude_paths,
67    );
68    
69    // Generate cache key using the configured function
70    let req_info = crate::RequestInfo {
71        method: method_str,
72        path,
73        query,
74        headers,
75    };
76    let cache_key = (state.config.cache_key_fn)(&req_info);
77
78    // Try to get from cache first (only if caching is enabled for this path)
79    if should_cache {
80        if let Some(cached) = state.cache.get(&cache_key).await {
81            tracing::info!("Cache hit for: {} {}", method_str, cache_key);
82            return Ok(build_response_from_cache(cached));
83        }
84        tracing::info!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
85    } else {
86        tracing::info!("{} {} not cacheable (filtered), proxying directly", method_str, path);
87    }
88
89    // Fetch from backend (proxy_url)
90    let target_url = format!("{}{}", state.config.proxy_url, req.uri());
91    let client = reqwest::Client::new();
92
93    let method = req.method().clone();
94    let headers = req.headers().clone();
95
96    let response = match client
97        .request(method, &target_url)
98        .headers(convert_headers(&headers))
99        .send()
100        .await
101    {
102        Ok(resp) => resp,
103        Err(e) => {
104            tracing::error!("Failed to fetch from backend: {}", e);
105            return Err(StatusCode::BAD_GATEWAY);
106        }
107    };
108
109    // Cache the response (only if caching is enabled for this path)
110    let status = response.status().as_u16();
111    let response_headers = response.headers().clone();
112    let body_bytes = match response.bytes().await {
113        Ok(bytes) => bytes.to_vec(),
114        Err(e) => {
115            tracing::error!("Failed to read response body: {}", e);
116            return Err(StatusCode::BAD_GATEWAY);
117        }
118    };
119
120    let cached_response = CachedResponse {
121        body: body_bytes.clone(),
122        headers: convert_headers_to_map(&response_headers),
123        status,
124    };
125
126    if should_cache {
127        state
128            .cache
129            .set(cache_key.clone(), cached_response.clone())
130            .await;
131        tracing::info!("Cached response for: {} {}", method_str, cache_key);
132    }
133
134    Ok(build_response_from_cache(cached_response))
135}
136
137/// Handle WebSocket and other upgrade requests by establishing a direct TCP tunnel
138/// 
139/// This function handles long-lived connections like WebSocket by:
140/// 1. Connecting to the backend server
141/// 2. Forwarding the upgrade request
142/// 3. Capturing both client and backend upgrade connections
143/// 4. Creating a bidirectional TCP tunnel between them
144/// 
145/// The tunnel remains open for the lifetime of the connection, allowing
146/// full-duplex communication. Data flows directly between client and backend
147/// without any caching or inspection.
148async fn handle_upgrade_request(
149    state: Arc<ProxyState>,
150    mut req: Request<Body>,
151) -> Result<Response<Body>, StatusCode> {
152    let target_url = format!("{}{}", state.config.proxy_url, req.uri());
153    
154    // Parse the backend URL to extract host and port
155    let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
156        tracing::error!("Failed to parse backend URL: {}", e);
157        StatusCode::BAD_GATEWAY
158    })?;
159    
160    let host = backend_uri.host().ok_or_else(|| {
161        tracing::error!("No host in backend URL");
162        StatusCode::BAD_GATEWAY
163    })?;
164    
165    let port = backend_uri.port_u16().unwrap_or_else(|| {
166        if backend_uri.scheme_str() == Some("https") {
167            443
168        } else {
169            80
170        }
171    });
172    
173    // IMPORTANT: Set up client upgrade BEFORE processing the request
174    // This captures the client's connection for later upgrade
175    let client_upgrade = hyper::upgrade::on(&mut req);
176    
177    // Connect to backend
178    let backend_stream = tokio::net::TcpStream::connect((host, port))
179        .await
180        .map_err(|e| {
181            tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
182            StatusCode::BAD_GATEWAY
183        })?;
184    
185    let backend_stream = TokioIo::new(backend_stream);
186    
187    // Build the backend request
188    let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_stream)
189        .await
190        .map_err(|e| {
191            tracing::error!("Failed to handshake with backend: {}", e);
192            StatusCode::BAD_GATEWAY
193        })?;
194    
195    // Spawn a task to poll the connection
196    tokio::spawn(async move {
197        if let Err(e) = conn.await {
198            tracing::error!("Connection to backend failed: {}", e);
199        }
200    });
201    
202    // Forward the request to the backend
203    let backend_response = sender.send_request(req).await.map_err(|e| {
204        tracing::error!("Failed to send request to backend: {}", e);
205        StatusCode::BAD_GATEWAY
206    })?;
207    
208    // Check if backend accepted the upgrade
209    let status = backend_response.status();
210    if status != StatusCode::SWITCHING_PROTOCOLS {
211        tracing::warn!("Backend did not accept upgrade request, status: {}", status);
212        // Convert the backend response to our response type
213        let (parts, body) = backend_response.into_parts();
214        let body = Body::new(body);
215        return Ok(Response::from_parts(parts, body));
216    }
217    
218    // Extract headers before moving backend_response
219    let backend_headers = backend_response.headers().clone();
220    
221    // Spawn a task to handle bidirectional streaming between client and backend
222    tokio::spawn(async move {
223        tracing::info!("Starting upgrade tunnel establishment");
224        
225        // Wait for both upgrades to complete
226        let (client_result, backend_result) = tokio::join!(
227            client_upgrade,
228            hyper::upgrade::on(backend_response)
229        );
230        
231        match (client_result, backend_result) {
232            (Ok(client_upgraded), Ok(backend_upgraded)) => {
233                tracing::info!("Both upgrades successful, establishing bidirectional tunnel");
234                
235                // Wrap both in TokioIo for AsyncRead + AsyncWrite
236                let mut client_stream = TokioIo::new(client_upgraded);
237                let mut backend_stream = TokioIo::new(backend_upgraded);
238                
239                // Create bidirectional tunnel
240                match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
241                    Ok((client_to_backend, backend_to_client)) => {
242                        tracing::info!(
243                            "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
244                            client_to_backend,
245                            backend_to_client
246                        );
247                    }
248                    Err(e) => {
249                        tracing::error!("Tunnel error: {}", e);
250                    }
251                }
252            }
253            (Err(e), _) => {
254                tracing::error!("Client upgrade failed: {}", e);
255            }
256            (_, Err(e)) => {
257                tracing::error!("Backend upgrade failed: {}", e);
258            }
259        }
260    });
261    
262    // Build the response to send back to the client with upgrade support
263    let mut response = Response::builder()
264        .status(StatusCode::SWITCHING_PROTOCOLS)
265        .body(Body::empty())
266        .unwrap();
267    
268    // Copy necessary headers from backend response
269    // These headers are essential for WebSocket handshake
270    if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
271        response.headers_mut().insert(
272            axum::http::header::UPGRADE,
273            upgrade_header.clone(),
274        );
275    }
276    if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
277        response.headers_mut().insert(
278            axum::http::header::CONNECTION,
279            connection_header.clone(),
280        );
281    }
282    if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
283        response.headers_mut().insert(
284            HeaderName::from_static("sec-websocket-accept"),
285            sec_websocket_accept.clone(),
286        );
287    }
288    
289    tracing::info!("Upgrade response sent to client, tunnel task spawned");
290    
291    Ok(response)
292}
293
294fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
295    let mut response = Response::builder().status(cached.status);
296
297    // Add headers
298    let headers = response.headers_mut().unwrap();
299    for (key, value) in cached.headers {
300        if let Ok(header_name) = key.parse::<HeaderName>() {
301            if let Ok(header_value) = HeaderValue::from_str(&value) {
302                headers.insert(header_name, header_value);
303            }
304        }
305    }
306
307    response.body(Body::from(cached.body)).unwrap()
308}
309
310fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
311    let mut req_headers = reqwest::header::HeaderMap::new();
312    for (key, value) in headers {
313        if let Ok(val) = value.to_str() {
314            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
315                req_headers.insert(key.clone(), header_value);
316            }
317        }
318    }
319    req_headers
320}
321
322fn convert_headers_to_map(
323    headers: &reqwest::header::HeaderMap,
324) -> std::collections::HashMap<String, String> {
325    let mut map = std::collections::HashMap::new();
326    for (key, value) in headers {
327        if let Ok(val) = value.to_str() {
328            map.insert(key.to_string(), val.to_string());
329        }
330    }
331    map
332}