1use std::convert::Infallible;
4use std::fmt;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::Arc;
8
9pub trait Service<Input>: Sized + Send + Sync + 'static {
13 type Output: Send + 'static;
15
16 type Error: Send + 'static;
18
19 fn serve(
22 &self,
23 input: Input,
24 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_;
25
26 fn boxed(self) -> BoxService<Input, Self::Output, Self::Error> {
28 BoxService::new(self)
29 }
30}
31
32impl<Input> Service<Input> for ()
33where
34 Input: Send + 'static,
35{
36 type Output = Input;
37 type Error = Infallible;
38
39 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
40 Ok(input)
41 }
42}
43
44impl<S, Input> Service<Input> for std::sync::Arc<S>
45where
46 S: Service<Input>,
47{
48 type Output = S::Output;
49 type Error = S::Error;
50
51 #[inline]
52 fn serve(
53 &self,
54 input: Input,
55 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_ {
56 self.as_ref().serve(input)
57 }
58}
59
60impl<S, Input> Service<Input> for &'static S
61where
62 S: Service<Input>,
63{
64 type Output = S::Output;
65 type Error = S::Error;
66
67 #[inline(always)]
68 fn serve(
69 &self,
70 input: Input,
71 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_ {
72 (**self).serve(input)
73 }
74}
75
76impl<S, Input> Service<Input> for Box<S>
77where
78 S: Service<Input>,
79{
80 type Output = S::Output;
81 type Error = S::Error;
82
83 #[inline]
84 fn serve(
85 &self,
86
87 input: Input,
88 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_ {
89 self.as_ref().serve(input)
90 }
91}
92
93trait DynService<Input> {
98 type Output;
99 type Error;
100
101 #[allow(clippy::type_complexity)]
102 fn serve_box(
103 &self,
104
105 input: Input,
106 ) -> Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send + '_>>;
107}
108
109impl<Input, T> DynService<Input> for T
110where
111 T: Service<Input>,
112{
113 type Output = T::Output;
114 type Error = T::Error;
115
116 fn serve_box(
117 &self,
118
119 input: Input,
120 ) -> Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send + '_>> {
121 Box::pin(self.serve(input))
122 }
123}
124
125pub struct BoxService<Input, Output, Error> {
128 inner: Arc<dyn DynService<Input, Output = Output, Error = Error> + Send + Sync + 'static>,
129}
130
131impl<Input, Output, Error> Clone for BoxService<Input, Output, Error> {
132 fn clone(&self) -> Self {
133 Self {
134 inner: self.inner.clone(),
135 }
136 }
137}
138
139impl<Input, Output, Error> BoxService<Input, Output, Error> {
140 #[inline]
142 pub fn new<T>(service: T) -> Self
143 where
144 T: Service<Input, Output = Output, Error = Error>,
145 {
146 Self {
147 inner: Arc::new(service),
148 }
149 }
150}
151
152impl<Input, Output, Error> std::fmt::Debug for BoxService<Input, Output, Error> {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("BoxService").finish()
155 }
156}
157
158impl<Input, Output, Error> Service<Input> for BoxService<Input, Output, Error>
159where
160 Input: 'static,
161 Output: Send + 'static,
162 Error: Send + 'static,
163{
164 type Output = Output;
165 type Error = Error;
166
167 #[inline]
168 fn serve(
169 &self,
170
171 input: Input,
172 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_ {
173 self.inner.serve_box(input)
174 }
175
176 #[inline]
177 fn boxed(self) -> Self {
178 self
179 }
180}
181
182macro_rules! impl_service_either {
183 ($id:ident, $first:ident $(, $param:ident)* $(,)?) => {
184 impl<$first, $($param,)* Input, Output> Service<Input> for crate::combinators::$id<$first $(,$param)*>
185 where
186 $first: Service<Input, Output = Output>,
187 $(
188 $param: Service<Input, Output = Output, Error: Into<$first::Error>>,
189 )*
190 Input: Send + 'static,
191 Output: Send + 'static,
192 {
193 type Output = Output;
194 type Error = $first::Error;
195
196 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
197 match self {
198 crate::combinators::$id::$first(s) => s.serve(input).await,
199 $(
200 crate::combinators::$id::$param(s) => s.serve(input).await.map_err(Into::into),
201 )*
202 }
203 }
204 }
205 };
206}
207
208crate::combinators::impl_either!(impl_service_either);
209
210#[non_exhaustive]
211#[derive(Debug, Clone, Copy, Default)]
212pub struct MirrorService;
215
216impl MirrorService {
217 #[inline(always)]
219 #[must_use]
220 pub fn new() -> Self {
221 Self
222 }
223}
224
225impl<Input> Service<Input> for MirrorService
226where
227 Input: Send + 'static,
228{
229 type Output = Input;
230 type Error = Infallible;
231
232 #[inline]
233 fn serve(
234 &self,
235 input: Input,
236 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_ {
237 std::future::ready(Ok(input))
238 }
239}
240
241rama_utils::macros::error::static_str_error! {
242 #[doc = "Input rejected"]
243 pub struct RejectError;
244}
245
246pub struct RejectService<R = (), E = RejectError> {
248 error: E,
249 _phantom: PhantomData<fn() -> R>,
250}
251
252impl Default for RejectService {
253 fn default() -> Self {
254 Self {
255 error: RejectError,
256 _phantom: PhantomData,
257 }
258 }
259}
260
261impl<R, E: Clone + Send + Sync + 'static> RejectService<R, E> {
262 pub fn new(error: E) -> Self {
264 Self {
265 error,
266 _phantom: PhantomData,
267 }
268 }
269}
270
271impl<R, E: Clone> Clone for RejectService<R, E> {
272 fn clone(&self) -> Self {
273 Self {
274 error: self.error.clone(),
275 _phantom: PhantomData,
276 }
277 }
278}
279
280impl<R, E: fmt::Debug> fmt::Debug for RejectService<R, E> {
281 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 f.debug_struct("RejectService")
283 .field("error", &self.error)
284 .field(
285 "_phantom",
286 &format_args!("{}", std::any::type_name::<fn() -> R>()),
287 )
288 .finish()
289 }
290}
291
292impl<Input, Output, Error> Service<Input> for RejectService<Output, Error>
293where
294 Input: 'static,
295 Output: Send + 'static,
296 Error: Clone + Send + Sync + 'static,
297{
298 type Output = Output;
299 type Error = Error;
300
301 #[inline]
302 fn serve(
303 &self,
304
305 _input: Input,
306 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + '_ {
307 let error = self.error.clone();
308 std::future::ready(Err(error))
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use std::convert::Infallible;
316
317 #[derive(Debug)]
318 struct AddSvc(usize);
319
320 impl Service<usize> for AddSvc {
321 type Output = usize;
322 type Error = Infallible;
323
324 async fn serve(&self, input: usize) -> Result<Self::Output, Self::Error> {
325 Ok(self.0 + input)
326 }
327 }
328
329 #[derive(Debug)]
330 struct MulSvc(usize);
331
332 impl Service<usize> for MulSvc {
333 type Output = usize;
334 type Error = Infallible;
335
336 async fn serve(&self, input: usize) -> Result<Self::Output, Self::Error> {
337 Ok(self.0 * input)
338 }
339 }
340
341 #[test]
342 fn assert_send() {
343 use rama_utils::test_helpers::*;
344
345 assert_send::<AddSvc>();
346 assert_send::<MulSvc>();
347 assert_send::<BoxService<(), (), ()>>();
348 assert_send::<RejectService>();
349 }
350
351 #[test]
352 fn assert_sync() {
353 use rama_utils::test_helpers::*;
354
355 assert_sync::<AddSvc>();
356 assert_sync::<MulSvc>();
357 assert_sync::<BoxService<(), (), ()>>();
358 assert_sync::<RejectService>();
359 }
360
361 #[tokio::test]
362 async fn add_svc() {
363 let svc = AddSvc(1);
364
365 let output = svc.serve(1).await.unwrap();
366 assert_eq!(output, 2);
367 }
368
369 #[tokio::test]
370 async fn static_dispatch() {
371 let services = vec![AddSvc(1), AddSvc(2), AddSvc(3)];
372
373 for (i, svc) in services.into_iter().enumerate() {
374 let output = svc.serve(i).await.unwrap();
375 assert_eq!(output, i * 2 + 1);
376 }
377 }
378
379 #[tokio::test]
380 async fn dynamic_dispatch() {
381 let services = vec![
382 AddSvc(1).boxed(),
383 AddSvc(2).boxed(),
384 AddSvc(3).boxed(),
385 MulSvc(4).boxed(),
386 MulSvc(5).boxed(),
387 ];
388
389 for (i, svc) in services.into_iter().enumerate() {
390 let output = svc.serve(i).await.unwrap();
391 if i < 3 {
392 assert_eq!(output, i * 2 + 1);
393 } else {
394 assert_eq!(output, i * (i + 1));
395 }
396 }
397 }
398
399 #[tokio::test]
400 async fn service_arc() {
401 let svc = std::sync::Arc::new(AddSvc(1));
402
403 let output = svc.serve(1).await.unwrap();
404 assert_eq!(output, 2);
405 }
406
407 #[tokio::test]
408 async fn box_service_arc() {
409 let svc = std::sync::Arc::new(AddSvc(1)).boxed();
410
411 let output = svc.serve(1).await.unwrap();
412 assert_eq!(output, 2);
413 }
414
415 #[tokio::test]
416 async fn reject_svc() {
417 let svc = RejectService::default();
418
419 let err = svc.serve(1).await.unwrap_err();
420 assert_eq!(err.to_string(), RejectError::new().to_string());
421 }
422}