Skip to main content

ordinary_utils/
middleware.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use crate::GMT_FORMAT;
6use ahash::AHasher;
7use axum::extract::Request;
8use axum::http::{HeaderValue, StatusCode, header};
9use axum::middleware::Next;
10use axum::response::{IntoResponse, Response};
11use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
12use http_body_util::BodyExt;
13use hyper::HeaderMap;
14use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
15use std::hash::Hasher;
16use time::UtcDateTime;
17
18#[allow(clippy::similar_names)]
19pub async fn http_cache_middleware(
20    last_modified: UtcDateTime,
21    req_headers: HeaderMap,
22    request: Request,
23    next: Next,
24) -> Response {
25    let response = next.run(request).await;
26    let (mut parts, body) = response.into_parts();
27
28    let body_bytes = if let Ok(collected) = body.collect().await {
29        collected.to_bytes()
30    } else {
31        return StatusCode::INTERNAL_SERVER_ERROR.into_response();
32    };
33
34    let mut res_headers = HeaderMap::new();
35
36    let etag_string = get_etag_hash(body_bytes.as_ref(), None);
37    let etag_str = etag_string.as_str();
38
39    if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
40        && let Ok(if_none_match_str) = if_none_match.to_str()
41        && if_none_match_str == etag_str
42    {
43        res_headers.insert(header::ETAG, if_none_match.to_owned());
44
45        return (StatusCode::NOT_MODIFIED, res_headers).into_response();
46    } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
47        if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
48            && let Ok(if_modified_since_str) = if_modified_since.to_str()
49            && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
50            && if_modified_since >= last_modified
51        {
52            res_headers.insert(header::ETAG, etag_header);
53            return (StatusCode::NOT_MODIFIED, res_headers).into_response();
54        }
55
56        parts.headers.insert(header::ETAG, etag_header);
57    }
58
59    (parts, body_bytes).into_response()
60}
61
62#[must_use]
63pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
64    if let Some(http_cache) = http_cache
65        && let Some(etag_config) = &http_cache.etag
66        && let Some(etag_alg) = &etag_config.alg
67    {
68        return match etag_alg {
69            HttpEtagAlgorithm::AHash => {
70                let mut hasher = AHasher::default();
71                hasher.write(content);
72                b64.encode(hasher.finish().to_be_bytes())
73            }
74            HttpEtagAlgorithm::XXH3(variation) => match variation {
75                XXH3Variation::Bit64 => {
76                    b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
77                }
78                XXH3Variation::Bit128 => {
79                    b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
80                }
81            },
82            HttpEtagAlgorithm::Rustc => {
83                let mut hasher = rustc_hash::FxHasher::default();
84                hasher.write(content);
85
86                b64.encode(hasher.finish().to_be_bytes())
87            }
88            HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
89        };
90    }
91
92    let mut hasher = AHasher::default();
93    hasher.write(content);
94    b64.encode(hasher.finish().to_be_bytes())
95}
96
97pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
98    let headers = res.headers();
99
100    if let Some(curr_etag) = headers.get(header::ETAG)
101        && let Ok(curr_etag_str) = curr_etag.to_str()
102    {
103        let etag_len = curr_etag_str.len();
104
105        if (etag_len == 22 || etag_len == 11)
106            && let Some(compression) = headers.get(header::CONTENT_ENCODING)
107            && let Ok(compression_str) = compression.to_str()
108        {
109            let mut etag_string = curr_etag_str.to_owned();
110
111            match compression_str {
112                "gzip" => etag_string.push('1'),
113                "zstd" => etag_string.push('2'),
114                "br" => etag_string.push('3'),
115                "deflate" => etag_string.push('4'),
116                _ => (),
117            }
118
119            match HeaderValue::from_str(etag_string.as_str()) {
120                Ok(v) => return Some(v),
121                Err(err) => tracing::error!(%err),
122            }
123        } else {
124            return Some(curr_etag.clone());
125        }
126    }
127
128    None
129}
130
131pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
132    if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
133        && let Ok(if_none_match_str) = if_none_match.to_str()
134    {
135        if if_none_match_str.len() < 11 {
136            return None;
137        }
138
139        if (etag.len() == 23
140            || etag.len() == 12
141            || if_none_match_str.len() == 22
142            || if_none_match_str.len() == 11)
143            && if_none_match_str == etag
144        {
145            return Some(etag);
146        }
147
148        if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
149            Some(if_none_match_str)
150        } else {
151            None
152        }
153    } else {
154        None
155    }
156}