Skip to main content

ordinary_utils/
server.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use axum::Router;
21use axum::handler::Handler;
22use axum::response::Response;
23use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
24use blake2::{
25    Blake2bVar,
26    digest::{Update, VariableOutput},
27};
28use bytes::Bytes;
29use http_body_util::Full;
30use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
31use hyper::{HeaderMap, StatusCode, Uri, header};
32use ordinary_config::RedactedHashAlg;
33use rcgen::{CertifiedKey, generate_simple_self_signed};
34use rustls_acme::{AcmeState, EventError, EventOk};
35use std::any::Any;
36use std::error::Error;
37use std::fmt;
38use std::fmt::{Debug, Display};
39use std::fs::File;
40use std::io::Write;
41use std::path::Path;
42use std::sync::Arc;
43use tokio::sync::watch::Sender;
44use tokio_rustls::{
45    rustls::ServerConfig,
46    rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
47};
48use tower_http::catch_panic::CatchPanicLayer;
49use tower_http::compression::CompressionLayer;
50use tower_http::decompression::RequestDecompressionLayer;
51use valuable::{Mappable, Valuable, Value, Visit};
52
53pub const REQUEST_ID_HEADER: &str = "x-request-id";
54const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
55
56pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
57
58impl WrappedRedactedHashingAlg {
59    fn hash(&self, header_value: &str) -> String {
60        let span = tracing::info_span!("redacted:hash");
61
62        span.in_scope(|| match self.0 {
63            RedactedHashAlg::Blake2 => {
64                let mut out = [0u8; 32];
65
66                let mut hasher = match Blake2bVar::new(32) {
67                    Ok(v) => v,
68                    Err(err) => {
69                        tracing::error!(%err);
70                        return "redacted".into();
71                    }
72                };
73
74                hasher.update(header_value.as_bytes());
75                if let Err(err) = hasher.finalize_variable(&mut out) {
76                    tracing::error!(%err);
77                    return "redacted".into();
78                }
79
80                b64.encode(out)
81            }
82            RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
83        })
84    }
85}
86pub struct HeadersDebug<'a>(
87    pub &'a HeaderMap,
88    pub Arc<Option<WrappedRedactedHashingAlg>>,
89);
90
91#[cfg(tracing_unstable)]
92impl Valuable for HeadersDebug<'_> {
93    fn as_value(&self) -> Value<'_> {
94        Value::Mappable(self)
95    }
96
97    fn visit(&self, visit: &mut dyn Visit) {
98        for (k, v) in self.0 {
99            if let Ok(v) = v.to_str() {
100                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
101                {
102                    if let Some(hasher) = &*self.1 {
103                        visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
104                    } else {
105                        visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
106                    }
107                } else {
108                    visit.visit_entry(k.as_str().as_value(), v.as_value());
109                }
110            }
111        }
112    }
113}
114
115#[cfg(tracing_unstable)]
116impl Mappable for HeadersDebug<'_> {
117    fn size_hint(&self) -> (usize, Option<usize>) {
118        self.0.iter().size_hint()
119    }
120}
121
122impl Debug for HeadersDebug<'_> {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        use std::fmt::Write;
125
126        f.write_char('{')?;
127
128        let mut is_first = true;
129
130        for (k, v) in self.0 {
131            if let Ok(v) = v.to_str() {
132                if is_first {
133                    is_first = false;
134                    f.write_char('"')?;
135                } else {
136                    f.write_str(",\"")?;
137                }
138
139                f.write_str(k.as_str())?;
140                f.write_str("\":\"")?;
141
142                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
143                {
144                    f.write_str("redacted")?;
145                    f.write_char('"')?;
146                } else {
147                    f.write_str(v)?;
148                    f.write_char('"')?;
149                }
150            }
151        }
152
153        f.write_char('}')
154    }
155}
156
157pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
158    if let Some(forwarded_values) = headers.get(header::FORWARDED)
159        && let Ok(forwarded_values_str) = forwarded_values.to_str()
160        && let Some(first_value) = forwarded_values_str.split(',').next()
161        && let Some(host) = first_value.split(';').find_map(|pair| {
162            let (key, value) = pair.split_once('=')?;
163            key.trim()
164                .eq_ignore_ascii_case("host")
165                .then(|| value.trim().trim_matches('"'))
166        })
167    {
168        return Some(host.to_owned());
169    }
170
171    if let Some(host) = headers
172        .get(X_FORWARDED_HOST_HEADER_KEY)
173        .and_then(|host| host.to_str().ok())
174    {
175        return Some(host.to_owned());
176    }
177
178    if let Some(host) = headers
179        .get(header::HOST)
180        .and_then(|host| host.to_str().ok())
181    {
182        return Some(host.to_owned());
183    }
184
185    if let Some(authority) = uri.authority() {
186        return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
187    }
188
189    None
190}
191
192pub struct LatencyDisplay(pub f64);
193
194impl Display for LatencyDisplay {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        let mut t = self.0;
197
198        for unit in ["ns", "µs", "ms", "s"] {
199            if t < 10.0 {
200                return write!(f, "{t:.2}{unit}");
201            } else if t < 100.0 {
202                return write!(f, "{t:.1}{unit}");
203            } else if t < 1000.0 {
204                return write!(f, "{t:.0}{unit}");
205            }
206            t /= 1000.0;
207        }
208        write!(f, "{:.0}s", t * 1000.0)
209    }
210}
211
212#[allow(clippy::needless_pass_by_value)]
213pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
214    #[allow(clippy::declare_interior_mutable_const)]
215    const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
216
217    let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
218
219    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
220    res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
221
222    res
223}
224
225pub fn rustls_server_config(
226    key: impl AsRef<Path>,
227    cert: impl AsRef<Path>,
228) -> Result<Arc<ServerConfig>, Box<dyn Error>> {
229    let key = PrivateKeyDer::from_pem_file(key)?;
230
231    let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
232
233    let mut config = ServerConfig::builder()
234        .with_no_client_auth()
235        .with_single_cert(certs, key)?;
236
237    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
238
239    Ok(Arc::new(config))
240}
241
242/// writes crt.pem and key.pem to directory
243pub fn generate_self_signed_localhost_certs(
244    cert_dir_path: impl AsRef<Path>,
245) -> Result<(), Box<dyn Error>> {
246    std::fs::create_dir_all(&cert_dir_path)?;
247
248    let cert_path = cert_dir_path.as_ref().join("crt.pem");
249    let key_path = cert_dir_path.as_ref().join("key.pem");
250
251    if !cert_path.exists() || !key_path.exists() {
252        let subject_alt_names = vec!["localhost".to_string()];
253
254        let CertifiedKey { cert, signing_key } =
255            match generate_simple_self_signed(subject_alt_names) {
256                Ok(ck) => {
257                    tracing::info!("generated self-signed localhost cert");
258                    ck
259                }
260                Err(err) => {
261                    tracing::error!("failed to generate self-signed localhost cert");
262                    return Err(err.into());
263                }
264            };
265
266        let cert = cert.pem();
267        let key = signing_key.serialize_pem();
268
269        let mut cert_file = File::create(cert_path)?;
270        let mut key_file = File::create(key_path)?;
271
272        cert_file.write_all(cert.as_bytes())?;
273        key_file.write_all(key.as_bytes())?;
274    }
275
276    Ok(())
277}
278
279pub fn acme_task(
280    acme_span_clone: Span,
281    mut state: AcmeState<std::io::Error>,
282    signal_tx: Sender<()>,
283) {
284    tokio::spawn(async move {
285        loop {
286            let event = tokio::select! {
287                state = state.next() => state,
288                () = signal_tx.closed() => {
289                    acme_span_clone.in_scope(|| {
290                       tracing::warn!("not accepting new connections");
291                    });
292                    break;
293                }
294            };
295
296            if let Some(event) = event {
297                match event {
298                    Ok(evt) => {
299                        acme_span_clone.in_scope(|| match evt {
300                            EventOk::DeployedNewCert => {
301                                tracing::info!(evt.deploy = %"new", "cert");
302                            }
303                            EventOk::CertCacheStore => {
304                                tracing::info!(evt.cache = %"stored", "cert");
305                            }
306                            EventOk::AccountCacheStore => {
307                                tracing::info!(evt.cache = %"stored", "account");
308                            }
309                            EventOk::DeployedCachedCert => {
310                                tracing::info!(evt.deploy = %"cached", "cert");
311                            }
312                        });
313                    }
314                    Err(err) => match err {
315                        EventError::AccountCacheStore(err) => {
316                            tracing::error!(%err, evt.cache = %"store", "account");
317                        }
318                        EventError::CertCacheStore(err) => {
319                            tracing::error!(%err, evt.cache = %"store", "cert");
320                        }
321                        EventError::AccountCacheLoad(err) => {
322                            tracing::error!(%err, evt.cache = %"load", "account");
323                        }
324                        EventError::CachedCertParse(err) => {
325                            tracing::error!(%err, evt.parse = %"cache", "cert");
326                        }
327                        EventError::NewCertParse(err) => {
328                            tracing::error!(%err, evt.parse = %"new", "cert");
329                        }
330                        EventError::CertCacheLoad(err) => {
331                            tracing::error!(%err, evt.cache = %"load", "cert");
332                        }
333                        EventError::Order(err) => {
334                            tracing::error!(%err, "order");
335                        }
336                    },
337                }
338            } else {
339                break;
340            }
341        }
342    });
343}
344
345pub fn redirect_service<H, T, S>(
346    span_clone: Span,
347    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
348    log_ips: bool,
349    log_headers: bool,
350    request_id_header: HeaderName,
351    handler: H,
352    state: S,
353) -> Router
354where
355    H: Handler<T, S>,
356    T: 'static,
357    S: Clone + Send + Sync + 'static,
358{
359    let redacted_hash_clone = redacted_hash.clone();
360
361    Router::new()
362        .route("/healthz", get(|| async { StatusCode::OK }))
363        .fallback(handler)
364        .with_state(state)
365        .layer(
366            ServiceBuilder::new()
367                .layer(CatchPanicLayer::custom(response_for_panic))
368                .layer(RequestDecompressionLayer::new())
369                .layer(CompressionLayer::new()),
370        )
371        .layer(
372            ServiceBuilder::new()
373                .layer(SetRequestIdLayer::new(
374                    request_id_header.clone(),
375                    MakeRequestUuid,
376                ))
377                .layer(
378                    TraceLayer::new_for_http()
379                        .make_span_with(move |req: &Request<_>| {
380                            let request_id = req.headers().get(REQUEST_ID_HEADER);
381
382                            let host =
383                                get_host(req.headers(), req.uri()).map(tracing::field::display);
384
385                            let ip = log_ips.then(|| {
386                                req.extensions()
387                                    .get::<ConnectInfo<SocketAddr>>()
388                                    .map(|addr| tracing::field::display(addr.ip()))
389                            });
390
391                            let query = req.uri().query().map(tracing::field::display);
392
393                            span_clone.in_scope(|| match request_id {
394                                Some(rid) => {
395                                    tracing::warn_span!(
396                                        "redirect",
397                                        host,
398                                        id = %rid
399                                            .to_str()
400                                            .unwrap_or(Uuid::new_v4().to_string().as_str()),
401                                        ip,
402                                        path = %req.uri().path(),
403                                        query,
404                                    )
405                                }
406                                None => {
407                                    tracing::warn_span!(
408                                        "redirect",
409                                        host,
410                                        id = %Uuid::new_v4(),
411                                        ip,
412                                        path = %req.uri().path(),
413                                        query,
414                                    )
415                                }
416                            })
417                        })
418                        .on_request(move |req: &Request<_>, _: &Span| {
419                            let hd = log_headers
420                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
421
422                            #[cfg(tracing_unstable)]
423                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
424
425                            #[cfg(not(tracing_unstable))]
426                            let headers = log_headers.then_some(tracing::field::debug(&hd));
427
428                            tracing::warn!(
429                                version = ?req.version(),
430                                method = %req.method(),
431                                headers,
432                                "req"
433                            );
434                        })
435                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
436                            let hd = log_headers.then_some(HeadersDebug(
437                                res.headers(),
438                                redacted_hash_clone.clone(),
439                            ));
440
441                            #[cfg(tracing_unstable)]
442                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
443
444                            #[cfg(not(tracing_unstable))]
445                            let headers = log_headers.then_some(tracing::field::debug(&hd));
446
447                            let status = res.status().as_u16();
448                            let latency = LatencyDisplay(latency.as_nanos() as f64);
449
450                            if status >= 500 {
451                                tracing::error!(status, headers, %latency, "res");
452                            } else if status >= 400 {
453                                tracing::warn!(status, headers, %latency, "res");
454                            } else {
455                                tracing::info!(status, headers, %latency, "res");
456                            }
457                        })
458                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
459                            tracing::error!(
460                                err = %error,
461                                "fail"
462                            );
463                        }),
464                )
465                .layer(TimeoutLayer::with_status_code(
466                    StatusCode::REQUEST_TIMEOUT,
467                    Duration::from_secs(5),
468                ))
469                .layer(PropagateRequestIdLayer::new(request_id_header))
470                .layer(SetResponseHeaderLayer::if_not_present(
471                    header::SERVER,
472                    HeaderValue::from_static("Ordinary"),
473                )),
474        )
475}