use axum::extract::ConnectInfo;
use axum::http::{HeaderName, HeaderValue, Request, Version};
use axum::routing::get;
use futures_util::stream::StreamExt;
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;
use tower::ServiceBuilder;
use tower_http::timeout::TimeoutLayer;
use tracing::{Instrument, Span};
use crate::middleware::{ServiceKind, apply_common_middleware, x_via};
use anyhow::anyhow;
use axum::Router;
use axum::body::Body;
use axum::handler::Handler;
use axum::response::Response;
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use blake2::{
Blake2bVar,
digest::{Update, VariableOutput},
};
use bytes::Bytes;
use http_body_util::Full;
use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
use hyper::{HeaderMap, StatusCode, Uri, header};
use ordinary_config::RedactedHashAlg;
use rcgen::{CertifiedKey, generate_simple_self_signed};
use rustls_acme::{AcmeState, EventError, EventOk};
use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Display};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::watch::{Receiver, Sender};
use tokio_rustls::{
rustls::ServerConfig,
rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
};
use tracing::field::DisplayValue;
use valuable::{Mappable, Valuable, Value, Visit};
pub const X_VIA: HeaderName = HeaderName::from_static("x-via");
pub const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
pub const REPORTING_ENDPOINTS: HeaderName = HeaderName::from_static("reporting-endpoints");
pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
pub const X_FORWARDED_HOST: HeaderName = HeaderName::from_static("x-forwarded-host");
pub const X_FORWARDED_PROTO: HeaderName = HeaderName::from_static("x-forwarded-host");
#[derive(PartialEq, Clone)]
pub enum ProvisionMode {
Localhost,
Staging,
Production,
}
#[derive(Clone)]
pub enum SecurityMode<T: AsRef<Path>> {
Insecure,
Secure(T, ProvisionMode),
}
pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
impl WrappedRedactedHashingAlg {
fn hash(&self, header_value: &str) -> String {
let span = tracing::info_span!("redacted:hash");
span.in_scope(|| match self.0 {
RedactedHashAlg::Blake2 => {
let mut out = [0u8; 32];
let mut hasher = match Blake2bVar::new(32) {
Ok(v) => v,
Err(err) => {
tracing::error!(%err);
return "redacted".into();
}
};
hasher.update(header_value.as_bytes());
if let Err(err) = hasher.finalize_variable(&mut out) {
tracing::error!(%err);
return "redacted".into();
}
b64.encode(&out[0..6])
}
RedactedHashAlg::Blake3 => {
b64.encode(&blake3::hash(header_value.as_bytes()).as_bytes()[0..6])
}
})
}
}
pub struct HeadersDebug<'a>(
pub &'a HeaderMap,
pub Arc<Option<WrappedRedactedHashingAlg>>,
);
#[cfg(tracing_unstable)]
impl Valuable for HeadersDebug<'_> {
fn as_value(&self) -> Value<'_> {
Value::Mappable(self)
}
fn visit(&self, visit: &mut dyn Visit) {
for (k, v) in self.0 {
if let Ok(v) = v.to_str() {
if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
{
if let Some(hasher) = &*self.1 {
visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
} else {
visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
}
} else {
visit.visit_entry(k.as_str().as_value(), v.as_value());
}
}
}
}
}
#[cfg(tracing_unstable)]
impl Mappable for HeadersDebug<'_> {
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.iter().size_hint()
}
}
impl Debug for HeadersDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::fmt::Write;
f.write_char('{')?;
let mut is_first = true;
for (k, v) in self.0 {
if let Ok(v) = v.to_str() {
if is_first {
is_first = false;
f.write_char('"')?;
} else {
f.write_str(",\"")?;
}
f.write_str(k.as_str())?;
f.write_str("\":\"")?;
if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
{
f.write_str("redacted")?;
f.write_char('"')?;
} else {
f.write_str(v)?;
f.write_char('"')?;
}
}
}
f.write_char('}')
}
}
#[must_use]
pub fn get_http_version_str(v: Version) -> &'static str {
match v {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_ => unreachable!(),
}
}
#[must_use]
pub fn get_display_ip(log_ips: bool, req: &Request<Body>) -> Option<DisplayValue<IpAddr>> {
log_ips
.then(|| {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|addr| tracing::field::display(get_mapped_ip_for_addr(&addr.0)))
})
.flatten()
}
#[must_use]
pub fn get_mapped_ip_for_addr(addr: &SocketAddr) -> IpAddr {
let ip = addr.ip();
if let IpAddr::V6(ipv6) = &ip
&& let Some(ipv4) = ipv6.to_ipv4_mapped()
{
IpAddr::V4(ipv4)
} else {
ip
}
}
pub fn get_bearer_token_as_bytes(headers: &HeaderMap) -> anyhow::Result<Vec<u8>> {
if let Some(auth_header) = headers.get("authorization")
&& let Ok(str_val) = auth_header.to_str()
&& let Some(b64_token) = str_val.strip_prefix("Bearer ")
&& let Ok(token) = b64.decode(b64_token)
{
Ok(token)
} else {
Err(anyhow!("failed to get token as bytes"))
}
}
pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
if let Some(forwarded_values) = headers.get(header::FORWARDED)
&& let Ok(forwarded_values_str) = forwarded_values.to_str()
&& let Some(first_value) = forwarded_values_str.split(',').next()
&& let Some(host) = first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
{
return Some(host.to_owned());
}
if let Some(host) = headers
.get(X_FORWARDED_HOST)
.and_then(|host| host.to_str().ok())
{
return Some(host.to_owned());
}
if let Some(host) = headers
.get(header::HOST)
.and_then(|host| host.to_str().ok())
{
return Some(host.to_owned());
}
if let Some(authority) = uri.authority() {
return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
}
None
}
pub struct LatencyDisplay(pub f64);
impl Display for LatencyDisplay {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut t = self.0;
for unit in ["ns", "µs", "ms", "s"] {
if t < 10.0 {
return write!(f, "{t:.2}{unit}");
} else if t < 100.0 {
return write!(f, "{t:.1}{unit}");
} else if t < 1000.0 {
return write!(f, "{t:.0}{unit}");
}
t /= 1000.0;
}
write!(f, "{:.0}s", t * 1000.0)
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
res
}
pub fn rustls_server_config(
key: impl AsRef<Path>,
cert: impl AsRef<Path>,
) -> anyhow::Result<Arc<ServerConfig>> {
let key = PrivateKeyDer::from_pem_file(key)?;
let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
pub fn generate_self_signed_localhost_certs(cert_dir_path: impl AsRef<Path>) -> anyhow::Result<()> {
std::fs::create_dir_all(&cert_dir_path)?;
let cert_path = cert_dir_path.as_ref().join("crt.pem");
let key_path = cert_dir_path.as_ref().join("key.pem");
if !cert_path.exists() || !key_path.exists() {
let subject_alt_names = vec!["localhost".to_string()];
let CertifiedKey { cert, signing_key } =
match generate_simple_self_signed(subject_alt_names) {
Ok(ck) => {
tracing::info!("generated self-signed localhost cert");
ck
}
Err(err) => {
tracing::error!("failed to generate self-signed localhost cert");
return Err(err.into());
}
};
let cert = cert.pem();
let key = signing_key.serialize_pem();
let mut cert_file = File::create(cert_path)?;
let mut key_file = File::create(key_path)?;
cert_file.write_all(cert.as_bytes())?;
cert_file.flush()?;
key_file.write_all(key.as_bytes())?;
key_file.flush()?;
}
Ok(())
}
pub fn acme_task(
acme_span_clone: Span,
mut state: AcmeState<std::io::Error>,
signal_tx: Sender<()>,
mut terminate_rx: Option<Receiver<bool>>,
) {
tokio::spawn(async move {
async {
tracing::info!("ready");
loop {
let event = if let Some(terminate_rx) = terminate_rx.as_mut() {
tokio::select! {
state = state.next() => state,
() = signal_tx.closed() => {
tracing::warn!("not accepting new connections");
break;
},
_ = terminate_rx.changed() => {
tracing::warn!("not accepting new connections");
break;
}
}
} else {
tokio::select! {
state = state.next() => state,
() = signal_tx.closed() => {
tracing::warn!("not accepting new connections");
break;
}
}
};
if let Some(event) = event {
match event {
Ok(evt) => match evt {
EventOk::DeployedNewCert => {
tracing::info!(evt.deploy = %"new", "cert");
}
EventOk::CertCacheStore => {
tracing::info!(evt.cache = %"stored", "cert");
}
EventOk::AccountCacheStore => {
tracing::info!(evt.cache = %"stored", "account");
}
EventOk::DeployedCachedCert => {
tracing::info!(evt.deploy = %"cached", "cert");
}
},
Err(err) => match err {
EventError::AccountCacheStore(err) => {
tracing::error!(%err, evt.cache = %"store", "account");
}
EventError::CertCacheStore(err) => {
tracing::error!(%err, evt.cache = %"store", "cert");
}
EventError::AccountCacheLoad(err) => {
tracing::error!(%err, evt.cache = %"load", "account");
}
EventError::CachedCertParse(err) => {
tracing::error!(%err, evt.parse = %"cache", "cert");
}
EventError::NewCertParse(err) => {
tracing::error!(%err, evt.parse = %"new", "cert");
}
EventError::CertCacheLoad(err) => {
tracing::error!(%err, evt.cache = %"load", "cert");
}
EventError::Order(err) => {
tracing::error!(%err, "order");
}
},
}
} else {
break;
}
}
}
.instrument(acme_span_clone)
.await;
});
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub fn redirect_service<H, T, S>(
span_clone: Span,
redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
log_ips: bool,
log_headers: bool,
handler: H,
state: S,
api_domain: Option<String>,
) -> Router
where
H: Handler<T, S>,
T: 'static,
S: Clone + Send + Sync + 'static,
{
let router = Router::new()
.route("/healthz", get(|| async { StatusCode::OK }))
.fallback(handler);
apply_common_middleware(
router,
&state,
Some(span_clone),
String::new(),
log_headers,
log_ips,
redacted_hash,
ServiceKind::Redirect,
Some(api_domain.unwrap_or("redirect".to_owned())),
)
.layer(
ServiceBuilder::new()
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(5),
))
.layer(axum::middleware::from_fn(x_via)),
)
}