async_svc/
then.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures_util::ready;
9use pin_project::pin_project;
10
11use crate::Svc;
12
13/// Service1(Response) -> Intermediate => Service2(Intermediate) -> Response
14#[pin_project]
15pub struct ThenSvc<S1, S2, Int, Res> {
16    #[pin]
17    svc1: S1,
18    svc2: Option<S2>,
19    _int: PhantomData<Int>,
20    _res: PhantomData<Res>,
21}
22
23impl<S1, S2, Int, Res> ThenSvc<S1, S2, Int, Res> {
24    pub fn new(svc1: S1, svc2: S2) -> Self {
25        Self {
26            svc1,
27            svc2: Some(svc2),
28            _int: PhantomData,
29            _res: PhantomData,
30        }
31    }
32}
33
34impl<S1, S2, Req, Int, Res> Svc<Req> for ThenSvc<S1, S2, Int, Res>
35where
36    S1: Svc<Req, Res = Int>,
37    S2: Svc<Int, Res = Res>,
38{
39    type Res = Res;
40    type Fut = ThenSvcFut<S1, S2, Req, Int, Res>;
41
42    fn exec(self: Pin<&mut Self>, req: Req) -> Self::Fut {
43        let this = self.project();
44        ThenSvcFut {
45            state: State::Svc1(
46                this.svc1.exec(req),
47                this.svc2.take().expect("Service must not be executed twice."),
48            ),
49        }
50    }
51}
52
53#[pin_project]
54pub struct ThenSvcFut<S1, S2, Req, Int, Res>
55where
56    S1: Svc<Req, Res = Int>,
57    S2: Svc<Int, Res = Res>,
58{
59    #[pin]
60    state: State<S1, S2, Req, Int, Res>,
61}
62
63#[pin_project(project = StateProj)]
64enum State<S1, S2, Req, Int, Res>
65where
66    S1: Svc<Req, Res = Int>,
67    S2: Svc<Int, Res = Res>,
68{
69    Svc1(#[pin] S1::Fut, #[pin] S2),
70    Svc2(#[pin] S2::Fut),
71    Done,
72}
73
74impl<S1, S2, Req, Int, Res> Future for ThenSvcFut<S1, S2, Req, Int, Res>
75where
76    S1: Svc<Req, Res = Int>,
77    S2: Svc<Int, Res = Res>,
78{
79    type Output = Res;
80
81    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82        let mut this = self.as_mut().project();
83
84        match this.state.as_mut().project() {
85            StateProj::Svc1(s1_fut, s2) => {
86                let s1_res = ready!(s1_fut.poll(cx));
87                let s2_exec = s2.exec(s1_res);
88                this.state.set(State::Svc2(s2_exec));
89                self.poll(cx)
90            }
91            StateProj::Svc2(s2_fut) => {
92                let s2_res = ready!(s2_fut.poll(cx));
93                this.state.set(State::Done);
94                Poll::Ready(s2_res)
95            }
96            StateProj::Done => panic!("Future must not be polled after it returned `Poll::Ready`."),
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use futures_util::pin_mut;
104
105    use super::*;
106    use crate::FnSvc;
107
108    #[tokio::test]
109    async fn test_then_service() {
110        async fn doubler(n: u64) -> u64 {
111            n * 2
112        }
113
114        let svc1 = FnSvc::new(doubler);
115        let svc2 = FnSvc::new(doubler);
116        let bnf = ThenSvc::new(svc1, svc2);
117
118        pin_mut!(bnf);
119
120        let res = bnf.exec(42).await;
121        assert_eq!(res, 168);
122    }
123}