rama_http/service/web/endpoint/
mod.rs

1use crate::{Body, IntoResponse, Request, Response, matcher::HttpMatcher};
2use rama_core::{Context, Layer, Service, layer::MapResponseLayer, service::BoxService};
3use std::{convert::Infallible, fmt};
4
5pub mod extract;
6
7pub(crate) struct Endpoint<State> {
8    pub(crate) matcher: HttpMatcher<State, Body>,
9    pub(crate) service: BoxService<State, Request, Response, Infallible>,
10}
11
12/// utility trait to accept multiple types as an endpoint service for [`super::WebService`]
13pub trait IntoEndpointService<State, T>: private::Sealed<T> {
14    /// convert the type into a [`rama_core::Service`].
15    fn into_endpoint_service(
16        self,
17    ) -> impl Service<State, Request, Response = Response, Error = Infallible>;
18}
19
20impl<State, S, R> IntoEndpointService<State, (State, R)> for S
21where
22    State: Clone + Send + Sync + 'static,
23    S: Service<State, Request, Response = R, Error = Infallible>,
24    R: IntoResponse + Send + Sync + 'static,
25{
26    fn into_endpoint_service(
27        self,
28    ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
29        MapResponseLayer::new(R::into_response).into_layer(self)
30    }
31}
32
33impl<State, R> IntoEndpointService<State, ()> for R
34where
35    State: Clone + Send + Sync + 'static,
36    R: IntoResponse + Clone + Send + Sync + 'static,
37{
38    fn into_endpoint_service(
39        self,
40    ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
41        StaticService(self)
42    }
43}
44
45/// A static [`Service`] that serves a pre-defined response.
46pub struct StaticService<R>(R);
47
48impl<R> StaticService<R>
49where
50    R: IntoResponse + Clone + Send + Sync + 'static,
51{
52    /// Create a new [`StaticService`] with the given response.
53    pub fn new(response: R) -> Self {
54        Self(response)
55    }
56}
57
58impl<T> fmt::Debug for StaticService<T>
59where
60    T: fmt::Debug,
61{
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_tuple("StaticService").field(&self.0).finish()
64    }
65}
66
67impl<R> Clone for StaticService<R>
68where
69    R: Clone,
70{
71    fn clone(&self) -> Self {
72        Self(self.0.clone())
73    }
74}
75
76impl<R, State> Service<State, Request> for StaticService<R>
77where
78    R: IntoResponse + Clone + Send + Sync + 'static,
79    State: Clone + Send + Sync + 'static,
80{
81    type Response = Response;
82    type Error = Infallible;
83
84    async fn serve(&self, _: Context<State>, _: Request) -> Result<Self::Response, Self::Error> {
85        Ok(self.0.clone().into_response())
86    }
87}
88
89mod service;
90#[doc(inline)]
91pub use service::EndpointServiceFn;
92
93struct EndpointServiceFnWrapper<F, S, T> {
94    inner: F,
95    _marker: std::marker::PhantomData<fn(S, T) -> ()>,
96}
97
98impl<F: std::fmt::Debug, S, T> std::fmt::Debug for EndpointServiceFnWrapper<F, S, T> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("EndpointServiceFnWrapper")
101            .field("inner", &self.inner)
102            .field(
103                "_marker",
104                &format_args!("{}", std::any::type_name::<fn(S, T) -> ()>()),
105            )
106            .finish()
107    }
108}
109
110impl<F, S, T> Clone for EndpointServiceFnWrapper<F, S, T>
111where
112    F: Clone,
113{
114    fn clone(&self) -> Self {
115        Self {
116            inner: self.inner.clone(),
117            _marker: std::marker::PhantomData,
118        }
119    }
120}
121
122impl<F, S, T> Service<S, Request> for EndpointServiceFnWrapper<F, S, T>
123where
124    F: EndpointServiceFn<S, T>,
125    S: Clone + Send + Sync + 'static,
126    T: Send + 'static,
127{
128    type Response = Response;
129    type Error = Infallible;
130
131    async fn serve(&self, ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
132        Ok(self.inner.call(ctx, req).await)
133    }
134}
135
136impl<F, S, T> IntoEndpointService<S, (F, S, T)> for F
137where
138    F: EndpointServiceFn<S, T>,
139    S: Clone + Send + Sync + 'static,
140    T: Send + 'static,
141{
142    fn into_endpoint_service(
143        self,
144    ) -> impl Service<S, Request, Response = Response, Error = Infallible> {
145        EndpointServiceFnWrapper {
146            inner: self,
147            _marker: std::marker::PhantomData,
148        }
149    }
150}
151
152mod private {
153    use super::*;
154
155    pub trait Sealed<T> {}
156
157    impl<State, S, R> Sealed<(State, R)> for S
158    where
159        State: Clone + Send + Sync + 'static,
160        S: Service<State, Request, Response = R, Error = Infallible>,
161        R: IntoResponse + Send + Sync + 'static,
162    {
163    }
164
165    impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Fut, R)> for F
166    where
167        State: Clone + Send + Sync + 'static,
168        F: Fn(Context<State>) -> Fut + Send + Sync + 'static,
169        Fut: Future<Output = R> + Send + 'static,
170        R: IntoResponse + Send + Sync + 'static,
171    {
172    }
173
174    impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Request, Fut, R)> for F
175    where
176        State: Clone + Send + Sync + 'static,
177        F: Fn(Context<State>, Request) -> Fut + Send + Sync + 'static,
178        Fut: Future<Output = R> + Send + 'static,
179        R: IntoResponse + Send + Sync + 'static,
180    {
181    }
182
183    impl<R> Sealed<()> for R where R: IntoResponse + Send + Sync + 'static {}
184
185    impl<F, S, T> Sealed<(F, S, T)> for F where F: EndpointServiceFn<S, T> {}
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::{Body, Method, Request, StatusCode, dep::http_body_util::BodyExt};
192    use extract::*;
193
194    fn assert_into_endpoint_service<T, I>(_: I)
195    where
196        I: IntoEndpointService<(), T>,
197    {
198    }
199
200    #[test]
201    fn test_into_endpoint_service_static() {
202        assert_into_endpoint_service(StatusCode::OK);
203        assert_into_endpoint_service("hello");
204        assert_into_endpoint_service("hello".to_owned());
205    }
206
207    #[tokio::test]
208    async fn test_into_endpoint_service_impl() {
209        #[derive(Debug, Clone)]
210        struct OkService;
211
212        impl<State> Service<State, Request> for OkService
213        where
214            State: Clone + Send + Sync + 'static,
215        {
216            type Response = StatusCode;
217            type Error = Infallible;
218
219            async fn serve(
220                &self,
221                _ctx: Context<State>,
222                _req: Request,
223            ) -> Result<Self::Response, Self::Error> {
224                Ok(StatusCode::OK)
225            }
226        }
227
228        let svc = OkService;
229        let resp = svc
230            .serve(
231                Context::default(),
232                Request::builder()
233                    .uri("http://example.com")
234                    .body(Body::empty())
235                    .unwrap(),
236            )
237            .await
238            .unwrap();
239        assert_eq!(resp, StatusCode::OK);
240
241        assert_into_endpoint_service(svc)
242    }
243
244    #[test]
245    fn test_into_endpoint_service_fn_no_param() {
246        assert_into_endpoint_service(async || StatusCode::OK);
247        assert_into_endpoint_service(async || "hello");
248    }
249
250    #[tokio::test]
251    async fn test_service_fn_wrapper_no_param() {
252        let svc = async || StatusCode::OK;
253        let svc = svc.into_endpoint_service();
254
255        let resp = svc
256            .serve(
257                Context::default(),
258                Request::builder()
259                    .uri("http://example.com")
260                    .body(Body::empty())
261                    .unwrap(),
262            )
263            .await
264            .unwrap();
265        assert_eq!(resp.status(), StatusCode::OK);
266    }
267
268    #[tokio::test]
269    async fn test_service_fn_wrapper_single_param_request() {
270        let svc = async |req: Request| req.uri().to_string();
271        let svc = svc.into_endpoint_service();
272
273        let resp = svc
274            .serve(
275                Context::default(),
276                Request::builder()
277                    .uri("http://example.com")
278                    .body(Body::empty())
279                    .unwrap(),
280            )
281            .await
282            .unwrap();
283        assert_eq!(resp.status(), StatusCode::OK);
284        let body = resp.into_body().collect().await.unwrap().to_bytes();
285        assert_eq!(body, "http://example.com/")
286    }
287
288    #[tokio::test]
289    async fn test_service_fn_wrapper_single_param_host() {
290        let svc = async |Host(host): Host| host.to_string();
291        let svc = svc.into_endpoint_service();
292
293        let resp = svc
294            .serve(
295                Context::default(),
296                Request::builder()
297                    .uri("http://example.com")
298                    .body(Body::empty())
299                    .unwrap(),
300            )
301            .await
302            .unwrap();
303        assert_eq!(resp.status(), StatusCode::OK);
304        let body = resp.into_body().collect().await.unwrap().to_bytes();
305        assert_eq!(body, "example.com")
306    }
307
308    #[tokio::test]
309    async fn test_service_fn_wrapper_multi_param_host() {
310        #[derive(Debug, Clone, serde::Deserialize)]
311        struct Params {
312            foo: String,
313        }
314
315        let svc = crate::service::web::WebService::default().get(
316            "/:foo/bar",
317            async |Host(host): Host, Path(params): Path<Params>| {
318                format!("{} => {}", host, params.foo)
319            },
320        );
321        let svc = svc.into_endpoint_service();
322
323        let resp = svc
324            .serve(
325                Context::default(),
326                Request::builder()
327                    .uri("http://example.com/42/bar")
328                    .body(Body::empty())
329                    .unwrap(),
330            )
331            .await
332            .unwrap();
333        assert_eq!(resp.status(), StatusCode::OK);
334        let body = resp.into_body().collect().await.unwrap().to_bytes();
335        assert_eq!(body, "example.com => 42")
336    }
337
338    #[test]
339    fn test_into_endpoint_service_fn_single_param() {
340        #[derive(Debug, Clone, serde::Deserialize)]
341        struct Params {
342            foo: String,
343        }
344
345        assert_into_endpoint_service(async |_path: Path<Params>| StatusCode::OK);
346        assert_into_endpoint_service(async |Path(params): Path<Params>| params.foo);
347        assert_into_endpoint_service(async |Query(query): Query<Params>| query.foo);
348        assert_into_endpoint_service(async |method: Method| method.to_string());
349        assert_into_endpoint_service(async |req: Request| req.uri().to_string());
350        assert_into_endpoint_service(async |_host: Host| StatusCode::OK);
351        assert_into_endpoint_service(async |Host(_host): Host| StatusCode::OK);
352    }
353}