1use std::future::Future;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::task::{Context, Poll};
40
41use api_bones::error::ApiError;
42use http::{Request, Response};
43use tower::{Layer, Service};
44
45#[derive(Clone, Debug)]
71pub struct RequestIdLayer {
72 counter: Arc<AtomicU64>,
73}
74
75impl RequestIdLayer {
76 #[must_use]
78 pub fn new() -> Self {
79 Self {
80 counter: Arc::new(AtomicU64::new(1)),
81 }
82 }
83}
84
85impl Default for RequestIdLayer {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<S> Layer<S> for RequestIdLayer {
92 type Service = RequestIdService<S>;
93
94 fn layer(&self, inner: S) -> Self::Service {
95 RequestIdService {
96 inner,
97 counter: Arc::clone(&self.counter),
98 }
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct RequestIdService<S> {
105 inner: S,
106 counter: Arc<AtomicU64>,
107}
108
109impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RequestIdService<S>
110where
111 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
112 S::Future: Send,
113 S::Error: Send,
114 ReqBody: Send + 'static,
115 ResBody: Default + Send,
116{
117 type Response = Response<ResBody>;
118 type Error = S::Error;
119 type Future = RequestIdFuture<S::Future, ResBody>;
120
121 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 self.inner.poll_ready(cx)
123 }
124
125 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
126 let request_id: String = if let Some(existing) = req.headers().get("x-request-id") {
128 existing.to_str().unwrap_or("invalid").to_owned()
129 } else {
130 let n = self.counter.fetch_add(1, Ordering::Relaxed);
131 let id = format!("req-{n}");
132 if let Ok(val) = http::HeaderValue::from_str(&id) {
133 req.headers_mut().insert("x-request-id", val);
134 }
135 id
136 };
137
138 let future = self.inner.call(req);
139 RequestIdFuture {
140 inner: future,
141 request_id,
142 _body: std::marker::PhantomData,
143 }
144 }
145}
146
147#[pin_project::pin_project]
149pub struct RequestIdFuture<F, ResBody> {
150 #[pin]
151 inner: F,
152 request_id: String,
153 _body: std::marker::PhantomData<ResBody>,
154}
155
156impl<F, ResBody, E> Future for RequestIdFuture<F, ResBody>
157where
158 F: Future<Output = Result<Response<ResBody>, E>>,
159{
160 type Output = Result<Response<ResBody>, E>;
161
162 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
163 let this = self.project();
164 match this.inner.poll(cx) {
165 Poll::Pending => Poll::Pending,
166 Poll::Ready(Ok(mut resp)) => {
167 if let Ok(val) = http::HeaderValue::from_str(this.request_id) {
168 resp.headers_mut().entry("x-request-id").or_insert(val);
169 }
170 Poll::Ready(Ok(resp))
171 }
172 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
173 }
174 }
175}
176
177#[derive(Clone, Debug, Default)]
203pub struct ProblemJsonLayer;
204
205impl ProblemJsonLayer {
206 #[must_use]
208 pub fn new() -> Self {
209 Self
210 }
211}
212
213impl<S> Layer<S> for ProblemJsonLayer {
214 type Service = ProblemJsonService<S>;
215
216 fn layer(&self, inner: S) -> Self::Service {
217 ProblemJsonService { inner }
218 }
219}
220
221#[derive(Clone, Debug)]
223pub struct ProblemJsonService<S> {
224 inner: S,
225}
226
227impl<S, ReqBody> Service<Request<ReqBody>> for ProblemJsonService<S>
228where
229 S: Service<Request<ReqBody>, Response = Response<String>> + Clone + Send + 'static,
230 S::Error: Into<ApiError> + Send,
231 S::Future: Send,
232 ReqBody: Send + 'static,
233{
234 type Response = Response<String>;
235 type Error = std::convert::Infallible;
236 type Future =
237 Pin<Box<dyn Future<Output = Result<Response<String>, std::convert::Infallible>> + Send>>;
238
239 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240 match self.inner.poll_ready(cx) {
241 Poll::Pending => Poll::Pending,
242 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
243 Poll::Ready(Err(_e)) => unreachable!("inner service poll_ready returned Err"),
244 }
245 }
246
247 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
248 let future = self.inner.call(req);
249 Box::pin(async move {
250 match future.await {
251 Ok(resp) => Ok(resp),
252 Err(e) => {
253 let api_err: ApiError = e.into();
254 Ok(api_error_to_response(api_err))
255 }
256 }
257 })
258 }
259}
260
261fn api_error_to_response(err: ApiError) -> Response<String> {
263 use api_bones::error::ProblemJson;
264
265 let status = err.status;
266 let problem = ProblemJson::from(err);
267 let body = serde_json::to_string(&problem).expect("ProblemJson serialization is infallible");
268
269 let status_code =
270 http::StatusCode::from_u16(status).unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
271
272 Response::builder()
273 .status(status_code)
274 .header("content-type", "application/problem+json")
275 .body(body)
276 .expect("response construction is infallible for valid status codes")
277}
278
279#[cfg(test)]
284mod tests {
285 use super::*;
286 use tower::{ServiceBuilder, ServiceExt};
287
288 #[tokio::test]
289 async fn request_id_layer_injects_header() {
290 let svc = ServiceBuilder::new()
291 .layer(RequestIdLayer::new())
292 .service(tower::service_fn(|req: Request<()>| async move {
293 let id = req
294 .headers()
295 .get("x-request-id")
296 .and_then(|v| v.to_str().ok())
297 .unwrap_or("")
298 .to_owned();
299 let resp = Response::new(id);
300 Ok::<_, std::convert::Infallible>(resp)
301 }));
302
303 let req = Request::builder().uri("/").body(()).unwrap();
304 let resp = svc.oneshot(req).await.unwrap();
305 assert!(resp.headers().contains_key("x-request-id"));
306 }
307
308 #[tokio::test]
309 async fn request_id_layer_preserves_existing_header() {
310 let svc = ServiceBuilder::new()
311 .layer(RequestIdLayer::new())
312 .service(tower::service_fn(|_req: Request<()>| async move {
313 Ok::<_, std::convert::Infallible>(Response::new(String::new()))
314 }));
315
316 let req = Request::builder()
317 .uri("/")
318 .header("x-request-id", "client-id")
319 .body(())
320 .unwrap();
321 let resp = svc.oneshot(req).await.unwrap();
322 assert_eq!(
323 resp.headers()
324 .get("x-request-id")
325 .unwrap()
326 .to_str()
327 .unwrap(),
328 "client-id"
329 );
330 }
331
332 #[tokio::test]
333 async fn problem_json_layer_maps_error() {
334 let svc = ServiceBuilder::new()
335 .layer(ProblemJsonLayer::new())
336 .service(tower::service_fn(|_req: Request<()>| async move {
337 Err::<Response<String>, ApiError>(ApiError::not_found("item 1"))
338 }));
339
340 let req = Request::builder().uri("/").body(()).unwrap();
341 let resp = svc.oneshot(req).await.unwrap();
342 assert_eq!(resp.status().as_u16(), 404);
343 assert_eq!(
344 resp.headers()
345 .get("content-type")
346 .unwrap()
347 .to_str()
348 .unwrap(),
349 "application/problem+json"
350 );
351 }
352
353 #[tokio::test]
354 async fn problem_json_layer_passes_through_ok() {
355 let svc = ServiceBuilder::new()
356 .layer(ProblemJsonLayer::new())
357 .service(tower::service_fn(|_req: Request<()>| async move {
358 Ok::<_, ApiError>(
359 Response::builder()
360 .status(200)
361 .body("ok".to_owned())
362 .unwrap(),
363 )
364 }));
365
366 let req = Request::builder().uri("/").body(()).unwrap();
367 let resp = svc.oneshot(req).await.unwrap();
368 assert_eq!(resp.status().as_u16(), 200);
369 }
370
371 #[test]
372 fn request_id_layer_default_is_same_as_new() {
373 let _layer = RequestIdLayer::default();
374 }
375
376 #[tokio::test]
377 async fn problem_json_service_poll_ready() {
378 use tower::{Service, ServiceExt};
379
380 let inner = tower::service_fn(|_req: Request<()>| async move {
381 Ok::<_, ApiError>(Response::builder().body("ok".to_owned()).unwrap())
382 });
383 let mut svc = ProblemJsonService { inner };
384 let svc_ref = svc.ready().await.unwrap();
385 let req = Request::builder().uri("/").body(()).unwrap();
386 let resp = svc_ref.call(req).await.unwrap();
387 assert_eq!(resp.status().as_u16(), 200);
388 }
389
390 #[tokio::test]
391 async fn request_id_future_propagates_inner_error() {
392 let svc = ServiceBuilder::new()
393 .layer(RequestIdLayer::new())
394 .service(tower::service_fn(|_req: Request<()>| async move {
395 Err::<Response<String>, ApiError>(ApiError::internal("boom"))
396 }));
397
398 let req = Request::builder().uri("/").body(()).unwrap();
399 let result = svc.oneshot(req).await;
400 let err = result.unwrap_err();
401 assert_eq!(err.status, 500);
402 }
403
404 #[tokio::test]
405 async fn request_id_future_poll_pending() {
406 use std::sync::{
407 Arc,
408 atomic::{AtomicBool, Ordering},
409 };
410
411 let ready = Arc::new(AtomicBool::new(false));
412 let ready2 = Arc::clone(&ready);
413
414 let inner = tower::service_fn(move |_req: Request<()>| {
415 let flag = Arc::clone(&ready2);
416 async move {
417 tokio::task::yield_now().await;
418 flag.store(true, Ordering::SeqCst);
419 Ok::<Response<String>, std::convert::Infallible>(
420 Response::builder().body(String::new()).unwrap(),
421 )
422 }
423 });
424
425 let layer = RequestIdLayer::new();
426 let mut svc = layer.layer(inner);
427
428 let req = Request::builder().uri("/").body(()).unwrap();
429 let fut = tower::Service::call(&mut svc, req);
430 let resp = fut.await.unwrap();
431 assert!(resp.headers().contains_key("x-request-id"));
432 assert!(ready.load(Ordering::SeqCst));
433 }
434}