Skip to main content

ordinary_utils/
server.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use anyhow::anyhow;
21use axum::Router;
22use axum::handler::Handler;
23use axum::response::Response;
24use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
25use blake2::{
26    Blake2bVar,
27    digest::{Update, VariableOutput},
28};
29use bytes::Bytes;
30use http_body_util::Full;
31use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
32use hyper::{HeaderMap, StatusCode, Uri, header};
33use ordinary_config::RedactedHashAlg;
34use rcgen::{CertifiedKey, generate_simple_self_signed};
35use rustls_acme::{AcmeState, EventError, EventOk};
36use std::any::Any;
37use std::fmt;
38use std::fmt::{Debug, Display};
39use std::fs::File;
40use std::io::Write;
41use std::path::Path;
42use std::sync::Arc;
43use tokio::sync::watch::Sender;
44use tokio_rustls::{
45    rustls::ServerConfig,
46    rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
47};
48use tower_http::catch_panic::CatchPanicLayer;
49use tower_http::compression::CompressionLayer;
50use tower_http::decompression::RequestDecompressionLayer;
51use valuable::{Mappable, Valuable, Value, Visit};
52
53pub const REQUEST_ID_HEADER: &str = "x-request-id";
54const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
55
56pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
57
58impl WrappedRedactedHashingAlg {
59    fn hash(&self, header_value: &str) -> String {
60        let span = tracing::info_span!("redacted:hash");
61
62        span.in_scope(|| match self.0 {
63            RedactedHashAlg::Blake2 => {
64                let mut out = [0u8; 32];
65
66                let mut hasher = match Blake2bVar::new(32) {
67                    Ok(v) => v,
68                    Err(err) => {
69                        tracing::error!(%err);
70                        return "redacted".into();
71                    }
72                };
73
74                hasher.update(header_value.as_bytes());
75                if let Err(err) = hasher.finalize_variable(&mut out) {
76                    tracing::error!(%err);
77                    return "redacted".into();
78                }
79
80                b64.encode(&out[0..6])
81            }
82            RedactedHashAlg::Blake3 => {
83                b64.encode(&blake3::hash(header_value.as_bytes()).as_bytes()[0..6])
84            }
85        })
86    }
87}
88pub struct HeadersDebug<'a>(
89    pub &'a HeaderMap,
90    pub Arc<Option<WrappedRedactedHashingAlg>>,
91);
92
93#[cfg(tracing_unstable)]
94impl Valuable for HeadersDebug<'_> {
95    fn as_value(&self) -> Value<'_> {
96        Value::Mappable(self)
97    }
98
99    fn visit(&self, visit: &mut dyn Visit) {
100        for (k, v) in self.0 {
101            if let Ok(v) = v.to_str() {
102                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
103                {
104                    if let Some(hasher) = &*self.1 {
105                        visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
106                    } else {
107                        visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
108                    }
109                } else {
110                    visit.visit_entry(k.as_str().as_value(), v.as_value());
111                }
112            }
113        }
114    }
115}
116
117#[cfg(tracing_unstable)]
118impl Mappable for HeadersDebug<'_> {
119    fn size_hint(&self) -> (usize, Option<usize>) {
120        self.0.iter().size_hint()
121    }
122}
123
124impl Debug for HeadersDebug<'_> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        use std::fmt::Write;
127
128        f.write_char('{')?;
129
130        let mut is_first = true;
131
132        for (k, v) in self.0 {
133            if let Ok(v) = v.to_str() {
134                if is_first {
135                    is_first = false;
136                    f.write_char('"')?;
137                } else {
138                    f.write_str(",\"")?;
139                }
140
141                f.write_str(k.as_str())?;
142                f.write_str("\":\"")?;
143
144                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
145                {
146                    f.write_str("redacted")?;
147                    f.write_char('"')?;
148                } else {
149                    f.write_str(v)?;
150                    f.write_char('"')?;
151                }
152            }
153        }
154
155        f.write_char('}')
156    }
157}
158
159pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
160    if let Some(auth_header) = headers.get("authorization")
161        && let Ok(str_val) = auth_header.to_str()
162        && let Some(b64_token) = str_val.strip_prefix("Bearer ")
163        && let Ok(token) = b64.decode(b64_token)
164    {
165        Ok(token)
166    } else {
167        Err(anyhow!("failed to get token as bytes"))
168    }
169}
170
171pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
172    if let Some(forwarded_values) = headers.get(header::FORWARDED)
173        && let Ok(forwarded_values_str) = forwarded_values.to_str()
174        && let Some(first_value) = forwarded_values_str.split(',').next()
175        && let Some(host) = first_value.split(';').find_map(|pair| {
176            let (key, value) = pair.split_once('=')?;
177            key.trim()
178                .eq_ignore_ascii_case("host")
179                .then(|| value.trim().trim_matches('"'))
180        })
181    {
182        return Some(host.to_owned());
183    }
184
185    if let Some(host) = headers
186        .get(X_FORWARDED_HOST_HEADER_KEY)
187        .and_then(|host| host.to_str().ok())
188    {
189        return Some(host.to_owned());
190    }
191
192    if let Some(host) = headers
193        .get(header::HOST)
194        .and_then(|host| host.to_str().ok())
195    {
196        return Some(host.to_owned());
197    }
198
199    if let Some(authority) = uri.authority() {
200        return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
201    }
202
203    None
204}
205
206pub struct LatencyDisplay(pub f64);
207
208impl Display for LatencyDisplay {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        let mut t = self.0;
211
212        for unit in ["ns", "µs", "ms", "s"] {
213            if t < 10.0 {
214                return write!(f, "{t:.2}{unit}");
215            } else if t < 100.0 {
216                return write!(f, "{t:.1}{unit}");
217            } else if t < 1000.0 {
218                return write!(f, "{t:.0}{unit}");
219            }
220            t /= 1000.0;
221        }
222        write!(f, "{:.0}s", t * 1000.0)
223    }
224}
225
226#[allow(clippy::needless_pass_by_value)]
227pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
228    #[allow(clippy::declare_interior_mutable_const)]
229    const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
230
231    let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
232
233    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
234    res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
235
236    res
237}
238
239pub fn rustls_server_config(
240    key: impl AsRef<Path>,
241    cert: impl AsRef<Path>,
242) -> anyhow::Result<Arc<ServerConfig>> {
243    let key = PrivateKeyDer::from_pem_file(key)?;
244
245    let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
246
247    let mut config = ServerConfig::builder()
248        .with_no_client_auth()
249        .with_single_cert(certs, key)?;
250
251    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
252
253    Ok(Arc::new(config))
254}
255
256/// writes crt.pem and key.pem to directory
257pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
258    std::fs::create_dir_all(&cert_dir_path)?;
259
260    let cert_path = cert_dir_path.as_ref().join("crt.pem");
261    let key_path = cert_dir_path.as_ref().join("key.pem");
262
263    if !cert_path.exists() || !key_path.exists() {
264        let subject_alt_names = vec!["localhost".to_string()];
265
266        let CertifiedKey { cert, signing_key } =
267            match generate_simple_self_signed(subject_alt_names) {
268                Ok(ck) => {
269                    tracing::info!("generated self-signed localhost cert");
270                    ck
271                }
272                Err(err) => {
273                    tracing::error!("failed to generate self-signed localhost cert");
274                    return Err(err.into());
275                }
276            };
277
278        let cert = cert.pem();
279        let key = signing_key.serialize_pem();
280
281        let mut cert_file = File::create(cert_path)?;
282        let mut key_file = File::create(key_path)?;
283
284        cert_file.write_all(cert.as_bytes())?;
285        cert_file.flush()?;
286        key_file.write_all(key.as_bytes())?;
287        key_file.flush()?;
288    }
289
290    Ok(())
291}
292
293pub fn acme_task(
294    acme_span_clone: Span,
295    mut state: AcmeState<std::io::Error>,
296    signal_tx: Sender<()>,
297) {
298    tokio::spawn(async move {
299        loop {
300            let event = tokio::select! {
301                state = state.next() => state,
302                () = signal_tx.closed() => {
303                    acme_span_clone.in_scope(|| {
304                       tracing::warn!("not accepting new connections");
305                    });
306                    break;
307                }
308            };
309
310            if let Some(event) = event {
311                match event {
312                    Ok(evt) => {
313                        acme_span_clone.in_scope(|| match evt {
314                            EventOk::DeployedNewCert => {
315                                tracing::info!(evt.deploy = %"new", "cert");
316                            }
317                            EventOk::CertCacheStore => {
318                                tracing::info!(evt.cache = %"stored", "cert");
319                            }
320                            EventOk::AccountCacheStore => {
321                                tracing::info!(evt.cache = %"stored", "account");
322                            }
323                            EventOk::DeployedCachedCert => {
324                                tracing::info!(evt.deploy = %"cached", "cert");
325                            }
326                        });
327                    }
328                    Err(err) => match err {
329                        EventError::AccountCacheStore(err) => {
330                            tracing::error!(%err, evt.cache = %"store", "account");
331                        }
332                        EventError::CertCacheStore(err) => {
333                            tracing::error!(%err, evt.cache = %"store", "cert");
334                        }
335                        EventError::AccountCacheLoad(err) => {
336                            tracing::error!(%err, evt.cache = %"load", "account");
337                        }
338                        EventError::CachedCertParse(err) => {
339                            tracing::error!(%err, evt.parse = %"cache", "cert");
340                        }
341                        EventError::NewCertParse(err) => {
342                            tracing::error!(%err, evt.parse = %"new", "cert");
343                        }
344                        EventError::CertCacheLoad(err) => {
345                            tracing::error!(%err, evt.cache = %"load", "cert");
346                        }
347                        EventError::Order(err) => {
348                            tracing::error!(%err, "order");
349                        }
350                    },
351                }
352            } else {
353                break;
354            }
355        }
356    });
357}
358
359#[allow(clippy::too_many_lines)]
360pub fn redirect_service<H, T, S>(
361    span_clone: Span,
362    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
363    log_ips: bool,
364    log_headers: bool,
365    request_id_header: HeaderName,
366    handler: H,
367    state: S,
368) -> Router
369where
370    H: Handler<T, S>,
371    T: 'static,
372    S: Clone + Send + Sync + 'static,
373{
374    let redacted_hash_clone = redacted_hash.clone();
375
376    Router::new()
377        .route("/healthz", get(|| async { StatusCode::OK }))
378        .fallback(handler)
379        .with_state(state)
380        .layer(
381            ServiceBuilder::new()
382                .layer(CatchPanicLayer::custom(response_for_panic))
383                .layer(RequestDecompressionLayer::new())
384                .layer(CompressionLayer::new()),
385        )
386        .layer(
387            ServiceBuilder::new()
388                .layer(SetRequestIdLayer::new(
389                    request_id_header.clone(),
390                    MakeRequestUuid,
391                ))
392                .layer(
393                    TraceLayer::new_for_http()
394                        .make_span_with(move |req: &Request<_>| {
395                            let request_id = req.headers().get(REQUEST_ID_HEADER);
396
397                            let host =
398                                get_host(req.headers(), req.uri()).map(tracing::field::display);
399
400                            let ip = log_ips.then(|| {
401                                req.extensions()
402                                    .get::<ConnectInfo<SocketAddr>>()
403                                    .map(|addr| tracing::field::display(addr.ip()))
404                            });
405
406                            let query = req.uri().query().map(tracing::field::display);
407
408                            span_clone.in_scope(|| match request_id {
409                                Some(rid) => {
410                                    tracing::warn_span!(
411                                        "redirect",
412                                        host,
413                                        rid = %rid
414                                            .to_str()
415                                            .unwrap_or(Uuid::new_v4().to_string().as_str()),
416                                        ip,
417                                        path = %req.uri().path(),
418                                        query,
419                                    )
420                                }
421                                None => {
422                                    tracing::warn_span!(
423                                        "redirect",
424                                        host,
425                                        rid = %Uuid::new_v4(),
426                                        ip,
427                                        path = %req.uri().path(),
428                                        query,
429                                    )
430                                }
431                            })
432                        })
433                        .on_request(move |req: &Request<_>, _: &Span| {
434                            let hd = log_headers
435                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
436
437                            #[cfg(tracing_unstable)]
438                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
439
440                            #[cfg(not(tracing_unstable))]
441                            let headers = log_headers.then_some(tracing::field::debug(&hd));
442
443                            tracing::warn!(
444                                version = ?req.version(),
445                                method = %req.method(),
446                                headers,
447                                "req"
448                            );
449                        })
450                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
451                            let hd = log_headers.then_some(HeadersDebug(
452                                res.headers(),
453                                redacted_hash_clone.clone(),
454                            ));
455
456                            #[cfg(tracing_unstable)]
457                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
458
459                            #[cfg(not(tracing_unstable))]
460                            let headers = log_headers.then_some(tracing::field::debug(&hd));
461
462                            let status = res.status().as_u16();
463                            let latency = LatencyDisplay(latency.as_nanos() as f64);
464
465                            if status >= 500 {
466                                tracing::error!(status, headers, %latency, "res");
467                            } else if status >= 400 {
468                                tracing::warn!(status, headers, %latency, "res");
469                            } else {
470                                tracing::info!(status, headers, %latency, "res");
471                            }
472                        })
473                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
474                            tracing::error!(
475                                err = %error,
476                                "fail"
477                            );
478                        }),
479                )
480                .layer(TimeoutLayer::with_status_code(
481                    StatusCode::REQUEST_TIMEOUT,
482                    Duration::from_secs(5),
483                ))
484                .layer(PropagateRequestIdLayer::new(request_id_header))
485                .layer(SetResponseHeaderLayer::if_not_present(
486                    header::SERVER,
487                    HeaderValue::from_static("Ordinary"),
488                )),
489        )
490}