rama_http/service/web/endpoint/
mod.rs

1use crate::{matcher::HttpMatcher, Body, IntoResponse, Request, Response};
2use rama_core::{layer::MapResponseLayer, service::BoxService, Context, Layer, Service};
3use std::future::Future;
4use std::{convert::Infallible, fmt};
5
6pub mod extract;
7
8pub(crate) struct Endpoint<State> {
9    pub(crate) matcher: HttpMatcher<State, Body>,
10    pub(crate) service: BoxService<State, Request, Response, Infallible>,
11}
12
13/// utility trait to accept multiple types as an endpoint service for [`super::WebService`]
14pub trait IntoEndpointService<State, T>: private::Sealed<T> {
15    /// convert the type into a [`rama_core::Service`].
16    fn into_endpoint_service(
17        self,
18    ) -> impl Service<State, Request, Response = Response, Error = Infallible>;
19}
20
21impl<State, S, R> IntoEndpointService<State, (State, R)> for S
22where
23    State: Clone + Send + Sync + 'static,
24    S: Service<State, Request, Response = R, Error = Infallible>,
25    R: IntoResponse + Send + Sync + 'static,
26{
27    fn into_endpoint_service(
28        self,
29    ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
30        MapResponseLayer::new(R::into_response).layer(self)
31    }
32}
33
34impl<State, R> IntoEndpointService<State, ()> for R
35where
36    State: Clone + Send + Sync + 'static,
37    R: IntoResponse + Clone + Send + Sync + 'static,
38{
39    fn into_endpoint_service(
40        self,
41    ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
42        StaticService(self)
43    }
44}
45
46struct StaticService<R>(R);
47
48impl<T> fmt::Debug for StaticService<T>
49where
50    T: fmt::Debug,
51{
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_tuple("StaticService").field(&self.0).finish()
54    }
55}
56
57impl<R> Clone for StaticService<R>
58where
59    R: Clone,
60{
61    fn clone(&self) -> Self {
62        Self(self.0.clone())
63    }
64}
65
66impl<R, State> Service<State, Request> for StaticService<R>
67where
68    R: IntoResponse + Clone + Send + Sync + 'static,
69    State: Clone + Send + Sync + 'static,
70{
71    type Response = Response;
72    type Error = Infallible;
73
74    async fn serve(&self, _: Context<State>, _: Request) -> Result<Self::Response, Self::Error> {
75        Ok(self.0.clone().into_response())
76    }
77}
78
79mod service;
80#[doc(inline)]
81pub use service::EndpointServiceFn;
82
83struct EndpointServiceFnWrapper<F, S, T> {
84    inner: F,
85    _marker: std::marker::PhantomData<fn(S, T) -> ()>,
86}
87
88impl<F: std::fmt::Debug, S, T> std::fmt::Debug for EndpointServiceFnWrapper<F, S, T> {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("EndpointServiceFnWrapper")
91            .field("inner", &self.inner)
92            .field(
93                "_marker",
94                &format_args!("{}", std::any::type_name::<fn(S, T) -> ()>()),
95            )
96            .finish()
97    }
98}
99
100impl<F, S, T> Clone for EndpointServiceFnWrapper<F, S, T>
101where
102    F: Clone,
103{
104    fn clone(&self) -> Self {
105        Self {
106            inner: self.inner.clone(),
107            _marker: std::marker::PhantomData,
108        }
109    }
110}
111
112impl<F, S, T> Service<S, Request> for EndpointServiceFnWrapper<F, S, T>
113where
114    F: EndpointServiceFn<S, T>,
115    S: Clone + Send + Sync + 'static,
116    T: Send + 'static,
117{
118    type Response = Response;
119    type Error = Infallible;
120
121    async fn serve(&self, ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
122        Ok(self.inner.call(ctx, req).await)
123    }
124}
125
126impl<F, S, T> IntoEndpointService<S, (F, S, T)> for F
127where
128    F: EndpointServiceFn<S, T>,
129    S: Clone + Send + Sync + 'static,
130    T: Send + 'static,
131{
132    fn into_endpoint_service(
133        self,
134    ) -> impl Service<S, Request, Response = Response, Error = Infallible> {
135        EndpointServiceFnWrapper {
136            inner: self,
137            _marker: std::marker::PhantomData,
138        }
139    }
140}
141
142mod private {
143    use super::*;
144
145    pub trait Sealed<T> {}
146
147    impl<State, S, R> Sealed<(State, R)> for S
148    where
149        State: Clone + Send + Sync + 'static,
150        S: Service<State, Request, Response = R, Error = Infallible>,
151        R: IntoResponse + Send + Sync + 'static,
152    {
153    }
154
155    impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Fut, R)> for F
156    where
157        State: Clone + Send + Sync + 'static,
158        F: Fn(Context<State>) -> Fut + Send + Sync + 'static,
159        Fut: Future<Output = R> + Send + 'static,
160        R: IntoResponse + Send + Sync + 'static,
161    {
162    }
163
164    impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Request, Fut, R)> for F
165    where
166        State: Clone + Send + Sync + 'static,
167        F: Fn(Context<State>, Request) -> Fut + Send + Sync + 'static,
168        Fut: Future<Output = R> + Send + 'static,
169        R: IntoResponse + Send + Sync + 'static,
170    {
171    }
172
173    impl<R> Sealed<()> for R where R: IntoResponse + Send + Sync + 'static {}
174
175    impl<F, S, T> Sealed<(F, S, T)> for F where F: EndpointServiceFn<S, T> {}
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::{dep::http_body_util::BodyExt, Body, Method, Request, StatusCode};
182    use extract::*;
183
184    fn assert_into_endpoint_service<T, I>(_: I)
185    where
186        I: IntoEndpointService<(), T>,
187    {
188    }
189
190    #[test]
191    fn test_into_endpoint_service_static() {
192        assert_into_endpoint_service(StatusCode::OK);
193        assert_into_endpoint_service("hello");
194        assert_into_endpoint_service("hello".to_owned());
195    }
196
197    #[tokio::test]
198    async fn test_into_endpoint_service_impl() {
199        #[derive(Debug, Clone)]
200        struct OkService;
201
202        impl<State> Service<State, Request> for OkService
203        where
204            State: Clone + Send + Sync + 'static,
205        {
206            type Response = StatusCode;
207            type Error = Infallible;
208
209            async fn serve(
210                &self,
211                _ctx: Context<State>,
212                _req: Request,
213            ) -> Result<Self::Response, Self::Error> {
214                Ok(StatusCode::OK)
215            }
216        }
217
218        let svc = OkService;
219        let resp = svc
220            .serve(
221                Context::default(),
222                Request::builder()
223                    .uri("http://example.com")
224                    .body(Body::empty())
225                    .unwrap(),
226            )
227            .await
228            .unwrap();
229        assert_eq!(resp, StatusCode::OK);
230
231        assert_into_endpoint_service(svc)
232    }
233
234    #[test]
235    fn test_into_endpoint_service_fn_no_param() {
236        assert_into_endpoint_service(|| async { StatusCode::OK });
237        assert_into_endpoint_service(|| async { "hello" });
238    }
239
240    #[tokio::test]
241    async fn test_service_fn_wrapper_no_param() {
242        let svc = || async { StatusCode::OK };
243        let svc = svc.into_endpoint_service();
244
245        let resp = svc
246            .serve(
247                Context::default(),
248                Request::builder()
249                    .uri("http://example.com")
250                    .body(Body::empty())
251                    .unwrap(),
252            )
253            .await
254            .unwrap();
255        assert_eq!(resp.status(), StatusCode::OK);
256    }
257
258    #[tokio::test]
259    async fn test_service_fn_wrapper_single_param_request() {
260        let svc = |req: Request| async move { req.uri().to_string() };
261        let svc = svc.into_endpoint_service();
262
263        let resp = svc
264            .serve(
265                Context::default(),
266                Request::builder()
267                    .uri("http://example.com")
268                    .body(Body::empty())
269                    .unwrap(),
270            )
271            .await
272            .unwrap();
273        assert_eq!(resp.status(), StatusCode::OK);
274        let body = resp.into_body().collect().await.unwrap().to_bytes();
275        assert_eq!(body, "http://example.com/")
276    }
277
278    #[tokio::test]
279    async fn test_service_fn_wrapper_single_param_host() {
280        let svc = |Host(host): Host| async move { host.to_string() };
281        let svc = svc.into_endpoint_service();
282
283        let resp = svc
284            .serve(
285                Context::default(),
286                Request::builder()
287                    .uri("http://example.com")
288                    .body(Body::empty())
289                    .unwrap(),
290            )
291            .await
292            .unwrap();
293        assert_eq!(resp.status(), StatusCode::OK);
294        let body = resp.into_body().collect().await.unwrap().to_bytes();
295        assert_eq!(body, "example.com")
296    }
297
298    #[tokio::test]
299    async fn test_service_fn_wrapper_multi_param_host() {
300        #[derive(Debug, Clone, serde::Deserialize)]
301        struct Params {
302            foo: String,
303        }
304
305        let svc = crate::service::web::WebService::default().get(
306            "/:foo/bar",
307            |Host(host): Host, Path(params): Path<Params>| async move {
308                format!("{} => {}", host, params.foo)
309            },
310        );
311        let svc = svc.into_endpoint_service();
312
313        let resp = svc
314            .serve(
315                Context::default(),
316                Request::builder()
317                    .uri("http://example.com/42/bar")
318                    .body(Body::empty())
319                    .unwrap(),
320            )
321            .await
322            .unwrap();
323        assert_eq!(resp.status(), StatusCode::OK);
324        let body = resp.into_body().collect().await.unwrap().to_bytes();
325        assert_eq!(body, "example.com => 42")
326    }
327
328    #[test]
329    fn test_into_endpoint_service_fn_single_param() {
330        #[derive(Debug, Clone, serde::Deserialize)]
331        struct Params {
332            foo: String,
333        }
334
335        assert_into_endpoint_service(|_path: Path<Params>| async { StatusCode::OK });
336        assert_into_endpoint_service(|Path(params): Path<Params>| async move { params.foo });
337        assert_into_endpoint_service(|Query(query): Query<Params>| async move { query.foo });
338        assert_into_endpoint_service(|method: Method| async move { method.to_string() });
339        assert_into_endpoint_service(|req: Request| async move { req.uri().to_string() });
340        assert_into_endpoint_service(|_host: Host| async { StatusCode::OK });
341        assert_into_endpoint_service(|Host(_host): Host| async { StatusCode::OK });
342    }
343}