Skip to main content

phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::compression::{
3    client_accepts_encoding, compress_body_async, configured_encoding, decode_upstream_body_async,
4    decompress_body_async, identity_acceptable,
5};
6use crate::path_matcher::should_cache_path;
7use crate::{CompressStrategy, CreateProxyConfig, ProxyMode, WebhookType};
8use axum::{
9    body::Body,
10    extract::Extension,
11    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
12};
13use hyper_util::rt::TokioIo;
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18#[derive(Clone)]
19pub struct ProxyState {
20    cache: CacheStore,
21    config: CreateProxyConfig,
22    upstream_client: reqwest::Client,
23    webhook_client: reqwest::Client,
24}
25
26impl ProxyState {
27    pub fn new(
28        cache: CacheStore,
29        config: CreateProxyConfig,
30        upstream_client: reqwest::Client,
31        webhook_client: reqwest::Client,
32    ) -> Self {
33        Self {
34            cache,
35            config,
36            upstream_client,
37            webhook_client,
38        }
39    }
40}
41
42pub(crate) fn build_upstream_client() -> anyhow::Result<reqwest::Client> {
43    reqwest::Client::builder()
44        .pool_idle_timeout(Duration::from_secs(90))
45        .connect_timeout(Duration::from_secs(5))
46        .timeout(Duration::from_secs(30))
47        .tcp_keepalive(Duration::from_secs(30))
48        .no_brotli()
49        .no_deflate()
50        .no_gzip()
51        .build()
52        .map_err(Into::into)
53}
54
55pub(crate) fn build_webhook_client() -> anyhow::Result<reqwest::Client> {
56    reqwest::Client::builder()
57        .pool_idle_timeout(Duration::from_secs(30))
58        .connect_timeout(Duration::from_secs(3))
59        .redirect(reqwest::redirect::Policy::none())
60        .build()
61        .map_err(Into::into)
62}
63
64/// Check if the request is a WebSocket or other upgrade request
65///
66/// WebSocket and other protocol upgrades are detected by checking for:
67/// - `Connection: Upgrade` header (case-insensitive)
68/// - Presence of `Upgrade` header
69///
70/// These requests will bypass caching and use direct TCP tunneling instead.
71fn is_upgrade_request(headers: &HeaderMap) -> bool {
72    headers
73        .get(axum::http::header::CONNECTION)
74        .and_then(|v| v.to_str().ok())
75        .map(|v| v.to_lowercase().contains("upgrade"))
76        .unwrap_or(false)
77        || headers.contains_key(axum::http::header::UPGRADE)
78}
79
80/// Build the JSON payload sent to webhook endpoints.
81///
82/// Contains `method`, `path`, `query`, and `headers` (as a flat string-to-string
83/// map). The request body is intentionally excluded so that the caller never has
84/// to consume it before the payload is built.
85fn build_webhook_payload(
86    method: &str,
87    path: &str,
88    query: &str,
89    headers: &HeaderMap,
90) -> serde_json::Value {
91    let headers_map: serde_json::Map<String, serde_json::Value> = headers
92        .iter()
93        .filter_map(|(name, value)| {
94            value.to_str().ok().map(|v| {
95                (
96                    name.as_str().to_string(),
97                    serde_json::Value::String(v.to_string()),
98                )
99            })
100        })
101        .collect();
102
103    serde_json::json!({
104        "method": method,
105        "path": path,
106        "query": query,
107        "headers": headers_map,
108    })
109}
110
111/// Result of a webhook HTTP call.
112struct WebhookCallResult {
113    /// HTTP status returned by the webhook server.
114    status: StatusCode,
115    /// Value of the `Location` header, if present.
116    /// Used to forward redirects to the client when a blocking webhook returns 3xx.
117    location: Option<String>,
118    /// Response body as plain text.
119    /// Used by `cache_key` webhooks to override the cache lookup key.
120    body: String,
121}
122
123/// POST `payload` to `url`.
124///
125/// Redirects are **not** followed — a `3xx` status is returned as-is so the
126/// caller can forward it to the original client.
127///
128/// Returns:
129/// - `Ok(WebhookCallResult)` — status, optional `Location` header, and body.
130/// - `Err(())` — timeout, connection error, or other transport failure.
131async fn call_webhook(
132    client: &reqwest::Client,
133    url: &str,
134    payload: &serde_json::Value,
135    timeout_ms: u64,
136) -> Result<WebhookCallResult, ()> {
137    let response = client
138        .post(url)
139        .timeout(std::time::Duration::from_millis(timeout_ms))
140        .json(payload)
141        .send()
142        .await
143        .map_err(|_| ())?;
144
145    let status = StatusCode::from_u16(response.status().as_u16()).map_err(|_| ())?;
146    let location = response
147        .headers()
148        .get(reqwest::header::LOCATION)
149        .and_then(|v| v.to_str().ok())
150        .map(|s| s.to_string());
151    let body = response.text().await.unwrap_or_default();
152
153    Ok(WebhookCallResult {
154        status,
155        location,
156        body,
157    })
158}
159
160/// Main proxy handler that serves prerendered content from cache
161/// or fetches from backend if not cached
162pub async fn proxy_handler(
163    Extension(state): Extension<Arc<ProxyState>>,
164    req: Request<Body>,
165) -> Result<Response<Body>, StatusCode> {
166    let request_started = Instant::now();
167    // Check for upgrade requests FIRST (before consuming anything from the request)
168    // This is critical for WebSocket to work properly
169    let is_upgrade = is_upgrade_request(req.headers());
170
171    if is_upgrade {
172        let method_str = req.method().as_str();
173        let path = req.uri().path();
174
175        // WebSocket / upgrade tunnelling is only meaningful when there is a live
176        // backend to tunnel to.  Pure SSG servers (PreGenerate with fallthrough
177        // disabled) have no backend reachable at request time, so we always
178        // return 501 for them regardless of the `enable_websocket` flag.
179        let ws_allowed = state.config.enable_websocket
180            && match &state.config.proxy_mode {
181                ProxyMode::Dynamic => true,
182                ProxyMode::PreGenerate { fallthrough, .. } => *fallthrough,
183            };
184
185        if ws_allowed {
186            tracing::debug!(
187                "Upgrade request detected for {} {}, establishing direct proxy tunnel",
188                method_str,
189                path
190            );
191            return handle_upgrade_request(state, req).await;
192        } else {
193            tracing::warn!(
194                "Upgrade request detected for {} {} but WebSocket support is disabled or not available in current proxy mode",
195                method_str,
196                path
197            );
198            return Err(StatusCode::NOT_IMPLEMENTED);
199        }
200    }
201
202    // Extract request details (only after we know it's not an upgrade request)
203    let method = req.method().clone();
204    let method_str = method.as_str();
205    let uri = req.uri().clone();
206    let path = uri.path();
207    let query = uri.query().unwrap_or("");
208    let headers = req.headers().clone();
209    tracing::debug!(
210        method = method_str,
211        path,
212        query,
213        "proxy request entered handler"
214    );
215
216    // Check if only GET requests are allowed
217    if state.config.forward_get_only && method != axum::http::Method::GET {
218        tracing::warn!(
219            "Non-GET request {} {} rejected (forward_get_only is enabled)",
220            method_str,
221            path
222        );
223        return Err(StatusCode::METHOD_NOT_ALLOWED);
224    }
225
226    // ── Webhook dispatch ────────────────────────────────────────────────────
227    // Webhooks fire before cache reads so that access control is enforced even
228    // for requests that would otherwise be served from the cache.
229    let mut cache_key_override: Option<String> = None;
230    if !state.config.webhooks.is_empty() {
231        let payload = build_webhook_payload(method_str, path, query, &headers);
232        let webhook_started = Instant::now();
233
234        for webhook in &state.config.webhooks {
235            match webhook.webhook_type {
236                WebhookType::Notify => {
237                    // Fire-and-forget: spawn without awaiting.
238                    let url = webhook.url.clone();
239                    let payload_clone = payload.clone();
240                    let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
241                    let webhook_client = state.webhook_client.clone();
242                    tokio::spawn(async move {
243                        if let Err(()) =
244                            call_webhook(&webhook_client, &url, &payload_clone, timeout_ms).await
245                        {
246                            tracing::warn!("Notify webhook POST to '{}' failed", url);
247                        }
248                    });
249                }
250                WebhookType::Blocking => {
251                    let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
252                    match call_webhook(&state.webhook_client, &webhook.url, &payload, timeout_ms)
253                        .await
254                    {
255                        Ok(result) if result.status.is_success() => {
256                            tracing::debug!(
257                                "Blocking webhook '{}' allowed {} {}",
258                                webhook.url,
259                                method_str,
260                                path
261                            );
262                        }
263                        Ok(result) if result.status.is_redirection() => {
264                            tracing::debug!(
265                                "Blocking webhook '{}' redirecting {} {} to {}",
266                                webhook.url,
267                                method_str,
268                                path,
269                                result.location.as_deref().unwrap_or("(no location)")
270                            );
271                            let mut builder = Response::builder().status(result.status);
272                            if let Some(loc) = &result.location {
273                                builder =
274                                    builder.header(axum::http::header::LOCATION, loc.as_str());
275                            }
276                            return Ok(builder
277                                .body(Body::empty())
278                                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?);
279                        }
280                        Ok(result) => {
281                            tracing::warn!(
282                                "Blocking webhook '{}' denied {} {} with status {}",
283                                webhook.url,
284                                method_str,
285                                path,
286                                result.status
287                            );
288                            return Err(result.status);
289                        }
290                        Err(()) => {
291                            tracing::warn!(
292                                "Blocking webhook '{}' timed out or failed for {} {} — denying request",
293                                webhook.url,
294                                method_str,
295                                path
296                            );
297                            return Err(StatusCode::SERVICE_UNAVAILABLE);
298                        }
299                    }
300                }
301                WebhookType::CacheKey => {
302                    let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
303                    match call_webhook(&state.webhook_client, &webhook.url, &payload, timeout_ms)
304                        .await
305                    {
306                        Ok(result) if result.status.is_success() => {
307                            let key = result.body.trim().to_string();
308                            if !key.is_empty() {
309                                tracing::debug!(
310                                    "Cache key webhook '{}' set key '{}' for {} {}",
311                                    webhook.url,
312                                    key,
313                                    method_str,
314                                    path
315                                );
316                                cache_key_override = Some(key);
317                            } else {
318                                tracing::warn!(
319                                    "Cache key webhook '{}' returned empty body for {} {} — using default key",
320                                    webhook.url,
321                                    method_str,
322                                    path
323                                );
324                            }
325                        }
326                        Ok(result) => {
327                            tracing::warn!(
328                                "Cache key webhook '{}' returned non-2xx {} for {} {} — using default key",
329                                webhook.url,
330                                result.status,
331                                method_str,
332                                path
333                            );
334                        }
335                        Err(()) => {
336                            tracing::warn!(
337                                "Cache key webhook '{}' timed out or failed for {} {} — using default key",
338                                webhook.url,
339                                method_str,
340                                path
341                            );
342                        }
343                    }
344                }
345            }
346        }
347
348        tracing::debug!(
349            method = method_str,
350            path,
351            elapsed_ms = webhook_started.elapsed().as_millis(),
352            "proxy request completed webhook phase"
353        );
354    }
355
356    // Check if this path should be cached based on include/exclude patterns
357    let should_cache = should_cache_path(
358        method_str,
359        path,
360        &state.config.include_paths,
361        &state.config.exclude_paths,
362    );
363
364    // Generate cache key using the configured function
365    let req_info = crate::RequestInfo {
366        method: method_str,
367        path,
368        query,
369        headers: &headers,
370    };
371    let cache_key = cache_key_override.unwrap_or_else(|| (state.config.cache_key_fn)(&req_info));
372    let cache_reads_enabled = !matches!(state.config.cache_strategy, crate::CacheStrategy::None);
373
374    // Try to get 404 cache first (available even if should_cache is false)
375    if cache_reads_enabled && state.config.cache_404_capacity > 0 {
376        if let Some(cached) = state.cache.get_404(&cache_key).await {
377            if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
378                tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
379                let response = build_response_from_cache(cached, &headers).await?;
380                tracing::debug!(
381                    method = method_str,
382                    path,
383                    elapsed_ms = request_started.elapsed().as_millis(),
384                    "proxy request served from 404 cache"
385                );
386                return Ok(response);
387            }
388        }
389    }
390
391    // Try to get from cache first (only if caching is enabled for this path)
392    if should_cache && cache_reads_enabled {
393        if let Some(cached) = state.cache.get(&cache_key).await {
394            if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
395                tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
396                let response = build_response_from_cache(cached, &headers).await?;
397                tracing::debug!(
398                    method = method_str,
399                    path,
400                    elapsed_ms = request_started.elapsed().as_millis(),
401                    "proxy request served from main cache"
402                );
403                return Ok(response);
404            }
405        }
406        // PreGenerate mode: serve only from cache, no backend fallthrough on miss
407        if let ProxyMode::PreGenerate { fallthrough, .. } = &state.config.proxy_mode {
408            if !fallthrough {
409                tracing::debug!(
410                    "PreGenerate cache miss for: {} {} — returning 404 (fallthrough disabled)",
411                    method_str,
412                    cache_key
413                );
414                return Err(StatusCode::NOT_FOUND);
415            }
416        }
417        tracing::debug!(
418            "Cache miss for: {} {}, fetching from backend",
419            method_str,
420            cache_key
421        );
422    } else if !cache_reads_enabled {
423        tracing::debug!(
424            "{} {} not cacheable (cache strategy: none), proxying directly",
425            method_str,
426            path
427        );
428    } else {
429        tracing::debug!(
430            "{} {} not cacheable (filtered), proxying directly",
431            method_str,
432            path
433        );
434    }
435
436    // Convert body to bytes to forward it
437    let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
438        Ok(bytes) => bytes,
439        Err(e) => {
440            tracing::error!("Failed to read request body: {}", e);
441            return Err(StatusCode::BAD_REQUEST);
442        }
443    };
444
445    // Fetch from backend (proxy_url)
446    // Use path+query only — not the full `uri` — because HTTP/2 requests carry an
447    // absolute-form URI (e.g. `https://example.com/path`) which would corrupt the
448    // concatenated URL when appended to proxy_url.
449    let path_and_query = uri
450        .path_and_query()
451        .map(|pq| pq.as_str())
452        .unwrap_or_else(|| uri.path());
453    let target_url = format!("{}{}", state.config.proxy_url, path_and_query);
454    let upstream_started = Instant::now();
455
456    let response = match state
457        .upstream_client
458        .request(method.clone(), &target_url)
459        .headers(convert_headers(&headers))
460        .body(body_bytes.to_vec())
461        .send()
462        .await
463    {
464        Ok(resp) => resp,
465        Err(e) => {
466            tracing::error!("Failed to fetch from backend: {}", e);
467            return Err(StatusCode::BAD_GATEWAY);
468        }
469    };
470    tracing::debug!(
471        method = method_str,
472        path,
473        elapsed_ms = upstream_started.elapsed().as_millis(),
474        "proxy request received upstream response headers"
475    );
476
477    // Cache the response (only if caching is enabled for this path)
478    let status = response.status().as_u16();
479    let response_headers = response.headers().clone();
480    let body_bytes = match response.bytes().await {
481        Ok(bytes) => bytes.to_vec(),
482        Err(e) => {
483            tracing::error!("Failed to read response body: {}", e);
484            return Err(StatusCode::BAD_GATEWAY);
485        }
486    };
487
488    let response_content_type = response_headers
489        .get(axum::http::header::CONTENT_TYPE)
490        .and_then(|value| value.to_str().ok());
491    let response_is_cacheable = state
492        .config
493        .cache_strategy
494        .allows_content_type(response_content_type);
495    let upstream_content_encoding = response_headers
496        .get(axum::http::header::CONTENT_ENCODING)
497        .and_then(|value| value.to_str().ok());
498    let should_try_cache = cache_reads_enabled
499        && response_is_cacheable
500        && (should_cache || state.config.cache_404_capacity > 0);
501    let normalized_body = if should_try_cache || state.config.use_404_meta {
502        match decode_upstream_body_async(
503            body_bytes.clone(),
504            upstream_content_encoding.map(|value| value.to_string()),
505        )
506        .await
507        {
508            Ok(body) => Some(body),
509            Err(error) => {
510                tracing::warn!(
511                    "Skipping cache compression for {} {} due to unsupported upstream encoding: {}",
512                    method_str,
513                    path,
514                    error
515                );
516                None
517            }
518        }
519    } else {
520        None
521    };
522
523    // Determine if this should be cached as a 404 (either by status or by meta tag if enabled)
524    let mut is_404 = status == 404;
525    if !is_404 && state.config.use_404_meta {
526        if let Some(body) = normalized_body.as_deref() {
527            is_404 = body_contains_404_meta(body);
528        }
529    }
530
531    let should_store_404 = is_404
532        && state.config.cache_404_capacity > 0
533        && response_is_cacheable
534        && cache_reads_enabled
535        && normalized_body.is_some();
536    let should_store_response = !is_404
537        && should_cache
538        && response_is_cacheable
539        && cache_reads_enabled
540        && normalized_body.is_some();
541
542    if should_store_404 || should_store_response {
543        let cached_response = match build_cached_response(
544            status,
545            &response_headers,
546            normalized_body.as_deref().unwrap(),
547            &state.config.compress_strategy,
548        )
549        .await
550        {
551            Ok(cached_response) => cached_response,
552            Err(error) => {
553                tracing::warn!(
554                    "Failed to prepare cached response for {} {}: {}",
555                    method_str,
556                    path,
557                    error
558                );
559                return Ok(build_response_from_upstream(
560                    status,
561                    &response_headers,
562                    body_bytes,
563                ));
564            }
565        };
566
567        if should_store_404 {
568            state
569                .cache
570                .set_404(cache_key.clone(), cached_response.clone())
571                .await;
572            tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
573        } else {
574            state
575                .cache
576                .set(cache_key.clone(), cached_response.clone())
577                .await;
578            tracing::debug!("Cached response for: {} {}", method_str, cache_key);
579        }
580
581        let response = build_response_from_cache(cached_response, &headers).await?;
582        tracing::debug!(
583            method = method_str,
584            path,
585            elapsed_ms = request_started.elapsed().as_millis(),
586            "proxy request completed after upstream fetch and cache write"
587        );
588        return Ok(response);
589    }
590
591    tracing::debug!(
592        method = method_str,
593        path,
594        elapsed_ms = request_started.elapsed().as_millis(),
595        "proxy request completed without caching"
596    );
597    Ok(build_response_from_upstream(
598        status,
599        &response_headers,
600        body_bytes,
601    ))
602}
603
604/// Handle WebSocket and other upgrade requests by establishing a direct TCP tunnel
605///
606/// This function handles long-lived connections like WebSocket by:
607/// 1. Connecting to the backend server
608/// 2. Forwarding the upgrade request
609/// 3. Capturing both client and backend upgrade connections
610/// 4. Creating a bidirectional TCP tunnel between them
611///
612/// The tunnel remains open for the lifetime of the connection, allowing
613/// full-duplex communication. Data flows directly between client and backend
614/// without any caching or inspection.
615async fn handle_upgrade_request(
616    state: Arc<ProxyState>,
617    mut req: Request<Body>,
618) -> Result<Response<Body>, StatusCode> {
619    // Use path+query only for the same reason as in proxy_handler (HTTP/2 absolute-form URI).
620    let req_path_and_query = req
621        .uri()
622        .path_and_query()
623        .map(|pq| pq.as_str())
624        .unwrap_or_else(|| req.uri().path());
625    let target_url = format!("{}{}", state.config.proxy_url, req_path_and_query);
626
627    // Parse the backend URL to extract host and port
628    let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
629        tracing::error!("Failed to parse backend URL: {}", e);
630        StatusCode::BAD_GATEWAY
631    })?;
632
633    let host = backend_uri.host().ok_or_else(|| {
634        tracing::error!("No host in backend URL");
635        StatusCode::BAD_GATEWAY
636    })?;
637
638    let port = backend_uri.port_u16().unwrap_or_else(|| {
639        if backend_uri.scheme_str() == Some("https") {
640            443
641        } else {
642            80
643        }
644    });
645
646    // IMPORTANT: Set up client upgrade BEFORE processing the request
647    // This captures the client's connection for later upgrade
648    let client_upgrade = hyper::upgrade::on(&mut req);
649
650    // Connect to backend
651    let backend_stream = tokio::net::TcpStream::connect((host, port))
652        .await
653        .map_err(|e| {
654            tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
655            StatusCode::BAD_GATEWAY
656        })?;
657
658    let backend_io = TokioIo::new(backend_stream);
659
660    // Build the backend request with upgrade support
661    let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
662        .await
663        .map_err(|e| {
664            tracing::error!("Failed to handshake with backend: {}", e);
665            StatusCode::BAD_GATEWAY
666        })?;
667
668    // Spawn a task to poll the connection - this will handle the upgrade
669    let conn_task = tokio::spawn(async move {
670        match conn.with_upgrades().await {
671            Ok(parts) => {
672                tracing::debug!("Backend connection upgraded successfully");
673                Ok(parts)
674            }
675            Err(e) => {
676                tracing::error!("Backend connection failed: {}", e);
677                Err(e)
678            }
679        }
680    });
681
682    // Forward the request to the backend
683    let backend_response = sender.send_request(req).await.map_err(|e| {
684        tracing::error!("Failed to send request to backend: {}", e);
685        StatusCode::BAD_GATEWAY
686    })?;
687
688    // Check if backend accepted the upgrade
689    let status = backend_response.status();
690    if status != StatusCode::SWITCHING_PROTOCOLS {
691        tracing::warn!("Backend did not accept upgrade request, status: {}", status);
692        // Convert the backend response to our response type
693        let (parts, body) = backend_response.into_parts();
694        let body = Body::new(body);
695        return Ok(Response::from_parts(parts, body));
696    }
697
698    // Extract headers before moving backend_response
699    let backend_headers = backend_response.headers().clone();
700
701    // Get the upgraded backend connection
702    let backend_upgrade = hyper::upgrade::on(backend_response);
703
704    // Spawn a task to handle bidirectional streaming between client and backend
705    tokio::spawn(async move {
706        tracing::debug!("Starting upgrade tunnel establishment");
707
708        // Wait for both upgrades to complete
709        let (client_result, backend_result) = tokio::join!(client_upgrade, backend_upgrade);
710
711        // Drop the connection task as we now have the upgraded streams
712        drop(conn_task);
713
714        match (client_result, backend_result) {
715            (Ok(client_upgraded), Ok(backend_upgraded)) => {
716                tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
717
718                // Wrap both in TokioIo for AsyncRead + AsyncWrite
719                let mut client_stream = TokioIo::new(client_upgraded);
720                let mut backend_stream = TokioIo::new(backend_upgraded);
721
722                // Create bidirectional tunnel
723                match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
724                    Ok((client_to_backend, backend_to_client)) => {
725                        tracing::debug!(
726                            "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
727                            client_to_backend,
728                            backend_to_client
729                        );
730                    }
731                    Err(e) => {
732                        tracing::error!("Tunnel error: {}", e);
733                    }
734                }
735            }
736            (Err(e), _) => {
737                tracing::error!("Client upgrade failed: {}", e);
738            }
739            (_, Err(e)) => {
740                tracing::error!("Backend upgrade failed: {}", e);
741            }
742        }
743    });
744
745    // Build the response to send back to the client with upgrade support
746    let mut response = Response::builder()
747        .status(StatusCode::SWITCHING_PROTOCOLS)
748        .body(Body::empty())
749        .unwrap();
750
751    // Copy necessary headers from backend response
752    // These headers are essential for WebSocket handshake
753    if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
754        response
755            .headers_mut()
756            .insert(axum::http::header::UPGRADE, upgrade_header.clone());
757    }
758    if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
759        response
760            .headers_mut()
761            .insert(axum::http::header::CONNECTION, connection_header.clone());
762    }
763    if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
764        response.headers_mut().insert(
765            HeaderName::from_static("sec-websocket-accept"),
766            sec_websocket_accept.clone(),
767        );
768    }
769
770    tracing::debug!("Upgrade response sent to client, tunnel task spawned");
771
772    Ok(response)
773}
774
775async fn build_response_from_cache(
776    cached: CachedResponse,
777    request_headers: &HeaderMap,
778) -> Result<Response<Body>, StatusCode> {
779    let mut response_headers = cached.headers;
780    let body = if let Some(content_encoding) = cached.content_encoding {
781        if client_accepts_encoding(request_headers, content_encoding) {
782            upsert_vary_accept_encoding(&mut response_headers);
783            cached.body
784        } else {
785            if !identity_acceptable(request_headers) {
786                tracing::warn!(
787                    "Client does not accept cached encoding '{}' or identity fallback",
788                    content_encoding.as_header_value()
789                );
790                return Err(StatusCode::NOT_ACCEPTABLE);
791            }
792
793            response_headers.remove("content-encoding");
794            upsert_vary_accept_encoding(&mut response_headers);
795            match decompress_body_async(cached.body.clone(), content_encoding).await {
796                Ok(body) => body,
797                Err(error) => {
798                    tracing::error!("Failed to decompress cached response: {}", error);
799                    return Err(StatusCode::INTERNAL_SERVER_ERROR);
800                }
801            }
802        }
803    } else {
804        cached.body
805    };
806
807    response_headers.remove("transfer-encoding");
808    response_headers.insert("content-length".to_string(), body.len().to_string());
809
810    Ok(build_response(cached.status, response_headers, body))
811}
812
813async fn build_cached_response(
814    status: u16,
815    response_headers: &reqwest::header::HeaderMap,
816    normalized_body: &[u8],
817    compress_strategy: &CompressStrategy,
818) -> anyhow::Result<CachedResponse> {
819    let mut headers = convert_headers_to_map(response_headers);
820    headers.remove("content-encoding");
821    headers.remove("content-length");
822    headers.remove("transfer-encoding");
823
824    let content_encoding = configured_encoding(compress_strategy);
825    let body = if let Some(content_encoding) = content_encoding {
826        let compressed = compress_body_async(normalized_body.to_vec(), content_encoding).await?;
827        headers.insert(
828            "content-encoding".to_string(),
829            content_encoding.as_header_value().to_string(),
830        );
831        upsert_vary_accept_encoding(&mut headers);
832        compressed
833    } else {
834        normalized_body.to_vec()
835    };
836
837    headers.insert("content-length".to_string(), body.len().to_string());
838
839    Ok(CachedResponse {
840        body,
841        headers,
842        status,
843        content_encoding,
844    })
845}
846
847fn build_response_from_upstream(
848    status: u16,
849    response_headers: &reqwest::header::HeaderMap,
850    body: Vec<u8>,
851) -> Response<Body> {
852    let mut headers = convert_headers_to_map(response_headers);
853    headers.remove("transfer-encoding");
854    headers.insert("content-length".to_string(), body.len().to_string());
855    build_response(status, headers, body)
856}
857
858fn build_response(
859    status: u16,
860    response_headers: HashMap<String, String>,
861    body: Vec<u8>,
862) -> Response<Body> {
863    let mut response = Response::builder().status(status);
864
865    // Add headers
866    let headers = response.headers_mut().unwrap();
867    for (key, value) in response_headers {
868        if let Ok(header_name) = key.parse::<HeaderName>() {
869            if let Ok(header_value) = HeaderValue::from_str(&value) {
870                headers.insert(header_name, header_value);
871            } else {
872                tracing::warn!(
873                    "Failed to parse header value for key '{}': {:?}",
874                    key,
875                    value
876                );
877            }
878        } else {
879            tracing::warn!("Failed to parse header name: {}", key);
880        }
881    }
882
883    response.body(Body::from(body)).unwrap()
884}
885
886fn cached_response_is_allowed(strategy: &crate::CacheStrategy, cached: &CachedResponse) -> bool {
887    strategy.allows_content_type(
888        cached
889            .headers
890            .get("content-type")
891            .map(|value| value.as_str()),
892    )
893}
894
895fn body_contains_404_meta(body: &[u8]) -> bool {
896    let Ok(body_str) = std::str::from_utf8(body) else {
897        return false;
898    };
899
900    let name_dbl = "name=\"phantom-404\"";
901    let name_sgl = "name='phantom-404'";
902    let content_dbl = "content=\"true\"";
903    let content_sgl = "content='true'";
904
905    (body_str.contains(name_dbl) || body_str.contains(name_sgl))
906        && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
907}
908
909fn upsert_vary_accept_encoding(headers: &mut HashMap<String, String>) {
910    match headers.get_mut("vary") {
911        Some(value) => {
912            let has_accept_encoding = value
913                .split(',')
914                .any(|part| part.trim().eq_ignore_ascii_case("accept-encoding"));
915            if !has_accept_encoding {
916                value.push_str(", Accept-Encoding");
917            }
918        }
919        None => {
920            headers.insert("vary".to_string(), "Accept-Encoding".to_string());
921        }
922    }
923}
924
925fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
926    let mut req_headers = reqwest::header::HeaderMap::new();
927    for (key, value) in headers {
928        // Skip host header as reqwest will set it
929        if key == axum::http::header::HOST {
930            continue;
931        }
932        if let Ok(val) = value.to_str() {
933            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
934                req_headers.insert(key.clone(), header_value);
935            }
936        }
937    }
938    req_headers
939}
940
941/// Fetch a single path from the upstream server, compress it, and store it in the cache.
942/// Used by the snapshot worker for PreGenerate warm-up and runtime snapshot management.
943pub(crate) async fn fetch_and_cache_snapshot(
944    path: &str,
945    client: &reqwest::Client,
946    proxy_url: &str,
947    cache: &CacheStore,
948    compress_strategy: &CompressStrategy,
949    cache_key_fn: &std::sync::Arc<dyn Fn(&crate::RequestInfo) -> String + Send + Sync>,
950) -> anyhow::Result<()> {
951    let empty_headers = axum::http::HeaderMap::new();
952    let req_info = crate::RequestInfo {
953        method: "GET",
954        path,
955        query: "",
956        headers: &empty_headers,
957    };
958    let cache_key = cache_key_fn(&req_info);
959
960    let url = format!("{}{}", proxy_url, path);
961    let response = client
962        .get(&url)
963        .send()
964        .await
965        .map_err(|e| anyhow::anyhow!("Failed to fetch snapshot '{}': {}", path, e))?;
966
967    let status = response.status().as_u16();
968    let response_headers = response.headers().clone();
969    let body_bytes = response
970        .bytes()
971        .await
972        .map_err(|e| anyhow::anyhow!("Failed to read snapshot response for '{}': {}", path, e))?
973        .to_vec();
974
975    let upstream_encoding = response_headers
976        .get(axum::http::header::CONTENT_ENCODING)
977        .and_then(|v| v.to_str().ok());
978    let normalized =
979        decode_upstream_body_async(body_bytes, upstream_encoding.map(|value| value.to_string()))
980            .await
981            .map_err(|e| anyhow::anyhow!("Failed to decode snapshot body for '{}': {}", path, e))?;
982
983    let cached =
984        build_cached_response(status, &response_headers, &normalized, compress_strategy).await?;
985    cache.set(cache_key, cached).await;
986    tracing::debug!("Snapshot pre-generated: {}", path);
987    Ok(())
988}
989
990fn convert_headers_to_map(
991    headers: &reqwest::header::HeaderMap,
992) -> std::collections::HashMap<String, String> {
993    let mut map = std::collections::HashMap::new();
994    for (key, value) in headers {
995        if let Ok(val) = value.to_str() {
996            map.insert(key.as_str().to_ascii_lowercase(), val.to_string());
997        } else {
998            // Log when we can't convert a header (might be binary)
999            tracing::debug!("Could not convert header '{}' to string", key);
1000        }
1001    }
1002    map
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007    use super::*;
1008    use crate::compression::ContentEncoding;
1009    use axum::body::to_bytes;
1010
1011    fn response_headers() -> reqwest::header::HeaderMap {
1012        let mut headers = reqwest::header::HeaderMap::new();
1013        headers.insert(
1014            reqwest::header::CONTENT_TYPE,
1015            reqwest::header::HeaderValue::from_static("text/html; charset=utf-8"),
1016        );
1017        headers
1018    }
1019
1020    #[tokio::test]
1021    async fn test_build_cached_response_uses_selected_encoding() {
1022        let cached = build_cached_response(
1023            200,
1024            &response_headers(),
1025            b"<html>compressed</html>",
1026            &CompressStrategy::Gzip,
1027        )
1028        .await
1029        .unwrap();
1030
1031        assert_eq!(cached.content_encoding, Some(ContentEncoding::Gzip));
1032        assert_eq!(
1033            cached.headers.get("content-encoding"),
1034            Some(&"gzip".to_string())
1035        );
1036        assert_eq!(
1037            cached.headers.get("vary"),
1038            Some(&"Accept-Encoding".to_string())
1039        );
1040    }
1041
1042    #[tokio::test]
1043    async fn test_build_response_from_cache_falls_back_to_identity() {
1044        let body = b"<html>identity</html>";
1045        let compressed = crate::compression::compress_body(body, ContentEncoding::Brotli).unwrap();
1046        let cached = CachedResponse {
1047            body: compressed,
1048            headers: HashMap::from([
1049                ("content-type".to_string(), "text/html".to_string()),
1050                ("content-encoding".to_string(), "br".to_string()),
1051                ("content-length".to_string(), "123".to_string()),
1052                ("vary".to_string(), "Accept-Encoding".to_string()),
1053            ]),
1054            status: 200,
1055            content_encoding: Some(ContentEncoding::Brotli),
1056        };
1057
1058        let mut request_headers = HeaderMap::new();
1059        request_headers.insert(
1060            axum::http::header::ACCEPT_ENCODING,
1061            HeaderValue::from_static("gzip"),
1062        );
1063
1064        let response = build_response_from_cache(cached, &request_headers)
1065            .await
1066            .unwrap();
1067        assert!(response
1068            .headers()
1069            .get(axum::http::header::CONTENT_ENCODING)
1070            .is_none());
1071
1072        let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
1073        assert_eq!(body.as_ref(), b"<html>identity</html>");
1074    }
1075
1076    #[tokio::test]
1077    async fn test_build_response_from_cache_keeps_supported_encoding() {
1078        let body = b"<html>compressed</html>";
1079        let compressed = crate::compression::compress_body(body, ContentEncoding::Brotli).unwrap();
1080        let cached = CachedResponse {
1081            body: compressed.clone(),
1082            headers: HashMap::from([
1083                ("content-type".to_string(), "text/html".to_string()),
1084                ("content-encoding".to_string(), "br".to_string()),
1085                ("content-length".to_string(), compressed.len().to_string()),
1086                ("vary".to_string(), "Accept-Encoding".to_string()),
1087            ]),
1088            status: 200,
1089            content_encoding: Some(ContentEncoding::Brotli),
1090        };
1091
1092        let mut request_headers = HeaderMap::new();
1093        request_headers.insert(
1094            axum::http::header::ACCEPT_ENCODING,
1095            HeaderValue::from_static("br, gzip;q=0.5"),
1096        );
1097
1098        let response = build_response_from_cache(cached, &request_headers)
1099            .await
1100            .unwrap();
1101        assert_eq!(
1102            response.headers().get(axum::http::header::CONTENT_ENCODING),
1103            Some(&HeaderValue::from_static("br"))
1104        );
1105
1106        let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
1107        assert_eq!(body.as_ref(), compressed.as_slice());
1108    }
1109}