use std::{convert::Infallible, marker::PhantomData, sync::Arc};
use motore::{layer::Layer, service::Service};
use super::{
IntoResponse,
handler::{MiddlewareHandlerFromFn, MiddlewareHandlerMapResponse},
route::Route,
};
use crate::{body::Body, context::ServerContext, request::Request, response::Response};
pub struct FromFnLayer<F, T, B, B2, E2> {
f: F,
#[allow(clippy::type_complexity)]
_marker: PhantomData<fn(T, B, B2, E2)>,
}
impl<F, T, B, B2, E2> Clone for FromFnLayer<F, T, B, B2, E2>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
_marker: self._marker,
}
}
}
pub fn from_fn<F, T, B, B2, E2>(f: F) -> FromFnLayer<F, T, B, B2, E2> {
FromFnLayer {
f,
_marker: PhantomData,
}
}
impl<S, F, T, B, B2, E2> Layer<S> for FromFnLayer<F, T, B, B2, E2>
where
S: Service<ServerContext, Request<B2>, Response = Response, Error = E2> + Send + Sync + 'static,
{
type Service = FromFn<Arc<S>, F, T, B, B2, E2>;
fn layer(self, service: S) -> Self::Service {
FromFn {
service: Arc::new(service),
f: self.f,
_marker: PhantomData,
}
}
}
pub struct FromFn<S, F, T, B, B2, E2> {
service: S,
f: F,
_marker: PhantomData<fn(T, B, B2, E2)>,
}
impl<S, F, T, B, B2, E2> Clone for FromFn<S, F, T, B, B2, E2>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
f: self.f.clone(),
_marker: self._marker,
}
}
}
impl<S, F, T, B, B2, E2> Service<ServerContext, Request<B>> for FromFn<S, F, T, B, B2, E2>
where
S: Service<ServerContext, Request<B2>, Response = Response, Error = E2>
+ Clone
+ Send
+ Sync
+ 'static,
F: for<'r> MiddlewareHandlerFromFn<'r, T, B, B2, E2> + Sync,
B: Send,
B2: 'static,
{
type Response = Response;
type Error = Infallible;
async fn call(
&self,
cx: &mut ServerContext,
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let next = Next {
service: Route::new(self.service.clone()),
};
Ok(self.f.handle(cx, req, next).await.into_response())
}
}
pub struct Next<B = Body, E = Infallible> {
service: Route<B, E>,
}
impl<B, E> Next<B, E> {
pub async fn run(self, cx: &mut ServerContext, req: Request<B>) -> Result<Response, E> {
self.service.call(cx, req).await
}
}
pub struct MapResponseLayer<F, T, R1, R2> {
f: F,
_marker: PhantomData<fn(T, R1, R2)>,
}
impl<F, T, R1, R2> Clone for MapResponseLayer<F, T, R1, R2>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
_marker: self._marker,
}
}
}
pub fn map_response<F, T, R1, R2>(f: F) -> MapResponseLayer<F, T, R1, R2> {
MapResponseLayer {
f,
_marker: PhantomData,
}
}
impl<S, F, T, R1, R2> Layer<S> for MapResponseLayer<F, T, R1, R2> {
type Service = MapResponse<S, F, T, R1, R2>;
fn layer(self, service: S) -> Self::Service {
MapResponse {
service,
f: self.f,
_marker: self._marker,
}
}
}
pub struct MapResponse<S, F, T, R1, R2> {
service: S,
f: F,
_marker: PhantomData<fn(T, R1, R2)>,
}
impl<S, F, T, R1, R2> Clone for MapResponse<S, F, T, R1, R2>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
f: self.f.clone(),
_marker: self._marker,
}
}
}
impl<S, F, T, Req, R1, R2> Service<ServerContext, Req> for MapResponse<S, F, T, R1, R2>
where
S: Service<ServerContext, Req, Response = R1> + Send + Sync,
F: for<'r> MiddlewareHandlerMapResponse<'r, T, R1, R2> + Sync,
Req: Send,
{
type Response = R2;
type Error = S::Error;
async fn call(&self, cx: &mut ServerContext, req: Req) -> Result<Self::Response, Self::Error> {
let resp = self.service.call(cx, req).await?;
Ok(self.f.handle(cx, resp).await)
}
}
#[cfg(test)]
mod middleware_tests {
use faststr::FastStr;
use http::{HeaderValue, Method, StatusCode, Uri};
use motore::service::service_fn;
use super::*;
use crate::{
body::{Body, BodyConversion},
context::ServerContext,
request::Request,
response::Response,
server::{
response::IntoResponse,
route::{any, get_service},
test_helpers::empty_cx,
},
utils::test_helpers::simple_req,
};
async fn print_body_handler(
_: &mut ServerContext,
req: Request<String>,
) -> Result<Response<Body>, Infallible> {
Ok(Response::new(req.into_body().into()))
}
async fn append_body_mw(
cx: &mut ServerContext,
req: Request<String>,
next: Next<String>,
) -> Response {
let (parts, mut body) = req.into_parts();
body += "test";
let req = Request::from_parts(parts, body);
next.run(cx, req).await.into_response()
}
async fn cors_mw(
method: Method,
url: Uri,
cx: &mut ServerContext,
req: Request<String>,
next: Next<String>,
) -> Response {
let mut resp = next.run(cx, req).await.into_response();
resp.headers_mut().insert(
"Access-Control-Allow-Methods",
HeaderValue::from_str(method.as_str()).unwrap(),
);
resp.headers_mut().insert(
"Access-Control-Allow-Origin",
HeaderValue::from_str(url.to_string().as_str()).unwrap(),
);
resp.headers_mut().insert(
"Access-Control-Allow-Headers",
HeaderValue::from_str("*").unwrap(),
);
resp
}
#[tokio::test]
async fn test_from_fn_with_necessary_params() {
let handler = service_fn(print_body_handler);
let mut cx = empty_cx();
let service = from_fn(append_body_mw).layer(handler);
let req = simple_req(Method::GET, "/", String::from(""));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!(resp.into_body().into_string().await.unwrap(), "test");
async fn error_mw(
_: &mut ServerContext,
_: Request<String>,
_: Next<String>,
) -> Result<Response, StatusCode> {
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
let service = from_fn(error_mw).layer(handler);
let req = simple_req(Method::GET, "/", String::from("test"));
let resp = service.call(&mut cx, req).await.unwrap();
let status = resp.status();
let (_, body) = resp.into_parts();
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(body.into_string().await.unwrap(), "");
}
#[tokio::test]
async fn test_from_fn_with_optional_params() {
let handler = service_fn(print_body_handler);
let mut cx = empty_cx();
let service = from_fn(cors_mw).layer(handler);
let req = simple_req(Method::GET, "/", String::from(""));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!(
resp.headers().get("Access-Control-Allow-Methods").unwrap(),
"GET"
);
assert_eq!(
resp.headers().get("Access-Control-Allow-Origin").unwrap(),
"/"
);
assert_eq!(
resp.headers().get("Access-Control-Allow-Headers").unwrap(),
"*"
);
}
#[tokio::test]
async fn test_from_fn_with_multiple_mws() {
let handler = service_fn(print_body_handler);
let mut cx = empty_cx();
let service = from_fn(cors_mw).layer(handler);
let service = from_fn(append_body_mw).layer(service);
let req = simple_req(Method::GET, "/", String::from(""));
let resp = service.call(&mut cx, req).await.unwrap();
let (parts, body) = resp.into_parts();
assert_eq!(
parts.headers.get("Access-Control-Allow-Methods").unwrap(),
"GET"
);
assert_eq!(
parts.headers.get("Access-Control-Allow-Origin").unwrap(),
"/"
);
assert_eq!(
parts.headers.get("Access-Control-Allow-Headers").unwrap(),
"*"
);
assert_eq!(body.into_string().await.unwrap(), "test");
}
#[tokio::test]
async fn test_from_fn_converts() {
async fn converter(
cx: &mut ServerContext,
req: Request<String>,
next: Next<FastStr>,
) -> Response {
let (parts, body) = req.into_parts();
let s = body.into_faststr().await.unwrap();
let req = Request::from_parts(parts, s);
let _: Request<FastStr> = req;
next.run(cx, req).await.into_response()
}
async fn service(
_: &mut ServerContext,
_: Request<FastStr>,
) -> Result<Response, Infallible> {
Ok(Response::new(String::from("Hello, World").into()))
}
let route = Route::new(get_service(service_fn(service)));
let service = from_fn(converter).layer(route);
let _: Result<Response, Infallible> = service
.call(
&mut empty_cx(),
simple_req(Method::GET, "/", String::from("")),
)
.await;
}
async fn index_handler() -> &'static str {
"Hello, World"
}
#[tokio::test]
async fn test_map_response() {
async fn append_header(resp: Response) -> ((&'static str, &'static str), Response) {
(("Server", "nginx"), resp)
}
let route: Route<String> = Route::new(any(index_handler));
let service = map_response(append_header).layer(route);
let mut cx = empty_cx();
let req = simple_req(Method::GET, "/", String::from(""));
let resp = service.call(&mut cx, req).await.unwrap();
let (parts, _) = resp.into_response().into_parts();
assert_eq!(parts.headers.get("Server").unwrap(), "nginx");
}
}