use core::marker::PhantomData;
use crate::service::Service;
#[doc(hidden)]
mod marker {
pub struct EraseReqBody;
pub struct EraseResBody;
pub struct EraseErr;
}
use marker::*;
pub struct TypeEraser<M>(PhantomData<M>);
impl<M> Clone for TypeEraser<M> {
fn clone(&self) -> Self {
Self(PhantomData)
}
}
impl TypeEraser<EraseReqBody> {
pub const fn request_body() -> Self {
TypeEraser(PhantomData)
}
}
impl TypeEraser<EraseResBody> {
pub const fn response_body() -> Self {
TypeEraser(PhantomData)
}
}
impl TypeEraser<EraseErr> {
pub const fn error() -> Self {
TypeEraser(PhantomData)
}
}
impl<M, S, E> Service<Result<S, E>> for TypeEraser<M> {
type Response = service::EraserService<M, S>;
type Error = E;
async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
res.map(|service| service::EraserService {
service,
_erase: PhantomData,
})
}
}
mod service {
use core::cell::RefCell;
use crate::{
WebContext, body::BodyStream, body::ResponseBody, error::Error, http::WebResponse, service::ready::ReadyService,
};
use super::*;
pub struct EraserService<M, S> {
pub(super) service: S,
pub(super) _erase: PhantomData<M>,
}
impl<'r, S, C, ReqB, ResB, Err> Service<WebContext<'r, C, ReqB>> for EraserService<EraseReqBody, S>
where
S: for<'rs> Service<WebContext<'rs, C>, Response = WebResponse<ResB>, Error = Err>,
ReqB: BodyStream + Default + 'static,
ResB: BodyStream + 'static,
{
type Response = WebResponse;
type Error = Err;
async fn call(&self, mut ctx: WebContext<'r, C, ReqB>) -> Result<Self::Response, Self::Error> {
let body = ctx.take_body_mut();
let body = crate::body::downcast_body(body);
let mut body = RefCell::new(body);
let WebContext { req, ctx, .. } = ctx;
let res = self.service.call(WebContext::new(req, &mut body, ctx)).await?;
Ok(res.map(ResponseBody::boxed))
}
}
impl<S, Req, ResB> Service<Req> for EraserService<EraseResBody, S>
where
S: Service<Req, Response = WebResponse<ResB>>,
ResB: BodyStream + 'static,
{
type Response = WebResponse;
type Error = S::Error;
#[inline]
async fn call(&self, req: Req) -> Result<Self::Response, Self::Error> {
let res = self.service.call(req).await?;
Ok(res.map(ResponseBody::boxed))
}
}
impl<'r, C, B, S> Service<WebContext<'r, C, B>> for EraserService<EraseErr, S>
where
S: Service<WebContext<'r, C, B>>,
S::Error: Into<Error>,
{
type Response = S::Response;
type Error = Error;
#[inline]
async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
self.service.call(ctx).await.map_err(Into::into)
}
}
impl<M, S> ReadyService for EraserService<M, S>
where
S: ReadyService,
{
type Ready = S::Ready;
#[inline]
async fn ready(&self) -> Self::Ready {
self.service.ready().await
}
}
}
#[cfg(test)]
mod test {
use xitca_unsafe_collection::futures::NowOrPanic;
use crate::{
App, WebContext,
body::Full,
bytes::Bytes,
error::Error,
handler::handler_service,
http::{Request, StatusCode, WebResponse},
middleware::Group,
service::ServiceExt,
};
use super::*;
async fn handler(_: &WebContext<'_>) -> &'static str {
"996"
}
async fn map_body<S, C, B, Err>(_: &S, _: WebContext<'_, C, B>) -> Result<WebResponse<Full<Bytes>>, Err>
where
S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Err>,
{
Ok(WebResponse::new(Full::new(Bytes::new())))
}
async fn middleware_fn<S, C, B, Err>(s: &S, ctx: WebContext<'_, C, B>) -> Result<WebResponse, Err>
where
S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Err>,
{
s.call(ctx).await
}
#[test]
fn erase_body() {
let _ = App::new()
.at("/", handler_service(handler).enclosed_fn(map_body))
.enclosed(TypeEraser::response_body())
.enclosed_fn(middleware_fn)
.finish()
.call(())
.now_or_panic()
.unwrap()
.call(Request::default())
.now_or_panic()
.unwrap();
}
#[test]
fn erase_error() {
async fn middleware_fn<S, C, B, Err>(s: &S, ctx: WebContext<'_, C, B>) -> Result<WebResponse, StatusCode>
where
S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Err>,
{
s.call(ctx).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
async fn middleware_fn2<S, C, B>(s: &S, ctx: WebContext<'_, C, B>) -> Result<WebResponse, Error>
where
S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Error>,
{
s.call(ctx).await
}
let _ = App::new()
.at("/", handler_service(handler).enclosed(TypeEraser::error()))
.enclosed(
Group::new()
.enclosed_fn(middleware_fn)
.enclosed(TypeEraser::error())
.enclosed_fn(middleware_fn2),
)
.finish()
.call(())
.now_or_panic()
.unwrap()
.call(Request::default())
.now_or_panic()
.unwrap();
}
}