caco3_web/middleware/
request_trace.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::extract::MatchedPath;
6use axum::http::{Method, Request, Response, Uri};
7use futures_core::ready;
8use pin_project::pin_project;
9use tower::{Layer, Service};
10use tracing::trace;
11
12pub trait RequestTrace {
13    fn is_traced(&self, path: &str, matched: bool) -> bool;
14
15    fn enabled(&self) -> bool {
16        true
17    }
18}
19
20/// A struct contain Http request info.
21#[derive(Debug, Clone)]
22pub struct RequestTraceData {
23    /// Indicate that request trace should be shown for route.
24    pub trace: bool,
25    /// Request method.
26    pub method: Method,
27    /// Request uri.
28    pub uri: Uri,
29}
30
31#[derive(Debug, Clone)]
32/// Middleware that adds [`RequestTraceData`] to response extension.
33pub struct RequestTraceService<S, F> {
34    inner: S,
35    make_tracer: F,
36}
37
38#[derive(Clone)]
39/// [`Layer`] that adds [`RequestTraceData`] to response extension.
40pub struct RequestTraceLayer<F> {
41    make_tracer: F,
42}
43
44impl<S, F> Layer<S> for RequestTraceLayer<F>
45    where
46        F: Clone,
47{
48    type Service = RequestTraceService<S, F>;
49
50    fn layer(&self, inner: S) -> Self::Service {
51        RequestTraceService {
52            inner,
53            make_tracer: self.make_tracer.clone(),
54        }
55    }
56}
57
58impl<F> RequestTraceLayer<F> {
59    pub fn new(make_tracer: F) -> Self {
60        Self { make_tracer }
61    }
62}
63
64impl<ReqBody, ResBody, S, F, T> Service<Request<ReqBody>> for RequestTraceService<S, F>
65    where
66        S: Service<Request<ReqBody>, Response=Response<ResBody>>,
67        F: FnMut() -> T,
68        T: RequestTrace,
69{
70    type Response = S::Response;
71    type Error = S::Error;
72    type Future = RequestTraceFuture<Request<ReqBody>, S>;
73
74    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        self.inner.poll_ready(cx)
76    }
77
78    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
79        let tracer = (self.make_tracer)();
80        let enabled = tracer.enabled();
81        let mut request_trace = None;
82
83        if enabled {
84            let matched;
85            let path;
86            if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
87                matched = true;
88                path = matched_path.as_str();
89            } else {
90                matched = false;
91                path = req.uri().path();
92            };
93            let trace = tracer.is_traced(path, matched);
94            request_trace = Some(RequestTraceData {
95                trace,
96                method: req.method().clone(),
97                uri: req.uri().clone(),
98            });
99            trace!(
100                "RequestTraceService: path = {path:?}, \
101                request_trace = {request_trace:?}",
102            );
103        }
104
105        RequestTraceFuture {
106            request_trace,
107            state: FutureState::Polling(self.inner.call(req)),
108        }
109    }
110}
111
112#[pin_project]
113pub struct RequestTraceFuture<Request, S: Service<Request>> {
114    request_trace: Option<RequestTraceData>,
115    #[pin]
116    state: FutureState<Request, S>,
117}
118
119#[pin_project(project = FutureStateProj)]
120enum FutureState<Request, S: Service<Request>> {
121    Polling(#[pin] S::Future),
122    Finished,
123}
124
125impl<Request, ResBody, S> Future for RequestTraceFuture<Request, S>
126    where
127        S: Service<Request, Response=Response<ResBody>>,
128{
129    type Output = Result<S::Response, S::Error>;
130
131    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132        let mut this = self.project();
133        match this.state.as_mut().project() {
134            FutureStateProj::Polling(service_fut) => {
135                let mut output: Self::Output = ready!(service_fut.poll(cx));
136                if let Ok(response) = &mut output {
137                    if let Some(request_trace) = this.request_trace.take() {
138                        response.extensions_mut().insert(request_trace);
139                    }
140                }
141                this.state.set(FutureState::Finished);
142                Poll::Ready(output)
143            }
144            FutureStateProj::Finished => {
145                panic!("RequestTraceFuture polled after completion");
146            }
147        }
148    }
149}