use axum::extract::ConnectInfo;
use axum::http::{HeaderName, HeaderValue, Request};
use axum::routing::{MethodRouter, get};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::time::Duration;
use tower::ServiceBuilder;
use tower_http::classify::ServerErrorsFailureClass;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
use tracing::Span;
use uuid::Uuid;
use axum::handler::Handler;
use axum::response::Response;
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use blake2::{
Blake2bVar,
digest::{Update, VariableOutput},
};
use bytes::Bytes;
use http_body_util::Full;
use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
use hyper::{HeaderMap, StatusCode, Uri, header};
use ordinary_config::RedactedHashAlg;
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::any::Any;
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use tokio_rustls::{
rustls::ServerConfig,
rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
};
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;
use valuable::{Mappable, Valuable, Value, Visit};
pub const REQUEST_ID_HEADER: &str = "x-request-id";
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);
impl WrappedRedactedHashingAlg {
fn hash(&self, header_value: &str) -> String {
let span = tracing::info_span!("redacted:hash");
span.in_scope(|| match self.0 {
RedactedHashAlg::Blake2 => {
let mut out = [0u8; 32];
let mut hasher = match Blake2bVar::new(32) {
Ok(v) => v,
Err(err) => {
tracing::error!(%err);
return "redacted".into();
}
};
hasher.update(header_value.as_bytes());
if let Err(err) = hasher.finalize_variable(&mut out) {
tracing::error!(%err);
return "redacted".into();
}
b64.encode(out)
}
RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
})
}
}
pub struct HeadersDebug<'a>(
pub &'a HeaderMap,
pub Arc<Option<WrappedRedactedHashingAlg>>,
);
#[cfg(tracing_unstable)]
impl Valuable for HeadersDebug<'_> {
fn as_value(&self) -> Value<'_> {
Value::Mappable(self)
}
fn visit(&self, visit: &mut dyn Visit) {
for (k, v) in self.0 {
if let Ok(v) = v.to_str() {
if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
{
if let Some(hasher) = &*self.1 {
visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
} else {
visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
}
} else {
visit.visit_entry(k.as_str().as_value(), v.as_value());
}
}
}
}
}
#[cfg(tracing_unstable)]
impl Mappable for HeadersDebug<'_> {
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.iter().size_hint()
}
}
impl Debug for HeadersDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::fmt::Write;
f.write_char('{')?;
let mut is_first = true;
for (k, v) in self.0 {
if let Ok(v) = v.to_str() {
if is_first {
is_first = false;
f.write_char('"')?;
} else {
f.write_str(",\"")?;
}
f.write_str(k.as_str())?;
f.write_str("\":\"")?;
if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
{
f.write_str("redacted")?;
f.write_char('"')?;
} else {
f.write_str(v)?;
f.write_char('"')?;
}
}
}
f.write_char('}')
}
}
pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
if let Some(forwarded_values) = headers.get(header::FORWARDED)
&& let Ok(forwarded_values_str) = forwarded_values.to_str()
&& let Some(first_value) = forwarded_values_str.split(',').next()
&& let Some(host) = first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
{
return Some(host.to_owned());
}
if let Some(host) = headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Some(host.to_owned());
}
if let Some(host) = headers
.get(header::HOST)
.and_then(|host| host.to_str().ok())
{
return Some(host.to_owned());
}
if let Some(authority) = uri.authority() {
return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
}
None
}
pub struct LatencyDisplay(pub f64);
impl Display for LatencyDisplay {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut t = self.0;
for unit in ["ns", "µs", "ms", "s"] {
if t < 10.0 {
return write!(f, "{t:.2}{unit}");
} else if t < 100.0 {
return write!(f, "{t:.1}{unit}");
} else if t < 1000.0 {
return write!(f, "{t:.0}{unit}");
}
t /= 1000.0;
}
write!(f, "{:.0}s", t * 1000.0)
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);
res
}
pub fn rustls_server_config(
key: impl AsRef<Path>,
cert: impl AsRef<Path>,
) -> Result<Arc<ServerConfig>, Box<dyn Error>> {
let key = PrivateKeyDer::from_pem_file(key)?;
let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
pub fn generate_self_signed_localhost_certs(
cert_dir_path: impl AsRef<Path>,
) -> Result<(), Box<dyn Error>> {
std::fs::create_dir_all(&cert_dir_path)?;
let cert_path = cert_dir_path.as_ref().join("crt.pem");
let key_path = cert_dir_path.as_ref().join("key.pem");
if !cert_path.exists() || !key_path.exists() {
let subject_alt_names = vec!["localhost".to_string()];
let CertifiedKey { cert, signing_key } =
match generate_simple_self_signed(subject_alt_names) {
Ok(ck) => {
tracing::info!("generated self-signed localhost cert");
ck
}
Err(err) => {
tracing::error!("failed to generate self-signed localhost cert");
return Err(err.into());
}
};
let cert = cert.pem();
let key = signing_key.serialize_pem();
let mut cert_file = File::create(cert_path)?;
let mut key_file = File::create(key_path)?;
cert_file.write_all(cert.as_bytes())?;
key_file.write_all(key.as_bytes())?;
}
Ok(())
}
pub fn redirect_service<H, T, S>(
span_clone: Span,
redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
log_ips: bool,
log_headers: bool,
request_id_header: HeaderName,
handler: H,
state: S,
) -> MethodRouter
where
H: Handler<T, S>,
T: 'static,
S: Clone + Send + Sync + 'static,
{
let redacted_hash_clone = redacted_hash.clone();
get(handler)
.with_state(state)
.layer::<_, Infallible>(
ServiceBuilder::new()
.layer(CatchPanicLayer::custom(response_for_panic))
.layer(RequestDecompressionLayer::new())
.layer(CompressionLayer::new()),
)
.layer::<_, Infallible>(
ServiceBuilder::new()
.layer(SetRequestIdLayer::new(
request_id_header.clone(),
MakeRequestUuid,
))
.layer(
TraceLayer::new_for_http()
.make_span_with(move |req: &Request<_>| {
let request_id = req.headers().get(REQUEST_ID_HEADER);
let host =
get_host(req.headers(), req.uri()).map(tracing::field::display);
let ip = log_ips.then(|| {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|addr| tracing::field::display(addr.ip()))
});
let query = req.uri().query().map(tracing::field::display);
span_clone.in_scope(|| match request_id {
Some(rid) => {
tracing::warn_span!(
"redirect",
host,
id = %rid
.to_str()
.unwrap_or(Uuid::new_v4().to_string().as_str()),
ip,
path = %req.uri().path(),
query,
)
}
None => {
tracing::warn_span!(
"redirect",
host,
id = %Uuid::new_v4(),
ip,
path = %req.uri().path(),
query,
)
}
})
})
.on_request(move |req: &Request<_>, _: &Span| {
let hd = log_headers
.then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
#[cfg(tracing_unstable)]
let headers = log_headers.then_some(tracing::field::valuable(&hd));
#[cfg(not(tracing_unstable))]
let headers = log_headers.then_some(tracing::field::debug(&hd));
tracing::warn!(
version = ?req.version(),
method = %req.method(),
headers,
"req"
);
})
.on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
let hd = log_headers.then_some(HeadersDebug(
res.headers(),
redacted_hash_clone.clone(),
));
#[cfg(tracing_unstable)]
let headers = log_headers.then_some(tracing::field::valuable(&hd));
#[cfg(not(tracing_unstable))]
let headers = log_headers.then_some(tracing::field::debug(&hd));
let status = res.status().as_u16();
let latency = LatencyDisplay(latency.as_nanos() as f64);
if status >= 500 {
tracing::error!(status, headers, %latency, "res");
} else if status >= 400 {
tracing::warn!(status, headers, %latency, "res");
} else {
tracing::info!(status, headers, %latency, "res");
}
})
.on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
tracing::error!(
err = %error,
"fail"
);
}),
)
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(5),
))
.layer(PropagateRequestIdLayer::new(request_id_header))
.layer(SetResponseHeaderLayer::if_not_present(
header::SERVER,
HeaderValue::from_static("Ordinary"),
)),
)
}