1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory};
4
5pub fn apply<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware<T, S, C>
7where
8 S: ServiceFactory<R, C>,
9 T: Middleware<S::Service>,
10 U: IntoServiceFactory<S, R, C>,
11{
12 ApplyMiddleware::new(t, factory.into_factory())
13}
14
15pub trait Middleware<S> {
86 type Service;
88
89 fn create(&self, service: S) -> Self::Service;
91}
92
93impl<T, S> Middleware<S> for Rc<T>
94where
95 T: Middleware<S>,
96{
97 type Service = T::Service;
98
99 fn create(&self, service: S) -> T::Service {
100 self.as_ref().create(service)
101 }
102}
103
104pub struct ApplyMiddleware<T, S, C>(Rc<(T, S)>, PhantomData<C>);
106
107impl<T, S, C> ApplyMiddleware<T, S, C> {
108 pub(crate) fn new(mw: T, svc: S) -> Self {
110 Self(Rc::new((mw, svc)), PhantomData)
111 }
112}
113
114impl<T, S, C> Clone for ApplyMiddleware<T, S, C> {
115 fn clone(&self) -> Self {
116 Self(self.0.clone(), PhantomData)
117 }
118}
119
120impl<T, S, C> fmt::Debug for ApplyMiddleware<T, S, C>
121where
122 T: fmt::Debug,
123 S: fmt::Debug,
124{
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("ApplyMiddleware")
127 .field("service", &self.0 .1)
128 .field("middleware", &self.0 .0)
129 .finish()
130 }
131}
132
133impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware<T, S, C>
134where
135 S: ServiceFactory<R, C>,
136 T: Middleware<S::Service>,
137 T::Service: Service<R>,
138{
139 type Response = <T::Service as Service<R>>::Response;
140 type Error = <T::Service as Service<R>>::Error;
141
142 type Service = T::Service;
143 type InitError = S::InitError;
144
145 #[inline]
146 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
147 Ok(self.0 .0.create(self.0 .1.create(cfg).await?))
148 }
149}
150
151#[derive(Debug, Clone, Copy)]
155pub struct Identity;
156
157impl<S> Middleware<S> for Identity {
158 type Service = S;
159
160 #[inline]
161 fn create(&self, service: S) -> Self::Service {
162 service
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct Stack<Inner, Outer> {
169 inner: Inner,
170 outer: Outer,
171}
172
173impl<Inner, Outer> Stack<Inner, Outer> {
174 pub fn new(inner: Inner, outer: Outer) -> Self {
175 Stack { inner, outer }
176 }
177}
178
179impl<S, Inner, Outer> Middleware<S> for Stack<Inner, Outer>
180where
181 Inner: Middleware<S>,
182 Outer: Middleware<Inner::Service>,
183{
184 type Service = Outer::Service;
185
186 fn create(&self, service: S) -> Self::Service {
187 self.outer.create(self.inner.create(service))
188 }
189}
190
191#[cfg(test)]
192#[allow(clippy::redundant_clone)]
193mod tests {
194 use std::{cell::Cell, rc::Rc};
195
196 use super::*;
197 use crate::{fn_service, Pipeline, ServiceCtx};
198
199 #[derive(Debug, Clone)]
200 struct Tr<R>(PhantomData<R>, Rc<Cell<usize>>);
201
202 impl<S, R> Middleware<S> for Tr<R> {
203 type Service = Srv<S, R>;
204
205 fn create(&self, service: S) -> Self::Service {
206 Srv(service, PhantomData, self.1.clone())
207 }
208 }
209
210 #[derive(Debug, Clone)]
211 struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
212
213 impl<S: Service<R>, R> Service<R> for Srv<S, R> {
214 type Response = S::Response;
215 type Error = S::Error;
216
217 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
218 ctx.ready(&self.0).await
219 }
220
221 async fn call(
222 &self,
223 req: R,
224 ctx: ServiceCtx<'_, Self>,
225 ) -> Result<S::Response, S::Error> {
226 ctx.call(&self.0, req).await
227 }
228
229 async fn shutdown(&self) {
230 self.2.set(self.2.get() + 1);
231 }
232 }
233
234 #[ntex::test]
235 async fn middleware() {
236 let cnt_sht = Rc::new(Cell::new(0));
237 let factory = apply(
238 Rc::new(Tr(PhantomData, cnt_sht.clone()).clone()),
239 fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
240 )
241 .clone();
242
243 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
244 let res = srv.call(10).await;
245 assert!(res.is_ok());
246 assert_eq!(res.unwrap(), 20);
247 let _ = format!("{:?} {:?}", factory, srv);
248
249 assert_eq!(srv.ready().await, Ok(()));
250 srv.shutdown().await;
251 assert_eq!(cnt_sht.get(), 1);
252
253 let factory =
254 crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
255 .apply(Rc::new(Tr(PhantomData, Rc::new(Cell::new(0))).clone()))
256 .clone();
257
258 let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
259 let res = srv.call(10).await;
260 assert!(res.is_ok());
261 assert_eq!(res.unwrap(), 20);
262 let _ = format!("{:?} {:?}", factory, srv);
263
264 assert_eq!(srv.ready().await, Ok(()));
265 }
266}