1#![allow(clippy::type_complexity)]
2use std::{fmt, marker};
3
4use crate::dev::{ServiceChain, ServiceChainFactory};
5use crate::{
6 IntoService, IntoServiceFactory, Pipeline, Service, ServiceCtx, ServiceFactory,
7};
8
9pub fn apply_fn<T, Req, F, In, Out, Err, U>(
11 service: U,
12 f: F,
13) -> ServiceChain<Apply<T, Req, F, In, Out, Err>, In>
14where
15 T: Service<Req>,
16 F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err>,
17 U: IntoService<T, Req>,
18 Err: From<T::Error>,
19{
20 crate::chain(Apply::new(service.into_service(), f))
21}
22
23pub fn apply_fn_factory<T, Req, Cfg, F, In, Out, Err, U>(
25 service: U,
26 f: F,
27) -> ServiceChainFactory<ApplyFactory<T, Req, Cfg, F, In, Out, Err>, In, Cfg>
28where
29 T: ServiceFactory<Req, Cfg>,
30 F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
31 U: IntoServiceFactory<T, Req, Cfg>,
32 Err: From<T::Error>,
33{
34 crate::chain_factory(ApplyFactory::new(service.into_factory(), f))
35}
36
37pub struct Apply<T, Req, F, In, Out, Err>
39where
40 T: Service<Req>,
41{
42 service: Pipeline<T>,
43 f: F,
44 r: marker::PhantomData<fn(Req) -> (In, Out, Err)>,
45}
46
47impl<T, Req, F, In, Out, Err> Apply<T, Req, F, In, Out, Err>
48where
49 T: Service<Req>,
50 F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err>,
51 Err: From<T::Error>,
52{
53 pub(crate) fn new(service: T, f: F) -> Self {
54 Apply {
55 f,
56 service: Pipeline::new(service),
57 r: marker::PhantomData,
58 }
59 }
60}
61
62impl<T, Req, F, In, Out, Err> Clone for Apply<T, Req, F, In, Out, Err>
63where
64 T: Service<Req> + Clone,
65 F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err> + Clone,
66 Err: From<T::Error>,
67{
68 fn clone(&self) -> Self {
69 Apply {
70 service: self.service.clone(),
71 f: self.f.clone(),
72 r: marker::PhantomData,
73 }
74 }
75}
76
77impl<T, Req, F, In, Out, Err> fmt::Debug for Apply<T, Req, F, In, Out, Err>
78where
79 T: Service<Req> + fmt::Debug,
80{
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.debug_struct("Apply")
83 .field("service", &self.service)
84 .field("map", &std::any::type_name::<F>())
85 .finish()
86 }
87}
88
89impl<T, Req, F, In, Out, Err> Service<In> for Apply<T, Req, F, In, Out, Err>
90where
91 T: Service<Req>,
92 F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err>,
93 Err: From<T::Error>,
94{
95 type Response = Out;
96 type Error = Err;
97
98 #[inline]
99 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Err> {
100 self.service.ready().await.map_err(From::from)
101 }
102
103 #[inline]
104 async fn call(
105 &self,
106 req: In,
107 _: ServiceCtx<'_, Self>,
108 ) -> Result<Self::Response, Self::Error> {
109 (self.f)(req, &self.service).await
110 }
111
112 crate::forward_poll!(service);
113 crate::forward_shutdown!(service);
114}
115
116pub struct ApplyFactory<T, Req, Cfg, F, In, Out, Err>
118where
119 T: ServiceFactory<Req, Cfg>,
120 F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
121{
122 service: T,
123 f: F,
124 r: marker::PhantomData<fn(Req, Cfg) -> (In, Out)>,
125}
126
127impl<T, Req, Cfg, F, In, Out, Err> ApplyFactory<T, Req, Cfg, F, In, Out, Err>
128where
129 T: ServiceFactory<Req, Cfg>,
130 F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
131 Err: From<T::Error>,
132{
133 pub(crate) fn new(service: T, f: F) -> Self {
135 Self {
136 f,
137 service,
138 r: marker::PhantomData,
139 }
140 }
141}
142
143impl<T, Req, Cfg, F, In, Out, Err> Clone for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
144where
145 T: ServiceFactory<Req, Cfg> + Clone,
146 F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
147 Err: From<T::Error>,
148{
149 fn clone(&self) -> Self {
150 Self {
151 service: self.service.clone(),
152 f: self.f.clone(),
153 r: marker::PhantomData,
154 }
155 }
156}
157
158impl<T, Req, Cfg, F, In, Out, Err> fmt::Debug for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
159where
160 T: ServiceFactory<Req, Cfg> + fmt::Debug,
161 F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
162 Err: From<T::Error>,
163{
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.debug_struct("ApplyFactory")
166 .field("factory", &self.service)
167 .field("map", &std::any::type_name::<F>())
168 .finish()
169 }
170}
171
172impl<T, Req, Cfg, F, In, Out, Err> ServiceFactory<In, Cfg>
173 for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
174where
175 T: ServiceFactory<Req, Cfg>,
176 F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
177 Err: From<T::Error>,
178{
179 type Response = Out;
180 type Error = Err;
181
182 type Service = Apply<T::Service, Req, F, In, Out, Err>;
183 type InitError = T::InitError;
184
185 #[inline]
186 async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
187 self.service.create(cfg).await.map(|svc| Apply {
188 service: svc.into(),
189 f: self.f.clone(),
190 r: marker::PhantomData,
191 })
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use ntex::util::lazy;
198 use std::{cell::Cell, rc::Rc, task::Context};
199
200 use super::*;
201 use crate::{chain, chain_factory, fn_factory};
202
203 #[derive(Debug, Default, Clone)]
204 struct Srv(Rc<Cell<usize>>);
205
206 impl Service<()> for Srv {
207 type Response = ();
208 type Error = ();
209
210 async fn call(&self, _r: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
211 Ok(())
212 }
213
214 fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
215 self.0.set(self.0.get() + 1);
216 Ok(())
217 }
218
219 async fn shutdown(&self) {
220 self.0.set(self.0.get() + 1);
221 }
222 }
223
224 #[derive(Debug, PartialEq, Eq)]
225 struct Err;
226
227 impl From<()> for Err {
228 fn from(_e: ()) -> Self {
229 Err
230 }
231 }
232
233 #[ntex::test]
234 async fn test_call() {
235 let cnt_sht = Rc::new(Cell::new(0));
236 let srv = chain(
237 apply_fn(Srv(cnt_sht.clone()), async move |req: &'static str, svc| {
238 svc.call(()).await.unwrap();
239 Ok((req, ()))
240 })
241 .clone(),
242 )
243 .into_pipeline();
244
245 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
246
247 lazy(|cx| srv.poll(cx)).await.unwrap();
248 assert_eq!(cnt_sht.get(), 1);
249
250 srv.shutdown().await;
251 assert_eq!(cnt_sht.get(), 2);
252
253 let res = srv.call("srv").await;
254 assert!(res.is_ok());
255 assert_eq!(res.unwrap(), ("srv", ()));
256 }
257
258 #[ntex::test]
259 async fn test_call_chain() {
260 let cnt_sht = Rc::new(Cell::new(0));
261 let srv = chain(Srv(cnt_sht.clone()))
262 .apply_fn(async move |req: &'static str, svc| {
263 svc.call(()).await.unwrap();
264 Ok((req, ()))
265 })
266 .clone()
267 .into_pipeline();
268
269 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
270
271 srv.shutdown().await;
272 assert_eq!(cnt_sht.get(), 1);
273
274 let res = srv.call("srv").await;
275 assert!(res.is_ok());
276 assert_eq!(res.unwrap(), ("srv", ()));
277 let _ = format!("{srv:?}");
278 }
279
280 #[ntex::test]
281 async fn test_create() {
282 let new_srv = chain_factory(
283 apply_fn_factory(
284 fn_factory(|| async { Ok::<_, ()>(Srv::default()) }),
285 async move |req: &'static str, srv| {
286 srv.call(()).await.unwrap();
287 Ok((req, ()))
288 },
289 )
290 .clone(),
291 );
292
293 let srv = new_srv.pipeline(&()).await.unwrap();
294
295 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
296
297 let res = srv.call("srv").await;
298 assert!(res.is_ok());
299 assert_eq!(res.unwrap(), ("srv", ()));
300 let _ = format!("{new_srv:?}");
301
302 assert!(Err == Err::from(()));
303 }
304
305 #[ntex::test]
306 async fn test_create_chain() {
307 let new_srv = chain_factory(fn_factory(|| async { Ok::<_, ()>(Srv::default()) }))
308 .apply_fn(async move |req: &'static str, srv| {
309 srv.call(()).await.unwrap();
310 Ok((req, ()))
311 })
312 .clone();
313
314 let srv = new_srv.pipeline(&()).await.unwrap();
315
316 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
317
318 let res = srv.call("srv").await;
319 assert!(res.is_ok());
320 assert_eq!(res.unwrap(), ("srv", ()));
321 let _ = format!("{new_srv:?}");
322 }
323}