rama_core/service/
svc.rs

1//! [`Service`] and [`BoxService`] traits.
2
3use crate::Context;
4use crate::error::BoxError;
5use std::pin::Pin;
6
7/// A [`Service`] that produces rama services,
8/// to serve requests with, be it transport layer requests or application layer requests.
9pub trait Service<S, Request>: Sized + Send + Sync + 'static {
10    /// The type of response returned by the service.
11    type Response: Send + 'static;
12
13    /// The type of error returned by the service.
14    type Error: Send + 'static;
15
16    /// Serve a response or error for the given request,
17    /// using the given context.
18    fn serve(
19        &self,
20        ctx: Context<S>,
21        req: Request,
22    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_;
23
24    /// Box this service to allow for dynamic dispatch.
25    fn boxed(self) -> BoxService<S, Request, Self::Response, Self::Error> {
26        BoxService {
27            inner: Box::new(self),
28        }
29    }
30}
31
32impl<S, State, Request> Service<State, Request> for std::sync::Arc<S>
33where
34    S: Service<State, Request>,
35{
36    type Response = S::Response;
37    type Error = S::Error;
38
39    #[inline]
40    fn serve(
41        &self,
42        ctx: Context<State>,
43        req: Request,
44    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
45        self.as_ref().serve(ctx, req)
46    }
47}
48
49impl<S, State, Request> Service<State, Request> for &'static S
50where
51    S: Service<State, Request>,
52{
53    type Response = S::Response;
54    type Error = S::Error;
55
56    #[inline]
57    fn serve(
58        &self,
59        ctx: Context<State>,
60        req: Request,
61    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
62        (**self).serve(ctx, req)
63    }
64}
65
66impl<S, State, Request> Service<State, Request> for Box<S>
67where
68    S: Service<State, Request>,
69{
70    type Response = S::Response;
71    type Error = S::Error;
72
73    #[inline]
74    fn serve(
75        &self,
76        ctx: Context<State>,
77        req: Request,
78    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
79        self.as_ref().serve(ctx, req)
80    }
81}
82
83/// Internal trait for dynamic dispatch of Async Traits,
84/// implemented according to the pioneers of this Design Pattern
85/// found at <https://rust-lang.github.io/async-fundamentals-initiative/evaluation/case-studies/builder-provider-api.html#dynamic-dispatch-behind-the-api>
86/// and widely published at <https://blog.rust-lang.org/inside-rust/2023/05/03/stabilizing-async-fn-in-trait.html>.
87trait DynService<S, Request> {
88    type Response;
89    type Error;
90
91    #[allow(clippy::type_complexity)]
92    fn serve_box(
93        &self,
94        ctx: Context<S>,
95        req: Request,
96    ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>>;
97}
98
99impl<S, Request, T> DynService<S, Request> for T
100where
101    T: Service<S, Request>,
102{
103    type Response = T::Response;
104    type Error = T::Error;
105
106    fn serve_box(
107        &self,
108        ctx: Context<S>,
109        req: Request,
110    ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>> {
111        Box::pin(self.serve(ctx, req))
112    }
113}
114
115/// A boxed [`Service`], to serve requests with,
116/// for where you require dynamic dispatch.
117pub struct BoxService<S, Request, Response, Error> {
118    inner:
119        Box<dyn DynService<S, Request, Response = Response, Error = Error> + Send + Sync + 'static>,
120}
121
122impl<S, Request, Response, Error> BoxService<S, Request, Response, Error> {
123    /// Create a new [`BoxService`] from the given service.
124    pub fn new<T>(service: T) -> Self
125    where
126        T: Service<S, Request, Response = Response, Error = Error>,
127    {
128        Self {
129            inner: Box::new(service),
130        }
131    }
132}
133
134impl<S, Request, Response, Error> std::fmt::Debug for BoxService<S, Request, Response, Error> {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("BoxService").finish()
137    }
138}
139
140impl<S, Request, Response, Error> Service<S, Request> for BoxService<S, Request, Response, Error>
141where
142    S: 'static,
143    Request: 'static,
144    Response: Send + 'static,
145    Error: Send + 'static,
146{
147    type Response = Response;
148    type Error = Error;
149
150    fn serve(
151        &self,
152        ctx: Context<S>,
153        req: Request,
154    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
155        self.inner.serve_box(ctx, req)
156    }
157}
158
159macro_rules! impl_service_either {
160    ($id:ident, $($param:ident),+ $(,)?) => {
161        impl<$($param),+, State, Request, Response> Service<State, Request> for crate::combinators::$id<$($param),+>
162        where
163            $(
164                $param: Service<State, Request, Response = Response, Error: Into<BoxError>>,
165            )+
166            Request: Send + 'static,
167            State: Clone + Send + Sync + 'static,
168            Response: Send + 'static,
169        {
170            type Response = Response;
171            type Error = BoxError;
172
173            async fn serve(&self, ctx: Context<State>, req: Request) -> Result<Self::Response, Self::Error> {
174                match self {
175                    $(
176                        crate::combinators::$id::$param(s) => s.serve(ctx, req).await.map_err(Into::into),
177                    )+
178                }
179            }
180        }
181    };
182}
183
184crate::combinators::impl_either!(impl_service_either);
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::convert::Infallible;
190
191    #[derive(Debug)]
192    struct AddSvc(usize);
193
194    impl Service<(), usize> for AddSvc {
195        type Response = usize;
196        type Error = Infallible;
197
198        async fn serve(
199            &self,
200            _ctx: Context<()>,
201            req: usize,
202        ) -> Result<Self::Response, Self::Error> {
203            Ok(self.0 + req)
204        }
205    }
206
207    #[derive(Debug)]
208    struct MulSvc(usize);
209
210    impl Service<(), usize> for MulSvc {
211        type Response = usize;
212        type Error = Infallible;
213
214        async fn serve(
215            &self,
216            _ctx: Context<()>,
217            req: usize,
218        ) -> Result<Self::Response, Self::Error> {
219            Ok(self.0 * req)
220        }
221    }
222
223    #[test]
224    fn assert_send() {
225        use rama_utils::test_helpers::*;
226
227        assert_send::<AddSvc>();
228        assert_send::<MulSvc>();
229        assert_send::<BoxService<(), (), (), ()>>();
230    }
231
232    #[test]
233    fn assert_sync() {
234        use rama_utils::test_helpers::*;
235
236        assert_sync::<AddSvc>();
237        assert_sync::<MulSvc>();
238        assert_sync::<BoxService<(), (), (), ()>>();
239    }
240
241    #[tokio::test]
242    async fn add_svc() {
243        let svc = AddSvc(1);
244
245        let ctx = Context::default();
246
247        let response = svc.serve(ctx, 1).await.unwrap();
248        assert_eq!(response, 2);
249    }
250
251    #[tokio::test]
252    async fn static_dispatch() {
253        let services = vec![AddSvc(1), AddSvc(2), AddSvc(3)];
254
255        let ctx = Context::default();
256
257        for (i, svc) in services.into_iter().enumerate() {
258            let response = svc.serve(ctx.clone(), i).await.unwrap();
259            assert_eq!(response, i * 2 + 1);
260        }
261    }
262
263    #[tokio::test]
264    async fn dynamic_dispatch() {
265        let services = vec![
266            AddSvc(1).boxed(),
267            AddSvc(2).boxed(),
268            AddSvc(3).boxed(),
269            MulSvc(4).boxed(),
270            MulSvc(5).boxed(),
271        ];
272
273        let ctx = Context::default();
274
275        for (i, svc) in services.into_iter().enumerate() {
276            let response = svc.serve(ctx.clone(), i).await.unwrap();
277            if i < 3 {
278                assert_eq!(response, i * 2 + 1);
279            } else {
280                assert_eq!(response, i * (i + 1));
281            }
282        }
283    }
284
285    #[tokio::test]
286    async fn service_arc() {
287        let svc = std::sync::Arc::new(AddSvc(1));
288
289        let ctx = Context::default();
290
291        let response = svc.serve(ctx, 1).await.unwrap();
292        assert_eq!(response, 2);
293    }
294
295    #[tokio::test]
296    async fn box_service_arc() {
297        let svc = std::sync::Arc::new(AddSvc(1)).boxed();
298
299        let ctx = Context::default();
300
301        let response = svc.serve(ctx, 1).await.unwrap();
302        assert_eq!(response, 2);
303    }
304}