Skip to main content

ordinary_utils/
server.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use anyhow::anyhow;
21use axum::Router;
22use axum::handler::Handler;
23use axum::response::Response;
24use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
25use blake2::{
26    Blake2bVar,
27    digest::{Update, VariableOutput},
28};
29use bytes::Bytes;
30use http_body_util::Full;
31use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
32use hyper::{HeaderMap, StatusCode, Uri, header};
33use ordinary_config::RedactedHashAlg;
34use rcgen::{CertifiedKey, generate_simple_self_signed};
35use rustls_acme::{AcmeState, EventError, EventOk};
36use std::any::Any;
37use std::fmt;
38use std::fmt::{Debug, Display};
39use std::fs::File;
40use std::io::Write;
41use std::path::Path;
42use std::sync::Arc;
43use tokio::sync::watch::Sender;
44use tokio_rustls::{
45    rustls::ServerConfig,
46    rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
47};
48use tower_http::catch_panic::CatchPanicLayer;
49use tower_http::compression::CompressionLayer;
50use tower_http::decompression::RequestDecompressionLayer;
51use valuable::{Mappable, Valuable, Value, Visit};
52
53pub const REQUEST_ID_HEADER: &str = "x-request-id";
54const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
55
56pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
57
58impl WrappedRedactedHashingAlg {
59    fn hash(&self, header_value: &str) -> String {
60        let span = tracing::info_span!("redacted:hash");
61
62        span.in_scope(|| match self.0 {
63            RedactedHashAlg::Blake2 => {
64                let mut out = [0u8; 32];
65
66                let mut hasher = match Blake2bVar::new(32) {
67                    Ok(v) => v,
68                    Err(err) => {
69                        tracing::error!(%err);
70                        return "redacted".into();
71                    }
72                };
73
74                hasher.update(header_value.as_bytes());
75                if let Err(err) = hasher.finalize_variable(&mut out) {
76                    tracing::error!(%err);
77                    return "redacted".into();
78                }
79
80                b64.encode(out)
81            }
82            RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
83        })
84    }
85}
86pub struct HeadersDebug<'a>(
87    pub &'a HeaderMap,
88    pub Arc<Option<WrappedRedactedHashingAlg>>,
89);
90
91#[cfg(tracing_unstable)]
92impl Valuable for HeadersDebug<'_> {
93    fn as_value(&self) -> Value<'_> {
94        Value::Mappable(self)
95    }
96
97    fn visit(&self, visit: &mut dyn Visit) {
98        for (k, v) in self.0 {
99            if let Ok(v) = v.to_str() {
100                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
101                {
102                    if let Some(hasher) = &*self.1 {
103                        visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
104                    } else {
105                        visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
106                    }
107                } else {
108                    visit.visit_entry(k.as_str().as_value(), v.as_value());
109                }
110            }
111        }
112    }
113}
114
115#[cfg(tracing_unstable)]
116impl Mappable for HeadersDebug<'_> {
117    fn size_hint(&self) -> (usize, Option<usize>) {
118        self.0.iter().size_hint()
119    }
120}
121
122impl Debug for HeadersDebug<'_> {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        use std::fmt::Write;
125
126        f.write_char('{')?;
127
128        let mut is_first = true;
129
130        for (k, v) in self.0 {
131            if let Ok(v) = v.to_str() {
132                if is_first {
133                    is_first = false;
134                    f.write_char('"')?;
135                } else {
136                    f.write_str(",\"")?;
137                }
138
139                f.write_str(k.as_str())?;
140                f.write_str("\":\"")?;
141
142                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
143                {
144                    f.write_str("redacted")?;
145                    f.write_char('"')?;
146                } else {
147                    f.write_str(v)?;
148                    f.write_char('"')?;
149                }
150            }
151        }
152
153        f.write_char('}')
154    }
155}
156
157pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
158    if let Some(auth_header) = headers.get("authorization")
159        && let Ok(str_val) = auth_header.to_str()
160        && let Some(b64_token) = str_val.strip_prefix("Bearer ")
161        && let Ok(token) = b64.decode(b64_token)
162    {
163        Ok(token)
164    } else {
165        Err(anyhow!("failed to get token as bytes"))
166    }
167}
168
169pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
170    if let Some(forwarded_values) = headers.get(header::FORWARDED)
171        && let Ok(forwarded_values_str) = forwarded_values.to_str()
172        && let Some(first_value) = forwarded_values_str.split(',').next()
173        && let Some(host) = first_value.split(';').find_map(|pair| {
174            let (key, value) = pair.split_once('=')?;
175            key.trim()
176                .eq_ignore_ascii_case("host")
177                .then(|| value.trim().trim_matches('"'))
178        })
179    {
180        return Some(host.to_owned());
181    }
182
183    if let Some(host) = headers
184        .get(X_FORWARDED_HOST_HEADER_KEY)
185        .and_then(|host| host.to_str().ok())
186    {
187        return Some(host.to_owned());
188    }
189
190    if let Some(host) = headers
191        .get(header::HOST)
192        .and_then(|host| host.to_str().ok())
193    {
194        return Some(host.to_owned());
195    }
196
197    if let Some(authority) = uri.authority() {
198        return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
199    }
200
201    None
202}
203
204pub struct LatencyDisplay(pub f64);
205
206impl Display for LatencyDisplay {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        let mut t = self.0;
209
210        for unit in ["ns", "µs", "ms", "s"] {
211            if t < 10.0 {
212                return write!(f, "{t:.2}{unit}");
213            } else if t < 100.0 {
214                return write!(f, "{t:.1}{unit}");
215            } else if t < 1000.0 {
216                return write!(f, "{t:.0}{unit}");
217            }
218            t /= 1000.0;
219        }
220        write!(f, "{:.0}s", t * 1000.0)
221    }
222}
223
224#[allow(clippy::needless_pass_by_value)]
225pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
226    #[allow(clippy::declare_interior_mutable_const)]
227    const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
228
229    let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
230
231    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
232    res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
233
234    res
235}
236
237pub fn rustls_server_config(
238    key: impl AsRef<Path>,
239    cert: impl AsRef<Path>,
240) -> anyhow::Result<Arc<ServerConfig>> {
241    let key = PrivateKeyDer::from_pem_file(key)?;
242
243    let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
244
245    let mut config = ServerConfig::builder()
246        .with_no_client_auth()
247        .with_single_cert(certs, key)?;
248
249    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
250
251    Ok(Arc::new(config))
252}
253
254/// writes crt.pem and key.pem to directory
255pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
256    std::fs::create_dir_all(&cert_dir_path)?;
257
258    let cert_path = cert_dir_path.as_ref().join("crt.pem");
259    let key_path = cert_dir_path.as_ref().join("key.pem");
260
261    if !cert_path.exists() || !key_path.exists() {
262        let subject_alt_names = vec!["localhost".to_string()];
263
264        let CertifiedKey { cert, signing_key } =
265            match generate_simple_self_signed(subject_alt_names) {
266                Ok(ck) => {
267                    tracing::info!("generated self-signed localhost cert");
268                    ck
269                }
270                Err(err) => {
271                    tracing::error!("failed to generate self-signed localhost cert");
272                    return Err(err.into());
273                }
274            };
275
276        let cert = cert.pem();
277        let key = signing_key.serialize_pem();
278
279        let mut cert_file = File::create(cert_path)?;
280        let mut key_file = File::create(key_path)?;
281
282        cert_file.write_all(cert.as_bytes())?;
283        key_file.write_all(key.as_bytes())?;
284    }
285
286    Ok(())
287}
288
289pub fn acme_task(
290    acme_span_clone: Span,
291    mut state: AcmeState<std::io::Error>,
292    signal_tx: Sender<()>,
293) {
294    tokio::spawn(async move {
295        loop {
296            let event = tokio::select! {
297                state = state.next() => state,
298                () = signal_tx.closed() => {
299                    acme_span_clone.in_scope(|| {
300                       tracing::warn!("not accepting new connections");
301                    });
302                    break;
303                }
304            };
305
306            if let Some(event) = event {
307                match event {
308                    Ok(evt) => {
309                        acme_span_clone.in_scope(|| match evt {
310                            EventOk::DeployedNewCert => {
311                                tracing::info!(evt.deploy = %"new", "cert");
312                            }
313                            EventOk::CertCacheStore => {
314                                tracing::info!(evt.cache = %"stored", "cert");
315                            }
316                            EventOk::AccountCacheStore => {
317                                tracing::info!(evt.cache = %"stored", "account");
318                            }
319                            EventOk::DeployedCachedCert => {
320                                tracing::info!(evt.deploy = %"cached", "cert");
321                            }
322                        });
323                    }
324                    Err(err) => match err {
325                        EventError::AccountCacheStore(err) => {
326                            tracing::error!(%err, evt.cache = %"store", "account");
327                        }
328                        EventError::CertCacheStore(err) => {
329                            tracing::error!(%err, evt.cache = %"store", "cert");
330                        }
331                        EventError::AccountCacheLoad(err) => {
332                            tracing::error!(%err, evt.cache = %"load", "account");
333                        }
334                        EventError::CachedCertParse(err) => {
335                            tracing::error!(%err, evt.parse = %"cache", "cert");
336                        }
337                        EventError::NewCertParse(err) => {
338                            tracing::error!(%err, evt.parse = %"new", "cert");
339                        }
340                        EventError::CertCacheLoad(err) => {
341                            tracing::error!(%err, evt.cache = %"load", "cert");
342                        }
343                        EventError::Order(err) => {
344                            tracing::error!(%err, "order");
345                        }
346                    },
347                }
348            } else {
349                break;
350            }
351        }
352    });
353}
354
355pub fn redirect_service<H, T, S>(
356    span_clone: Span,
357    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
358    log_ips: bool,
359    log_headers: bool,
360    request_id_header: HeaderName,
361    handler: H,
362    state: S,
363) -> Router
364where
365    H: Handler<T, S>,
366    T: 'static,
367    S: Clone + Send + Sync + 'static,
368{
369    let redacted_hash_clone = redacted_hash.clone();
370
371    Router::new()
372        .route("/healthz", get(|| async { StatusCode::OK }))
373        .fallback(handler)
374        .with_state(state)
375        .layer(
376            ServiceBuilder::new()
377                .layer(CatchPanicLayer::custom(response_for_panic))
378                .layer(RequestDecompressionLayer::new())
379                .layer(CompressionLayer::new()),
380        )
381        .layer(
382            ServiceBuilder::new()
383                .layer(SetRequestIdLayer::new(
384                    request_id_header.clone(),
385                    MakeRequestUuid,
386                ))
387                .layer(
388                    TraceLayer::new_for_http()
389                        .make_span_with(move |req: &Request<_>| {
390                            let request_id = req.headers().get(REQUEST_ID_HEADER);
391
392                            let host =
393                                get_host(req.headers(), req.uri()).map(tracing::field::display);
394
395                            let ip = log_ips.then(|| {
396                                req.extensions()
397                                    .get::<ConnectInfo<SocketAddr>>()
398                                    .map(|addr| tracing::field::display(addr.ip()))
399                            });
400
401                            let query = req.uri().query().map(tracing::field::display);
402
403                            span_clone.in_scope(|| match request_id {
404                                Some(rid) => {
405                                    tracing::warn_span!(
406                                        "redirect",
407                                        host,
408                                        id = %rid
409                                            .to_str()
410                                            .unwrap_or(Uuid::new_v4().to_string().as_str()),
411                                        ip,
412                                        path = %req.uri().path(),
413                                        query,
414                                    )
415                                }
416                                None => {
417                                    tracing::warn_span!(
418                                        "redirect",
419                                        host,
420                                        id = %Uuid::new_v4(),
421                                        ip,
422                                        path = %req.uri().path(),
423                                        query,
424                                    )
425                                }
426                            })
427                        })
428                        .on_request(move |req: &Request<_>, _: &Span| {
429                            let hd = log_headers
430                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
431
432                            #[cfg(tracing_unstable)]
433                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
434
435                            #[cfg(not(tracing_unstable))]
436                            let headers = log_headers.then_some(tracing::field::debug(&hd));
437
438                            tracing::warn!(
439                                version = ?req.version(),
440                                method = %req.method(),
441                                headers,
442                                "req"
443                            );
444                        })
445                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
446                            let hd = log_headers.then_some(HeadersDebug(
447                                res.headers(),
448                                redacted_hash_clone.clone(),
449                            ));
450
451                            #[cfg(tracing_unstable)]
452                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
453
454                            #[cfg(not(tracing_unstable))]
455                            let headers = log_headers.then_some(tracing::field::debug(&hd));
456
457                            let status = res.status().as_u16();
458                            let latency = LatencyDisplay(latency.as_nanos() as f64);
459
460                            if status >= 500 {
461                                tracing::error!(status, headers, %latency, "res");
462                            } else if status >= 400 {
463                                tracing::warn!(status, headers, %latency, "res");
464                            } else {
465                                tracing::info!(status, headers, %latency, "res");
466                            }
467                        })
468                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
469                            tracing::error!(
470                                err = %error,
471                                "fail"
472                            );
473                        }),
474                )
475                .layer(TimeoutLayer::with_status_code(
476                    StatusCode::REQUEST_TIMEOUT,
477                    Duration::from_secs(5),
478                ))
479                .layer(PropagateRequestIdLayer::new(request_id_header))
480                .layer(SetResponseHeaderLayer::if_not_present(
481                    header::SERVER,
482                    HeaderValue::from_static("Ordinary"),
483                )),
484        )
485}