Skip to main content

ordinary_utils/
server.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request, Version};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::{IpAddr, SocketAddr};
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::timeout::TimeoutLayer;
13use tracing::{Instrument, Span};
14
15use crate::middleware::{ServiceKind, apply_common_middleware, x_via};
16use anyhow::anyhow;
17use axum::Router;
18use axum::body::Body;
19use axum::handler::Handler;
20use axum::response::Response;
21use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
22use blake2::{
23    Blake2bVar,
24    digest::{Update, VariableOutput},
25};
26use bytes::Bytes;
27use http_body_util::Full;
28use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
29use hyper::{HeaderMap, StatusCode, Uri, header};
30use ordinary_config::RedactedHashAlg;
31use rcgen::{CertifiedKey, generate_simple_self_signed};
32use rustls_acme::{AcmeState, EventError, EventOk};
33use std::any::Any;
34use std::fmt;
35use std::fmt::{Debug, Display};
36use std::fs::File;
37use std::io::Write;
38use std::path::Path;
39use std::sync::Arc;
40use tokio::sync::watch::{Receiver, Sender};
41use tokio_rustls::{
42    rustls::ServerConfig,
43    rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
44};
45use tracing::field::DisplayValue;
46use valuable::{Mappable, Valuable, Value, Visit};
47
48pub const X_VIA: HeaderName = HeaderName::from_static("x-via");
49pub const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
50pub const REPORTING_ENDPOINTS: HeaderName = HeaderName::from_static("reporting-endpoints");
51pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
52pub const X_FORWARDED_HOST: HeaderName = HeaderName::from_static("x-forwarded-host");
53pub const X_FORWARDED_PROTO: HeaderName = HeaderName::from_static("x-forwarded-proto");
54
55#[derive(PartialEq, Clone)]
56pub enum ProvisionMode {
57    Localhost,
58    Staging,
59    Production,
60}
61
62#[derive(Clone)]
63pub enum SecurityMode<T: AsRef<Path>> {
64    Insecure,
65    Secure(T, ProvisionMode),
66}
67
68pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
69
70impl WrappedRedactedHashingAlg {
71    fn hash(&self, header_value: &str) -> String {
72        let span = tracing::info_span!("redacted:hash");
73
74        span.in_scope(|| match self.0 {
75            RedactedHashAlg::Blake2 => {
76                let mut out = [0u8; 32];
77
78                let mut hasher = match Blake2bVar::new(32) {
79                    Ok(v) => v,
80                    Err(err) => {
81                        tracing::error!(%err);
82                        return "redacted".into();
83                    }
84                };
85
86                hasher.update(header_value.as_bytes());
87                if let Err(err) = hasher.finalize_variable(&mut out) {
88                    tracing::error!(%err);
89                    return "redacted".into();
90                }
91
92                b64.encode(&out[0..6])
93            }
94            RedactedHashAlg::Blake3 => {
95                b64.encode(&blake3::hash(header_value.as_bytes()).as_bytes()[0..6])
96            }
97        })
98    }
99}
100pub struct HeadersDebug<'a>(
101    pub &'a HeaderMap,
102    pub Arc<Option<WrappedRedactedHashingAlg>>,
103);
104
105#[cfg(tracing_unstable)]
106impl Valuable for HeadersDebug<'_> {
107    fn as_value(&self) -> Value<'_> {
108        Value::Mappable(self)
109    }
110
111    fn visit(&self, visit: &mut dyn Visit) {
112        for (k, v) in self.0 {
113            if let Ok(v) = v.to_str() {
114                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
115                {
116                    if let Some(hasher) = &*self.1 {
117                        visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
118                    } else {
119                        visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
120                    }
121                } else {
122                    visit.visit_entry(k.as_str().as_value(), v.as_value());
123                }
124            }
125        }
126    }
127}
128
129#[cfg(tracing_unstable)]
130impl Mappable for HeadersDebug<'_> {
131    fn size_hint(&self) -> (usize, Option<usize>) {
132        self.0.iter().size_hint()
133    }
134}
135
136impl Debug for HeadersDebug<'_> {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        use std::fmt::Write;
139
140        f.write_char('{')?;
141
142        let mut is_first = true;
143
144        for (k, v) in self.0 {
145            if let Ok(v) = v.to_str() {
146                if is_first {
147                    is_first = false;
148                    f.write_char('"')?;
149                } else {
150                    f.write_str(",\"")?;
151                }
152
153                f.write_str(k.as_str())?;
154                f.write_str("\":\"")?;
155
156                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
157                {
158                    f.write_str("redacted")?;
159                    f.write_char('"')?;
160                } else {
161                    f.write_str(v)?;
162                    f.write_char('"')?;
163                }
164            }
165        }
166
167        f.write_char('}')
168    }
169}
170
171#[must_use]
172pub fn get_http_version_str(v: Version) -> &'static str {
173    match v {
174        Version::HTTP_09 => "0.9",
175        Version::HTTP_10 => "1.0",
176        Version::HTTP_11 => "1.1",
177        Version::HTTP_2 => "2.0",
178        Version::HTTP_3 => "3.0",
179        _ => unreachable!(),
180    }
181}
182
183#[must_use]
184pub fn get_display_ip(log_ips: bool, req: &Request<Body>) -> Option<DisplayValue<IpAddr>> {
185    log_ips
186        .then(|| {
187            req.extensions()
188                .get::<ConnectInfo<SocketAddr>>()
189                .map(|addr| tracing::field::display(get_mapped_ip_for_addr(&addr.0)))
190        })
191        .flatten()
192}
193
194#[must_use]
195pub fn get_mapped_ip_for_addr(addr: &SocketAddr) -> IpAddr {
196    let ip = addr.ip();
197
198    if let IpAddr::V6(ipv6) = &ip
199        && let Some(ipv4) = ipv6.to_ipv4_mapped()
200    {
201        IpAddr::V4(ipv4)
202    } else {
203        ip
204    }
205}
206
207pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
208    if let Some(auth_header) = headers.get("authorization")
209        && let Ok(str_val) = auth_header.to_str()
210        && let Some(b64_token) = str_val.strip_prefix("Bearer ")
211        && let Ok(token) = b64.decode(b64_token)
212    {
213        Ok(token)
214    } else {
215        Err(anyhow!("failed to get token as bytes"))
216    }
217}
218
219pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
220    if let Some(forwarded_values) = headers.get(header::FORWARDED)
221        && let Ok(forwarded_values_str) = forwarded_values.to_str()
222        && let Some(first_value) = forwarded_values_str.split(',').next()
223        && let Some(host) = first_value.split(';').find_map(|pair| {
224            let (key, value) = pair.split_once('=')?;
225            key.trim()
226                .eq_ignore_ascii_case("host")
227                .then(|| value.trim().trim_matches('"'))
228        })
229    {
230        return Some(host.to_owned());
231    }
232
233    if let Some(host) = headers
234        .get(X_FORWARDED_HOST)
235        .and_then(|host| host.to_str().ok())
236    {
237        return Some(host.to_owned());
238    }
239
240    if let Some(host) = headers
241        .get(header::HOST)
242        .and_then(|host| host.to_str().ok())
243    {
244        return Some(host.to_owned());
245    }
246
247    if let Some(authority) = uri.authority() {
248        return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
249    }
250
251    None
252}
253
254pub struct LatencyDisplay(pub f64);
255
256impl Display for LatencyDisplay {
257    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258        let mut t = self.0;
259
260        for unit in ["ns", "µs", "ms", "s"] {
261            if t < 10.0 {
262                return write!(f, "{t:.2}{unit}");
263            } else if t < 100.0 {
264                return write!(f, "{t:.1}{unit}");
265            } else if t < 1000.0 {
266                return write!(f, "{t:.0}{unit}");
267            }
268            t /= 1000.0;
269        }
270        write!(f, "{:.0}s", t * 1000.0)
271    }
272}
273
274#[allow(clippy::needless_pass_by_value)]
275pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
276    #[allow(clippy::declare_interior_mutable_const)]
277    const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
278
279    let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
280
281    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
282    res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
283
284    res
285}
286
287pub fn rustls_server_config(
288    key: impl AsRef<Path>,
289    cert: impl AsRef<Path>,
290) -> anyhow::Result<Arc<ServerConfig>> {
291    let key = PrivateKeyDer::from_pem_file(key)?;
292
293    let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
294
295    let mut config = ServerConfig::builder()
296        .with_no_client_auth()
297        .with_single_cert(certs, key)?;
298
299    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
300
301    Ok(Arc::new(config))
302}
303
304/// writes crt.pem and key.pem to directory
305pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
306    std::fs::create_dir_all(&cert_dir_path)?;
307
308    let cert_path = cert_dir_path.as_ref().join("crt.pem");
309    let key_path = cert_dir_path.as_ref().join("key.pem");
310
311    if !cert_path.exists() || !key_path.exists() {
312        let subject_alt_names = vec!["localhost".to_string()];
313
314        let CertifiedKey { cert, signing_key } =
315            match generate_simple_self_signed(subject_alt_names) {
316                Ok(ck) => {
317                    tracing::info!("generated self-signed localhost cert");
318                    ck
319                }
320                Err(err) => {
321                    tracing::error!("failed to generate self-signed localhost cert");
322                    return Err(err.into());
323                }
324            };
325
326        let cert = cert.pem();
327        let key = signing_key.serialize_pem();
328
329        let mut cert_file = File::create(cert_path)?;
330        let mut key_file = File::create(key_path)?;
331
332        cert_file.write_all(cert.as_bytes())?;
333        cert_file.flush()?;
334        key_file.write_all(key.as_bytes())?;
335        key_file.flush()?;
336    }
337
338    Ok(())
339}
340
341pub fn acme_task(
342    acme_span_clone: Span,
343    mut state: AcmeState<std::io::Error>,
344    signal_tx: Sender<()>,
345    mut terminate_rx: Option<Receiver<bool>>,
346) {
347    tokio::spawn(async move {
348        async {
349            tracing::info!("ready");
350
351            loop {
352                let event = if let Some(terminate_rx) = terminate_rx.as_mut() {
353                    tokio::select! {
354                        state = state.next() => state,
355                        () = signal_tx.closed() => {
356                            tracing::warn!("not accepting new connections");
357                            break;
358                        },
359                        _ = terminate_rx.changed() => {
360                            tracing::warn!("not accepting new connections");
361                            break;
362                        }
363                    }
364                } else {
365                    tokio::select! {
366                        state = state.next() => state,
367                        () = signal_tx.closed() => {
368                            tracing::warn!("not accepting new connections");
369                            break;
370                        }
371                    }
372                };
373
374                if let Some(event) = event {
375                    match event {
376                        Ok(evt) => match evt {
377                            EventOk::DeployedNewCert => {
378                                tracing::info!(evt.deploy = %"new", "cert");
379                            }
380                            EventOk::CertCacheStore => {
381                                tracing::info!(evt.cache = %"stored", "cert");
382                            }
383                            EventOk::AccountCacheStore => {
384                                tracing::info!(evt.cache = %"stored", "account");
385                            }
386                            EventOk::DeployedCachedCert => {
387                                tracing::info!(evt.deploy = %"cached", "cert");
388                            }
389                        },
390                        Err(err) => match err {
391                            EventError::AccountCacheStore(err) => {
392                                tracing::error!(%err, evt.cache = %"store", "account");
393                            }
394                            EventError::CertCacheStore(err) => {
395                                tracing::error!(%err, evt.cache = %"store", "cert");
396                            }
397                            EventError::AccountCacheLoad(err) => {
398                                tracing::error!(%err, evt.cache = %"load", "account");
399                            }
400                            EventError::CachedCertParse(err) => {
401                                tracing::error!(%err, evt.parse = %"cache", "cert");
402                            }
403                            EventError::NewCertParse(err) => {
404                                tracing::error!(%err, evt.parse = %"new", "cert");
405                            }
406                            EventError::CertCacheLoad(err) => {
407                                tracing::error!(%err, evt.cache = %"load", "cert");
408                            }
409                            EventError::Order(err) => {
410                                tracing::error!(%err, "order");
411                            }
412                        },
413                    }
414                } else {
415                    break;
416                }
417            }
418        }
419        .instrument(acme_span_clone)
420        .await;
421    });
422}
423
424#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
425pub fn redirect_service<H, T, S>(
426    span_clone: Span,
427    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
428    log_ips: bool,
429    log_headers: bool,
430    handler: H,
431    state: S,
432    api_domain: Option<String>,
433) -> Router
434where
435    H: Handler<T, S>,
436    T: 'static,
437    S: Clone + Send + Sync + 'static,
438{
439    let router = Router::new()
440        .route("/healthz", get(|| async { StatusCode::OK }))
441        .fallback(handler);
442
443    apply_common_middleware(
444        router,
445        &state,
446        Some(span_clone),
447        String::new(),
448        log_headers,
449        log_ips,
450        redacted_hash,
451        ServiceKind::Redirect,
452        Some(api_domain.unwrap_or("redirect".to_owned())),
453    )
454    .layer(
455        ServiceBuilder::new()
456            .layer(TimeoutLayer::with_status_code(
457                StatusCode::REQUEST_TIMEOUT,
458                Duration::from_secs(5),
459            ))
460            .layer(axum::middleware::from_fn(x_via)),
461    )
462}