Skip to main content

api_bones_tower/
lib.rs

1//! Tower middleware building blocks for api-bones services.
2//!
3//! Provides composable Tower [`Layer`](tower::Layer) / [`Service`](tower::Service)
4//! implementations for:
5//!
6//! | Layer                | What it does                                         |
7//! |----------------------|------------------------------------------------------|
8//! | [`RequestIdLayer`]   | Generates / propagates `X-Request-Id` on every req  |
9//! | [`ProblemJsonLayer`] | Maps non-`ApiError` inner-service errors to Problem+JSON |
10//!
11//! ## Feature flags
12//!
13//! By default this crate enables `std` and `serde` on `api-bones`.
14//! Additional `api-bones` features can be opted into:
15//!
16//! | Feature  | What it enables                              |
17//! |----------|----------------------------------------------|
18//! | `uuid`   | UUID-based request IDs (`api-bones/uuid`)    |
19//! | `chrono` | Chrono timestamp types (`api-bones/chrono`)  |
20//!
21//! # Example
22//!
23//! ```rust,no_run
24//! use api_bones_tower::{RequestIdLayer, ProblemJsonLayer};
25//! use tower::ServiceBuilder;
26//!
27//! let _svc = ServiceBuilder::new()
28//!     .layer(RequestIdLayer::new())
29//!     .layer(ProblemJsonLayer::new())
30//!     .service(tower::service_fn(|_req: http::Request<()>| async {
31//!         Ok::<_, std::convert::Infallible>(http::Response::new(()))
32//!     }));
33//! ```
34
35use std::future::Future;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::task::{Context, Poll};
40
41use api_bones::error::ApiError;
42use http::{Request, Response};
43use tower::{Layer, Service};
44
45// ---------------------------------------------------------------------------
46// RequestIdLayer
47// ---------------------------------------------------------------------------
48
49/// Tower [`Layer`] that ensures every request carries an `X-Request-Id` header.
50///
51/// - If the incoming request already has an `X-Request-Id`, it is forwarded
52///   unchanged.
53/// - Otherwise a monotonically-increasing numeric ID is generated and injected
54///   (format: `req-<n>`).
55///
56/// The same header value is echoed back in the response.
57///
58/// # Example
59///
60/// ```rust,no_run
61/// use api_bones_tower::RequestIdLayer;
62/// use tower::ServiceBuilder;
63///
64/// let _svc = ServiceBuilder::new()
65///     .layer(RequestIdLayer::new())
66///     .service(tower::service_fn(|_req: http::Request<()>| async {
67///         Ok::<_, std::convert::Infallible>(http::Response::new(()))
68///     }));
69/// ```
70#[derive(Clone, Debug)]
71pub struct RequestIdLayer {
72    counter: Arc<AtomicU64>,
73}
74
75impl RequestIdLayer {
76    /// Create a new `RequestIdLayer` with an internal counter starting at 1.
77    #[must_use]
78    pub fn new() -> Self {
79        Self {
80            counter: Arc::new(AtomicU64::new(1)),
81        }
82    }
83}
84
85impl Default for RequestIdLayer {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl<S> Layer<S> for RequestIdLayer {
92    type Service = RequestIdService<S>;
93
94    fn layer(&self, inner: S) -> Self::Service {
95        RequestIdService {
96            inner,
97            counter: Arc::clone(&self.counter),
98        }
99    }
100}
101
102/// Tower [`Service`] produced by [`RequestIdLayer`].
103#[derive(Clone, Debug)]
104pub struct RequestIdService<S> {
105    inner: S,
106    counter: Arc<AtomicU64>,
107}
108
109impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RequestIdService<S>
110where
111    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
112    S::Future: Send,
113    S::Error: Send,
114    ReqBody: Send + 'static,
115    ResBody: Default + Send,
116{
117    type Response = Response<ResBody>;
118    type Error = S::Error;
119    type Future = RequestIdFuture<S::Future, ResBody>;
120
121    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122        self.inner.poll_ready(cx)
123    }
124
125    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
126        // Determine (or generate) the request ID.
127        let request_id: String = if let Some(existing) = req.headers().get("x-request-id") {
128            existing.to_str().unwrap_or("invalid").to_owned()
129        } else {
130            let n = self.counter.fetch_add(1, Ordering::Relaxed);
131            let id = format!("req-{n}");
132            if let Ok(val) = http::HeaderValue::from_str(&id) {
133                req.headers_mut().insert("x-request-id", val);
134            }
135            id
136        };
137
138        let future = self.inner.call(req);
139        RequestIdFuture {
140            inner: future,
141            request_id,
142            _body: std::marker::PhantomData,
143        }
144    }
145}
146
147/// Future returned by [`RequestIdService`].
148#[pin_project::pin_project]
149pub struct RequestIdFuture<F, ResBody> {
150    #[pin]
151    inner: F,
152    request_id: String,
153    _body: std::marker::PhantomData<ResBody>,
154}
155
156impl<F, ResBody, E> Future for RequestIdFuture<F, ResBody>
157where
158    F: Future<Output = Result<Response<ResBody>, E>>,
159{
160    type Output = Result<Response<ResBody>, E>;
161
162    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
163        let this = self.project();
164        match this.inner.poll(cx) {
165            Poll::Pending => Poll::Pending,
166            Poll::Ready(Ok(mut resp)) => {
167                if let Ok(val) = http::HeaderValue::from_str(this.request_id) {
168                    resp.headers_mut().entry("x-request-id").or_insert(val);
169                }
170                Poll::Ready(Ok(resp))
171            }
172            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
173        }
174    }
175}
176
177// ---------------------------------------------------------------------------
178// ProblemJsonLayer
179// ---------------------------------------------------------------------------
180
181/// Tower [`Layer`] that maps inner-service errors into Problem+JSON HTTP
182/// responses.
183///
184/// Any `Err` propagated from the inner service is converted to an [`ApiError`]
185/// via the [`Into<ApiError>`] bound and then serialized as
186/// `application/problem+json`.
187///
188/// Successful responses are passed through unchanged.
189///
190/// # Example
191///
192/// ```rust,no_run
193/// use api_bones_tower::ProblemJsonLayer;
194/// use tower::ServiceBuilder;
195///
196/// let _svc = ServiceBuilder::new()
197///     .layer(ProblemJsonLayer::new())
198///     .service(tower::service_fn(|_req: http::Request<()>| async {
199///         Ok::<_, api_bones::ApiError>(http::Response::new(String::new()))
200///     }));
201/// ```
202#[derive(Clone, Debug, Default)]
203pub struct ProblemJsonLayer;
204
205impl ProblemJsonLayer {
206    /// Create a new `ProblemJsonLayer`.
207    #[must_use]
208    pub fn new() -> Self {
209        Self
210    }
211}
212
213impl<S> Layer<S> for ProblemJsonLayer {
214    type Service = ProblemJsonService<S>;
215
216    fn layer(&self, inner: S) -> Self::Service {
217        ProblemJsonService { inner }
218    }
219}
220
221/// Tower [`Service`] produced by [`ProblemJsonLayer`].
222#[derive(Clone, Debug)]
223pub struct ProblemJsonService<S> {
224    inner: S,
225}
226
227impl<S, ReqBody> Service<Request<ReqBody>> for ProblemJsonService<S>
228where
229    S: Service<Request<ReqBody>, Response = Response<String>> + Clone + Send + 'static,
230    S::Error: Into<ApiError> + Send,
231    S::Future: Send,
232    ReqBody: Send + 'static,
233{
234    type Response = Response<String>;
235    type Error = std::convert::Infallible;
236    type Future =
237        Pin<Box<dyn Future<Output = Result<Response<String>, std::convert::Infallible>> + Send>>;
238
239    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240        match self.inner.poll_ready(cx) {
241            Poll::Pending => Poll::Pending,
242            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
243            Poll::Ready(Err(_e)) => unreachable!("inner service poll_ready returned Err"),
244        }
245    }
246
247    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
248        let future = self.inner.call(req);
249        Box::pin(async move {
250            match future.await {
251                Ok(resp) => Ok(resp),
252                Err(e) => {
253                    let api_err: ApiError = e.into();
254                    Ok(api_error_to_response(api_err))
255                }
256            }
257        })
258    }
259}
260
261/// Convert an [`ApiError`] into an HTTP response with `application/problem+json`.
262fn api_error_to_response(err: ApiError) -> Response<String> {
263    use api_bones::error::ProblemJson;
264
265    let status = err.status;
266    let problem = ProblemJson::from(err);
267    let body = serde_json::to_string(&problem).expect("ProblemJson serialization is infallible");
268
269    let status_code =
270        http::StatusCode::from_u16(status).unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
271
272    Response::builder()
273        .status(status_code)
274        .header("content-type", "application/problem+json")
275        .body(body)
276        .expect("response construction is infallible for valid status codes")
277}
278
279// ---------------------------------------------------------------------------
280// Tests
281// ---------------------------------------------------------------------------
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use tower::{ServiceBuilder, ServiceExt};
287
288    #[tokio::test]
289    async fn request_id_layer_injects_header() {
290        let svc = ServiceBuilder::new()
291            .layer(RequestIdLayer::new())
292            .service(tower::service_fn(|req: Request<()>| async move {
293                let id = req
294                    .headers()
295                    .get("x-request-id")
296                    .and_then(|v| v.to_str().ok())
297                    .unwrap_or("")
298                    .to_owned();
299                let resp = Response::new(id);
300                Ok::<_, std::convert::Infallible>(resp)
301            }));
302
303        let req = Request::builder().uri("/").body(()).unwrap();
304        let resp = svc.oneshot(req).await.unwrap();
305        assert!(resp.headers().contains_key("x-request-id"));
306    }
307
308    #[tokio::test]
309    async fn request_id_layer_preserves_existing_header() {
310        let svc = ServiceBuilder::new()
311            .layer(RequestIdLayer::new())
312            .service(tower::service_fn(|_req: Request<()>| async move {
313                Ok::<_, std::convert::Infallible>(Response::new(String::new()))
314            }));
315
316        let req = Request::builder()
317            .uri("/")
318            .header("x-request-id", "client-id")
319            .body(())
320            .unwrap();
321        let resp = svc.oneshot(req).await.unwrap();
322        assert_eq!(
323            resp.headers()
324                .get("x-request-id")
325                .unwrap()
326                .to_str()
327                .unwrap(),
328            "client-id"
329        );
330    }
331
332    #[tokio::test]
333    async fn problem_json_layer_maps_error() {
334        let svc = ServiceBuilder::new()
335            .layer(ProblemJsonLayer::new())
336            .service(tower::service_fn(|_req: Request<()>| async move {
337                Err::<Response<String>, ApiError>(ApiError::not_found("item 1"))
338            }));
339
340        let req = Request::builder().uri("/").body(()).unwrap();
341        let resp = svc.oneshot(req).await.unwrap();
342        assert_eq!(resp.status().as_u16(), 404);
343        assert_eq!(
344            resp.headers()
345                .get("content-type")
346                .unwrap()
347                .to_str()
348                .unwrap(),
349            "application/problem+json"
350        );
351    }
352
353    #[tokio::test]
354    async fn problem_json_layer_passes_through_ok() {
355        let svc = ServiceBuilder::new()
356            .layer(ProblemJsonLayer::new())
357            .service(tower::service_fn(|_req: Request<()>| async move {
358                Ok::<_, ApiError>(
359                    Response::builder()
360                        .status(200)
361                        .body("ok".to_owned())
362                        .unwrap(),
363                )
364            }));
365
366        let req = Request::builder().uri("/").body(()).unwrap();
367        let resp = svc.oneshot(req).await.unwrap();
368        assert_eq!(resp.status().as_u16(), 200);
369    }
370
371    #[test]
372    fn request_id_layer_default_is_same_as_new() {
373        let _layer = RequestIdLayer::default();
374    }
375
376    #[tokio::test]
377    async fn problem_json_service_poll_ready() {
378        use tower::{Service, ServiceExt};
379
380        let inner = tower::service_fn(|_req: Request<()>| async move {
381            Ok::<_, ApiError>(Response::builder().body("ok".to_owned()).unwrap())
382        });
383        let mut svc = ProblemJsonService { inner };
384        let svc_ref = svc.ready().await.unwrap();
385        let req = Request::builder().uri("/").body(()).unwrap();
386        let resp = svc_ref.call(req).await.unwrap();
387        assert_eq!(resp.status().as_u16(), 200);
388    }
389
390    #[tokio::test]
391    async fn request_id_future_propagates_inner_error() {
392        let svc = ServiceBuilder::new()
393            .layer(RequestIdLayer::new())
394            .service(tower::service_fn(|_req: Request<()>| async move {
395                Err::<Response<String>, ApiError>(ApiError::internal("boom"))
396            }));
397
398        let req = Request::builder().uri("/").body(()).unwrap();
399        let result = svc.oneshot(req).await;
400        let err = result.unwrap_err();
401        assert_eq!(err.status, 500);
402    }
403
404    #[tokio::test]
405    async fn request_id_future_poll_pending() {
406        use std::sync::{
407            Arc,
408            atomic::{AtomicBool, Ordering},
409        };
410
411        let ready = Arc::new(AtomicBool::new(false));
412        let ready2 = Arc::clone(&ready);
413
414        let inner = tower::service_fn(move |_req: Request<()>| {
415            let flag = Arc::clone(&ready2);
416            async move {
417                tokio::task::yield_now().await;
418                flag.store(true, Ordering::SeqCst);
419                Ok::<Response<String>, std::convert::Infallible>(
420                    Response::builder().body(String::new()).unwrap(),
421                )
422            }
423        });
424
425        let layer = RequestIdLayer::new();
426        let mut svc = layer.layer(inner);
427
428        let req = Request::builder().uri("/").body(()).unwrap();
429        let fut = tower::Service::call(&mut svc, req);
430        let resp = fut.await.unwrap();
431        assert!(resp.headers().contains_key("x-request-id"));
432        assert!(ready.load(Ordering::SeqCst));
433    }
434}