1use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use anyhow::anyhow;
21use axum::Router;
22use axum::handler::Handler;
23use axum::response::Response;
24use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
25use blake2::{
26 Blake2bVar,
27 digest::{Update, VariableOutput},
28};
29use bytes::Bytes;
30use http_body_util::Full;
31use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
32use hyper::{HeaderMap, StatusCode, Uri, header};
33use ordinary_config::RedactedHashAlg;
34use rcgen::{CertifiedKey, generate_simple_self_signed};
35use rustls_acme::{AcmeState, EventError, EventOk};
36use std::any::Any;
37use std::fmt;
38use std::fmt::{Debug, Display};
39use std::fs::File;
40use std::io::Write;
41use std::path::Path;
42use std::sync::Arc;
43use tokio::sync::watch::Sender;
44use tokio_rustls::{
45 rustls::ServerConfig,
46 rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
47};
48use tower_http::catch_panic::CatchPanicLayer;
49use tower_http::compression::CompressionLayer;
50use tower_http::decompression::RequestDecompressionLayer;
51use valuable::{Mappable, Valuable, Value, Visit};
52
53pub const REQUEST_ID_HEADER: &str = "x-request-id";
54const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
55
56pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
57
58impl WrappedRedactedHashingAlg {
59 fn hash(&self, header_value: &str) -> String {
60 let span = tracing::info_span!("redacted:hash");
61
62 span.in_scope(|| match self.0 {
63 RedactedHashAlg::Blake2 => {
64 let mut out = [0u8; 32];
65
66 let mut hasher = match Blake2bVar::new(32) {
67 Ok(v) => v,
68 Err(err) => {
69 tracing::error!(%err);
70 return "redacted".into();
71 }
72 };
73
74 hasher.update(header_value.as_bytes());
75 if let Err(err) = hasher.finalize_variable(&mut out) {
76 tracing::error!(%err);
77 return "redacted".into();
78 }
79
80 b64.encode(&out[0..6])
81 }
82 RedactedHashAlg::Blake3 => {
83 b64.encode(&blake3::hash(header_value.as_bytes()).as_bytes()[0..6])
84 }
85 })
86 }
87}
88pub struct HeadersDebug<'a>(
89 pub &'a HeaderMap,
90 pub Arc<Option<WrappedRedactedHashingAlg>>,
91);
92
93#[cfg(tracing_unstable)]
94impl Valuable for HeadersDebug<'_> {
95 fn as_value(&self) -> Value<'_> {
96 Value::Mappable(self)
97 }
98
99 fn visit(&self, visit: &mut dyn Visit) {
100 for (k, v) in self.0 {
101 if let Ok(v) = v.to_str() {
102 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
103 {
104 if let Some(hasher) = &*self.1 {
105 visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
106 } else {
107 visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
108 }
109 } else {
110 visit.visit_entry(k.as_str().as_value(), v.as_value());
111 }
112 }
113 }
114 }
115}
116
117#[cfg(tracing_unstable)]
118impl Mappable for HeadersDebug<'_> {
119 fn size_hint(&self) -> (usize, Option<usize>) {
120 self.0.iter().size_hint()
121 }
122}
123
124impl Debug for HeadersDebug<'_> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 use std::fmt::Write;
127
128 f.write_char('{')?;
129
130 let mut is_first = true;
131
132 for (k, v) in self.0 {
133 if let Ok(v) = v.to_str() {
134 if is_first {
135 is_first = false;
136 f.write_char('"')?;
137 } else {
138 f.write_str(",\"")?;
139 }
140
141 f.write_str(k.as_str())?;
142 f.write_str("\":\"")?;
143
144 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
145 {
146 f.write_str("redacted")?;
147 f.write_char('"')?;
148 } else {
149 f.write_str(v)?;
150 f.write_char('"')?;
151 }
152 }
153 }
154
155 f.write_char('}')
156 }
157}
158
159pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
160 if let Some(auth_header) = headers.get("authorization")
161 && let Ok(str_val) = auth_header.to_str()
162 && let Some(b64_token) = str_val.strip_prefix("Bearer ")
163 && let Ok(token) = b64.decode(b64_token)
164 {
165 Ok(token)
166 } else {
167 Err(anyhow!("failed to get token as bytes"))
168 }
169}
170
171pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
172 if let Some(forwarded_values) = headers.get(header::FORWARDED)
173 && let Ok(forwarded_values_str) = forwarded_values.to_str()
174 && let Some(first_value) = forwarded_values_str.split(',').next()
175 && let Some(host) = first_value.split(';').find_map(|pair| {
176 let (key, value) = pair.split_once('=')?;
177 key.trim()
178 .eq_ignore_ascii_case("host")
179 .then(|| value.trim().trim_matches('"'))
180 })
181 {
182 return Some(host.to_owned());
183 }
184
185 if let Some(host) = headers
186 .get(X_FORWARDED_HOST_HEADER_KEY)
187 .and_then(|host| host.to_str().ok())
188 {
189 return Some(host.to_owned());
190 }
191
192 if let Some(host) = headers
193 .get(header::HOST)
194 .and_then(|host| host.to_str().ok())
195 {
196 return Some(host.to_owned());
197 }
198
199 if let Some(authority) = uri.authority() {
200 return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
201 }
202
203 None
204}
205
206pub struct LatencyDisplay(pub f64);
207
208impl Display for LatencyDisplay {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 let mut t = self.0;
211
212 for unit in ["ns", "µs", "ms", "s"] {
213 if t < 10.0 {
214 return write!(f, "{t:.2}{unit}");
215 } else if t < 100.0 {
216 return write!(f, "{t:.1}{unit}");
217 } else if t < 1000.0 {
218 return write!(f, "{t:.0}{unit}");
219 }
220 t /= 1000.0;
221 }
222 write!(f, "{:.0}s", t * 1000.0)
223 }
224}
225
226#[allow(clippy::needless_pass_by_value)]
227pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
228 #[allow(clippy::declare_interior_mutable_const)]
229 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
230
231 let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
232
233 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
234 res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
235
236 res
237}
238
239pub fn rustls_server_config(
240 key: impl AsRef<Path>,
241 cert: impl AsRef<Path>,
242) -> anyhow::Result<Arc<ServerConfig>> {
243 let key = PrivateKeyDer::from_pem_file(key)?;
244
245 let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
246
247 let mut config = ServerConfig::builder()
248 .with_no_client_auth()
249 .with_single_cert(certs, key)?;
250
251 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
252
253 Ok(Arc::new(config))
254}
255
256pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
258 std::fs::create_dir_all(&cert_dir_path)?;
259
260 let cert_path = cert_dir_path.as_ref().join("crt.pem");
261 let key_path = cert_dir_path.as_ref().join("key.pem");
262
263 if !cert_path.exists() || !key_path.exists() {
264 let subject_alt_names = vec!["localhost".to_string()];
265
266 let CertifiedKey { cert, signing_key } =
267 match generate_simple_self_signed(subject_alt_names) {
268 Ok(ck) => {
269 tracing::info!("generated self-signed localhost cert");
270 ck
271 }
272 Err(err) => {
273 tracing::error!("failed to generate self-signed localhost cert");
274 return Err(err.into());
275 }
276 };
277
278 let cert = cert.pem();
279 let key = signing_key.serialize_pem();
280
281 let mut cert_file = File::create(cert_path)?;
282 let mut key_file = File::create(key_path)?;
283
284 cert_file.write_all(cert.as_bytes())?;
285 cert_file.flush()?;
286 key_file.write_all(key.as_bytes())?;
287 key_file.flush()?;
288 }
289
290 Ok(())
291}
292
293pub fn acme_task(
294 acme_span_clone: Span,
295 mut state: AcmeState<std::io::Error>,
296 signal_tx: Sender<()>,
297) {
298 tokio::spawn(async move {
299 loop {
300 let event = tokio::select! {
301 state = state.next() => state,
302 () = signal_tx.closed() => {
303 acme_span_clone.in_scope(|| {
304 tracing::warn!("not accepting new connections");
305 });
306 break;
307 }
308 };
309
310 if let Some(event) = event {
311 match event {
312 Ok(evt) => {
313 acme_span_clone.in_scope(|| match evt {
314 EventOk::DeployedNewCert => {
315 tracing::info!(evt.deploy = %"new", "cert");
316 }
317 EventOk::CertCacheStore => {
318 tracing::info!(evt.cache = %"stored", "cert");
319 }
320 EventOk::AccountCacheStore => {
321 tracing::info!(evt.cache = %"stored", "account");
322 }
323 EventOk::DeployedCachedCert => {
324 tracing::info!(evt.deploy = %"cached", "cert");
325 }
326 });
327 }
328 Err(err) => match err {
329 EventError::AccountCacheStore(err) => {
330 tracing::error!(%err, evt.cache = %"store", "account");
331 }
332 EventError::CertCacheStore(err) => {
333 tracing::error!(%err, evt.cache = %"store", "cert");
334 }
335 EventError::AccountCacheLoad(err) => {
336 tracing::error!(%err, evt.cache = %"load", "account");
337 }
338 EventError::CachedCertParse(err) => {
339 tracing::error!(%err, evt.parse = %"cache", "cert");
340 }
341 EventError::NewCertParse(err) => {
342 tracing::error!(%err, evt.parse = %"new", "cert");
343 }
344 EventError::CertCacheLoad(err) => {
345 tracing::error!(%err, evt.cache = %"load", "cert");
346 }
347 EventError::Order(err) => {
348 tracing::error!(%err, "order");
349 }
350 },
351 }
352 } else {
353 break;
354 }
355 }
356 });
357}
358
359#[allow(clippy::too_many_lines)]
360pub fn redirect_service<H, T, S>(
361 span_clone: Span,
362 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
363 log_ips: bool,
364 log_headers: bool,
365 request_id_header: HeaderName,
366 handler: H,
367 state: S,
368) -> Router
369where
370 H: Handler<T, S>,
371 T: 'static,
372 S: Clone + Send + Sync + 'static,
373{
374 let redacted_hash_clone = redacted_hash.clone();
375
376 Router::new()
377 .route("/healthz", get(|| async { StatusCode::OK }))
378 .fallback(handler)
379 .with_state(state)
380 .layer(
381 ServiceBuilder::new()
382 .layer(CatchPanicLayer::custom(response_for_panic))
383 .layer(RequestDecompressionLayer::new())
384 .layer(CompressionLayer::new()),
385 )
386 .layer(
387 ServiceBuilder::new()
388 .layer(SetRequestIdLayer::new(
389 request_id_header.clone(),
390 MakeRequestUuid,
391 ))
392 .layer(
393 TraceLayer::new_for_http()
394 .make_span_with(move |req: &Request<_>| {
395 let request_id = req.headers().get(REQUEST_ID_HEADER);
396
397 let host =
398 get_host(req.headers(), req.uri()).map(tracing::field::display);
399
400 let ip = log_ips.then(|| {
401 req.extensions()
402 .get::<ConnectInfo<SocketAddr>>()
403 .map(|addr| tracing::field::display(addr.ip()))
404 });
405
406 let query = req.uri().query().map(tracing::field::display);
407
408 span_clone.in_scope(|| match request_id {
409 Some(rid) => {
410 tracing::warn_span!(
411 "redirect",
412 host,
413 rid = %rid
414 .to_str()
415 .unwrap_or(Uuid::new_v4().to_string().as_str()),
416 ip,
417 path = %req.uri().path(),
418 query,
419 )
420 }
421 None => {
422 tracing::warn_span!(
423 "redirect",
424 host,
425 rid = %Uuid::new_v4(),
426 ip,
427 path = %req.uri().path(),
428 query,
429 )
430 }
431 })
432 })
433 .on_request(move |req: &Request<_>, _: &Span| {
434 let hd = log_headers
435 .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
436
437 #[cfg(tracing_unstable)]
438 let headers = log_headers.then_some(tracing::field::valuable(&hd));
439
440 #[cfg(not(tracing_unstable))]
441 let headers = log_headers.then_some(tracing::field::debug(&hd));
442
443 tracing::warn!(
444 version = ?req.version(),
445 method = %req.method(),
446 headers,
447 "req"
448 );
449 })
450 .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
451 let hd = log_headers.then_some(HeadersDebug(
452 res.headers(),
453 redacted_hash_clone.clone(),
454 ));
455
456 #[cfg(tracing_unstable)]
457 let headers = log_headers.then_some(tracing::field::valuable(&hd));
458
459 #[cfg(not(tracing_unstable))]
460 let headers = log_headers.then_some(tracing::field::debug(&hd));
461
462 let status = res.status().as_u16();
463 let latency = LatencyDisplay(latency.as_nanos() as f64);
464
465 if status >= 500 {
466 tracing::error!(status, headers, %latency, "res");
467 } else if status >= 400 {
468 tracing::warn!(status, headers, %latency, "res");
469 } else {
470 tracing::info!(status, headers, %latency, "res");
471 }
472 })
473 .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
474 tracing::error!(
475 err = %error,
476 "fail"
477 );
478 }),
479 )
480 .layer(TimeoutLayer::with_status_code(
481 StatusCode::REQUEST_TIMEOUT,
482 Duration::from_secs(5),
483 ))
484 .layer(PropagateRequestIdLayer::new(request_id_header))
485 .layer(SetResponseHeaderLayer::if_not_present(
486 header::SERVER,
487 HeaderValue::from_static("ordinaryd"),
488 )),
489 )
490}