1use std::{
2 future::Future,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures_util::ready;
9use pin_project::pin_project;
10
11use crate::Svc;
12
13#[pin_project]
15pub struct ThenSvc<S1, S2, Int, Res> {
16 #[pin]
17 svc1: S1,
18 svc2: Option<S2>,
19 _int: PhantomData<Int>,
20 _res: PhantomData<Res>,
21}
22
23impl<S1, S2, Int, Res> ThenSvc<S1, S2, Int, Res> {
24 pub fn new(svc1: S1, svc2: S2) -> Self {
25 Self {
26 svc1,
27 svc2: Some(svc2),
28 _int: PhantomData,
29 _res: PhantomData,
30 }
31 }
32}
33
34impl<S1, S2, Req, Int, Res> Svc<Req> for ThenSvc<S1, S2, Int, Res>
35where
36 S1: Svc<Req, Res = Int>,
37 S2: Svc<Int, Res = Res>,
38{
39 type Res = Res;
40 type Fut = ThenSvcFut<S1, S2, Req, Int, Res>;
41
42 fn exec(self: Pin<&mut Self>, req: Req) -> Self::Fut {
43 let this = self.project();
44 ThenSvcFut {
45 state: State::Svc1(
46 this.svc1.exec(req),
47 this.svc2.take().expect("Service must not be executed twice."),
48 ),
49 }
50 }
51}
52
53#[pin_project]
54pub struct ThenSvcFut<S1, S2, Req, Int, Res>
55where
56 S1: Svc<Req, Res = Int>,
57 S2: Svc<Int, Res = Res>,
58{
59 #[pin]
60 state: State<S1, S2, Req, Int, Res>,
61}
62
63#[pin_project(project = StateProj)]
64enum State<S1, S2, Req, Int, Res>
65where
66 S1: Svc<Req, Res = Int>,
67 S2: Svc<Int, Res = Res>,
68{
69 Svc1(#[pin] S1::Fut, #[pin] S2),
70 Svc2(#[pin] S2::Fut),
71 Done,
72}
73
74impl<S1, S2, Req, Int, Res> Future for ThenSvcFut<S1, S2, Req, Int, Res>
75where
76 S1: Svc<Req, Res = Int>,
77 S2: Svc<Int, Res = Res>,
78{
79 type Output = Res;
80
81 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82 let mut this = self.as_mut().project();
83
84 match this.state.as_mut().project() {
85 StateProj::Svc1(s1_fut, s2) => {
86 let s1_res = ready!(s1_fut.poll(cx));
87 let s2_exec = s2.exec(s1_res);
88 this.state.set(State::Svc2(s2_exec));
89 self.poll(cx)
90 }
91 StateProj::Svc2(s2_fut) => {
92 let s2_res = ready!(s2_fut.poll(cx));
93 this.state.set(State::Done);
94 Poll::Ready(s2_res)
95 }
96 StateProj::Done => panic!("Future must not be polled after it returned `Poll::Ready`."),
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use futures_util::pin_mut;
104
105 use super::*;
106 use crate::FnSvc;
107
108 #[tokio::test]
109 async fn test_then_service() {
110 async fn doubler(n: u64) -> u64 {
111 n * 2
112 }
113
114 let svc1 = FnSvc::new(doubler);
115 let svc2 = FnSvc::new(doubler);
116 let bnf = ThenSvc::new(svc1, svc2);
117
118 pin_mut!(bnf);
119
120 let res = bnf.exec(42).await;
121 assert_eq!(res, 168);
122 }
123}