ordinary-utils 0.8.2

Utils for Ordinary
Documentation
// Copyright (C) 2026 Ordinary Labs, LLC.
//
// SPDX-License-Identifier: AGPL-3.0-only

use crate::{
    GMT_FORMAT, HeadersDebug, LatencyDisplay, SERVER, WrappedRedactedHashingAlg, X_REQUEST_ID,
    X_VIA, get_host, get_http_version_str, response_for_panic,
};
use ahash::AHasher;
use axum::Router;
use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use http_body_util::BodyExt;
use hyper::HeaderMap;
use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
use std::hash::Hasher;
use std::sync::Arc;
use std::time::Duration;
use time::UtcDateTime;
use tower::ServiceBuilder;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::classify::ServerErrorsFailureClass;
use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::set_header::{SetRequestHeaderLayer, SetResponseHeaderLayer};
use tower_http::trace::TraceLayer;
use tracing::Span;
use uuid::Uuid;

#[derive(Clone)]
pub enum ServiceKind {
    App,
    Api,
    Redirect,
    Proxy,
}

#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
pub fn apply_common_middleware<T>(
    router: Router<T>,
    state: &T,
    server_span: Option<Span>,
    domain: String,
    log_headers: bool,
    log_ips: bool,
    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
    kind: ServiceKind,
    via_domain: Option<String>,
) -> Router
where
    T: Clone + Send + Sync + 'static,
{
    let redacted_hash_clone = redacted_hash.clone();

    router
        .with_state(state.clone())
        .layer(
            ServiceBuilder::new()
                .layer(CatchPanicLayer::custom(response_for_panic))
                .layer(RequestDecompressionLayer::new())
                .layer(CompressionLayer::new()),
        )
        .layer(
            ServiceBuilder::new()
                .layer(SetRequestIdLayer::new(X_REQUEST_ID, MakeRequestUuid))
                .layer(
                    TraceLayer::new_for_http()
                        .make_span_with(move |req: &axum::http::Request<_>| {
                            let request_id = req
                                .headers()
                                .get(X_REQUEST_ID)
                                .and_then(|rid| {
                                    rid.to_str()
                                        .ok()
                                        .and_then(|rid| Uuid::parse_str(rid).ok())
                                        .map(tracing::field::display)
                                })
                                .unwrap_or(tracing::field::display(Uuid::new_v4()));

                            let host =
                                get_host(req.headers(), req.uri()).map(tracing::field::display);

                            let ip = crate::get_display_ip(log_ips, req);

                            let query = req.uri().query().map(tracing::field::display);

                            match kind {
                                ServiceKind::App => {
                                    if let Some(server_span) = &server_span {
                                        server_span.in_scope(|| {
                                            tracing::info_span!(
                                                "app",
                                                %domain,
                                                host,
                                                rid = %request_id,
                                                ip,
                                                path = %req.uri().path(),
                                                query,
                                            )
                                        })
                                    } else {
                                        tracing::info_span!(
                                            "app",
                                            %domain,
                                            host,
                                            rid = %request_id,
                                            ip,
                                            path = %req.uri().path(),
                                            query,
                                        )
                                    }
                                }
                                ServiceKind::Proxy => {
                                    if let Some(server_span) = &server_span {
                                        server_span.in_scope(|| {
                                            tracing::info_span!(
                                                "proxy",
                                                %domain,
                                                host,
                                                rid = %request_id,
                                                ip,
                                                path = %req.uri().path(),
                                                query,
                                            )
                                        })
                                    } else {
                                        tracing::info_span!(
                                            "proxy",
                                            %domain,
                                            host,
                                            rid = %request_id,
                                            ip,
                                            path = %req.uri().path(),
                                            query,
                                        )
                                    }
                                }
                                ServiceKind::Api => {
                                    if let Some(server_span) = &server_span {
                                        server_span.in_scope(|| {
                                            tracing::info_span!(
                                                "api",
                                                %domain,
                                                host,
                                                rid = %request_id,
                                                ip,
                                                path = %req.uri().path(),
                                                query,
                                            )
                                        })
                                    } else {
                                        tracing::info_span!(
                                            "api",
                                            %domain,
                                            host,
                                            rid = %request_id,
                                            ip,
                                            path = %req.uri().path(),
                                            query,
                                        )
                                    }
                                }
                                ServiceKind::Redirect => {
                                    if let Some(server_span) = &server_span {
                                        server_span.in_scope(|| {
                                            tracing::info_span!(
                                                "redirect",
                                                host,
                                                rid = %request_id,
                                                ip,
                                                path = %req.uri().path(),
                                                query,
                                            )
                                        })
                                    } else {
                                        tracing::info_span!(
                                            "redirect",
                                            host,
                                            rid = %request_id,
                                            ip,
                                            path = %req.uri().path(),
                                            query,
                                        )
                                    }
                                }
                            }
                        })
                        .on_request(move |req: &axum::http::Request<_>, _: &Span| {
                            let hd = log_headers
                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));

                            #[cfg(tracing_unstable)]
                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
                            #[cfg(not(tracing_unstable))]
                            let headers = log_headers.then_some(tracing::field::debug(&hd));

                            tracing::info!(
                                version = ?req.version(),
                                method = %req.method(),
                                headers,
                                "req"
                            );
                        })
                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
                            let hd = log_headers.then_some(HeadersDebug(
                                res.headers(),
                                redacted_hash_clone.clone(),
                            ));

                            #[cfg(tracing_unstable)]
                            let headers = log_headers.then_some(tracing::field::valuable(&hd));

                            #[cfg(not(tracing_unstable))]
                            let headers = log_headers.then_some(tracing::field::debug(&hd));

                            let status = res.status().as_u16();
                            let latency = LatencyDisplay(latency.as_nanos() as f64);

                            if status >= 500 {
                                tracing::error!(status, headers, %latency, "res");
                            } else if status >= 400 {
                                tracing::warn!(status, headers, %latency, "res");
                            } else {
                                tracing::info!(status, headers, %latency, "res");
                            }
                        })
                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
                            tracing::error!(
                                err = %error,
                                "fail"
                            );
                        }),
                )
                .layer(PropagateRequestIdLayer::new(X_REQUEST_ID))
                .layer(SetResponseHeaderLayer::if_not_present(
                    header::SERVER,
                    HeaderValue::from_static(SERVER),
                ))
                .option_layer(via_domain.map(|domain| {
                    SetRequestHeaderLayer::overriding(X_VIA, move |req: &axum::http::Request<_>| {
                        let req_version = get_http_version_str(req.version());
                        HeaderValue::from_str(&format!("{req_version} {domain} (ordinaryd)")).ok()
                    })
                })),
        )
}

#[allow(clippy::similar_names)]
pub async fn http_cache_middleware(
    last_modified: UtcDateTime,
    req_headers: HeaderMap,
    request: Request,
    next: Next,
) -> Response {
    let response = next.run(request).await;
    let (mut parts, body) = response.into_parts();

    let body_bytes = if let Ok(collected) = body.collect().await {
        collected.to_bytes()
    } else {
        return StatusCode::INTERNAL_SERVER_ERROR.into_response();
    };

    let mut res_headers = HeaderMap::new();

    let etag_string = get_etag_hash(body_bytes.as_ref(), None);
    let etag_str = etag_string.as_str();

    if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
        && let Ok(if_none_match_str) = if_none_match.to_str()
        && if_none_match_str == etag_str
    {
        res_headers.insert(header::ETAG, if_none_match.to_owned());

        return (StatusCode::NOT_MODIFIED, res_headers).into_response();
    } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
        if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
            && let Ok(if_modified_since_str) = if_modified_since.to_str()
            && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
            && if_modified_since >= last_modified
        {
            res_headers.insert(header::ETAG, etag_header);
            return (StatusCode::NOT_MODIFIED, res_headers).into_response();
        }

        parts.headers.insert(header::ETAG, etag_header);
    }

    (parts, body_bytes).into_response()
}

#[must_use]
pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
    if let Some(http_cache) = http_cache
        && let Some(etag_config) = &http_cache.etag
        && let Some(etag_alg) = &etag_config.alg
    {
        return match etag_alg {
            HttpEtagAlgorithm::AHash => {
                let mut hasher = AHasher::default();
                hasher.write(content);
                b64.encode(hasher.finish().to_be_bytes())
            }
            HttpEtagAlgorithm::XXH3(variation) => match variation {
                XXH3Variation::Bit64 => {
                    b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
                }
                XXH3Variation::Bit128 => {
                    b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
                }
            },
            HttpEtagAlgorithm::Rustc => {
                let mut hasher = rustc_hash::FxHasher::default();
                hasher.write(content);

                b64.encode(hasher.finish().to_be_bytes())
            }
            HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
        };
    }

    let mut hasher = AHasher::default();
    hasher.write(content);
    b64.encode(hasher.finish().to_be_bytes())
}

#[must_use]
pub async fn x_via(headers: HeaderMap, request: Request, next: Next) -> Response {
    let mut response = next.run(request).await;

    if let Some(x_via) = headers.get(X_VIA) {
        response.headers_mut().insert(header::VIA, x_via.to_owned());
    }

    response
}

pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
    let headers = res.headers();

    if let Some(curr_etag) = headers.get(header::ETAG)
        && let Ok(curr_etag_str) = curr_etag.to_str()
    {
        let etag_len = curr_etag_str.len();

        if (etag_len == 22 || etag_len == 11)
            && let Some(compression) = headers.get(header::CONTENT_ENCODING)
            && let Ok(compression_str) = compression.to_str()
        {
            let mut etag_string = curr_etag_str.to_owned();

            match compression_str {
                "gzip" => etag_string.push('1'),
                "zstd" => etag_string.push('2'),
                "br" => etag_string.push('3'),
                "deflate" => etag_string.push('4'),
                _ => (),
            }

            match HeaderValue::from_str(etag_string.as_str()) {
                Ok(v) => return Some(v),
                Err(err) => tracing::error!(%err),
            }
        } else {
            return Some(curr_etag.clone());
        }
    }

    None
}

pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
    if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
        && let Ok(if_none_match_str) = if_none_match.to_str()
    {
        if if_none_match_str.len() < 11 {
            return None;
        }

        if (etag.len() == 23
            || etag.len() == 12
            || if_none_match_str.len() == 22
            || if_none_match_str.len() == 11)
            && if_none_match_str == etag
        {
            return Some(etag);
        }

        if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
            Some(if_none_match_str)
        } else {
            None
        }
    } else {
        None
    }
}