Skip to main content

ordinary_utils/
server.rs

1// Copyright (C) 2026 Ordinary Labs, LLC.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4
5use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use axum::Router;
21use axum::handler::Handler;
22use axum::response::Response;
23use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
24use blake2::{
25    Blake2bVar,
26    digest::{Update, VariableOutput},
27};
28use bytes::Bytes;
29use http_body_util::Full;
30use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
31use hyper::{HeaderMap, StatusCode, Uri, header};
32use ordinary_config::RedactedHashAlg;
33use rcgen::{CertifiedKey, generate_simple_self_signed};
34use rustls_acme::{AcmeState, EventError, EventOk};
35use std::any::Any;
36use std::fmt;
37use std::fmt::{Debug, Display};
38use std::fs::File;
39use std::io::Write;
40use std::path::Path;
41use std::sync::Arc;
42use tokio::sync::watch::Sender;
43use tokio_rustls::{
44    rustls::ServerConfig,
45    rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
46};
47use tower_http::catch_panic::CatchPanicLayer;
48use tower_http::compression::CompressionLayer;
49use tower_http::decompression::RequestDecompressionLayer;
50use valuable::{Mappable, Valuable, Value, Visit};
51
52pub const REQUEST_ID_HEADER: &str = "x-request-id";
53const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
54
55pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
56
57impl WrappedRedactedHashingAlg {
58    fn hash(&self, header_value: &str) -> String {
59        let span = tracing::info_span!("redacted:hash");
60
61        span.in_scope(|| match self.0 {
62            RedactedHashAlg::Blake2 => {
63                let mut out = [0u8; 32];
64
65                let mut hasher = match Blake2bVar::new(32) {
66                    Ok(v) => v,
67                    Err(err) => {
68                        tracing::error!(%err);
69                        return "redacted".into();
70                    }
71                };
72
73                hasher.update(header_value.as_bytes());
74                if let Err(err) = hasher.finalize_variable(&mut out) {
75                    tracing::error!(%err);
76                    return "redacted".into();
77                }
78
79                b64.encode(out)
80            }
81            RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
82        })
83    }
84}
85pub struct HeadersDebug<'a>(
86    pub &'a HeaderMap,
87    pub Arc<Option<WrappedRedactedHashingAlg>>,
88);
89
90#[cfg(tracing_unstable)]
91impl Valuable for HeadersDebug<'_> {
92    fn as_value(&self) -> Value<'_> {
93        Value::Mappable(self)
94    }
95
96    fn visit(&self, visit: &mut dyn Visit) {
97        for (k, v) in self.0 {
98            if let Ok(v) = v.to_str() {
99                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
100                {
101                    if let Some(hasher) = &*self.1 {
102                        visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
103                    } else {
104                        visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
105                    }
106                } else {
107                    visit.visit_entry(k.as_str().as_value(), v.as_value());
108                }
109            }
110        }
111    }
112}
113
114#[cfg(tracing_unstable)]
115impl Mappable for HeadersDebug<'_> {
116    fn size_hint(&self) -> (usize, Option<usize>) {
117        self.0.iter().size_hint()
118    }
119}
120
121impl Debug for HeadersDebug<'_> {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        use std::fmt::Write;
124
125        f.write_char('{')?;
126
127        let mut is_first = true;
128
129        for (k, v) in self.0 {
130            if let Ok(v) = v.to_str() {
131                if is_first {
132                    is_first = false;
133                    f.write_char('"')?;
134                } else {
135                    f.write_str(",\"")?;
136                }
137
138                f.write_str(k.as_str())?;
139                f.write_str("\":\"")?;
140
141                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
142                {
143                    f.write_str("redacted")?;
144                    f.write_char('"')?;
145                } else {
146                    f.write_str(v)?;
147                    f.write_char('"')?;
148                }
149            }
150        }
151
152        f.write_char('}')
153    }
154}
155
156pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
157    if let Some(forwarded_values) = headers.get(header::FORWARDED)
158        && let Ok(forwarded_values_str) = forwarded_values.to_str()
159        && let Some(first_value) = forwarded_values_str.split(',').next()
160        && let Some(host) = first_value.split(';').find_map(|pair| {
161            let (key, value) = pair.split_once('=')?;
162            key.trim()
163                .eq_ignore_ascii_case("host")
164                .then(|| value.trim().trim_matches('"'))
165        })
166    {
167        return Some(host.to_owned());
168    }
169
170    if let Some(host) = headers
171        .get(X_FORWARDED_HOST_HEADER_KEY)
172        .and_then(|host| host.to_str().ok())
173    {
174        return Some(host.to_owned());
175    }
176
177    if let Some(host) = headers
178        .get(header::HOST)
179        .and_then(|host| host.to_str().ok())
180    {
181        return Some(host.to_owned());
182    }
183
184    if let Some(authority) = uri.authority() {
185        return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
186    }
187
188    None
189}
190
191pub struct LatencyDisplay(pub f64);
192
193impl Display for LatencyDisplay {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        let mut t = self.0;
196
197        for unit in ["ns", "µs", "ms", "s"] {
198            if t < 10.0 {
199                return write!(f, "{t:.2}{unit}");
200            } else if t < 100.0 {
201                return write!(f, "{t:.1}{unit}");
202            } else if t < 1000.0 {
203                return write!(f, "{t:.0}{unit}");
204            }
205            t /= 1000.0;
206        }
207        write!(f, "{:.0}s", t * 1000.0)
208    }
209}
210
211#[allow(clippy::needless_pass_by_value)]
212pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
213    #[allow(clippy::declare_interior_mutable_const)]
214    const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
215
216    let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
217
218    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
219    res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
220
221    res
222}
223
224pub fn rustls_server_config(
225    key: impl AsRef<Path>,
226    cert: impl AsRef<Path>,
227) -> anyhow::Result<Arc<ServerConfig>> {
228    let key = PrivateKeyDer::from_pem_file(key)?;
229
230    let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
231
232    let mut config = ServerConfig::builder()
233        .with_no_client_auth()
234        .with_single_cert(certs, key)?;
235
236    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
237
238    Ok(Arc::new(config))
239}
240
241/// writes crt.pem and key.pem to directory
242pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
243    std::fs::create_dir_all(&cert_dir_path)?;
244
245    let cert_path = cert_dir_path.as_ref().join("crt.pem");
246    let key_path = cert_dir_path.as_ref().join("key.pem");
247
248    if !cert_path.exists() || !key_path.exists() {
249        let subject_alt_names = vec!["localhost".to_string()];
250
251        let CertifiedKey { cert, signing_key } =
252            match generate_simple_self_signed(subject_alt_names) {
253                Ok(ck) => {
254                    tracing::info!("generated self-signed localhost cert");
255                    ck
256                }
257                Err(err) => {
258                    tracing::error!("failed to generate self-signed localhost cert");
259                    return Err(err.into());
260                }
261            };
262
263        let cert = cert.pem();
264        let key = signing_key.serialize_pem();
265
266        let mut cert_file = File::create(cert_path)?;
267        let mut key_file = File::create(key_path)?;
268
269        cert_file.write_all(cert.as_bytes())?;
270        key_file.write_all(key.as_bytes())?;
271    }
272
273    Ok(())
274}
275
276pub fn acme_task(
277    acme_span_clone: Span,
278    mut state: AcmeState<std::io::Error>,
279    signal_tx: Sender<()>,
280) {
281    tokio::spawn(async move {
282        loop {
283            let event = tokio::select! {
284                state = state.next() => state,
285                () = signal_tx.closed() => {
286                    acme_span_clone.in_scope(|| {
287                       tracing::warn!("not accepting new connections");
288                    });
289                    break;
290                }
291            };
292
293            if let Some(event) = event {
294                match event {
295                    Ok(evt) => {
296                        acme_span_clone.in_scope(|| match evt {
297                            EventOk::DeployedNewCert => {
298                                tracing::info!(evt.deploy = %"new", "cert");
299                            }
300                            EventOk::CertCacheStore => {
301                                tracing::info!(evt.cache = %"stored", "cert");
302                            }
303                            EventOk::AccountCacheStore => {
304                                tracing::info!(evt.cache = %"stored", "account");
305                            }
306                            EventOk::DeployedCachedCert => {
307                                tracing::info!(evt.deploy = %"cached", "cert");
308                            }
309                        });
310                    }
311                    Err(err) => match err {
312                        EventError::AccountCacheStore(err) => {
313                            tracing::error!(%err, evt.cache = %"store", "account");
314                        }
315                        EventError::CertCacheStore(err) => {
316                            tracing::error!(%err, evt.cache = %"store", "cert");
317                        }
318                        EventError::AccountCacheLoad(err) => {
319                            tracing::error!(%err, evt.cache = %"load", "account");
320                        }
321                        EventError::CachedCertParse(err) => {
322                            tracing::error!(%err, evt.parse = %"cache", "cert");
323                        }
324                        EventError::NewCertParse(err) => {
325                            tracing::error!(%err, evt.parse = %"new", "cert");
326                        }
327                        EventError::CertCacheLoad(err) => {
328                            tracing::error!(%err, evt.cache = %"load", "cert");
329                        }
330                        EventError::Order(err) => {
331                            tracing::error!(%err, "order");
332                        }
333                    },
334                }
335            } else {
336                break;
337            }
338        }
339    });
340}
341
342pub fn redirect_service<H, T, S>(
343    span_clone: Span,
344    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
345    log_ips: bool,
346    log_headers: bool,
347    request_id_header: HeaderName,
348    handler: H,
349    state: S,
350) -> Router
351where
352    H: Handler<T, S>,
353    T: 'static,
354    S: Clone + Send + Sync + 'static,
355{
356    let redacted_hash_clone = redacted_hash.clone();
357
358    Router::new()
359        .route("/healthz", get(|| async { StatusCode::OK }))
360        .fallback(handler)
361        .with_state(state)
362        .layer(
363            ServiceBuilder::new()
364                .layer(CatchPanicLayer::custom(response_for_panic))
365                .layer(RequestDecompressionLayer::new())
366                .layer(CompressionLayer::new()),
367        )
368        .layer(
369            ServiceBuilder::new()
370                .layer(SetRequestIdLayer::new(
371                    request_id_header.clone(),
372                    MakeRequestUuid,
373                ))
374                .layer(
375                    TraceLayer::new_for_http()
376                        .make_span_with(move |req: &Request<_>| {
377                            let request_id = req.headers().get(REQUEST_ID_HEADER);
378
379                            let host =
380                                get_host(req.headers(), req.uri()).map(tracing::field::display);
381
382                            let ip = log_ips.then(|| {
383                                req.extensions()
384                                    .get::<ConnectInfo<SocketAddr>>()
385                                    .map(|addr| tracing::field::display(addr.ip()))
386                            });
387
388                            let query = req.uri().query().map(tracing::field::display);
389
390                            span_clone.in_scope(|| match request_id {
391                                Some(rid) => {
392                                    tracing::warn_span!(
393                                        "redirect",
394                                        host,
395                                        id = %rid
396                                            .to_str()
397                                            .unwrap_or(Uuid::new_v4().to_string().as_str()),
398                                        ip,
399                                        path = %req.uri().path(),
400                                        query,
401                                    )
402                                }
403                                None => {
404                                    tracing::warn_span!(
405                                        "redirect",
406                                        host,
407                                        id = %Uuid::new_v4(),
408                                        ip,
409                                        path = %req.uri().path(),
410                                        query,
411                                    )
412                                }
413                            })
414                        })
415                        .on_request(move |req: &Request<_>, _: &Span| {
416                            let hd = log_headers
417                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
418
419                            #[cfg(tracing_unstable)]
420                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
421
422                            #[cfg(not(tracing_unstable))]
423                            let headers = log_headers.then_some(tracing::field::debug(&hd));
424
425                            tracing::warn!(
426                                version = ?req.version(),
427                                method = %req.method(),
428                                headers,
429                                "req"
430                            );
431                        })
432                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
433                            let hd = log_headers.then_some(HeadersDebug(
434                                res.headers(),
435                                redacted_hash_clone.clone(),
436                            ));
437
438                            #[cfg(tracing_unstable)]
439                            let headers = log_headers.then_some(tracing::field::valuable(&hd));
440
441                            #[cfg(not(tracing_unstable))]
442                            let headers = log_headers.then_some(tracing::field::debug(&hd));
443
444                            let status = res.status().as_u16();
445                            let latency = LatencyDisplay(latency.as_nanos() as f64);
446
447                            if status >= 500 {
448                                tracing::error!(status, headers, %latency, "res");
449                            } else if status >= 400 {
450                                tracing::warn!(status, headers, %latency, "res");
451                            } else {
452                                tracing::info!(status, headers, %latency, "res");
453                            }
454                        })
455                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
456                            tracing::error!(
457                                err = %error,
458                                "fail"
459                            );
460                        }),
461                )
462                .layer(TimeoutLayer::with_status_code(
463                    StatusCode::REQUEST_TIMEOUT,
464                    Duration::from_secs(5),
465                ))
466                .layer(PropagateRequestIdLayer::new(request_id_header))
467                .layer(SetResponseHeaderLayer::if_not_present(
468                    header::SERVER,
469                    HeaderValue::from_static("Ordinary"),
470                )),
471        )
472}