1use std::{fmt, task::Context};
3
4use ntex_service::{Service, ServiceCtx, ServiceFactory};
5
6use crate::future::Either;
7
8#[derive(Clone)]
9pub struct EitherService<SLeft, SRight> {
13 svc: Either<SLeft, SRight>,
14}
15
16#[derive(Clone)]
17pub struct EitherServiceFactory<ChooseFn, SFLeft, SFRight> {
21 left: SFLeft,
22 right: SFRight,
23 choose_left_fn: ChooseFn,
24}
25
26impl<ChooseFn, SFLeft, SFRight> EitherServiceFactory<ChooseFn, SFLeft, SFRight> {
27 pub fn new(choose_left_fn: ChooseFn, sf_left: SFLeft, sf_right: SFRight) -> Self {
29 EitherServiceFactory {
30 choose_left_fn,
31 left: sf_left,
32 right: sf_right,
33 }
34 }
35}
36
37impl<ChooseFn, SFLeft, SFRight> fmt::Debug
38 for EitherServiceFactory<ChooseFn, SFLeft, SFRight>
39{
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.debug_struct("EitherServiceFactory")
42 .field("left", &std::any::type_name::<SFLeft>())
43 .field("right", &std::any::type_name::<SFRight>())
44 .field("choose_fn", &std::any::type_name::<ChooseFn>())
45 .finish()
46 }
47}
48
49impl<R, C, ChooseFn, SFLeft, SFRight> ServiceFactory<R, C>
50 for EitherServiceFactory<ChooseFn, SFLeft, SFRight>
51where
52 ChooseFn: Fn(&C) -> bool,
53 SFLeft: ServiceFactory<R, C>,
54 SFRight: ServiceFactory<
55 R,
56 C,
57 Response = SFLeft::Response,
58 InitError = SFLeft::InitError,
59 Error = SFLeft::Error,
60 >,
61{
62 type Response = SFLeft::Response;
63 type Error = SFLeft::Error;
64 type InitError = SFLeft::InitError;
65 type Service = EitherService<SFLeft::Service, SFRight::Service>;
66
67 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
68 let choose_left = (self.choose_left_fn)(&cfg);
69
70 if choose_left {
71 let svc = self.left.create(cfg).await?;
72 Ok(EitherService {
73 svc: Either::Left(svc),
74 })
75 } else {
76 let svc = self.right.create(cfg).await?;
77 Ok(EitherService {
78 svc: Either::Right(svc),
79 })
80 }
81 }
82}
83
84impl<SLeft, SRight> EitherService<SLeft, SRight> {
85 pub fn left(svc: SLeft) -> Self {
87 EitherService {
88 svc: Either::Left(svc),
89 }
90 }
91
92 pub fn right(svc: SRight) -> Self {
94 EitherService {
95 svc: Either::Right(svc),
96 }
97 }
98}
99
100impl<SLeft, SRight> fmt::Debug for EitherService<SLeft, SRight> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("EitherService")
103 .field("left", &std::any::type_name::<SLeft>())
104 .field("right", &std::any::type_name::<SRight>())
105 .finish()
106 }
107}
108
109impl<Req, SLeft, SRight> Service<Req> for EitherService<SLeft, SRight>
110where
111 SLeft: Service<Req>,
112 SRight: Service<Req, Response = SLeft::Response, Error = SLeft::Error>,
113{
114 type Response = SLeft::Response;
115 type Error = SLeft::Error;
116
117 #[inline]
118 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
119 match self.svc {
120 Either::Left(ref svc) => ctx.ready(svc).await,
121 Either::Right(ref svc) => ctx.ready(svc).await,
122 }
123 }
124
125 #[inline]
126 async fn shutdown(&self) {
127 match self.svc {
128 Either::Left(ref svc) => svc.shutdown().await,
129 Either::Right(ref svc) => svc.shutdown().await,
130 }
131 }
132
133 #[inline]
134 async fn call(
135 &self,
136 req: Req,
137 ctx: ServiceCtx<'_, Self>,
138 ) -> Result<Self::Response, Self::Error> {
139 match self.svc {
140 Either::Left(ref svc) => ctx.call(svc, req).await,
141 Either::Right(ref svc) => ctx.call(svc, req).await,
142 }
143 }
144
145 #[inline]
146 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
147 match self.svc {
148 Either::Left(ref svc) => svc.poll(cx),
149 Either::Right(ref svc) => svc.poll(cx),
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use ntex_service::{Pipeline, ServiceFactory};
157
158 use super::*;
159
160 #[derive(Copy, Clone, Debug, PartialEq)]
161 struct Svc1;
162 impl Service<()> for Svc1 {
163 type Response = &'static str;
164 type Error = ();
165
166 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<&'static str, ()> {
167 Ok("svc1")
168 }
169 }
170
171 #[derive(Clone)]
172 struct Svc1Factory;
173 impl ServiceFactory<(), &'static str> for Svc1Factory {
174 type Response = &'static str;
175 type Error = ();
176 type InitError = ();
177 type Service = Svc1;
178
179 async fn create(&self, _: &'static str) -> Result<Self::Service, Self::InitError> {
180 Ok(Svc1)
181 }
182 }
183
184 #[derive(Copy, Clone, Debug, PartialEq)]
185 struct Svc2;
186 impl Service<()> for Svc2 {
187 type Response = &'static str;
188 type Error = ();
189
190 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<&'static str, ()> {
191 Ok("svc2")
192 }
193 }
194
195 #[derive(Clone)]
196 struct Svc2Factory;
197 impl ServiceFactory<(), &'static str> for Svc2Factory {
198 type Response = &'static str;
199 type Error = ();
200 type InitError = ();
201 type Service = Svc2;
202
203 async fn create(&self, _: &'static str) -> Result<Self::Service, Self::InitError> {
204 Ok(Svc2)
205 }
206 }
207
208 type Either = EitherService<Svc1, Svc2>;
209 type EitherFactory<F> = EitherServiceFactory<F, Svc1Factory, Svc2Factory>;
210
211 #[ntex_macros::rt_test2]
212 async fn test_success() {
213 let svc = Pipeline::new(Either::left(Svc1).clone());
214 assert_eq!(svc.call(()).await, Ok("svc1"));
215 assert_eq!(svc.ready().await, Ok(()));
216 svc.shutdown().await;
217
218 let svc = Pipeline::new(Either::right(Svc2).clone());
219 assert_eq!(svc.call(()).await, Ok("svc2"));
220 assert_eq!(svc.ready().await, Ok(()));
221 svc.shutdown().await;
222
223 assert!(format!("{svc:?}").contains("EitherService"));
224 }
225
226 #[ntex_macros::rt_test2]
227 async fn test_factory() {
228 let factory =
229 EitherFactory::new(|s: &&'static str| *s == "svc1", Svc1Factory, Svc2Factory)
230 .clone();
231 assert!(format!("{factory:?}").contains("EitherServiceFactory"));
232
233 let svc = factory.pipeline("svc1").await.unwrap();
234 assert_eq!(svc.call(()).await, Ok("svc1"));
235
236 let svc = factory.pipeline("other").await.unwrap();
237 assert_eq!(svc.call(()).await, Ok("svc2"));
238 }
239}