axum_help/filter/
future.rs

1use super::predicate::AsyncPredicate;
2use axum::response::Response;
3use futures_core::ready;
4use pin_project_lite::pin_project;
5use std::{
6    future::Future,
7    marker::PhantomData,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tower::Service;
12
13pin_project! {
14    /// Filtered response future from [`FilterEx`] services.
15    ///
16    #[project = ResponseKind]
17    pub enum ResponseFuture<F> {
18        Future {#[pin] future: F },
19        Error { response: Option<Response> },
20    }
21}
22
23impl<F, E> Future for ResponseFuture<F>
24where
25    F: Future<Output = Result<Response, E>>,
26{
27    type Output = F::Output;
28
29    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
30        match self.project() {
31            ResponseKind::Future { future } => future.poll(cx),
32            ResponseKind::Error { response } => {
33                let response = response.take().unwrap();
34                Poll::Ready(Ok(response))
35            }
36        }
37    }
38}
39
40pin_project! {
41    /// Filtered response future from [`AsyncFilterEx`](super::AsyncFilterEx) services.
42    ///
43    pub struct AsyncResponseFuture<P, S, R>
44    where
45        P:  AsyncPredicate<R>,
46        S: Service<P::Request>,
47    {
48        #[pin]
49        state: State<P::Future, S::Future>,
50        service: S,
51        _p: PhantomData<P>
52    }
53}
54
55pin_project! {
56    #[project = StateProj]
57    #[derive(Debug)]
58    enum State<F, G> {
59        /// Waiting for the predicate future
60        Check { #[pin] check: F},
61        /// Waiting for the response future
62        WaitResponse { #[pin] response: G}
63    }
64}
65
66impl<P, S, R> AsyncResponseFuture<P, S, R>
67where
68    P: AsyncPredicate<R>,
69    S: Service<P::Request>,
70{
71    pub(super) fn new(check: P::Future, service: S) -> Self {
72        Self {
73            state: State::Check { check },
74            service,
75            _p: PhantomData,
76        }
77    }
78}
79
80impl<P, S, R> Future for AsyncResponseFuture<P, S, R>
81where
82    P: AsyncPredicate<R>,
83    S: Service<P::Request, Response = <P as AsyncPredicate<R>>::Response>,
84{
85    type Output = Result<S::Response, S::Error>;
86
87    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88        let mut this = self.project();
89
90        loop {
91            match this.state.as_mut().project() {
92                StateProj::Check { mut check } => match ready!(check.as_mut().poll(cx)) {
93                    Ok(request) => {
94                        let response = this.service.call(request);
95                        this.state.set(State::WaitResponse { response });
96                    }
97                    Err(e) => {
98                        return Poll::Ready(Ok(e));
99                    }
100                },
101
102                StateProj::WaitResponse { response } => {
103                    return response.poll(cx);
104                }
105            }
106        }
107    }
108}