1use crate::cache::{CacheStore, CachedResponse};
2use crate::compression::{
3 client_accepts_encoding, compress_body, configured_encoding, decode_upstream_body,
4 decompress_body, identity_acceptable,
5};
6use crate::path_matcher::should_cache_path;
7use crate::{CompressStrategy, CreateProxyConfig, ProxyMode, WebhookType};
8use axum::{
9 body::Body,
10 extract::Extension,
11 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
12};
13use hyper_util::rt::TokioIo;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub struct ProxyState {
19 cache: CacheStore,
20 config: CreateProxyConfig,
21}
22
23impl ProxyState {
24 pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
25 Self { cache, config }
26 }
27}
28
29fn is_upgrade_request(headers: &HeaderMap) -> bool {
37 headers
38 .get(axum::http::header::CONNECTION)
39 .and_then(|v| v.to_str().ok())
40 .map(|v| v.to_lowercase().contains("upgrade"))
41 .unwrap_or(false)
42 || headers.contains_key(axum::http::header::UPGRADE)
43}
44
45fn build_webhook_payload(
51 method: &str,
52 path: &str,
53 query: &str,
54 headers: &HeaderMap,
55) -> serde_json::Value {
56 let headers_map: serde_json::Map<String, serde_json::Value> = headers
57 .iter()
58 .filter_map(|(name, value)| {
59 value
60 .to_str()
61 .ok()
62 .map(|v| (name.as_str().to_string(), serde_json::Value::String(v.to_string())))
63 })
64 .collect();
65
66 serde_json::json!({
67 "method": method,
68 "path": path,
69 "query": query,
70 "headers": headers_map,
71 })
72}
73
74struct WebhookCallResult {
76 status: StatusCode,
78 location: Option<String>,
81 body: String,
84}
85
86async fn call_webhook(
95 url: &str,
96 payload: &serde_json::Value,
97 timeout_ms: u64,
98) -> Result<WebhookCallResult, ()> {
99 let client = reqwest::Client::builder()
100 .timeout(std::time::Duration::from_millis(timeout_ms))
101 .redirect(reqwest::redirect::Policy::none())
102 .build()
103 .map_err(|_| ())?;
104
105 let response = client
106 .post(url)
107 .json(payload)
108 .send()
109 .await
110 .map_err(|_| ())?;
111
112 let status = StatusCode::from_u16(response.status().as_u16()).map_err(|_| ())?;
113 let location = response
114 .headers()
115 .get(reqwest::header::LOCATION)
116 .and_then(|v| v.to_str().ok())
117 .map(|s| s.to_string());
118 let body = response.text().await.unwrap_or_default();
119
120 Ok(WebhookCallResult { status, location, body })
121}
122
123pub async fn proxy_handler(
126 Extension(state): Extension<Arc<ProxyState>>,
127 req: Request<Body>,
128) -> Result<Response<Body>, StatusCode> {
129 let is_upgrade = is_upgrade_request(req.headers());
132
133 if is_upgrade {
134 let method_str = req.method().as_str();
135 let path = req.uri().path();
136
137 let ws_allowed = state.config.enable_websocket
142 && match &state.config.proxy_mode {
143 ProxyMode::Dynamic => true,
144 ProxyMode::PreGenerate { fallthrough, .. } => *fallthrough,
145 };
146
147 if ws_allowed {
148 tracing::debug!(
149 "Upgrade request detected for {} {}, establishing direct proxy tunnel",
150 method_str,
151 path
152 );
153 return handle_upgrade_request(state, req).await;
154 } else {
155 tracing::warn!(
156 "Upgrade request detected for {} {} but WebSocket support is disabled or not available in current proxy mode",
157 method_str,
158 path
159 );
160 return Err(StatusCode::NOT_IMPLEMENTED);
161 }
162 }
163
164 let method = req.method().clone();
166 let method_str = method.as_str();
167 let uri = req.uri().clone();
168 let path = uri.path();
169 let query = uri.query().unwrap_or("");
170 let headers = req.headers().clone();
171
172 if state.config.forward_get_only && method != axum::http::Method::GET {
174 tracing::warn!(
175 "Non-GET request {} {} rejected (forward_get_only is enabled)",
176 method_str,
177 path
178 );
179 return Err(StatusCode::METHOD_NOT_ALLOWED);
180 }
181
182 let mut cache_key_override: Option<String> = None;
186 if !state.config.webhooks.is_empty() {
187 let payload = build_webhook_payload(method_str, path, query, &headers);
188
189 for webhook in &state.config.webhooks {
190 match webhook.webhook_type {
191 WebhookType::Notify => {
192 let url = webhook.url.clone();
194 let payload_clone = payload.clone();
195 let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
196 tokio::spawn(async move {
197 if let Err(()) = call_webhook(&url, &payload_clone, timeout_ms).await {
198 tracing::warn!("Notify webhook POST to '{}' failed", url);
199 }
200 });
201 }
202 WebhookType::Blocking => {
203 let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
204 match call_webhook(&webhook.url, &payload, timeout_ms).await {
205 Ok(result) if result.status.is_success() => {
206 tracing::debug!(
207 "Blocking webhook '{}' allowed {} {}",
208 webhook.url,
209 method_str,
210 path
211 );
212 }
213 Ok(result) if result.status.is_redirection() => {
214 tracing::debug!(
215 "Blocking webhook '{}' redirecting {} {} to {}",
216 webhook.url,
217 method_str,
218 path,
219 result.location.as_deref().unwrap_or("(no location)")
220 );
221 let mut builder = Response::builder().status(result.status);
222 if let Some(loc) = &result.location {
223 builder = builder.header(axum::http::header::LOCATION, loc.as_str());
224 }
225 return Ok(builder
226 .body(Body::empty())
227 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?);
228 }
229 Ok(result) => {
230 tracing::warn!(
231 "Blocking webhook '{}' denied {} {} with status {}",
232 webhook.url,
233 method_str,
234 path,
235 result.status
236 );
237 return Err(result.status);
238 }
239 Err(()) => {
240 tracing::warn!(
241 "Blocking webhook '{}' timed out or failed for {} {} — denying request",
242 webhook.url,
243 method_str,
244 path
245 );
246 return Err(StatusCode::SERVICE_UNAVAILABLE);
247 }
248 }
249 }
250 WebhookType::CacheKey => {
251 let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
252 match call_webhook(&webhook.url, &payload, timeout_ms).await {
253 Ok(result) if result.status.is_success() => {
254 let key = result.body.trim().to_string();
255 if !key.is_empty() {
256 tracing::debug!(
257 "Cache key webhook '{}' set key '{}' for {} {}",
258 webhook.url,
259 key,
260 method_str,
261 path
262 );
263 cache_key_override = Some(key);
264 } else {
265 tracing::warn!(
266 "Cache key webhook '{}' returned empty body for {} {} — using default key",
267 webhook.url,
268 method_str,
269 path
270 );
271 }
272 }
273 Ok(result) => {
274 tracing::warn!(
275 "Cache key webhook '{}' returned non-2xx {} for {} {} — using default key",
276 webhook.url,
277 result.status,
278 method_str,
279 path
280 );
281 }
282 Err(()) => {
283 tracing::warn!(
284 "Cache key webhook '{}' timed out or failed for {} {} — using default key",
285 webhook.url,
286 method_str,
287 path
288 );
289 }
290 }
291 }
292 }
293 }
294 }
295
296 let should_cache = should_cache_path(
298 method_str,
299 path,
300 &state.config.include_paths,
301 &state.config.exclude_paths,
302 );
303
304 let req_info = crate::RequestInfo {
306 method: method_str,
307 path,
308 query,
309 headers: &headers,
310 };
311 let cache_key = cache_key_override
312 .unwrap_or_else(|| (state.config.cache_key_fn)(&req_info));
313 let cache_reads_enabled = !matches!(state.config.cache_strategy, crate::CacheStrategy::None);
314
315 if cache_reads_enabled && state.config.cache_404_capacity > 0 {
317 if let Some(cached) = state.cache.get_404(&cache_key).await {
318 if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
319 tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
320 return build_response_from_cache(cached, &headers);
321 }
322 }
323 }
324
325 if should_cache && cache_reads_enabled {
327 if let Some(cached) = state.cache.get(&cache_key).await {
328 if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
329 tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
330 return build_response_from_cache(cached, &headers);
331 }
332 }
333 if let ProxyMode::PreGenerate { fallthrough, .. } = &state.config.proxy_mode {
335 if !fallthrough {
336 tracing::debug!(
337 "PreGenerate cache miss for: {} {} — returning 404 (fallthrough disabled)",
338 method_str,
339 cache_key
340 );
341 return Err(StatusCode::NOT_FOUND);
342 }
343 }
344 tracing::debug!(
345 "Cache miss for: {} {}, fetching from backend",
346 method_str,
347 cache_key
348 );
349 } else if !cache_reads_enabled {
350 tracing::debug!(
351 "{} {} not cacheable (cache strategy: none), proxying directly",
352 method_str,
353 path
354 );
355 } else {
356 tracing::debug!(
357 "{} {} not cacheable (filtered), proxying directly",
358 method_str,
359 path
360 );
361 }
362
363 let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
365 Ok(bytes) => bytes,
366 Err(e) => {
367 tracing::error!("Failed to read request body: {}", e);
368 return Err(StatusCode::BAD_REQUEST);
369 }
370 };
371
372 let path_and_query = uri
377 .path_and_query()
378 .map(|pq| pq.as_str())
379 .unwrap_or_else(|| uri.path());
380 let target_url = format!("{}{}", state.config.proxy_url, path_and_query);
381 let client = match reqwest::Client::builder()
382 .no_brotli()
383 .no_deflate()
384 .no_gzip()
385 .build()
386 {
387 Ok(client) => client,
388 Err(error) => {
389 tracing::error!("Failed to build upstream HTTP client: {}", error);
390 return Err(StatusCode::INTERNAL_SERVER_ERROR);
391 }
392 };
393
394 let response = match client
395 .request(method.clone(), &target_url)
396 .headers(convert_headers(&headers))
397 .body(body_bytes.to_vec())
398 .send()
399 .await
400 {
401 Ok(resp) => resp,
402 Err(e) => {
403 tracing::error!("Failed to fetch from backend: {}", e);
404 return Err(StatusCode::BAD_GATEWAY);
405 }
406 };
407
408 let status = response.status().as_u16();
410 let response_headers = response.headers().clone();
411 let body_bytes = match response.bytes().await {
412 Ok(bytes) => bytes.to_vec(),
413 Err(e) => {
414 tracing::error!("Failed to read response body: {}", e);
415 return Err(StatusCode::BAD_GATEWAY);
416 }
417 };
418
419 let response_content_type = response_headers
420 .get(axum::http::header::CONTENT_TYPE)
421 .and_then(|value| value.to_str().ok());
422 let response_is_cacheable = state
423 .config
424 .cache_strategy
425 .allows_content_type(response_content_type);
426 let upstream_content_encoding = response_headers
427 .get(axum::http::header::CONTENT_ENCODING)
428 .and_then(|value| value.to_str().ok());
429 let should_try_cache = cache_reads_enabled
430 && response_is_cacheable
431 && (should_cache || state.config.cache_404_capacity > 0);
432 let normalized_body = if should_try_cache || state.config.use_404_meta {
433 match decode_upstream_body(&body_bytes, upstream_content_encoding) {
434 Ok(body) => Some(body),
435 Err(error) => {
436 tracing::warn!(
437 "Skipping cache compression for {} {} due to unsupported upstream encoding: {}",
438 method_str,
439 path,
440 error
441 );
442 None
443 }
444 }
445 } else {
446 None
447 };
448
449 let mut is_404 = status == 404;
451 if !is_404 && state.config.use_404_meta {
452 if let Some(body) = normalized_body.as_deref() {
453 is_404 = body_contains_404_meta(body);
454 }
455 }
456
457 let should_store_404 = is_404
458 && state.config.cache_404_capacity > 0
459 && response_is_cacheable
460 && cache_reads_enabled
461 && normalized_body.is_some();
462 let should_store_response = !is_404
463 && should_cache
464 && response_is_cacheable
465 && cache_reads_enabled
466 && normalized_body.is_some();
467
468 if should_store_404 || should_store_response {
469 let cached_response = match build_cached_response(
470 status,
471 &response_headers,
472 normalized_body.as_deref().unwrap(),
473 &state.config.compress_strategy,
474 ) {
475 Ok(cached_response) => cached_response,
476 Err(error) => {
477 tracing::warn!(
478 "Failed to prepare cached response for {} {}: {}",
479 method_str,
480 path,
481 error
482 );
483 return Ok(build_response_from_upstream(
484 status,
485 &response_headers,
486 body_bytes,
487 ));
488 }
489 };
490
491 if should_store_404 {
492 state
493 .cache
494 .set_404(cache_key.clone(), cached_response.clone())
495 .await;
496 tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
497 } else {
498 state
499 .cache
500 .set(cache_key.clone(), cached_response.clone())
501 .await;
502 tracing::debug!("Cached response for: {} {}", method_str, cache_key);
503 }
504
505 return build_response_from_cache(cached_response, &headers);
506 }
507
508 Ok(build_response_from_upstream(
509 status,
510 &response_headers,
511 body_bytes,
512 ))
513}
514
515async fn handle_upgrade_request(
527 state: Arc<ProxyState>,
528 mut req: Request<Body>,
529) -> Result<Response<Body>, StatusCode> {
530 let req_path_and_query = req
532 .uri()
533 .path_and_query()
534 .map(|pq| pq.as_str())
535 .unwrap_or_else(|| req.uri().path());
536 let target_url = format!("{}{}", state.config.proxy_url, req_path_and_query);
537
538 let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
540 tracing::error!("Failed to parse backend URL: {}", e);
541 StatusCode::BAD_GATEWAY
542 })?;
543
544 let host = backend_uri.host().ok_or_else(|| {
545 tracing::error!("No host in backend URL");
546 StatusCode::BAD_GATEWAY
547 })?;
548
549 let port = backend_uri.port_u16().unwrap_or_else(|| {
550 if backend_uri.scheme_str() == Some("https") {
551 443
552 } else {
553 80
554 }
555 });
556
557 let client_upgrade = hyper::upgrade::on(&mut req);
560
561 let backend_stream = tokio::net::TcpStream::connect((host, port))
563 .await
564 .map_err(|e| {
565 tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
566 StatusCode::BAD_GATEWAY
567 })?;
568
569 let backend_io = TokioIo::new(backend_stream);
570
571 let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
573 .await
574 .map_err(|e| {
575 tracing::error!("Failed to handshake with backend: {}", e);
576 StatusCode::BAD_GATEWAY
577 })?;
578
579 let conn_task = tokio::spawn(async move {
581 match conn.with_upgrades().await {
582 Ok(parts) => {
583 tracing::debug!("Backend connection upgraded successfully");
584 Ok(parts)
585 }
586 Err(e) => {
587 tracing::error!("Backend connection failed: {}", e);
588 Err(e)
589 }
590 }
591 });
592
593 let backend_response = sender.send_request(req).await.map_err(|e| {
595 tracing::error!("Failed to send request to backend: {}", e);
596 StatusCode::BAD_GATEWAY
597 })?;
598
599 let status = backend_response.status();
601 if status != StatusCode::SWITCHING_PROTOCOLS {
602 tracing::warn!("Backend did not accept upgrade request, status: {}", status);
603 let (parts, body) = backend_response.into_parts();
605 let body = Body::new(body);
606 return Ok(Response::from_parts(parts, body));
607 }
608
609 let backend_headers = backend_response.headers().clone();
611
612 let backend_upgrade = hyper::upgrade::on(backend_response);
614
615 tokio::spawn(async move {
617 tracing::debug!("Starting upgrade tunnel establishment");
618
619 let (client_result, backend_result) = tokio::join!(client_upgrade, backend_upgrade);
621
622 drop(conn_task);
624
625 match (client_result, backend_result) {
626 (Ok(client_upgraded), Ok(backend_upgraded)) => {
627 tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
628
629 let mut client_stream = TokioIo::new(client_upgraded);
631 let mut backend_stream = TokioIo::new(backend_upgraded);
632
633 match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
635 Ok((client_to_backend, backend_to_client)) => {
636 tracing::debug!(
637 "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
638 client_to_backend,
639 backend_to_client
640 );
641 }
642 Err(e) => {
643 tracing::error!("Tunnel error: {}", e);
644 }
645 }
646 }
647 (Err(e), _) => {
648 tracing::error!("Client upgrade failed: {}", e);
649 }
650 (_, Err(e)) => {
651 tracing::error!("Backend upgrade failed: {}", e);
652 }
653 }
654 });
655
656 let mut response = Response::builder()
658 .status(StatusCode::SWITCHING_PROTOCOLS)
659 .body(Body::empty())
660 .unwrap();
661
662 if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
665 response
666 .headers_mut()
667 .insert(axum::http::header::UPGRADE, upgrade_header.clone());
668 }
669 if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
670 response
671 .headers_mut()
672 .insert(axum::http::header::CONNECTION, connection_header.clone());
673 }
674 if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
675 response.headers_mut().insert(
676 HeaderName::from_static("sec-websocket-accept"),
677 sec_websocket_accept.clone(),
678 );
679 }
680
681 tracing::debug!("Upgrade response sent to client, tunnel task spawned");
682
683 Ok(response)
684}
685
686fn build_response_from_cache(
687 cached: CachedResponse,
688 request_headers: &HeaderMap,
689) -> Result<Response<Body>, StatusCode> {
690 let mut response_headers = cached.headers;
691 let body = if let Some(content_encoding) = cached.content_encoding {
692 if client_accepts_encoding(request_headers, content_encoding) {
693 upsert_vary_accept_encoding(&mut response_headers);
694 cached.body
695 } else {
696 if !identity_acceptable(request_headers) {
697 tracing::warn!(
698 "Client does not accept cached encoding '{}' or identity fallback",
699 content_encoding.as_header_value()
700 );
701 return Err(StatusCode::NOT_ACCEPTABLE);
702 }
703
704 response_headers.remove("content-encoding");
705 upsert_vary_accept_encoding(&mut response_headers);
706 match decompress_body(&cached.body, content_encoding) {
707 Ok(body) => body,
708 Err(error) => {
709 tracing::error!("Failed to decompress cached response: {}", error);
710 return Err(StatusCode::INTERNAL_SERVER_ERROR);
711 }
712 }
713 }
714 } else {
715 cached.body
716 };
717
718 response_headers.remove("transfer-encoding");
719 response_headers.insert("content-length".to_string(), body.len().to_string());
720
721 Ok(build_response(cached.status, response_headers, body))
722}
723
724fn build_cached_response(
725 status: u16,
726 response_headers: &reqwest::header::HeaderMap,
727 normalized_body: &[u8],
728 compress_strategy: &CompressStrategy,
729) -> anyhow::Result<CachedResponse> {
730 let mut headers = convert_headers_to_map(response_headers);
731 headers.remove("content-encoding");
732 headers.remove("content-length");
733 headers.remove("transfer-encoding");
734
735 let content_encoding = configured_encoding(compress_strategy);
736 let body = if let Some(content_encoding) = content_encoding {
737 let compressed = compress_body(normalized_body, content_encoding)?;
738 headers.insert(
739 "content-encoding".to_string(),
740 content_encoding.as_header_value().to_string(),
741 );
742 upsert_vary_accept_encoding(&mut headers);
743 compressed
744 } else {
745 normalized_body.to_vec()
746 };
747
748 headers.insert("content-length".to_string(), body.len().to_string());
749
750 Ok(CachedResponse {
751 body,
752 headers,
753 status,
754 content_encoding,
755 })
756}
757
758fn build_response_from_upstream(
759 status: u16,
760 response_headers: &reqwest::header::HeaderMap,
761 body: Vec<u8>,
762) -> Response<Body> {
763 let mut headers = convert_headers_to_map(response_headers);
764 headers.remove("transfer-encoding");
765 headers.insert("content-length".to_string(), body.len().to_string());
766 build_response(status, headers, body)
767}
768
769fn build_response(
770 status: u16,
771 response_headers: HashMap<String, String>,
772 body: Vec<u8>,
773) -> Response<Body> {
774 let mut response = Response::builder().status(status);
775
776 let headers = response.headers_mut().unwrap();
778 for (key, value) in response_headers {
779 if let Ok(header_name) = key.parse::<HeaderName>() {
780 if let Ok(header_value) = HeaderValue::from_str(&value) {
781 headers.insert(header_name, header_value);
782 } else {
783 tracing::warn!(
784 "Failed to parse header value for key '{}': {:?}",
785 key,
786 value
787 );
788 }
789 } else {
790 tracing::warn!("Failed to parse header name: {}", key);
791 }
792 }
793
794 response.body(Body::from(body)).unwrap()
795}
796
797fn cached_response_is_allowed(strategy: &crate::CacheStrategy, cached: &CachedResponse) -> bool {
798 strategy.allows_content_type(
799 cached
800 .headers
801 .get("content-type")
802 .map(|value| value.as_str()),
803 )
804}
805
806fn body_contains_404_meta(body: &[u8]) -> bool {
807 let Ok(body_str) = std::str::from_utf8(body) else {
808 return false;
809 };
810
811 let name_dbl = "name=\"phantom-404\"";
812 let name_sgl = "name='phantom-404'";
813 let content_dbl = "content=\"true\"";
814 let content_sgl = "content='true'";
815
816 (body_str.contains(name_dbl) || body_str.contains(name_sgl))
817 && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
818}
819
820fn upsert_vary_accept_encoding(headers: &mut HashMap<String, String>) {
821 match headers.get_mut("vary") {
822 Some(value) => {
823 let has_accept_encoding = value
824 .split(',')
825 .any(|part| part.trim().eq_ignore_ascii_case("accept-encoding"));
826 if !has_accept_encoding {
827 value.push_str(", Accept-Encoding");
828 }
829 }
830 None => {
831 headers.insert("vary".to_string(), "Accept-Encoding".to_string());
832 }
833 }
834}
835
836fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
837 let mut req_headers = reqwest::header::HeaderMap::new();
838 for (key, value) in headers {
839 if key == axum::http::header::HOST {
841 continue;
842 }
843 if let Ok(val) = value.to_str() {
844 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
845 req_headers.insert(key.clone(), header_value);
846 }
847 }
848 }
849 req_headers
850}
851
852pub(crate) async fn fetch_and_cache_snapshot(
855 path: &str,
856 proxy_url: &str,
857 cache: &CacheStore,
858 compress_strategy: &CompressStrategy,
859 cache_key_fn: &std::sync::Arc<dyn Fn(&crate::RequestInfo) -> String + Send + Sync>,
860) -> anyhow::Result<()> {
861 let empty_headers = axum::http::HeaderMap::new();
862 let req_info = crate::RequestInfo {
863 method: "GET",
864 path,
865 query: "",
866 headers: &empty_headers,
867 };
868 let cache_key = cache_key_fn(&req_info);
869
870 let client = reqwest::Client::builder()
871 .no_brotli()
872 .no_deflate()
873 .no_gzip()
874 .build()
875 .map_err(|e| anyhow::anyhow!("Failed to build HTTP client for snapshot fetch: {}", e))?;
876
877 let url = format!("{}{}", proxy_url, path);
878 let response = client
879 .get(&url)
880 .send()
881 .await
882 .map_err(|e| anyhow::anyhow!("Failed to fetch snapshot '{}': {}", path, e))?;
883
884 let status = response.status().as_u16();
885 let response_headers = response.headers().clone();
886 let body_bytes = response
887 .bytes()
888 .await
889 .map_err(|e| anyhow::anyhow!("Failed to read snapshot response for '{}': {}", path, e))?
890 .to_vec();
891
892 let upstream_encoding = response_headers
893 .get(axum::http::header::CONTENT_ENCODING)
894 .and_then(|v| v.to_str().ok());
895 let normalized = decode_upstream_body(&body_bytes, upstream_encoding)
896 .map_err(|e| anyhow::anyhow!("Failed to decode snapshot body for '{}': {}", path, e))?;
897
898 let cached = build_cached_response(status, &response_headers, &normalized, compress_strategy)?;
899 cache.set(cache_key, cached).await;
900 tracing::debug!("Snapshot pre-generated: {}", path);
901 Ok(())
902}
903
904fn convert_headers_to_map(
905 headers: &reqwest::header::HeaderMap,
906) -> std::collections::HashMap<String, String> {
907 let mut map = std::collections::HashMap::new();
908 for (key, value) in headers {
909 if let Ok(val) = value.to_str() {
910 map.insert(key.as_str().to_ascii_lowercase(), val.to_string());
911 } else {
912 tracing::debug!("Could not convert header '{}' to string", key);
914 }
915 }
916 map
917}
918
919#[cfg(test)]
920mod tests {
921 use super::*;
922 use crate::compression::{compress_body, ContentEncoding};
923 use axum::body::to_bytes;
924
925 fn response_headers() -> reqwest::header::HeaderMap {
926 let mut headers = reqwest::header::HeaderMap::new();
927 headers.insert(
928 reqwest::header::CONTENT_TYPE,
929 reqwest::header::HeaderValue::from_static("text/html; charset=utf-8"),
930 );
931 headers
932 }
933
934 #[test]
935 fn test_build_cached_response_uses_selected_encoding() {
936 let cached = build_cached_response(
937 200,
938 &response_headers(),
939 b"<html>compressed</html>",
940 &CompressStrategy::Gzip,
941 )
942 .unwrap();
943
944 assert_eq!(cached.content_encoding, Some(ContentEncoding::Gzip));
945 assert_eq!(
946 cached.headers.get("content-encoding"),
947 Some(&"gzip".to_string())
948 );
949 assert_eq!(
950 cached.headers.get("vary"),
951 Some(&"Accept-Encoding".to_string())
952 );
953 }
954
955 #[tokio::test]
956 async fn test_build_response_from_cache_falls_back_to_identity() {
957 let body = b"<html>identity</html>";
958 let compressed = compress_body(body, ContentEncoding::Brotli).unwrap();
959 let cached = CachedResponse {
960 body: compressed,
961 headers: HashMap::from([
962 ("content-type".to_string(), "text/html".to_string()),
963 ("content-encoding".to_string(), "br".to_string()),
964 ("content-length".to_string(), "123".to_string()),
965 ("vary".to_string(), "Accept-Encoding".to_string()),
966 ]),
967 status: 200,
968 content_encoding: Some(ContentEncoding::Brotli),
969 };
970
971 let mut request_headers = HeaderMap::new();
972 request_headers.insert(
973 axum::http::header::ACCEPT_ENCODING,
974 HeaderValue::from_static("gzip"),
975 );
976
977 let response = build_response_from_cache(cached, &request_headers).unwrap();
978 assert!(response
979 .headers()
980 .get(axum::http::header::CONTENT_ENCODING)
981 .is_none());
982
983 let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
984 assert_eq!(body.as_ref(), b"<html>identity</html>");
985 }
986
987 #[tokio::test]
988 async fn test_build_response_from_cache_keeps_supported_encoding() {
989 let body = b"<html>compressed</html>";
990 let compressed = compress_body(body, ContentEncoding::Brotli).unwrap();
991 let cached = CachedResponse {
992 body: compressed.clone(),
993 headers: HashMap::from([
994 ("content-type".to_string(), "text/html".to_string()),
995 ("content-encoding".to_string(), "br".to_string()),
996 ("content-length".to_string(), compressed.len().to_string()),
997 ("vary".to_string(), "Accept-Encoding".to_string()),
998 ]),
999 status: 200,
1000 content_encoding: Some(ContentEncoding::Brotli),
1001 };
1002
1003 let mut request_headers = HeaderMap::new();
1004 request_headers.insert(
1005 axum::http::header::ACCEPT_ENCODING,
1006 HeaderValue::from_static("br, gzip;q=0.5"),
1007 );
1008
1009 let response = build_response_from_cache(cached, &request_headers).unwrap();
1010 assert_eq!(
1011 response.headers().get(axum::http::header::CONTENT_ENCODING),
1012 Some(&HeaderValue::from_static("br"))
1013 );
1014
1015 let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
1016 assert_eq!(body.as_ref(), compressed.as_slice());
1017 }
1018}