axum_help/filter/
future.rs1use super::predicate::AsyncPredicate;
2use axum::response::Response;
3use futures_core::ready;
4use pin_project_lite::pin_project;
5use std::{
6 future::Future,
7 marker::PhantomData,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tower::Service;
12
13pin_project! {
14 #[project = ResponseKind]
17 pub enum ResponseFuture<F> {
18 Future {#[pin] future: F },
19 Error { response: Option<Response> },
20 }
21}
22
23impl<F, E> Future for ResponseFuture<F>
24where
25 F: Future<Output = Result<Response, E>>,
26{
27 type Output = F::Output;
28
29 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
30 match self.project() {
31 ResponseKind::Future { future } => future.poll(cx),
32 ResponseKind::Error { response } => {
33 let response = response.take().unwrap();
34 Poll::Ready(Ok(response))
35 }
36 }
37 }
38}
39
40pin_project! {
41 pub struct AsyncResponseFuture<P, S, R>
44 where
45 P: AsyncPredicate<R>,
46 S: Service<P::Request>,
47 {
48 #[pin]
49 state: State<P::Future, S::Future>,
50 service: S,
51 _p: PhantomData<P>
52 }
53}
54
55pin_project! {
56 #[project = StateProj]
57 #[derive(Debug)]
58 enum State<F, G> {
59 Check { #[pin] check: F},
61 WaitResponse { #[pin] response: G}
63 }
64}
65
66impl<P, S, R> AsyncResponseFuture<P, S, R>
67where
68 P: AsyncPredicate<R>,
69 S: Service<P::Request>,
70{
71 pub(super) fn new(check: P::Future, service: S) -> Self {
72 Self {
73 state: State::Check { check },
74 service,
75 _p: PhantomData,
76 }
77 }
78}
79
80impl<P, S, R> Future for AsyncResponseFuture<P, S, R>
81where
82 P: AsyncPredicate<R>,
83 S: Service<P::Request, Response = <P as AsyncPredicate<R>>::Response>,
84{
85 type Output = Result<S::Response, S::Error>;
86
87 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88 let mut this = self.project();
89
90 loop {
91 match this.state.as_mut().project() {
92 StateProj::Check { mut check } => match ready!(check.as_mut().poll(cx)) {
93 Ok(request) => {
94 let response = this.service.call(request);
95 this.state.set(State::WaitResponse { response });
96 }
97 Err(e) => {
98 return Poll::Ready(Ok(e));
99 }
100 },
101
102 StateProj::WaitResponse { response } => {
103 return response.poll(cx);
104 }
105 }
106 }
107 }
108}