miku_server_timing/
lib.rs1use std::{
4 future::Future,
5 pin::Pin,
6 task::{ready, Context, Poll},
7 time::Instant,
8};
9
10use http::{header::Entry as HeaderEntry, HeaderName, Request, Response};
11use macro_toolset::string::{NumStr, StringExtT};
12use pin_project_lite::pin_project;
13
14#[derive(Debug, Clone, Copy)]
15pub struct ServerTimingLayer<'a> {
17 app: &'a str,
19
20 description: Option<&'a str>,
22}
23
24impl<'a> ServerTimingLayer<'a> {
25 #[inline]
26 pub const fn new(app: &'a str) -> Self {
28 ServerTimingLayer {
29 app,
30 description: None,
31 }
32 }
33
34 #[inline]
35 pub const fn new_default() -> Self {
37 ServerTimingLayer {
38 app: "runtime",
39 description: None,
40 }
41 }
42
43 #[inline]
44 pub const fn with_description(mut self, description: &'a str) -> Self {
46 self.description = Some(description);
47 self
48 }
49}
50
51impl<'a, S> tower_layer::Layer<S> for ServerTimingLayer<'a> {
52 type Service = ServerTimingService<'a, S>;
53
54 fn layer(&self, service: S) -> Self::Service {
55 ServerTimingService {
56 service,
57 app: self.app,
58 description: self.description,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct ServerTimingService<'a, S> {
66 service: S,
68
69 app: &'a str,
71
72 description: Option<&'a str>,
74}
75
76impl<'a, S, ReqBody, ResBody> tower_service::Service<Request<ReqBody>>
77 for ServerTimingService<'a, S>
78where
79 S: tower_service::Service<Request<ReqBody>, Response = Response<ResBody>>,
80 ResBody: Default,
81{
82 type Response = S::Response;
83 type Error = S::Error;
84 type Future = ResponseFuture<'a, S::Future>;
85
86 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 self.service.poll_ready(cx)
88 }
89
90 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
91 ResponseFuture {
92 inner: self.service.call(req),
93 request_time: Instant::now(),
94 app: self.app,
95 description: self.description,
96 }
97 }
98}
99
100pin_project! {
101 pub struct ResponseFuture<'a, F> {
103 #[pin]
104 inner: F,
105 request_time: Instant,
106 app: &'a str,
107 description: Option<&'a str>,
108 }
109}
110
111const SERVER_TIMING: HeaderName = HeaderName::from_static("server-timing");
112
113impl<F, B, E> Future for ResponseFuture<'_, F>
114where
115 F: Future<Output = Result<Response<B>, E>>,
116 B: Default,
117{
118 type Output = Result<Response<B>, E>;
119
120 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121 let this = self.project();
122
123 let mut response: Response<B> = ready!(this.inner.poll(cx))?;
124
125 match response.headers_mut().try_entry(SERVER_TIMING) {
126 Ok(entry) => match entry {
127 HeaderEntry::Occupied(mut val) => {
128 if let Ok(v) = (
129 this.app.with_suffix(";"),
130 this.description.with_prefix("desc=\"").with_suffix("\";"),
131 NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
132 .set_resize_len::<1>()
133 .with_prefix("dur="),
134 ", ",
135 val.get().to_str(),
136 )
137 .to_http_header_value()
138 {
139 val.insert(v);
140 } else {
141 }
143 }
144 HeaderEntry::Vacant(val) => {
145 if let Ok(v) = (
146 this.app.with_suffix(";"),
147 this.description.with_prefix("desc=\"").with_suffix("\";"),
148 NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
149 .set_resize_len::<1>()
150 .with_prefix("dur="),
151 )
152 .to_http_header_value()
153 {
154 val.insert(v);
155 } else {
156 }
158 }
159 },
160 Err(_e) => {
161 #[cfg(feature = "feat-tracing")]
162 tracing::error!("Failed to add `server-timing` header: {_e:?}");
163 }
166 };
167
168 Poll::Ready(Ok(response))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::time::Duration;
175
176 use axum::{routing::get, Router};
177 use http::{HeaderMap, HeaderValue};
178
179 use super::ServerTimingLayer;
180
181 #[test]
182 fn service_name() {
183 let name = "svc1";
184 let obj = ServerTimingLayer::new(name);
185 assert_eq!(obj.app, name);
186 }
187
188 #[test]
189 fn service_desc() {
190 let name = "svc1";
191 let desc = "desc1";
192 let obj = ServerTimingLayer::new(name).with_description(desc);
193 assert_eq!(obj.app, name);
194 assert_eq!(obj.description, Some(desc));
195 }
196
197 #[tokio::test]
198 async fn axum_test() {
199 let name = "svc1";
200 let app = Router::new()
201 .route(
202 "/",
203 get(|| async move {
204 tokio::time::sleep(Duration::from_millis(100)).await;
205 ""
206 }),
207 )
208 .layer(ServerTimingLayer::new(name));
209
210 let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
211
212 tokio::spawn(async move { axum::serve(listener, app.into_make_service()).await });
213
214 let _ = tokio::task::spawn_blocking(|| {
215 let headers = minreq::get("http://localhost:3001/")
216 .send()
217 .unwrap()
218 .headers;
219
220 let hdr = headers.get("server-timing");
221 assert!(
222 hdr.is_some(),
223 "Cannot find `server-timing` from: {headers:#?}"
224 );
225
226 let val: f32 = hdr.unwrap()[9..].parse().unwrap();
227 assert!(
228 (100f32..300f32).contains(&val),
229 "Invalid `server-timing` from: {headers:#?}"
230 );
231 })
232 .await;
233 }
234
235 #[tokio::test]
236 async fn support_existing_header() {
237 let name = "svc1";
238 let app = Router::new()
239 .route(
240 "/",
241 get(|| async move {
242 tokio::time::sleep(Duration::from_millis(100)).await;
243 let mut hdr = HeaderMap::new();
244 hdr.insert("server-timing", HeaderValue::from_static("inner;dur=23"));
245 (hdr, "")
246 }),
247 )
248 .layer(ServerTimingLayer::new(name).with_description("desc1"));
249
250 let listener = tokio::net::TcpListener::bind("0.0.0.0:3003").await.unwrap();
251 tokio::spawn(async { axum::serve(listener, app.into_make_service()).await });
252
253 let _ = tokio::task::spawn_blocking(|| {
254 let headers = minreq::get("http://localhost:3003/")
255 .send()
256 .unwrap()
257 .headers;
258
259 let hdr = headers.get("server-timing").unwrap();
260 assert!(hdr.contains("svc1"));
261 assert!(hdr.contains("inner"));
262 println!("{hdr}");
263 })
264 .await;
265 }
266}