rama_core/service/
handler.rs

1//! `async fn(...)` as [`crate`].
2
3use std::marker::PhantomData;
4
5use crate::{Context, Service};
6
7/// Create a [`ServiceFn`] from a function.
8pub fn service_fn<F, T, R, O, E>(f: F) -> ServiceFn<F, T, R, O, E>
9where
10    F: Factory<T, R, O, E>,
11    R: Future<Output = Result<O, E>>,
12{
13    ServiceFn::new(f)
14}
15
16/// Async handler converter factory
17pub trait Factory<T, R, O, E>: Send + Sync + 'static
18where
19    R: Future<Output = Result<O, E>>,
20{
21    /// Call the handler function with the given parameter.
22    fn call(&self, param: T) -> R;
23}
24
25impl<F, R, O, E> Factory<(), R, O, E> for F
26where
27    F: Fn() -> R + Send + Sync + 'static,
28    R: Future<Output = Result<O, E>>,
29{
30    fn call(&self, _: ()) -> R {
31        (self)()
32    }
33}
34
35impl<State, Request, F, R, O, E> Factory<(Context<State>, Request), R, O, E> for F
36where
37    F: Fn(Context<State>, Request) -> R + Send + Sync + 'static,
38    R: Future<Output = Result<O, E>>,
39{
40    fn call(&self, (ctx, req): (Context<State>, Request)) -> R {
41        (self)(ctx, req)
42    }
43}
44
45impl<Request, F, R, O, E> Factory<((), Request), R, O, E> for F
46where
47    F: Fn(Request) -> R + Send + Sync + 'static,
48    R: Future<Output = Result<O, E>>,
49{
50    fn call(&self, ((), req): ((), Request)) -> R {
51        (self)(req)
52    }
53}
54
55/// A [`ServiceFn`] is a [`Service`] implemented using a function.
56///
57/// You do not need to implement this trait yourself.
58/// Instead, you need to use the [`service_fn`] function to create a [`ServiceFn`].
59///
60/// [`Service`]: crate
61pub struct ServiceFn<F, T, R, O, E>
62where
63    F: Factory<T, R, O, E>,
64    R: Future<Output = Result<O, E>>,
65{
66    hnd: F,
67    _t: PhantomData<fn(T, R, O) -> ()>,
68}
69
70impl<F, T, R, O, E> std::fmt::Debug for ServiceFn<F, T, R, O, E>
71where
72    F: Factory<T, R, O, E>,
73    R: Future<Output = Result<O, E>>,
74{
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("ServiceFn").finish()
77    }
78}
79
80impl<F, T, R, O, E> ServiceFn<F, T, R, O, E>
81where
82    F: Factory<T, R, O, E>,
83    R: Future<Output = Result<O, E>>,
84{
85    pub(crate) fn new(hnd: F) -> Self {
86        Self {
87            hnd,
88            _t: PhantomData,
89        }
90    }
91}
92
93impl<F, T, R, O, E> Clone for ServiceFn<F, T, R, O, E>
94where
95    F: Factory<T, R, O, E> + Clone,
96    R: Future<Output = Result<O, E>>,
97{
98    fn clone(&self) -> Self {
99        Self {
100            hnd: self.hnd.clone(),
101            _t: PhantomData,
102        }
103    }
104}
105
106impl<State, Request, F, T, R, O, E> Service<State, Request> for ServiceFn<F, T, R, O, E>
107where
108    F: Factory<T, R, O, E>,
109    R: Future<Output = Result<O, E>> + Send + 'static,
110    T: FromContextRequest<State, Request>,
111    O: Send + 'static,
112    E: Send + Sync + 'static,
113{
114    type Response = O;
115    type Error = E;
116
117    fn serve(
118        &self,
119        ctx: Context<State>,
120        req: Request,
121    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
122        let param = T::from_context_request(ctx, req);
123        self.hnd.call(param)
124    }
125}
126
127/// Convert a context+request into a parameter for the [`ServiceFn`] handler function.
128pub trait FromContextRequest<State, Request>: Send + 'static {
129    /// Convert a context+request into a parameter for the [`ServiceFn`] handler function.
130    fn from_context_request(ctx: Context<State>, req: Request) -> Self;
131}
132
133impl<State, Request> FromContextRequest<State, Request> for () {
134    fn from_context_request(_ctx: Context<State>, _req: Request) -> Self {}
135}
136
137impl<State, Request> FromContextRequest<State, Request> for ((), Request)
138where
139    State: Clone + Send + Sync + 'static,
140    Request: Send + 'static,
141{
142    fn from_context_request(_ctx: Context<State>, req: Request) -> Self {
143        ((), req)
144    }
145}
146
147impl<State, Request> FromContextRequest<State, Request> for (Context<State>, Request)
148where
149    State: Clone + Send + Sync + 'static,
150    Request: Send + 'static,
151{
152    fn from_context_request(ctx: Context<State>, req: Request) -> Self {
153        (ctx, req)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use std::convert::Infallible;
160
161    use super::*;
162    use crate::Context;
163
164    #[tokio::test]
165    async fn test_service_fn() {
166        let services = vec![
167            service_fn(async || Ok(())).boxed(),
168            service_fn(async |req: String| {
169                assert_eq!(req, "hello");
170                Ok(())
171            })
172            .boxed(),
173            service_fn(async |_ctx: Context<()>, req: String| {
174                assert_eq!(req, "hello");
175                Ok(())
176            })
177            .boxed(),
178        ];
179
180        for service in services {
181            let ctx = Context::default();
182            let req = "hello".to_owned();
183            let res: Result<(), Infallible> = service.serve(ctx, req).await;
184            assert!(res.is_ok());
185        }
186    }
187
188    fn assert_send_sync<T: Send + Sync + 'static>(_t: T) {}
189
190    #[test]
191    fn test_service_fn_without_usage() {
192        assert_send_sync(service_fn(async || Ok::<_, Infallible>(())));
193        assert_send_sync(service_fn(async |_req: String| Ok::<_, Infallible>(())));
194        assert_send_sync(service_fn(async |_ctx: Context<()>, _req: String| {
195            Ok::<_, Infallible>(())
196        }));
197    }
198}