use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use axum::extract::MatchedPath;
use axum::http::{HeaderName, Method, Request, Response};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
pub const ACCESS_LOG_TARGET: &str = "autumn::access";
pub const UNMATCHED_ROUTE: &str = "_unmatched";
#[derive(Clone, Debug, Default)]
pub struct AccessLogEmitted(Arc<std::sync::atomic::AtomicBool>);
impl AccessLogEmitted {
fn mark(&self) {
self.0.store(true, std::sync::atomic::Ordering::Release);
}
fn is_marked(&self) -> bool {
self.0.load(std::sync::atomic::Ordering::Acquire)
}
}
#[derive(Clone, Debug)]
pub struct AccessLogLayer {
exclude: Arc<[String]>,
fallback: bool,
}
impl AccessLogLayer {
#[must_use]
pub fn new(exclude: Vec<String>) -> Self {
Self {
exclude: normalize_exclusions(exclude),
fallback: false,
}
}
#[must_use]
pub fn fallback(exclude: Vec<String>) -> Self {
Self {
exclude: normalize_exclusions(exclude),
fallback: true,
}
}
}
fn normalize_exclusions(exclude: Vec<String>) -> Arc<[String]> {
exclude
.into_iter()
.map(|prefix| {
let trimmed = prefix.trim_end_matches('/');
if trimmed.len() == prefix.len() {
prefix
} else {
trimmed.to_owned()
}
})
.filter(|prefix| !prefix.is_empty())
.collect()
}
impl<S> Layer<S> for AccessLogLayer {
type Service = AccessLogService<S>;
fn layer(&self, inner: S) -> Self::Service {
AccessLogService {
inner,
exclude: Arc::clone(&self.exclude),
fallback: self.fallback,
}
}
}
#[derive(Clone, Debug)]
pub struct AccessLogService<S> {
inner: S,
exclude: Arc<[String]>,
fallback: bool,
}
struct RequestMeta {
method: Method,
route: Option<String>,
request_id: Option<String>,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AccessLogService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = AccessLogFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let mut req = req;
let meta = if is_excluded(req.uri().path(), &self.exclude) {
None
} else {
Some(RequestMeta {
method: req.method().clone(),
route: req
.extensions()
.get::<MatchedPath>()
.map(|matched| matched.as_str().to_owned()),
request_id: req
.extensions()
.get::<crate::middleware::RequestId>()
.map(ToString::to_string),
})
};
let sentinel = if self.fallback {
let sentinel = AccessLogEmitted::default();
req.extensions_mut().insert(sentinel.clone());
Some(sentinel)
} else {
req.extensions().get::<AccessLogEmitted>().cloned()
};
AccessLogFuture {
inner: self.inner.call(req),
meta,
start: Instant::now(),
fallback: self.fallback,
sentinel,
}
}
}
fn is_excluded(path: &str, exclude: &[String]) -> bool {
exclude.iter().any(|prefix| {
path.strip_prefix(prefix.as_str())
.is_some_and(|rest| rest.is_empty() || rest.starts_with('/'))
})
}
pin_project! {
pub struct AccessLogFuture<F> {
#[pin]
inner: F,
meta: Option<RequestMeta>,
start: Instant,
fallback: bool,
sentinel: Option<AccessLogEmitted>,
}
}
impl<F, ResBody, E> Future for AccessLogFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Ready(Ok(response)) => {
let already_emitted = *this.fallback
&& this
.sentinel
.as_ref()
.is_some_and(AccessLogEmitted::is_marked);
if let Some(meta) = this.meta.take()
&& !already_emitted
{
let duration_ms = this.start.elapsed().as_secs_f64() * 1000.0;
let header_id = response
.headers()
.get(&X_REQUEST_ID)
.and_then(|value| value.to_str().ok());
let request_id = meta.request_id.as_deref().or(header_id);
tracing::info!(
target: ACCESS_LOG_TARGET,
method = %meta.method,
route = meta.route.as_deref().unwrap_or(UNMATCHED_ROUTE),
status = response.status().as_u16(),
duration_ms,
request_id,
"request served"
);
if !*this.fallback
&& let Some(sentinel) = this.sentinel.as_ref()
{
sentinel.mark();
}
}
Poll::Ready(Ok(response))
}
other => other,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn exclude(prefixes: &[&str]) -> Arc<[String]> {
normalize_exclusions(prefixes.iter().map(|p| (*p).to_owned()).collect())
}
#[test]
fn excludes_exact_path_and_sub_segments() {
let prefixes = exclude(&["/health", "/actuator", "/static"]);
assert!(is_excluded("/health", &prefixes));
assert!(is_excluded("/actuator/health", &prefixes));
assert!(is_excluded("/static/css/app.css", &prefixes));
}
#[test]
fn does_not_exclude_lookalike_prefixes() {
let prefixes = exclude(&["/health", "/static"]);
assert!(!is_excluded("/healthz", &prefixes));
assert!(!is_excluded("/staticfiles", &prefixes));
assert!(!is_excluded("/users/1", &prefixes));
}
#[test]
fn trailing_slashes_in_config_are_tolerated() {
let prefixes = exclude(&["/actuator/"]);
assert!(is_excluded("/actuator", &prefixes));
assert!(is_excluded("/actuator/metrics", &prefixes));
}
#[test]
fn empty_or_slash_only_prefixes_exclude_nothing() {
assert!(!is_excluded("/users/1", &exclude(&[""])));
assert!(!is_excluded("/users/1", &exclude(&["/"])));
assert!(!is_excluded("/users/1", &[]));
}
}