ordinary-utils 0.6.0-pre.1

Utils for Ordinary
Documentation
use axum::extract::ConnectInfo;
use axum::http::{HeaderName, HeaderValue, Request};
use axum::routing::get;
use std::net::SocketAddr;
use std::time::Duration;
use tower::ServiceBuilder;
use tower_http::classify::ServerErrorsFailureClass;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
use tracing::Span;
use uuid::Uuid;

use axum::Router;
use axum::handler::Handler;
use axum::response::Response;
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use blake2::{
    Blake2bVar,
    digest::{Update, VariableOutput},
};
use bytes::Bytes;
use http_body_util::Full;
use hyper::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, SET_COOKIE};
use hyper::{HeaderMap, StatusCode, Uri, header};
use ordinary_config::RedactedHashAlg;
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::any::Any;
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use tokio_rustls::{
    rustls::ServerConfig,
    rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
};
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;
use valuable::{Mappable, Valuable, Value, Visit};

pub const REQUEST_ID_HEADER: &str = "x-request-id";
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";

pub struct WrappedRedactedHashingAlg(pub RedactedHashAlg);

impl WrappedRedactedHashingAlg {
    fn hash(&self, header_value: &str) -> String {
        let span = tracing::info_span!("redacted:hash");

        span.in_scope(|| match self.0 {
            RedactedHashAlg::Blake2 => {
                let mut out = [0u8; 32];

                let mut hasher = match Blake2bVar::new(32) {
                    Ok(v) => v,
                    Err(err) => {
                        tracing::error!(%err);
                        return "redacted".into();
                    }
                };

                hasher.update(header_value.as_bytes());
                if let Err(err) = hasher.finalize_variable(&mut out) {
                    tracing::error!(%err);
                    return "redacted".into();
                }

                b64.encode(out)
            }
            RedactedHashAlg::Blake3 => b64.encode(blake3::hash(header_value.as_bytes()).as_bytes()),
        })
    }
}
pub struct HeadersDebug<'a>(
    pub &'a HeaderMap,
    pub Arc<Option<WrappedRedactedHashingAlg>>,
);

#[cfg(tracing_unstable)]
impl Valuable for HeadersDebug<'_> {
    fn as_value(&self) -> Value<'_> {
        Value::Mappable(self)
    }

    fn visit(&self, visit: &mut dyn Visit) {
        for (k, v) in self.0 {
            if let Ok(v) = v.to_str() {
                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
                {
                    if let Some(hasher) = &*self.1 {
                        visit.visit_entry(k.as_str().as_value(), hasher.hash(v).as_value());
                    } else {
                        visit.visit_entry(k.as_str().as_value(), "redacted".as_value());
                    }
                } else {
                    visit.visit_entry(k.as_str().as_value(), v.as_value());
                }
            }
        }
    }
}

#[cfg(tracing_unstable)]
impl Mappable for HeadersDebug<'_> {
    fn size_hint(&self) -> (usize, Option<usize>) {
        self.0.iter().size_hint()
    }
}

impl Debug for HeadersDebug<'_> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        use std::fmt::Write;

        f.write_char('{')?;

        let mut is_first = true;

        for (k, v) in self.0 {
            if let Ok(v) = v.to_str() {
                if is_first {
                    is_first = false;
                    f.write_char('"')?;
                } else {
                    f.write_str(",\"")?;
                }

                f.write_str(k.as_str())?;
                f.write_str("\":\"")?;

                if k == AUTHORIZATION || k == PROXY_AUTHORIZATION || k == COOKIE || k == SET_COOKIE
                {
                    f.write_str("redacted")?;
                    f.write_char('"')?;
                } else {
                    f.write_str(v)?;
                    f.write_char('"')?;
                }
            }
        }

        f.write_char('}')
    }
}

pub fn get_host(headers: &HeaderMap, uri: &Uri) -> Option<String> {
    if let Some(forwarded_values) = headers.get(header::FORWARDED)
        && let Ok(forwarded_values_str) = forwarded_values.to_str()
        && let Some(first_value) = forwarded_values_str.split(',').next()
        && let Some(host) = first_value.split(';').find_map(|pair| {
            let (key, value) = pair.split_once('=')?;
            key.trim()
                .eq_ignore_ascii_case("host")
                .then(|| value.trim().trim_matches('"'))
        })
    {
        return Some(host.to_owned());
    }

    if let Some(host) = headers
        .get(X_FORWARDED_HOST_HEADER_KEY)
        .and_then(|host| host.to_str().ok())
    {
        return Some(host.to_owned());
    }

    if let Some(host) = headers
        .get(header::HOST)
        .and_then(|host| host.to_str().ok())
    {
        return Some(host.to_owned());
    }

    if let Some(authority) = uri.authority() {
        return authority.as_str().rsplit('@').next().map(ToOwned::to_owned);
    }

    None
}

pub struct LatencyDisplay(pub f64);

impl Display for LatencyDisplay {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let mut t = self.0;

        for unit in ["ns", "µs", "ms", "s"] {
            if t < 10.0 {
                return write!(f, "{t:.2}{unit}");
            } else if t < 100.0 {
                return write!(f, "{t:.1}{unit}");
            } else if t < 1000.0 {
                return write!(f, "{t:.0}{unit}");
            }
            t /= 1000.0;
        }
        write!(f, "{:.0}s", t * 1000.0)
    }
}

#[allow(clippy::needless_pass_by_value)]
pub fn response_for_panic(_: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
    #[allow(clippy::declare_interior_mutable_const)]
    const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");

    let mut res = Response::new(Full::new(Bytes::from_static(b"500 Internal Server Error")));

    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
    res.headers_mut().insert(header::CONTENT_TYPE, TEXT_PLAIN);

    res
}

pub fn rustls_server_config(
    key: impl AsRef<Path>,
    cert: impl AsRef<Path>,
) -> Result<Arc<ServerConfig>, Box<dyn Error>> {
    let key = PrivateKeyDer::from_pem_file(key)?;

    let certs = CertificateDer::pem_file_iter(cert)?.flatten().collect();

    let mut config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(certs, key)?;

    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

    Ok(Arc::new(config))
}

/// writes crt.pem and key.pem to directory
pub fn generate_self_signed_localhost_certs(
    cert_dir_path: impl AsRef<Path>,
) -> Result<(), Box<dyn Error>> {
    std::fs::create_dir_all(&cert_dir_path)?;

    let cert_path = cert_dir_path.as_ref().join("crt.pem");
    let key_path = cert_dir_path.as_ref().join("key.pem");

    if !cert_path.exists() || !key_path.exists() {
        let subject_alt_names = vec!["localhost".to_string()];

        let CertifiedKey { cert, signing_key } =
            match generate_simple_self_signed(subject_alt_names) {
                Ok(ck) => {
                    tracing::info!("generated self-signed localhost cert");
                    ck
                }
                Err(err) => {
                    tracing::error!("failed to generate self-signed localhost cert");
                    return Err(err.into());
                }
            };

        let cert = cert.pem();
        let key = signing_key.serialize_pem();

        let mut cert_file = File::create(cert_path)?;
        let mut key_file = File::create(key_path)?;

        cert_file.write_all(cert.as_bytes())?;
        key_file.write_all(key.as_bytes())?;
    }

    Ok(())
}

pub fn redirect_service<H, T, S>(
    span_clone: Span,
    redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
    log_ips: bool,
    log_headers: bool,
    request_id_header: HeaderName,
    handler: H,
    state: S,
) -> Router
where
    H: Handler<T, S>,
    T: 'static,
    S: Clone + Send + Sync + 'static,
{
    let redacted_hash_clone = redacted_hash.clone();

    Router::new()
        .route("/healthz", get(|| async { StatusCode::OK }))
        .fallback(handler)
        .with_state(state)
        .layer(
            ServiceBuilder::new()
                .layer(CatchPanicLayer::custom(response_for_panic))
                .layer(RequestDecompressionLayer::new())
                .layer(CompressionLayer::new()),
        )
        .layer(
            ServiceBuilder::new()
                .layer(SetRequestIdLayer::new(
                    request_id_header.clone(),
                    MakeRequestUuid,
                ))
                .layer(
                    TraceLayer::new_for_http()
                        .make_span_with(move |req: &Request<_>| {
                            let request_id = req.headers().get(REQUEST_ID_HEADER);

                            let host =
                                get_host(req.headers(), req.uri()).map(tracing::field::display);

                            let ip = log_ips.then(|| {
                                req.extensions()
                                    .get::<ConnectInfo<SocketAddr>>()
                                    .map(|addr| tracing::field::display(addr.ip()))
                            });

                            let query = req.uri().query().map(tracing::field::display);

                            span_clone.in_scope(|| match request_id {
                                Some(rid) => {
                                    tracing::warn_span!(
                                        "redirect",
                                        host,
                                        id = %rid
                                            .to_str()
                                            .unwrap_or(Uuid::new_v4().to_string().as_str()),
                                        ip,
                                        path = %req.uri().path(),
                                        query,
                                    )
                                }
                                None => {
                                    tracing::warn_span!(
                                        "redirect",
                                        host,
                                        id = %Uuid::new_v4(),
                                        ip,
                                        path = %req.uri().path(),
                                        query,
                                    )
                                }
                            })
                        })
                        .on_request(move |req: &Request<_>, _: &Span| {
                            let hd = log_headers
                                .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));

                            #[cfg(tracing_unstable)]
                            let headers = log_headers.then_some(tracing::field::valuable(&hd));

                            #[cfg(not(tracing_unstable))]
                            let headers = log_headers.then_some(tracing::field::debug(&hd));

                            tracing::warn!(
                                version = ?req.version(),
                                method = %req.method(),
                                headers,
                                "req"
                            );
                        })
                        .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
                            let hd = log_headers.then_some(HeadersDebug(
                                res.headers(),
                                redacted_hash_clone.clone(),
                            ));

                            #[cfg(tracing_unstable)]
                            let headers = log_headers.then_some(tracing::field::valuable(&hd));

                            #[cfg(not(tracing_unstable))]
                            let headers = log_headers.then_some(tracing::field::debug(&hd));

                            let status = res.status().as_u16();
                            let latency = LatencyDisplay(latency.as_nanos() as f64);

                            if status >= 500 {
                                tracing::error!(status, headers, %latency, "res");
                            } else if status >= 400 {
                                tracing::warn!(status, headers, %latency, "res");
                            } else {
                                tracing::info!(status, headers, %latency, "res");
                            }
                        })
                        .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
                            tracing::error!(
                                err = %error,
                                "fail"
                            );
                        }),
                )
                .layer(TimeoutLayer::with_status_code(
                    StatusCode::REQUEST_TIMEOUT,
                    Duration::from_secs(5),
                ))
                .layer(PropagateRequestIdLayer::new(request_id_header))
                .layer(SetResponseHeaderLayer::if_not_present(
                    header::SERVER,
                    HeaderValue::from_static("Ordinary"),
                )),
        )
}