1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory, dev::ServiceChainFactory};
4
5pub fn apply<M, S, R, C, U>(
7 mw: M,
8 factory: U,
9) -> ServiceChainFactory<ApplyMiddleware<M, S, C>, R, C>
10where
11 S: ServiceFactory<R, C>,
12 M: Middleware<S::Service, C>,
13 U: IntoServiceFactory<S, R, C>,
14{
15 ServiceChainFactory {
16 factory: ApplyMiddleware::new(mw, factory.into_factory()),
17 _t: PhantomData,
18 }
19}
20
21pub trait Middleware<Svc, Cfg = ()> {
91 type Service;
93
94 fn create(&self, service: Svc, cfg: Cfg) -> Self::Service;
96
97 fn apply<Fac, Req>(
102 self,
103 factory: Fac,
104 ) -> ServiceChainFactory<ApplyMiddleware<Self, Fac, Cfg>, Req, Cfg>
105 where
106 Fac: ServiceFactory<Req, Cfg, Service = Svc>,
107 Cfg: Clone,
108 Self: Sized,
109 Self::Service: Service<Req>,
110 {
111 crate::chain_factory(ApplyMiddleware::new(self, factory))
112 }
113}
114
115impl<M, Svc, Cfg> Middleware<Svc, Cfg> for Rc<M>
116where
117 M: Middleware<Svc, Cfg>,
118{
119 type Service = M::Service;
120
121 fn create(&self, service: Svc, cfg: Cfg) -> M::Service {
122 self.as_ref().create(service, cfg)
123 }
124}
125
126pub struct ApplyMiddleware<M, Fac, Cfg>(Rc<(M, Fac)>, PhantomData<Cfg>);
128
129impl<M, Fac, Cfg> ApplyMiddleware<M, Fac, Cfg> {
130 pub(crate) fn new(mw: M, fac: Fac) -> Self {
132 Self(Rc::new((mw, fac)), PhantomData)
133 }
134}
135
136impl<M, Fac, Cfg> Clone for ApplyMiddleware<M, Fac, Cfg> {
137 fn clone(&self) -> Self {
138 Self(self.0.clone(), PhantomData)
139 }
140}
141
142impl<M, Fac, Cfg> fmt::Debug for ApplyMiddleware<M, Fac, Cfg>
143where
144 M: fmt::Debug,
145 Fac: fmt::Debug,
146{
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 f.debug_struct("ApplyMiddleware")
149 .field("factory", &self.0.1)
150 .field("middleware", &self.0.0)
151 .finish()
152 }
153}
154
155impl<M, Fac, Req, Cfg> ServiceFactory<Req, Cfg> for ApplyMiddleware<M, Fac, Cfg>
156where
157 Fac: ServiceFactory<Req, Cfg>,
158 M: Middleware<Fac::Service, Cfg>,
159 M::Service: Service<Req>,
160 Cfg: Clone,
161{
162 type Response = <M::Service as Service<Req>>::Response;
163 type Error = <M::Service as Service<Req>>::Error;
164
165 type Service = M::Service;
166 type InitError = Fac::InitError;
167
168 #[inline]
169 async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
170 Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
171 }
172}
173
174#[derive(Debug, Clone, Copy)]
178pub struct Identity;
179
180impl<S, Cfg> Middleware<S, Cfg> for Identity {
181 type Service = S;
182
183 #[inline]
184 fn create(&self, service: S, _: Cfg) -> Self::Service {
185 service
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct Stack<Inner, Outer> {
192 inner: Inner,
193 outer: Outer,
194}
195
196impl<Inner, Outer> Stack<Inner, Outer> {
197 pub fn new(inner: Inner, outer: Outer) -> Self {
198 Stack { inner, outer }
199 }
200}
201
202impl<S, Inner, Outer, C> Middleware<S, C> for Stack<Inner, Outer>
203where
204 Inner: Middleware<S, C>,
205 Outer: Middleware<Inner::Service, C>,
206 C: Clone,
207{
208 type Service = Outer::Service;
209
210 fn create(&self, service: S, cfg: C) -> Self::Service {
211 self.outer
212 .create(self.inner.create(service, cfg.clone()), cfg)
213 }
214}
215
216#[cfg(test)]
217#[allow(clippy::redundant_clone)]
218mod tests {
219 use std::{cell::Cell, rc::Rc};
220
221 use super::*;
222 use crate::{Pipeline, ServiceCtx, fn_service};
223
224 #[derive(Debug, Clone)]
225 struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
226
227 impl<S, R, C> Middleware<S, C> for Mw<R> {
228 type Service = Srv<S, R>;
229
230 fn create(&self, service: S, _: C) -> Self::Service {
231 self.1.set(self.1.get() + 1);
232 Srv(service, PhantomData, self.1.clone())
233 }
234 }
235
236 #[derive(Debug, Clone)]
237 struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
238
239 impl<S: Service<R>, R> Service<R> for Srv<S, R> {
240 type Response = S::Response;
241 type Error = S::Error;
242
243 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
244 ctx.ready(&self.0).await
245 }
246
247 async fn call(
248 &self,
249 req: R,
250 ctx: ServiceCtx<'_, Self>,
251 ) -> Result<S::Response, S::Error> {
252 ctx.call(&self.0, req).await
253 }
254
255 async fn shutdown(&self) {
256 self.2.set(self.2.get() + 1);
257 }
258 }
259
260 #[ntex::test]
261 async fn middleware() {
262 let cnt_sht = Rc::new(Cell::new(0));
263 let factory = apply(
264 Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
265 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
266 )
267 .clone();
268
269 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
270 let res = srv.call(10).await;
271 assert!(res.is_ok());
272 assert_eq!(res.unwrap(), 20);
273 let _ = format!("{factory:?} {srv:?}");
274
275 assert_eq!(srv.ready().await, Ok(()));
276 srv.shutdown().await;
277 assert_eq!(cnt_sht.get(), 2);
278
279 let factory =
280 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
281 .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
282 .clone();
283
284 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
285 let res = srv.call(10).await;
286 assert!(res.is_ok());
287 assert_eq!(res.unwrap(), 20);
288 let _ = format!("{factory:?} {srv:?}");
289
290 assert_eq!(srv.ready().await, Ok(()));
291 }
292
293 #[ntex::test]
294 async fn middleware_apply() {
295 let cnt_sht = Rc::new(Cell::new(0));
296 let factory = Mw(PhantomData, cnt_sht.clone())
297 .apply(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
298 .boxed();
299
300 let srv = factory.pipeline(&()).await.unwrap();
301 let res = srv.call(10).await;
302 assert!(res.is_ok());
303 assert_eq!(res.unwrap(), 20);
304 let _ = format!("{factory:?} {srv:?}");
305
306 assert_eq!(srv.ready().await, Ok(()));
307 srv.shutdown().await;
308 assert_eq!(cnt_sht.get(), 2);
309 }
310
311 #[ntex::test]
312 async fn middleware_chain() {
313 let cnt_sht = Rc::new(Cell::new(0));
314 let factory =
315 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
316 .apply(Mw(PhantomData, cnt_sht.clone()).clone());
317
318 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
319 let res = srv.call(10).await;
320 assert!(res.is_ok());
321 assert_eq!(res.unwrap(), 20);
322 let _ = format!("{factory:?} {srv:?}");
323
324 assert_eq!(srv.ready().await, Ok(()));
325 srv.shutdown().await;
326 assert_eq!(cnt_sht.get(), 2);
327 }
328
329 #[ntex::test]
330 async fn stack() {
331 let cnt_sht = Rc::new(Cell::new(0));
332 let mw = Stack::new(Identity, Mw(PhantomData, cnt_sht.clone()));
333 let _ = format!("{mw:?}");
334
335 let pl = Pipeline::new(Middleware::create(
336 &mw,
337 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
338 (),
339 ));
340 let res = pl.call(10).await;
341 assert!(res.is_ok());
342 assert_eq!(res.unwrap(), 20);
343 assert_eq!(pl.ready().await, Ok(()));
344 pl.shutdown().await;
345 assert_eq!(cnt_sht.get(), 2);
346 }
347}