Skip to main content

ordinary_utils/
middleware.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use crate::{
6    GMT_FORMAT, HeadersDebug, LatencyDisplay, SERVER, WrappedRedactedHashingAlg, X_REQUEST_ID,
7    X_VIA, get_host, get_http_version_str, response_for_panic,
8};
9use ahash::AHasher;
10use axum::Router;
11use axum::body::HttpBody;
12use axum::extract::Request;
13use axum::http::{HeaderValue, StatusCode, header};
14use axum::middleware::Next;
15use axum::response::{IntoResponse, Response};
16use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
17use hyper::HeaderMap;
18use hyper::header::LAST_MODIFIED;
19use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
20use std::hash::Hasher;
21use std::sync::Arc;
22use std::time::Duration;
23use time::UtcDateTime;
24use tower::ServiceBuilder;
25use tower_http::catch_panic::CatchPanicLayer;
26use tower_http::classify::ServerErrorsFailureClass;
27use tower_http::compression::CompressionLayer;
28use tower_http::decompression::RequestDecompressionLayer;
29use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
30use tower_http::set_header::{SetRequestHeaderLayer, SetResponseHeaderLayer};
31use tower_http::trace::TraceLayer;
32use tracing::Span;
33use uuid::Uuid;
34
35#[derive(Clone)]
36pub enum ServiceKind {
37    App,
38    Api,
39    Redirect,
40    Proxy,
41}
42
43#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
44pub fn apply_common_middleware<T>(
45    router: Router<T>,
46    state: &T,
47    server_span: Option<Span>,
48    domain: String,
49    log_headers: bool,
50    log_ips: bool,
51    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
52    kind: ServiceKind,
53    via_domain: Option<String>,
54) -> Router
55where
56    T: Clone + Send + Sync + 'static,
57{
58    let redacted_hash_clone = redacted_hash.clone();
59
60    router
61        .with_state(state.clone())
62        .layer(
63            ServiceBuilder::new()
64                .layer(CatchPanicLayer::custom(response_for_panic))
65                .layer(RequestDecompressionLayer::new())
66                .layer(CompressionLayer::new()),
67        )
68        .layer(
69            ServiceBuilder::new()
70                .layer(SetRequestIdLayer::new(X_REQUEST_ID, MakeRequestUuid))
71                .layer(
72                    TraceLayer::new_for_http()
73                        .make_span_with(move |req: &axum::http::Request<_>| {
74                            let request_id = req
75                                .headers()
76                                .get(X_REQUEST_ID)
77                                .and_then(|rid| {
78                                    rid.to_str()
79                                        .ok()
80                                        .and_then(|rid| Uuid::parse_str(rid).ok())
81                                        .map(tracing::field::display)
82                                })
83                                .unwrap_or(tracing::field::display(Uuid::new_v4()));
84
85                            let host =
86                                get_host(req.headers(), req.uri()).map(tracing::field::display);
87
88                            let ip = crate::get_display_ip(log_ips, req);
89
90                            let query = req.uri().query().map(tracing::field::display);
91
92                            match kind {
93                                ServiceKind::App => {
94                                    if let Some(server_span) = &server_span {
95                                        server_span.in_scope(|| {
96                                            tracing::info_span!(
97                                                "app",
98                                                %domain,
99                                                host,
100                                                rid = %request_id,
101                                                ip,
102                                                path = %req.uri().path(),
103                                                query,
104                                            )
105                                        })
106                                    } else {
107                                        tracing::info_span!(
108                                            "app",
109                                            %domain,
110                                            host,
111                                            rid = %request_id,
112                                            ip,
113                                            path = %req.uri().path(),
114                                            query,
115                                        )
116                                    }
117                                }
118                                ServiceKind::Proxy => {
119                                    if let Some(server_span) = &server_span {
120                                        server_span.in_scope(|| {
121                                            tracing::info_span!(
122                                                "proxy",
123                                                %domain,
124                                                host,
125                                                rid = %request_id,
126                                                ip,
127                                                path = %req.uri().path(),
128                                                query,
129                                            )
130                                        })
131                                    } else {
132                                        tracing::info_span!(
133                                            "proxy",
134                                            %domain,
135                                            host,
136                                            rid = %request_id,
137                                            ip,
138                                            path = %req.uri().path(),
139                                            query,
140                                        )
141                                    }
142                                }
143                                ServiceKind::Api => {
144                                    if let Some(server_span) = &server_span {
145                                        server_span.in_scope(|| {
146                                            tracing::info_span!(
147                                                "api",
148                                                %domain,
149                                                host,
150                                                rid = %request_id,
151                                                ip,
152                                                path = %req.uri().path(),
153                                                query,
154                                            )
155                                        })
156                                    } else {
157                                        tracing::info_span!(
158                                            "api",
159                                            %domain,
160                                            host,
161                                            rid = %request_id,
162                                            ip,
163                                            path = %req.uri().path(),
164                                            query,
165                                        )
166                                    }
167                                }
168                                ServiceKind::Redirect => {
169                                    if let Some(server_span) = &server_span {
170                                        server_span.in_scope(|| {
171                                            tracing::info_span!(
172                                                "redirect",
173                                                host,
174                                                rid = %request_id,
175                                                ip,
176                                                path = %req.uri().path(),
177                                                query,
178                                            )
179                                        })
180                                    } else {
181                                        tracing::info_span!(
182                                            "redirect",
183                                            host,
184                                            rid = %request_id,
185                                            ip,
186                                            path = %req.uri().path(),
187                                            query,
188                                        )
189                                    }
190                                }
191                            }
192                        })
193                        .on_request(move |req: &axum::http::Request<_>, _: &Span| {
194                            let hd = log_headers
195                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
196
197                            #[cfg(tracing_unstable)]
198                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
199                            #[cfg(not(tracing_unstable))]
200                            let headers = log_headers.then_some(tracing::field::debug(&hd));
201
202                            tracing::info!(
203                                version = ?req.version(),
204                                method = %req.method(),
205                                headers,
206                                "req"
207                            );
208                        })
209                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
210                            let hd = log_headers.then_some(HeadersDebug(
211                                res.headers(),
212                                redacted_hash_clone.clone(),
213                            ));
214
215                            #[cfg(tracing_unstable)]
216                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
217
218                            #[cfg(not(tracing_unstable))]
219                            let headers = log_headers.then_some(tracing::field::debug(&hd));
220
221                            let status = res.status().as_u16();
222                            let latency = LatencyDisplay(latency.as_nanos() as f64);
223
224                            if status >= 500 {
225                                tracing::error!(status, headers, %latency, "res");
226                            } else if status >= 400 {
227                                tracing::warn!(status, headers, %latency, "res");
228                            } else {
229                                tracing::info!(status, headers, %latency, "res");
230                            }
231                        })
232                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
233                            tracing::error!(
234                                err = %error,
235                                "fail"
236                            );
237                        }),
238                )
239                .layer(PropagateRequestIdLayer::new(X_REQUEST_ID))
240                .layer(SetResponseHeaderLayer::if_not_present(
241                    header::SERVER,
242                    HeaderValue::from_static(SERVER),
243                ))
244                .option_layer(via_domain.map(|domain| {
245                    SetRequestHeaderLayer::overriding(X_VIA, move |req: &axum::http::Request<_>| {
246                        let req_version = get_http_version_str(req.version());
247                        HeaderValue::from_str(&format!("{req_version} {domain} (ordinaryd)")).ok()
248                    })
249                })),
250        )
251}
252
253#[allow(clippy::similar_names)]
254pub async fn http_cache_middleware(
255    last_modified: UtcDateTime,
256    last_modified_header: HeaderValue,
257    req_headers: HeaderMap,
258    request: Request,
259    next: Next,
260) -> Response {
261    let response = next.run(request).await;
262    let (mut parts, body) = response.into_parts();
263
264    let body_bytes = if let Some(limit) = body.size_hint().upper()
265        && let Ok(limit) = usize::try_from(limit)
266        && let Ok(body_bytes) = axum::body::to_bytes(body, limit).await
267    {
268        body_bytes
269    } else {
270        return StatusCode::INTERNAL_SERVER_ERROR.into_response();
271    };
272
273    let mut res_headers = HeaderMap::new();
274    res_headers.insert(LAST_MODIFIED, last_modified_header);
275
276    let etag_string = get_etag_hash(body_bytes.as_ref(), None);
277    let etag_str = etag_string.as_str();
278
279    if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
280        && let Ok(if_none_match_str) = if_none_match.to_str()
281        && if_none_match_str == etag_str
282    {
283        res_headers.insert(header::ETAG, if_none_match.to_owned());
284
285        return (StatusCode::NOT_MODIFIED, res_headers).into_response();
286    } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
287        if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
288            && let Ok(if_modified_since_str) = if_modified_since.to_str()
289            && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
290            && if_modified_since >= last_modified
291        {
292            res_headers.insert(header::ETAG, etag_header);
293            return (StatusCode::NOT_MODIFIED, res_headers).into_response();
294        }
295
296        parts.headers.insert(header::ETAG, etag_header);
297    }
298
299    (parts, body_bytes).into_response()
300}
301
302#[must_use]
303pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
304    if let Some(http_cache) = http_cache
305        && let Some(etag_config) = &http_cache.etag
306        && let Some(etag_alg) = &etag_config.alg
307    {
308        return match etag_alg {
309            HttpEtagAlgorithm::AHash => {
310                let mut hasher = AHasher::default();
311                hasher.write(content);
312                b64.encode(hasher.finish().to_be_bytes())
313            }
314            HttpEtagAlgorithm::XXH3(variation) => match variation {
315                XXH3Variation::Bit64 => {
316                    b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
317                }
318                XXH3Variation::Bit128 => {
319                    b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
320                }
321            },
322            HttpEtagAlgorithm::Rustc => {
323                let mut hasher = rustc_hash::FxHasher::default();
324                hasher.write(content);
325
326                b64.encode(hasher.finish().to_be_bytes())
327            }
328            HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
329        };
330    }
331
332    let mut hasher = AHasher::default();
333    hasher.write(content);
334    b64.encode(hasher.finish().to_be_bytes())
335}
336
337// todo: switch to using request extensions
338#[must_use]
339pub async fn x_via(headers: HeaderMap, request: Request, next: Next) -> Response {
340    let mut response = next.run(request).await;
341
342    if let Some(x_via) = headers.get(X_VIA) {
343        response.headers_mut().insert(header::VIA, x_via.to_owned());
344    }
345
346    response
347}
348
349pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
350    let headers = res.headers();
351
352    if let Some(curr_etag) = headers.get(header::ETAG)
353        && let Ok(curr_etag_str) = curr_etag.to_str()
354    {
355        let etag_len = curr_etag_str.len();
356
357        if (etag_len == 22 || etag_len == 11)
358            && let Some(compression) = headers.get(header::CONTENT_ENCODING)
359            && let Ok(compression_str) = compression.to_str()
360        {
361            let mut etag_string = curr_etag_str.to_owned();
362
363            match compression_str {
364                "gzip" => etag_string.push('1'),
365                "zstd" => etag_string.push('2'),
366                "br" => etag_string.push('3'),
367                "deflate" => etag_string.push('4'),
368                _ => (),
369            }
370
371            match HeaderValue::from_str(etag_string.as_str()) {
372                Ok(v) => return Some(v),
373                Err(err) => tracing::error!(%err),
374            }
375        } else {
376            return Some(curr_etag.clone());
377        }
378    }
379
380    None
381}
382
383pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
384    if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
385        && let Ok(if_none_match_str) = if_none_match.to_str()
386    {
387        if if_none_match_str.len() < 11 {
388            return None;
389        }
390
391        if (etag.len() == 23
392            || etag.len() == 12
393            || if_none_match_str.len() == 22
394            || if_none_match_str.len() == 11)
395            && if_none_match_str == etag
396        {
397            return Some(etag);
398        }
399
400        if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
401            Some(if_none_match_str)
402        } else {
403            None
404        }
405    } else {
406        None
407    }
408}