1use std::{convert::Infallible, marker::PhantomData, sync::Arc};
4
5use motore::{layer::Layer, service::Service};
6
7use super::{
8 IntoResponse,
9 handler::{MiddlewareHandlerFromFn, MiddlewareHandlerMapResponse},
10 route::Route,
11};
12use crate::{body::Body, context::ServerContext, request::Request, response::Response};
13
14pub struct FromFnLayer<F, T, B, B2, E2> {
18 f: F,
19 #[allow(clippy::type_complexity)]
20 _marker: PhantomData<fn(T, B, B2, E2)>,
21}
22
23impl<F, T, B, B2, E2> Clone for FromFnLayer<F, T, B, B2, E2>
24where
25 F: Clone,
26{
27 fn clone(&self) -> Self {
28 Self {
29 f: self.f.clone(),
30 _marker: self._marker,
31 }
32 }
33}
34
35pub fn from_fn<F, T, B, B2, E2>(f: F) -> FromFnLayer<F, T, B, B2, E2> {
157 FromFnLayer {
158 f,
159 _marker: PhantomData,
160 }
161}
162
163impl<S, F, T, B, B2, E2> Layer<S> for FromFnLayer<F, T, B, B2, E2>
164where
165 S: Service<ServerContext, Request<B2>, Response = Response, Error = E2> + Send + Sync + 'static,
166{
167 type Service = FromFn<Arc<S>, F, T, B, B2, E2>;
168
169 fn layer(self, service: S) -> Self::Service {
170 FromFn {
171 service: Arc::new(service),
172 f: self.f,
173 _marker: PhantomData,
174 }
175 }
176}
177
178pub struct FromFn<S, F, T, B, B2, E2> {
180 service: S,
181 f: F,
182 _marker: PhantomData<fn(T, B, B2, E2)>,
183}
184
185impl<S, F, T, B, B2, E2> Clone for FromFn<S, F, T, B, B2, E2>
186where
187 S: Clone,
188 F: Clone,
189{
190 fn clone(&self) -> Self {
191 Self {
192 service: self.service.clone(),
193 f: self.f.clone(),
194 _marker: self._marker,
195 }
196 }
197}
198
199impl<S, F, T, B, B2, E2> Service<ServerContext, Request<B>> for FromFn<S, F, T, B, B2, E2>
200where
201 S: Service<ServerContext, Request<B2>, Response = Response, Error = E2>
202 + Clone
203 + Send
204 + Sync
205 + 'static,
206 F: for<'r> MiddlewareHandlerFromFn<'r, T, B, B2, E2> + Sync,
207 B: Send,
208 B2: 'static,
209{
210 type Response = Response;
211 type Error = Infallible;
212
213 async fn call(
214 &self,
215 cx: &mut ServerContext,
216 req: Request<B>,
217 ) -> Result<Self::Response, Self::Error> {
218 let next = Next {
219 service: Route::new(self.service.clone()),
220 };
221 Ok(self.f.handle(cx, req, next).await.into_response())
222 }
223}
224
225pub struct Next<B = Body, E = Infallible> {
232 service: Route<B, E>,
233}
234
235impl<B, E> Next<B, E> {
236 pub async fn run(self, cx: &mut ServerContext, req: Request<B>) -> Result<Response, E> {
238 self.service.call(cx, req).await
239 }
240}
241
242pub struct MapResponseLayer<F, T, R1, R2> {
246 f: F,
247 _marker: PhantomData<fn(T, R1, R2)>,
248}
249
250impl<F, T, R1, R2> Clone for MapResponseLayer<F, T, R1, R2>
251where
252 F: Clone,
253{
254 fn clone(&self) -> Self {
255 Self {
256 f: self.f.clone(),
257 _marker: self._marker,
258 }
259 }
260}
261
262pub fn map_response<F, T, R1, R2>(f: F) -> MapResponseLayer<F, T, R1, R2> {
295 MapResponseLayer {
296 f,
297 _marker: PhantomData,
298 }
299}
300
301impl<S, F, T, R1, R2> Layer<S> for MapResponseLayer<F, T, R1, R2> {
302 type Service = MapResponse<S, F, T, R1, R2>;
303
304 fn layer(self, service: S) -> Self::Service {
305 MapResponse {
306 service,
307 f: self.f,
308 _marker: self._marker,
309 }
310 }
311}
312
313pub struct MapResponse<S, F, T, R1, R2> {
315 service: S,
316 f: F,
317 _marker: PhantomData<fn(T, R1, R2)>,
318}
319
320impl<S, F, T, R1, R2> Clone for MapResponse<S, F, T, R1, R2>
321where
322 S: Clone,
323 F: Clone,
324{
325 fn clone(&self) -> Self {
326 Self {
327 service: self.service.clone(),
328 f: self.f.clone(),
329 _marker: self._marker,
330 }
331 }
332}
333
334impl<S, F, T, Req, R1, R2> Service<ServerContext, Req> for MapResponse<S, F, T, R1, R2>
335where
336 S: Service<ServerContext, Req, Response = R1> + Send + Sync,
337 F: for<'r> MiddlewareHandlerMapResponse<'r, T, R1, R2> + Sync,
338 Req: Send,
339{
340 type Response = R2;
341 type Error = S::Error;
342
343 async fn call(&self, cx: &mut ServerContext, req: Req) -> Result<Self::Response, Self::Error> {
344 let resp = self.service.call(cx, req).await?;
345
346 Ok(self.f.handle(cx, resp).await)
347 }
348}
349
350#[cfg(test)]
351mod middleware_tests {
352 use faststr::FastStr;
353 use http::{HeaderValue, Method, StatusCode, Uri};
354 use motore::service::service_fn;
355
356 use super::*;
357 use crate::{
358 body::{Body, BodyConversion},
359 context::ServerContext,
360 request::Request,
361 response::Response,
362 server::{
363 response::IntoResponse,
364 route::{any, get_service},
365 test_helpers::empty_cx,
366 },
367 utils::test_helpers::simple_req,
368 };
369
370 async fn print_body_handler(
371 _: &mut ServerContext,
372 req: Request<String>,
373 ) -> Result<Response<Body>, Infallible> {
374 Ok(Response::new(req.into_body().into()))
375 }
376
377 async fn append_body_mw(
378 cx: &mut ServerContext,
379 req: Request<String>,
380 next: Next<String>,
381 ) -> Response {
382 let (parts, mut body) = req.into_parts();
383 body += "test";
384 let req = Request::from_parts(parts, body);
385 next.run(cx, req).await.into_response()
386 }
387
388 async fn cors_mw(
389 method: Method,
390 url: Uri,
391 cx: &mut ServerContext,
392 req: Request<String>,
393 next: Next<String>,
394 ) -> Response {
395 let mut resp = next.run(cx, req).await.into_response();
396 resp.headers_mut().insert(
397 "Access-Control-Allow-Methods",
398 HeaderValue::from_str(method.as_str()).unwrap(),
399 );
400 resp.headers_mut().insert(
401 "Access-Control-Allow-Origin",
402 HeaderValue::from_str(url.to_string().as_str()).unwrap(),
403 );
404 resp.headers_mut().insert(
405 "Access-Control-Allow-Headers",
406 HeaderValue::from_str("*").unwrap(),
407 );
408 resp
409 }
410
411 #[tokio::test]
412 async fn test_from_fn_with_necessary_params() {
413 let handler = service_fn(print_body_handler);
414 let mut cx = empty_cx();
415
416 let service = from_fn(append_body_mw).layer(handler);
417 let req = simple_req(Method::GET, "/", String::from(""));
418 let resp = service.call(&mut cx, req).await.unwrap();
419 assert_eq!(resp.into_body().into_string().await.unwrap(), "test");
420
421 async fn error_mw(
423 _: &mut ServerContext,
424 _: Request<String>,
425 _: Next<String>,
426 ) -> Result<Response, StatusCode> {
427 Err(StatusCode::INTERNAL_SERVER_ERROR)
428 }
429 let service = from_fn(error_mw).layer(handler);
430 let req = simple_req(Method::GET, "/", String::from("test"));
431 let resp = service.call(&mut cx, req).await.unwrap();
432 let status = resp.status();
433 let (_, body) = resp.into_parts();
434 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
435 assert_eq!(body.into_string().await.unwrap(), "");
436 }
437
438 #[tokio::test]
439 async fn test_from_fn_with_optional_params() {
440 let handler = service_fn(print_body_handler);
441 let mut cx = empty_cx();
442
443 let service = from_fn(cors_mw).layer(handler);
444 let req = simple_req(Method::GET, "/", String::from(""));
445 let resp = service.call(&mut cx, req).await.unwrap();
446 assert_eq!(
447 resp.headers().get("Access-Control-Allow-Methods").unwrap(),
448 "GET"
449 );
450 assert_eq!(
451 resp.headers().get("Access-Control-Allow-Origin").unwrap(),
452 "/"
453 );
454 assert_eq!(
455 resp.headers().get("Access-Control-Allow-Headers").unwrap(),
456 "*"
457 );
458 }
459
460 #[tokio::test]
461 async fn test_from_fn_with_multiple_mws() {
462 let handler = service_fn(print_body_handler);
463 let mut cx = empty_cx();
464
465 let service = from_fn(cors_mw).layer(handler);
466 let service = from_fn(append_body_mw).layer(service);
467 let req = simple_req(Method::GET, "/", String::from(""));
468 let resp = service.call(&mut cx, req).await.unwrap();
469 let (parts, body) = resp.into_parts();
470 assert_eq!(
471 parts.headers.get("Access-Control-Allow-Methods").unwrap(),
472 "GET"
473 );
474 assert_eq!(
475 parts.headers.get("Access-Control-Allow-Origin").unwrap(),
476 "/"
477 );
478 assert_eq!(
479 parts.headers.get("Access-Control-Allow-Headers").unwrap(),
480 "*"
481 );
482 assert_eq!(body.into_string().await.unwrap(), "test");
483 }
484
485 #[tokio::test]
486 async fn test_from_fn_converts() {
487 async fn converter(
488 cx: &mut ServerContext,
489 req: Request<String>,
490 next: Next<FastStr>,
491 ) -> Response {
492 let (parts, body) = req.into_parts();
493 let s = body.into_faststr().await.unwrap();
494 let req = Request::from_parts(parts, s);
495 let _: Request<FastStr> = req;
496 next.run(cx, req).await.into_response()
497 }
498
499 async fn service(
500 _: &mut ServerContext,
501 _: Request<FastStr>,
502 ) -> Result<Response, Infallible> {
503 Ok(Response::new(String::from("Hello, World").into()))
504 }
505
506 let route = Route::new(get_service(service_fn(service)));
507 let service = from_fn(converter).layer(route);
508
509 let _: Result<Response, Infallible> = service
510 .call(
511 &mut empty_cx(),
512 simple_req(Method::GET, "/", String::from("")),
513 )
514 .await;
515 }
516
517 async fn index_handler() -> &'static str {
518 "Hello, World"
519 }
520
521 #[tokio::test]
522 async fn test_map_response() {
523 async fn append_header(resp: Response) -> ((&'static str, &'static str), Response) {
524 (("Server", "nginx"), resp)
525 }
526
527 let route: Route<String> = Route::new(any(index_handler));
528 let service = map_response(append_header).layer(route);
529
530 let mut cx = empty_cx();
531 let req = simple_req(Method::GET, "/", String::from(""));
532 let resp = service.call(&mut cx, req).await.unwrap();
533 let (parts, _) = resp.into_response().into_parts();
534 assert_eq!(parts.headers.get("Server").unwrap(), "nginx");
535 }
536}