1use crate::cache::{CacheStore, CachedResponse};
2use crate::compression::{
3 client_accepts_encoding, compress_body_async, configured_encoding, decode_upstream_body_async,
4 decompress_body_async, identity_acceptable,
5};
6use crate::path_matcher::should_cache_path;
7use crate::{CompressStrategy, CreateProxyConfig, ProxyMode, WebhookType};
8use axum::{
9 body::Body,
10 extract::Extension,
11 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
12};
13use hyper_util::rt::TokioIo;
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18#[derive(Clone)]
19pub struct ProxyState {
20 cache: CacheStore,
21 config: CreateProxyConfig,
22 upstream_client: reqwest::Client,
23 webhook_client: reqwest::Client,
24}
25
26impl ProxyState {
27 pub fn new(
28 cache: CacheStore,
29 config: CreateProxyConfig,
30 upstream_client: reqwest::Client,
31 webhook_client: reqwest::Client,
32 ) -> Self {
33 Self {
34 cache,
35 config,
36 upstream_client,
37 webhook_client,
38 }
39 }
40}
41
42pub(crate) fn build_upstream_client() -> anyhow::Result<reqwest::Client> {
43 reqwest::Client::builder()
44 .pool_idle_timeout(Duration::from_secs(90))
45 .connect_timeout(Duration::from_secs(5))
46 .timeout(Duration::from_secs(30))
47 .tcp_keepalive(Duration::from_secs(30))
48 .no_brotli()
49 .no_deflate()
50 .no_gzip()
51 .build()
52 .map_err(Into::into)
53}
54
55pub(crate) fn build_webhook_client() -> anyhow::Result<reqwest::Client> {
56 reqwest::Client::builder()
57 .pool_idle_timeout(Duration::from_secs(30))
58 .connect_timeout(Duration::from_secs(3))
59 .redirect(reqwest::redirect::Policy::none())
60 .build()
61 .map_err(Into::into)
62}
63
64fn is_upgrade_request(headers: &HeaderMap) -> bool {
72 headers
73 .get(axum::http::header::CONNECTION)
74 .and_then(|v| v.to_str().ok())
75 .map(|v| v.to_lowercase().contains("upgrade"))
76 .unwrap_or(false)
77 || headers.contains_key(axum::http::header::UPGRADE)
78}
79
80fn build_webhook_payload(
86 method: &str,
87 path: &str,
88 query: &str,
89 headers: &HeaderMap,
90) -> serde_json::Value {
91 let headers_map: serde_json::Map<String, serde_json::Value> = headers
92 .iter()
93 .filter_map(|(name, value)| {
94 value.to_str().ok().map(|v| {
95 (
96 name.as_str().to_string(),
97 serde_json::Value::String(v.to_string()),
98 )
99 })
100 })
101 .collect();
102
103 serde_json::json!({
104 "method": method,
105 "path": path,
106 "query": query,
107 "headers": headers_map,
108 })
109}
110
111struct WebhookCallResult {
113 status: StatusCode,
115 location: Option<String>,
118 body: String,
121}
122
123async fn call_webhook(
132 client: &reqwest::Client,
133 url: &str,
134 payload: &serde_json::Value,
135 timeout_ms: u64,
136) -> Result<WebhookCallResult, ()> {
137 let response = client
138 .post(url)
139 .timeout(std::time::Duration::from_millis(timeout_ms))
140 .json(payload)
141 .send()
142 .await
143 .map_err(|_| ())?;
144
145 let status = StatusCode::from_u16(response.status().as_u16()).map_err(|_| ())?;
146 let location = response
147 .headers()
148 .get(reqwest::header::LOCATION)
149 .and_then(|v| v.to_str().ok())
150 .map(|s| s.to_string());
151 let body = response.text().await.unwrap_or_default();
152
153 Ok(WebhookCallResult {
154 status,
155 location,
156 body,
157 })
158}
159
160pub async fn proxy_handler(
163 Extension(state): Extension<Arc<ProxyState>>,
164 req: Request<Body>,
165) -> Result<Response<Body>, StatusCode> {
166 let request_started = Instant::now();
167 let is_upgrade = is_upgrade_request(req.headers());
170
171 if is_upgrade {
172 let method_str = req.method().as_str();
173 let path = req.uri().path();
174
175 let ws_allowed = state.config.enable_websocket
180 && match &state.config.proxy_mode {
181 ProxyMode::Dynamic => true,
182 ProxyMode::PreGenerate { fallthrough, .. } => *fallthrough,
183 };
184
185 if ws_allowed {
186 tracing::debug!(
187 "Upgrade request detected for {} {}, establishing direct proxy tunnel",
188 method_str,
189 path
190 );
191 return handle_upgrade_request(state, req).await;
192 } else {
193 tracing::warn!(
194 "Upgrade request detected for {} {} but WebSocket support is disabled or not available in current proxy mode",
195 method_str,
196 path
197 );
198 return Err(StatusCode::NOT_IMPLEMENTED);
199 }
200 }
201
202 let method = req.method().clone();
204 let method_str = method.as_str();
205 let uri = req.uri().clone();
206 let path = uri.path();
207 let query = uri.query().unwrap_or("");
208 let headers = req.headers().clone();
209 tracing::debug!(
210 method = method_str,
211 path,
212 query,
213 "proxy request entered handler"
214 );
215
216 if state.config.forward_get_only && method != axum::http::Method::GET {
218 tracing::warn!(
219 "Non-GET request {} {} rejected (forward_get_only is enabled)",
220 method_str,
221 path
222 );
223 return Err(StatusCode::METHOD_NOT_ALLOWED);
224 }
225
226 let mut cache_key_override: Option<String> = None;
230 if !state.config.webhooks.is_empty() {
231 let payload = build_webhook_payload(method_str, path, query, &headers);
232 let webhook_started = Instant::now();
233
234 for webhook in &state.config.webhooks {
235 match webhook.webhook_type {
236 WebhookType::Notify => {
237 let url = webhook.url.clone();
239 let payload_clone = payload.clone();
240 let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
241 let webhook_client = state.webhook_client.clone();
242 tokio::spawn(async move {
243 if let Err(()) =
244 call_webhook(&webhook_client, &url, &payload_clone, timeout_ms).await
245 {
246 tracing::warn!("Notify webhook POST to '{}' failed", url);
247 }
248 });
249 }
250 WebhookType::Blocking => {
251 let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
252 match call_webhook(&state.webhook_client, &webhook.url, &payload, timeout_ms)
253 .await
254 {
255 Ok(result) if result.status.is_success() => {
256 tracing::debug!(
257 "Blocking webhook '{}' allowed {} {}",
258 webhook.url,
259 method_str,
260 path
261 );
262 }
263 Ok(result) if result.status.is_redirection() => {
264 tracing::debug!(
265 "Blocking webhook '{}' redirecting {} {} to {}",
266 webhook.url,
267 method_str,
268 path,
269 result.location.as_deref().unwrap_or("(no location)")
270 );
271 let mut builder = Response::builder().status(result.status);
272 if let Some(loc) = &result.location {
273 builder =
274 builder.header(axum::http::header::LOCATION, loc.as_str());
275 }
276 return Ok(builder
277 .body(Body::empty())
278 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?);
279 }
280 Ok(result) => {
281 tracing::warn!(
282 "Blocking webhook '{}' denied {} {} with status {}",
283 webhook.url,
284 method_str,
285 path,
286 result.status
287 );
288 return Err(result.status);
289 }
290 Err(()) => {
291 tracing::warn!(
292 "Blocking webhook '{}' timed out or failed for {} {} — denying request",
293 webhook.url,
294 method_str,
295 path
296 );
297 return Err(StatusCode::SERVICE_UNAVAILABLE);
298 }
299 }
300 }
301 WebhookType::CacheKey => {
302 let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
303 match call_webhook(&state.webhook_client, &webhook.url, &payload, timeout_ms)
304 .await
305 {
306 Ok(result) if result.status.is_success() => {
307 let key = result.body.trim().to_string();
308 if !key.is_empty() {
309 tracing::debug!(
310 "Cache key webhook '{}' set key '{}' for {} {}",
311 webhook.url,
312 key,
313 method_str,
314 path
315 );
316 cache_key_override = Some(key);
317 } else {
318 tracing::warn!(
319 "Cache key webhook '{}' returned empty body for {} {} — using default key",
320 webhook.url,
321 method_str,
322 path
323 );
324 }
325 }
326 Ok(result) => {
327 tracing::warn!(
328 "Cache key webhook '{}' returned non-2xx {} for {} {} — using default key",
329 webhook.url,
330 result.status,
331 method_str,
332 path
333 );
334 }
335 Err(()) => {
336 tracing::warn!(
337 "Cache key webhook '{}' timed out or failed for {} {} — using default key",
338 webhook.url,
339 method_str,
340 path
341 );
342 }
343 }
344 }
345 }
346 }
347
348 tracing::debug!(
349 method = method_str,
350 path,
351 elapsed_ms = webhook_started.elapsed().as_millis(),
352 "proxy request completed webhook phase"
353 );
354 }
355
356 let should_cache = should_cache_path(
358 method_str,
359 path,
360 &state.config.include_paths,
361 &state.config.exclude_paths,
362 );
363
364 let req_info = crate::RequestInfo {
366 method: method_str,
367 path,
368 query,
369 headers: &headers,
370 };
371 let cache_key = cache_key_override.unwrap_or_else(|| (state.config.cache_key_fn)(&req_info));
372 let cache_reads_enabled = !matches!(state.config.cache_strategy, crate::CacheStrategy::None);
373
374 if cache_reads_enabled && state.config.cache_404_capacity > 0 {
376 if let Some(cached) = state.cache.get_404(&cache_key).await {
377 if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
378 tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
379 let response = build_response_from_cache(cached, &headers).await?;
380 tracing::debug!(
381 method = method_str,
382 path,
383 elapsed_ms = request_started.elapsed().as_millis(),
384 "proxy request served from 404 cache"
385 );
386 return Ok(response);
387 }
388 }
389 }
390
391 if should_cache && cache_reads_enabled {
393 if let Some(cached) = state.cache.get(&cache_key).await {
394 if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
395 tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
396 let response = build_response_from_cache(cached, &headers).await?;
397 tracing::debug!(
398 method = method_str,
399 path,
400 elapsed_ms = request_started.elapsed().as_millis(),
401 "proxy request served from main cache"
402 );
403 return Ok(response);
404 }
405 }
406 if let ProxyMode::PreGenerate { fallthrough, .. } = &state.config.proxy_mode {
408 if !fallthrough {
409 tracing::debug!(
410 "PreGenerate cache miss for: {} {} — returning 404 (fallthrough disabled)",
411 method_str,
412 cache_key
413 );
414 return Err(StatusCode::NOT_FOUND);
415 }
416 }
417 tracing::debug!(
418 "Cache miss for: {} {}, fetching from backend",
419 method_str,
420 cache_key
421 );
422 } else if !cache_reads_enabled {
423 tracing::debug!(
424 "{} {} not cacheable (cache strategy: none), proxying directly",
425 method_str,
426 path
427 );
428 } else {
429 tracing::debug!(
430 "{} {} not cacheable (filtered), proxying directly",
431 method_str,
432 path
433 );
434 }
435
436 let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
438 Ok(bytes) => bytes,
439 Err(e) => {
440 tracing::error!("Failed to read request body: {}", e);
441 return Err(StatusCode::BAD_REQUEST);
442 }
443 };
444
445 let path_and_query = uri
450 .path_and_query()
451 .map(|pq| pq.as_str())
452 .unwrap_or_else(|| uri.path());
453 let target_url = format!("{}{}", state.config.proxy_url, path_and_query);
454 let upstream_started = Instant::now();
455
456 let response = match state
457 .upstream_client
458 .request(method.clone(), &target_url)
459 .headers(convert_headers(&headers))
460 .body(body_bytes.to_vec())
461 .send()
462 .await
463 {
464 Ok(resp) => resp,
465 Err(e) => {
466 tracing::error!("Failed to fetch from backend: {}", e);
467 return Err(StatusCode::BAD_GATEWAY);
468 }
469 };
470 tracing::debug!(
471 method = method_str,
472 path,
473 elapsed_ms = upstream_started.elapsed().as_millis(),
474 "proxy request received upstream response headers"
475 );
476
477 let status = response.status().as_u16();
479 let response_headers = response.headers().clone();
480 let body_bytes = match response.bytes().await {
481 Ok(bytes) => bytes.to_vec(),
482 Err(e) => {
483 tracing::error!("Failed to read response body: {}", e);
484 return Err(StatusCode::BAD_GATEWAY);
485 }
486 };
487
488 let response_content_type = response_headers
489 .get(axum::http::header::CONTENT_TYPE)
490 .and_then(|value| value.to_str().ok());
491 let response_is_cacheable = state
492 .config
493 .cache_strategy
494 .allows_content_type(response_content_type);
495 let upstream_content_encoding = response_headers
496 .get(axum::http::header::CONTENT_ENCODING)
497 .and_then(|value| value.to_str().ok());
498 let should_try_cache = cache_reads_enabled
499 && response_is_cacheable
500 && (should_cache || state.config.cache_404_capacity > 0);
501 let normalized_body = if should_try_cache || state.config.use_404_meta {
502 match decode_upstream_body_async(
503 body_bytes.clone(),
504 upstream_content_encoding.map(|value| value.to_string()),
505 )
506 .await
507 {
508 Ok(body) => Some(body),
509 Err(error) => {
510 tracing::warn!(
511 "Skipping cache compression for {} {} due to unsupported upstream encoding: {}",
512 method_str,
513 path,
514 error
515 );
516 None
517 }
518 }
519 } else {
520 None
521 };
522
523 let mut is_404 = status == 404;
525 if !is_404 && state.config.use_404_meta {
526 if let Some(body) = normalized_body.as_deref() {
527 is_404 = body_contains_404_meta(body);
528 }
529 }
530
531 let should_store_404 = is_404
532 && state.config.cache_404_capacity > 0
533 && response_is_cacheable
534 && cache_reads_enabled
535 && normalized_body.is_some();
536 let should_store_response = !is_404
537 && should_cache
538 && response_is_cacheable
539 && cache_reads_enabled
540 && normalized_body.is_some();
541
542 if should_store_404 || should_store_response {
543 let cached_response = match build_cached_response(
544 status,
545 &response_headers,
546 normalized_body.as_deref().unwrap(),
547 &state.config.compress_strategy,
548 )
549 .await
550 {
551 Ok(cached_response) => cached_response,
552 Err(error) => {
553 tracing::warn!(
554 "Failed to prepare cached response for {} {}: {}",
555 method_str,
556 path,
557 error
558 );
559 return Ok(build_response_from_upstream(
560 status,
561 &response_headers,
562 body_bytes,
563 ));
564 }
565 };
566
567 if should_store_404 {
568 state
569 .cache
570 .set_404(cache_key.clone(), cached_response.clone())
571 .await;
572 tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
573 } else {
574 state
575 .cache
576 .set(cache_key.clone(), cached_response.clone())
577 .await;
578 tracing::debug!("Cached response for: {} {}", method_str, cache_key);
579 }
580
581 let response = build_response_from_cache(cached_response, &headers).await?;
582 tracing::debug!(
583 method = method_str,
584 path,
585 elapsed_ms = request_started.elapsed().as_millis(),
586 "proxy request completed after upstream fetch and cache write"
587 );
588 return Ok(response);
589 }
590
591 tracing::debug!(
592 method = method_str,
593 path,
594 elapsed_ms = request_started.elapsed().as_millis(),
595 "proxy request completed without caching"
596 );
597 Ok(build_response_from_upstream(
598 status,
599 &response_headers,
600 body_bytes,
601 ))
602}
603
604async fn handle_upgrade_request(
616 state: Arc<ProxyState>,
617 mut req: Request<Body>,
618) -> Result<Response<Body>, StatusCode> {
619 let req_path_and_query = req
621 .uri()
622 .path_and_query()
623 .map(|pq| pq.as_str())
624 .unwrap_or_else(|| req.uri().path());
625 let target_url = format!("{}{}", state.config.proxy_url, req_path_and_query);
626
627 let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
629 tracing::error!("Failed to parse backend URL: {}", e);
630 StatusCode::BAD_GATEWAY
631 })?;
632
633 let host = backend_uri.host().ok_or_else(|| {
634 tracing::error!("No host in backend URL");
635 StatusCode::BAD_GATEWAY
636 })?;
637
638 let port = backend_uri.port_u16().unwrap_or_else(|| {
639 if backend_uri.scheme_str() == Some("https") {
640 443
641 } else {
642 80
643 }
644 });
645
646 let client_upgrade = hyper::upgrade::on(&mut req);
649
650 let backend_stream = tokio::net::TcpStream::connect((host, port))
652 .await
653 .map_err(|e| {
654 tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
655 StatusCode::BAD_GATEWAY
656 })?;
657
658 let backend_io = TokioIo::new(backend_stream);
659
660 let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
662 .await
663 .map_err(|e| {
664 tracing::error!("Failed to handshake with backend: {}", e);
665 StatusCode::BAD_GATEWAY
666 })?;
667
668 let conn_task = tokio::spawn(async move {
670 match conn.with_upgrades().await {
671 Ok(parts) => {
672 tracing::debug!("Backend connection upgraded successfully");
673 Ok(parts)
674 }
675 Err(e) => {
676 tracing::error!("Backend connection failed: {}", e);
677 Err(e)
678 }
679 }
680 });
681
682 let backend_response = sender.send_request(req).await.map_err(|e| {
684 tracing::error!("Failed to send request to backend: {}", e);
685 StatusCode::BAD_GATEWAY
686 })?;
687
688 let status = backend_response.status();
690 if status != StatusCode::SWITCHING_PROTOCOLS {
691 tracing::warn!("Backend did not accept upgrade request, status: {}", status);
692 let (parts, body) = backend_response.into_parts();
694 let body = Body::new(body);
695 return Ok(Response::from_parts(parts, body));
696 }
697
698 let backend_headers = backend_response.headers().clone();
700
701 let backend_upgrade = hyper::upgrade::on(backend_response);
703
704 tokio::spawn(async move {
706 tracing::debug!("Starting upgrade tunnel establishment");
707
708 let (client_result, backend_result) = tokio::join!(client_upgrade, backend_upgrade);
710
711 drop(conn_task);
713
714 match (client_result, backend_result) {
715 (Ok(client_upgraded), Ok(backend_upgraded)) => {
716 tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
717
718 let mut client_stream = TokioIo::new(client_upgraded);
720 let mut backend_stream = TokioIo::new(backend_upgraded);
721
722 match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
724 Ok((client_to_backend, backend_to_client)) => {
725 tracing::debug!(
726 "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
727 client_to_backend,
728 backend_to_client
729 );
730 }
731 Err(e) => {
732 tracing::error!("Tunnel error: {}", e);
733 }
734 }
735 }
736 (Err(e), _) => {
737 tracing::error!("Client upgrade failed: {}", e);
738 }
739 (_, Err(e)) => {
740 tracing::error!("Backend upgrade failed: {}", e);
741 }
742 }
743 });
744
745 let mut response = Response::builder()
747 .status(StatusCode::SWITCHING_PROTOCOLS)
748 .body(Body::empty())
749 .unwrap();
750
751 if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
754 response
755 .headers_mut()
756 .insert(axum::http::header::UPGRADE, upgrade_header.clone());
757 }
758 if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
759 response
760 .headers_mut()
761 .insert(axum::http::header::CONNECTION, connection_header.clone());
762 }
763 if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
764 response.headers_mut().insert(
765 HeaderName::from_static("sec-websocket-accept"),
766 sec_websocket_accept.clone(),
767 );
768 }
769
770 tracing::debug!("Upgrade response sent to client, tunnel task spawned");
771
772 Ok(response)
773}
774
775async fn build_response_from_cache(
776 cached: CachedResponse,
777 request_headers: &HeaderMap,
778) -> Result<Response<Body>, StatusCode> {
779 let mut response_headers = cached.headers;
780 let body = if let Some(content_encoding) = cached.content_encoding {
781 if client_accepts_encoding(request_headers, content_encoding) {
782 upsert_vary_accept_encoding(&mut response_headers);
783 cached.body
784 } else {
785 if !identity_acceptable(request_headers) {
786 tracing::warn!(
787 "Client does not accept cached encoding '{}' or identity fallback",
788 content_encoding.as_header_value()
789 );
790 return Err(StatusCode::NOT_ACCEPTABLE);
791 }
792
793 response_headers.remove("content-encoding");
794 upsert_vary_accept_encoding(&mut response_headers);
795 match decompress_body_async(cached.body.clone(), content_encoding).await {
796 Ok(body) => body,
797 Err(error) => {
798 tracing::error!("Failed to decompress cached response: {}", error);
799 return Err(StatusCode::INTERNAL_SERVER_ERROR);
800 }
801 }
802 }
803 } else {
804 cached.body
805 };
806
807 response_headers.remove("transfer-encoding");
808 response_headers.insert("content-length".to_string(), body.len().to_string());
809
810 Ok(build_response(cached.status, response_headers, body))
811}
812
813async fn build_cached_response(
814 status: u16,
815 response_headers: &reqwest::header::HeaderMap,
816 normalized_body: &[u8],
817 compress_strategy: &CompressStrategy,
818) -> anyhow::Result<CachedResponse> {
819 let mut headers = convert_headers_to_map(response_headers);
820 headers.remove("content-encoding");
821 headers.remove("content-length");
822 headers.remove("transfer-encoding");
823
824 let content_encoding = configured_encoding(compress_strategy);
825 let body = if let Some(content_encoding) = content_encoding {
826 let compressed = compress_body_async(normalized_body.to_vec(), content_encoding).await?;
827 headers.insert(
828 "content-encoding".to_string(),
829 content_encoding.as_header_value().to_string(),
830 );
831 upsert_vary_accept_encoding(&mut headers);
832 compressed
833 } else {
834 normalized_body.to_vec()
835 };
836
837 headers.insert("content-length".to_string(), body.len().to_string());
838
839 Ok(CachedResponse {
840 body,
841 headers,
842 status,
843 content_encoding,
844 })
845}
846
847fn build_response_from_upstream(
848 status: u16,
849 response_headers: &reqwest::header::HeaderMap,
850 body: Vec<u8>,
851) -> Response<Body> {
852 let mut headers = convert_headers_to_map(response_headers);
853 headers.remove("transfer-encoding");
854 headers.insert("content-length".to_string(), body.len().to_string());
855 build_response(status, headers, body)
856}
857
858fn build_response(
859 status: u16,
860 response_headers: HashMap<String, String>,
861 body: Vec<u8>,
862) -> Response<Body> {
863 let mut response = Response::builder().status(status);
864
865 let headers = response.headers_mut().unwrap();
867 for (key, value) in response_headers {
868 if let Ok(header_name) = key.parse::<HeaderName>() {
869 if let Ok(header_value) = HeaderValue::from_str(&value) {
870 headers.insert(header_name, header_value);
871 } else {
872 tracing::warn!(
873 "Failed to parse header value for key '{}': {:?}",
874 key,
875 value
876 );
877 }
878 } else {
879 tracing::warn!("Failed to parse header name: {}", key);
880 }
881 }
882
883 response.body(Body::from(body)).unwrap()
884}
885
886fn cached_response_is_allowed(strategy: &crate::CacheStrategy, cached: &CachedResponse) -> bool {
887 strategy.allows_content_type(
888 cached
889 .headers
890 .get("content-type")
891 .map(|value| value.as_str()),
892 )
893}
894
895fn body_contains_404_meta(body: &[u8]) -> bool {
896 let Ok(body_str) = std::str::from_utf8(body) else {
897 return false;
898 };
899
900 let name_dbl = "name=\"phantom-404\"";
901 let name_sgl = "name='phantom-404'";
902 let content_dbl = "content=\"true\"";
903 let content_sgl = "content='true'";
904
905 (body_str.contains(name_dbl) || body_str.contains(name_sgl))
906 && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
907}
908
909fn upsert_vary_accept_encoding(headers: &mut HashMap<String, String>) {
910 match headers.get_mut("vary") {
911 Some(value) => {
912 let has_accept_encoding = value
913 .split(',')
914 .any(|part| part.trim().eq_ignore_ascii_case("accept-encoding"));
915 if !has_accept_encoding {
916 value.push_str(", Accept-Encoding");
917 }
918 }
919 None => {
920 headers.insert("vary".to_string(), "Accept-Encoding".to_string());
921 }
922 }
923}
924
925fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
926 let mut req_headers = reqwest::header::HeaderMap::new();
927 for (key, value) in headers {
928 if key == axum::http::header::HOST {
930 continue;
931 }
932 if let Ok(val) = value.to_str() {
933 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
934 req_headers.insert(key.clone(), header_value);
935 }
936 }
937 }
938 req_headers
939}
940
941pub(crate) async fn fetch_and_cache_snapshot(
944 path: &str,
945 client: &reqwest::Client,
946 proxy_url: &str,
947 cache: &CacheStore,
948 compress_strategy: &CompressStrategy,
949 cache_key_fn: &std::sync::Arc<dyn Fn(&crate::RequestInfo) -> String + Send + Sync>,
950) -> anyhow::Result<()> {
951 let empty_headers = axum::http::HeaderMap::new();
952 let req_info = crate::RequestInfo {
953 method: "GET",
954 path,
955 query: "",
956 headers: &empty_headers,
957 };
958 let cache_key = cache_key_fn(&req_info);
959
960 let url = format!("{}{}", proxy_url, path);
961 let response = client
962 .get(&url)
963 .send()
964 .await
965 .map_err(|e| anyhow::anyhow!("Failed to fetch snapshot '{}': {}", path, e))?;
966
967 let status = response.status().as_u16();
968 let response_headers = response.headers().clone();
969 let body_bytes = response
970 .bytes()
971 .await
972 .map_err(|e| anyhow::anyhow!("Failed to read snapshot response for '{}': {}", path, e))?
973 .to_vec();
974
975 let upstream_encoding = response_headers
976 .get(axum::http::header::CONTENT_ENCODING)
977 .and_then(|v| v.to_str().ok());
978 let normalized =
979 decode_upstream_body_async(body_bytes, upstream_encoding.map(|value| value.to_string()))
980 .await
981 .map_err(|e| anyhow::anyhow!("Failed to decode snapshot body for '{}': {}", path, e))?;
982
983 let cached =
984 build_cached_response(status, &response_headers, &normalized, compress_strategy).await?;
985 cache.set(cache_key, cached).await;
986 tracing::debug!("Snapshot pre-generated: {}", path);
987 Ok(())
988}
989
990fn convert_headers_to_map(
991 headers: &reqwest::header::HeaderMap,
992) -> std::collections::HashMap<String, String> {
993 let mut map = std::collections::HashMap::new();
994 for (key, value) in headers {
995 if let Ok(val) = value.to_str() {
996 map.insert(key.as_str().to_ascii_lowercase(), val.to_string());
997 } else {
998 tracing::debug!("Could not convert header '{}' to string", key);
1000 }
1001 }
1002 map
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::*;
1008 use crate::compression::ContentEncoding;
1009 use axum::body::to_bytes;
1010
1011 fn response_headers() -> reqwest::header::HeaderMap {
1012 let mut headers = reqwest::header::HeaderMap::new();
1013 headers.insert(
1014 reqwest::header::CONTENT_TYPE,
1015 reqwest::header::HeaderValue::from_static("text/html; charset=utf-8"),
1016 );
1017 headers
1018 }
1019
1020 #[tokio::test]
1021 async fn test_build_cached_response_uses_selected_encoding() {
1022 let cached = build_cached_response(
1023 200,
1024 &response_headers(),
1025 b"<html>compressed</html>",
1026 &CompressStrategy::Gzip,
1027 )
1028 .await
1029 .unwrap();
1030
1031 assert_eq!(cached.content_encoding, Some(ContentEncoding::Gzip));
1032 assert_eq!(
1033 cached.headers.get("content-encoding"),
1034 Some(&"gzip".to_string())
1035 );
1036 assert_eq!(
1037 cached.headers.get("vary"),
1038 Some(&"Accept-Encoding".to_string())
1039 );
1040 }
1041
1042 #[tokio::test]
1043 async fn test_build_response_from_cache_falls_back_to_identity() {
1044 let body = b"<html>identity</html>";
1045 let compressed = crate::compression::compress_body(body, ContentEncoding::Brotli).unwrap();
1046 let cached = CachedResponse {
1047 body: compressed,
1048 headers: HashMap::from([
1049 ("content-type".to_string(), "text/html".to_string()),
1050 ("content-encoding".to_string(), "br".to_string()),
1051 ("content-length".to_string(), "123".to_string()),
1052 ("vary".to_string(), "Accept-Encoding".to_string()),
1053 ]),
1054 status: 200,
1055 content_encoding: Some(ContentEncoding::Brotli),
1056 };
1057
1058 let mut request_headers = HeaderMap::new();
1059 request_headers.insert(
1060 axum::http::header::ACCEPT_ENCODING,
1061 HeaderValue::from_static("gzip"),
1062 );
1063
1064 let response = build_response_from_cache(cached, &request_headers)
1065 .await
1066 .unwrap();
1067 assert!(response
1068 .headers()
1069 .get(axum::http::header::CONTENT_ENCODING)
1070 .is_none());
1071
1072 let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
1073 assert_eq!(body.as_ref(), b"<html>identity</html>");
1074 }
1075
1076 #[tokio::test]
1077 async fn test_build_response_from_cache_keeps_supported_encoding() {
1078 let body = b"<html>compressed</html>";
1079 let compressed = crate::compression::compress_body(body, ContentEncoding::Brotli).unwrap();
1080 let cached = CachedResponse {
1081 body: compressed.clone(),
1082 headers: HashMap::from([
1083 ("content-type".to_string(), "text/html".to_string()),
1084 ("content-encoding".to_string(), "br".to_string()),
1085 ("content-length".to_string(), compressed.len().to_string()),
1086 ("vary".to_string(), "Accept-Encoding".to_string()),
1087 ]),
1088 status: 200,
1089 content_encoding: Some(ContentEncoding::Brotli),
1090 };
1091
1092 let mut request_headers = HeaderMap::new();
1093 request_headers.insert(
1094 axum::http::header::ACCEPT_ENCODING,
1095 HeaderValue::from_static("br, gzip;q=0.5"),
1096 );
1097
1098 let response = build_response_from_cache(cached, &request_headers)
1099 .await
1100 .unwrap();
1101 assert_eq!(
1102 response.headers().get(axum::http::header::CONTENT_ENCODING),
1103 Some(&HeaderValue::from_static("br"))
1104 );
1105
1106 let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
1107 assert_eq!(body.as_ref(), compressed.as_slice());
1108 }
1109}