use crate::GMT_FORMAT;
use ahash::AHasher;
use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use http_body_util::BodyExt;
use hyper::HeaderMap;
use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
use std::hash::Hasher;
use time::{OffsetDateTime, UtcDateTime};
pub async fn http_cache_middleware(
last_modified: UtcDateTime,
req_headers: HeaderMap,
request: Request,
next: Next,
) -> Response {
let response = next.run(request).await;
let (mut parts, body) = response.into_parts();
let body_bytes = if let Ok(collected) = body.collect().await {
collected.to_bytes()
} else {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
};
let mut res_headers = HeaderMap::new();
let etag_string = get_etag_hash(body_bytes.as_ref(), None);
let etag_str = etag_string.as_str();
if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
&& let Ok(if_none_match_str) = if_none_match.to_str()
&& if_none_match_str == etag_str
{
res_headers.insert(header::ETAG, if_none_match.to_owned());
return (StatusCode::NOT_MODIFIED, res_headers).into_response();
} else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
&& let Ok(if_modified_since_str) = if_modified_since.to_str()
&& let Ok(if_modified_since) = OffsetDateTime::parse(if_modified_since_str, &GMT_FORMAT)
&& if_modified_since >= last_modified
{
res_headers.insert(header::ETAG, etag_header);
return (StatusCode::NOT_MODIFIED, res_headers).into_response();
}
parts.headers.insert(header::ETAG, etag_header);
}
(parts, body_bytes).into_response()
}
pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
if let Some(http_cache) = http_cache
&& let Some(etag_config) = &http_cache.etag
&& let Some(etag_alg) = &etag_config.alg
{
return match etag_alg {
HttpEtagAlgorithm::AHash => {
let mut hasher = AHasher::default();
hasher.write(content);
b64.encode(hasher.finish().to_be_bytes())
}
HttpEtagAlgorithm::XXH3(variation) => match variation {
XXH3Variation::Bit64 => {
b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
}
XXH3Variation::Bit128 => {
b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
}
},
HttpEtagAlgorithm::Rustc => {
let mut hasher = rustc_hash::FxHasher::default();
hasher.write(content);
b64.encode(hasher.finish().to_be_bytes())
}
HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
};
}
let mut hasher = AHasher::default();
hasher.write(content);
b64.encode(hasher.finish().to_be_bytes())
}
pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
let headers = res.headers();
if let Some(curr_etag) = headers.get(header::ETAG)
&& let Ok(curr_etag_str) = curr_etag.to_str()
{
let etag_len = curr_etag_str.len();
if (etag_len == 22 || etag_len == 11)
&& let Some(compression) = headers.get(header::CONTENT_ENCODING)
&& let Ok(compression_str) = compression.to_str()
{
let mut etag_string = curr_etag_str.to_owned();
match compression_str {
"gzip" => etag_string.push('1'),
"zstd" => etag_string.push('2'),
"br" => etag_string.push('3'),
"deflate" => etag_string.push('4'),
_ => (),
}
match HeaderValue::from_str(etag_string.as_str()) {
Ok(v) => return Some(v),
Err(err) => tracing::error!(%err),
}
} else {
return Some(curr_etag.clone());
}
}
None
}
pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
&& let Ok(if_none_match_str) = if_none_match.to_str()
{
if if_none_match_str.len() < 11 {
return None;
}
if (etag.len() == 23
|| etag.len() == 12
|| if_none_match_str.len() == 22
|| if_none_match_str.len() == 11)
&& if_none_match_str == etag
{
return Some(etag);
}
if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
Some(if_none_match_str)
} else {
None
}
} else {
None
}
}