1use crate::Context;
4use crate::error::BoxError;
5use std::pin::Pin;
6
7pub trait Service<S, Request>: Sized + Send + Sync + 'static {
10 type Response: Send + 'static;
12
13 type Error: Send + 'static;
15
16 fn serve(
19 &self,
20 ctx: Context<S>,
21 req: Request,
22 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_;
23
24 fn boxed(self) -> BoxService<S, Request, Self::Response, Self::Error> {
26 BoxService {
27 inner: Box::new(self),
28 }
29 }
30}
31
32impl<S, State, Request> Service<State, Request> for std::sync::Arc<S>
33where
34 S: Service<State, Request>,
35{
36 type Response = S::Response;
37 type Error = S::Error;
38
39 #[inline]
40 fn serve(
41 &self,
42 ctx: Context<State>,
43 req: Request,
44 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
45 self.as_ref().serve(ctx, req)
46 }
47}
48
49impl<S, State, Request> Service<State, Request> for &'static S
50where
51 S: Service<State, Request>,
52{
53 type Response = S::Response;
54 type Error = S::Error;
55
56 #[inline]
57 fn serve(
58 &self,
59 ctx: Context<State>,
60 req: Request,
61 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
62 (**self).serve(ctx, req)
63 }
64}
65
66impl<S, State, Request> Service<State, Request> for Box<S>
67where
68 S: Service<State, Request>,
69{
70 type Response = S::Response;
71 type Error = S::Error;
72
73 #[inline]
74 fn serve(
75 &self,
76 ctx: Context<State>,
77 req: Request,
78 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
79 self.as_ref().serve(ctx, req)
80 }
81}
82
83trait DynService<S, Request> {
88 type Response;
89 type Error;
90
91 #[allow(clippy::type_complexity)]
92 fn serve_box(
93 &self,
94 ctx: Context<S>,
95 req: Request,
96 ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>>;
97}
98
99impl<S, Request, T> DynService<S, Request> for T
100where
101 T: Service<S, Request>,
102{
103 type Response = T::Response;
104 type Error = T::Error;
105
106 fn serve_box(
107 &self,
108 ctx: Context<S>,
109 req: Request,
110 ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>> {
111 Box::pin(self.serve(ctx, req))
112 }
113}
114
115pub struct BoxService<S, Request, Response, Error> {
118 inner:
119 Box<dyn DynService<S, Request, Response = Response, Error = Error> + Send + Sync + 'static>,
120}
121
122impl<S, Request, Response, Error> BoxService<S, Request, Response, Error> {
123 pub fn new<T>(service: T) -> Self
125 where
126 T: Service<S, Request, Response = Response, Error = Error>,
127 {
128 Self {
129 inner: Box::new(service),
130 }
131 }
132}
133
134impl<S, Request, Response, Error> std::fmt::Debug for BoxService<S, Request, Response, Error> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("BoxService").finish()
137 }
138}
139
140impl<S, Request, Response, Error> Service<S, Request> for BoxService<S, Request, Response, Error>
141where
142 S: 'static,
143 Request: 'static,
144 Response: Send + 'static,
145 Error: Send + 'static,
146{
147 type Response = Response;
148 type Error = Error;
149
150 fn serve(
151 &self,
152 ctx: Context<S>,
153 req: Request,
154 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
155 self.inner.serve_box(ctx, req)
156 }
157}
158
159macro_rules! impl_service_either {
160 ($id:ident, $($param:ident),+ $(,)?) => {
161 impl<$($param),+, State, Request, Response> Service<State, Request> for crate::combinators::$id<$($param),+>
162 where
163 $(
164 $param: Service<State, Request, Response = Response, Error: Into<BoxError>>,
165 )+
166 Request: Send + 'static,
167 State: Clone + Send + Sync + 'static,
168 Response: Send + 'static,
169 {
170 type Response = Response;
171 type Error = BoxError;
172
173 async fn serve(&self, ctx: Context<State>, req: Request) -> Result<Self::Response, Self::Error> {
174 match self {
175 $(
176 crate::combinators::$id::$param(s) => s.serve(ctx, req).await.map_err(Into::into),
177 )+
178 }
179 }
180 }
181 };
182}
183
184crate::combinators::impl_either!(impl_service_either);
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::convert::Infallible;
190
191 #[derive(Debug)]
192 struct AddSvc(usize);
193
194 impl Service<(), usize> for AddSvc {
195 type Response = usize;
196 type Error = Infallible;
197
198 async fn serve(
199 &self,
200 _ctx: Context<()>,
201 req: usize,
202 ) -> Result<Self::Response, Self::Error> {
203 Ok(self.0 + req)
204 }
205 }
206
207 #[derive(Debug)]
208 struct MulSvc(usize);
209
210 impl Service<(), usize> for MulSvc {
211 type Response = usize;
212 type Error = Infallible;
213
214 async fn serve(
215 &self,
216 _ctx: Context<()>,
217 req: usize,
218 ) -> Result<Self::Response, Self::Error> {
219 Ok(self.0 * req)
220 }
221 }
222
223 #[test]
224 fn assert_send() {
225 use rama_utils::test_helpers::*;
226
227 assert_send::<AddSvc>();
228 assert_send::<MulSvc>();
229 assert_send::<BoxService<(), (), (), ()>>();
230 }
231
232 #[test]
233 fn assert_sync() {
234 use rama_utils::test_helpers::*;
235
236 assert_sync::<AddSvc>();
237 assert_sync::<MulSvc>();
238 assert_sync::<BoxService<(), (), (), ()>>();
239 }
240
241 #[tokio::test]
242 async fn add_svc() {
243 let svc = AddSvc(1);
244
245 let ctx = Context::default();
246
247 let response = svc.serve(ctx, 1).await.unwrap();
248 assert_eq!(response, 2);
249 }
250
251 #[tokio::test]
252 async fn static_dispatch() {
253 let services = vec![AddSvc(1), AddSvc(2), AddSvc(3)];
254
255 let ctx = Context::default();
256
257 for (i, svc) in services.into_iter().enumerate() {
258 let response = svc.serve(ctx.clone(), i).await.unwrap();
259 assert_eq!(response, i * 2 + 1);
260 }
261 }
262
263 #[tokio::test]
264 async fn dynamic_dispatch() {
265 let services = vec![
266 AddSvc(1).boxed(),
267 AddSvc(2).boxed(),
268 AddSvc(3).boxed(),
269 MulSvc(4).boxed(),
270 MulSvc(5).boxed(),
271 ];
272
273 let ctx = Context::default();
274
275 for (i, svc) in services.into_iter().enumerate() {
276 let response = svc.serve(ctx.clone(), i).await.unwrap();
277 if i < 3 {
278 assert_eq!(response, i * 2 + 1);
279 } else {
280 assert_eq!(response, i * (i + 1));
281 }
282 }
283 }
284
285 #[tokio::test]
286 async fn service_arc() {
287 let svc = std::sync::Arc::new(AddSvc(1));
288
289 let ctx = Context::default();
290
291 let response = svc.serve(ctx, 1).await.unwrap();
292 assert_eq!(response, 2);
293 }
294
295 #[tokio::test]
296 async fn box_service_arc() {
297 let svc = std::sync::Arc::new(AddSvc(1)).boxed();
298
299 let ctx = Context::default();
300
301 let response = svc.serve(ctx, 1).await.unwrap();
302 assert_eq!(response, 2);
303 }
304}