Skip to main content

phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::compression::{
3    client_accepts_encoding, compress_body, configured_encoding, decode_upstream_body,
4    decompress_body, identity_acceptable,
5};
6use crate::path_matcher::should_cache_path;
7use crate::{CompressStrategy, CreateProxyConfig, ProxyMode};
8use axum::{
9    body::Body,
10    extract::Extension,
11    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
12};
13use hyper_util::rt::TokioIo;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub struct ProxyState {
19    cache: CacheStore,
20    config: CreateProxyConfig,
21}
22
23impl ProxyState {
24    pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
25        Self { cache, config }
26    }
27}
28
29/// Check if the request is a WebSocket or other upgrade request
30///
31/// WebSocket and other protocol upgrades are detected by checking for:
32/// - `Connection: Upgrade` header (case-insensitive)
33/// - Presence of `Upgrade` header
34///
35/// These requests will bypass caching and use direct TCP tunneling instead.
36fn is_upgrade_request(headers: &HeaderMap) -> bool {
37    headers
38        .get(axum::http::header::CONNECTION)
39        .and_then(|v| v.to_str().ok())
40        .map(|v| v.to_lowercase().contains("upgrade"))
41        .unwrap_or(false)
42        || headers.contains_key(axum::http::header::UPGRADE)
43}
44
45/// Main proxy handler that serves prerendered content from cache
46/// or fetches from backend if not cached
47pub async fn proxy_handler(
48    Extension(state): Extension<Arc<ProxyState>>,
49    req: Request<Body>,
50) -> Result<Response<Body>, StatusCode> {
51    // Check for upgrade requests FIRST (before consuming anything from the request)
52    // This is critical for WebSocket to work properly
53    let is_upgrade = is_upgrade_request(req.headers());
54
55    if is_upgrade {
56        let method_str = req.method().as_str();
57        let path = req.uri().path();
58
59        // WebSocket / upgrade tunnelling is only meaningful when there is a live
60        // backend to tunnel to.  Pure SSG servers (PreGenerate with fallthrough
61        // disabled) have no backend reachable at request time, so we always
62        // return 501 for them regardless of the `enable_websocket` flag.
63        let ws_allowed = state.config.enable_websocket
64            && match &state.config.proxy_mode {
65                ProxyMode::Dynamic => true,
66                ProxyMode::PreGenerate { fallthrough, .. } => *fallthrough,
67            };
68
69        if ws_allowed {
70            tracing::debug!(
71                "Upgrade request detected for {} {}, establishing direct proxy tunnel",
72                method_str,
73                path
74            );
75            return handle_upgrade_request(state, req).await;
76        } else {
77            tracing::warn!(
78                "Upgrade request detected for {} {} but WebSocket support is disabled or not available in current proxy mode",
79                method_str,
80                path
81            );
82            return Err(StatusCode::NOT_IMPLEMENTED);
83        }
84    }
85
86    // Extract request details (only after we know it's not an upgrade request)
87    let method = req.method().clone();
88    let method_str = method.as_str();
89    let uri = req.uri().clone();
90    let path = uri.path();
91    let query = uri.query().unwrap_or("");
92    let headers = req.headers().clone();
93
94    // Check if only GET requests are allowed
95    if state.config.forward_get_only && method != axum::http::Method::GET {
96        tracing::warn!(
97            "Non-GET request {} {} rejected (forward_get_only is enabled)",
98            method_str,
99            path
100        );
101        return Err(StatusCode::METHOD_NOT_ALLOWED);
102    }
103
104    // Check if this path should be cached based on include/exclude patterns
105    let should_cache = should_cache_path(
106        method_str,
107        path,
108        &state.config.include_paths,
109        &state.config.exclude_paths,
110    );
111
112    // Generate cache key using the configured function
113    let req_info = crate::RequestInfo {
114        method: method_str,
115        path,
116        query,
117        headers: &headers,
118    };
119    let cache_key = (state.config.cache_key_fn)(&req_info);
120    let cache_reads_enabled = !matches!(state.config.cache_strategy, crate::CacheStrategy::None);
121
122    // Try to get 404 cache first (available even if should_cache is false)
123    if cache_reads_enabled && state.config.cache_404_capacity > 0 {
124        if let Some(cached) = state.cache.get_404(&cache_key).await {
125            if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
126                tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
127                return build_response_from_cache(cached, &headers);
128            }
129        }
130    }
131
132    // Try to get from cache first (only if caching is enabled for this path)
133    if should_cache && cache_reads_enabled {
134        if let Some(cached) = state.cache.get(&cache_key).await {
135            if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
136                tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
137                return build_response_from_cache(cached, &headers);
138            }
139        }
140        // PreGenerate mode: serve only from cache, no backend fallthrough on miss
141        if let ProxyMode::PreGenerate { fallthrough, .. } = &state.config.proxy_mode {
142            if !fallthrough {
143                tracing::debug!(
144                    "PreGenerate cache miss for: {} {} — returning 404 (fallthrough disabled)",
145                    method_str,
146                    cache_key
147                );
148                return Err(StatusCode::NOT_FOUND);
149            }
150        }
151        tracing::debug!(
152            "Cache miss for: {} {}, fetching from backend",
153            method_str,
154            cache_key
155        );
156    } else if !cache_reads_enabled {
157        tracing::debug!(
158            "{} {} not cacheable (cache strategy: none), proxying directly",
159            method_str,
160            path
161        );
162    } else {
163        tracing::debug!(
164            "{} {} not cacheable (filtered), proxying directly",
165            method_str,
166            path
167        );
168    }
169
170    // Convert body to bytes to forward it
171    let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
172        Ok(bytes) => bytes,
173        Err(e) => {
174            tracing::error!("Failed to read request body: {}", e);
175            return Err(StatusCode::BAD_REQUEST);
176        }
177    };
178
179    // Fetch from backend (proxy_url)
180    // Use path+query only — not the full `uri` — because HTTP/2 requests carry an
181    // absolute-form URI (e.g. `https://example.com/path`) which would corrupt the
182    // concatenated URL when appended to proxy_url.
183    let path_and_query = uri
184        .path_and_query()
185        .map(|pq| pq.as_str())
186        .unwrap_or_else(|| uri.path());
187    let target_url = format!("{}{}", state.config.proxy_url, path_and_query);
188    let client = match reqwest::Client::builder()
189        .no_brotli()
190        .no_deflate()
191        .no_gzip()
192        .build()
193    {
194        Ok(client) => client,
195        Err(error) => {
196            tracing::error!("Failed to build upstream HTTP client: {}", error);
197            return Err(StatusCode::INTERNAL_SERVER_ERROR);
198        }
199    };
200
201    let response = match client
202        .request(method.clone(), &target_url)
203        .headers(convert_headers(&headers))
204        .body(body_bytes.to_vec())
205        .send()
206        .await
207    {
208        Ok(resp) => resp,
209        Err(e) => {
210            tracing::error!("Failed to fetch from backend: {}", e);
211            return Err(StatusCode::BAD_GATEWAY);
212        }
213    };
214
215    // Cache the response (only if caching is enabled for this path)
216    let status = response.status().as_u16();
217    let response_headers = response.headers().clone();
218    let body_bytes = match response.bytes().await {
219        Ok(bytes) => bytes.to_vec(),
220        Err(e) => {
221            tracing::error!("Failed to read response body: {}", e);
222            return Err(StatusCode::BAD_GATEWAY);
223        }
224    };
225
226    let response_content_type = response_headers
227        .get(axum::http::header::CONTENT_TYPE)
228        .and_then(|value| value.to_str().ok());
229    let response_is_cacheable = state
230        .config
231        .cache_strategy
232        .allows_content_type(response_content_type);
233    let upstream_content_encoding = response_headers
234        .get(axum::http::header::CONTENT_ENCODING)
235        .and_then(|value| value.to_str().ok());
236    let should_try_cache = cache_reads_enabled
237        && response_is_cacheable
238        && (should_cache || state.config.cache_404_capacity > 0);
239    let normalized_body = if should_try_cache || state.config.use_404_meta {
240        match decode_upstream_body(&body_bytes, upstream_content_encoding) {
241            Ok(body) => Some(body),
242            Err(error) => {
243                tracing::warn!(
244                    "Skipping cache compression for {} {} due to unsupported upstream encoding: {}",
245                    method_str,
246                    path,
247                    error
248                );
249                None
250            }
251        }
252    } else {
253        None
254    };
255
256    // Determine if this should be cached as a 404 (either by status or by meta tag if enabled)
257    let mut is_404 = status == 404;
258    if !is_404 && state.config.use_404_meta {
259        if let Some(body) = normalized_body.as_deref() {
260            is_404 = body_contains_404_meta(body);
261        }
262    }
263
264    let should_store_404 = is_404
265        && state.config.cache_404_capacity > 0
266        && response_is_cacheable
267        && cache_reads_enabled
268        && normalized_body.is_some();
269    let should_store_response = !is_404
270        && should_cache
271        && response_is_cacheable
272        && cache_reads_enabled
273        && normalized_body.is_some();
274
275    if should_store_404 || should_store_response {
276        let cached_response = match build_cached_response(
277            status,
278            &response_headers,
279            normalized_body.as_deref().unwrap(),
280            &state.config.compress_strategy,
281        ) {
282            Ok(cached_response) => cached_response,
283            Err(error) => {
284                tracing::warn!(
285                    "Failed to prepare cached response for {} {}: {}",
286                    method_str,
287                    path,
288                    error
289                );
290                return Ok(build_response_from_upstream(
291                    status,
292                    &response_headers,
293                    body_bytes,
294                ));
295            }
296        };
297
298        if should_store_404 {
299            state
300                .cache
301                .set_404(cache_key.clone(), cached_response.clone())
302                .await;
303            tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
304        } else {
305            state
306                .cache
307                .set(cache_key.clone(), cached_response.clone())
308                .await;
309            tracing::debug!("Cached response for: {} {}", method_str, cache_key);
310        }
311
312        return build_response_from_cache(cached_response, &headers);
313    }
314
315    Ok(build_response_from_upstream(
316        status,
317        &response_headers,
318        body_bytes,
319    ))
320}
321
322/// Handle WebSocket and other upgrade requests by establishing a direct TCP tunnel
323///
324/// This function handles long-lived connections like WebSocket by:
325/// 1. Connecting to the backend server
326/// 2. Forwarding the upgrade request
327/// 3. Capturing both client and backend upgrade connections
328/// 4. Creating a bidirectional TCP tunnel between them
329///
330/// The tunnel remains open for the lifetime of the connection, allowing
331/// full-duplex communication. Data flows directly between client and backend
332/// without any caching or inspection.
333async fn handle_upgrade_request(
334    state: Arc<ProxyState>,
335    mut req: Request<Body>,
336) -> Result<Response<Body>, StatusCode> {
337    // Use path+query only for the same reason as in proxy_handler (HTTP/2 absolute-form URI).
338    let req_path_and_query = req
339        .uri()
340        .path_and_query()
341        .map(|pq| pq.as_str())
342        .unwrap_or_else(|| req.uri().path());
343    let target_url = format!("{}{}", state.config.proxy_url, req_path_and_query);
344
345    // Parse the backend URL to extract host and port
346    let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
347        tracing::error!("Failed to parse backend URL: {}", e);
348        StatusCode::BAD_GATEWAY
349    })?;
350
351    let host = backend_uri.host().ok_or_else(|| {
352        tracing::error!("No host in backend URL");
353        StatusCode::BAD_GATEWAY
354    })?;
355
356    let port = backend_uri.port_u16().unwrap_or_else(|| {
357        if backend_uri.scheme_str() == Some("https") {
358            443
359        } else {
360            80
361        }
362    });
363
364    // IMPORTANT: Set up client upgrade BEFORE processing the request
365    // This captures the client's connection for later upgrade
366    let client_upgrade = hyper::upgrade::on(&mut req);
367
368    // Connect to backend
369    let backend_stream = tokio::net::TcpStream::connect((host, port))
370        .await
371        .map_err(|e| {
372            tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
373            StatusCode::BAD_GATEWAY
374        })?;
375
376    let backend_io = TokioIo::new(backend_stream);
377
378    // Build the backend request with upgrade support
379    let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
380        .await
381        .map_err(|e| {
382            tracing::error!("Failed to handshake with backend: {}", e);
383            StatusCode::BAD_GATEWAY
384        })?;
385
386    // Spawn a task to poll the connection - this will handle the upgrade
387    let conn_task = tokio::spawn(async move {
388        match conn.with_upgrades().await {
389            Ok(parts) => {
390                tracing::debug!("Backend connection upgraded successfully");
391                Ok(parts)
392            }
393            Err(e) => {
394                tracing::error!("Backend connection failed: {}", e);
395                Err(e)
396            }
397        }
398    });
399
400    // Forward the request to the backend
401    let backend_response = sender.send_request(req).await.map_err(|e| {
402        tracing::error!("Failed to send request to backend: {}", e);
403        StatusCode::BAD_GATEWAY
404    })?;
405
406    // Check if backend accepted the upgrade
407    let status = backend_response.status();
408    if status != StatusCode::SWITCHING_PROTOCOLS {
409        tracing::warn!("Backend did not accept upgrade request, status: {}", status);
410        // Convert the backend response to our response type
411        let (parts, body) = backend_response.into_parts();
412        let body = Body::new(body);
413        return Ok(Response::from_parts(parts, body));
414    }
415
416    // Extract headers before moving backend_response
417    let backend_headers = backend_response.headers().clone();
418
419    // Get the upgraded backend connection
420    let backend_upgrade = hyper::upgrade::on(backend_response);
421
422    // Spawn a task to handle bidirectional streaming between client and backend
423    tokio::spawn(async move {
424        tracing::debug!("Starting upgrade tunnel establishment");
425
426        // Wait for both upgrades to complete
427        let (client_result, backend_result) = tokio::join!(client_upgrade, backend_upgrade);
428
429        // Drop the connection task as we now have the upgraded streams
430        drop(conn_task);
431
432        match (client_result, backend_result) {
433            (Ok(client_upgraded), Ok(backend_upgraded)) => {
434                tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
435
436                // Wrap both in TokioIo for AsyncRead + AsyncWrite
437                let mut client_stream = TokioIo::new(client_upgraded);
438                let mut backend_stream = TokioIo::new(backend_upgraded);
439
440                // Create bidirectional tunnel
441                match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
442                    Ok((client_to_backend, backend_to_client)) => {
443                        tracing::debug!(
444                            "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
445                            client_to_backend,
446                            backend_to_client
447                        );
448                    }
449                    Err(e) => {
450                        tracing::error!("Tunnel error: {}", e);
451                    }
452                }
453            }
454            (Err(e), _) => {
455                tracing::error!("Client upgrade failed: {}", e);
456            }
457            (_, Err(e)) => {
458                tracing::error!("Backend upgrade failed: {}", e);
459            }
460        }
461    });
462
463    // Build the response to send back to the client with upgrade support
464    let mut response = Response::builder()
465        .status(StatusCode::SWITCHING_PROTOCOLS)
466        .body(Body::empty())
467        .unwrap();
468
469    // Copy necessary headers from backend response
470    // These headers are essential for WebSocket handshake
471    if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
472        response
473            .headers_mut()
474            .insert(axum::http::header::UPGRADE, upgrade_header.clone());
475    }
476    if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
477        response
478            .headers_mut()
479            .insert(axum::http::header::CONNECTION, connection_header.clone());
480    }
481    if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
482        response.headers_mut().insert(
483            HeaderName::from_static("sec-websocket-accept"),
484            sec_websocket_accept.clone(),
485        );
486    }
487
488    tracing::debug!("Upgrade response sent to client, tunnel task spawned");
489
490    Ok(response)
491}
492
493fn build_response_from_cache(
494    cached: CachedResponse,
495    request_headers: &HeaderMap,
496) -> Result<Response<Body>, StatusCode> {
497    let mut response_headers = cached.headers;
498    let body = if let Some(content_encoding) = cached.content_encoding {
499        if client_accepts_encoding(request_headers, content_encoding) {
500            upsert_vary_accept_encoding(&mut response_headers);
501            cached.body
502        } else {
503            if !identity_acceptable(request_headers) {
504                tracing::warn!(
505                    "Client does not accept cached encoding '{}' or identity fallback",
506                    content_encoding.as_header_value()
507                );
508                return Err(StatusCode::NOT_ACCEPTABLE);
509            }
510
511            response_headers.remove("content-encoding");
512            upsert_vary_accept_encoding(&mut response_headers);
513            match decompress_body(&cached.body, content_encoding) {
514                Ok(body) => body,
515                Err(error) => {
516                    tracing::error!("Failed to decompress cached response: {}", error);
517                    return Err(StatusCode::INTERNAL_SERVER_ERROR);
518                }
519            }
520        }
521    } else {
522        cached.body
523    };
524
525    response_headers.remove("transfer-encoding");
526    response_headers.insert("content-length".to_string(), body.len().to_string());
527
528    Ok(build_response(cached.status, response_headers, body))
529}
530
531fn build_cached_response(
532    status: u16,
533    response_headers: &reqwest::header::HeaderMap,
534    normalized_body: &[u8],
535    compress_strategy: &CompressStrategy,
536) -> anyhow::Result<CachedResponse> {
537    let mut headers = convert_headers_to_map(response_headers);
538    headers.remove("content-encoding");
539    headers.remove("content-length");
540    headers.remove("transfer-encoding");
541
542    let content_encoding = configured_encoding(compress_strategy);
543    let body = if let Some(content_encoding) = content_encoding {
544        let compressed = compress_body(normalized_body, content_encoding)?;
545        headers.insert(
546            "content-encoding".to_string(),
547            content_encoding.as_header_value().to_string(),
548        );
549        upsert_vary_accept_encoding(&mut headers);
550        compressed
551    } else {
552        normalized_body.to_vec()
553    };
554
555    headers.insert("content-length".to_string(), body.len().to_string());
556
557    Ok(CachedResponse {
558        body,
559        headers,
560        status,
561        content_encoding,
562    })
563}
564
565fn build_response_from_upstream(
566    status: u16,
567    response_headers: &reqwest::header::HeaderMap,
568    body: Vec<u8>,
569) -> Response<Body> {
570    let mut headers = convert_headers_to_map(response_headers);
571    headers.remove("transfer-encoding");
572    headers.insert("content-length".to_string(), body.len().to_string());
573    build_response(status, headers, body)
574}
575
576fn build_response(
577    status: u16,
578    response_headers: HashMap<String, String>,
579    body: Vec<u8>,
580) -> Response<Body> {
581    let mut response = Response::builder().status(status);
582
583    // Add headers
584    let headers = response.headers_mut().unwrap();
585    for (key, value) in response_headers {
586        if let Ok(header_name) = key.parse::<HeaderName>() {
587            if let Ok(header_value) = HeaderValue::from_str(&value) {
588                headers.insert(header_name, header_value);
589            } else {
590                tracing::warn!(
591                    "Failed to parse header value for key '{}': {:?}",
592                    key,
593                    value
594                );
595            }
596        } else {
597            tracing::warn!("Failed to parse header name: {}", key);
598        }
599    }
600
601    response.body(Body::from(body)).unwrap()
602}
603
604fn cached_response_is_allowed(strategy: &crate::CacheStrategy, cached: &CachedResponse) -> bool {
605    strategy.allows_content_type(
606        cached
607            .headers
608            .get("content-type")
609            .map(|value| value.as_str()),
610    )
611}
612
613fn body_contains_404_meta(body: &[u8]) -> bool {
614    let Ok(body_str) = std::str::from_utf8(body) else {
615        return false;
616    };
617
618    let name_dbl = "name=\"phantom-404\"";
619    let name_sgl = "name='phantom-404'";
620    let content_dbl = "content=\"true\"";
621    let content_sgl = "content='true'";
622
623    (body_str.contains(name_dbl) || body_str.contains(name_sgl))
624        && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
625}
626
627fn upsert_vary_accept_encoding(headers: &mut HashMap<String, String>) {
628    match headers.get_mut("vary") {
629        Some(value) => {
630            let has_accept_encoding = value
631                .split(',')
632                .any(|part| part.trim().eq_ignore_ascii_case("accept-encoding"));
633            if !has_accept_encoding {
634                value.push_str(", Accept-Encoding");
635            }
636        }
637        None => {
638            headers.insert("vary".to_string(), "Accept-Encoding".to_string());
639        }
640    }
641}
642
643fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
644    let mut req_headers = reqwest::header::HeaderMap::new();
645    for (key, value) in headers {
646        // Skip host header as reqwest will set it
647        if key == axum::http::header::HOST {
648            continue;
649        }
650        if let Ok(val) = value.to_str() {
651            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
652                req_headers.insert(key.clone(), header_value);
653            }
654        }
655    }
656    req_headers
657}
658
659/// Fetch a single path from the upstream server, compress it, and store it in the cache.
660/// Used by the snapshot worker for PreGenerate warm-up and runtime snapshot management.
661pub(crate) async fn fetch_and_cache_snapshot(
662    path: &str,
663    proxy_url: &str,
664    cache: &CacheStore,
665    compress_strategy: &CompressStrategy,
666    cache_key_fn: &std::sync::Arc<dyn Fn(&crate::RequestInfo) -> String + Send + Sync>,
667) -> anyhow::Result<()> {
668    let empty_headers = axum::http::HeaderMap::new();
669    let req_info = crate::RequestInfo {
670        method: "GET",
671        path,
672        query: "",
673        headers: &empty_headers,
674    };
675    let cache_key = cache_key_fn(&req_info);
676
677    let client = reqwest::Client::builder()
678        .no_brotli()
679        .no_deflate()
680        .no_gzip()
681        .build()
682        .map_err(|e| anyhow::anyhow!("Failed to build HTTP client for snapshot fetch: {}", e))?;
683
684    let url = format!("{}{}", proxy_url, path);
685    let response = client
686        .get(&url)
687        .send()
688        .await
689        .map_err(|e| anyhow::anyhow!("Failed to fetch snapshot '{}': {}", path, e))?;
690
691    let status = response.status().as_u16();
692    let response_headers = response.headers().clone();
693    let body_bytes = response
694        .bytes()
695        .await
696        .map_err(|e| anyhow::anyhow!("Failed to read snapshot response for '{}': {}", path, e))?
697        .to_vec();
698
699    let upstream_encoding = response_headers
700        .get(axum::http::header::CONTENT_ENCODING)
701        .and_then(|v| v.to_str().ok());
702    let normalized = decode_upstream_body(&body_bytes, upstream_encoding)
703        .map_err(|e| anyhow::anyhow!("Failed to decode snapshot body for '{}': {}", path, e))?;
704
705    let cached = build_cached_response(status, &response_headers, &normalized, compress_strategy)?;
706    cache.set(cache_key, cached).await;
707    tracing::debug!("Snapshot pre-generated: {}", path);
708    Ok(())
709}
710
711fn convert_headers_to_map(
712    headers: &reqwest::header::HeaderMap,
713) -> std::collections::HashMap<String, String> {
714    let mut map = std::collections::HashMap::new();
715    for (key, value) in headers {
716        if let Ok(val) = value.to_str() {
717            map.insert(key.as_str().to_ascii_lowercase(), val.to_string());
718        } else {
719            // Log when we can't convert a header (might be binary)
720            tracing::debug!("Could not convert header '{}' to string", key);
721        }
722    }
723    map
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::compression::{compress_body, ContentEncoding};
730    use axum::body::to_bytes;
731
732    fn response_headers() -> reqwest::header::HeaderMap {
733        let mut headers = reqwest::header::HeaderMap::new();
734        headers.insert(
735            reqwest::header::CONTENT_TYPE,
736            reqwest::header::HeaderValue::from_static("text/html; charset=utf-8"),
737        );
738        headers
739    }
740
741    #[test]
742    fn test_build_cached_response_uses_selected_encoding() {
743        let cached = build_cached_response(
744            200,
745            &response_headers(),
746            b"<html>compressed</html>",
747            &CompressStrategy::Gzip,
748        )
749        .unwrap();
750
751        assert_eq!(cached.content_encoding, Some(ContentEncoding::Gzip));
752        assert_eq!(
753            cached.headers.get("content-encoding"),
754            Some(&"gzip".to_string())
755        );
756        assert_eq!(
757            cached.headers.get("vary"),
758            Some(&"Accept-Encoding".to_string())
759        );
760    }
761
762    #[tokio::test]
763    async fn test_build_response_from_cache_falls_back_to_identity() {
764        let body = b"<html>identity</html>";
765        let compressed = compress_body(body, ContentEncoding::Brotli).unwrap();
766        let cached = CachedResponse {
767            body: compressed,
768            headers: HashMap::from([
769                ("content-type".to_string(), "text/html".to_string()),
770                ("content-encoding".to_string(), "br".to_string()),
771                ("content-length".to_string(), "123".to_string()),
772                ("vary".to_string(), "Accept-Encoding".to_string()),
773            ]),
774            status: 200,
775            content_encoding: Some(ContentEncoding::Brotli),
776        };
777
778        let mut request_headers = HeaderMap::new();
779        request_headers.insert(
780            axum::http::header::ACCEPT_ENCODING,
781            HeaderValue::from_static("gzip"),
782        );
783
784        let response = build_response_from_cache(cached, &request_headers).unwrap();
785        assert!(response
786            .headers()
787            .get(axum::http::header::CONTENT_ENCODING)
788            .is_none());
789
790        let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
791        assert_eq!(body.as_ref(), b"<html>identity</html>");
792    }
793
794    #[tokio::test]
795    async fn test_build_response_from_cache_keeps_supported_encoding() {
796        let body = b"<html>compressed</html>";
797        let compressed = compress_body(body, ContentEncoding::Brotli).unwrap();
798        let cached = CachedResponse {
799            body: compressed.clone(),
800            headers: HashMap::from([
801                ("content-type".to_string(), "text/html".to_string()),
802                ("content-encoding".to_string(), "br".to_string()),
803                ("content-length".to_string(), compressed.len().to_string()),
804                ("vary".to_string(), "Accept-Encoding".to_string()),
805            ]),
806            status: 200,
807            content_encoding: Some(ContentEncoding::Brotli),
808        };
809
810        let mut request_headers = HeaderMap::new();
811        request_headers.insert(
812            axum::http::header::ACCEPT_ENCODING,
813            HeaderValue::from_static("br, gzip;q=0.5"),
814        );
815
816        let response = build_response_from_cache(cached, &request_headers).unwrap();
817        assert_eq!(
818            response.headers().get(axum::http::header::CONTENT_ENCODING),
819            Some(&HeaderValue::from_static("br"))
820        );
821
822        let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
823        assert_eq!(body.as_ref(), compressed.as_slice());
824    }
825}