ntex_service/
fn_shutdown.rs

1use std::{cell::Cell, fmt, marker::PhantomData};
2
3use crate::{Service, ServiceCtx, ServiceFactory};
4
5#[inline]
6/// Create `FnShutdown` for function that can act as a `on_shutdown` callback.
7pub fn fn_shutdown<Req, Err, F>(f: F) -> FnShutdown<Req, Err, F>
8where
9    F: AsyncFnOnce(),
10{
11    FnShutdown::new(f)
12}
13
14pub struct FnShutdown<Req, Err, F> {
15    f_shutdown: Cell<Option<F>>,
16    _t: PhantomData<(Req, Err)>,
17}
18
19impl<Req, Err, F> FnShutdown<Req, Err, F> {
20    pub(crate) fn new(f: F) -> Self {
21        Self {
22            f_shutdown: Cell::new(Some(f)),
23            _t: PhantomData,
24        }
25    }
26}
27
28impl<Req, Err, F> Clone for FnShutdown<Req, Err, F>
29where
30    F: Clone,
31{
32    #[inline]
33    fn clone(&self) -> Self {
34        let f = self.f_shutdown.take();
35        self.f_shutdown.set(f.clone());
36        Self {
37            f_shutdown: Cell::new(f),
38            _t: PhantomData,
39        }
40    }
41}
42
43impl<Req, Err, F> fmt::Debug for FnShutdown<Req, Err, F> {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.debug_struct("FnShutdown")
46            .field("fn", &std::any::type_name::<F>())
47            .finish()
48    }
49}
50
51impl<Req, Err, C, F> ServiceFactory<Req, C> for FnShutdown<Req, Err, F>
52where
53    F: AsyncFnOnce() + Clone,
54{
55    type Response = Req;
56    type Error = Err;
57    type Service = FnShutdown<Req, Err, F>;
58    type InitError = ();
59
60    #[inline]
61    async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
62        if let Some(f) = self.f_shutdown.take() {
63            self.f_shutdown.set(Some(f.clone()));
64            Ok(FnShutdown {
65                f_shutdown: Cell::new(Some(f)),
66                _t: PhantomData,
67            })
68        } else {
69            panic!("FnShutdown was used already");
70        }
71    }
72}
73
74impl<Req, Err, F> Service<Req> for FnShutdown<Req, Err, F>
75where
76    F: AsyncFnOnce(),
77{
78    type Response = Req;
79    type Error = Err;
80
81    #[inline]
82    async fn shutdown(&self) {
83        if let Some(f) = self.f_shutdown.take() {
84            (f)().await
85        }
86    }
87
88    #[inline]
89    async fn call(&self, req: Req, _: ServiceCtx<'_, Self>) -> Result<Req, Err> {
90        Ok(req)
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use std::{future::poll_fn, rc::Rc};
97
98    use crate::{chain_factory, fn_service};
99
100    use super::*;
101
102    #[ntex::test]
103    async fn test_fn_shutdown() {
104        let is_called = Rc::new(Cell::new(false));
105        let srv = fn_service(|_| async { Ok::<_, ()>("pipe") });
106        let is_called2 = is_called.clone();
107        let on_shutdown = fn_shutdown(async move || {
108            is_called2.set(true);
109        });
110
111        let pipe = chain_factory(srv)
112            .and_then(on_shutdown)
113            .clone()
114            .pipeline(())
115            .await
116            .unwrap();
117
118        let res = pipe.call(()).await;
119        assert_eq!(pipe.ready().await, Ok(()));
120        assert!(res.is_ok());
121        assert_eq!(res.unwrap(), "pipe");
122        assert!(!pipe.is_shutdown());
123        pipe.shutdown().await;
124        assert!(is_called.get());
125        assert!(!pipe.is_shutdown());
126
127        let pipe = pipe.bind();
128        let _ = poll_fn(|cx| pipe.poll_shutdown(cx)).await;
129        assert!(pipe.is_shutdown());
130
131        let _ = format!("{pipe:?}");
132    }
133
134    #[ntex::test]
135    #[should_panic]
136    async fn test_fn_shutdown_panic() {
137        let is_called = Rc::new(Cell::new(false));
138        let is_called2 = is_called.clone();
139        let on_shutdown = fn_shutdown::<(), (), _>(async move || {
140            is_called2.set(true);
141        });
142
143        let pipe = chain_factory(on_shutdown).pipeline(()).await.unwrap();
144        pipe.shutdown().await;
145        assert!(is_called.get());
146        assert!(!pipe.is_shutdown());
147
148        let _factory = pipe.get_ref().create(()).await;
149    }
150}