ntex_util/services/
keepalive.rs

1use std::{cell::Cell, convert::Infallible, fmt, marker, task::Context, task::Poll, time};
2
3use ntex_service::{Service, ServiceCtx, ServiceFactory};
4
5use crate::time::{Millis, Sleep, now, sleep};
6
7/// KeepAlive service factory
8///
9/// Controls min time between requests.
10pub struct KeepAlive<R, E, F> {
11    f: F,
12    ka: Millis,
13    _t: marker::PhantomData<(R, E)>,
14}
15
16impl<R, E, F> KeepAlive<R, E, F>
17where
18    F: Fn() -> E + Clone,
19{
20    /// Construct KeepAlive service factory.
21    ///
22    /// ka - keep-alive timeout
23    /// err - error factory function
24    pub fn new(ka: Millis, err: F) -> Self {
25        KeepAlive {
26            ka,
27            f: err,
28            _t: marker::PhantomData,
29        }
30    }
31}
32
33impl<R, E, F> Clone for KeepAlive<R, E, F>
34where
35    F: Clone,
36{
37    fn clone(&self) -> Self {
38        KeepAlive {
39            f: self.f.clone(),
40            ka: self.ka,
41            _t: marker::PhantomData,
42        }
43    }
44}
45
46impl<R, E, F> fmt::Debug for KeepAlive<R, E, F> {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("KeepAlive")
49            .field("ka", &self.ka)
50            .field("f", &std::any::type_name::<F>())
51            .finish()
52    }
53}
54
55impl<R, E, F, C> ServiceFactory<R, C> for KeepAlive<R, E, F>
56where
57    F: Fn() -> E + Clone,
58{
59    type Response = R;
60    type Error = E;
61
62    type Service = KeepAliveService<R, E, F>;
63    type InitError = Infallible;
64
65    #[inline]
66    async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
67        Ok(KeepAliveService::new(self.ka, self.f.clone()))
68    }
69}
70
71pub struct KeepAliveService<R, E, F> {
72    f: F,
73    dur: Millis,
74    sleep: Sleep,
75    expire: Cell<time::Instant>,
76    _t: marker::PhantomData<(R, E)>,
77}
78
79impl<R, E, F> KeepAliveService<R, E, F>
80where
81    F: Fn() -> E,
82{
83    pub fn new(dur: Millis, f: F) -> Self {
84        let expire = Cell::new(now());
85
86        KeepAliveService {
87            f,
88            dur,
89            expire,
90            sleep: sleep(dur),
91            _t: marker::PhantomData,
92        }
93    }
94}
95
96impl<R, E, F> fmt::Debug for KeepAliveService<R, E, F> {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("KeepAliveService")
99            .field("dur", &self.dur)
100            .field("expire", &self.expire)
101            .field("f", &std::any::type_name::<F>())
102            .finish()
103    }
104}
105
106impl<R, E, F> Service<R> for KeepAliveService<R, E, F>
107where
108    F: Fn() -> E,
109{
110    type Response = R;
111    type Error = E;
112
113    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
114        let expire = self.expire.get() + time::Duration::from(self.dur);
115        if expire <= now() { Err((self.f)()) } else { Ok(()) }
116    }
117
118    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
119        match self.sleep.poll_elapsed(cx) {
120            Poll::Ready(_) => {
121                let now = now();
122                let expire = self.expire.get() + time::Duration::from(self.dur);
123                if expire <= now {
124                    Err((self.f)())
125                } else {
126                    let expire = expire - now;
127                    self.sleep
128                        .reset(Millis(expire.as_millis().try_into().unwrap_or(u32::MAX)));
129                    let _ = self.sleep.poll_elapsed(cx);
130                    Ok(())
131                }
132            }
133            Poll::Pending => Ok(()),
134        }
135    }
136
137    #[inline]
138    async fn call(&self, req: R, _: ServiceCtx<'_, Self>) -> Result<R, E> {
139        self.expire.set(now());
140        Ok(req)
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use std::task::Poll;
147
148    use super::*;
149    use crate::future::lazy;
150
151    #[derive(Debug, PartialEq)]
152    struct TestErr;
153
154    #[ntex::test]
155    async fn test_ka() {
156        let factory = KeepAlive::new(Millis(100), || TestErr);
157        assert!(format!("{factory:?}").contains("KeepAlive"));
158        let _ = factory.clone();
159
160        let service = factory.pipeline(&()).await.unwrap().bind();
161        assert!(format!("{service:?}").contains("KeepAliveService"));
162
163        assert_eq!(service.call(1usize).await, Ok(1usize));
164        assert!(lazy(|cx| service.poll_ready(cx)).await.is_ready());
165
166        sleep(Millis(500)).await;
167        assert_eq!(
168            lazy(|cx| service.poll_ready(cx)).await,
169            Poll::Ready(Err(TestErr))
170        );
171    }
172}