1use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use anyhow::anyhow;
21use axum::Router;
22use axum::handler::Handler;
23use axum::response::Response;
24use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
25use blake2::{
26 Blake2bVar,
27 digest::{Update, VariableOutput},
28};
29use bytes::Bytes;
30use http_body_util::Full;
31use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
32use hyper::{HeaderMap, StatusCode, Uri, header};
33use ordinary_config::RedactedHashAlg;
34use rcgen::{CertifiedKey, generate_simple_self_signed};
35use rustls_acme::{AcmeState, EventError, EventOk};
36use std::any::Any;
37use std::fmt;
38use std::fmt::{Debug, Display};
39use std::fs::File;
40use std::io::Write;
41use std::path::Path;
42use std::sync::Arc;
43use tokio::sync::watch::Sender;
44use tokio_rustls::{
45 rustls::ServerConfig,
46 rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
47};
48use tower_http::catch_panic::CatchPanicLayer;
49use tower_http::compression::CompressionLayer;
50use tower_http::decompression::RequestDecompressionLayer;
51use valuable::{Mappable, Valuable, Value, Visit};
52
53pub const REQUEST_ID_HEADER: &str = "x-request-id";
54const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
55
56pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
57
58impl WrappedRedactedHashingAlg {
59 fn hash(&self, header_value: &str) -> String {
60 let span = tracing::info_span!("redacted:hash");
61
62 span.in_scope(|| match self.0 {
63 RedactedHashAlg::Blake2 => {
64 let mut out = [0u8; 32];
65
66 let mut hasher = match Blake2bVar::new(32) {
67 Ok(v) => v,
68 Err(err) => {
69 tracing::error!(%err);
70 return "redacted".into();
71 }
72 };
73
74 hasher.update(header_value.as_bytes());
75 if let Err(err) = hasher.finalize_variable(&mut out) {
76 tracing::error!(%err);
77 return "redacted".into();
78 }
79
80 b64.encode(out)
81 }
82 RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
83 })
84 }
85}
86pub struct HeadersDebug<'a>(
87 pub &'a HeaderMap,
88 pub Arc<Option<WrappedRedactedHashingAlg>>,
89);
90
91#[cfg(tracing_unstable)]
92impl Valuable for HeadersDebug<'_> {
93 fn as_value(&self) -> Value<'_> {
94 Value::Mappable(self)
95 }
96
97 fn visit(&self, visit: &mut dyn Visit) {
98 for (k, v) in self.0 {
99 if let Ok(v) = v.to_str() {
100 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
101 {
102 if let Some(hasher) = &*self.1 {
103 visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
104 } else {
105 visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
106 }
107 } else {
108 visit.visit_entry(k.as_str().as_value(), v.as_value());
109 }
110 }
111 }
112 }
113}
114
115#[cfg(tracing_unstable)]
116impl Mappable for HeadersDebug<'_> {
117 fn size_hint(&self) -> (usize, Option<usize>) {
118 self.0.iter().size_hint()
119 }
120}
121
122impl Debug for HeadersDebug<'_> {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 use std::fmt::Write;
125
126 f.write_char('{')?;
127
128 let mut is_first = true;
129
130 for (k, v) in self.0 {
131 if let Ok(v) = v.to_str() {
132 if is_first {
133 is_first = false;
134 f.write_char('"')?;
135 } else {
136 f.write_str(",\"")?;
137 }
138
139 f.write_str(k.as_str())?;
140 f.write_str("\":\"")?;
141
142 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
143 {
144 f.write_str("redacted")?;
145 f.write_char('"')?;
146 } else {
147 f.write_str(v)?;
148 f.write_char('"')?;
149 }
150 }
151 }
152
153 f.write_char('}')
154 }
155}
156
157pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
158 if let Some(auth_header) = headers.get("authorization")
159 && let Ok(str_val) = auth_header.to_str()
160 && let Some(b64_token) = str_val.strip_prefix("Bearer ")
161 && let Ok(token) = b64.decode(b64_token)
162 {
163 Ok(token)
164 } else {
165 Err(anyhow!("failed to get token as bytes"))
166 }
167}
168
169pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
170 if let Some(forwarded_values) = headers.get(header::FORWARDED)
171 && let Ok(forwarded_values_str) = forwarded_values.to_str()
172 && let Some(first_value) = forwarded_values_str.split(',').next()
173 && let Some(host) = first_value.split(';').find_map(|pair| {
174 let (key, value) = pair.split_once('=')?;
175 key.trim()
176 .eq_ignore_ascii_case("host")
177 .then(|| value.trim().trim_matches('"'))
178 })
179 {
180 return Some(host.to_owned());
181 }
182
183 if let Some(host) = headers
184 .get(X_FORWARDED_HOST_HEADER_KEY)
185 .and_then(|host| host.to_str().ok())
186 {
187 return Some(host.to_owned());
188 }
189
190 if let Some(host) = headers
191 .get(header::HOST)
192 .and_then(|host| host.to_str().ok())
193 {
194 return Some(host.to_owned());
195 }
196
197 if let Some(authority) = uri.authority() {
198 return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
199 }
200
201 None
202}
203
204pub struct LatencyDisplay(pub f64);
205
206impl Display for LatencyDisplay {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 let mut t = self.0;
209
210 for unit in ["ns", "µs", "ms", "s"] {
211 if t < 10.0 {
212 return write!(f, "{t:.2}{unit}");
213 } else if t < 100.0 {
214 return write!(f, "{t:.1}{unit}");
215 } else if t < 1000.0 {
216 return write!(f, "{t:.0}{unit}");
217 }
218 t /= 1000.0;
219 }
220 write!(f, "{:.0}s", t * 1000.0)
221 }
222}
223
224#[allow(clippy::needless_pass_by_value)]
225pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
226 #[allow(clippy::declare_interior_mutable_const)]
227 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
228
229 let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
230
231 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
232 res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
233
234 res
235}
236
237pub fn rustls_server_config(
238 key: impl AsRef<Path>,
239 cert: impl AsRef<Path>,
240) -> anyhow::Result<Arc<ServerConfig>> {
241 let key = PrivateKeyDer::from_pem_file(key)?;
242
243 let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
244
245 let mut config = ServerConfig::builder()
246 .with_no_client_auth()
247 .with_single_cert(certs, key)?;
248
249 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
250
251 Ok(Arc::new(config))
252}
253
254pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
256 std::fs::create_dir_all(&cert_dir_path)?;
257
258 let cert_path = cert_dir_path.as_ref().join("crt.pem");
259 let key_path = cert_dir_path.as_ref().join("key.pem");
260
261 if !cert_path.exists() || !key_path.exists() {
262 let subject_alt_names = vec!["localhost".to_string()];
263
264 let CertifiedKey { cert, signing_key } =
265 match generate_simple_self_signed(subject_alt_names) {
266 Ok(ck) => {
267 tracing::info!("generated self-signed localhost cert");
268 ck
269 }
270 Err(err) => {
271 tracing::error!("failed to generate self-signed localhost cert");
272 return Err(err.into());
273 }
274 };
275
276 let cert = cert.pem();
277 let key = signing_key.serialize_pem();
278
279 let mut cert_file = File::create(cert_path)?;
280 let mut key_file = File::create(key_path)?;
281
282 cert_file.write_all(cert.as_bytes())?;
283 key_file.write_all(key.as_bytes())?;
284 }
285
286 Ok(())
287}
288
289pub fn acme_task(
290 acme_span_clone: Span,
291 mut state: AcmeState<std::io::Error>,
292 signal_tx: Sender<()>,
293) {
294 tokio::spawn(async move {
295 loop {
296 let event = tokio::select! {
297 state = state.next() => state,
298 () = signal_tx.closed() => {
299 acme_span_clone.in_scope(|| {
300 tracing::warn!("not accepting new connections");
301 });
302 break;
303 }
304 };
305
306 if let Some(event) = event {
307 match event {
308 Ok(evt) => {
309 acme_span_clone.in_scope(|| match evt {
310 EventOk::DeployedNewCert => {
311 tracing::info!(evt.deploy = %"new", "cert");
312 }
313 EventOk::CertCacheStore => {
314 tracing::info!(evt.cache = %"stored", "cert");
315 }
316 EventOk::AccountCacheStore => {
317 tracing::info!(evt.cache = %"stored", "account");
318 }
319 EventOk::DeployedCachedCert => {
320 tracing::info!(evt.deploy = %"cached", "cert");
321 }
322 });
323 }
324 Err(err) => match err {
325 EventError::AccountCacheStore(err) => {
326 tracing::error!(%err, evt.cache = %"store", "account");
327 }
328 EventError::CertCacheStore(err) => {
329 tracing::error!(%err, evt.cache = %"store", "cert");
330 }
331 EventError::AccountCacheLoad(err) => {
332 tracing::error!(%err, evt.cache = %"load", "account");
333 }
334 EventError::CachedCertParse(err) => {
335 tracing::error!(%err, evt.parse = %"cache", "cert");
336 }
337 EventError::NewCertParse(err) => {
338 tracing::error!(%err, evt.parse = %"new", "cert");
339 }
340 EventError::CertCacheLoad(err) => {
341 tracing::error!(%err, evt.cache = %"load", "cert");
342 }
343 EventError::Order(err) => {
344 tracing::error!(%err, "order");
345 }
346 },
347 }
348 } else {
349 break;
350 }
351 }
352 });
353}
354
355pub fn redirect_service<H, T, S>(
356 span_clone: Span,
357 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
358 log_ips: bool,
359 log_headers: bool,
360 request_id_header: HeaderName,
361 handler: H,
362 state: S,
363) -> Router
364where
365 H: Handler<T, S>,
366 T: 'static,
367 S: Clone + Send + Sync + 'static,
368{
369 let redacted_hash_clone = redacted_hash.clone();
370
371 Router::new()
372 .route("/healthz", get(|| async { StatusCode::OK }))
373 .fallback(handler)
374 .with_state(state)
375 .layer(
376 ServiceBuilder::new()
377 .layer(CatchPanicLayer::custom(response_for_panic))
378 .layer(RequestDecompressionLayer::new())
379 .layer(CompressionLayer::new()),
380 )
381 .layer(
382 ServiceBuilder::new()
383 .layer(SetRequestIdLayer::new(
384 request_id_header.clone(),
385 MakeRequestUuid,
386 ))
387 .layer(
388 TraceLayer::new_for_http()
389 .make_span_with(move |req: &Request<_>| {
390 let request_id = req.headers().get(REQUEST_ID_HEADER);
391
392 let host =
393 get_host(req.headers(), req.uri()).map(tracing::field::display);
394
395 let ip = log_ips.then(|| {
396 req.extensions()
397 .get::<ConnectInfo<SocketAddr>>()
398 .map(|addr| tracing::field::display(addr.ip()))
399 });
400
401 let query = req.uri().query().map(tracing::field::display);
402
403 span_clone.in_scope(|| match request_id {
404 Some(rid) => {
405 tracing::warn_span!(
406 "redirect",
407 host,
408 id = %rid
409 .to_str()
410 .unwrap_or(Uuid::new_v4().to_string().as_str()),
411 ip,
412 path = %req.uri().path(),
413 query,
414 )
415 }
416 None => {
417 tracing::warn_span!(
418 "redirect",
419 host,
420 id = %Uuid::new_v4(),
421 ip,
422 path = %req.uri().path(),
423 query,
424 )
425 }
426 })
427 })
428 .on_request(move |req: &Request<_>, _: &Span| {
429 let hd = log_headers
430 .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
431
432 #[cfg(tracing_unstable)]
433 let headers = log_headers.then_some(tracing::field::valuable(&hd));
434
435 #[cfg(not(tracing_unstable))]
436 let headers = log_headers.then_some(tracing::field::debug(&hd));
437
438 tracing::warn!(
439 version = ?req.version(),
440 method = %req.method(),
441 headers,
442 "req"
443 );
444 })
445 .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
446 let hd = log_headers.then_some(HeadersDebug(
447 res.headers(),
448 redacted_hash_clone.clone(),
449 ));
450
451 #[cfg(tracing_unstable)]
452 let headers = log_headers.then_some(tracing::field::valuable(&hd));
453
454 #[cfg(not(tracing_unstable))]
455 let headers = log_headers.then_some(tracing::field::debug(&hd));
456
457 let status = res.status().as_u16();
458 let latency = LatencyDisplay(latency.as_nanos() as f64);
459
460 if status >= 500 {
461 tracing::error!(status, headers, %latency, "res");
462 } else if status >= 400 {
463 tracing::warn!(status, headers, %latency, "res");
464 } else {
465 tracing::info!(status, headers, %latency, "res");
466 }
467 })
468 .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
469 tracing::error!(
470 err = %error,
471 "fail"
472 );
473 }),
474 )
475 .layer(TimeoutLayer::with_status_code(
476 StatusCode::REQUEST_TIMEOUT,
477 Duration::from_secs(5),
478 ))
479 .layer(PropagateRequestIdLayer::new(request_id_header))
480 .layer(SetResponseHeaderLayer::if_not_present(
481 header::SERVER,
482 HeaderValue::from_static("Ordinary"),
483 )),
484 )
485}