1use super::{util, Service, ServiceCtx, ServiceFactory};
2
3#[derive(Debug, Clone)]
4pub struct Then<A, B> {
9 svc1: A,
10 svc2: B,
11}
12
13impl<A, B> Then<A, B> {
14 pub(crate) fn new(svc1: A, svc2: B) -> Then<A, B> {
16 Self { svc1, svc2 }
17 }
18}
19
20impl<A, B, R> Service<R> for Then<A, B>
21where
22 A: Service<R>,
23 B: Service<Result<A::Response, A::Error>, Error = A::Error>,
24{
25 type Response = B::Response;
26 type Error = B::Error;
27
28 #[inline]
29 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
30 util::ready(&self.svc1, &self.svc2, ctx).await
31 }
32
33 #[inline]
34 fn poll(&self, cx: &mut std::task::Context<'_>) -> Result<(), Self::Error> {
35 self.svc1.poll(cx)?;
36 self.svc2.poll(cx)
37 }
38
39 #[inline]
40 async fn shutdown(&self) {
41 util::shutdown(&self.svc1, &self.svc2).await
42 }
43
44 #[inline]
45 async fn call(
46 &self,
47 req: R,
48 ctx: ServiceCtx<'_, Self>,
49 ) -> Result<Self::Response, Self::Error> {
50 ctx.call(&self.svc2, ctx.call(&self.svc1, req).await).await
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct ThenFactory<A, B> {
57 svc1: A,
58 svc2: B,
59}
60
61impl<A, B> ThenFactory<A, B> {
62 pub(crate) fn new(svc1: A, svc2: B) -> Self {
64 Self { svc1, svc2 }
65 }
66}
67
68impl<A, B, R, C> ServiceFactory<R, C> for ThenFactory<A, B>
69where
70 A: ServiceFactory<R, C>,
71 B: ServiceFactory<
72 Result<A::Response, A::Error>,
73 C,
74 Error = A::Error,
75 InitError = A::InitError,
76 >,
77 C: Clone,
78{
79 type Response = B::Response;
80 type Error = A::Error;
81
82 type Service = Then<A::Service, B::Service>;
83 type InitError = A::InitError;
84
85 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
86 Ok(Then {
87 svc1: self.svc1.create(cfg.clone()).await?,
88 svc2: self.svc2.create(cfg).await?,
89 })
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use ntex_util::future::lazy;
96 use std::{cell::Cell, rc::Rc, task::Context};
97
98 use crate::{chain, chain_factory, fn_factory, Service, ServiceCtx};
99
100 #[derive(Clone)]
101 struct Srv1(Rc<Cell<usize>>, Rc<Cell<usize>>);
102
103 impl Service<Result<&'static str, &'static str>> for Srv1 {
104 type Response = &'static str;
105 type Error = ();
106
107 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
108 self.0.set(self.0.get() + 1);
109 Ok(())
110 }
111
112 fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
113 self.0.set(self.0.get() + 1);
114 Ok(())
115 }
116
117 async fn call(
118 &self,
119 req: Result<&'static str, &'static str>,
120 _: ServiceCtx<'_, Self>,
121 ) -> Result<&'static str, ()> {
122 match req {
123 Ok(msg) => Ok(msg),
124 Err(_) => Err(()),
125 }
126 }
127
128 async fn shutdown(&self) {
129 self.1.set(self.1.get() + 1);
130 }
131 }
132
133 #[derive(Clone)]
134 struct Srv2(Rc<Cell<usize>>, Rc<Cell<usize>>);
135
136 impl Service<Result<&'static str, ()>> for Srv2 {
137 type Response = (&'static str, &'static str);
138 type Error = ();
139
140 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
141 self.0.set(self.0.get() + 1);
142 Ok(())
143 }
144
145 fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
146 self.0.set(self.0.get() + 1);
147 Ok(())
148 }
149
150 async fn call(
151 &self,
152 req: Result<&'static str, ()>,
153 _: ServiceCtx<'_, Self>,
154 ) -> Result<Self::Response, ()> {
155 match req {
156 Ok(msg) => Ok((msg, "ok")),
157 Err(()) => Ok(("srv2", "err")),
158 }
159 }
160
161 async fn shutdown(&self) {
162 self.1.set(self.1.get() + 1);
163 }
164 }
165
166 #[ntex::test]
167 async fn test_ready() {
168 let cnt = Rc::new(Cell::new(0));
169 let cnt_sht = Rc::new(Cell::new(0));
170 let srv = chain(Srv1(cnt.clone(), cnt_sht.clone()))
171 .then(Srv2(cnt.clone(), cnt_sht.clone()))
172 .into_pipeline();
173 let res = srv.ready().await;
174 assert_eq!(res, Ok(()));
175 assert_eq!(cnt.get(), 2);
176
177 lazy(|cx| srv.clone().poll(cx)).await.unwrap();
178 assert_eq!(cnt.get(), 4);
179
180 srv.shutdown().await;
181 assert_eq!(cnt_sht.get(), 2);
182 }
183
184 #[ntex::test]
185 async fn test_call() {
186 let cnt = Rc::new(Cell::new(0));
187 let srv = chain(Srv1(cnt.clone(), Rc::new(Cell::new(0))))
188 .then(Srv2(cnt, Rc::new(Cell::new(0))))
189 .clone()
190 .into_pipeline();
191
192 let res = srv.call(Ok("srv1")).await;
193 assert!(res.is_ok());
194 assert_eq!(res.unwrap(), ("srv1", "ok"));
195
196 let res = srv.call(Err("srv")).await;
197 assert!(res.is_ok());
198 assert_eq!(res.unwrap(), ("srv2", "err"));
199 }
200
201 #[ntex::test]
202 async fn test_factory() {
203 let cnt = Rc::new(Cell::new(0));
204 let cnt2 = cnt.clone();
205 let blank = fn_factory(move || {
206 let cnt = cnt2.clone();
207 async move { Ok::<_, ()>(Srv1(cnt, Rc::new(Cell::new(0)))) }
208 });
209 let factory = chain_factory(blank)
210 .then(fn_factory(move || {
211 let cnt = cnt.clone();
212 async move { Ok(Srv2(cnt.clone(), Rc::new(Cell::new(0)))) }
213 }))
214 .clone();
215 let srv = factory.pipeline(&()).await.unwrap();
216 let res = srv.call(Ok("srv1")).await;
217 assert!(res.is_ok());
218 assert_eq!(res.unwrap(), ("srv1", "ok"));
219
220 let res = srv.call(Err("srv")).await;
221 assert!(res.is_ok());
222 assert_eq!(res.unwrap(), ("srv2", "err"));
223 }
224}