1use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use axum::Router;
21use axum::handler::Handler;
22use axum::response::Response;
23use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
24use blake2::{
25 Blake2bVar,
26 digest::{Update, VariableOutput},
27};
28use bytes::Bytes;
29use http_body_util::Full;
30use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
31use hyper::{HeaderMap, StatusCode, Uri, header};
32use ordinary_config::RedactedHashAlg;
33use rcgen::{CertifiedKey, generate_simple_self_signed};
34use rustls_acme::{AcmeState, EventError, EventOk};
35use std::any::Any;
36use std::error::Error;
37use std::fmt;
38use std::fmt::{Debug, Display};
39use std::fs::File;
40use std::io::Write;
41use std::path::Path;
42use std::sync::Arc;
43use tokio::sync::watch::Sender;
44use tokio_rustls::{
45 rustls::ServerConfig,
46 rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
47};
48use tower_http::catch_panic::CatchPanicLayer;
49use tower_http::compression::CompressionLayer;
50use tower_http::decompression::RequestDecompressionLayer;
51use valuable::{Mappable, Valuable, Value, Visit};
52
53pub const REQUEST_ID_HEADER: &str = "x-request-id";
54const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
55
56pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
57
58impl WrappedRedactedHashingAlg {
59 fn hash(&self, header_value: &str) -> String {
60 let span = tracing::info_span!("redacted:hash");
61
62 span.in_scope(|| match self.0 {
63 RedactedHashAlg::Blake2 => {
64 let mut out = [0u8; 32];
65
66 let mut hasher = match Blake2bVar::new(32) {
67 Ok(v) => v,
68 Err(err) => {
69 tracing::error!(%err);
70 return "redacted".into();
71 }
72 };
73
74 hasher.update(header_value.as_bytes());
75 if let Err(err) = hasher.finalize_variable(&mut out) {
76 tracing::error!(%err);
77 return "redacted".into();
78 }
79
80 b64.encode(out)
81 }
82 RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
83 })
84 }
85}
86pub struct HeadersDebug<'a>(
87 pub &'a HeaderMap,
88 pub Arc<Option<WrappedRedactedHashingAlg>>,
89);
90
91#[cfg(tracing_unstable)]
92impl Valuable for HeadersDebug<'_> {
93 fn as_value(&self) -> Value<'_> {
94 Value::Mappable(self)
95 }
96
97 fn visit(&self, visit: &mut dyn Visit) {
98 for (k, v) in self.0 {
99 if let Ok(v) = v.to_str() {
100 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
101 {
102 if let Some(hasher) = &*self.1 {
103 visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
104 } else {
105 visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
106 }
107 } else {
108 visit.visit_entry(k.as_str().as_value(), v.as_value());
109 }
110 }
111 }
112 }
113}
114
115#[cfg(tracing_unstable)]
116impl Mappable for HeadersDebug<'_> {
117 fn size_hint(&self) -> (usize, Option<usize>) {
118 self.0.iter().size_hint()
119 }
120}
121
122impl Debug for HeadersDebug<'_> {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 use std::fmt::Write;
125
126 f.write_char('{')?;
127
128 let mut is_first = true;
129
130 for (k, v) in self.0 {
131 if let Ok(v) = v.to_str() {
132 if is_first {
133 is_first = false;
134 f.write_char('"')?;
135 } else {
136 f.write_str(",\"")?;
137 }
138
139 f.write_str(k.as_str())?;
140 f.write_str("\":\"")?;
141
142 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
143 {
144 f.write_str("redacted")?;
145 f.write_char('"')?;
146 } else {
147 f.write_str(v)?;
148 f.write_char('"')?;
149 }
150 }
151 }
152
153 f.write_char('}')
154 }
155}
156
157pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
158 if let Some(forwarded_values) = headers.get(header::FORWARDED)
159 && let Ok(forwarded_values_str) = forwarded_values.to_str()
160 && let Some(first_value) = forwarded_values_str.split(',').next()
161 && let Some(host) = first_value.split(';').find_map(|pair| {
162 let (key, value) = pair.split_once('=')?;
163 key.trim()
164 .eq_ignore_ascii_case("host")
165 .then(|| value.trim().trim_matches('"'))
166 })
167 {
168 return Some(host.to_owned());
169 }
170
171 if let Some(host) = headers
172 .get(X_FORWARDED_HOST_HEADER_KEY)
173 .and_then(|host| host.to_str().ok())
174 {
175 return Some(host.to_owned());
176 }
177
178 if let Some(host) = headers
179 .get(header::HOST)
180 .and_then(|host| host.to_str().ok())
181 {
182 return Some(host.to_owned());
183 }
184
185 if let Some(authority) = uri.authority() {
186 return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
187 }
188
189 None
190}
191
192pub struct LatencyDisplay(pub f64);
193
194impl Display for LatencyDisplay {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 let mut t = self.0;
197
198 for unit in ["ns", "µs", "ms", "s"] {
199 if t < 10.0 {
200 return write!(f, "{t:.2}{unit}");
201 } else if t < 100.0 {
202 return write!(f, "{t:.1}{unit}");
203 } else if t < 1000.0 {
204 return write!(f, "{t:.0}{unit}");
205 }
206 t /= 1000.0;
207 }
208 write!(f, "{:.0}s", t * 1000.0)
209 }
210}
211
212#[allow(clippy::needless_pass_by_value)]
213pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
214 #[allow(clippy::declare_interior_mutable_const)]
215 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
216
217 let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
218
219 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
220 res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
221
222 res
223}
224
225pub fn rustls_server_config(
226 key: impl AsRef<Path>,
227 cert: impl AsRef<Path>,
228) -> Result<Arc<ServerConfig>, Box<dyn Error>> {
229 let key = PrivateKeyDer::from_pem_file(key)?;
230
231 let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
232
233 let mut config = ServerConfig::builder()
234 .with_no_client_auth()
235 .with_single_cert(certs, key)?;
236
237 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
238
239 Ok(Arc::new(config))
240}
241
242pub fn generate_self_signed_localhost_certs(
244 cert_dir_path: impl AsRef<Path>,
245) -> Result<(), Box<dyn Error>> {
246 std::fs::create_dir_all(&cert_dir_path)?;
247
248 let cert_path = cert_dir_path.as_ref().join("crt.pem");
249 let key_path = cert_dir_path.as_ref().join("key.pem");
250
251 if !cert_path.exists() || !key_path.exists() {
252 let subject_alt_names = vec!["localhost".to_string()];
253
254 let CertifiedKey { cert, signing_key } =
255 match generate_simple_self_signed(subject_alt_names) {
256 Ok(ck) => {
257 tracing::info!("generated self-signed localhost cert");
258 ck
259 }
260 Err(err) => {
261 tracing::error!("failed to generate self-signed localhost cert");
262 return Err(err.into());
263 }
264 };
265
266 let cert = cert.pem();
267 let key = signing_key.serialize_pem();
268
269 let mut cert_file = File::create(cert_path)?;
270 let mut key_file = File::create(key_path)?;
271
272 cert_file.write_all(cert.as_bytes())?;
273 key_file.write_all(key.as_bytes())?;
274 }
275
276 Ok(())
277}
278
279pub fn acme_task(
280 acme_span_clone: Span,
281 mut state: AcmeState<std::io::Error>,
282 signal_tx: Sender<()>,
283) {
284 tokio::spawn(async move {
285 loop {
286 let event = tokio::select! {
287 state = state.next() => state,
288 () = signal_tx.closed() => {
289 acme_span_clone.in_scope(|| {
290 tracing::warn!("not accepting new connections");
291 });
292 break;
293 }
294 };
295
296 if let Some(event) = event {
297 match event {
298 Ok(evt) => {
299 acme_span_clone.in_scope(|| match evt {
300 EventOk::DeployedNewCert => {
301 tracing::info!(evt.deploy = %"new", "cert");
302 }
303 EventOk::CertCacheStore => {
304 tracing::info!(evt.cache = %"stored", "cert");
305 }
306 EventOk::AccountCacheStore => {
307 tracing::info!(evt.cache = %"stored", "account");
308 }
309 EventOk::DeployedCachedCert => {
310 tracing::info!(evt.deploy = %"cached", "cert");
311 }
312 });
313 }
314 Err(err) => match err {
315 EventError::AccountCacheStore(err) => {
316 tracing::error!(%err, evt.cache = %"store", "account");
317 }
318 EventError::CertCacheStore(err) => {
319 tracing::error!(%err, evt.cache = %"store", "cert");
320 }
321 EventError::AccountCacheLoad(err) => {
322 tracing::error!(%err, evt.cache = %"load", "account");
323 }
324 EventError::CachedCertParse(err) => {
325 tracing::error!(%err, evt.parse = %"cache", "cert");
326 }
327 EventError::NewCertParse(err) => {
328 tracing::error!(%err, evt.parse = %"new", "cert");
329 }
330 EventError::CertCacheLoad(err) => {
331 tracing::error!(%err, evt.cache = %"load", "cert");
332 }
333 EventError::Order(err) => {
334 tracing::error!(%err, "order");
335 }
336 },
337 }
338 } else {
339 break;
340 }
341 }
342 });
343}
344
345pub fn redirect_service<H, T, S>(
346 span_clone: Span,
347 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
348 log_ips: bool,
349 log_headers: bool,
350 request_id_header: HeaderName,
351 handler: H,
352 state: S,
353) -> Router
354where
355 H: Handler<T, S>,
356 T: 'static,
357 S: Clone + Send + Sync + 'static,
358{
359 let redacted_hash_clone = redacted_hash.clone();
360
361 Router::new()
362 .route("/healthz", get(|| async { StatusCode::OK }))
363 .fallback(handler)
364 .with_state(state)
365 .layer(
366 ServiceBuilder::new()
367 .layer(CatchPanicLayer::custom(response_for_panic))
368 .layer(RequestDecompressionLayer::new())
369 .layer(CompressionLayer::new()),
370 )
371 .layer(
372 ServiceBuilder::new()
373 .layer(SetRequestIdLayer::new(
374 request_id_header.clone(),
375 MakeRequestUuid,
376 ))
377 .layer(
378 TraceLayer::new_for_http()
379 .make_span_with(move |req: &Request<_>| {
380 let request_id = req.headers().get(REQUEST_ID_HEADER);
381
382 let host =
383 get_host(req.headers(), req.uri()).map(tracing::field::display);
384
385 let ip = log_ips.then(|| {
386 req.extensions()
387 .get::<ConnectInfo<SocketAddr>>()
388 .map(|addr| tracing::field::display(addr.ip()))
389 });
390
391 let query = req.uri().query().map(tracing::field::display);
392
393 span_clone.in_scope(|| match request_id {
394 Some(rid) => {
395 tracing::warn_span!(
396 "redirect",
397 host,
398 id = %rid
399 .to_str()
400 .unwrap_or(Uuid::new_v4().to_string().as_str()),
401 ip,
402 path = %req.uri().path(),
403 query,
404 )
405 }
406 None => {
407 tracing::warn_span!(
408 "redirect",
409 host,
410 id = %Uuid::new_v4(),
411 ip,
412 path = %req.uri().path(),
413 query,
414 )
415 }
416 })
417 })
418 .on_request(move |req: &Request<_>, _: &Span| {
419 let hd = log_headers
420 .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
421
422 #[cfg(tracing_unstable)]
423 let headers = log_headers.then_some(tracing::field::valuable(&hd));
424
425 #[cfg(not(tracing_unstable))]
426 let headers = log_headers.then_some(tracing::field::debug(&hd));
427
428 tracing::warn!(
429 version = ?req.version(),
430 method = %req.method(),
431 headers,
432 "req"
433 );
434 })
435 .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
436 let hd = log_headers.then_some(HeadersDebug(
437 res.headers(),
438 redacted_hash_clone.clone(),
439 ));
440
441 #[cfg(tracing_unstable)]
442 let headers = log_headers.then_some(tracing::field::valuable(&hd));
443
444 #[cfg(not(tracing_unstable))]
445 let headers = log_headers.then_some(tracing::field::debug(&hd));
446
447 let status = res.status().as_u16();
448 let latency = LatencyDisplay(latency.as_nanos() as f64);
449
450 if status >= 500 {
451 tracing::error!(status, headers, %latency, "res");
452 } else if status >= 400 {
453 tracing::warn!(status, headers, %latency, "res");
454 } else {
455 tracing::info!(status, headers, %latency, "res");
456 }
457 })
458 .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
459 tracing::error!(
460 err = %error,
461 "fail"
462 );
463 }),
464 )
465 .layer(TimeoutLayer::with_status_code(
466 StatusCode::REQUEST_TIMEOUT,
467 Duration::from_secs(5),
468 ))
469 .layer(PropagateRequestIdLayer::new(request_id_header))
470 .layer(SetResponseHeaderLayer::if_not_present(
471 header::SERVER,
472 HeaderValue::from_static("Ordinary"),
473 )),
474 )
475}