use crate::cache::{CacheStore, CachedResponse};
use crate::compression::{
client_accepts_encoding, compress_body_async, configured_encoding, decode_upstream_body_async,
decompress_body_async, identity_acceptable,
};
use crate::path_matcher::should_cache_path;
use crate::{CompressStrategy, CreateProxyConfig, ProxyMode, WebhookType};
use axum::{
body::Body,
extract::Extension,
http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
};
use hyper_util::rt::TokioIo;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Clone)]
pub struct ProxyState {
cache: CacheStore,
config: CreateProxyConfig,
upstream_client: reqwest::Client,
webhook_client: reqwest::Client,
}
impl ProxyState {
pub fn new(
cache: CacheStore,
config: CreateProxyConfig,
upstream_client: reqwest::Client,
webhook_client: reqwest::Client,
) -> Self {
Self {
cache,
config,
upstream_client,
webhook_client,
}
}
}
pub(crate) fn build_upstream_client() -> anyhow::Result<reqwest::Client> {
reqwest::Client::builder()
.pool_idle_timeout(Duration::from_secs(90))
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(30))
.tcp_keepalive(Duration::from_secs(30))
.no_brotli()
.no_deflate()
.no_gzip()
.build()
.map_err(Into::into)
}
pub(crate) fn build_webhook_client() -> anyhow::Result<reqwest::Client> {
reqwest::Client::builder()
.pool_idle_timeout(Duration::from_secs(30))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(Into::into)
}
fn is_upgrade_request(headers: &HeaderMap) -> bool {
headers
.get(axum::http::header::CONNECTION)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_lowercase().contains("upgrade"))
.unwrap_or(false)
|| headers.contains_key(axum::http::header::UPGRADE)
}
fn build_webhook_payload(
method: &str,
path: &str,
query: &str,
headers: &HeaderMap,
) -> serde_json::Value {
let headers_map: serde_json::Map<String, serde_json::Value> = headers
.iter()
.filter_map(|(name, value)| {
value.to_str().ok().map(|v| {
(
name.as_str().to_string(),
serde_json::Value::String(v.to_string()),
)
})
})
.collect();
serde_json::json!({
"method": method,
"path": path,
"query": query,
"headers": headers_map,
})
}
struct WebhookCallResult {
status: StatusCode,
location: Option<String>,
body: String,
}
async fn call_webhook(
client: &reqwest::Client,
url: &str,
payload: &serde_json::Value,
timeout_ms: u64,
) -> Result<WebhookCallResult, ()> {
let response = client
.post(url)
.timeout(std::time::Duration::from_millis(timeout_ms))
.json(payload)
.send()
.await
.map_err(|_| ())?;
let status = StatusCode::from_u16(response.status().as_u16()).map_err(|_| ())?;
let location = response
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = response.text().await.unwrap_or_default();
Ok(WebhookCallResult {
status,
location,
body,
})
}
pub async fn proxy_handler(
Extension(state): Extension<Arc<ProxyState>>,
req: Request<Body>,
) -> Result<Response<Body>, StatusCode> {
let request_started = Instant::now();
let is_upgrade = is_upgrade_request(req.headers());
if is_upgrade {
let method_str = req.method().as_str();
let path = req.uri().path();
let ws_allowed = state.config.enable_websocket
&& match &state.config.proxy_mode {
ProxyMode::Dynamic => true,
ProxyMode::PreGenerate { fallthrough, .. } => *fallthrough,
};
if ws_allowed {
tracing::debug!(
"Upgrade request detected for {} {}, establishing direct proxy tunnel",
method_str,
path
);
return handle_upgrade_request(state, req).await;
} else {
tracing::warn!(
"Upgrade request detected for {} {} but WebSocket support is disabled or not available in current proxy mode",
method_str,
path
);
return Err(StatusCode::NOT_IMPLEMENTED);
}
}
let method = req.method().clone();
let method_str = method.as_str();
let uri = req.uri().clone();
let path = uri.path();
let query = uri.query().unwrap_or("");
let headers = req.headers().clone();
tracing::debug!(
method = method_str,
path,
query,
"proxy request entered handler"
);
if state.config.forward_get_only && method != axum::http::Method::GET {
tracing::warn!(
"Non-GET request {} {} rejected (forward_get_only is enabled)",
method_str,
path
);
return Err(StatusCode::METHOD_NOT_ALLOWED);
}
let mut cache_key_override: Option<String> = None;
if !state.config.webhooks.is_empty() {
let payload = build_webhook_payload(method_str, path, query, &headers);
let webhook_started = Instant::now();
for webhook in &state.config.webhooks {
match webhook.webhook_type {
WebhookType::Notify => {
let url = webhook.url.clone();
let payload_clone = payload.clone();
let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
let webhook_client = state.webhook_client.clone();
tokio::spawn(async move {
if let Err(()) =
call_webhook(&webhook_client, &url, &payload_clone, timeout_ms).await
{
tracing::warn!("Notify webhook POST to '{}' failed", url);
}
});
}
WebhookType::Blocking => {
let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
match call_webhook(&state.webhook_client, &webhook.url, &payload, timeout_ms)
.await
{
Ok(result) if result.status.is_success() => {
tracing::debug!(
"Blocking webhook '{}' allowed {} {}",
webhook.url,
method_str,
path
);
}
Ok(result) if result.status.is_redirection() => {
tracing::debug!(
"Blocking webhook '{}' redirecting {} {} to {}",
webhook.url,
method_str,
path,
result.location.as_deref().unwrap_or("(no location)")
);
let mut builder = Response::builder().status(result.status);
if let Some(loc) = &result.location {
builder =
builder.header(axum::http::header::LOCATION, loc.as_str());
}
return Ok(builder
.body(Body::empty())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?);
}
Ok(result) => {
tracing::warn!(
"Blocking webhook '{}' denied {} {} with status {}",
webhook.url,
method_str,
path,
result.status
);
return Err(result.status);
}
Err(()) => {
tracing::warn!(
"Blocking webhook '{}' timed out or failed for {} {} — denying request",
webhook.url,
method_str,
path
);
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
}
}
WebhookType::CacheKey => {
let timeout_ms = webhook.timeout_ms.unwrap_or(5000);
match call_webhook(&state.webhook_client, &webhook.url, &payload, timeout_ms)
.await
{
Ok(result) if result.status.is_success() => {
let key = result.body.trim().to_string();
if !key.is_empty() {
tracing::debug!(
"Cache key webhook '{}' set key '{}' for {} {}",
webhook.url,
key,
method_str,
path
);
cache_key_override = Some(key);
} else {
tracing::warn!(
"Cache key webhook '{}' returned empty body for {} {} — using default key",
webhook.url,
method_str,
path
);
}
}
Ok(result) => {
tracing::warn!(
"Cache key webhook '{}' returned non-2xx {} for {} {} — using default key",
webhook.url,
result.status,
method_str,
path
);
}
Err(()) => {
tracing::warn!(
"Cache key webhook '{}' timed out or failed for {} {} — using default key",
webhook.url,
method_str,
path
);
}
}
}
}
}
tracing::debug!(
method = method_str,
path,
elapsed_ms = webhook_started.elapsed().as_millis(),
"proxy request completed webhook phase"
);
}
let should_cache = should_cache_path(
method_str,
path,
&state.config.include_paths,
&state.config.exclude_paths,
);
let req_info = crate::RequestInfo {
method: method_str,
path,
query,
headers: &headers,
};
let cache_key = cache_key_override.unwrap_or_else(|| (state.config.cache_key_fn)(&req_info));
let cache_reads_enabled = !matches!(state.config.cache_strategy, crate::CacheStrategy::None);
if cache_reads_enabled && state.config.cache_404_capacity > 0 {
if let Some(cached) = state.cache.get_404(&cache_key).await {
if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
let response = build_response_from_cache(cached, &headers).await?;
tracing::debug!(
method = method_str,
path,
elapsed_ms = request_started.elapsed().as_millis(),
"proxy request served from 404 cache"
);
return Ok(response);
}
}
}
if should_cache && cache_reads_enabled {
if let Some(cached) = state.cache.get(&cache_key).await {
if cached_response_is_allowed(&state.config.cache_strategy, &cached) {
tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
let response = build_response_from_cache(cached, &headers).await?;
tracing::debug!(
method = method_str,
path,
elapsed_ms = request_started.elapsed().as_millis(),
"proxy request served from main cache"
);
return Ok(response);
}
}
if let ProxyMode::PreGenerate { fallthrough, .. } = &state.config.proxy_mode {
if !fallthrough {
tracing::debug!(
"PreGenerate cache miss for: {} {} — returning 404 (fallthrough disabled)",
method_str,
cache_key
);
return Err(StatusCode::NOT_FOUND);
}
}
tracing::debug!(
"Cache miss for: {} {}, fetching from backend",
method_str,
cache_key
);
} else if !cache_reads_enabled {
tracing::debug!(
"{} {} not cacheable (cache strategy: none), proxying directly",
method_str,
path
);
} else {
tracing::debug!(
"{} {} not cacheable (filtered), proxying directly",
method_str,
path
);
}
let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
tracing::error!("Failed to read request body: {}", e);
return Err(StatusCode::BAD_REQUEST);
}
};
let path_and_query = uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or_else(|| uri.path());
let target_url = format!("{}{}", state.config.proxy_url, path_and_query);
let upstream_started = Instant::now();
let response = match state
.upstream_client
.request(method.clone(), &target_url)
.headers(convert_headers(&headers))
.body(body_bytes.to_vec())
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::error!("Failed to fetch from backend: {}", e);
return Err(StatusCode::BAD_GATEWAY);
}
};
tracing::debug!(
method = method_str,
path,
elapsed_ms = upstream_started.elapsed().as_millis(),
"proxy request received upstream response headers"
);
let status = response.status().as_u16();
let response_headers = response.headers().clone();
let body_bytes = match response.bytes().await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
tracing::error!("Failed to read response body: {}", e);
return Err(StatusCode::BAD_GATEWAY);
}
};
let response_content_type = response_headers
.get(axum::http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok());
let response_is_cacheable = state
.config
.cache_strategy
.allows_content_type(response_content_type);
let upstream_content_encoding = response_headers
.get(axum::http::header::CONTENT_ENCODING)
.and_then(|value| value.to_str().ok());
let should_try_cache = cache_reads_enabled
&& response_is_cacheable
&& (should_cache || state.config.cache_404_capacity > 0);
let normalized_body = if should_try_cache || state.config.use_404_meta {
match decode_upstream_body_async(
body_bytes.clone(),
upstream_content_encoding.map(|value| value.to_string()),
)
.await
{
Ok(body) => Some(body),
Err(error) => {
tracing::warn!(
"Skipping cache compression for {} {} due to unsupported upstream encoding: {}",
method_str,
path,
error
);
None
}
}
} else {
None
};
let mut is_404 = status == 404;
if !is_404 && state.config.use_404_meta {
if let Some(body) = normalized_body.as_deref() {
is_404 = body_contains_404_meta(body);
}
}
let should_store_404 = is_404
&& state.config.cache_404_capacity > 0
&& response_is_cacheable
&& cache_reads_enabled
&& normalized_body.is_some();
let should_store_response = !is_404
&& should_cache
&& response_is_cacheable
&& cache_reads_enabled
&& normalized_body.is_some();
if should_store_404 || should_store_response {
let cached_response = match build_cached_response(
status,
&response_headers,
normalized_body.as_deref().unwrap(),
&state.config.compress_strategy,
)
.await
{
Ok(cached_response) => cached_response,
Err(error) => {
tracing::warn!(
"Failed to prepare cached response for {} {}: {}",
method_str,
path,
error
);
return Ok(build_response_from_upstream(
status,
&response_headers,
body_bytes,
));
}
};
if should_store_404 {
state
.cache
.set_404(cache_key.clone(), cached_response.clone())
.await;
tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
} else {
state
.cache
.set(cache_key.clone(), cached_response.clone())
.await;
tracing::debug!("Cached response for: {} {}", method_str, cache_key);
}
let response = build_response_from_cache(cached_response, &headers).await?;
tracing::debug!(
method = method_str,
path,
elapsed_ms = request_started.elapsed().as_millis(),
"proxy request completed after upstream fetch and cache write"
);
return Ok(response);
}
tracing::debug!(
method = method_str,
path,
elapsed_ms = request_started.elapsed().as_millis(),
"proxy request completed without caching"
);
Ok(build_response_from_upstream(
status,
&response_headers,
body_bytes,
))
}
async fn handle_upgrade_request(
state: Arc<ProxyState>,
mut req: Request<Body>,
) -> Result<Response<Body>, StatusCode> {
let req_path_and_query = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or_else(|| req.uri().path());
let target_url = format!("{}{}", state.config.proxy_url, req_path_and_query);
let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
tracing::error!("Failed to parse backend URL: {}", e);
StatusCode::BAD_GATEWAY
})?;
let host = backend_uri.host().ok_or_else(|| {
tracing::error!("No host in backend URL");
StatusCode::BAD_GATEWAY
})?;
let port = backend_uri.port_u16().unwrap_or_else(|| {
if backend_uri.scheme_str() == Some("https") {
443
} else {
80
}
});
let client_upgrade = hyper::upgrade::on(&mut req);
let backend_stream = tokio::net::TcpStream::connect((host, port))
.await
.map_err(|e| {
tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
StatusCode::BAD_GATEWAY
})?;
let backend_io = TokioIo::new(backend_stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
.await
.map_err(|e| {
tracing::error!("Failed to handshake with backend: {}", e);
StatusCode::BAD_GATEWAY
})?;
let conn_task = tokio::spawn(async move {
match conn.with_upgrades().await {
Ok(parts) => {
tracing::debug!("Backend connection upgraded successfully");
Ok(parts)
}
Err(e) => {
tracing::error!("Backend connection failed: {}", e);
Err(e)
}
}
});
let backend_response = sender.send_request(req).await.map_err(|e| {
tracing::error!("Failed to send request to backend: {}", e);
StatusCode::BAD_GATEWAY
})?;
let status = backend_response.status();
if status != StatusCode::SWITCHING_PROTOCOLS {
tracing::warn!("Backend did not accept upgrade request, status: {}", status);
let (parts, body) = backend_response.into_parts();
let body = Body::new(body);
return Ok(Response::from_parts(parts, body));
}
let backend_headers = backend_response.headers().clone();
let backend_upgrade = hyper::upgrade::on(backend_response);
tokio::spawn(async move {
tracing::debug!("Starting upgrade tunnel establishment");
let (client_result, backend_result) = tokio::join!(client_upgrade, backend_upgrade);
drop(conn_task);
match (client_result, backend_result) {
(Ok(client_upgraded), Ok(backend_upgraded)) => {
tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
let mut client_stream = TokioIo::new(client_upgraded);
let mut backend_stream = TokioIo::new(backend_upgraded);
match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
Ok((client_to_backend, backend_to_client)) => {
tracing::debug!(
"Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
client_to_backend,
backend_to_client
);
}
Err(e) => {
tracing::error!("Tunnel error: {}", e);
}
}
}
(Err(e), _) => {
tracing::error!("Client upgrade failed: {}", e);
}
(_, Err(e)) => {
tracing::error!("Backend upgrade failed: {}", e);
}
}
});
let mut response = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.body(Body::empty())
.unwrap();
if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
response
.headers_mut()
.insert(axum::http::header::UPGRADE, upgrade_header.clone());
}
if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
response
.headers_mut()
.insert(axum::http::header::CONNECTION, connection_header.clone());
}
if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
response.headers_mut().insert(
HeaderName::from_static("sec-websocket-accept"),
sec_websocket_accept.clone(),
);
}
tracing::debug!("Upgrade response sent to client, tunnel task spawned");
Ok(response)
}
async fn build_response_from_cache(
cached: CachedResponse,
request_headers: &HeaderMap,
) -> Result<Response<Body>, StatusCode> {
let mut response_headers = cached.headers;
let body = if let Some(content_encoding) = cached.content_encoding {
if client_accepts_encoding(request_headers, content_encoding) {
upsert_vary_accept_encoding(&mut response_headers);
cached.body
} else {
if !identity_acceptable(request_headers) {
tracing::warn!(
"Client does not accept cached encoding '{}' or identity fallback",
content_encoding.as_header_value()
);
return Err(StatusCode::NOT_ACCEPTABLE);
}
response_headers.remove("content-encoding");
upsert_vary_accept_encoding(&mut response_headers);
match decompress_body_async(cached.body.clone(), content_encoding).await {
Ok(body) => body,
Err(error) => {
tracing::error!("Failed to decompress cached response: {}", error);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
} else {
cached.body
};
response_headers.remove("transfer-encoding");
response_headers.insert("content-length".to_string(), body.len().to_string());
Ok(build_response(cached.status, response_headers, body))
}
async fn build_cached_response(
status: u16,
response_headers: &reqwest::header::HeaderMap,
normalized_body: &[u8],
compress_strategy: &CompressStrategy,
) -> anyhow::Result<CachedResponse> {
let mut headers = convert_headers_to_map(response_headers);
headers.remove("content-encoding");
headers.remove("content-length");
headers.remove("transfer-encoding");
let content_encoding = configured_encoding(compress_strategy);
let body = if let Some(content_encoding) = content_encoding {
let compressed = compress_body_async(normalized_body.to_vec(), content_encoding).await?;
headers.insert(
"content-encoding".to_string(),
content_encoding.as_header_value().to_string(),
);
upsert_vary_accept_encoding(&mut headers);
compressed
} else {
normalized_body.to_vec()
};
headers.insert("content-length".to_string(), body.len().to_string());
Ok(CachedResponse {
body,
headers,
status,
content_encoding,
})
}
fn build_response_from_upstream(
status: u16,
response_headers: &reqwest::header::HeaderMap,
body: Vec<u8>,
) -> Response<Body> {
let mut headers = convert_headers_to_map(response_headers);
headers.remove("transfer-encoding");
headers.insert("content-length".to_string(), body.len().to_string());
build_response(status, headers, body)
}
fn build_response(
status: u16,
response_headers: HashMap<String, String>,
body: Vec<u8>,
) -> Response<Body> {
let mut response = Response::builder().status(status);
let headers = response.headers_mut().unwrap();
for (key, value) in response_headers {
if let Ok(header_name) = key.parse::<HeaderName>() {
if let Ok(header_value) = HeaderValue::from_str(&value) {
headers.insert(header_name, header_value);
} else {
tracing::warn!(
"Failed to parse header value for key '{}': {:?}",
key,
value
);
}
} else {
tracing::warn!("Failed to parse header name: {}", key);
}
}
response.body(Body::from(body)).unwrap()
}
fn cached_response_is_allowed(strategy: &crate::CacheStrategy, cached: &CachedResponse) -> bool {
strategy.allows_content_type(
cached
.headers
.get("content-type")
.map(|value| value.as_str()),
)
}
fn body_contains_404_meta(body: &[u8]) -> bool {
let Ok(body_str) = std::str::from_utf8(body) else {
return false;
};
let name_dbl = "name=\"phantom-404\"";
let name_sgl = "name='phantom-404'";
let content_dbl = "content=\"true\"";
let content_sgl = "content='true'";
(body_str.contains(name_dbl) || body_str.contains(name_sgl))
&& (body_str.contains(content_dbl) || body_str.contains(content_sgl))
}
fn upsert_vary_accept_encoding(headers: &mut HashMap<String, String>) {
match headers.get_mut("vary") {
Some(value) => {
let has_accept_encoding = value
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("accept-encoding"));
if !has_accept_encoding {
value.push_str(", Accept-Encoding");
}
}
None => {
headers.insert("vary".to_string(), "Accept-Encoding".to_string());
}
}
}
fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
let mut req_headers = reqwest::header::HeaderMap::new();
for (key, value) in headers {
if key == axum::http::header::HOST {
continue;
}
if let Ok(val) = value.to_str() {
if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
req_headers.insert(key.clone(), header_value);
}
}
}
req_headers
}
pub(crate) async fn fetch_and_cache_snapshot(
path: &str,
client: &reqwest::Client,
proxy_url: &str,
cache: &CacheStore,
compress_strategy: &CompressStrategy,
cache_key_fn: &std::sync::Arc<dyn Fn(&crate::RequestInfo) -> String + Send + Sync>,
) -> anyhow::Result<()> {
let empty_headers = axum::http::HeaderMap::new();
let req_info = crate::RequestInfo {
method: "GET",
path,
query: "",
headers: &empty_headers,
};
let cache_key = cache_key_fn(&req_info);
let url = format!("{}{}", proxy_url, path);
let response = client
.get(&url)
.send()
.await
.map_err(|e| anyhow::anyhow!("Failed to fetch snapshot '{}': {}", path, e))?;
let status = response.status().as_u16();
let response_headers = response.headers().clone();
let body_bytes = response
.bytes()
.await
.map_err(|e| anyhow::anyhow!("Failed to read snapshot response for '{}': {}", path, e))?
.to_vec();
let upstream_encoding = response_headers
.get(axum::http::header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok());
let normalized =
decode_upstream_body_async(body_bytes, upstream_encoding.map(|value| value.to_string()))
.await
.map_err(|e| anyhow::anyhow!("Failed to decode snapshot body for '{}': {}", path, e))?;
let cached =
build_cached_response(status, &response_headers, &normalized, compress_strategy).await?;
cache.set(cache_key, cached).await;
tracing::debug!("Snapshot pre-generated: {}", path);
Ok(())
}
fn convert_headers_to_map(
headers: &reqwest::header::HeaderMap,
) -> std::collections::HashMap<String, String> {
let mut map = std::collections::HashMap::new();
for (key, value) in headers {
if let Ok(val) = value.to_str() {
map.insert(key.as_str().to_ascii_lowercase(), val.to_string());
} else {
tracing::debug!("Could not convert header '{}' to string", key);
}
}
map
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compression::ContentEncoding;
use axum::body::to_bytes;
fn response_headers() -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("text/html; charset=utf-8"),
);
headers
}
#[tokio::test]
async fn test_build_cached_response_uses_selected_encoding() {
let cached = build_cached_response(
200,
&response_headers(),
b"<html>compressed</html>",
&CompressStrategy::Gzip,
)
.await
.unwrap();
assert_eq!(cached.content_encoding, Some(ContentEncoding::Gzip));
assert_eq!(
cached.headers.get("content-encoding"),
Some(&"gzip".to_string())
);
assert_eq!(
cached.headers.get("vary"),
Some(&"Accept-Encoding".to_string())
);
}
#[tokio::test]
async fn test_build_response_from_cache_falls_back_to_identity() {
let body = b"<html>identity</html>";
let compressed = crate::compression::compress_body(body, ContentEncoding::Brotli).unwrap();
let cached = CachedResponse {
body: compressed,
headers: HashMap::from([
("content-type".to_string(), "text/html".to_string()),
("content-encoding".to_string(), "br".to_string()),
("content-length".to_string(), "123".to_string()),
("vary".to_string(), "Accept-Encoding".to_string()),
]),
status: 200,
content_encoding: Some(ContentEncoding::Brotli),
};
let mut request_headers = HeaderMap::new();
request_headers.insert(
axum::http::header::ACCEPT_ENCODING,
HeaderValue::from_static("gzip"),
);
let response = build_response_from_cache(cached, &request_headers)
.await
.unwrap();
assert!(response
.headers()
.get(axum::http::header::CONTENT_ENCODING)
.is_none());
let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(body.as_ref(), b"<html>identity</html>");
}
#[tokio::test]
async fn test_build_response_from_cache_keeps_supported_encoding() {
let body = b"<html>compressed</html>";
let compressed = crate::compression::compress_body(body, ContentEncoding::Brotli).unwrap();
let cached = CachedResponse {
body: compressed.clone(),
headers: HashMap::from([
("content-type".to_string(), "text/html".to_string()),
("content-encoding".to_string(), "br".to_string()),
("content-length".to_string(), compressed.len().to_string()),
("vary".to_string(), "Accept-Encoding".to_string()),
]),
status: 200,
content_encoding: Some(ContentEncoding::Brotli),
};
let mut request_headers = HeaderMap::new();
request_headers.insert(
axum::http::header::ACCEPT_ENCODING,
HeaderValue::from_static("br, gzip;q=0.5"),
);
let response = build_response_from_cache(cached, &request_headers)
.await
.unwrap();
assert_eq!(
response.headers().get(axum::http::header::CONTENT_ENCODING),
Some(&HeaderValue::from_static("br"))
);
let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(body.as_ref(), compressed.as_slice());
}
}