1use crate::{Body, Request, Response, matcher::HttpMatcher};
2use rama_core::{Context, Layer, Service, layer::MapResponseLayer, service::BoxService};
3use std::{convert::Infallible, fmt};
4
5pub mod extract;
6pub mod response;
7
8use response::IntoResponse;
9
10pub(crate) struct Endpoint<State> {
11 pub(crate) matcher: HttpMatcher<State, Body>,
12 pub(crate) service: BoxService<State, Request, Response, Infallible>,
13}
14
15pub trait IntoEndpointService<State, T>: private::Sealed<T> {
17 fn into_endpoint_service(
19 self,
20 ) -> impl Service<State, Request, Response = Response, Error = Infallible>;
21}
22
23impl<State, S, R> IntoEndpointService<State, (State, R)> for S
24where
25 State: Clone + Send + Sync + 'static,
26 S: Service<State, Request, Response = R, Error = Infallible>,
27 R: IntoResponse + Send + Sync + 'static,
28{
29 fn into_endpoint_service(
30 self,
31 ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
32 MapResponseLayer::new(R::into_response).into_layer(self)
33 }
34}
35
36impl<State, R> IntoEndpointService<State, ()> for R
37where
38 State: Clone + Send + Sync + 'static,
39 R: IntoResponse + Clone + Send + Sync + 'static,
40{
41 fn into_endpoint_service(
42 self,
43 ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
44 StaticService(self)
45 }
46}
47
48pub struct StaticService<R>(R);
50
51impl<R> StaticService<R>
52where
53 R: IntoResponse + Clone + Send + Sync + 'static,
54{
55 pub fn new(response: R) -> Self {
57 Self(response)
58 }
59}
60
61impl<T> fmt::Debug for StaticService<T>
62where
63 T: fmt::Debug,
64{
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_tuple("StaticService").field(&self.0).finish()
67 }
68}
69
70impl<R> Clone for StaticService<R>
71where
72 R: Clone,
73{
74 fn clone(&self) -> Self {
75 Self(self.0.clone())
76 }
77}
78
79impl<R, State> Service<State, Request> for StaticService<R>
80where
81 R: IntoResponse + Clone + Send + Sync + 'static,
82 State: Clone + Send + Sync + 'static,
83{
84 type Response = Response;
85 type Error = Infallible;
86
87 async fn serve(&self, _: Context<State>, _: Request) -> Result<Self::Response, Self::Error> {
88 Ok(self.0.clone().into_response())
89 }
90}
91
92mod service;
93#[doc(inline)]
94pub use service::EndpointServiceFn;
95
96struct EndpointServiceFnWrapper<F, S, T> {
97 inner: F,
98 _marker: std::marker::PhantomData<fn(S, T) -> ()>,
99}
100
101impl<F: std::fmt::Debug, S, T> std::fmt::Debug for EndpointServiceFnWrapper<F, S, T> {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("EndpointServiceFnWrapper")
104 .field("inner", &self.inner)
105 .field(
106 "_marker",
107 &format_args!("{}", std::any::type_name::<fn(S, T) -> ()>()),
108 )
109 .finish()
110 }
111}
112
113impl<F, S, T> Clone for EndpointServiceFnWrapper<F, S, T>
114where
115 F: Clone,
116{
117 fn clone(&self) -> Self {
118 Self {
119 inner: self.inner.clone(),
120 _marker: std::marker::PhantomData,
121 }
122 }
123}
124
125impl<F, S, T> Service<S, Request> for EndpointServiceFnWrapper<F, S, T>
126where
127 F: EndpointServiceFn<S, T>,
128 S: Clone + Send + Sync + 'static,
129 T: Send + 'static,
130{
131 type Response = Response;
132 type Error = Infallible;
133
134 async fn serve(&self, ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
135 Ok(self.inner.call(ctx, req).await)
136 }
137}
138
139impl<F, S, T> IntoEndpointService<S, (F, S, T)> for F
140where
141 F: EndpointServiceFn<S, T>,
142 S: Clone + Send + Sync + 'static,
143 T: Send + 'static,
144{
145 fn into_endpoint_service(
146 self,
147 ) -> impl Service<S, Request, Response = Response, Error = Infallible> {
148 EndpointServiceFnWrapper {
149 inner: self,
150 _marker: std::marker::PhantomData,
151 }
152 }
153}
154
155mod private {
156 use super::*;
157
158 pub trait Sealed<T> {}
159
160 impl<State, S, R> Sealed<(State, R)> for S
161 where
162 State: Clone + Send + Sync + 'static,
163 S: Service<State, Request, Response = R, Error = Infallible>,
164 R: IntoResponse + Send + Sync + 'static,
165 {
166 }
167
168 impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Fut, R)> for F
169 where
170 State: Clone + Send + Sync + 'static,
171 F: Fn(Context<State>) -> Fut + Send + Sync + 'static,
172 Fut: Future<Output = R> + Send + 'static,
173 R: IntoResponse + Send + Sync + 'static,
174 {
175 }
176
177 impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Request, Fut, R)> for F
178 where
179 State: Clone + Send + Sync + 'static,
180 F: Fn(Context<State>, Request) -> Fut + Send + Sync + 'static,
181 Fut: Future<Output = R> + Send + 'static,
182 R: IntoResponse + Send + Sync + 'static,
183 {
184 }
185
186 impl<R> Sealed<()> for R where R: IntoResponse + Send + Sync + 'static {}
187
188 impl<F, S, T> Sealed<(F, S, T)> for F where F: EndpointServiceFn<S, T> {}
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::{Body, Method, Request, StatusCode, dep::http_body_util::BodyExt};
195 use extract::*;
196
197 fn assert_into_endpoint_service<T, I>(_: I)
198 where
199 I: IntoEndpointService<(), T>,
200 {
201 }
202
203 #[test]
204 fn test_into_endpoint_service_static() {
205 assert_into_endpoint_service(StatusCode::OK);
206 assert_into_endpoint_service("hello");
207 assert_into_endpoint_service("hello".to_owned());
208 }
209
210 #[tokio::test]
211 async fn test_into_endpoint_service_impl() {
212 #[derive(Debug, Clone)]
213 struct OkService;
214
215 impl<State> Service<State, Request> for OkService
216 where
217 State: Clone + Send + Sync + 'static,
218 {
219 type Response = StatusCode;
220 type Error = Infallible;
221
222 async fn serve(
223 &self,
224 _ctx: Context<State>,
225 _req: Request,
226 ) -> Result<Self::Response, Self::Error> {
227 Ok(StatusCode::OK)
228 }
229 }
230
231 let svc = OkService;
232 let resp = svc
233 .serve(
234 Context::default(),
235 Request::builder()
236 .uri("http://example.com")
237 .body(Body::empty())
238 .unwrap(),
239 )
240 .await
241 .unwrap();
242 assert_eq!(resp, StatusCode::OK);
243
244 assert_into_endpoint_service(svc)
245 }
246
247 #[test]
248 fn test_into_endpoint_service_fn_no_param() {
249 assert_into_endpoint_service(async || StatusCode::OK);
250 assert_into_endpoint_service(async || "hello");
251 }
252
253 #[tokio::test]
254 async fn test_service_fn_wrapper_no_param() {
255 let svc = async || StatusCode::OK;
256 let svc = svc.into_endpoint_service();
257
258 let resp = svc
259 .serve(
260 Context::default(),
261 Request::builder()
262 .uri("http://example.com")
263 .body(Body::empty())
264 .unwrap(),
265 )
266 .await
267 .unwrap();
268 assert_eq!(resp.status(), StatusCode::OK);
269 }
270
271 #[tokio::test]
272 async fn test_service_fn_wrapper_single_param_request() {
273 let svc = async |req: Request| req.uri().to_string();
274 let svc = svc.into_endpoint_service();
275
276 let resp = svc
277 .serve(
278 Context::default(),
279 Request::builder()
280 .uri("http://example.com")
281 .body(Body::empty())
282 .unwrap(),
283 )
284 .await
285 .unwrap();
286 assert_eq!(resp.status(), StatusCode::OK);
287 let body = resp.into_body().collect().await.unwrap().to_bytes();
288 assert_eq!(body, "http://example.com/")
289 }
290
291 #[tokio::test]
292 async fn test_service_fn_wrapper_single_param_host() {
293 let svc = async |Host(host): Host| host.to_string();
294 let svc = svc.into_endpoint_service();
295
296 let resp = svc
297 .serve(
298 Context::default(),
299 Request::builder()
300 .uri("http://example.com")
301 .body(Body::empty())
302 .unwrap(),
303 )
304 .await
305 .unwrap();
306 assert_eq!(resp.status(), StatusCode::OK);
307 let body = resp.into_body().collect().await.unwrap().to_bytes();
308 assert_eq!(body, "example.com")
309 }
310
311 #[tokio::test]
312 async fn test_service_fn_wrapper_multi_param_host() {
313 #[derive(Debug, Clone, serde::Deserialize)]
314 struct Params {
315 foo: String,
316 }
317
318 let svc = crate::service::web::WebService::default().get(
319 "/:foo/bar",
320 async |Host(host): Host, Path(params): Path<Params>| {
321 format!("{} => {}", host, params.foo)
322 },
323 );
324 let svc = svc.into_endpoint_service();
325
326 let resp = svc
327 .serve(
328 Context::default(),
329 Request::builder()
330 .uri("http://example.com/42/bar")
331 .body(Body::empty())
332 .unwrap(),
333 )
334 .await
335 .unwrap();
336 assert_eq!(resp.status(), StatusCode::OK);
337 let body = resp.into_body().collect().await.unwrap().to_bytes();
338 assert_eq!(body, "example.com => 42")
339 }
340
341 #[test]
342 fn test_into_endpoint_service_fn_single_param() {
343 #[derive(Debug, Clone, serde::Deserialize)]
344 struct Params {
345 foo: String,
346 }
347
348 assert_into_endpoint_service(async |_path: Path<Params>| StatusCode::OK);
349 assert_into_endpoint_service(async |Path(params): Path<Params>| params.foo);
350 assert_into_endpoint_service(async |Query(query): Query<Params>| query.foo);
351 assert_into_endpoint_service(async |method: Method| method.to_string());
352 assert_into_endpoint_service(async |req: Request| req.uri().to_string());
353 assert_into_endpoint_service(async |_host: Host| StatusCode::OK);
354 assert_into_endpoint_service(async |Host(_host): Host| StatusCode::OK);
355 }
356}