1use crate::{matcher::HttpMatcher, Body, IntoResponse, Request, Response};
2use rama_core::{layer::MapResponseLayer, service::BoxService, Context, Layer, Service};
3use std::future::Future;
4use std::{convert::Infallible, fmt};
5
6pub mod extract;
7
8pub(crate) struct Endpoint<State> {
9 pub(crate) matcher: HttpMatcher<State, Body>,
10 pub(crate) service: BoxService<State, Request, Response, Infallible>,
11}
12
13pub trait IntoEndpointService<State, T>: private::Sealed<T> {
15 fn into_endpoint_service(
17 self,
18 ) -> impl Service<State, Request, Response = Response, Error = Infallible>;
19}
20
21impl<State, S, R> IntoEndpointService<State, (State, R)> for S
22where
23 State: Clone + Send + Sync + 'static,
24 S: Service<State, Request, Response = R, Error = Infallible>,
25 R: IntoResponse + Send + Sync + 'static,
26{
27 fn into_endpoint_service(
28 self,
29 ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
30 MapResponseLayer::new(R::into_response).layer(self)
31 }
32}
33
34impl<State, R> IntoEndpointService<State, ()> for R
35where
36 State: Clone + Send + Sync + 'static,
37 R: IntoResponse + Clone + Send + Sync + 'static,
38{
39 fn into_endpoint_service(
40 self,
41 ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
42 StaticService(self)
43 }
44}
45
46struct StaticService<R>(R);
47
48impl<T> fmt::Debug for StaticService<T>
49where
50 T: fmt::Debug,
51{
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_tuple("StaticService").field(&self.0).finish()
54 }
55}
56
57impl<R> Clone for StaticService<R>
58where
59 R: Clone,
60{
61 fn clone(&self) -> Self {
62 Self(self.0.clone())
63 }
64}
65
66impl<R, State> Service<State, Request> for StaticService<R>
67where
68 R: IntoResponse + Clone + Send + Sync + 'static,
69 State: Clone + Send + Sync + 'static,
70{
71 type Response = Response;
72 type Error = Infallible;
73
74 async fn serve(&self, _: Context<State>, _: Request) -> Result<Self::Response, Self::Error> {
75 Ok(self.0.clone().into_response())
76 }
77}
78
79mod service;
80#[doc(inline)]
81pub use service::EndpointServiceFn;
82
83struct EndpointServiceFnWrapper<F, S, T> {
84 inner: F,
85 _marker: std::marker::PhantomData<fn(S, T) -> ()>,
86}
87
88impl<F: std::fmt::Debug, S, T> std::fmt::Debug for EndpointServiceFnWrapper<F, S, T> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("EndpointServiceFnWrapper")
91 .field("inner", &self.inner)
92 .field(
93 "_marker",
94 &format_args!("{}", std::any::type_name::<fn(S, T) -> ()>()),
95 )
96 .finish()
97 }
98}
99
100impl<F, S, T> Clone for EndpointServiceFnWrapper<F, S, T>
101where
102 F: Clone,
103{
104 fn clone(&self) -> Self {
105 Self {
106 inner: self.inner.clone(),
107 _marker: std::marker::PhantomData,
108 }
109 }
110}
111
112impl<F, S, T> Service<S, Request> for EndpointServiceFnWrapper<F, S, T>
113where
114 F: EndpointServiceFn<S, T>,
115 S: Clone + Send + Sync + 'static,
116 T: Send + 'static,
117{
118 type Response = Response;
119 type Error = Infallible;
120
121 async fn serve(&self, ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
122 Ok(self.inner.call(ctx, req).await)
123 }
124}
125
126impl<F, S, T> IntoEndpointService<S, (F, S, T)> for F
127where
128 F: EndpointServiceFn<S, T>,
129 S: Clone + Send + Sync + 'static,
130 T: Send + 'static,
131{
132 fn into_endpoint_service(
133 self,
134 ) -> impl Service<S, Request, Response = Response, Error = Infallible> {
135 EndpointServiceFnWrapper {
136 inner: self,
137 _marker: std::marker::PhantomData,
138 }
139 }
140}
141
142mod private {
143 use super::*;
144
145 pub trait Sealed<T> {}
146
147 impl<State, S, R> Sealed<(State, R)> for S
148 where
149 State: Clone + Send + Sync + 'static,
150 S: Service<State, Request, Response = R, Error = Infallible>,
151 R: IntoResponse + Send + Sync + 'static,
152 {
153 }
154
155 impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Fut, R)> for F
156 where
157 State: Clone + Send + Sync + 'static,
158 F: Fn(Context<State>) -> Fut + Send + Sync + 'static,
159 Fut: Future<Output = R> + Send + 'static,
160 R: IntoResponse + Send + Sync + 'static,
161 {
162 }
163
164 impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Request, Fut, R)> for F
165 where
166 State: Clone + Send + Sync + 'static,
167 F: Fn(Context<State>, Request) -> Fut + Send + Sync + 'static,
168 Fut: Future<Output = R> + Send + 'static,
169 R: IntoResponse + Send + Sync + 'static,
170 {
171 }
172
173 impl<R> Sealed<()> for R where R: IntoResponse + Send + Sync + 'static {}
174
175 impl<F, S, T> Sealed<(F, S, T)> for F where F: EndpointServiceFn<S, T> {}
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::{dep::http_body_util::BodyExt, Body, Method, Request, StatusCode};
182 use extract::*;
183
184 fn assert_into_endpoint_service<T, I>(_: I)
185 where
186 I: IntoEndpointService<(), T>,
187 {
188 }
189
190 #[test]
191 fn test_into_endpoint_service_static() {
192 assert_into_endpoint_service(StatusCode::OK);
193 assert_into_endpoint_service("hello");
194 assert_into_endpoint_service("hello".to_owned());
195 }
196
197 #[tokio::test]
198 async fn test_into_endpoint_service_impl() {
199 #[derive(Debug, Clone)]
200 struct OkService;
201
202 impl<State> Service<State, Request> for OkService
203 where
204 State: Clone + Send + Sync + 'static,
205 {
206 type Response = StatusCode;
207 type Error = Infallible;
208
209 async fn serve(
210 &self,
211 _ctx: Context<State>,
212 _req: Request,
213 ) -> Result<Self::Response, Self::Error> {
214 Ok(StatusCode::OK)
215 }
216 }
217
218 let svc = OkService;
219 let resp = svc
220 .serve(
221 Context::default(),
222 Request::builder()
223 .uri("http://example.com")
224 .body(Body::empty())
225 .unwrap(),
226 )
227 .await
228 .unwrap();
229 assert_eq!(resp, StatusCode::OK);
230
231 assert_into_endpoint_service(svc)
232 }
233
234 #[test]
235 fn test_into_endpoint_service_fn_no_param() {
236 assert_into_endpoint_service(|| async { StatusCode::OK });
237 assert_into_endpoint_service(|| async { "hello" });
238 }
239
240 #[tokio::test]
241 async fn test_service_fn_wrapper_no_param() {
242 let svc = || async { StatusCode::OK };
243 let svc = svc.into_endpoint_service();
244
245 let resp = svc
246 .serve(
247 Context::default(),
248 Request::builder()
249 .uri("http://example.com")
250 .body(Body::empty())
251 .unwrap(),
252 )
253 .await
254 .unwrap();
255 assert_eq!(resp.status(), StatusCode::OK);
256 }
257
258 #[tokio::test]
259 async fn test_service_fn_wrapper_single_param_request() {
260 let svc = |req: Request| async move { req.uri().to_string() };
261 let svc = svc.into_endpoint_service();
262
263 let resp = svc
264 .serve(
265 Context::default(),
266 Request::builder()
267 .uri("http://example.com")
268 .body(Body::empty())
269 .unwrap(),
270 )
271 .await
272 .unwrap();
273 assert_eq!(resp.status(), StatusCode::OK);
274 let body = resp.into_body().collect().await.unwrap().to_bytes();
275 assert_eq!(body, "http://example.com/")
276 }
277
278 #[tokio::test]
279 async fn test_service_fn_wrapper_single_param_host() {
280 let svc = |Host(host): Host| async move { host.to_string() };
281 let svc = svc.into_endpoint_service();
282
283 let resp = svc
284 .serve(
285 Context::default(),
286 Request::builder()
287 .uri("http://example.com")
288 .body(Body::empty())
289 .unwrap(),
290 )
291 .await
292 .unwrap();
293 assert_eq!(resp.status(), StatusCode::OK);
294 let body = resp.into_body().collect().await.unwrap().to_bytes();
295 assert_eq!(body, "example.com")
296 }
297
298 #[tokio::test]
299 async fn test_service_fn_wrapper_multi_param_host() {
300 #[derive(Debug, Clone, serde::Deserialize)]
301 struct Params {
302 foo: String,
303 }
304
305 let svc = crate::service::web::WebService::default().get(
306 "/:foo/bar",
307 |Host(host): Host, Path(params): Path<Params>| async move {
308 format!("{} => {}", host, params.foo)
309 },
310 );
311 let svc = svc.into_endpoint_service();
312
313 let resp = svc
314 .serve(
315 Context::default(),
316 Request::builder()
317 .uri("http://example.com/42/bar")
318 .body(Body::empty())
319 .unwrap(),
320 )
321 .await
322 .unwrap();
323 assert_eq!(resp.status(), StatusCode::OK);
324 let body = resp.into_body().collect().await.unwrap().to_bytes();
325 assert_eq!(body, "example.com => 42")
326 }
327
328 #[test]
329 fn test_into_endpoint_service_fn_single_param() {
330 #[derive(Debug, Clone, serde::Deserialize)]
331 struct Params {
332 foo: String,
333 }
334
335 assert_into_endpoint_service(|_path: Path<Params>| async { StatusCode::OK });
336 assert_into_endpoint_service(|Path(params): Path<Params>| async move { params.foo });
337 assert_into_endpoint_service(|Query(query): Query<Params>| async move { query.foo });
338 assert_into_endpoint_service(|method: Method| async move { method.to_string() });
339 assert_into_endpoint_service(|req: Request| async move { req.uri().to_string() });
340 assert_into_endpoint_service(|_host: Host| async { StatusCode::OK });
341 assert_into_endpoint_service(|Host(_host): Host| async { StatusCode::OK });
342 }
343}