ntex_util/services/
keepalive.rs

1use std::{cell::Cell, convert::Infallible, fmt, marker, task::Context, task::Poll, time};
2
3use ntex_service::{Service, ServiceCtx, ServiceFactory};
4
5use crate::time::{now, sleep, Millis, Sleep};
6
7/// KeepAlive service factory
8///
9/// Controls min time between requests.
10pub struct KeepAlive<R, E, F> {
11    f: F,
12    ka: Millis,
13    _t: marker::PhantomData<(R, E)>,
14}
15
16impl<R, E, F> KeepAlive<R, E, F>
17where
18    F: Fn() -> E + Clone,
19{
20    /// Construct KeepAlive service factory.
21    ///
22    /// ka - keep-alive timeout
23    /// err - error factory function
24    pub fn new(ka: Millis, err: F) -> Self {
25        KeepAlive {
26            ka,
27            f: err,
28            _t: marker::PhantomData,
29        }
30    }
31}
32
33impl<R, E, F> Clone for KeepAlive<R, E, F>
34where
35    F: Clone,
36{
37    fn clone(&self) -> Self {
38        KeepAlive {
39            f: self.f.clone(),
40            ka: self.ka,
41            _t: marker::PhantomData,
42        }
43    }
44}
45
46impl<R, E, F> fmt::Debug for KeepAlive<R, E, F> {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("KeepAlive")
49            .field("ka", &self.ka)
50            .field("f", &std::any::type_name::<F>())
51            .finish()
52    }
53}
54
55impl<R, E, F, C> ServiceFactory<R, C> for KeepAlive<R, E, F>
56where
57    F: Fn() -> E + Clone,
58{
59    type Response = R;
60    type Error = E;
61
62    type Service = KeepAliveService<R, E, F>;
63    type InitError = Infallible;
64
65    #[inline]
66    async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
67        Ok(KeepAliveService::new(self.ka, self.f.clone()))
68    }
69}
70
71pub struct KeepAliveService<R, E, F> {
72    f: F,
73    dur: Millis,
74    sleep: Sleep,
75    expire: Cell<time::Instant>,
76    _t: marker::PhantomData<(R, E)>,
77}
78
79impl<R, E, F> KeepAliveService<R, E, F>
80where
81    F: Fn() -> E,
82{
83    pub fn new(dur: Millis, f: F) -> Self {
84        let expire = Cell::new(now());
85
86        KeepAliveService {
87            f,
88            dur,
89            expire,
90            sleep: sleep(dur),
91            _t: marker::PhantomData,
92        }
93    }
94}
95
96impl<R, E, F> fmt::Debug for KeepAliveService<R, E, F> {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("KeepAliveService")
99            .field("dur", &self.dur)
100            .field("expire", &self.expire)
101            .field("f", &std::any::type_name::<F>())
102            .finish()
103    }
104}
105
106impl<R, E, F> Service<R> for KeepAliveService<R, E, F>
107where
108    F: Fn() -> E,
109{
110    type Response = R;
111    type Error = E;
112
113    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
114        let expire = self.expire.get() + time::Duration::from(self.dur);
115        if expire <= now() {
116            Err((self.f)())
117        } else {
118            Ok(())
119        }
120    }
121
122    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
123        match self.sleep.poll_elapsed(cx) {
124            Poll::Ready(_) => {
125                let now = now();
126                let expire = self.expire.get() + time::Duration::from(self.dur);
127                if expire <= now {
128                    Err((self.f)())
129                } else {
130                    let expire = expire - now;
131                    self.sleep
132                        .reset(Millis(expire.as_millis().try_into().unwrap_or(u32::MAX)));
133                    let _ = self.sleep.poll_elapsed(cx);
134                    Ok(())
135                }
136            }
137            Poll::Pending => Ok(()),
138        }
139    }
140
141    #[inline]
142    async fn call(&self, req: R, _: ServiceCtx<'_, Self>) -> Result<R, E> {
143        self.expire.set(now());
144        Ok(req)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::task::Poll;
151
152    use super::*;
153    use crate::future::lazy;
154
155    #[derive(Debug, PartialEq)]
156    struct TestErr;
157
158    #[ntex_macros::rt_test2]
159    async fn test_ka() {
160        let factory = KeepAlive::new(Millis(100), || TestErr);
161        assert!(format!("{factory:?}").contains("KeepAlive"));
162        let _ = factory.clone();
163
164        let service = factory.pipeline(&()).await.unwrap().bind();
165        assert!(format!("{service:?}").contains("KeepAliveService"));
166
167        assert_eq!(service.call(1usize).await, Ok(1usize));
168        assert!(lazy(|cx| service.poll_ready(cx)).await.is_ready());
169
170        sleep(Millis(500)).await;
171        assert_eq!(
172            lazy(|cx| service.poll_ready(cx)).await,
173            Poll::Ready(Err(TestErr))
174        );
175    }
176}