Skip to main content

phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::compression::{
3    client_accepts_encoding, compress_body, configured_encoding, decode_upstream_body,
4    decompress_body, identity_acceptable,
5};
6use crate::path_matcher::should_cache_path;
7use crate::{CompressStrategy, CreateProxyConfig, ProxyMode, WebhookType};
8use axum::{
9    body::Body,
10    extract::Extension,
11    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
12};
13use hyper_util::rt::TokioIo;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub struct ProxyState {
19    cache: CacheStore,
20    config: CreateProxyConfig,
21}
22
23impl ProxyState {
24    pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
25        Self { cache, config }
26    }
27}
28
29/// Check if the request is a WebSocket or other upgrade request
30///
31/// WebSocket and other protocol upgrades are detected by checking for:
32/// - `Connection: Upgrade` header (case-insensitive)
33/// - Presence of `Upgrade` header
34///
35/// These requests will bypass caching and use direct TCP tunneling instead.
36fn is_upgrade_request(headers: &HeaderMap) -> bool {
37    headers
38        .get(axum::http::header::CONNECTION)
39        .and_then(|v| v.to_str().ok())
40        .map(|v| v.to_lowercase().contains("upgrade"))
41        .unwrap_or(false)
42        || headers.contains_key(axum::http::header::UPGRADE)
43}
44
45/// Build the JSON payload sent to webhook endpoints.
46///
47/// Contains `method`, `path`, `query`, and `headers` (as a flat string-to-string
48/// map). The request body is intentionally excluded so that the caller never has
49/// to consume it before the payload is built.
50fn build_webhook_payload(
51    method: &str,
52    path: &str,
53    query: &str,
54    headers: &HeaderMap,
55) -> serde_json::Value {
56    let headers_map: serde_json::Map<String, serde_json::Value> = headers
57        .iter()
58        .filter_map(|(name, value)| {
59            value
60                .to_str()
61                .ok()
62                .map(|v| (name.as_str().to_string(), serde_json::Value::String(v.to_string())))
63        })
64        .collect();
65
66    serde_json::json!({
67        "method": method,
68        "path": path,
69        "query": query,
70        "headers": headers_map,
71    })
72}
73
74/// Result of a webhook HTTP call.
75struct WebhookCallResult {
76    /// HTTP status returned by the webhook server.
77    status: StatusCode,
78    /// Value of the `Location` header, if present.
79    /// Used to forward redirects to the client when a blocking webhook returns 3xx.
80    location: Option<String>,
81    /// Response body as plain text.
82    /// Used by `cache_key` webhooks to override the cache lookup key.
83    body: String,
84}
85
86/// POST `payload` to `url`.
87///
88/// Redirects are **not** followed — a `3xx` status is returned as-is so the
89/// caller can forward it to the original client.
90///
91/// Returns:
92/// - `Ok(WebhookCallResult)` — status, optional `Location` header, and body.
93/// - `Err(())` — timeout, connection error, or other transport failure.
94async fn call_webhook(
95    url: &str,
96    payload: &serde_json::Value,
97    timeout_ms: u64,
98) -> Result<WebhookCallResult, ()> {
99    let client = reqwest::Client::builder()
100        .timeout(std::time::Duration::from_millis(timeout_ms))
101        .redirect(reqwest::redirect::Policy::none())
102        .build()
103        .map_err(|_| ())?;
104
105    let response = client
106        .post(url)
107        .json(payload)
108        .send()
109        .await
110        .map_err(|_| ())?;
111
112    let status = StatusCode::from_u16(response.status().as_u16()).map_err(|_| ())?;
113    let location = response
114        .headers()
115        .get(reqwest::header::LOCATION)
116        .and_then(|v| v.to_str().ok())
117        .map(|s| s.to_string());
118    let body = response.text().await.unwrap_or_default();
119
120    Ok(WebhookCallResult { status, location, body })
121}
122
123/// Main proxy handler that serves prerendered content from cache
124/// or fetches from backend if not cached
125pub async fn proxy_handler(
126    Extension(state): Extension<Arc<ProxyState>>,
127    req: Request<Body>,
128) -> Result<Response<Body>, StatusCode> {
129    // Check for upgrade requests FIRST (before consuming anything from the request)
130    // This is critical for WebSocket to work properly
131    let is_upgrade = is_upgrade_request(req.headers());
132
133    if is_upgrade {
134        let method_str = req.method().as_str();
135        let path = req.uri().path();
136
137        // WebSocket / upgrade tunnelling is only meaningful when there is a live
138        // backend to tunnel to.  Pure SSG servers (PreGenerate with fallthrough
139        // disabled) have no backend reachable at request time, so we always
140        // return 501 for them regardless of the `enable_websocket` flag.
141        let ws_allowed = state.config.enable_websocket
142            && match &state.config.proxy_mode {
143                ProxyMode::Dynamic => true,
144                ProxyMode::PreGenerate { fallthrough, .. } => *fallthrough,
145            };
146
147        if ws_allowed {
148            tracing::debug!(
149                "Upgrade request detected for {} {}, establishing direct proxy tunnel",
150                method_str,
151                path
152            );
153            return handle_upgrade_request(state, req).await;
154        } else {
155            tracing::warn!(
156                "Upgrade request detected for {} {} but WebSocket support is disabled or not available in current proxy mode",
157                method_str,
158                path
159            );
160            return Err(StatusCode::NOT_IMPLEMENTED);
161        }
162    }
163
164    // Extract request details (only after we know it's not an upgrade request)
165    let method = req.method().clone();
166    let method_str = method.as_str();
167    let uri = req.uri().clone();
168    let path = uri.path();
169    let query = uri.query().unwrap_or("");
170    let headers = req.headers().clone();
171
172    // Check if only GET requests are allowed
173    if state.config.forward_get_only && method != axum::http::Method::GET {
174        tracing::warn!(
175            "Non-GET request {} {} rejected (forward_get_only is enabled)",
176            method_str,
177            path
178        );
179        return Err(StatusCode::METHOD_NOT_ALLOWED);
180    }
181
182    // ── Webhook dispatch ────────────────────────────────────────────────────
183    // Webhooks fire before cache reads so that access control is enforced even
184    // for requests that would otherwise be served from the cache.
185    let mut cache_key_override: Option<String> = None;
186    if !state.config.webhooks.is_empty() {
187        let payload = build_webhook_payload(method_str, path, query, &headers);
188
189        for webhook in &state.config.webhooks {
190            match webhook.webhook_type {
191                WebhookType::Notify => {
192                    // Fire-and-forget: spawn without awaiting.
193                    let url = webhook.url.clone();
194                    let payload_clone = payload.clone();
195                    let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
196                    tokio::spawn(async move {
197                        if let Err(()) = call_webhook(&url, &payload_clone, timeout_ms).await {
198                            tracing::warn!("Notify webhook POST to '{}' failed", url);
199                        }
200                    });
201                }
202                WebhookType::Blocking => {
203                    let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
204                    match call_webhook(&webhook.url, &payload, timeout_ms).await {
205                        Ok(result) if result.status.is_success() => {
206                            tracing::debug!(
207                                "Blocking webhook '{}' allowed {} {}",
208                                webhook.url,
209                                method_str,
210                                path
211                            );
212                        }
213                        Ok(result) if result.status.is_redirection() => {
214                            tracing::debug!(
215                                "Blocking webhook '{}' redirecting {} {} to {}",
216                                webhook.url,
217                                method_str,
218                                path,
219                                result.location.as_deref().unwrap_or("(no location)")
220                            );
221                            let mut builder = Response::builder().status(result.status);
222                            if let Some(loc) = &result.location {
223                                builder = builder.header(axum::http::header::LOCATION, loc.as_str());
224                            }
225                            return Ok(builder
226                                .body(Body::empty())
227                                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?);
228                        }
229                        Ok(result) => {
230                            tracing::warn!(
231                                "Blocking webhook '{}' denied {} {} with status {}",
232                                webhook.url,
233                                method_str,
234                                path,
235                                result.status
236                            );
237                            return Err(result.status);
238                        }
239                        Err(()) => {
240                            tracing::warn!(
241                                "Blocking webhook '{}' timed out or failed for {} {} — denying request",
242                                webhook.url,
243                                method_str,
244                                path
245                            );
246                            return Err(StatusCode::SERVICE_UNAVAILABLE);
247                        }
248                    }
249                }
250                WebhookType::CacheKey => {
251                    let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
252                    match call_webhook(&webhook.url, &payload, timeout_ms).await {
253                        Ok(result) if result.status.is_success() => {
254                            let key = result.body.trim().to_string();
255                            if !key.is_empty() {
256                                tracing::debug!(
257                                    "Cache key webhook '{}' set key '{}' for {} {}",
258                                    webhook.url,
259                                    key,
260                                    method_str,
261                                    path
262                                );
263                                cache_key_override = Some(key);
264                            } else {
265                                tracing::warn!(
266                                    "Cache key webhook '{}' returned empty body for {} {} — using default key",
267                                    webhook.url,
268                                    method_str,
269                                    path
270                                );
271                            }
272                        }
273                        Ok(result) => {
274                            tracing::warn!(
275                                "Cache key webhook '{}' returned non-2xx {} for {} {} — using default key",
276                                webhook.url,
277                                result.status,
278                                method_str,
279                                path
280                            );
281                        }
282                        Err(()) => {
283                            tracing::warn!(
284                                "Cache key webhook '{}' timed out or failed for {} {} — using default key",
285                                webhook.url,
286                                method_str,
287                                path
288                            );
289                        }
290                    }
291                }
292            }
293        }
294    }
295
296    // Check if this path should be cached based on include/exclude patterns
297    let should_cache = should_cache_path(
298        method_str,
299        path,
300        &state.config.include_paths,
301        &state.config.exclude_paths,
302    );
303
304    // Generate cache key using the configured function
305    let req_info = crate::RequestInfo {
306        method: method_str,
307        path,
308        query,
309        headers: &headers,
310    };
311    let cache_key = cache_key_override
312        .unwrap_or_else(|| (state.config.cache_key_fn)(&req_info));
313    let cache_reads_enabled = !matches!(state.config.cache_strategy, crate::CacheStrategy::None);
314
315    // Try to get 404 cache first (available even if should_cache is false)
316    if cache_reads_enabled && state.config.cache_404_capacity > 0 {
317        if let Some(cached) = state.cache.get_404(&cache_key).await {
318            if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
319                tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
320                return build_response_from_cache(cached, &headers);
321            }
322        }
323    }
324
325    // Try to get from cache first (only if caching is enabled for this path)
326    if should_cache && cache_reads_enabled {
327        if let Some(cached) = state.cache.get(&cache_key).await {
328            if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
329                tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
330                return build_response_from_cache(cached, &headers);
331            }
332        }
333        // PreGenerate mode: serve only from cache, no backend fallthrough on miss
334        if let ProxyMode::PreGenerate { fallthrough, .. } = &state.config.proxy_mode {
335            if !fallthrough {
336                tracing::debug!(
337                    "PreGenerate cache miss for: {} {} — returning 404 (fallthrough disabled)",
338                    method_str,
339                    cache_key
340                );
341                return Err(StatusCode::NOT_FOUND);
342            }
343        }
344        tracing::debug!(
345            "Cache miss for: {} {}, fetching from backend",
346            method_str,
347            cache_key
348        );
349    } else if !cache_reads_enabled {
350        tracing::debug!(
351            "{} {} not cacheable (cache strategy: none), proxying directly",
352            method_str,
353            path
354        );
355    } else {
356        tracing::debug!(
357            "{} {} not cacheable (filtered), proxying directly",
358            method_str,
359            path
360        );
361    }
362
363    // Convert body to bytes to forward it
364    let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
365        Ok(bytes) => bytes,
366        Err(e) => {
367            tracing::error!("Failed to read request body: {}", e);
368            return Err(StatusCode::BAD_REQUEST);
369        }
370    };
371
372    // Fetch from backend (proxy_url)
373    // Use path+query only — not the full `uri` — because HTTP/2 requests carry an
374    // absolute-form URI (e.g. `https://example.com/path`) which would corrupt the
375    // concatenated URL when appended to proxy_url.
376    let path_and_query = uri
377        .path_and_query()
378        .map(|pq| pq.as_str())
379        .unwrap_or_else(|| uri.path());
380    let target_url = format!("{}{}", state.config.proxy_url, path_and_query);
381    let client = match reqwest::Client::builder()
382        .no_brotli()
383        .no_deflate()
384        .no_gzip()
385        .build()
386    {
387        Ok(client) => client,
388        Err(error) => {
389            tracing::error!("Failed to build upstream HTTP client: {}", error);
390            return Err(StatusCode::INTERNAL_SERVER_ERROR);
391        }
392    };
393
394    let response = match client
395        .request(method.clone(), &target_url)
396        .headers(convert_headers(&headers))
397        .body(body_bytes.to_vec())
398        .send()
399        .await
400    {
401        Ok(resp) => resp,
402        Err(e) => {
403            tracing::error!("Failed to fetch from backend: {}", e);
404            return Err(StatusCode::BAD_GATEWAY);
405        }
406    };
407
408    // Cache the response (only if caching is enabled for this path)
409    let status = response.status().as_u16();
410    let response_headers = response.headers().clone();
411    let body_bytes = match response.bytes().await {
412        Ok(bytes) => bytes.to_vec(),
413        Err(e) => {
414            tracing::error!("Failed to read response body: {}", e);
415            return Err(StatusCode::BAD_GATEWAY);
416        }
417    };
418
419    let response_content_type = response_headers
420        .get(axum::http::header::CONTENT_TYPE)
421        .and_then(|value| value.to_str().ok());
422    let response_is_cacheable = state
423        .config
424        .cache_strategy
425        .allows_content_type(response_content_type);
426    let upstream_content_encoding = response_headers
427        .get(axum::http::header::CONTENT_ENCODING)
428        .and_then(|value| value.to_str().ok());
429    let should_try_cache = cache_reads_enabled
430        && response_is_cacheable
431        && (should_cache || state.config.cache_404_capacity > 0);
432    let normalized_body = if should_try_cache || state.config.use_404_meta {
433        match decode_upstream_body(&body_bytes, upstream_content_encoding) {
434            Ok(body) => Some(body),
435            Err(error) => {
436                tracing::warn!(
437                    "Skipping cache compression for {} {} due to unsupported upstream encoding: {}",
438                    method_str,
439                    path,
440                    error
441                );
442                None
443            }
444        }
445    } else {
446        None
447    };
448
449    // Determine if this should be cached as a 404 (either by status or by meta tag if enabled)
450    let mut is_404 = status == 404;
451    if !is_404 && state.config.use_404_meta {
452        if let Some(body) = normalized_body.as_deref() {
453            is_404 = body_contains_404_meta(body);
454        }
455    }
456
457    let should_store_404 = is_404
458        && state.config.cache_404_capacity > 0
459        && response_is_cacheable
460        && cache_reads_enabled
461        && normalized_body.is_some();
462    let should_store_response = !is_404
463        && should_cache
464        && response_is_cacheable
465        && cache_reads_enabled
466        && normalized_body.is_some();
467
468    if should_store_404 || should_store_response {
469        let cached_response = match build_cached_response(
470            status,
471            &response_headers,
472            normalized_body.as_deref().unwrap(),
473            &state.config.compress_strategy,
474        ) {
475            Ok(cached_response) => cached_response,
476            Err(error) => {
477                tracing::warn!(
478                    "Failed to prepare cached response for {} {}: {}",
479                    method_str,
480                    path,
481                    error
482                );
483                return Ok(build_response_from_upstream(
484                    status,
485                    &response_headers,
486                    body_bytes,
487                ));
488            }
489        };
490
491        if should_store_404 {
492            state
493                .cache
494                .set_404(cache_key.clone(), cached_response.clone())
495                .await;
496            tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
497        } else {
498            state
499                .cache
500                .set(cache_key.clone(), cached_response.clone())
501                .await;
502            tracing::debug!("Cached response for: {} {}", method_str, cache_key);
503        }
504
505        return build_response_from_cache(cached_response, &headers);
506    }
507
508    Ok(build_response_from_upstream(
509        status,
510        &response_headers,
511        body_bytes,
512    ))
513}
514
515/// Handle WebSocket and other upgrade requests by establishing a direct TCP tunnel
516///
517/// This function handles long-lived connections like WebSocket by:
518/// 1. Connecting to the backend server
519/// 2. Forwarding the upgrade request
520/// 3. Capturing both client and backend upgrade connections
521/// 4. Creating a bidirectional TCP tunnel between them
522///
523/// The tunnel remains open for the lifetime of the connection, allowing
524/// full-duplex communication. Data flows directly between client and backend
525/// without any caching or inspection.
526async fn handle_upgrade_request(
527    state: Arc<ProxyState>,
528    mut req: Request<Body>,
529) -> Result<Response<Body>, StatusCode> {
530    // Use path+query only for the same reason as in proxy_handler (HTTP/2 absolute-form URI).
531    let req_path_and_query = req
532        .uri()
533        .path_and_query()
534        .map(|pq| pq.as_str())
535        .unwrap_or_else(|| req.uri().path());
536    let target_url = format!("{}{}", state.config.proxy_url, req_path_and_query);
537
538    // Parse the backend URL to extract host and port
539    let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
540        tracing::error!("Failed to parse backend URL: {}", e);
541        StatusCode::BAD_GATEWAY
542    })?;
543
544    let host = backend_uri.host().ok_or_else(|| {
545        tracing::error!("No host in backend URL");
546        StatusCode::BAD_GATEWAY
547    })?;
548
549    let port = backend_uri.port_u16().unwrap_or_else(|| {
550        if backend_uri.scheme_str() == Some("https") {
551            443
552        } else {
553            80
554        }
555    });
556
557    // IMPORTANT: Set up client upgrade BEFORE processing the request
558    // This captures the client's connection for later upgrade
559    let client_upgrade = hyper::upgrade::on(&mut req);
560
561    // Connect to backend
562    let backend_stream = tokio::net::TcpStream::connect((host, port))
563        .await
564        .map_err(|e| {
565            tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
566            StatusCode::BAD_GATEWAY
567        })?;
568
569    let backend_io = TokioIo::new(backend_stream);
570
571    // Build the backend request with upgrade support
572    let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
573        .await
574        .map_err(|e| {
575            tracing::error!("Failed to handshake with backend: {}", e);
576            StatusCode::BAD_GATEWAY
577        })?;
578
579    // Spawn a task to poll the connection - this will handle the upgrade
580    let conn_task = tokio::spawn(async move {
581        match conn.with_upgrades().await {
582            Ok(parts) => {
583                tracing::debug!("Backend connection upgraded successfully");
584                Ok(parts)
585            }
586            Err(e) => {
587                tracing::error!("Backend connection failed: {}", e);
588                Err(e)
589            }
590        }
591    });
592
593    // Forward the request to the backend
594    let backend_response = sender.send_request(req).await.map_err(|e| {
595        tracing::error!("Failed to send request to backend: {}", e);
596        StatusCode::BAD_GATEWAY
597    })?;
598
599    // Check if backend accepted the upgrade
600    let status = backend_response.status();
601    if status != StatusCode::SWITCHING_PROTOCOLS {
602        tracing::warn!("Backend did not accept upgrade request, status: {}", status);
603        // Convert the backend response to our response type
604        let (parts, body) = backend_response.into_parts();
605        let body = Body::new(body);
606        return Ok(Response::from_parts(parts, body));
607    }
608
609    // Extract headers before moving backend_response
610    let backend_headers = backend_response.headers().clone();
611
612    // Get the upgraded backend connection
613    let backend_upgrade = hyper::upgrade::on(backend_response);
614
615    // Spawn a task to handle bidirectional streaming between client and backend
616    tokio::spawn(async move {
617        tracing::debug!("Starting upgrade tunnel establishment");
618
619        // Wait for both upgrades to complete
620        let (client_result, backend_result) = tokio::join!(client_upgrade, backend_upgrade);
621
622        // Drop the connection task as we now have the upgraded streams
623        drop(conn_task);
624
625        match (client_result, backend_result) {
626            (Ok(client_upgraded), Ok(backend_upgraded)) => {
627                tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
628
629                // Wrap both in TokioIo for AsyncRead + AsyncWrite
630                let mut client_stream = TokioIo::new(client_upgraded);
631                let mut backend_stream = TokioIo::new(backend_upgraded);
632
633                // Create bidirectional tunnel
634                match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
635                    Ok((client_to_backend, backend_to_client)) => {
636                        tracing::debug!(
637                            "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
638                            client_to_backend,
639                            backend_to_client
640                        );
641                    }
642                    Err(e) => {
643                        tracing::error!("Tunnel error: {}", e);
644                    }
645                }
646            }
647            (Err(e), _) => {
648                tracing::error!("Client upgrade failed: {}", e);
649            }
650            (_, Err(e)) => {
651                tracing::error!("Backend upgrade failed: {}", e);
652            }
653        }
654    });
655
656    // Build the response to send back to the client with upgrade support
657    let mut response = Response::builder()
658        .status(StatusCode::SWITCHING_PROTOCOLS)
659        .body(Body::empty())
660        .unwrap();
661
662    // Copy necessary headers from backend response
663    // These headers are essential for WebSocket handshake
664    if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
665        response
666            .headers_mut()
667            .insert(axum::http::header::UPGRADE, upgrade_header.clone());
668    }
669    if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
670        response
671            .headers_mut()
672            .insert(axum::http::header::CONNECTION, connection_header.clone());
673    }
674    if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
675        response.headers_mut().insert(
676            HeaderName::from_static("sec-websocket-accept"),
677            sec_websocket_accept.clone(),
678        );
679    }
680
681    tracing::debug!("Upgrade response sent to client, tunnel task spawned");
682
683    Ok(response)
684}
685
686fn build_response_from_cache(
687    cached: CachedResponse,
688    request_headers: &HeaderMap,
689) -> Result<Response<Body>, StatusCode> {
690    let mut response_headers = cached.headers;
691    let body = if let Some(content_encoding) = cached.content_encoding {
692        if client_accepts_encoding(request_headers, content_encoding) {
693            upsert_vary_accept_encoding(&mut response_headers);
694            cached.body
695        } else {
696            if !identity_acceptable(request_headers) {
697                tracing::warn!(
698                    "Client does not accept cached encoding '{}' or identity fallback",
699                    content_encoding.as_header_value()
700                );
701                return Err(StatusCode::NOT_ACCEPTABLE);
702            }
703
704            response_headers.remove("content-encoding");
705            upsert_vary_accept_encoding(&mut response_headers);
706            match decompress_body(&cached.body, content_encoding) {
707                Ok(body) => body,
708                Err(error) => {
709                    tracing::error!("Failed to decompress cached response: {}", error);
710                    return Err(StatusCode::INTERNAL_SERVER_ERROR);
711                }
712            }
713        }
714    } else {
715        cached.body
716    };
717
718    response_headers.remove("transfer-encoding");
719    response_headers.insert("content-length".to_string(), body.len().to_string());
720
721    Ok(build_response(cached.status, response_headers, body))
722}
723
724fn build_cached_response(
725    status: u16,
726    response_headers: &reqwest::header::HeaderMap,
727    normalized_body: &[u8],
728    compress_strategy: &CompressStrategy,
729) -> anyhow::Result<CachedResponse> {
730    let mut headers = convert_headers_to_map(response_headers);
731    headers.remove("content-encoding");
732    headers.remove("content-length");
733    headers.remove("transfer-encoding");
734
735    let content_encoding = configured_encoding(compress_strategy);
736    let body = if let Some(content_encoding) = content_encoding {
737        let compressed = compress_body(normalized_body, content_encoding)?;
738        headers.insert(
739            "content-encoding".to_string(),
740            content_encoding.as_header_value().to_string(),
741        );
742        upsert_vary_accept_encoding(&mut headers);
743        compressed
744    } else {
745        normalized_body.to_vec()
746    };
747
748    headers.insert("content-length".to_string(), body.len().to_string());
749
750    Ok(CachedResponse {
751        body,
752        headers,
753        status,
754        content_encoding,
755    })
756}
757
758fn build_response_from_upstream(
759    status: u16,
760    response_headers: &reqwest::header::HeaderMap,
761    body: Vec<u8>,
762) -> Response<Body> {
763    let mut headers = convert_headers_to_map(response_headers);
764    headers.remove("transfer-encoding");
765    headers.insert("content-length".to_string(), body.len().to_string());
766    build_response(status, headers, body)
767}
768
769fn build_response(
770    status: u16,
771    response_headers: HashMap<String, String>,
772    body: Vec<u8>,
773) -> Response<Body> {
774    let mut response = Response::builder().status(status);
775
776    // Add headers
777    let headers = response.headers_mut().unwrap();
778    for (key, value) in response_headers {
779        if let Ok(header_name) = key.parse::<HeaderName>() {
780            if let Ok(header_value) = HeaderValue::from_str(&value) {
781                headers.insert(header_name, header_value);
782            } else {
783                tracing::warn!(
784                    "Failed to parse header value for key '{}': {:?}",
785                    key,
786                    value
787                );
788            }
789        } else {
790            tracing::warn!("Failed to parse header name: {}", key);
791        }
792    }
793
794    response.body(Body::from(body)).unwrap()
795}
796
797fn cached_response_is_allowed(strategy: &crate::CacheStrategy, cached: &CachedResponse) -> bool {
798    strategy.allows_content_type(
799        cached
800            .headers
801            .get("content-type")
802            .map(|value| value.as_str()),
803    )
804}
805
806fn body_contains_404_meta(body: &[u8]) -> bool {
807    let Ok(body_str) = std::str::from_utf8(body) else {
808        return false;
809    };
810
811    let name_dbl = "name=\"phantom-404\"";
812    let name_sgl = "name='phantom-404'";
813    let content_dbl = "content=\"true\"";
814    let content_sgl = "content='true'";
815
816    (body_str.contains(name_dbl) || body_str.contains(name_sgl))
817        && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
818}
819
820fn upsert_vary_accept_encoding(headers: &mut HashMap<String, String>) {
821    match headers.get_mut("vary") {
822        Some(value) => {
823            let has_accept_encoding = value
824                .split(',')
825                .any(|part| part.trim().eq_ignore_ascii_case("accept-encoding"));
826            if !has_accept_encoding {
827                value.push_str(", Accept-Encoding");
828            }
829        }
830        None => {
831            headers.insert("vary".to_string(), "Accept-Encoding".to_string());
832        }
833    }
834}
835
836fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
837    let mut req_headers = reqwest::header::HeaderMap::new();
838    for (key, value) in headers {
839        // Skip host header as reqwest will set it
840        if key == axum::http::header::HOST {
841            continue;
842        }
843        if let Ok(val) = value.to_str() {
844            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
845                req_headers.insert(key.clone(), header_value);
846            }
847        }
848    }
849    req_headers
850}
851
852/// Fetch a single path from the upstream server, compress it, and store it in the cache.
853/// Used by the snapshot worker for PreGenerate warm-up and runtime snapshot management.
854pub(crate) async fn fetch_and_cache_snapshot(
855    path: &str,
856    proxy_url: &str,
857    cache: &CacheStore,
858    compress_strategy: &CompressStrategy,
859    cache_key_fn: &std::sync::Arc<dyn Fn(&crate::RequestInfo) -> String + Send + Sync>,
860) -> anyhow::Result<()> {
861    let empty_headers = axum::http::HeaderMap::new();
862    let req_info = crate::RequestInfo {
863        method: "GET",
864        path,
865        query: "",
866        headers: &empty_headers,
867    };
868    let cache_key = cache_key_fn(&req_info);
869
870    let client = reqwest::Client::builder()
871        .no_brotli()
872        .no_deflate()
873        .no_gzip()
874        .build()
875        .map_err(|e| anyhow::anyhow!("Failed to build HTTP client for snapshot fetch: {}", e))?;
876
877    let url = format!("{}{}", proxy_url, path);
878    let response = client
879        .get(&url)
880        .send()
881        .await
882        .map_err(|e| anyhow::anyhow!("Failed to fetch snapshot '{}': {}", path, e))?;
883
884    let status = response.status().as_u16();
885    let response_headers = response.headers().clone();
886    let body_bytes = response
887        .bytes()
888        .await
889        .map_err(|e| anyhow::anyhow!("Failed to read snapshot response for '{}': {}", path, e))?
890        .to_vec();
891
892    let upstream_encoding = response_headers
893        .get(axum::http::header::CONTENT_ENCODING)
894        .and_then(|v| v.to_str().ok());
895    let normalized = decode_upstream_body(&body_bytes, upstream_encoding)
896        .map_err(|e| anyhow::anyhow!("Failed to decode snapshot body for '{}': {}", path, e))?;
897
898    let cached = build_cached_response(status, &response_headers, &normalized, compress_strategy)?;
899    cache.set(cache_key, cached).await;
900    tracing::debug!("Snapshot pre-generated: {}", path);
901    Ok(())
902}
903
904fn convert_headers_to_map(
905    headers: &reqwest::header::HeaderMap,
906) -> std::collections::HashMap<String, String> {
907    let mut map = std::collections::HashMap::new();
908    for (key, value) in headers {
909        if let Ok(val) = value.to_str() {
910            map.insert(key.as_str().to_ascii_lowercase(), val.to_string());
911        } else {
912            // Log when we can't convert a header (might be binary)
913            tracing::debug!("Could not convert header '{}' to string", key);
914        }
915    }
916    map
917}
918
919#[cfg(test)]
920mod tests {
921    use super::*;
922    use crate::compression::{compress_body, ContentEncoding};
923    use axum::body::to_bytes;
924
925    fn response_headers() -> reqwest::header::HeaderMap {
926        let mut headers = reqwest::header::HeaderMap::new();
927        headers.insert(
928            reqwest::header::CONTENT_TYPE,
929            reqwest::header::HeaderValue::from_static("text/html; charset=utf-8"),
930        );
931        headers
932    }
933
934    #[test]
935    fn test_build_cached_response_uses_selected_encoding() {
936        let cached = build_cached_response(
937            200,
938            &response_headers(),
939            b"<html>compressed</html>",
940            &CompressStrategy::Gzip,
941        )
942        .unwrap();
943
944        assert_eq!(cached.content_encoding, Some(ContentEncoding::Gzip));
945        assert_eq!(
946            cached.headers.get("content-encoding"),
947            Some(&"gzip".to_string())
948        );
949        assert_eq!(
950            cached.headers.get("vary"),
951            Some(&"Accept-Encoding".to_string())
952        );
953    }
954
955    #[tokio::test]
956    async fn test_build_response_from_cache_falls_back_to_identity() {
957        let body = b"<html>identity</html>";
958        let compressed = compress_body(body, ContentEncoding::Brotli).unwrap();
959        let cached = CachedResponse {
960            body: compressed,
961            headers: HashMap::from([
962                ("content-type".to_string(), "text/html".to_string()),
963                ("content-encoding".to_string(), "br".to_string()),
964                ("content-length".to_string(), "123".to_string()),
965                ("vary".to_string(), "Accept-Encoding".to_string()),
966            ]),
967            status: 200,
968            content_encoding: Some(ContentEncoding::Brotli),
969        };
970
971        let mut request_headers = HeaderMap::new();
972        request_headers.insert(
973            axum::http::header::ACCEPT_ENCODING,
974            HeaderValue::from_static("gzip"),
975        );
976
977        let response = build_response_from_cache(cached, &request_headers).unwrap();
978        assert!(response
979            .headers()
980            .get(axum::http::header::CONTENT_ENCODING)
981            .is_none());
982
983        let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
984        assert_eq!(body.as_ref(), b"<html>identity</html>");
985    }
986
987    #[tokio::test]
988    async fn test_build_response_from_cache_keeps_supported_encoding() {
989        let body = b"<html>compressed</html>";
990        let compressed = compress_body(body, ContentEncoding::Brotli).unwrap();
991        let cached = CachedResponse {
992            body: compressed.clone(),
993            headers: HashMap::from([
994                ("content-type".to_string(), "text/html".to_string()),
995                ("content-encoding".to_string(), "br".to_string()),
996                ("content-length".to_string(), compressed.len().to_string()),
997                ("vary".to_string(), "Accept-Encoding".to_string()),
998            ]),
999            status: 200,
1000            content_encoding: Some(ContentEncoding::Brotli),
1001        };
1002
1003        let mut request_headers = HeaderMap::new();
1004        request_headers.insert(
1005            axum::http::header::ACCEPT_ENCODING,
1006            HeaderValue::from_static("br, gzip;q=0.5"),
1007        );
1008
1009        let response = build_response_from_cache(cached, &request_headers).unwrap();
1010        assert_eq!(
1011            response.headers().get(axum::http::header::CONTENT_ENCODING),
1012            Some(&HeaderValue::from_static("br"))
1013        );
1014
1015        let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
1016        assert_eq!(body.as_ref(), compressed.as_slice());
1017    }
1018}