Skip to main content

miku_server_timing/
lib.rs

1//! Miku's Server-Timing middleware for Axum
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{ready, Context, Poll},
7    time::Instant,
8};
9
10use http::{header::Entry as HeaderEntry, HeaderName, Request, Response};
11use macro_toolset::string::{NumStr, StringExtT};
12use pin_project_lite::pin_project;
13
14#[derive(Debug, Clone, Copy)]
15/// A middleware that will add a Server-Timing header to the response.
16pub struct ServerTimingLayer<'a> {
17    /// The service name.
18    app: &'a str,
19
20    /// An optional description of the service.
21    description: Option<&'a str>,
22}
23
24impl<'a> ServerTimingLayer<'a> {
25    #[inline]
26    /// Creates a new `ServerTimingLayer` with the given service name.
27    pub const fn new(app: &'a str) -> Self {
28        ServerTimingLayer {
29            app,
30            description: None,
31        }
32    }
33
34    #[inline]
35    /// Creates a new `ServerTimingLayer` with default service name (`runtime`).
36    pub const fn new_default() -> Self {
37        ServerTimingLayer {
38            app: "runtime",
39            description: None,
40        }
41    }
42
43    #[inline]
44    /// Adds a description to the service name.
45    pub const fn with_description(mut self, description: &'a str) -> Self {
46        self.description = Some(description);
47        self
48    }
49}
50
51impl<'a, S> tower_layer::Layer<S> for ServerTimingLayer<'a> {
52    type Service = ServerTimingService<'a, S>;
53
54    fn layer(&self, service: S) -> Self::Service {
55        ServerTimingService {
56            service,
57            app: self.app,
58            description: self.description,
59        }
60    }
61}
62
63#[derive(Debug, Clone)]
64/// A service that will add a Server-Timing header to the response.
65pub struct ServerTimingService<'a, S> {
66    /// The service to wrap.
67    service: S,
68
69    /// The service name.
70    app: &'a str,
71
72    /// An optional description of the service.
73    description: Option<&'a str>,
74}
75
76impl<'a, S, ReqBody, ResBody> tower_service::Service<Request<ReqBody>>
77    for ServerTimingService<'a, S>
78where
79    S: tower_service::Service<Request<ReqBody>, Response = Response<ResBody>>,
80    ResBody: Default,
81{
82    type Response = S::Response;
83    type Error = S::Error;
84    type Future = ResponseFuture<'a, S::Future>;
85
86    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        self.service.poll_ready(cx)
88    }
89
90    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
91        ResponseFuture {
92            inner: self.service.call(req),
93            request_time: Instant::now(),
94            app: self.app,
95            description: self.description,
96        }
97    }
98}
99
100pin_project! {
101    /// A future that will add a Server-Timing header to the response.
102    pub struct ResponseFuture<'a, F> {
103        #[pin]
104        inner: F,
105        request_time: Instant,
106        app: &'a str,
107        description: Option<&'a str>,
108    }
109}
110
111const SERVER_TIMING: HeaderName = HeaderName::from_static("server-timing");
112
113impl<F, B, E> Future for ResponseFuture<'_, F>
114where
115    F: Future<Output = Result<Response<B>, E>>,
116    B: Default,
117{
118    type Output = Result<Response<B>, E>;
119
120    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121        let this = self.project();
122
123        let mut response: Response<B> = ready!(this.inner.poll(cx))?;
124
125        match response.headers_mut().try_entry(SERVER_TIMING) {
126            Ok(entry) => match entry {
127                HeaderEntry::Occupied(mut val) => {
128                    if let Ok(v) = (
129                        this.app.with_suffix(";"),
130                        this.description.with_prefix("desc=\"").with_suffix("\";"),
131                        NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
132                            .set_resize_len::<1>()
133                            .with_prefix("dur="),
134                            ", ",
135                        val.get().to_str(),
136                    )
137                        .to_http_header_value()
138                    {
139                        val.insert(v);
140                    } else {
141                        // unlikely to happen, but if it does, just ignore it.
142                    }
143                }
144                HeaderEntry::Vacant(val) => {
145                    if let Ok(v) = (
146                        this.app.with_suffix(";"),
147                        this.description.with_prefix("desc=\"").with_suffix("\";"),
148                        NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
149                            .set_resize_len::<1>()
150                            .with_prefix("dur="),
151                    )
152                        .to_http_header_value()
153                    {
154                        val.insert(v);
155                    } else {
156                        // unlikely to happen, but if it does, just ignore it.
157                    }
158                }
159            },
160            Err(_e) => {
161                #[cfg(feature = "feat-tracing")]
162                tracing::error!("Failed to add `server-timing` header: {_e:?}");
163                // header name was invalid (it wasn't) or too many headers (just
164                // give up).
165            }
166        };
167
168        Poll::Ready(Ok(response))
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use std::time::Duration;
175
176    use axum::{routing::get, Router};
177    use http::{HeaderMap, HeaderValue};
178
179    use super::ServerTimingLayer;
180
181    #[test]
182    fn service_name() {
183        let name = "svc1";
184        let obj = ServerTimingLayer::new(name);
185        assert_eq!(obj.app, name);
186    }
187
188    #[test]
189    fn service_desc() {
190        let name = "svc1";
191        let desc = "desc1";
192        let obj = ServerTimingLayer::new(name).with_description(desc);
193        assert_eq!(obj.app, name);
194        assert_eq!(obj.description, Some(desc));
195    }
196
197    #[tokio::test]
198    async fn axum_test() {
199        let name = "svc1";
200        let app = Router::new()
201            .route(
202                "/",
203                get(|| async move {
204                    tokio::time::sleep(Duration::from_millis(100)).await;
205                    ""
206                }),
207            )
208            .layer(ServerTimingLayer::new(name));
209
210        let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
211
212        tokio::spawn(async move { axum::serve(listener, app.into_make_service()).await });
213
214        let _ = tokio::task::spawn_blocking(|| {
215            let headers = minreq::get("http://localhost:3001/")
216                .send()
217                .unwrap()
218                .headers;
219
220            let hdr = headers.get("server-timing");
221            assert!(
222                hdr.is_some(),
223                "Cannot find `server-timing` from: {headers:#?}"
224            );
225
226            let val: f32 = hdr.unwrap()[9..].parse().unwrap();
227            assert!(
228                (100f32..300f32).contains(&val),
229                "Invalid `server-timing` from: {headers:#?}"
230            );
231        })
232        .await;
233    }
234
235    #[tokio::test]
236    async fn support_existing_header() {
237        let name = "svc1";
238        let app = Router::new()
239            .route(
240                "/",
241                get(|| async move {
242                    tokio::time::sleep(Duration::from_millis(100)).await;
243                    let mut hdr = HeaderMap::new();
244                    hdr.insert("server-timing", HeaderValue::from_static("inner;dur=23"));
245                    (hdr, "")
246                }),
247            )
248            .layer(ServerTimingLayer::new(name).with_description("desc1"));
249
250        let listener = tokio::net::TcpListener::bind("0.0.0.0:3003").await.unwrap();
251        tokio::spawn(async { axum::serve(listener, app.into_make_service()).await });
252
253        let _ = tokio::task::spawn_blocking(|| {
254            let headers = minreq::get("http://localhost:3003/")
255                .send()
256                .unwrap()
257                .headers;
258
259            let hdr = headers.get("server-timing").unwrap();
260            assert!(hdr.contains("svc1"));
261            assert!(hdr.contains("inner"));
262            println!("{hdr}");
263        })
264        .await;
265    }
266}