vidi_tower/
lib.rs

1//! An adapter that makes a tower [`Service`] into a [`Handler`].
2
3use tower::{Service, ServiceExt};
4use vidi_core::{Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result};
5
6mod service;
7pub use service::HandlerService;
8
9mod middleware;
10pub use middleware::Middleware;
11
12mod layer;
13pub use layer::Layered;
14
15/// Converts a tower [`Service`] into a [`Handler`].
16#[derive(Clone, Debug)]
17pub struct ServiceHandler<S>(S);
18
19impl<S> ServiceHandler<S> {
20    /// Creates a new [`ServiceHandler`].
21    pub const fn new(s: S) -> Self {
22        Self(s)
23    }
24}
25
26#[vidi_core::async_trait]
27impl<O, S> Handler<Request> for ServiceHandler<S>
28where
29    O: HttpBody + Send + 'static,
30    O::Data: Into<Bytes>,
31    O::Error: Into<BoxError>,
32    S: Service<Request, Response = Response<O>> + Send + Sync + Clone + 'static,
33    S::Future: Send,
34    S::Error: Into<BoxError>,
35{
36    type Output = Result<Response>;
37
38    async fn call(&self, req: Request) -> Self::Output {
39        self.0
40            .clone()
41            .oneshot(req)
42            .await
43            .map_err(Error::boxed)
44            .map(|resp| resp.map(Body::wrap))
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51    use std::{
52        sync::{
53            Arc,
54            atomic::{AtomicU64, Ordering},
55        },
56        time::Duration,
57    };
58    use tower::util::{MapErrLayer, MapRequestLayer, MapResponseLayer};
59    use tower::{ServiceBuilder, service_fn};
60    use tower_http::{
61        limit::RequestBodyLimitLayer,
62        request_id::{MakeRequestId, RequestId, SetRequestIdLayer},
63        timeout::TimeoutLayer,
64    };
65    use vidi_core::{
66        Body, BoxHandler, Handler, HandlerExt, IntoResponse, Request, RequestExt, Response,
67        StatusCode,
68    };
69
70    #[derive(Clone, Debug, Default)]
71    struct MyMakeRequestId {
72        counter: Arc<AtomicU64>,
73    }
74
75    impl MakeRequestId for MyMakeRequestId {
76        fn make_request_id<B>(&mut self, _: &Request<B>) -> Option<RequestId> {
77            let request_id = self
78                .counter
79                .fetch_add(1, Ordering::SeqCst)
80                .to_string()
81                .parse()
82                .unwrap();
83
84            Some(RequestId::new(request_id))
85        }
86    }
87
88    async fn hello(mut req: Request) -> Result<Response> {
89        let bytes = req.bytes().await?;
90        Ok(bytes.into_response())
91    }
92
93    #[tokio::test]
94    async fn tower_service_into_handler() {
95        let hello_svc = service_fn(hello);
96
97        let svc = ServiceBuilder::new()
98            .layer(RequestBodyLimitLayer::new(1))
99            .layer(MapErrLayer::new(Error::from))
100            .layer(SetRequestIdLayer::x_request_id(MyMakeRequestId::default()))
101            .layer(MapResponseLayer::new(IntoResponse::into_response))
102            .layer(MapRequestLayer::new(|req: Request<_>| req.map(Body::wrap)))
103            .layer(TimeoutLayer::with_status_code(
104                StatusCode::REQUEST_TIMEOUT,
105                Duration::from_secs(10),
106            ))
107            .service(hello_svc);
108
109        let r0 = Request::new(Body::Full("12".into()));
110        let h0 = ServiceHandler::new(svc);
111        assert!(h0.call(r0).await.is_err());
112
113        let r1 = Request::new(Body::Full("1".into()));
114        let b0: BoxHandler = h0.boxed();
115        assert!(b0.call(r1).await.is_ok());
116    }
117}