1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory};
4
5pub fn apply<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware<T, S, C>
7where
8 S: ServiceFactory<R, C>,
9 T: Middleware<S::Service, C>,
10 U: IntoServiceFactory<S, R, C>,
11{
12 ApplyMiddleware::new(t, factory.into_factory())
13}
14
15pub trait Middleware<S, Cfg = ()> {
86 type Service;
88
89 fn create(&self, service: S, cfg: Cfg) -> Self::Service;
91}
92
93impl<T, S, C> Middleware<S, C> for Rc<T>
94where
95 T: Middleware<S, C>,
96{
97 type Service = T::Service;
98
99 fn create(&self, service: S, cfg: C) -> T::Service {
100 self.as_ref().create(service, cfg)
101 }
102}
103
104pub struct ApplyMiddleware<T, S, C>(Rc<(T, S)>, PhantomData<C>);
106
107impl<T, S, C> ApplyMiddleware<T, S, C> {
108 pub(crate) fn new(mw: T, svc: S) -> Self {
110 Self(Rc::new((mw, svc)), PhantomData)
111 }
112}
113
114impl<T, S, C> Clone for ApplyMiddleware<T, S, C> {
115 fn clone(&self) -> Self {
116 Self(self.0.clone(), PhantomData)
117 }
118}
119
120impl<T, S, C> fmt::Debug for ApplyMiddleware<T, S, C>
121where
122 T: fmt::Debug,
123 S: fmt::Debug,
124{
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("ApplyMiddleware")
127 .field("service", &self.0.1)
128 .field("middleware", &self.0.0)
129 .finish()
130 }
131}
132
133impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware<T, S, C>
134where
135 S: ServiceFactory<R, C>,
136 T: Middleware<S::Service, C>,
137 T::Service: Service<R>,
138 C: Clone,
139{
140 type Response = <T::Service as Service<R>>::Response;
141 type Error = <T::Service as Service<R>>::Error;
142
143 type Service = T::Service;
144 type InitError = S::InitError;
145
146 #[inline]
147 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
148 Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
149 }
150}
151
152#[derive(Debug, Clone, Copy)]
156pub struct Identity;
157
158impl<S, Cfg> Middleware<S, Cfg> for Identity {
159 type Service = S;
160
161 #[inline]
162 fn create(&self, service: S, _: Cfg) -> Self::Service {
163 service
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct Stack<Inner, Outer> {
170 inner: Inner,
171 outer: Outer,
172}
173
174impl<Inner, Outer> Stack<Inner, Outer> {
175 pub fn new(inner: Inner, outer: Outer) -> Self {
176 Stack { inner, outer }
177 }
178}
179
180impl<S, Inner, Outer, C> Middleware<S, C> for Stack<Inner, Outer>
181where
182 Inner: Middleware<S, C>,
183 Outer: Middleware<Inner::Service, C>,
184 C: Clone,
185{
186 type Service = Outer::Service;
187
188 fn create(&self, service: S, cfg: C) -> Self::Service {
189 self.outer
190 .create(self.inner.create(service, cfg.clone()), cfg)
191 }
192}
193
194#[cfg(test)]
195#[allow(clippy::redundant_clone)]
196mod tests {
197 use std::{cell::Cell, rc::Rc};
198
199 use super::*;
200 use crate::{Pipeline, ServiceCtx, fn_service};
201
202 #[derive(Debug, Clone)]
203 struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
204
205 impl<S, R, C> Middleware<S, C> for Mw<R> {
206 type Service = Srv<S, R>;
207
208 fn create(&self, service: S, _: C) -> Self::Service {
209 self.1.set(self.1.get() + 1);
210 Srv(service, PhantomData, self.1.clone())
211 }
212 }
213
214 #[derive(Debug, Clone)]
215 struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
216
217 impl<S: Service<R>, R> Service<R> for Srv<S, R> {
218 type Response = S::Response;
219 type Error = S::Error;
220
221 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
222 ctx.ready(&self.0).await
223 }
224
225 async fn call(
226 &self,
227 req: R,
228 ctx: ServiceCtx<'_, Self>,
229 ) -> Result<S::Response, S::Error> {
230 ctx.call(&self.0, req).await
231 }
232
233 async fn shutdown(&self) {
234 self.2.set(self.2.get() + 1);
235 }
236 }
237
238 #[ntex::test]
239 async fn middleware() {
240 let cnt_sht = Rc::new(Cell::new(0));
241 let factory = apply(
242 Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
243 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
244 )
245 .clone();
246
247 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
248 let res = srv.call(10).await;
249 assert!(res.is_ok());
250 assert_eq!(res.unwrap(), 20);
251 let _ = format!("{factory:?} {srv:?}");
252
253 assert_eq!(srv.ready().await, Ok(()));
254 srv.shutdown().await;
255 assert_eq!(cnt_sht.get(), 2);
256
257 let factory =
258 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
259 .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
260 .clone();
261
262 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
263 let res = srv.call(10).await;
264 assert!(res.is_ok());
265 assert_eq!(res.unwrap(), 20);
266 let _ = format!("{factory:?} {srv:?}");
267
268 assert_eq!(srv.ready().await, Ok(()));
269 }
270
271 #[ntex::test]
272 async fn middleware_chain() {
273 let cnt_sht = Rc::new(Cell::new(0));
274 let factory =
275 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
276 .apply(Mw(PhantomData, cnt_sht.clone()).clone());
277
278 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
279 let res = srv.call(10).await;
280 assert!(res.is_ok());
281 assert_eq!(res.unwrap(), 20);
282 let _ = format!("{factory:?} {srv:?}");
283
284 assert_eq!(srv.ready().await, Ok(()));
285 srv.shutdown().await;
286 assert_eq!(cnt_sht.get(), 2);
287 }
288
289 #[ntex::test]
290 async fn stack() {
291 let cnt_sht = Rc::new(Cell::new(0));
292 let mw = Stack::new(Identity, Mw(PhantomData, cnt_sht.clone()));
293 let _ = format!("{mw:?}");
294
295 let pl = Pipeline::new(Middleware::create(
296 &mw,
297 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
298 (),
299 ));
300 let res = pl.call(10).await;
301 assert!(res.is_ok());
302 assert_eq!(res.unwrap(), 20);
303 assert_eq!(pl.ready().await, Ok(()));
304 pl.shutdown().await;
305 assert_eq!(cnt_sht.get(), 2);
306 }
307}