1use axum::extract::ConnectInfo;
6use axum::http::{HeaderName, HeaderValue, Request, Version};
7use axum::routing::get;
8use futures_util::stream::StreamExt;
9use std::net::{IpAddr, SocketAddr};
10use std::time::Duration;
11use tower::ServiceBuilder;
12use tower_http::timeout::TimeoutLayer;
13use tracing::{Instrument, Span};
14
15use crate::middleware::{ServiceKind, apply_common_middleware, x_via};
16use anyhow::anyhow;
17use axum::Router;
18use axum::body::Body;
19use axum::handler::Handler;
20use axum::response::Response;
21use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
22use blake2::{
23 Blake2bVar,
24 digest::{Update, VariableOutput},
25};
26use bytes::Bytes;
27use http_body_util::Full;
28use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
29use hyper::{HeaderMap, StatusCode, Uri, header};
30use ordinary_config::RedactedHashAlg;
31use rcgen::{CertifiedKey, generate_simple_self_signed};
32use rustls_acme::{AcmeState, EventError, EventOk};
33use std::any::Any;
34use std::fmt;
35use std::fmt::{Debug, Display};
36use std::fs::File;
37use std::io::Write;
38use std::path::Path;
39use std::sync::Arc;
40use tokio::sync::watch::{Receiver, Sender};
41use tokio_rustls::{
42 rustls::ServerConfig,
43 rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
44};
45use tracing::field::DisplayValue;
46use valuable::{Mappable, Valuable, Value, Visit};
47
48pub const X_VIA: HeaderName = HeaderName::from_static("x-via");
49pub const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
50pub const REPORTING_ENDPOINTS: HeaderName = HeaderName::from_static("reporting-endpoints");
51pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
52pub const X_FORWARDED_HOST: HeaderName = HeaderName::from_static("x-forwarded-host");
53pub const X_FORWARDED_PROTO: HeaderName = HeaderName::from_static("x-forwarded-proto");
54
55#[derive(PartialEq, Clone)]
56pub enum ProvisionMode {
57 Localhost,
58 Staging,
59 Production,
60}
61
62#[derive(Clone)]
63pub enum SecurityMode<T: AsRef<Path>> {
64 Insecure,
65 Secure(T, ProvisionMode),
66}
67
68pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
69
70impl WrappedRedactedHashingAlg {
71 fn hash(&self, header_value: &str) -> String {
72 let span = tracing::info_span!("redacted:hash");
73
74 span.in_scope(|| match self.0 {
75 RedactedHashAlg::Blake2 => {
76 let mut out = [0u8; 32];
77
78 let mut hasher = match Blake2bVar::new(32) {
79 Ok(v) => v,
80 Err(err) => {
81 tracing::error!(%err);
82 return "redacted".into();
83 }
84 };
85
86 hasher.update(header_value.as_bytes());
87 if let Err(err) = hasher.finalize_variable(&mut out) {
88 tracing::error!(%err);
89 return "redacted".into();
90 }
91
92 b64.encode(&out[0..6])
93 }
94 RedactedHashAlg::Blake3 => {
95 b64.encode(&blake3::hash(header_value.as_bytes()).as_bytes()[0..6])
96 }
97 })
98 }
99}
100pub struct HeadersDebug<'a>(
101 pub &'a HeaderMap,
102 pub Arc<Option<WrappedRedactedHashingAlg>>,
103);
104
105#[cfg(tracing_unstable)]
106impl Valuable for HeadersDebug<'_> {
107 fn as_value(&self) -> Value<'_> {
108 Value::Mappable(self)
109 }
110
111 fn visit(&self, visit: &mut dyn Visit) {
112 for (k, v) in self.0 {
113 if let Ok(v) = v.to_str() {
114 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
115 {
116 if let Some(hasher) = &*self.1 {
117 visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
118 } else {
119 visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
120 }
121 } else {
122 visit.visit_entry(k.as_str().as_value(), v.as_value());
123 }
124 }
125 }
126 }
127}
128
129#[cfg(tracing_unstable)]
130impl Mappable for HeadersDebug<'_> {
131 fn size_hint(&self) -> (usize, Option<usize>) {
132 self.0.iter().size_hint()
133 }
134}
135
136impl Debug for HeadersDebug<'_> {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 use std::fmt::Write;
139
140 f.write_char('{')?;
141
142 let mut is_first = true;
143
144 for (k, v) in self.0 {
145 if let Ok(v) = v.to_str() {
146 if is_first {
147 is_first = false;
148 f.write_char('"')?;
149 } else {
150 f.write_str(",\"")?;
151 }
152
153 f.write_str(k.as_str())?;
154 f.write_str("\":\"")?;
155
156 if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
157 {
158 f.write_str("redacted")?;
159 f.write_char('"')?;
160 } else {
161 f.write_str(v)?;
162 f.write_char('"')?;
163 }
164 }
165 }
166
167 f.write_char('}')
168 }
169}
170
171#[must_use]
172pub fn get_http_version_str(v: Version) -> &'static str {
173 match v {
174 Version::HTTP_09 => "0.9",
175 Version::HTTP_10 => "1.0",
176 Version::HTTP_11 => "1.1",
177 Version::HTTP_2 => "2.0",
178 Version::HTTP_3 => "3.0",
179 _ => unreachable!(),
180 }
181}
182
183#[must_use]
184pub fn get_display_ip(log_ips: bool, req: &Request<Body>) -> Option<DisplayValue<IpAddr>> {
185 log_ips
186 .then(|| {
187 req.extensions()
188 .get::<ConnectInfo<SocketAddr>>()
189 .map(|addr| tracing::field::display(get_mapped_ip_for_addr(&addr.0)))
190 })
191 .flatten()
192}
193
194#[must_use]
195pub fn get_mapped_ip_for_addr(addr: &SocketAddr) -> IpAddr {
196 let ip = addr.ip();
197
198 if let IpAddr::V6(ipv6) = &ip
199 && let Some(ipv4) = ipv6.to_ipv4_mapped()
200 {
201 IpAddr::V4(ipv4)
202 } else {
203 ip
204 }
205}
206
207pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
208 if let Some(auth_header) = headers.get("authorization")
209 && let Ok(str_val) = auth_header.to_str()
210 && let Some(b64_token) = str_val.strip_prefix("Bearer ")
211 && let Ok(token) = b64.decode(b64_token)
212 {
213 Ok(token)
214 } else {
215 Err(anyhow!("failed to get token as bytes"))
216 }
217}
218
219pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
220 if let Some(forwarded_values) = headers.get(header::FORWARDED)
221 && let Ok(forwarded_values_str) = forwarded_values.to_str()
222 && let Some(first_value) = forwarded_values_str.split(',').next()
223 && let Some(host) = first_value.split(';').find_map(|pair| {
224 let (key, value) = pair.split_once('=')?;
225 key.trim()
226 .eq_ignore_ascii_case("host")
227 .then(|| value.trim().trim_matches('"'))
228 })
229 {
230 return Some(host.to_owned());
231 }
232
233 if let Some(host) = headers
234 .get(X_FORWARDED_HOST)
235 .and_then(|host| host.to_str().ok())
236 {
237 return Some(host.to_owned());
238 }
239
240 if let Some(host) = headers
241 .get(header::HOST)
242 .and_then(|host| host.to_str().ok())
243 {
244 return Some(host.to_owned());
245 }
246
247 if let Some(authority) = uri.authority() {
248 return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
249 }
250
251 None
252}
253
254pub struct LatencyDisplay(pub f64);
255
256impl Display for LatencyDisplay {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 let mut t = self.0;
259
260 for unit in ["ns", "µs", "ms", "s"] {
261 if t < 10.0 {
262 return write!(f, "{t:.2}{unit}");
263 } else if t < 100.0 {
264 return write!(f, "{t:.1}{unit}");
265 } else if t < 1000.0 {
266 return write!(f, "{t:.0}{unit}");
267 }
268 t /= 1000.0;
269 }
270 write!(f, "{:.0}s", t * 1000.0)
271 }
272}
273
274#[allow(clippy::needless_pass_by_value)]
275pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
276 #[allow(clippy::declare_interior_mutable_const)]
277 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
278
279 let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
280
281 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
282 res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
283
284 res
285}
286
287pub fn rustls_server_config(
288 key: impl AsRef<Path>,
289 cert: impl AsRef<Path>,
290) -> anyhow::Result<Arc<ServerConfig>> {
291 let key = PrivateKeyDer::from_pem_file(key)?;
292
293 let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
294
295 let mut config = ServerConfig::builder()
296 .with_no_client_auth()
297 .with_single_cert(certs, key)?;
298
299 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
300
301 Ok(Arc::new(config))
302}
303
304pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
306 std::fs::create_dir_all(&cert_dir_path)?;
307
308 let cert_path = cert_dir_path.as_ref().join("crt.pem");
309 let key_path = cert_dir_path.as_ref().join("key.pem");
310
311 if !cert_path.exists() || !key_path.exists() {
312 let subject_alt_names = vec!["localhost".to_string()];
313
314 let CertifiedKey { cert, signing_key } =
315 match generate_simple_self_signed(subject_alt_names) {
316 Ok(ck) => {
317 tracing::info!("generated self-signed localhost cert");
318 ck
319 }
320 Err(err) => {
321 tracing::error!("failed to generate self-signed localhost cert");
322 return Err(err.into());
323 }
324 };
325
326 let cert = cert.pem();
327 let key = signing_key.serialize_pem();
328
329 let mut cert_file = File::create(cert_path)?;
330 let mut key_file = File::create(key_path)?;
331
332 cert_file.write_all(cert.as_bytes())?;
333 cert_file.flush()?;
334 key_file.write_all(key.as_bytes())?;
335 key_file.flush()?;
336 }
337
338 Ok(())
339}
340
341pub fn acme_task(
342 acme_span_clone: Span,
343 mut state: AcmeState<std::io::Error>,
344 signal_tx: Sender<()>,
345 mut terminate_rx: Option<Receiver<bool>>,
346) {
347 tokio::spawn(async move {
348 async {
349 tracing::info!("ready");
350
351 loop {
352 let event = if let Some(terminate_rx) = terminate_rx.as_mut() {
353 tokio::select! {
354 state = state.next() => state,
355 () = signal_tx.closed() => {
356 tracing::warn!("not accepting new connections");
357 break;
358 },
359 _ = terminate_rx.changed() => {
360 tracing::warn!("not accepting new connections");
361 break;
362 }
363 }
364 } else {
365 tokio::select! {
366 state = state.next() => state,
367 () = signal_tx.closed() => {
368 tracing::warn!("not accepting new connections");
369 break;
370 }
371 }
372 };
373
374 if let Some(event) = event {
375 match event {
376 Ok(evt) => match evt {
377 EventOk::DeployedNewCert => {
378 tracing::info!(evt.deploy = %"new", "cert");
379 }
380 EventOk::CertCacheStore => {
381 tracing::info!(evt.cache = %"stored", "cert");
382 }
383 EventOk::AccountCacheStore => {
384 tracing::info!(evt.cache = %"stored", "account");
385 }
386 EventOk::DeployedCachedCert => {
387 tracing::info!(evt.deploy = %"cached", "cert");
388 }
389 },
390 Err(err) => match err {
391 EventError::AccountCacheStore(err) => {
392 tracing::error!(%err, evt.cache = %"store", "account");
393 }
394 EventError::CertCacheStore(err) => {
395 tracing::error!(%err, evt.cache = %"store", "cert");
396 }
397 EventError::AccountCacheLoad(err) => {
398 tracing::error!(%err, evt.cache = %"load", "account");
399 }
400 EventError::CachedCertParse(err) => {
401 tracing::error!(%err, evt.parse = %"cache", "cert");
402 }
403 EventError::NewCertParse(err) => {
404 tracing::error!(%err, evt.parse = %"new", "cert");
405 }
406 EventError::CertCacheLoad(err) => {
407 tracing::error!(%err, evt.cache = %"load", "cert");
408 }
409 EventError::Order(err) => {
410 tracing::error!(%err, "order");
411 }
412 },
413 }
414 } else {
415 break;
416 }
417 }
418 }
419 .instrument(acme_span_clone)
420 .await;
421 });
422}
423
424#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
425pub fn redirect_service<H, T, S>(
426 span_clone: Span,
427 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
428 log_ips: bool,
429 log_headers: bool,
430 handler: H,
431 state: S,
432 api_domain: Option<String>,
433) -> Router
434where
435 H: Handler<T, S>,
436 T: 'static,
437 S: Clone + Send + Sync + 'static,
438{
439 let router = Router::new()
440 .route("/healthz", get(|| async { StatusCode::OK }))
441 .fallback(handler);
442
443 apply_common_middleware(
444 router,
445 &state,
446 Some(span_clone),
447 String::new(),
448 log_headers,
449 log_ips,
450 redacted_hash,
451 ServiceKind::Redirect,
452 Some(api_domain.unwrap_or("redirect".to_owned())),
453 )
454 .layer(
455 ServiceBuilder::new()
456 .layer(TimeoutLayer::with_status_code(
457 StatusCode::REQUEST_TIMEOUT,
458 Duration::from_secs(5),
459 ))
460 .layer(axum::middleware::from_fn(x_via)),
461 )
462}