1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory};
4
5pub fn apply<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware<T, S, C>
7where
8 S: ServiceFactory<R, C>,
9 T: Middleware<S::Service>,
10 U: IntoServiceFactory<S, R, C>,
11{
12 ApplyMiddleware::new(t, factory.into_factory())
13}
14
15pub fn apply2<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware2<T, S, C>
17where
18 S: ServiceFactory<R, C>,
19 T: Middleware2<S::Service, C>,
20 U: IntoServiceFactory<S, R, C>,
21{
22 ApplyMiddleware2::new(t, factory.into_factory())
23}
24
25pub trait Middleware<S> {
96 type Service;
98
99 fn create(&self, service: S) -> Self::Service;
101}
102
103pub trait Middleware2<S, Cfg = ()> {
174 type Service;
176
177 fn create(&self, service: S, cfg: Cfg) -> Self::Service;
179}
180
181impl<T, S> Middleware<S> for Rc<T>
182where
183 T: Middleware<S>,
184{
185 type Service = T::Service;
186
187 fn create(&self, service: S) -> T::Service {
188 self.as_ref().create(service)
189 }
190}
191
192impl<T, S, C> Middleware2<S, C> for Rc<T>
193where
194 T: Middleware2<S, C>,
195{
196 type Service = T::Service;
197
198 fn create(&self, service: S, cfg: C) -> T::Service {
199 self.as_ref().create(service, cfg)
200 }
201}
202
203pub struct ApplyMiddleware<T, S, C>(Rc<(T, S)>, PhantomData<C>);
205
206impl<T, S, C> ApplyMiddleware<T, S, C> {
207 pub(crate) fn new(mw: T, svc: S) -> Self {
209 Self(Rc::new((mw, svc)), PhantomData)
210 }
211}
212
213impl<T, S, C> Clone for ApplyMiddleware<T, S, C> {
214 fn clone(&self) -> Self {
215 Self(self.0.clone(), PhantomData)
216 }
217}
218
219impl<T, S, C> fmt::Debug for ApplyMiddleware<T, S, C>
220where
221 T: fmt::Debug,
222 S: fmt::Debug,
223{
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 f.debug_struct("ApplyMiddleware")
226 .field("service", &self.0.1)
227 .field("middleware", &self.0.0)
228 .finish()
229 }
230}
231
232impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware<T, S, C>
233where
234 S: ServiceFactory<R, C>,
235 T: Middleware<S::Service>,
236 T::Service: Service<R>,
237{
238 type Response = <T::Service as Service<R>>::Response;
239 type Error = <T::Service as Service<R>>::Error;
240
241 type Service = T::Service;
242 type InitError = S::InitError;
243
244 #[inline]
245 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
246 Ok(self.0.0.create(self.0.1.create(cfg).await?))
247 }
248}
249
250pub struct ApplyMiddleware2<T, S, C>(Rc<(T, S)>, PhantomData<C>);
252
253impl<T, S, C> ApplyMiddleware2<T, S, C> {
254 pub(crate) fn new(mw: T, svc: S) -> Self {
256 Self(Rc::new((mw, svc)), PhantomData)
257 }
258}
259
260impl<T, S, C> Clone for ApplyMiddleware2<T, S, C> {
261 fn clone(&self) -> Self {
262 Self(self.0.clone(), PhantomData)
263 }
264}
265
266impl<T, S, C> fmt::Debug for ApplyMiddleware2<T, S, C>
267where
268 T: fmt::Debug,
269 S: fmt::Debug,
270{
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 f.debug_struct("ApplyMiddleware")
273 .field("service", &self.0.1)
274 .field("middleware", &self.0.0)
275 .finish()
276 }
277}
278
279impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware2<T, S, C>
280where
281 S: ServiceFactory<R, C>,
282 T: Middleware2<S::Service, C>,
283 T::Service: Service<R>,
284 C: Clone,
285{
286 type Response = <T::Service as Service<R>>::Response;
287 type Error = <T::Service as Service<R>>::Error;
288
289 type Service = T::Service;
290 type InitError = S::InitError;
291
292 #[inline]
293 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
294 Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
295 }
296}
297
298#[derive(Debug, Clone, Copy)]
302pub struct Identity;
303
304impl<S> Middleware<S> for Identity {
305 type Service = S;
306
307 #[inline]
308 fn create(&self, service: S) -> Self::Service {
309 service
310 }
311}
312
313impl<S, Cfg> Middleware2<S, Cfg> for Identity {
314 type Service = S;
315
316 #[inline]
317 fn create(&self, service: S, _: Cfg) -> Self::Service {
318 service
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct Stack<Inner, Outer> {
325 inner: Inner,
326 outer: Outer,
327}
328
329impl<Inner, Outer> Stack<Inner, Outer> {
330 pub fn new(inner: Inner, outer: Outer) -> Self {
331 Stack { inner, outer }
332 }
333}
334
335impl<S, Inner, Outer> Middleware<S> for Stack<Inner, Outer>
336where
337 Inner: Middleware<S>,
338 Outer: Middleware<Inner::Service>,
339{
340 type Service = Outer::Service;
341
342 fn create(&self, service: S) -> Self::Service {
343 self.outer.create(self.inner.create(service))
344 }
345}
346
347impl<S, Inner, Outer, C> Middleware2<S, C> for Stack<Inner, Outer>
348where
349 Inner: Middleware2<S, C>,
350 Outer: Middleware2<Inner::Service, C>,
351 C: Clone,
352{
353 type Service = Outer::Service;
354
355 fn create(&self, service: S, cfg: C) -> Self::Service {
356 self.outer
357 .create(self.inner.create(service, cfg.clone()), cfg)
358 }
359}
360
361#[cfg(test)]
362#[allow(clippy::redundant_clone)]
363mod tests {
364 use std::{cell::Cell, rc::Rc};
365
366 use super::*;
367 use crate::{Pipeline, ServiceCtx, fn_service};
368
369 #[derive(Debug, Clone)]
370 struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
371
372 impl<S, R> Middleware<S> for Mw<R> {
373 type Service = Srv<S, R>;
374
375 fn create(&self, service: S) -> Self::Service {
376 Srv(service, PhantomData, self.1.clone())
377 }
378 }
379
380 impl<S, R, C> Middleware2<S, C> for Mw<R> {
381 type Service = Srv<S, R>;
382
383 fn create(&self, service: S, _: C) -> Self::Service {
384 self.1.set(self.1.get() + 1);
385 Srv(service, PhantomData, self.1.clone())
386 }
387 }
388
389 #[derive(Debug, Clone)]
390 struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
391
392 impl<S: Service<R>, R> Service<R> for Srv<S, R> {
393 type Response = S::Response;
394 type Error = S::Error;
395
396 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
397 ctx.ready(&self.0).await
398 }
399
400 async fn call(
401 &self,
402 req: R,
403 ctx: ServiceCtx<'_, Self>,
404 ) -> Result<S::Response, S::Error> {
405 ctx.call(&self.0, req).await
406 }
407
408 async fn shutdown(&self) {
409 self.2.set(self.2.get() + 1);
410 }
411 }
412
413 #[ntex::test]
414 async fn middleware() {
415 let cnt_sht = Rc::new(Cell::new(0));
416 let factory = apply(
417 Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
418 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
419 )
420 .clone();
421
422 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
423 let res = srv.call(10).await;
424 assert!(res.is_ok());
425 assert_eq!(res.unwrap(), 20);
426 let _ = format!("{factory:?} {srv:?}");
427
428 assert_eq!(srv.ready().await, Ok(()));
429 srv.shutdown().await;
430 assert_eq!(cnt_sht.get(), 1);
431
432 let factory =
433 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
434 .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
435 .clone();
436
437 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
438 let res = srv.call(10).await;
439 assert!(res.is_ok());
440 assert_eq!(res.unwrap(), 20);
441 let _ = format!("{factory:?} {srv:?}");
442
443 assert_eq!(srv.ready().await, Ok(()));
444 }
445
446 #[ntex::test]
447 async fn middleware2() {
448 let cnt_sht = Rc::new(Cell::new(0));
449 let factory = apply2(
450 Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
451 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
452 )
453 .clone();
454
455 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
456 let res = srv.call(10).await;
457 assert!(res.is_ok());
458 assert_eq!(res.unwrap(), 20);
459 let _ = format!("{factory:?} {srv:?}");
460
461 assert_eq!(srv.ready().await, Ok(()));
462 srv.shutdown().await;
463 assert_eq!(cnt_sht.get(), 2);
464
465 let factory =
466 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
467 .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
468 .clone();
469
470 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
471 let res = srv.call(10).await;
472 assert!(res.is_ok());
473 assert_eq!(res.unwrap(), 20);
474 let _ = format!("{factory:?} {srv:?}");
475
476 assert_eq!(srv.ready().await, Ok(()));
477 }
478
479 #[ntex::test]
480 async fn middleware2_chain() {
481 let cnt_sht = Rc::new(Cell::new(0));
482 let factory =
483 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
484 .apply2(Mw(PhantomData, cnt_sht.clone()).clone());
485
486 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
487 let res = srv.call(10).await;
488 assert!(res.is_ok());
489 assert_eq!(res.unwrap(), 20);
490 let _ = format!("{factory:?} {srv:?}");
491
492 assert_eq!(srv.ready().await, Ok(()));
493 srv.shutdown().await;
494 assert_eq!(cnt_sht.get(), 2);
495 }
496}