miku_server_timing/
lib.rs

1//! Miku's Server-Timing middleware for Axum
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{ready, Context, Poll},
7    time::Instant,
8};
9
10use http::{header::Entry as HeaderEntry, HeaderName, Request, Response};
11use macro_toolset::string::{NumStr, StringExtT};
12use pin_project_lite::pin_project;
13
14#[derive(Debug, Clone)]
15/// A middleware that will add a Server-Timing header to the response.
16pub struct ServerTimingLayer<'a> {
17    /// The service name.
18    app: &'a str,
19
20    /// An optional description of the service.
21    description: Option<&'a str>,
22}
23
24impl<'a> ServerTimingLayer<'a> {
25    #[inline]
26    /// Creates a new `ServerTimingLayer` with the given service name.
27    pub const fn new(app: &'a str) -> Self {
28        ServerTimingLayer {
29            app,
30            description: None,
31        }
32    }
33
34    #[inline]
35    /// Adds a description to the service name.
36    pub const fn with_description(mut self, description: &'a str) -> Self {
37        self.description = Some(description);
38        self
39    }
40}
41
42impl<'a, S> tower_layer::Layer<S> for ServerTimingLayer<'a> {
43    type Service = ServerTimingService<'a, S>;
44
45    fn layer(&self, service: S) -> Self::Service {
46        ServerTimingService {
47            service,
48            app: self.app,
49            description: self.description,
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55/// A service that will add a Server-Timing header to the response.
56pub struct ServerTimingService<'a, S> {
57    /// The service to wrap.
58    service: S,
59
60    /// The service name.
61    app: &'a str,
62
63    /// An optional description of the service.
64    description: Option<&'a str>,
65}
66
67impl<'a, S, ReqBody, ResBody> tower_service::Service<Request<ReqBody>>
68    for ServerTimingService<'a, S>
69where
70    S: tower_service::Service<Request<ReqBody>, Response = Response<ResBody>>,
71    ResBody: Default,
72{
73    type Response = S::Response;
74    type Error = S::Error;
75    type Future = ResponseFuture<'a, S::Future>;
76
77    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        self.service.poll_ready(cx)
79    }
80
81    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
82        ResponseFuture {
83            inner: self.service.call(req),
84            request_time: Instant::now(),
85            app: self.app,
86            description: self.description,
87        }
88    }
89}
90
91pin_project! {
92    /// A future that will add a Server-Timing header to the response.
93    pub struct ResponseFuture<'a, F> {
94        #[pin]
95        inner: F,
96        request_time: Instant,
97        app: &'a str,
98        description: Option<&'a str>,
99    }
100}
101
102const SERVER_TIMING: HeaderName = HeaderName::from_static("server-timing");
103
104impl<F, B, E> Future for ResponseFuture<'_, F>
105where
106    F: Future<Output = Result<Response<B>, E>>,
107    B: Default,
108{
109    type Output = Result<Response<B>, E>;
110
111    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
112        let this = self.project();
113
114        let mut response: Response<B> = ready!(this.inner.poll(cx))?;
115
116        match response.headers_mut().try_entry(SERVER_TIMING) {
117            Ok(entry) => match entry {
118                HeaderEntry::Occupied(mut val) => {
119                    if let Ok(v) = (
120                        this.app.with_suffix(";"),
121                        this.description.with_prefix("desc=\"").with_suffix("\";"),
122                        NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
123                            .set_resize_len::<1>()
124                            .with_prefix("dur="),
125                        val.get().to_str().with_prefix(", "),
126                    )
127                        .to_http_header_value()
128                    {
129                        val.insert(v);
130                    } else {
131                        // unlikely to happen, but if it does, just ignore it.
132                    }
133                }
134                HeaderEntry::Vacant(val) => {
135                    if let Ok(v) = (
136                        this.app.with_suffix(";"),
137                        this.description.with_prefix("desc=\"").with_suffix("\";"),
138                        NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
139                            .set_resize_len::<1>()
140                            .with_prefix("dur="),
141                    )
142                        .to_http_header_value()
143                    {
144                        val.insert(v);
145                    } else {
146                        // unlikely to happen, but if it does, just ignore it.
147                    }
148                }
149            },
150            Err(_e) => {
151                #[cfg(feature = "feat-tracing")]
152                tracing::error!("Failed to add `server-timing` header: {_e:?}");
153                // header name was invalid (it wasn't) or too many headers (just
154                // give up).
155            }
156        };
157
158        Poll::Ready(Ok(response))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use std::time::Duration;
165
166    use axum::{routing::get, Router};
167    use http::{HeaderMap, HeaderValue};
168
169    use super::ServerTimingLayer;
170
171    #[test]
172    fn service_name() {
173        let name = "svc1";
174        let obj = ServerTimingLayer::new(name);
175        assert_eq!(obj.app, name);
176    }
177
178    #[test]
179    fn service_desc() {
180        let name = "svc1";
181        let desc = "desc1";
182        let obj = ServerTimingLayer::new(name).with_description(desc);
183        assert_eq!(obj.app, name);
184        assert_eq!(obj.description, Some(desc));
185    }
186
187    #[tokio::test]
188    async fn axum_test() {
189        let name = "svc1";
190        let app = Router::new()
191            .route(
192                "/",
193                get(|| async move {
194                    tokio::time::sleep(Duration::from_millis(100)).await;
195                    ""
196                }),
197            )
198            .layer(ServerTimingLayer::new(name));
199
200        let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
201
202        tokio::spawn(async move { axum::serve(listener, app.into_make_service()).await });
203
204        let _ = tokio::task::spawn_blocking(|| {
205            let headers = minreq::get("http://localhost:3001/")
206                .send()
207                .unwrap()
208                .headers;
209
210            let hdr = headers.get("server-timing");
211            assert!(
212                hdr.is_some(),
213                "Cannot find `server-timing` from: {headers:#?}"
214            );
215
216            let val: f32 = hdr.unwrap()[9..].parse().unwrap();
217            assert!(
218                (100f32..300f32).contains(&val),
219                "Invalid `server-timing` from: {headers:#?}"
220            );
221        })
222        .await;
223    }
224
225    #[tokio::test]
226    async fn support_existing_header() {
227        let name = "svc1";
228        let app = Router::new()
229            .route(
230                "/",
231                get(|| async move {
232                    tokio::time::sleep(Duration::from_millis(100)).await;
233                    let mut hdr = HeaderMap::new();
234                    hdr.insert("server-timing", HeaderValue::from_static("inner;dur=23"));
235                    (hdr, "")
236                }),
237            )
238            .layer(ServerTimingLayer::new(name).with_description("desc1"));
239
240        let listener = tokio::net::TcpListener::bind("0.0.0.0:3003").await.unwrap();
241        tokio::spawn(async { axum::serve(listener, app.into_make_service()).await });
242
243        let _ = tokio::task::spawn_blocking(|| {
244            let headers = minreq::get("http://localhost:3003/")
245                .send()
246                .unwrap()
247                .headers;
248
249            let hdr = headers.get("server-timing").unwrap();
250            assert!(hdr.contains("svc1"));
251            assert!(hdr.contains("inner"));
252            println!("{hdr}");
253        })
254        .await;
255    }
256}