use futures_util::future::BoxFuture;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use super::{HttpContext, NextFn};
use crate::error::Error;
use crate::{HttpRequestMut, HttpResponse, HttpResult, http::FilterResult};
#[allow(clippy::large_enum_variant)]
enum NextState {
Pending(HttpContext, NextFn),
Running(BoxFuture<'static, HttpResult>),
}
pub struct Next {
state: Option<NextState>,
}
impl std::fmt::Debug for Next {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Next(..)")
}
}
impl Future for Next {
type Output = HttpResult;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this.state.take() {
None => Poll::Ready(Err(Error::server_error("Next polled after completion"))),
Some(NextState::Pending(ctx, next)) => {
let mut fut = next(ctx);
let poll = fut.as_mut().poll(cx);
if poll.is_pending() {
this.state = Some(NextState::Running(fut));
}
poll
}
Some(NextState::Running(mut fut)) => {
let poll = fut.as_mut().poll(cx);
if poll.is_pending() {
this.state = Some(NextState::Running(fut));
}
poll
}
}
}
}
impl Next {
pub fn new(ctx: HttpContext, next: NextFn) -> Self {
Self {
state: Some(NextState::Pending(ctx, next)),
}
}
}
pub trait Middleware: Send + Sync + 'static {
fn call(
&self,
ctx: HttpContext,
next: NextFn,
) -> impl Future<Output = HttpResult> + Send + 'static;
}
pub trait With<Args>: Clone + Send + Sync + 'static {
type Output;
fn with(&self, args: Args, next: Next) -> impl Future<Output = Self::Output> + Send;
}
pub trait Filter<Args>: Clone + Send + Sync + 'static {
type Output: Into<FilterResult>;
fn filter(&self, args: Args) -> impl Future<Output = Self::Output> + Send;
}
pub trait TapReq<Args = ()>: Clone + Send + Sync + 'static {
type Output;
fn tap_req(&self, req: HttpRequestMut, args: Args)
-> impl Future<Output = Self::Output> + Send;
}
pub trait MapOk<Args>: Clone + Send + Sync + 'static {
type Output;
fn map_ok(&self, resp: HttpResponse, args: Args) -> impl Future<Output = Self::Output> + Send;
}
impl<Func, Fut: Send> Middleware for Func
where
Func: Fn(HttpContext, NextFn) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HttpResult> + Send + 'static,
{
#[inline]
fn call(
&self,
ctx: HttpContext,
next: NextFn,
) -> impl Future<Output = HttpResult> + Send + 'static {
self(ctx, next)
}
}
#[cfg(not(feature = "di"))]
impl<Func, Fut: Send> TapReq for Func
where
Func: Fn(HttpRequestMut) -> Fut + Send + Sync + Clone + 'static,
Fut: Future,
{
type Output = Fut::Output;
#[inline]
fn tap_req(&self, req: HttpRequestMut, _args: ()) -> impl Future<Output = Self::Output> + Send {
self(req)
}
}
macro_rules! define_generic_mw_handler ({ $($param:ident)* } => {
impl<Func, Fut: Send, $($param,)*> With<($($param,)*)> for Func
where
Func: Fn($($param,)* Next) -> Fut + Send + Sync + Clone + 'static,
Fut: Future,
{
type Output = Fut::Output;
#[inline]
#[allow(non_snake_case)]
fn with(&self, ($($param,)*): ($($param,)*), next: Next) -> impl Future<Output = Self::Output> {
(self)($($param,)* next)
}
}
impl<Func, Fut: Send, $($param,)*> Filter<($($param,)*)> for Func
where
Func: Fn($($param,)*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future,
Fut::Output: Into<FilterResult>,
{
type Output = Fut::Output;
#[inline]
#[allow(non_snake_case)]
fn filter(&self, ($($param,)*): ($($param,)*)) -> impl Future<Output = Self::Output> {
(self)($($param,)*)
}
}
#[cfg(feature = "di")]
impl<Func, Fut: Send, $($param,)*> TapReq<($($param,)*)> for Func
where
Func: Fn(HttpRequestMut,$($param,)*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future,
{
type Output = Fut::Output;
#[inline]
#[allow(non_snake_case)]
fn tap_req(&self, req: HttpRequestMut, ($($param,)*): ($($param,)*)) -> impl Future<Output = Self::Output> {
(self)(req, $($param,)*)
}
}
impl<Func, Fut: Send, $($param,)*> MapOk<($($param,)*)> for Func
where
Func: Fn(HttpResponse,$($param,)*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future,
{
type Output = Fut::Output;
#[inline]
#[allow(non_snake_case)]
fn map_ok(&self, resp: HttpResponse, ($($param,)*): ($($param,)*)) -> impl Future<Output = Self::Output> {
(self)(resp, $($param,)*)
}
}
});
define_generic_mw_handler! {}
define_generic_mw_handler! { T1 }
define_generic_mw_handler! { T1 T2 }
define_generic_mw_handler! { T1 T2 T3 }
define_generic_mw_handler! { T1 T2 T3 T4 }
define_generic_mw_handler! { T1 T2 T3 T4 T5 }
define_generic_mw_handler! { T1 T2 T3 T4 T5 T6 }
define_generic_mw_handler! { T1 T2 T3 T4 T5 T6 T7 }
define_generic_mw_handler! { T1 T2 T3 T4 T5 T6 T7 T8 }
define_generic_mw_handler! { T1 T2 T3 T4 T5 T6 T7 T8 T9 }
define_generic_mw_handler! { T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 }
#[cfg(test)]
mod tests {
use super::{MapOk, Next, NextState, With};
use crate::error::Error;
use crate::{HttpBody, HttpResponse, status};
use futures_util::task::noop_waker_ref;
use std::pin::Pin;
use std::task::{Context, Poll};
#[test]
fn next_returns_error_when_polled_after_completion() {
let mut next = Next {
state: Some(NextState::Running(Box::pin(async { status!(204) }))),
};
let waker = noop_waker_ref();
let mut cx = Context::from_waker(waker);
let mut pinned = Pin::new(&mut next);
match pinned.as_mut().poll(&mut cx) {
Poll::Ready(Ok(_)) => {}
other => panic!("unexpected poll result: {other:?}"),
}
match pinned.as_mut().poll(&mut cx) {
Poll::Ready(Err(err)) => {
assert!(err.to_string().contains("Next polled after completion"));
}
other => panic!("expected error after completion, got {other:?}"),
}
}
#[tokio::test]
async fn middleware_handler_invokes_function_with_next() {
let next = Next {
state: Some(NextState::Running(Box::pin(async { status!(204) }))),
};
let handler = |value: u8, next: Next| async move {
assert_eq!(value, 7);
next.await
};
let response = With::with(&handler, (7,), next).await.unwrap();
assert_eq!(response.status(), 204);
}
#[tokio::test]
async fn map_ok_handler_invokes_function() {
let handler = |resp: HttpResponse, extra: &'static str| async move {
assert_eq!(resp.status(), 200);
assert_eq!(extra, "ok");
Ok::<HttpResponse, Error>(resp)
};
let response = HttpResponse::builder()
.status(200)
.body(HttpBody::from("ok"))
.unwrap();
let result = MapOk::map_ok(&handler, response, ("ok",)).await;
assert!(result.is_ok());
}
}