1#![allow(clippy::type_complexity)]
2use std::{fmt, marker};
3
4use crate::ctx::WaitersRef;
5use crate::dev::{ServiceChain, ServiceChainFactory};
6use crate::{IntoService, IntoServiceFactory, Service, ServiceCtx, ServiceFactory};
7
8pub fn apply_fn<T, Req, F, In, Out, Err, U>(
10 service: U,
11 f: F,
12) -> ServiceChain<Apply<T, Req, F, In, Out, Err>, In>
13where
14 T: Service<Req>,
15 F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err>,
16 U: IntoService<T, Req>,
17 Err: From<T::Error>,
18{
19 crate::chain(Apply::new(service.into_service(), f))
20}
21
22pub fn apply_fn_factory<T, Req, Cfg, F, In, Out, Err, U>(
24 service: U,
25 f: F,
26) -> ServiceChainFactory<ApplyFactory<T, Req, Cfg, F, In, Out, Err>, In, Cfg>
27where
28 T: ServiceFactory<Req, Cfg>,
29 F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
30 U: IntoServiceFactory<T, Req, Cfg>,
31 Err: From<T::Error>,
32{
33 crate::chain_factory(ApplyFactory::new(service.into_factory(), f))
34}
35
36#[derive(Debug)]
37pub struct ApplyCtx<'a, S> {
38 idx: u32,
39 waiters: &'a WaitersRef,
40 service: &'a S,
41}
42
43impl<'a, S> ApplyCtx<'a, S> {
44 #[inline]
45 pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
47 where
48 S: Service<R>,
49 R: 'a,
50 {
51 let ctx = ServiceCtx::new(self.idx, self.waiters);
52
53 self.service.ready(ctx).await?;
54 self.service.call(req, ctx).await
55 }
56}
57
58pub struct Apply<T, Req, F, In, Out, Err> {
60 service: T,
61 f: F,
62 r: marker::PhantomData<fn(Req) -> (In, Out, Err)>,
63}
64
65impl<T, Req, F, In, Out, Err> Apply<T, Req, F, In, Out, Err>
66where
67 T: Service<Req>,
68 F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err>,
69 Err: From<T::Error>,
70{
71 pub(crate) fn new(service: T, f: F) -> Self {
72 Apply {
73 f,
74 service,
75 r: marker::PhantomData,
76 }
77 }
78}
79
80impl<T, Req, F, In, Out, Err> Clone for Apply<T, Req, F, In, Out, Err>
81where
82 T: Clone,
83 F: Clone,
84{
85 fn clone(&self) -> Self {
86 Apply {
87 service: self.service.clone(),
88 f: self.f.clone(),
89 r: marker::PhantomData,
90 }
91 }
92}
93
94impl<T, Req, F, In, Out, Err> fmt::Debug for Apply<T, Req, F, In, Out, Err>
95where
96 T: fmt::Debug,
97{
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 f.debug_struct("Apply")
100 .field("service", &self.service)
101 .field("map", &std::any::type_name::<F>())
102 .finish()
103 }
104}
105
106impl<T, Req, F, In, Out, Err> Service<In> for Apply<T, Req, F, In, Out, Err>
107where
108 T: Service<Req>,
109 F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err>,
110 Err: From<T::Error>,
111{
112 type Response = Out;
113 type Error = Err;
114
115 #[inline]
116 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Err> {
117 ctx.ready(&self.service).await.map_err(From::from)
118 }
119
120 #[inline]
121 async fn call(
122 &self,
123 req: In,
124 ctx: ServiceCtx<'_, Self>,
125 ) -> Result<Self::Response, Self::Error> {
126 let (idx, waiters) = ctx.inner();
127
128 let ctx = ApplyCtx {
129 idx,
130 waiters,
131 service: &self.service,
132 };
133 (self.f)(req, &ctx).await
134 }
135
136 crate::forward_poll!(service);
137 crate::forward_shutdown!(service);
138}
139
140pub struct ApplyFactory<T, Req, Cfg, F, In, Out, Err>
142where
143 T: ServiceFactory<Req, Cfg>,
144 F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
145{
146 service: T,
147 f: F,
148 r: marker::PhantomData<fn(Req, Cfg) -> (In, Out)>,
149}
150
151impl<T, Req, Cfg, F, In, Out, Err> ApplyFactory<T, Req, Cfg, F, In, Out, Err>
152where
153 T: ServiceFactory<Req, Cfg>,
154 F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
155 Err: From<T::Error>,
156{
157 pub(crate) fn new(service: T, f: F) -> Self {
159 Self {
160 f,
161 service,
162 r: marker::PhantomData,
163 }
164 }
165}
166
167impl<T, Req, Cfg, F, In, Out, Err> Clone for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
168where
169 T: ServiceFactory<Req, Cfg> + Clone,
170 F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
171 Err: From<T::Error>,
172{
173 fn clone(&self) -> Self {
174 Self {
175 service: self.service.clone(),
176 f: self.f.clone(),
177 r: marker::PhantomData,
178 }
179 }
180}
181
182impl<T, Req, Cfg, F, In, Out, Err> fmt::Debug for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
183where
184 T: ServiceFactory<Req, Cfg> + fmt::Debug,
185 F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
186 Err: From<T::Error>,
187{
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 f.debug_struct("ApplyFactory")
190 .field("factory", &self.service)
191 .field("map", &std::any::type_name::<F>())
192 .finish()
193 }
194}
195
196impl<T, Req, Cfg, F, In, Out, Err> ServiceFactory<In, Cfg>
197 for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
198where
199 T: ServiceFactory<Req, Cfg>,
200 F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
201 Err: From<T::Error>,
202{
203 type Response = Out;
204 type Error = Err;
205
206 type Service = Apply<T::Service, Req, F, In, Out, Err>;
207 type InitError = T::InitError;
208
209 #[inline]
210 async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
211 self.service.create(cfg).await.map(|service| Apply {
212 service,
213 f: self.f.clone(),
214 r: marker::PhantomData,
215 })
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use ntex::util::lazy;
222 use std::{cell::Cell, rc::Rc, task::Context};
223
224 use super::*;
225 use crate::{chain, chain_factory, fn_factory};
226
227 #[derive(Debug, Default, Clone)]
228 struct Srv(Rc<Cell<usize>>);
229
230 impl Service<()> for Srv {
231 type Response = ();
232 type Error = ();
233
234 async fn call(&self, _r: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
235 Ok(())
236 }
237
238 fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
239 self.0.set(self.0.get() + 1);
240 Ok(())
241 }
242
243 async fn shutdown(&self) {
244 self.0.set(self.0.get() + 1);
245 }
246 }
247
248 #[derive(Debug, PartialEq, Eq)]
249 struct Err;
250
251 impl From<()> for Err {
252 fn from(_e: ()) -> Self {
253 Err
254 }
255 }
256
257 #[ntex::test]
258 async fn test_call() {
259 let cnt_sht = Rc::new(Cell::new(0));
260 let srv = chain(
261 apply_fn(Srv(cnt_sht.clone()), async move |req: &'static str, svc| {
262 svc.call(()).await.unwrap();
263 Ok((req, ()))
264 })
265 .clone(),
266 )
267 .into_pipeline();
268
269 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
270
271 lazy(|cx| srv.poll(cx)).await.unwrap();
272 assert_eq!(cnt_sht.get(), 1);
273
274 srv.shutdown().await;
275 assert_eq!(cnt_sht.get(), 2);
276
277 let res = srv.call("srv").await;
278 assert!(res.is_ok());
279 assert_eq!(res.unwrap(), ("srv", ()));
280 }
281
282 #[ntex::test]
283 async fn test_call_chain() {
284 let cnt_sht = Rc::new(Cell::new(0));
285 let srv = chain(Srv(cnt_sht.clone()))
286 .apply_fn(async move |req: &'static str, svc| {
287 svc.call(()).await.unwrap();
288 Ok((req, ()))
289 })
290 .clone()
291 .into_pipeline();
292
293 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
294
295 srv.shutdown().await;
296 assert_eq!(cnt_sht.get(), 1);
297
298 let res = srv.call("srv").await;
299 assert!(res.is_ok());
300 assert_eq!(res.unwrap(), ("srv", ()));
301 let _ = format!("{srv:?}");
302 }
303
304 #[ntex::test]
305 async fn test_create() {
306 let new_srv = chain_factory(
307 apply_fn_factory(
308 fn_factory(|| async { Ok::<_, ()>(Srv::default()) }),
309 async move |req: &'static str, srv| {
310 srv.call(()).await.unwrap();
311 Ok((req, ()))
312 },
313 )
314 .clone(),
315 );
316
317 let srv = new_srv.pipeline(&()).await.unwrap();
318
319 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
320
321 let res = srv.call("srv").await;
322 assert!(res.is_ok());
323 assert_eq!(res.unwrap(), ("srv", ()));
324 let _ = format!("{new_srv:?}");
325
326 assert!(Err == Err::from(()));
327 }
328
329 #[ntex::test]
330 async fn test_create_chain() {
331 let new_srv = chain_factory(fn_factory(|| async { Ok::<_, ()>(Srv::default()) }))
332 .apply_fn(async move |req: &'static str, srv| {
333 srv.call(()).await.unwrap();
334 Ok((req, ()))
335 })
336 .clone();
337
338 let srv = new_srv.pipeline(&()).await.unwrap();
339
340 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
341
342 let res = srv.call("srv").await;
343 assert!(res.is_ok());
344 assert_eq!(res.unwrap(), ("srv", ()));
345 let _ = format!("{new_srv:?}");
346 }
347}