1use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::SocketAddr;
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::classify::ServerErrorsFailureClass;
13use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
14use tower_http::set_header::SetResponseHeaderLayer;
15use tower_http::timeout::TimeoutLayer;
16use tower_http::trace::TraceLayer;
17use tracing::Span;
18use uuid::Uuid;
19
20use axum::Router;
21use axum::handler::Handler;
22use axum::response::Response;
23use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
24use blake2::{
25 Blake2bVar,
26 digest::{Update, VariableOutput},
27};
28use bytes::Bytes;
29use http_body_util::Full;
30use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
31use hyper::{HeaderMap, StatusCode, Uri, header};
32use ordinary_config::RedactedHashAlg;
33use rcgen::{CertifiedKey, generate_simple_self_signed};
34use rustls_acme::{AcmeState, EventError, EventOk};
35use std::any::Any;
36use std::fmt;
37use std::fmt::{Debug, Display};
38use std::fs::File;
39use std::io::Write;
40use std::path::Path;
41use std::sync::Arc;
42use tokio::sync::watch::Sender;
43use tokio_rustls::{
44 rustls::ServerConfig,
45 rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
46};
47use tower_http::catch_panic::CatchPanicLayer;
48use tower_http::compression::CompressionLayer;
49use tower_http::decompression::RequestDecompressionLayer;
50use valuable::{Mappable, Valuable, Value, Visit};
51
52pub const REQUEST_ID_HEADER: &str = "x-request-id";
53const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
54
55pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
56
57impl WrappedRedactedHashingAlg {
58 fn hash(&self, header_value: &str) -> String {
59 let span = tracing::info_span!("redacted:hash");
60
61 span.in_scope(|| match self.0 {
62 RedactedHashAlg::Blake2 => {
63 let mut out = [0u8; 32];
64
65 let mut hasher = match Blake2bVar::new(32) {
66 Ok(v) => v,
67 Err(err) => {
68 tracing::error!(%err);
69 return "redacted".into();
70 }
71 };
72
73 hasher.update(header_value.as_bytes());
74 if let Err(err) = hasher.finalize_variable(&mut out) {
75 tracing::error!(%err);
76 return "redacted".into();
77 }
78
79 b64.encode(out)
80 }
81 RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
82 })
83 }
84}
85pub struct HeadersDebug<'a>(
86 pub &'a HeaderMap,
87 pub Arc<Option<WrappedRedactedHashingAlg>>,
88);
89
90#[cfg(tracing_unstable)]
91impl Valuable for HeadersDebug<'_> {
92 fn as_value(&self) -> Value<'_> {
93 Value::Mappable(self)
94 }
95
96 fn visit(&self, visit: &mut dyn Visit) {
97 for (k, v) in self.0 {
98 if let Ok(v) = v.to_str() {
99 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
100 {
101 if let Some(hasher) = &*self.1 {
102 visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
103 } else {
104 visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
105 }
106 } else {
107 visit.visit_entry(k.as_str().as_value(), v.as_value());
108 }
109 }
110 }
111 }
112}
113
114#[cfg(tracing_unstable)]
115impl Mappable for HeadersDebug<'_> {
116 fn size_hint(&self) -> (usize, Option<usize>) {
117 self.0.iter().size_hint()
118 }
119}
120
121impl Debug for HeadersDebug<'_> {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 use std::fmt::Write;
124
125 f.write_char('{')?;
126
127 let mut is_first = true;
128
129 for (k, v) in self.0 {
130 if let Ok(v) = v.to_str() {
131 if is_first {
132 is_first = false;
133 f.write_char('"')?;
134 } else {
135 f.write_str(",\"")?;
136 }
137
138 f.write_str(k.as_str())?;
139 f.write_str("\":\"")?;
140
141 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
142 {
143 f.write_str("redacted")?;
144 f.write_char('"')?;
145 } else {
146 f.write_str(v)?;
147 f.write_char('"')?;
148 }
149 }
150 }
151
152 f.write_char('}')
153 }
154}
155
156pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
157 if let Some(forwarded_values) = headers.get(header::FORWARDED)
158 && let Ok(forwarded_values_str) = forwarded_values.to_str()
159 && let Some(first_value) = forwarded_values_str.split(',').next()
160 && let Some(host) = first_value.split(';').find_map(|pair| {
161 let (key, value) = pair.split_once('=')?;
162 key.trim()
163 .eq_ignore_ascii_case("host")
164 .then(|| value.trim().trim_matches('"'))
165 })
166 {
167 return Some(host.to_owned());
168 }
169
170 if let Some(host) = headers
171 .get(X_FORWARDED_HOST_HEADER_KEY)
172 .and_then(|host| host.to_str().ok())
173 {
174 return Some(host.to_owned());
175 }
176
177 if let Some(host) = headers
178 .get(header::HOST)
179 .and_then(|host| host.to_str().ok())
180 {
181 return Some(host.to_owned());
182 }
183
184 if let Some(authority) = uri.authority() {
185 return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
186 }
187
188 None
189}
190
191pub struct LatencyDisplay(pub f64);
192
193impl Display for LatencyDisplay {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 let mut t = self.0;
196
197 for unit in ["ns", "µs", "ms", "s"] {
198 if t < 10.0 {
199 return write!(f, "{t:.2}{unit}");
200 } else if t < 100.0 {
201 return write!(f, "{t:.1}{unit}");
202 } else if t < 1000.0 {
203 return write!(f, "{t:.0}{unit}");
204 }
205 t /= 1000.0;
206 }
207 write!(f, "{:.0}s", t * 1000.0)
208 }
209}
210
211#[allow(clippy::needless_pass_by_value)]
212pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
213 #[allow(clippy::declare_interior_mutable_const)]
214 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
215
216 let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
217
218 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
219 res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
220
221 res
222}
223
224pub fn rustls_server_config(
225 key: impl AsRef<Path>,
226 cert: impl AsRef<Path>,
227) -> anyhow::Result<Arc<ServerConfig>> {
228 let key = PrivateKeyDer::from_pem_file(key)?;
229
230 let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
231
232 let mut config = ServerConfig::builder()
233 .with_no_client_auth()
234 .with_single_cert(certs, key)?;
235
236 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
237
238 Ok(Arc::new(config))
239}
240
241pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
243 std::fs::create_dir_all(&cert_dir_path)?;
244
245 let cert_path = cert_dir_path.as_ref().join("crt.pem");
246 let key_path = cert_dir_path.as_ref().join("key.pem");
247
248 if !cert_path.exists() || !key_path.exists() {
249 let subject_alt_names = vec!["localhost".to_string()];
250
251 let CertifiedKey { cert, signing_key } =
252 match generate_simple_self_signed(subject_alt_names) {
253 Ok(ck) => {
254 tracing::info!("generated self-signed localhost cert");
255 ck
256 }
257 Err(err) => {
258 tracing::error!("failed to generate self-signed localhost cert");
259 return Err(err.into());
260 }
261 };
262
263 let cert = cert.pem();
264 let key = signing_key.serialize_pem();
265
266 let mut cert_file = File::create(cert_path)?;
267 let mut key_file = File::create(key_path)?;
268
269 cert_file.write_all(cert.as_bytes())?;
270 key_file.write_all(key.as_bytes())?;
271 }
272
273 Ok(())
274}
275
276pub fn acme_task(
277 acme_span_clone: Span,
278 mut state: AcmeState<std::io::Error>,
279 signal_tx: Sender<()>,
280) {
281 tokio::spawn(async move {
282 loop {
283 let event = tokio::select! {
284 state = state.next() => state,
285 () = signal_tx.closed() => {
286 acme_span_clone.in_scope(|| {
287 tracing::warn!("not accepting new connections");
288 });
289 break;
290 }
291 };
292
293 if let Some(event) = event {
294 match event {
295 Ok(evt) => {
296 acme_span_clone.in_scope(|| match evt {
297 EventOk::DeployedNewCert => {
298 tracing::info!(evt.deploy = %"new", "cert");
299 }
300 EventOk::CertCacheStore => {
301 tracing::info!(evt.cache = %"stored", "cert");
302 }
303 EventOk::AccountCacheStore => {
304 tracing::info!(evt.cache = %"stored", "account");
305 }
306 EventOk::DeployedCachedCert => {
307 tracing::info!(evt.deploy = %"cached", "cert");
308 }
309 });
310 }
311 Err(err) => match err {
312 EventError::AccountCacheStore(err) => {
313 tracing::error!(%err, evt.cache = %"store", "account");
314 }
315 EventError::CertCacheStore(err) => {
316 tracing::error!(%err, evt.cache = %"store", "cert");
317 }
318 EventError::AccountCacheLoad(err) => {
319 tracing::error!(%err, evt.cache = %"load", "account");
320 }
321 EventError::CachedCertParse(err) => {
322 tracing::error!(%err, evt.parse = %"cache", "cert");
323 }
324 EventError::NewCertParse(err) => {
325 tracing::error!(%err, evt.parse = %"new", "cert");
326 }
327 EventError::CertCacheLoad(err) => {
328 tracing::error!(%err, evt.cache = %"load", "cert");
329 }
330 EventError::Order(err) => {
331 tracing::error!(%err, "order");
332 }
333 },
334 }
335 } else {
336 break;
337 }
338 }
339 });
340}
341
342pub fn redirect_service<H, T, S>(
343 span_clone: Span,
344 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
345 log_ips: bool,
346 log_headers: bool,
347 request_id_header: HeaderName,
348 handler: H,
349 state: S,
350) -> Router
351where
352 H: Handler<T, S>,
353 T: 'static,
354 S: Clone + Send + Sync + 'static,
355{
356 let redacted_hash_clone = redacted_hash.clone();
357
358 Router::new()
359 .route("/healthz", get(|| async { StatusCode::OK }))
360 .fallback(handler)
361 .with_state(state)
362 .layer(
363 ServiceBuilder::new()
364 .layer(CatchPanicLayer::custom(response_for_panic))
365 .layer(RequestDecompressionLayer::new())
366 .layer(CompressionLayer::new()),
367 )
368 .layer(
369 ServiceBuilder::new()
370 .layer(SetRequestIdLayer::new(
371 request_id_header.clone(),
372 MakeRequestUuid,
373 ))
374 .layer(
375 TraceLayer::new_for_http()
376 .make_span_with(move |req: &Request<_>| {
377 let request_id = req.headers().get(REQUEST_ID_HEADER);
378
379 let host =
380 get_host(req.headers(), req.uri()).map(tracing::field::display);
381
382 let ip = log_ips.then(|| {
383 req.extensions()
384 .get::<ConnectInfo<SocketAddr>>()
385 .map(|addr| tracing::field::display(addr.ip()))
386 });
387
388 let query = req.uri().query().map(tracing::field::display);
389
390 span_clone.in_scope(|| match request_id {
391 Some(rid) => {
392 tracing::warn_span!(
393 "redirect",
394 host,
395 id = %rid
396 .to_str()
397 .unwrap_or(Uuid::new_v4().to_string().as_str()),
398 ip,
399 path = %req.uri().path(),
400 query,
401 )
402 }
403 None => {
404 tracing::warn_span!(
405 "redirect",
406 host,
407 id = %Uuid::new_v4(),
408 ip,
409 path = %req.uri().path(),
410 query,
411 )
412 }
413 })
414 })
415 .on_request(move |req: &Request<_>, _: &Span| {
416 let hd = log_headers
417 .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
418
419 #[cfg(tracing_unstable)]
420 let headers = log_headers.then_some(tracing::field::valuable(&hd));
421
422 #[cfg(not(tracing_unstable))]
423 let headers = log_headers.then_some(tracing::field::debug(&hd));
424
425 tracing::warn!(
426 version = ?req.version(),
427 method = %req.method(),
428 headers,
429 "req"
430 );
431 })
432 .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
433 let hd = log_headers.then_some(HeadersDebug(
434 res.headers(),
435 redacted_hash_clone.clone(),
436 ));
437
438 #[cfg(tracing_unstable)]
439 let headers = log_headers.then_some(tracing::field::valuable(&hd));
440
441 #[cfg(not(tracing_unstable))]
442 let headers = log_headers.then_some(tracing::field::debug(&hd));
443
444 let status = res.status().as_u16();
445 let latency = LatencyDisplay(latency.as_nanos() as f64);
446
447 if status >= 500 {
448 tracing::error!(status, headers, %latency, "res");
449 } else if status >= 400 {
450 tracing::warn!(status, headers, %latency, "res");
451 } else {
452 tracing::info!(status, headers, %latency, "res");
453 }
454 })
455 .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
456 tracing::error!(
457 err = %error,
458 "fail"
459 );
460 }),
461 )
462 .layer(TimeoutLayer::with_status_code(
463 StatusCode::REQUEST_TIMEOUT,
464 Duration::from_secs(5),
465 ))
466 .layer(PropagateRequestIdLayer::new(request_id_header))
467 .layer(SetResponseHeaderLayer::if_not_present(
468 header::SERVER,
469 HeaderValue::from_static("Ordinary"),
470 )),
471 )
472}