use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::extract::MatchedPath;
use axum::http::{Method, Request, Response, Uri};
use futures_core::ready;
use pin_project::pin_project;
use tower::{Layer, Service};
use tracing::trace;
pub trait RequestTrace {
fn is_traced(&self, path: &str, matched: bool) -> bool;
fn enabled(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
pub struct RequestTraceData {
pub trace: bool,
pub method: Method,
pub uri: Uri,
}
#[derive(Debug, Clone)]
pub struct RequestTraceService<S, F> {
inner: S,
make_tracer: F,
}
#[derive(Clone)]
pub struct RequestTraceLayer<F> {
make_tracer: F,
}
impl<S, F> Layer<S> for RequestTraceLayer<F>
where
F: Clone,
{
type Service = RequestTraceService<S, F>;
fn layer(&self, inner: S) -> Self::Service {
RequestTraceService {
inner,
make_tracer: self.make_tracer.clone(),
}
}
}
impl<F> RequestTraceLayer<F> {
pub fn new(make_tracer: F) -> Self {
Self { make_tracer }
}
}
impl<ReqBody, ResBody, S, F, T> Service<Request<ReqBody>> for RequestTraceService<S, F>
where
S: Service<Request<ReqBody>, Response=Response<ResBody>>,
F: FnMut() -> T,
T: RequestTrace,
{
type Response = S::Response;
type Error = S::Error;
type Future = RequestTraceFuture<Request<ReqBody>, S>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let tracer = (self.make_tracer)();
let enabled = tracer.enabled();
let mut request_trace = None;
if enabled {
let matched;
let path;
if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
matched = true;
path = matched_path.as_str();
} else {
matched = false;
path = req.uri().path();
};
let trace = tracer.is_traced(path, matched);
request_trace = Some(RequestTraceData {
trace,
method: req.method().clone(),
uri: req.uri().clone(),
});
trace!(
"RequestTraceService: path = {path:?}, \
request_trace = {request_trace:?}",
);
}
RequestTraceFuture {
request_trace,
state: FutureState::Polling(self.inner.call(req)),
}
}
}
#[pin_project]
pub struct RequestTraceFuture<Request, S: Service<Request>> {
request_trace: Option<RequestTraceData>,
#[pin]
state: FutureState<Request, S>,
}
#[pin_project(project = FutureStateProj)]
enum FutureState<Request, S: Service<Request>> {
Polling(#[pin] S::Future),
Finished,
}
impl<Request, ResBody, S> Future for RequestTraceFuture<Request, S>
where
S: Service<Request, Response=Response<ResBody>>,
{
type Output = Result<S::Response, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
match this.state.as_mut().project() {
FutureStateProj::Polling(service_fut) => {
let mut output: Self::Output = ready!(service_fut.poll(cx));
if let Ok(response) = &mut output {
if let Some(request_trace) = this.request_trace.take() {
response.extensions_mut().insert(request_trace);
}
}
this.state.set(FutureState::Finished);
Poll::Ready(output)
}
FutureStateProj::Finished => {
panic!("RequestTraceFuture polled after completion");
}
}
}
}