miku_server_timing/
lib.rs1use std::{
4 future::Future,
5 pin::Pin,
6 task::{ready, Context, Poll},
7 time::Instant,
8};
9
10use http::{header::Entry as HeaderEntry, HeaderName, Request, Response};
11use macro_toolset::string::{NumStr, StringExtT};
12use pin_project_lite::pin_project;
13
14#[derive(Debug, Clone)]
15pub struct ServerTimingLayer<'a> {
17 app: &'a str,
19
20 description: Option<&'a str>,
22}
23
24impl<'a> ServerTimingLayer<'a> {
25 #[inline]
26 pub const fn new(app: &'a str) -> Self {
28 ServerTimingLayer {
29 app,
30 description: None,
31 }
32 }
33
34 #[inline]
35 pub const fn with_description(mut self, description: &'a str) -> Self {
37 self.description = Some(description);
38 self
39 }
40}
41
42impl<'a, S> tower_layer::Layer<S> for ServerTimingLayer<'a> {
43 type Service = ServerTimingService<'a, S>;
44
45 fn layer(&self, service: S) -> Self::Service {
46 ServerTimingService {
47 service,
48 app: self.app,
49 description: self.description,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct ServerTimingService<'a, S> {
57 service: S,
59
60 app: &'a str,
62
63 description: Option<&'a str>,
65}
66
67impl<'a, S, ReqBody, ResBody> tower_service::Service<Request<ReqBody>>
68 for ServerTimingService<'a, S>
69where
70 S: tower_service::Service<Request<ReqBody>, Response = Response<ResBody>>,
71 ResBody: Default,
72{
73 type Response = S::Response;
74 type Error = S::Error;
75 type Future = ResponseFuture<'a, S::Future>;
76
77 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 self.service.poll_ready(cx)
79 }
80
81 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
82 ResponseFuture {
83 inner: self.service.call(req),
84 request_time: Instant::now(),
85 app: self.app,
86 description: self.description,
87 }
88 }
89}
90
91pin_project! {
92 pub struct ResponseFuture<'a, F> {
94 #[pin]
95 inner: F,
96 request_time: Instant,
97 app: &'a str,
98 description: Option<&'a str>,
99 }
100}
101
102const SERVER_TIMING: HeaderName = HeaderName::from_static("server-timing");
103
104impl<F, B, E> Future for ResponseFuture<'_, F>
105where
106 F: Future<Output = Result<Response<B>, E>>,
107 B: Default,
108{
109 type Output = Result<Response<B>, E>;
110
111 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
112 let this = self.project();
113
114 let mut response: Response<B> = ready!(this.inner.poll(cx))?;
115
116 match response.headers_mut().try_entry(SERVER_TIMING) {
117 Ok(entry) => match entry {
118 HeaderEntry::Occupied(mut val) => {
119 if let Ok(v) = (
120 this.app.with_suffix(";"),
121 this.description.with_prefix("desc=\"").with_suffix("\";"),
122 NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
123 .set_resize_len::<1>()
124 .with_prefix("dur="),
125 val.get().to_str().with_prefix(", "),
126 )
127 .to_http_header_value()
128 {
129 val.insert(v);
130 } else {
131 }
133 }
134 HeaderEntry::Vacant(val) => {
135 if let Ok(v) = (
136 this.app.with_suffix(";"),
137 this.description.with_prefix("desc=\"").with_suffix("\";"),
138 NumStr::new_default(this.request_time.elapsed().as_secs_f32() * 1000.0)
139 .set_resize_len::<1>()
140 .with_prefix("dur="),
141 )
142 .to_http_header_value()
143 {
144 val.insert(v);
145 } else {
146 }
148 }
149 },
150 Err(_e) => {
151 #[cfg(feature = "feat-tracing")]
152 tracing::error!("Failed to add `server-timing` header: {_e:?}");
153 }
156 };
157
158 Poll::Ready(Ok(response))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use std::time::Duration;
165
166 use axum::{routing::get, Router};
167 use http::{HeaderMap, HeaderValue};
168
169 use super::ServerTimingLayer;
170
171 #[test]
172 fn service_name() {
173 let name = "svc1";
174 let obj = ServerTimingLayer::new(name);
175 assert_eq!(obj.app, name);
176 }
177
178 #[test]
179 fn service_desc() {
180 let name = "svc1";
181 let desc = "desc1";
182 let obj = ServerTimingLayer::new(name).with_description(desc);
183 assert_eq!(obj.app, name);
184 assert_eq!(obj.description, Some(desc));
185 }
186
187 #[tokio::test]
188 async fn axum_test() {
189 let name = "svc1";
190 let app = Router::new()
191 .route(
192 "/",
193 get(|| async move {
194 tokio::time::sleep(Duration::from_millis(100)).await;
195 ""
196 }),
197 )
198 .layer(ServerTimingLayer::new(name));
199
200 let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await.unwrap();
201
202 tokio::spawn(async move { axum::serve(listener, app.into_make_service()).await });
203
204 let _ = tokio::task::spawn_blocking(|| {
205 let headers = minreq::get("http://localhost:3001/")
206 .send()
207 .unwrap()
208 .headers;
209
210 let hdr = headers.get("server-timing");
211 assert!(
212 hdr.is_some(),
213 "Cannot find `server-timing` from: {headers:#?}"
214 );
215
216 let val: f32 = hdr.unwrap()[9..].parse().unwrap();
217 assert!(
218 (100f32..300f32).contains(&val),
219 "Invalid `server-timing` from: {headers:#?}"
220 );
221 })
222 .await;
223 }
224
225 #[tokio::test]
226 async fn support_existing_header() {
227 let name = "svc1";
228 let app = Router::new()
229 .route(
230 "/",
231 get(|| async move {
232 tokio::time::sleep(Duration::from_millis(100)).await;
233 let mut hdr = HeaderMap::new();
234 hdr.insert("server-timing", HeaderValue::from_static("inner;dur=23"));
235 (hdr, "")
236 }),
237 )
238 .layer(ServerTimingLayer::new(name).with_description("desc1"));
239
240 let listener = tokio::net::TcpListener::bind("0.0.0.0:3003").await.unwrap();
241 tokio::spawn(async { axum::serve(listener, app.into_make_service()).await });
242
243 let _ = tokio::task::spawn_blocking(|| {
244 let headers = minreq::get("http://localhost:3003/")
245 .send()
246 .unwrap()
247 .headers;
248
249 let hdr = headers.get("server-timing").unwrap();
250 assert!(hdr.contains("svc1"));
251 assert!(hdr.contains("inner"));
252 println!("{hdr}");
253 })
254 .await;
255 }
256}