1use std::{fmt, task::Context};
2
3use super::{Service, ServiceCtx, ServiceFactory};
4
5pub struct Inspect<S, F> {
7 svc: S,
8 f: F,
9}
10
11impl<S, F> Inspect<S, F> {
12 pub(crate) fn new<R>(svc: S, f: F) -> Self
14 where
15 S: Service<R>,
16 F: Fn(&S::Response),
17 {
18 Self { svc, f }
19 }
20}
21
22impl<S, F> Clone for Inspect<S, F>
23where
24 S: Clone,
25 F: Clone,
26{
27 #[inline]
28 fn clone(&self) -> Self {
29 Inspect {
30 svc: self.svc.clone(),
31 f: self.f.clone(),
32 }
33 }
34}
35
36impl<S, F> fmt::Debug for Inspect<S, F>
37where
38 S: fmt::Debug,
39{
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.debug_struct("Inspect")
42 .field("svc", &self.svc)
43 .field("inspect", &std::any::type_name::<F>())
44 .finish()
45 }
46}
47
48impl<S, F, R> Service<R> for Inspect<S, F>
49where
50 S: Service<R>,
51 F: Fn(&S::Response),
52{
53 type Response = S::Response;
54 type Error = S::Error;
55
56 #[inline]
57 async fn call(&self, r: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
58 ctx.call(&self.svc, r).await.inspect(&self.f)
59 }
60
61 crate::forward_ready!(svc);
62 crate::forward_poll!(svc);
63 crate::forward_shutdown!(svc);
64}
65
66pub struct InspectErr<S, F> {
68 svc: S,
69 f: F,
70}
71
72impl<S, F> InspectErr<S, F> {
73 pub(crate) fn new<R>(svc: S, f: F) -> Self
75 where
76 S: Service<R>,
77 F: Fn(&S::Error),
78 {
79 Self { svc, f }
80 }
81}
82
83impl<S, F> Clone for InspectErr<S, F>
84where
85 S: Clone,
86 F: Clone,
87{
88 #[inline]
89 fn clone(&self) -> Self {
90 InspectErr {
91 svc: self.svc.clone(),
92 f: self.f.clone(),
93 }
94 }
95}
96
97impl<S, F> fmt::Debug for InspectErr<S, F>
98where
99 S: fmt::Debug,
100{
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("InspectErr")
103 .field("svc", &self.svc)
104 .field("inspect_err", &std::any::type_name::<F>())
105 .finish()
106 }
107}
108
109impl<S, F, R> Service<R> for InspectErr<S, F>
110where
111 S: Service<R>,
112 F: Fn(&S::Error),
113{
114 type Response = S::Response;
115 type Error = S::Error;
116
117 #[inline]
118 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
119 ctx.ready(&self.svc).await.inspect_err(&self.f)
120 }
121
122 #[inline]
123 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
124 self.svc.poll(cx).inspect_err(&self.f)
125 }
126
127 #[inline]
128 async fn call(&self, r: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
129 ctx.call(&self.svc, r).await.inspect_err(&self.f)
130 }
131
132 crate::forward_shutdown!(svc);
133}
134
135pub struct InspectFactory<S, F> {
137 s: S,
138 f: F,
139}
140
141impl<S, F> InspectFactory<S, F> {
142 pub(crate) fn new(s: S, f: F) -> Self {
144 Self { s, f }
145 }
146}
147
148impl<S, F> Clone for InspectFactory<S, F>
149where
150 S: Clone,
151 F: Clone,
152{
153 fn clone(&self) -> Self {
154 Self {
155 s: self.s.clone(),
156 f: self.f.clone(),
157 }
158 }
159}
160
161impl<S, F> fmt::Debug for InspectFactory<S, F>
162where
163 S: fmt::Debug,
164{
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 f.debug_struct("InspectFactory")
167 .field("factory", &self.s)
168 .field("inspect", &std::any::type_name::<F>())
169 .finish()
170 }
171}
172
173impl<S, F, R, C> ServiceFactory<R, C> for InspectFactory<S, F>
174where
175 S: ServiceFactory<R, C>,
176 F: Fn(&S::Response) + Clone,
177{
178 type Response = S::Response;
179 type Error = S::Error;
180
181 type Service = Inspect<S::Service, F>;
182 type InitError = S::InitError;
183
184 #[inline]
185 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
186 self.s.create(cfg).await.map(|svc| Inspect {
187 svc,
188 f: self.f.clone(),
189 })
190 }
191}
192
193pub struct InspectErrFactory<S, F> {
195 s: S,
196 f: F,
197}
198
199impl<S, F> InspectErrFactory<S, F> {
200 pub(crate) fn new(s: S, f: F) -> Self {
202 Self { s, f }
203 }
204}
205
206impl<S, F> Clone for InspectErrFactory<S, F>
207where
208 S: Clone,
209 F: Clone,
210{
211 fn clone(&self) -> Self {
212 Self {
213 s: self.s.clone(),
214 f: self.f.clone(),
215 }
216 }
217}
218
219impl<S, F> fmt::Debug for InspectErrFactory<S, F>
220where
221 S: fmt::Debug,
222{
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 f.debug_struct("InspectErrFactory")
225 .field("factory", &self.s)
226 .field("inspect_err", &std::any::type_name::<F>())
227 .finish()
228 }
229}
230
231impl<S, F, R, C> ServiceFactory<R, C> for InspectErrFactory<S, F>
232where
233 S: ServiceFactory<R, C>,
234 F: Fn(&S::Error) + Clone,
235{
236 type Response = S::Response;
237 type Error = S::Error;
238
239 type Service = InspectErr<S::Service, F>;
240 type InitError = S::InitError;
241
242 #[inline]
243 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
244 self.s.create(cfg).await.map(|svc| InspectErr {
245 svc,
246 f: self.f.clone(),
247 })
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use std::{cell::Cell, rc::Rc};
254
255 use super::*;
256 use crate::{chain, chain_factory, fn_factory};
257
258 #[derive(Debug, Clone)]
259 struct Srv(bool, bool, Rc<Cell<usize>>);
260
261 impl Service<()> for Srv {
262 type Response = ();
263 type Error = ();
264
265 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
266 if self.1 { Err(()) } else { Ok(()) }
267 }
268
269 async fn call(&self, _m: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
270 if self.0 { Err(()) } else { Ok(()) }
271 }
272
273 async fn shutdown(&self) {
274 self.2.set(self.2.get() + 1);
275 }
276 }
277
278 #[ntex::test]
279 async fn test_inspect_ready() {
280 let cnt = Rc::new(Cell::new(0));
281 let cnt2 = cnt.clone();
282 let srv = chain(Srv(false, false, cnt.clone()))
283 .inspect(move |&()| cnt2.set(cnt2.get() + 1))
284 .into_pipeline();
285 let res = srv.ready().await;
286 assert_eq!(res, Ok(()));
287
288 srv.shutdown().await;
289 assert_eq!(cnt.get(), 1);
290 }
291
292 #[ntex::test]
293 async fn test_inspect_err_ready() {
294 let cnt = Rc::new(Cell::new(0));
295 let cnt2 = cnt.clone();
296 let srv = chain(Srv(true, true, cnt.clone()))
297 .inspect_err(move |&()| cnt2.set(cnt2.get() + 1))
298 .into_pipeline();
299 let res = srv.ready().await;
300 assert_eq!(res, Err(()));
301
302 srv.shutdown().await;
303 assert_eq!(cnt.get(), 2);
304 }
305
306 #[ntex::test]
307 async fn test_inspect_service() {
308 let cnt = Rc::new(Cell::new(0));
309 let cnt2 = cnt.clone();
310 let srv = chain(Srv(false, false, cnt.clone()))
311 .inspect(move |&()| cnt2.set(cnt2.get() + 1))
312 .clone()
313 .into_pipeline();
314 let res = srv.call(()).await;
315 assert!(res.is_ok());
316
317 let _ = format!("{srv:?}");
318
319 srv.shutdown().await;
320 assert_eq!(cnt.get(), 2);
321 }
322
323 #[ntex::test]
324 async fn test_inspect_err_service() {
325 let cnt = Rc::new(Cell::new(0));
326 let cnt2 = cnt.clone();
327 let srv = chain(Srv(false, true, cnt.clone()))
328 .inspect_err(move |&()| cnt2.set(cnt2.get() + 1))
329 .clone()
330 .into_pipeline();
331 let res = srv.call(()).await;
332 assert!(res.is_err());
333 assert_eq!(res.err().unwrap(), ());
334
335 let _ = format!("{srv:?}");
336
337 srv.shutdown().await;
338 assert_eq!(cnt.get(), 2);
339 }
340
341 #[ntex::test]
342 async fn test_inspect_factory() {
343 let cnt = Rc::new(Cell::new(0));
344 let cnt2 = cnt.clone();
345 let cnt3 = cnt.clone();
346 let new_srv = chain_factory(fn_factory(async move || {
347 Ok::<_, ()>(Srv(false, false, cnt2.clone()))
348 }))
349 .inspect(move |&()| cnt3.set(cnt3.get() + 1))
350 .clone();
351 let srv = new_srv.pipeline(&()).await.unwrap();
352 let res = srv.call(()).await;
353 assert!(res.is_ok());
354 let _ = format!("{new_srv:?}");
355 srv.shutdown().await;
356 assert_eq!(cnt.get(), 2);
357 }
358
359 #[ntex::test]
360 async fn test_inspect_err_factory() {
361 let cnt = Rc::new(Cell::new(0));
362 let cnt2 = cnt.clone();
363 let cnt3 = cnt.clone();
364 let new_srv = chain_factory(fn_factory(async move || {
365 Ok::<_, ()>(Srv(false, true, cnt2.clone()))
366 }))
367 .inspect_err(move |&()| cnt3.set(cnt3.get() + 1))
368 .clone();
369 let srv = new_srv.pipeline(&()).await.unwrap();
370 let res = srv.call(()).await;
371 assert!(res.is_err());
372 assert_eq!(res.err().unwrap(), ());
373 let _ = format!("{new_srv:?}");
374 srv.shutdown().await;
375 assert_eq!(cnt.get(), 2);
376 }
377}