Skip to main content

ordinary_utils/
headers.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use crate::{
6    HeadersDebug, LatencyDisplay, WrappedRedactedHashingAlg, X_FORWARDED_FOR, X_FORWARDED_HOST,
7    X_FORWARDED_PROTO, X_VIA, get_host, get_http_version_str, get_mapped_ip_for_addr,
8};
9use axum::extract::{ConnectInfo, Request};
10use axum::http::{HeaderMap, HeaderValue, header};
11use hyper::{Method, Version};
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Instant;
15
16pub fn log_request(
17    log_headers: bool,
18    headers: &HeaderMap,
19    redacted_hash: &Arc<Option<WrappedRedactedHashingAlg>>,
20    method: &Method,
21) {
22    {
23        let hd = log_headers.then_some(HeadersDebug(headers, redacted_hash.clone()));
24
25        #[cfg(tracing_unstable)]
26        let headers = log_headers.then_some(tracing::field::valuable(&hd));
27        #[cfg(not(tracing_unstable))]
28        let headers = self.log_headers.then_some(tracing::field::debug(&hd));
29
30        tracing::info!(
31            %method,
32            headers,
33            "req"
34        );
35    }
36}
37
38pub fn log_response(
39    status: u16,
40    log_headers: bool,
41    redacted_hash: &Arc<Option<WrappedRedactedHashingAlg>>,
42    start: Instant,
43    headers: &HeaderMap,
44    version: Version,
45) {
46    let hd = log_headers.then_some(HeadersDebug(headers, redacted_hash.clone()));
47
48    #[cfg(tracing_unstable)]
49    let headers = log_headers.then_some(tracing::field::valuable(&hd));
50    #[cfg(not(tracing_unstable))]
51    let headers = log_headers.then_some(tracing::field::debug(&hd));
52
53    let latency = LatencyDisplay(start.elapsed().as_nanos() as f64);
54
55    if status >= 500 {
56        tracing::error!(version = ?version, status, headers, %latency, "res");
57    } else if status >= 400 {
58        tracing::warn!(version = ?version, status, headers, %latency, "res");
59    } else {
60        tracing::info!(version = ?version, status, headers, %latency, "res");
61    }
62}
63
64pub fn get_request_headers_for_forward(
65    req: &Request,
66    forwarded_by: &str,
67    forwarded_proto: &str,
68    via_domain: &str,
69) -> HeaderMap {
70    let mut headers = req.headers().clone();
71
72    let req_version = get_http_version_str(req.version());
73
74    let via = if let Some(src_via) = headers.get(header::VIA)
75        && let Ok(src_via) = src_via.to_str()
76    {
77        format!("{src_via}, {req_version} {via_domain} (ordinaryd)")
78    } else {
79        format!("{req_version} {via_domain} (ordinaryd)")
80    };
81
82    if let Ok(via) = HeaderValue::from_str(&via) {
83        headers.insert(header::VIA, via);
84    }
85
86    let connect_info = req.extensions().get::<ConnectInfo<SocketAddr>>();
87
88    let mut forwarded = if let Some(src_forwarded) = headers.get(header::FORWARDED)
89        && let Ok(src_forwarded) = src_forwarded.to_str()
90    {
91        format!("{src_forwarded}, by={forwarded_by}")
92    } else {
93        format!("by={forwarded_by}")
94    };
95
96    if let Some(addr) = connect_info {
97        let ip = get_mapped_ip_for_addr(&addr.0);
98        let ip_str = ip.to_string();
99
100        if ip.is_ipv6() {
101            forwarded = format!("{forwarded};for=\"[{ip_str}]\"");
102        } else {
103            forwarded = format!("{forwarded};for={ip_str}");
104        }
105
106        let forwarded_for = if let Some(src_forwarded_for) = headers.get("x-forwarded-for")
107            && let Ok(src_forwarded_for) = src_forwarded_for.to_str()
108        {
109            format!("{src_forwarded_for}, {ip_str}")
110        } else {
111            ip_str
112        };
113
114        if let Ok(forwarded_for) = HeaderValue::from_str(&forwarded_for) {
115            headers.insert(X_FORWARDED_FOR, forwarded_for);
116        }
117    }
118
119    if let Some(host) = get_host(req.headers(), req.uri()) {
120        forwarded = format!("{forwarded};host={host}");
121
122        if let Ok(forwarded_host) = HeaderValue::from_str(host.as_str()) {
123            headers.insert(X_FORWARDED_HOST, forwarded_host);
124        }
125    }
126
127    forwarded = format!("{forwarded};proto={forwarded_proto}");
128
129    if let Ok(forwarded_proto) = HeaderValue::from_str(forwarded_proto) {
130        headers.insert(X_FORWARDED_PROTO, forwarded_proto);
131    }
132
133    if let Ok(forwarded) = HeaderValue::from_str(&forwarded) {
134        headers.insert(header::FORWARDED, forwarded);
135    }
136
137    headers.remove(X_VIA);
138    headers
139}