1use crate::Context;
4use crate::error::BoxError;
5use std::fmt;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::sync::Arc;
9
10pub trait Service<S, Request>: Sized + Send + Sync + 'static {
13 type Response: Send + 'static;
15
16 type Error: Send + 'static;
18
19 fn serve(
22 &self,
23 ctx: Context<S>,
24 req: Request,
25 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_;
26
27 fn boxed(self) -> BoxService<S, Request, Self::Response, Self::Error> {
29 BoxService::new(self)
30 }
31}
32
33impl<S, State, Request> Service<State, Request> for std::sync::Arc<S>
34where
35 S: Service<State, Request>,
36{
37 type Response = S::Response;
38 type Error = S::Error;
39
40 #[inline]
41 fn serve(
42 &self,
43 ctx: Context<State>,
44 req: Request,
45 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
46 self.as_ref().serve(ctx, req)
47 }
48}
49
50impl<S, State, Request> Service<State, Request> for &'static S
51where
52 S: Service<State, Request>,
53{
54 type Response = S::Response;
55 type Error = S::Error;
56
57 #[inline]
58 fn serve(
59 &self,
60 ctx: Context<State>,
61 req: Request,
62 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
63 (**self).serve(ctx, req)
64 }
65}
66
67impl<S, State, Request> Service<State, Request> for Box<S>
68where
69 S: Service<State, Request>,
70{
71 type Response = S::Response;
72 type Error = S::Error;
73
74 #[inline]
75 fn serve(
76 &self,
77 ctx: Context<State>,
78 req: Request,
79 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
80 self.as_ref().serve(ctx, req)
81 }
82}
83
84trait DynService<S, Request> {
89 type Response;
90 type Error;
91
92 #[allow(clippy::type_complexity)]
93 fn serve_box(
94 &self,
95 ctx: Context<S>,
96 req: Request,
97 ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>>;
98}
99
100impl<S, Request, T> DynService<S, Request> for T
101where
102 T: Service<S, Request>,
103{
104 type Response = T::Response;
105 type Error = T::Error;
106
107 fn serve_box(
108 &self,
109 ctx: Context<S>,
110 req: Request,
111 ) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>> {
112 Box::pin(self.serve(ctx, req))
113 }
114}
115
116pub struct BoxService<S, Request, Response, Error> {
119 inner:
120 Arc<dyn DynService<S, Request, Response = Response, Error = Error> + Send + Sync + 'static>,
121}
122
123impl<S, Request, Response, Error> Clone for BoxService<S, Request, Response, Error> {
124 fn clone(&self) -> Self {
125 Self {
126 inner: self.inner.clone(),
127 }
128 }
129}
130
131impl<S, Request, Response, Error> BoxService<S, Request, Response, Error> {
132 #[inline]
134 pub fn new<T>(service: T) -> Self
135 where
136 T: Service<S, Request, Response = Response, Error = Error>,
137 {
138 Self {
139 inner: Arc::new(service),
140 }
141 }
142}
143
144impl<S, Request, Response, Error> std::fmt::Debug for BoxService<S, Request, Response, Error> {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("BoxService").finish()
147 }
148}
149
150impl<S, Request, Response, Error> Service<S, Request> for BoxService<S, Request, Response, Error>
151where
152 S: 'static,
153 Request: 'static,
154 Response: Send + 'static,
155 Error: Send + 'static,
156{
157 type Response = Response;
158 type Error = Error;
159
160 #[inline]
161 fn serve(
162 &self,
163 ctx: Context<S>,
164 req: Request,
165 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
166 self.inner.serve_box(ctx, req)
167 }
168
169 #[inline]
170 fn boxed(self) -> Self {
171 self
172 }
173}
174
175macro_rules! impl_service_either {
176 ($id:ident, $($param:ident),+ $(,)?) => {
177 impl<$($param),+, State, Request, Response> Service<State, Request> for crate::combinators::$id<$($param),+>
178 where
179 $(
180 $param: Service<State, Request, Response = Response, Error: Into<BoxError>>,
181 )+
182 Request: Send + 'static,
183 State: Clone + Send + Sync + 'static,
184 Response: Send + 'static,
185 {
186 type Response = Response;
187 type Error = BoxError;
188
189 async fn serve(&self, ctx: Context<State>, req: Request) -> Result<Self::Response, Self::Error> {
190 match self {
191 $(
192 crate::combinators::$id::$param(s) => s.serve(ctx, req).await.map_err(Into::into),
193 )+
194 }
195 }
196 }
197 };
198}
199
200crate::combinators::impl_either!(impl_service_either);
201
202rama_utils::macros::error::static_str_error! {
203 #[doc = "request rejected"]
204 pub struct RejectError;
205}
206
207pub struct RejectService<R = (), E = RejectError> {
209 error: E,
210 _phantom: PhantomData<fn() -> R>,
211}
212
213impl Default for RejectService {
214 fn default() -> Self {
215 Self {
216 error: RejectError,
217 _phantom: PhantomData,
218 }
219 }
220}
221
222impl<R, E: Clone + Send + Sync + 'static> RejectService<R, E> {
223 pub fn new(error: E) -> Self {
225 Self {
226 error,
227 _phantom: PhantomData,
228 }
229 }
230}
231
232impl<R, E: Clone> Clone for RejectService<R, E> {
233 fn clone(&self) -> Self {
234 Self {
235 error: self.error.clone(),
236 _phantom: PhantomData,
237 }
238 }
239}
240
241impl<R, E: fmt::Debug> fmt::Debug for RejectService<R, E> {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 f.debug_struct("RejectService")
244 .field("error", &self.error)
245 .field(
246 "_phantom",
247 &format_args!("{}", std::any::type_name::<fn() -> R>()),
248 )
249 .finish()
250 }
251}
252
253impl<S, Request, Response, Error> Service<S, Request> for RejectService<Response, Error>
254where
255 S: 'static,
256 Request: 'static,
257 Response: Send + 'static,
258 Error: Clone + Send + Sync + 'static,
259{
260 type Response = Response;
261 type Error = Error;
262
263 #[inline]
264 fn serve(
265 &self,
266 _ctx: Context<S>,
267 _req: Request,
268 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
269 let error = self.error.clone();
270 std::future::ready(Err(error))
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use std::convert::Infallible;
278
279 #[derive(Debug)]
280 struct AddSvc(usize);
281
282 impl Service<(), usize> for AddSvc {
283 type Response = usize;
284 type Error = Infallible;
285
286 async fn serve(
287 &self,
288 _ctx: Context<()>,
289 req: usize,
290 ) -> Result<Self::Response, Self::Error> {
291 Ok(self.0 + req)
292 }
293 }
294
295 #[derive(Debug)]
296 struct MulSvc(usize);
297
298 impl Service<(), usize> for MulSvc {
299 type Response = usize;
300 type Error = Infallible;
301
302 async fn serve(
303 &self,
304 _ctx: Context<()>,
305 req: usize,
306 ) -> Result<Self::Response, Self::Error> {
307 Ok(self.0 * req)
308 }
309 }
310
311 #[test]
312 fn assert_send() {
313 use rama_utils::test_helpers::*;
314
315 assert_send::<AddSvc>();
316 assert_send::<MulSvc>();
317 assert_send::<BoxService<(), (), (), ()>>();
318 assert_send::<RejectService>();
319 }
320
321 #[test]
322 fn assert_sync() {
323 use rama_utils::test_helpers::*;
324
325 assert_sync::<AddSvc>();
326 assert_sync::<MulSvc>();
327 assert_sync::<BoxService<(), (), (), ()>>();
328 assert_sync::<RejectService>();
329 }
330
331 #[tokio::test]
332 async fn add_svc() {
333 let svc = AddSvc(1);
334
335 let ctx = Context::default();
336
337 let response = svc.serve(ctx, 1).await.unwrap();
338 assert_eq!(response, 2);
339 }
340
341 #[tokio::test]
342 async fn static_dispatch() {
343 let services = vec![AddSvc(1), AddSvc(2), AddSvc(3)];
344
345 let ctx = Context::default();
346
347 for (i, svc) in services.into_iter().enumerate() {
348 let response = svc.serve(ctx.clone(), i).await.unwrap();
349 assert_eq!(response, i * 2 + 1);
350 }
351 }
352
353 #[tokio::test]
354 async fn dynamic_dispatch() {
355 let services = vec![
356 AddSvc(1).boxed(),
357 AddSvc(2).boxed(),
358 AddSvc(3).boxed(),
359 MulSvc(4).boxed(),
360 MulSvc(5).boxed(),
361 ];
362
363 let ctx = Context::default();
364
365 for (i, svc) in services.into_iter().enumerate() {
366 let response = svc.serve(ctx.clone(), i).await.unwrap();
367 if i < 3 {
368 assert_eq!(response, i * 2 + 1);
369 } else {
370 assert_eq!(response, i * (i + 1));
371 }
372 }
373 }
374
375 #[tokio::test]
376 async fn service_arc() {
377 let svc = std::sync::Arc::new(AddSvc(1));
378
379 let ctx = Context::default();
380
381 let response = svc.serve(ctx, 1).await.unwrap();
382 assert_eq!(response, 2);
383 }
384
385 #[tokio::test]
386 async fn box_service_arc() {
387 let svc = std::sync::Arc::new(AddSvc(1)).boxed();
388
389 let ctx = Context::default();
390
391 let response = svc.serve(ctx, 1).await.unwrap();
392 assert_eq!(response, 2);
393 }
394
395 #[tokio::test]
396 async fn reject_svc() {
397 let svc = RejectService::default();
398
399 let ctx = Context::default();
400
401 let err = svc.serve(ctx, 1).await.unwrap_err();
402 assert_eq!(err.to_string(), RejectError::new().to_string());
403 }
404}