1use crate::Context;
4use crate::error::BoxError;
5use std::fmt;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::sync::Arc;
9
10pub trait Service<S, Request>: Sized + Send + Sync + 'static {
13    type Response: Send + 'static;
15
16    type Error: Send + 'static;
18
19    fn serve(
22        &self,
23        ctx: Context<S>,
24        req: Request,
25    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_;
26
27    fn boxed(self) -> BoxService<S, Request, Self::Response, Self::Error> {
29        BoxService::new(self)
30    }
31}
32
33impl<S, State, Request> Service<State, Request> for std::sync::Arc<S>
34where
35    S: Service<State, Request>,
36{
37    type Response = S::Response;
38    type Error = S::Error;
39
40    #[inline]
41    fn serve(
42        &self,
43        ctx: Context<State>,
44        req: Request,
45    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
46        self.as_ref().serve(ctx, req)
47    }
48}
49
50impl<S, State, Request> Service<State, Request> for &'static S
51where
52    S: Service<State, Request>,
53{
54    type Response = S::Response;
55    type Error = S::Error;
56
57    #[inline]
58    fn serve(
59        &self,
60        ctx: Context<State>,
61        req: Request,
62    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
63        (**self).serve(ctx, req)
64    }
65}
66
67impl<S, State, Request> Service<State, Request> for Box<S>
68where
69    S: Service<State, Request>,
70{
71    type Response = S::Response;
72    type Error = S::Error;
73
74    #[inline]
75    fn serve(
76        &self,
77        ctx: Context<State>,
78        req: Request,
79    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
80        self.as_ref().serve(ctx, req)
81    }
82}
83
84trait DynService<S, Request> {
89    type Response;
90    type Error;
91
92    #[allow(clippy::type_complexity)]
93    fn serve_box(
94        &self,
95        ctx: Context<S>,
96        req: Request,
97    ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>>;
98}
99
100impl<S, Request, T> DynService<S, Request> for T
101where
102    T: Service<S, Request>,
103{
104    type Response = T::Response;
105    type Error = T::Error;
106
107    fn serve_box(
108        &self,
109        ctx: Context<S>,
110        req: Request,
111    ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>> {
112        Box::pin(self.serve(ctx, req))
113    }
114}
115
116pub struct BoxService<S, Request, Response, Error> {
119    inner:
120        Arc<dyn DynService<S, Request, Response = Response, Error = Error> + Send + Sync + 'static>,
121}
122
123impl<S, Request, Response, Error> Clone for BoxService<S, Request, Response, Error> {
124    fn clone(&self) -> Self {
125        Self {
126            inner: self.inner.clone(),
127        }
128    }
129}
130
131impl<S, Request, Response, Error> BoxService<S, Request, Response, Error> {
132    #[inline]
134    pub fn new<T>(service: T) -> Self
135    where
136        T: Service<S, Request, Response = Response, Error = Error>,
137    {
138        Self {
139            inner: Arc::new(service),
140        }
141    }
142}
143
144impl<S, Request, Response, Error> std::fmt::Debug for BoxService<S, Request, Response, Error> {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("BoxService").finish()
147    }
148}
149
150impl<S, Request, Response, Error> Service<S, Request> for BoxService<S, Request, Response, Error>
151where
152    S: 'static,
153    Request: 'static,
154    Response: Send + 'static,
155    Error: Send + 'static,
156{
157    type Response = Response;
158    type Error = Error;
159
160    #[inline]
161    fn serve(
162        &self,
163        ctx: Context<S>,
164        req: Request,
165    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
166        self.inner.serve_box(ctx, req)
167    }
168
169    #[inline]
170    fn boxed(self) -> Self {
171        self
172    }
173}
174
175macro_rules! impl_service_either {
176    ($id:ident, $($param:ident),+ $(,)?) => {
177        impl<$($param),+, State, Request, Response> Service<State, Request> for crate::combinators::$id<$($param),+>
178        where
179            $(
180                $param: Service<State, Request, Response = Response, Error: Into<BoxError>>,
181            )+
182            Request: Send + 'static,
183            State: Clone + Send + Sync + 'static,
184            Response: Send + 'static,
185        {
186            type Response = Response;
187            type Error = BoxError;
188
189            async fn serve(&self, ctx: Context<State>, req: Request) -> Result<Self::Response, Self::Error> {
190                match self {
191                    $(
192                        crate::combinators::$id::$param(s) => s.serve(ctx, req).await.map_err(Into::into),
193                    )+
194                }
195            }
196        }
197    };
198}
199
200crate::combinators::impl_either!(impl_service_either);
201
202rama_utils::macros::error::static_str_error! {
203    #[doc = "request rejected"]
204    pub struct RejectError;
205}
206
207pub struct RejectService<R = (), E = RejectError> {
209    error: E,
210    _phantom: PhantomData<fn() -> R>,
211}
212
213impl Default for RejectService {
214    fn default() -> Self {
215        Self {
216            error: RejectError,
217            _phantom: PhantomData,
218        }
219    }
220}
221
222impl<R, E: Clone + Send + Sync + 'static> RejectService<R, E> {
223    pub fn new(error: E) -> Self {
225        Self {
226            error,
227            _phantom: PhantomData,
228        }
229    }
230}
231
232impl<R, E: Clone> Clone for RejectService<R, E> {
233    fn clone(&self) -> Self {
234        Self {
235            error: self.error.clone(),
236            _phantom: PhantomData,
237        }
238    }
239}
240
241impl<R, E: fmt::Debug> fmt::Debug for RejectService<R, E> {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        f.debug_struct("RejectService")
244            .field("error", &self.error)
245            .field(
246                "_phantom",
247                &format_args!("{}", std::any::type_name::<fn() -> R>()),
248            )
249            .finish()
250    }
251}
252
253impl<S, Request, Response, Error> Service<S, Request> for RejectService<Response, Error>
254where
255    S: 'static,
256    Request: 'static,
257    Response: Send + 'static,
258    Error: Clone + Send + Sync + 'static,
259{
260    type Response = Response;
261    type Error = Error;
262
263    #[inline]
264    fn serve(
265        &self,
266        _ctx: Context<S>,
267        _req: Request,
268    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
269        let error = self.error.clone();
270        std::future::ready(Err(error))
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use std::convert::Infallible;
278
279    #[derive(Debug)]
280    struct AddSvc(usize);
281
282    impl Service<(), usize> for AddSvc {
283        type Response = usize;
284        type Error = Infallible;
285
286        async fn serve(
287            &self,
288            _ctx: Context<()>,
289            req: usize,
290        ) -> Result<Self::Response, Self::Error> {
291            Ok(self.0 + req)
292        }
293    }
294
295    #[derive(Debug)]
296    struct MulSvc(usize);
297
298    impl Service<(), usize> for MulSvc {
299        type Response = usize;
300        type Error = Infallible;
301
302        async fn serve(
303            &self,
304            _ctx: Context<()>,
305            req: usize,
306        ) -> Result<Self::Response, Self::Error> {
307            Ok(self.0 * req)
308        }
309    }
310
311    #[test]
312    fn assert_send() {
313        use rama_utils::test_helpers::*;
314
315        assert_send::<AddSvc>();
316        assert_send::<MulSvc>();
317        assert_send::<BoxService<(), (), (), ()>>();
318        assert_send::<RejectService>();
319    }
320
321    #[test]
322    fn assert_sync() {
323        use rama_utils::test_helpers::*;
324
325        assert_sync::<AddSvc>();
326        assert_sync::<MulSvc>();
327        assert_sync::<BoxService<(), (), (), ()>>();
328        assert_sync::<RejectService>();
329    }
330
331    #[tokio::test]
332    async fn add_svc() {
333        let svc = AddSvc(1);
334
335        let ctx = Context::default();
336
337        let response = svc.serve(ctx, 1).await.unwrap();
338        assert_eq!(response, 2);
339    }
340
341    #[tokio::test]
342    async fn static_dispatch() {
343        let services = vec![AddSvc(1), AddSvc(2), AddSvc(3)];
344
345        let ctx = Context::default();
346
347        for (i, svc) in services.into_iter().enumerate() {
348            let response = svc.serve(ctx.clone(), i).await.unwrap();
349            assert_eq!(response, i * 2 + 1);
350        }
351    }
352
353    #[tokio::test]
354    async fn dynamic_dispatch() {
355        let services = vec![
356            AddSvc(1).boxed(),
357            AddSvc(2).boxed(),
358            AddSvc(3).boxed(),
359            MulSvc(4).boxed(),
360            MulSvc(5).boxed(),
361        ];
362
363        let ctx = Context::default();
364
365        for (i, svc) in services.into_iter().enumerate() {
366            let response = svc.serve(ctx.clone(), i).await.unwrap();
367            if i < 3 {
368                assert_eq!(response, i * 2 + 1);
369            } else {
370                assert_eq!(response, i * (i + 1));
371            }
372        }
373    }
374
375    #[tokio::test]
376    async fn service_arc() {
377        let svc = std::sync::Arc::new(AddSvc(1));
378
379        let ctx = Context::default();
380
381        let response = svc.serve(ctx, 1).await.unwrap();
382        assert_eq!(response, 2);
383    }
384
385    #[tokio::test]
386    async fn box_service_arc() {
387        let svc = std::sync::Arc::new(AddSvc(1)).boxed();
388
389        let ctx = Context::default();
390
391        let response = svc.serve(ctx, 1).await.unwrap();
392        assert_eq!(response, 2);
393    }
394
395    #[tokio::test]
396    async fn reject_svc() {
397        let svc = RejectService::default();
398
399        let ctx = Context::default();
400
401        let err = svc.serve(ctx, 1).await.unwrap_err();
402        assert_eq!(err.to_string(), RejectError::new().to_string());
403    }
404}