1use crate::{Body, IntoResponse, Request, Response, matcher::HttpMatcher};
2use rama_core::{Context, Layer, Service, layer::MapResponseLayer, service::BoxService};
3use std::{convert::Infallible, fmt};
4
5pub mod extract;
6
7pub(crate) struct Endpoint<State> {
8 pub(crate) matcher: HttpMatcher<State, Body>,
9 pub(crate) service: BoxService<State, Request, Response, Infallible>,
10}
11
12pub trait IntoEndpointService<State, T>: private::Sealed<T> {
14 fn into_endpoint_service(
16 self,
17 ) -> impl Service<State, Request, Response = Response, Error = Infallible>;
18}
19
20impl<State, S, R> IntoEndpointService<State, (State, R)> for S
21where
22 State: Clone + Send + Sync + 'static,
23 S: Service<State, Request, Response = R, Error = Infallible>,
24 R: IntoResponse + Send + Sync + 'static,
25{
26 fn into_endpoint_service(
27 self,
28 ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
29 MapResponseLayer::new(R::into_response).into_layer(self)
30 }
31}
32
33impl<State, R> IntoEndpointService<State, ()> for R
34where
35 State: Clone + Send + Sync + 'static,
36 R: IntoResponse + Clone + Send + Sync + 'static,
37{
38 fn into_endpoint_service(
39 self,
40 ) -> impl Service<State, Request, Response = Response, Error = Infallible> {
41 StaticService(self)
42 }
43}
44
45pub struct StaticService<R>(R);
47
48impl<R> StaticService<R>
49where
50 R: IntoResponse + Clone + Send + Sync + 'static,
51{
52 pub fn new(response: R) -> Self {
54 Self(response)
55 }
56}
57
58impl<T> fmt::Debug for StaticService<T>
59where
60 T: fmt::Debug,
61{
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_tuple("StaticService").field(&self.0).finish()
64 }
65}
66
67impl<R> Clone for StaticService<R>
68where
69 R: Clone,
70{
71 fn clone(&self) -> Self {
72 Self(self.0.clone())
73 }
74}
75
76impl<R, State> Service<State, Request> for StaticService<R>
77where
78 R: IntoResponse + Clone + Send + Sync + 'static,
79 State: Clone + Send + Sync + 'static,
80{
81 type Response = Response;
82 type Error = Infallible;
83
84 async fn serve(&self, _: Context<State>, _: Request) -> Result<Self::Response, Self::Error> {
85 Ok(self.0.clone().into_response())
86 }
87}
88
89mod service;
90#[doc(inline)]
91pub use service::EndpointServiceFn;
92
93struct EndpointServiceFnWrapper<F, S, T> {
94 inner: F,
95 _marker: std::marker::PhantomData<fn(S, T) -> ()>,
96}
97
98impl<F: std::fmt::Debug, S, T> std::fmt::Debug for EndpointServiceFnWrapper<F, S, T> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("EndpointServiceFnWrapper")
101 .field("inner", &self.inner)
102 .field(
103 "_marker",
104 &format_args!("{}", std::any::type_name::<fn(S, T) -> ()>()),
105 )
106 .finish()
107 }
108}
109
110impl<F, S, T> Clone for EndpointServiceFnWrapper<F, S, T>
111where
112 F: Clone,
113{
114 fn clone(&self) -> Self {
115 Self {
116 inner: self.inner.clone(),
117 _marker: std::marker::PhantomData,
118 }
119 }
120}
121
122impl<F, S, T> Service<S, Request> for EndpointServiceFnWrapper<F, S, T>
123where
124 F: EndpointServiceFn<S, T>,
125 S: Clone + Send + Sync + 'static,
126 T: Send + 'static,
127{
128 type Response = Response;
129 type Error = Infallible;
130
131 async fn serve(&self, ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
132 Ok(self.inner.call(ctx, req).await)
133 }
134}
135
136impl<F, S, T> IntoEndpointService<S, (F, S, T)> for F
137where
138 F: EndpointServiceFn<S, T>,
139 S: Clone + Send + Sync + 'static,
140 T: Send + 'static,
141{
142 fn into_endpoint_service(
143 self,
144 ) -> impl Service<S, Request, Response = Response, Error = Infallible> {
145 EndpointServiceFnWrapper {
146 inner: self,
147 _marker: std::marker::PhantomData,
148 }
149 }
150}
151
152mod private {
153 use super::*;
154
155 pub trait Sealed<T> {}
156
157 impl<State, S, R> Sealed<(State, R)> for S
158 where
159 State: Clone + Send + Sync + 'static,
160 S: Service<State, Request, Response = R, Error = Infallible>,
161 R: IntoResponse + Send + Sync + 'static,
162 {
163 }
164
165 impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Fut, R)> for F
166 where
167 State: Clone + Send + Sync + 'static,
168 F: Fn(Context<State>) -> Fut + Send + Sync + 'static,
169 Fut: Future<Output = R> + Send + 'static,
170 R: IntoResponse + Send + Sync + 'static,
171 {
172 }
173
174 impl<State, F, Fut, R> Sealed<(State, F, Context<State>, Request, Fut, R)> for F
175 where
176 State: Clone + Send + Sync + 'static,
177 F: Fn(Context<State>, Request) -> Fut + Send + Sync + 'static,
178 Fut: Future<Output = R> + Send + 'static,
179 R: IntoResponse + Send + Sync + 'static,
180 {
181 }
182
183 impl<R> Sealed<()> for R where R: IntoResponse + Send + Sync + 'static {}
184
185 impl<F, S, T> Sealed<(F, S, T)> for F where F: EndpointServiceFn<S, T> {}
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::{Body, Method, Request, StatusCode, dep::http_body_util::BodyExt};
192 use extract::*;
193
194 fn assert_into_endpoint_service<T, I>(_: I)
195 where
196 I: IntoEndpointService<(), T>,
197 {
198 }
199
200 #[test]
201 fn test_into_endpoint_service_static() {
202 assert_into_endpoint_service(StatusCode::OK);
203 assert_into_endpoint_service("hello");
204 assert_into_endpoint_service("hello".to_owned());
205 }
206
207 #[tokio::test]
208 async fn test_into_endpoint_service_impl() {
209 #[derive(Debug, Clone)]
210 struct OkService;
211
212 impl<State> Service<State, Request> for OkService
213 where
214 State: Clone + Send + Sync + 'static,
215 {
216 type Response = StatusCode;
217 type Error = Infallible;
218
219 async fn serve(
220 &self,
221 _ctx: Context<State>,
222 _req: Request,
223 ) -> Result<Self::Response, Self::Error> {
224 Ok(StatusCode::OK)
225 }
226 }
227
228 let svc = OkService;
229 let resp = svc
230 .serve(
231 Context::default(),
232 Request::builder()
233 .uri("http://example.com")
234 .body(Body::empty())
235 .unwrap(),
236 )
237 .await
238 .unwrap();
239 assert_eq!(resp, StatusCode::OK);
240
241 assert_into_endpoint_service(svc)
242 }
243
244 #[test]
245 fn test_into_endpoint_service_fn_no_param() {
246 assert_into_endpoint_service(async || StatusCode::OK);
247 assert_into_endpoint_service(async || "hello");
248 }
249
250 #[tokio::test]
251 async fn test_service_fn_wrapper_no_param() {
252 let svc = async || StatusCode::OK;
253 let svc = svc.into_endpoint_service();
254
255 let resp = svc
256 .serve(
257 Context::default(),
258 Request::builder()
259 .uri("http://example.com")
260 .body(Body::empty())
261 .unwrap(),
262 )
263 .await
264 .unwrap();
265 assert_eq!(resp.status(), StatusCode::OK);
266 }
267
268 #[tokio::test]
269 async fn test_service_fn_wrapper_single_param_request() {
270 let svc = async |req: Request| req.uri().to_string();
271 let svc = svc.into_endpoint_service();
272
273 let resp = svc
274 .serve(
275 Context::default(),
276 Request::builder()
277 .uri("http://example.com")
278 .body(Body::empty())
279 .unwrap(),
280 )
281 .await
282 .unwrap();
283 assert_eq!(resp.status(), StatusCode::OK);
284 let body = resp.into_body().collect().await.unwrap().to_bytes();
285 assert_eq!(body, "http://example.com/")
286 }
287
288 #[tokio::test]
289 async fn test_service_fn_wrapper_single_param_host() {
290 let svc = async |Host(host): Host| host.to_string();
291 let svc = svc.into_endpoint_service();
292
293 let resp = svc
294 .serve(
295 Context::default(),
296 Request::builder()
297 .uri("http://example.com")
298 .body(Body::empty())
299 .unwrap(),
300 )
301 .await
302 .unwrap();
303 assert_eq!(resp.status(), StatusCode::OK);
304 let body = resp.into_body().collect().await.unwrap().to_bytes();
305 assert_eq!(body, "example.com")
306 }
307
308 #[tokio::test]
309 async fn test_service_fn_wrapper_multi_param_host() {
310 #[derive(Debug, Clone, serde::Deserialize)]
311 struct Params {
312 foo: String,
313 }
314
315 let svc = crate::service::web::WebService::default().get(
316 "/:foo/bar",
317 async |Host(host): Host, Path(params): Path<Params>| {
318 format!("{} => {}", host, params.foo)
319 },
320 );
321 let svc = svc.into_endpoint_service();
322
323 let resp = svc
324 .serve(
325 Context::default(),
326 Request::builder()
327 .uri("http://example.com/42/bar")
328 .body(Body::empty())
329 .unwrap(),
330 )
331 .await
332 .unwrap();
333 assert_eq!(resp.status(), StatusCode::OK);
334 let body = resp.into_body().collect().await.unwrap().to_bytes();
335 assert_eq!(body, "example.com => 42")
336 }
337
338 #[test]
339 fn test_into_endpoint_service_fn_single_param() {
340 #[derive(Debug, Clone, serde::Deserialize)]
341 struct Params {
342 foo: String,
343 }
344
345 assert_into_endpoint_service(async |_path: Path<Params>| StatusCode::OK);
346 assert_into_endpoint_service(async |Path(params): Path<Params>| params.foo);
347 assert_into_endpoint_service(async |Query(query): Query<Params>| query.foo);
348 assert_into_endpoint_service(async |method: Method| method.to_string());
349 assert_into_endpoint_service(async |req: Request| req.uri().to_string());
350 assert_into_endpoint_service(async |_host: Host| StatusCode::OK);
351 assert_into_endpoint_service(async |Host(_host): Host| StatusCode::OK);
352 }
353}