rama_http/service/web/endpoint/
mod.rs

1use crate::{Body, Request, Response, matcher::HttpMatcher};
2use rama_core::{Context, Layer, Service, layer::MapResponseLayer, service::BoxService};
3use std::{convert::Infallible, fmt};
4
5pub mod extract;
6pub mod response;
7
8use response::IntoResponse;
9
10pub(crate) struct Endpoint<State> {
11    pub(crate) matcher: HttpMatcher<State, Body>,
12    pub(crate) service: BoxService<State, Request, Response, Infallible>,
13}
14
15/// utility trait to accept multiple types as an endpoint service for [`super::WebService`]
16pub trait IntoEndpointService<State, T>: private::Sealed<T> {
17    /// convert the type into a [`rama_core::Service`].
18    fn into_endpoint_service(
19        self,
20    ) -> impl Service<State, Request, Response = Response, Error = Infallible>;
21}
22
23impl<State, S, R> IntoEndpointService<State, (State, R)> for S
24where
25    State: Clone + Send + Sync + 'static,
26    S: Service<State, Request, Response = R, Error = Infallible>,
27    R: IntoResponse + Send + Sync + 'static,
28{
29    fn into_endpoint_service(
30        self,
31    ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
32        MapResponseLayer::new(R::into_response).into_layer(self)
33    }
34}
35
36impl<State, R> IntoEndpointService<State, ()> for R
37where
38    State: Clone + Send + Sync + 'static,
39    R: IntoResponse + Clone + Send + Sync + 'static,
40{
41    fn into_endpoint_service(
42        self,
43    ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
44        StaticService(self)
45    }
46}
47
48/// A static [`Service`] that serves a pre-defined response.
49pub struct StaticService<R>(R);
50
51impl<R> StaticService<R>
52where
53    R: IntoResponse + Clone + Send + Sync + 'static,
54{
55    /// Create a new [`StaticService`] with the given response.
56    pub fn new(response: R) -> Self {
57        Self(response)
58    }
59}
60
61impl<T> fmt::Debug for StaticService<T>
62where
63    T: fmt::Debug,
64{
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_tuple("StaticService").field(&self.0).finish()
67    }
68}
69
70impl<R> Clone for StaticService<R>
71where
72    R: Clone,
73{
74    fn clone(&self) -> Self {
75        Self(self.0.clone())
76    }
77}
78
79impl<R, State> Service<State, Request> for StaticService<R>
80where
81    R: IntoResponse + Clone + Send + Sync + 'static,
82    State: Clone + Send + Sync + 'static,
83{
84    type Response = Response;
85    type Error = Infallible;
86
87    async fn serve(&self, _: Context<State>, _: Request) -> Result<Self::Response, Self::Error> {
88        Ok(self.0.clone().into_response())
89    }
90}
91
92mod service;
93#[doc(inline)]
94pub use service::EndpointServiceFn;
95
96struct EndpointServiceFnWrapper<F, S, T> {
97    inner: F,
98    _marker: std::marker::PhantomData<fn(S, T) -> ()>,
99}
100
101impl<F: std::fmt::Debug, S, T> std::fmt::Debug for EndpointServiceFnWrapper<F, S, T> {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("EndpointServiceFnWrapper")
104            .field("inner", &self.inner)
105            .field(
106                "_marker",
107                &format_args!("{}", std::any::type_name::<fn(S, T) -> ()>()),
108            )
109            .finish()
110    }
111}
112
113impl<F, S, T> Clone for EndpointServiceFnWrapper<F, S, T>
114where
115    F: Clone,
116{
117    fn clone(&self) -> Self {
118        Self {
119            inner: self.inner.clone(),
120            _marker: std::marker::PhantomData,
121        }
122    }
123}
124
125impl<F, S, T> Service<S, Request> for EndpointServiceFnWrapper<F, S, T>
126where
127    F: EndpointServiceFn<S, T>,
128    S: Clone + Send + Sync + 'static,
129    T: Send + 'static,
130{
131    type Response = Response;
132    type Error = Infallible;
133
134    async fn serve(&self, ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
135        Ok(self.inner.call(ctx, req).await)
136    }
137}
138
139impl<F, S, T> IntoEndpointService<S, (F, S, T)> for F
140where
141    F: EndpointServiceFn<S, T>,
142    S: Clone + Send + Sync + 'static,
143    T: Send + 'static,
144{
145    fn into_endpoint_service(
146        self,
147    ) -> impl Service<S, Request, Response = Response, Error = Infallible> {
148        EndpointServiceFnWrapper {
149            inner: self,
150            _marker: std::marker::PhantomData,
151        }
152    }
153}
154
155mod private {
156    use super::*;
157
158    pub trait Sealed<T> {}
159
160    impl<State, S, R> Sealed<(State, R)> for S
161    where
162        State: Clone + Send + Sync + 'static,
163        S: Service<State, Request, Response = R, Error = Infallible>,
164        R: IntoResponse + Send + Sync + 'static,
165    {
166    }
167
168    impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Fut, R)> for F
169    where
170        State: Clone + Send + Sync + 'static,
171        F: Fn(Context<State>) -> Fut + Send + Sync + 'static,
172        Fut: Future<Output = R> + Send + 'static,
173        R: IntoResponse + Send + Sync + 'static,
174    {
175    }
176
177    impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Request, Fut, R)> for F
178    where
179        State: Clone + Send + Sync + 'static,
180        F: Fn(Context<State>, Request) -> Fut + Send + Sync + 'static,
181        Fut: Future<Output = R> + Send + 'static,
182        R: IntoResponse + Send + Sync + 'static,
183    {
184    }
185
186    impl<R> Sealed<()> for R where R: IntoResponse + Send + Sync + 'static {}
187
188    impl<F, S, T> Sealed<(F, S, T)> for F where F: EndpointServiceFn<S, T> {}
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::{Body, Method, Request, StatusCode, dep::http_body_util::BodyExt};
195    use extract::*;
196
197    fn assert_into_endpoint_service<T, I>(_: I)
198    where
199        I: IntoEndpointService<(), T>,
200    {
201    }
202
203    #[test]
204    fn test_into_endpoint_service_static() {
205        assert_into_endpoint_service(StatusCode::OK);
206        assert_into_endpoint_service("hello");
207        assert_into_endpoint_service("hello".to_owned());
208    }
209
210    #[tokio::test]
211    async fn test_into_endpoint_service_impl() {
212        #[derive(Debug, Clone)]
213        struct OkService;
214
215        impl<State> Service<State, Request> for OkService
216        where
217            State: Clone + Send + Sync + 'static,
218        {
219            type Response = StatusCode;
220            type Error = Infallible;
221
222            async fn serve(
223                &self,
224                _ctx: Context<State>,
225                _req: Request,
226            ) -> Result<Self::Response, Self::Error> {
227                Ok(StatusCode::OK)
228            }
229        }
230
231        let svc = OkService;
232        let resp = svc
233            .serve(
234                Context::default(),
235                Request::builder()
236                    .uri("http://example.com")
237                    .body(Body::empty())
238                    .unwrap(),
239            )
240            .await
241            .unwrap();
242        assert_eq!(resp, StatusCode::OK);
243
244        assert_into_endpoint_service(svc)
245    }
246
247    #[test]
248    fn test_into_endpoint_service_fn_no_param() {
249        assert_into_endpoint_service(async || StatusCode::OK);
250        assert_into_endpoint_service(async || "hello");
251    }
252
253    #[tokio::test]
254    async fn test_service_fn_wrapper_no_param() {
255        let svc = async || StatusCode::OK;
256        let svc = svc.into_endpoint_service();
257
258        let resp = svc
259            .serve(
260                Context::default(),
261                Request::builder()
262                    .uri("http://example.com")
263                    .body(Body::empty())
264                    .unwrap(),
265            )
266            .await
267            .unwrap();
268        assert_eq!(resp.status(), StatusCode::OK);
269    }
270
271    #[tokio::test]
272    async fn test_service_fn_wrapper_single_param_request() {
273        let svc = async |req: Request| req.uri().to_string();
274        let svc = svc.into_endpoint_service();
275
276        let resp = svc
277            .serve(
278                Context::default(),
279                Request::builder()
280                    .uri("http://example.com")
281                    .body(Body::empty())
282                    .unwrap(),
283            )
284            .await
285            .unwrap();
286        assert_eq!(resp.status(), StatusCode::OK);
287        let body = resp.into_body().collect().await.unwrap().to_bytes();
288        assert_eq!(body, "http://example.com/")
289    }
290
291    #[tokio::test]
292    async fn test_service_fn_wrapper_single_param_host() {
293        let svc = async |Host(host): Host| host.to_string();
294        let svc = svc.into_endpoint_service();
295
296        let resp = svc
297            .serve(
298                Context::default(),
299                Request::builder()
300                    .uri("http://example.com")
301                    .body(Body::empty())
302                    .unwrap(),
303            )
304            .await
305            .unwrap();
306        assert_eq!(resp.status(), StatusCode::OK);
307        let body = resp.into_body().collect().await.unwrap().to_bytes();
308        assert_eq!(body, "example.com")
309    }
310
311    #[tokio::test]
312    async fn test_service_fn_wrapper_multi_param_host() {
313        #[derive(Debug, Clone, serde::Deserialize)]
314        struct Params {
315            foo: String,
316        }
317
318        let svc = crate::service::web::WebService::default().get(
319            "/:foo/bar",
320            async |Host(host): Host, Path(params): Path<Params>| {
321                format!("{} => {}", host, params.foo)
322            },
323        );
324        let svc = svc.into_endpoint_service();
325
326        let resp = svc
327            .serve(
328                Context::default(),
329                Request::builder()
330                    .uri("http://example.com/42/bar")
331                    .body(Body::empty())
332                    .unwrap(),
333            )
334            .await
335            .unwrap();
336        assert_eq!(resp.status(), StatusCode::OK);
337        let body = resp.into_body().collect().await.unwrap().to_bytes();
338        assert_eq!(body, "example.com => 42")
339    }
340
341    #[test]
342    fn test_into_endpoint_service_fn_single_param() {
343        #[derive(Debug, Clone, serde::Deserialize)]
344        struct Params {
345            foo: String,
346        }
347
348        assert_into_endpoint_service(async |_path: Path<Params>| StatusCode::OK);
349        assert_into_endpoint_service(async |Path(params): Path<Params>| params.foo);
350        assert_into_endpoint_service(async |Query(query): Query<Params>| query.foo);
351        assert_into_endpoint_service(async |method: Method| method.to_string());
352        assert_into_endpoint_service(async |req: Request| req.uri().to_string());
353        assert_into_endpoint_service(async |_host: Host| StatusCode::OK);
354        assert_into_endpoint_service(async |Host(_host): Host| StatusCode::OK);
355    }
356}