axum_server_timing/
lib.rs

1use http::{HeaderValue, Request, Response};
2use pin_project_lite::pin_project;
3use std::{
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex},
7    task::{ready, Context, Poll},
8    time::{Duration, Instant},
9};
10use tower::{Layer, Service};
11
12#[allow(dead_code)]
13pub type ServerTimingExtension = Arc<Mutex<ServerTiming>>;
14
15#[derive(Debug)]
16pub struct ServerTiming {
17    app: String,
18    description: Option<String>,
19    created: Instant,
20    data: Vec<ServerTimingData>,
21}
22
23impl ServerTiming {
24    /// records the duration of the current operation
25    /// the duration is always relative to the last data point (record or creation)
26    pub fn record(&mut self, name: String, description: Option<String>) {
27        let duration = if self.data.is_empty() {
28            Instant::now() - self.created
29        } else {
30            self.data.last().unwrap().duration
31        };
32        self.data.push(ServerTimingData {
33            name,
34            description,
35            duration,
36        });
37    }
38    /// records a duration of an operation
39    pub fn record_timing(&mut self, name: String, duration: Duration, description: Option<String>) {
40        self.data.push(ServerTimingData {
41            name,
42            description,
43            duration,
44        });
45    }
46}
47
48#[derive(Debug)]
49pub struct ServerTimingData {
50    name: String,
51    description: Option<String>,
52    duration: Duration,
53}
54
55#[cfg(test)]
56mod test;
57
58#[derive(Debug, Clone)]
59pub struct ServerTimingLayer<'a> {
60    app: &'a str,
61    description: Option<&'a str>,
62}
63
64impl<'a> ServerTimingLayer<'a> {
65    pub fn new(app: &'a str) -> Self {
66        ServerTimingLayer {
67            app,
68            description: None,
69        }
70    }
71
72    pub fn with_description(&mut self, description: &'a str) -> Self {
73        let mut new_self = self.clone();
74        new_self.description = Some(description);
75        new_self
76    }
77}
78
79impl<'a, S> Layer<S> for ServerTimingLayer<'a> {
80    type Service = ServerTimingService<'a, S>;
81
82    fn layer(&self, service: S) -> Self::Service {
83        ServerTimingService {
84            service,
85            app: self.app,
86            description: self.description,
87        }
88    }
89}
90
91#[derive(Clone)]
92pub struct ServerTimingService<'a, S> {
93    service: S,
94    app: &'a str,
95    description: Option<&'a str>,
96}
97
98impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ServerTimingService<'_, S>
99where
100    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
101    ResBody: Default,
102{
103    type Response = S::Response;
104    type Error = S::Error;
105    type Future = ResponseFuture<S::Future>;
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.service.poll_ready(cx)
108    }
109
110    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
111        let timings = ServerTiming {
112            app: self.app.to_string(),
113            created: Instant::now(),
114            description: self.description.map(|elem| elem.to_string()),
115            data: vec![],
116        };
117        let x = Arc::new(Mutex::new(timings));
118        req.extensions_mut().insert(x.clone());
119
120        let (parts, body) = req.into_parts();
121
122        let req = Request::from_parts(parts, body);
123        ResponseFuture {
124            inner: self.service.call(req),
125            timings: x,
126        }
127    }
128}
129
130pin_project! {
131    pub struct ResponseFuture<F> {
132        #[pin]
133        inner: F,
134        timings: Arc<Mutex<ServerTiming>>,
135    }
136}
137
138impl<F, B, E> Future for ResponseFuture<F>
139where
140    F: Future<Output = Result<Response<B>, E>>,
141    B: Default,
142{
143    type Output = Result<Response<B>, E>;
144
145    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146        let timing = self.timings.clone();
147        let mut response: Response<B> = ready!(self.project().inner.poll(cx))?;
148        let hdr = response.headers_mut();
149        // TODO: Once stable for a while, use `as_millis_f32`
150        let timing_after = timing.lock().unwrap();
151        let x = timing_after.created.elapsed().as_secs_f32() * 1000.0;
152        let app = timing_after.app.clone();
153        let mut header_value = match &timing_after.description {
154            Some(val) => format!("{app};desc=\"{val}\";dur={x:.2}"),
155            None => format!("{app};dur={x:.2}"),
156        };
157        for data in timing_after.data.iter() {
158            let x = data.duration.as_secs_f32() * 1000.0;
159            let name = data.name.clone();
160            let newval = match &data.description {
161                Some(val) => format!("{name};desc=\"{val}\";dur={x:.2}"),
162                None => format!("{name};dur={x:.2}"),
163            };
164            header_value = format!("{header_value}, {newval}");
165        }
166        match hdr.try_entry("Server-Timing") {
167            Ok(entry) => {
168                match entry {
169                    http::header::Entry::Occupied(mut val) => {
170                        //has val
171                        let old_val = val.get();
172                        let new_val = format!("{header_value}, {}", old_val.to_str().unwrap());
173                        val.insert(HeaderValue::from_str(&new_val).unwrap());
174                    }
175                    http::header::Entry::Vacant(val) => {
176                        val.insert(HeaderValue::from_str(&header_value).unwrap());
177                    }
178                }
179            }
180            Err(_) => {
181                // header name was invalid (it wasn't) or too many headers (just give up).
182            }
183        }
184
185        Poll::Ready(Ok(response))
186    }
187}