caco3_web/middleware/
request_trace.rs1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::extract::MatchedPath;
6use axum::http::{Method, Request, Response, Uri};
7use futures_core::ready;
8use pin_project::pin_project;
9use tower::{Layer, Service};
10use tracing::trace;
11
12pub trait RequestTrace {
13 fn is_traced(&self, path: &str, matched: bool) -> bool;
14
15 fn enabled(&self) -> bool {
16 true
17 }
18}
19
20#[derive(Debug, Clone)]
22pub struct RequestTraceData {
23 pub trace: bool,
25 pub method: Method,
27 pub uri: Uri,
29}
30
31#[derive(Debug, Clone)]
32pub struct RequestTraceService<S, F> {
34 inner: S,
35 make_tracer: F,
36}
37
38#[derive(Clone)]
39pub struct RequestTraceLayer<F> {
41 make_tracer: F,
42}
43
44impl<S, F> Layer<S> for RequestTraceLayer<F>
45 where
46 F: Clone,
47{
48 type Service = RequestTraceService<S, F>;
49
50 fn layer(&self, inner: S) -> Self::Service {
51 RequestTraceService {
52 inner,
53 make_tracer: self.make_tracer.clone(),
54 }
55 }
56}
57
58impl<F> RequestTraceLayer<F> {
59 pub fn new(make_tracer: F) -> Self {
60 Self { make_tracer }
61 }
62}
63
64impl<ReqBody, ResBody, S, F, T> Service<Request<ReqBody>> for RequestTraceService<S, F>
65 where
66 S: Service<Request<ReqBody>, Response=Response<ResBody>>,
67 F: FnMut() -> T,
68 T: RequestTrace,
69{
70 type Response = S::Response;
71 type Error = S::Error;
72 type Future = RequestTraceFuture<Request<ReqBody>, S>;
73
74 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 self.inner.poll_ready(cx)
76 }
77
78 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
79 let tracer = (self.make_tracer)();
80 let enabled = tracer.enabled();
81 let mut request_trace = None;
82
83 if enabled {
84 let matched;
85 let path;
86 if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
87 matched = true;
88 path = matched_path.as_str();
89 } else {
90 matched = false;
91 path = req.uri().path();
92 };
93 let trace = tracer.is_traced(path, matched);
94 request_trace = Some(RequestTraceData {
95 trace,
96 method: req.method().clone(),
97 uri: req.uri().clone(),
98 });
99 trace!(
100 "RequestTraceService: path = {path:?}, \
101 request_trace = {request_trace:?}",
102 );
103 }
104
105 RequestTraceFuture {
106 request_trace,
107 state: FutureState::Polling(self.inner.call(req)),
108 }
109 }
110}
111
112#[pin_project]
113pub struct RequestTraceFuture<Request, S: Service<Request>> {
114 request_trace: Option<RequestTraceData>,
115 #[pin]
116 state: FutureState<Request, S>,
117}
118
119#[pin_project(project = FutureStateProj)]
120enum FutureState<Request, S: Service<Request>> {
121 Polling(#[pin] S::Future),
122 Finished,
123}
124
125impl<Request, ResBody, S> Future for RequestTraceFuture<Request, S>
126 where
127 S: Service<Request, Response=Response<ResBody>>,
128{
129 type Output = Result<S::Response, S::Error>;
130
131 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132 let mut this = self.project();
133 match this.state.as_mut().project() {
134 FutureStateProj::Polling(service_fut) => {
135 let mut output: Self::Output = ready!(service_fut.poll(cx));
136 if let Ok(response) = &mut output {
137 if let Some(request_trace) = this.request_trace.take() {
138 response.extensions_mut().insert(request_trace);
139 }
140 }
141 this.state.set(FutureState::Finished);
142 Poll::Ready(output)
143 }
144 FutureStateProj::Finished => {
145 panic!("RequestTraceFuture polled after completion");
146 }
147 }
148 }
149}