use std::marker::PhantomData;
use motore::{Service, layer::Layer};
use crate::{
context::ServerContext,
request::Request,
response::Response,
server::{IntoResponse, handler::HandlerWithoutRequest},
};
#[derive(Clone)]
pub struct FilterLayer<H, R, T> {
handler: H,
_marker: PhantomData<(R, T)>,
}
impl<H, R, T> FilterLayer<H, R, T> {
pub fn new(h: H) -> Self {
Self {
handler: h,
_marker: PhantomData,
}
}
}
impl<S, H, R, T> Layer<S> for FilterLayer<H, R, T>
where
S: Send + Sync + 'static,
H: Clone + Send + Sync + 'static,
T: Sync,
{
type Service = Filter<S, H, R, T>;
fn layer(self, inner: S) -> Self::Service {
Filter {
service: inner,
handler: self.handler,
_marker: PhantomData,
}
}
}
#[derive(Clone)]
pub struct Filter<S, H, R, T> {
service: S,
handler: H,
_marker: PhantomData<(R, T)>,
}
impl<S, B, H, R, T> Service<ServerContext, Request<B>> for Filter<S, H, R, T>
where
S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
S::Response: IntoResponse,
S::Error: IntoResponse,
B: Send,
H: HandlerWithoutRequest<T, Result<(), R>> + Clone + Send + Sync + 'static,
R: IntoResponse + Send + Sync,
T: Sync,
{
type Response = Response;
type Error = S::Error;
async fn call(
&self,
cx: &mut ServerContext,
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let (mut parts, body) = req.into_parts();
let res = self.handler.clone().handle(cx, &mut parts).await;
let req = Request::from_parts(parts, body);
match res {
Ok(Ok(())) => self
.service
.call(cx, req)
.await
.map(IntoResponse::into_response),
Ok(Err(res)) => Ok(res.into_response()),
Err(rej) => {
tracing::warn!("[Volo-HTTP] FilterLayer: something wrong while extracting");
Ok(rej.into_response())
}
}
}
}
#[cfg(test)]
mod filter_tests {
use http::{Method, StatusCode};
use motore::{Service, layer::Layer};
use crate::{
body::BodyConversion,
server::{
route::{Route, any},
test_helpers::empty_cx,
},
utils::test_helpers::simple_req,
};
#[tokio::test]
async fn test_filter_layer() {
use crate::server::layer::FilterLayer;
async fn reject_post(method: Method) -> Result<(), StatusCode> {
if method == Method::POST {
Err(StatusCode::METHOD_NOT_ALLOWED)
} else {
Ok(())
}
}
async fn handler() -> &'static str {
"Hello, World"
}
let filter_layer = FilterLayer::new(reject_post);
let route: Route<&str> = Route::new(any(handler));
let service = filter_layer.layer(route);
let mut cx = empty_cx();
let req = simple_req(Method::GET, "/", "");
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!(
resp.into_body().into_string().await.unwrap(),
"Hello, World"
);
let req = simple_req(Method::POST, "/", "");
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
}
}