ordinary-utils 0.6.0-pre.14

Utils for Ordinary
Documentation
// Copyright (C) 2026 Ordinary Labs, LLC.
//
// SPDX-License-Identifier: AGPL-3.0-only

use crate::GMT_FORMAT;
use ahash::AHasher;
use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use http_body_util::BodyExt;
use hyper::HeaderMap;
use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
use std::hash::Hasher;
use time::UtcDateTime;

#[allow(clippy::similar_names)]
pub async fn http_cache_middleware(
    last_modified: UtcDateTime,
    req_headers: HeaderMap,
    request: Request,
    next: Next,
) -> Response {
    let response = next.run(request).await;
    let (mut parts, body) = response.into_parts();

    let body_bytes = if let Ok(collected) = body.collect().await {
        collected.to_bytes()
    } else {
        return StatusCode::INTERNAL_SERVER_ERROR.into_response();
    };

    let mut res_headers = HeaderMap::new();

    let etag_string = get_etag_hash(body_bytes.as_ref(), None);
    let etag_str = etag_string.as_str();

    if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
        && let Ok(if_none_match_str) = if_none_match.to_str()
        && if_none_match_str == etag_str
    {
        res_headers.insert(header::ETAG, if_none_match.to_owned());

        return (StatusCode::NOT_MODIFIED, res_headers).into_response();
    } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
        if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
            && let Ok(if_modified_since_str) = if_modified_since.to_str()
            && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
            && if_modified_since >= last_modified
        {
            res_headers.insert(header::ETAG, etag_header);
            return (StatusCode::NOT_MODIFIED, res_headers).into_response();
        }

        parts.headers.insert(header::ETAG, etag_header);
    }

    (parts, body_bytes).into_response()
}

#[must_use]
pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
    if let Some(http_cache) = http_cache
        && let Some(etag_config) = &http_cache.etag
        && let Some(etag_alg) = &etag_config.alg
    {
        return match etag_alg {
            HttpEtagAlgorithm::AHash => {
                let mut hasher = AHasher::default();
                hasher.write(content);
                b64.encode(hasher.finish().to_be_bytes())
            }
            HttpEtagAlgorithm::XXH3(variation) => match variation {
                XXH3Variation::Bit64 => {
                    b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
                }
                XXH3Variation::Bit128 => {
                    b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
                }
            },
            HttpEtagAlgorithm::Rustc => {
                let mut hasher = rustc_hash::FxHasher::default();
                hasher.write(content);

                b64.encode(hasher.finish().to_be_bytes())
            }
            HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
        };
    }

    let mut hasher = AHasher::default();
    hasher.write(content);
    b64.encode(hasher.finish().to_be_bytes())
}

pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
    let headers = res.headers();

    if let Some(curr_etag) = headers.get(header::ETAG)
        && let Ok(curr_etag_str) = curr_etag.to_str()
    {
        let etag_len = curr_etag_str.len();

        if (etag_len == 22 || etag_len == 11)
            && let Some(compression) = headers.get(header::CONTENT_ENCODING)
            && let Ok(compression_str) = compression.to_str()
        {
            let mut etag_string = curr_etag_str.to_owned();

            match compression_str {
                "gzip" => etag_string.push('1'),
                "zstd" => etag_string.push('2'),
                "br" => etag_string.push('3'),
                "deflate" => etag_string.push('4'),
                _ => (),
            }

            match HeaderValue::from_str(etag_string.as_str()) {
                Ok(v) => return Some(v),
                Err(err) => tracing::error!(%err),
            }
        } else {
            return Some(curr_etag.clone());
        }
    }

    None
}

pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
    if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
        && let Ok(if_none_match_str) = if_none_match.to_str()
    {
        if if_none_match_str.len() < 11 {
            return None;
        }

        if (etag.len() == 23
            || etag.len() == 12
            || if_none_match_str.len() == 22
            || if_none_match_str.len() == 11)
            && if_none_match_str == etag
        {
            return Some(etag);
        }

        if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
            Some(if_none_match_str)
        } else {
            None
        }
    } else {
        None
    }
}