1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::dev::{Apply, ApplyCtx, ServiceChainFactory};
4use crate::{IntoServiceFactory, Service, ServiceFactory};
5
6pub fn apply<M, S, R, C, U>(
8 mw: M,
9 factory: U,
10) -> ServiceChainFactory<ApplyMiddleware<M, S, C>, R, C>
11where
12 S: ServiceFactory<R, C>,
13 M: Middleware<S::Service, C>,
14 U: IntoServiceFactory<S, R, C>,
15{
16 ServiceChainFactory {
17 factory: ApplyMiddleware::new(mw, factory.into_factory()),
18 _t: PhantomData,
19 }
20}
21
22pub trait Middleware<Svc, Cfg = ()> {
92 type Service;
94
95 fn create(&self, service: Svc, cfg: Cfg) -> Self::Service;
97
98 fn apply<Fac, Req>(
103 self,
104 factory: Fac,
105 ) -> ServiceChainFactory<ApplyMiddleware<Self, Fac, Cfg>, Req, Cfg>
106 where
107 Fac: ServiceFactory<Req, Cfg, Service = Svc>,
108 Cfg: Clone,
109 Self: Sized,
110 Self::Service: Service<Req>,
111 {
112 crate::chain_factory(ApplyMiddleware::new(self, factory))
113 }
114}
115
116impl<M, Svc, Cfg> Middleware<Svc, Cfg> for Rc<M>
117where
118 M: Middleware<Svc, Cfg>,
119{
120 type Service = M::Service;
121
122 fn create(&self, service: Svc, cfg: Cfg) -> M::Service {
123 self.as_ref().create(service, cfg)
124 }
125}
126
127pub struct ApplyMiddleware<M, Fac, Cfg>(Rc<(M, Fac)>, PhantomData<Cfg>);
129
130impl<M, Fac, Cfg> ApplyMiddleware<M, Fac, Cfg> {
131 pub(crate) fn new(mw: M, fac: Fac) -> Self {
133 Self(Rc::new((mw, fac)), PhantomData)
134 }
135}
136
137impl<M, Fac, Cfg> Clone for ApplyMiddleware<M, Fac, Cfg> {
138 fn clone(&self) -> Self {
139 Self(self.0.clone(), PhantomData)
140 }
141}
142
143impl<M, Fac, Cfg> fmt::Debug for ApplyMiddleware<M, Fac, Cfg>
144where
145 M: fmt::Debug,
146 Fac: fmt::Debug,
147{
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 f.debug_struct("ApplyMiddleware")
150 .field("factory", &self.0.1)
151 .field("middleware", &self.0.0)
152 .finish()
153 }
154}
155
156impl<M, Fac, Req, Cfg> ServiceFactory<Req, Cfg> for ApplyMiddleware<M, Fac, Cfg>
157where
158 Fac: ServiceFactory<Req, Cfg>,
159 M: Middleware<Fac::Service, Cfg>,
160 M::Service: Service<Req>,
161 Cfg: Clone,
162{
163 type Response = <M::Service as Service<Req>>::Response;
164 type Error = <M::Service as Service<Req>>::Error;
165
166 type Service = M::Service;
167 type InitError = Fac::InitError;
168
169 #[inline]
170 async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
171 Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
172 }
173}
174
175#[derive(Debug, Clone, Copy)]
179pub struct Identity;
180
181impl<S, Cfg> Middleware<S, Cfg> for Identity {
182 type Service = S;
183
184 #[inline]
185 fn create(&self, service: S, _: Cfg) -> Self::Service {
186 service
187 }
188}
189
190#[derive(Debug, Clone)]
192pub struct Stack<Inner, Outer> {
193 inner: Inner,
194 outer: Outer,
195}
196
197impl<Inner, Outer> Stack<Inner, Outer> {
198 pub fn new(inner: Inner, outer: Outer) -> Self {
199 Stack { inner, outer }
200 }
201}
202
203impl<S, Inner, Outer, C> Middleware<S, C> for Stack<Inner, Outer>
204where
205 Inner: Middleware<S, C>,
206 Outer: Middleware<Inner::Service, C>,
207 C: Clone,
208{
209 type Service = Outer::Service;
210
211 fn create(&self, service: S, cfg: C) -> Self::Service {
212 self.outer
213 .create(self.inner.create(service, cfg.clone()), cfg)
214 }
215}
216
217#[doc(hidden)]
218pub fn fn_layer<T, Req, F, In, Out, Err>(f: F) -> FnMiddleware<T, Req, F, In, Out, Err>
220where
221 F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err> + Clone,
222{
223 FnMiddleware { f, r: PhantomData }
224}
225
226#[allow(clippy::type_complexity)]
227pub struct FnMiddleware<T, Req, F, In, Out, Err> {
229 f: F,
230 r: PhantomData<fn(T, Req) -> (In, Out, Err)>,
231}
232
233impl<T, Req, F, In, Out, Err> Clone for FnMiddleware<T, Req, F, In, Out, Err>
234where
235 F: Clone,
236{
237 fn clone(&self) -> Self {
238 FnMiddleware {
239 f: self.f.clone(),
240 r: PhantomData,
241 }
242 }
243}
244
245impl<T, Req, F, In, Out, Err> fmt::Debug for FnMiddleware<T, Req, F, In, Out, Err> {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 f.debug_struct("FnMiddleware")
248 .field("layer", &std::any::type_name::<F>())
249 .finish()
250 }
251}
252
253impl<T, C, R, F, In, Out, Err> Middleware<T, C> for FnMiddleware<T, R, F, In, Out, Err>
254where
255 T: Service<R>,
256 F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err> + Clone,
257 Err: From<T::Error>,
258{
259 type Service = Apply<T, R, F, In, Out, Err>;
260
261 fn create(&self, service: T, _: C) -> Self::Service {
262 Apply::new(service, self.f.clone())
263 }
264}
265
266#[cfg(test)]
267#[allow(clippy::redundant_clone)]
268mod tests {
269 use std::{cell::Cell, rc::Rc};
270
271 use super::*;
272 use crate::{Pipeline, ServiceCtx, fn_service};
273
274 #[derive(Debug, Clone)]
275 struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
276
277 impl<S, R, C> Middleware<S, C> for Mw<R> {
278 type Service = Srv<S, R>;
279
280 fn create(&self, service: S, _: C) -> Self::Service {
281 self.1.set(self.1.get() + 1);
282 Srv(service, PhantomData, self.1.clone())
283 }
284 }
285
286 #[derive(Debug, Clone)]
287 struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
288
289 impl<S: Service<R>, R> Service<R> for Srv<S, R> {
290 type Response = S::Response;
291 type Error = S::Error;
292
293 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
294 ctx.ready(&self.0).await
295 }
296
297 async fn call(
298 &self,
299 req: R,
300 ctx: ServiceCtx<'_, Self>,
301 ) -> Result<S::Response, S::Error> {
302 ctx.call(&self.0, req).await
303 }
304
305 async fn shutdown(&self) {
306 self.2.set(self.2.get() + 1);
307 }
308 }
309
310 #[ntex::test]
311 async fn middleware() {
312 let cnt_sht = Rc::new(Cell::new(0));
313 let factory = apply(
314 Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
315 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
316 )
317 .clone();
318
319 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
320 let res = srv.call(10).await;
321 assert!(res.is_ok());
322 assert_eq!(res.unwrap(), 20);
323 let _ = format!("{factory:?} {srv:?}");
324
325 assert_eq!(srv.ready().await, Ok(()));
326 srv.shutdown().await;
327 assert_eq!(cnt_sht.get(), 2);
328
329 let factory =
330 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
331 .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
332 .clone();
333
334 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
335 let res = srv.call(10).await;
336 assert!(res.is_ok());
337 assert_eq!(res.unwrap(), 20);
338 let _ = format!("{factory:?} {srv:?}");
339
340 assert_eq!(srv.ready().await, Ok(()));
341 }
342
343 #[ntex::test]
344 async fn middleware_apply() {
345 let cnt_sht = Rc::new(Cell::new(0));
346 let factory = Mw(PhantomData, cnt_sht.clone())
347 .apply(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
348 .boxed();
349
350 let srv = factory.pipeline(&()).await.unwrap();
351 let res = srv.call(10).await;
352 assert!(res.is_ok());
353 assert_eq!(res.unwrap(), 20);
354 let _ = format!("{factory:?} {srv:?}");
355
356 assert_eq!(srv.ready().await, Ok(()));
357 srv.shutdown().await;
358 assert_eq!(cnt_sht.get(), 2);
359 }
360
361 #[ntex::test]
362 async fn middleware_chain() {
363 let cnt_sht = Rc::new(Cell::new(0));
364 let factory =
365 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
366 .apply(Mw(PhantomData, cnt_sht.clone()).clone());
367
368 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
369 let res = srv.call(10).await;
370 assert!(res.is_ok());
371 assert_eq!(res.unwrap(), 20);
372 let _ = format!("{factory:?} {srv:?}");
373
374 assert_eq!(srv.ready().await, Ok(()));
375 srv.shutdown().await;
376 assert_eq!(cnt_sht.get(), 2);
377 }
378
379 #[ntex::test]
380 async fn stack() {
381 let cnt_sht = Rc::new(Cell::new(0));
382 let mw = Stack::new(Identity, Mw(PhantomData, cnt_sht.clone()));
383 let _ = format!("{mw:?}");
384
385 let pl = Pipeline::new(Middleware::create(
386 &mw,
387 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
388 (),
389 ));
390 let res = pl.call(10).await;
391 assert!(res.is_ok());
392 assert_eq!(res.unwrap(), 20);
393 assert_eq!(pl.ready().await, Ok(()));
394 pl.shutdown().await;
395 assert_eq!(cnt_sht.get(), 2);
396 }
397
398 #[ntex::test]
399 async fn fn_middleware_service() {
400 let cnt_sht = Rc::new(Cell::new(0));
401 let cnt_sht2 = cnt_sht.clone();
402 let mw = fn_layer(async move |req: &'static str, svc| {
403 cnt_sht2.set(cnt_sht2.get() + 1);
404 let result = svc.call(1).await?;
405 Ok::<_, ()>((req, result))
406 })
407 .clone();
408 let _ = format!("{mw:?}");
409
410 let svc = Pipeline::new(
411 mw.create(fn_service(async move |i: usize| Ok::<_, ()>(i * 2)), ()),
412 );
413
414 let res = svc.call("test").await;
415 assert!(res.is_ok());
416 assert_eq!(res.unwrap(), ("test", 2));
417 let _ = format!("{svc:?}");
418
419 assert_eq!(svc.ready().await, Ok(()));
420 svc.shutdown().await;
421 assert_eq!(cnt_sht.get(), 1);
422 }
423}