Skip to main content

ordinary_utils/
middleware.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use crate::{
6    GMT_FORMAT, HeadersDebug, LatencyDisplay, SERVER, WrappedRedactedHashingAlg, X_REQUEST_ID,
7    X_VIA, get_host, get_http_version_str, response_for_panic,
8};
9use ahash::AHasher;
10use axum::Router;
11use axum::extract::Request;
12use axum::http::{HeaderValue, StatusCode, header};
13use axum::middleware::Next;
14use axum::response::{IntoResponse, Response};
15use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
16use http_body_util::BodyExt;
17use hyper::HeaderMap;
18use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
19use std::hash::Hasher;
20use std::sync::Arc;
21use std::time::Duration;
22use time::UtcDateTime;
23use tower::ServiceBuilder;
24use tower_http::catch_panic::CatchPanicLayer;
25use tower_http::classify::ServerErrorsFailureClass;
26use tower_http::compression::CompressionLayer;
27use tower_http::decompression::RequestDecompressionLayer;
28use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
29use tower_http::set_header::{SetRequestHeaderLayer, SetResponseHeaderLayer};
30use tower_http::trace::TraceLayer;
31use tracing::Span;
32use uuid::Uuid;
33
34#[derive(Clone)]
35pub enum ServiceKind {
36    App,
37    Api,
38    Redirect,
39    Proxy,
40}
41
42#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
43pub fn apply_common_middleware<T>(
44    router: Router<T>,
45    state: &T,
46    server_span: Option<Span>,
47    domain: String,
48    log_headers: bool,
49    log_ips: bool,
50    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
51    kind: ServiceKind,
52    via_domain: Option<String>,
53) -> Router
54where
55    T: Clone + Send + Sync + 'static,
56{
57    let redacted_hash_clone = redacted_hash.clone();
58
59    router
60        .with_state(state.clone())
61        .layer(
62            ServiceBuilder::new()
63                .layer(CatchPanicLayer::custom(response_for_panic))
64                .layer(RequestDecompressionLayer::new())
65                .layer(CompressionLayer::new()),
66        )
67        .layer(
68            ServiceBuilder::new()
69                .layer(SetRequestIdLayer::new(X_REQUEST_ID, MakeRequestUuid))
70                .layer(
71                    TraceLayer::new_for_http()
72                        .make_span_with(move |req: &axum::http::Request<_>| {
73                            let request_id = req
74                                .headers()
75                                .get(X_REQUEST_ID)
76                                .and_then(|rid| {
77                                    rid.to_str()
78                                        .ok()
79                                        .and_then(|rid| Uuid::parse_str(rid).ok())
80                                        .map(tracing::field::display)
81                                })
82                                .unwrap_or(tracing::field::display(Uuid::new_v4()));
83
84                            let host =
85                                get_host(req.headers(), req.uri()).map(tracing::field::display);
86
87                            let ip = crate::get_display_ip(log_ips, req);
88
89                            let query = req.uri().query().map(tracing::field::display);
90
91                            match kind {
92                                ServiceKind::App => {
93                                    if let Some(server_span) = &server_span {
94                                        server_span.in_scope(|| {
95                                            tracing::info_span!(
96                                                "app",
97                                                %domain,
98                                                host,
99                                                rid = %request_id,
100                                                ip,
101                                                path = %req.uri().path(),
102                                                query,
103                                            )
104                                        })
105                                    } else {
106                                        tracing::info_span!(
107                                            "app",
108                                            %domain,
109                                            host,
110                                            rid = %request_id,
111                                            ip,
112                                            path = %req.uri().path(),
113                                            query,
114                                        )
115                                    }
116                                }
117                                ServiceKind::Proxy => {
118                                    if let Some(server_span) = &server_span {
119                                        server_span.in_scope(|| {
120                                            tracing::info_span!(
121                                                "proxy",
122                                                %domain,
123                                                host,
124                                                rid = %request_id,
125                                                ip,
126                                                path = %req.uri().path(),
127                                                query,
128                                            )
129                                        })
130                                    } else {
131                                        tracing::info_span!(
132                                            "proxy",
133                                            %domain,
134                                            host,
135                                            rid = %request_id,
136                                            ip,
137                                            path = %req.uri().path(),
138                                            query,
139                                        )
140                                    }
141                                }
142                                ServiceKind::Api => {
143                                    if let Some(server_span) = &server_span {
144                                        server_span.in_scope(|| {
145                                            tracing::info_span!(
146                                                "api",
147                                                %domain,
148                                                host,
149                                                rid = %request_id,
150                                                ip,
151                                                path = %req.uri().path(),
152                                                query,
153                                            )
154                                        })
155                                    } else {
156                                        tracing::info_span!(
157                                            "api",
158                                            %domain,
159                                            host,
160                                            rid = %request_id,
161                                            ip,
162                                            path = %req.uri().path(),
163                                            query,
164                                        )
165                                    }
166                                }
167                                ServiceKind::Redirect => {
168                                    if let Some(server_span) = &server_span {
169                                        server_span.in_scope(|| {
170                                            tracing::info_span!(
171                                                "redirect",
172                                                host,
173                                                rid = %request_id,
174                                                ip,
175                                                path = %req.uri().path(),
176                                                query,
177                                            )
178                                        })
179                                    } else {
180                                        tracing::info_span!(
181                                            "redirect",
182                                            host,
183                                            rid = %request_id,
184                                            ip,
185                                            path = %req.uri().path(),
186                                            query,
187                                        )
188                                    }
189                                }
190                            }
191                        })
192                        .on_request(move |req: &axum::http::Request<_>, _: &Span| {
193                            let hd = log_headers
194                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
195
196                            #[cfg(tracing_unstable)]
197                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
198                            #[cfg(not(tracing_unstable))]
199                            let headers = log_headers.then_some(tracing::field::debug(&hd));
200
201                            tracing::info!(
202                                version = ?req.version(),
203                                method = %req.method(),
204                                headers,
205                                "req"
206                            );
207                        })
208                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
209                            let hd = log_headers.then_some(HeadersDebug(
210                                res.headers(),
211                                redacted_hash_clone.clone(),
212                            ));
213
214                            #[cfg(tracing_unstable)]
215                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
216
217                            #[cfg(not(tracing_unstable))]
218                            let headers = log_headers.then_some(tracing::field::debug(&hd));
219
220                            let status = res.status().as_u16();
221                            let latency = LatencyDisplay(latency.as_nanos() as f64);
222
223                            if status >= 500 {
224                                tracing::error!(status, headers, %latency, "res");
225                            } else if status >= 400 {
226                                tracing::warn!(status, headers, %latency, "res");
227                            } else {
228                                tracing::info!(status, headers, %latency, "res");
229                            }
230                        })
231                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
232                            tracing::error!(
233                                err = %error,
234                                "fail"
235                            );
236                        }),
237                )
238                .layer(PropagateRequestIdLayer::new(X_REQUEST_ID))
239                .layer(SetResponseHeaderLayer::if_not_present(
240                    header::SERVER,
241                    HeaderValue::from_static(SERVER),
242                ))
243                .option_layer(via_domain.map(|domain| {
244                    SetRequestHeaderLayer::overriding(X_VIA, move |req: &axum::http::Request<_>| {
245                        let req_version = get_http_version_str(req.version());
246                        HeaderValue::from_str(&format!("{req_version} {domain} (ordinaryd)")).ok()
247                    })
248                })),
249        )
250}
251
252#[allow(clippy::similar_names)]
253pub async fn http_cache_middleware(
254    last_modified: UtcDateTime,
255    req_headers: HeaderMap,
256    request: Request,
257    next: Next,
258) -> Response {
259    let response = next.run(request).await;
260    let (mut parts, body) = response.into_parts();
261
262    let body_bytes = if let Ok(collected) = body.collect().await {
263        collected.to_bytes()
264    } else {
265        return StatusCode::INTERNAL_SERVER_ERROR.into_response();
266    };
267
268    let mut res_headers = HeaderMap::new();
269
270    let etag_string = get_etag_hash(body_bytes.as_ref(), None);
271    let etag_str = etag_string.as_str();
272
273    if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
274        && let Ok(if_none_match_str) = if_none_match.to_str()
275        && if_none_match_str == etag_str
276    {
277        res_headers.insert(header::ETAG, if_none_match.to_owned());
278
279        return (StatusCode::NOT_MODIFIED, res_headers).into_response();
280    } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
281        if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
282            && let Ok(if_modified_since_str) = if_modified_since.to_str()
283            && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
284            && if_modified_since >= last_modified
285        {
286            res_headers.insert(header::ETAG, etag_header);
287            return (StatusCode::NOT_MODIFIED, res_headers).into_response();
288        }
289
290        parts.headers.insert(header::ETAG, etag_header);
291    }
292
293    (parts, body_bytes).into_response()
294}
295
296#[must_use]
297pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
298    if let Some(http_cache) = http_cache
299        && let Some(etag_config) = &http_cache.etag
300        && let Some(etag_alg) = &etag_config.alg
301    {
302        return match etag_alg {
303            HttpEtagAlgorithm::AHash => {
304                let mut hasher = AHasher::default();
305                hasher.write(content);
306                b64.encode(hasher.finish().to_be_bytes())
307            }
308            HttpEtagAlgorithm::XXH3(variation) => match variation {
309                XXH3Variation::Bit64 => {
310                    b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
311                }
312                XXH3Variation::Bit128 => {
313                    b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
314                }
315            },
316            HttpEtagAlgorithm::Rustc => {
317                let mut hasher = rustc_hash::FxHasher::default();
318                hasher.write(content);
319
320                b64.encode(hasher.finish().to_be_bytes())
321            }
322            HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
323        };
324    }
325
326    let mut hasher = AHasher::default();
327    hasher.write(content);
328    b64.encode(hasher.finish().to_be_bytes())
329}
330
331#[must_use]
332pub async fn x_via(headers: HeaderMap, request: Request, next: Next) -> Response {
333    let mut response = next.run(request).await;
334
335    if let Some(x_via) = headers.get(X_VIA) {
336        response.headers_mut().insert(header::VIA, x_via.to_owned());
337    }
338
339    response
340}
341
342pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
343    let headers = res.headers();
344
345    if let Some(curr_etag) = headers.get(header::ETAG)
346        && let Ok(curr_etag_str) = curr_etag.to_str()
347    {
348        let etag_len = curr_etag_str.len();
349
350        if (etag_len == 22 || etag_len == 11)
351            && let Some(compression) = headers.get(header::CONTENT_ENCODING)
352            && let Ok(compression_str) = compression.to_str()
353        {
354            let mut etag_string = curr_etag_str.to_owned();
355
356            match compression_str {
357                "gzip" => etag_string.push('1'),
358                "zstd" => etag_string.push('2'),
359                "br" => etag_string.push('3'),
360                "deflate" => etag_string.push('4'),
361                _ => (),
362            }
363
364            match HeaderValue::from_str(etag_string.as_str()) {
365                Ok(v) => return Some(v),
366                Err(err) => tracing::error!(%err),
367            }
368        } else {
369            return Some(curr_etag.clone());
370        }
371    }
372
373    None
374}
375
376pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
377    if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
378        && let Ok(if_none_match_str) = if_none_match.to_str()
379    {
380        if if_none_match_str.len() < 11 {
381            return None;
382        }
383
384        if (etag.len() == 23
385            || etag.len() == 12
386            || if_none_match_str.len() == 22
387            || if_none_match_str.len() == 11)
388            && if_none_match_str == etag
389        {
390            return Some(etag);
391        }
392
393        if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
394            Some(if_none_match_str)
395        } else {
396            None
397        }
398    } else {
399        None
400    }
401}