1use std::{cell::Cell, convert::Infallible, fmt, marker, task::Context, task::Poll, time};
2
3use ntex_service::{Service, ServiceCtx, ServiceFactory};
4
5use crate::time::{now, sleep, Millis, Sleep};
6
7pub struct KeepAlive<R, E, F> {
11 f: F,
12 ka: Millis,
13 _t: marker::PhantomData<(R, E)>,
14}
15
16impl<R, E, F> KeepAlive<R, E, F>
17where
18 F: Fn() -> E + Clone,
19{
20 pub fn new(ka: Millis, err: F) -> Self {
25 KeepAlive {
26 ka,
27 f: err,
28 _t: marker::PhantomData,
29 }
30 }
31}
32
33impl<R, E, F> Clone for KeepAlive<R, E, F>
34where
35 F: Clone,
36{
37 fn clone(&self) -> Self {
38 KeepAlive {
39 f: self.f.clone(),
40 ka: self.ka,
41 _t: marker::PhantomData,
42 }
43 }
44}
45
46impl<R, E, F> fmt::Debug for KeepAlive<R, E, F> {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_struct("KeepAlive")
49 .field("ka", &self.ka)
50 .field("f", &std::any::type_name::<F>())
51 .finish()
52 }
53}
54
55impl<R, E, F, C> ServiceFactory<R, C> for KeepAlive<R, E, F>
56where
57 F: Fn() -> E + Clone,
58{
59 type Response = R;
60 type Error = E;
61
62 type Service = KeepAliveService<R, E, F>;
63 type InitError = Infallible;
64
65 #[inline]
66 async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
67 Ok(KeepAliveService::new(self.ka, self.f.clone()))
68 }
69}
70
71pub struct KeepAliveService<R, E, F> {
72 f: F,
73 dur: Millis,
74 sleep: Sleep,
75 expire: Cell<time::Instant>,
76 _t: marker::PhantomData<(R, E)>,
77}
78
79impl<R, E, F> KeepAliveService<R, E, F>
80where
81 F: Fn() -> E,
82{
83 pub fn new(dur: Millis, f: F) -> Self {
84 let expire = Cell::new(now());
85
86 KeepAliveService {
87 f,
88 dur,
89 expire,
90 sleep: sleep(dur),
91 _t: marker::PhantomData,
92 }
93 }
94}
95
96impl<R, E, F> fmt::Debug for KeepAliveService<R, E, F> {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.debug_struct("KeepAliveService")
99 .field("dur", &self.dur)
100 .field("expire", &self.expire)
101 .field("f", &std::any::type_name::<F>())
102 .finish()
103 }
104}
105
106impl<R, E, F> Service<R> for KeepAliveService<R, E, F>
107where
108 F: Fn() -> E,
109{
110 type Response = R;
111 type Error = E;
112
113 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
114 let expire = self.expire.get() + time::Duration::from(self.dur);
115 if expire <= now() {
116 Err((self.f)())
117 } else {
118 Ok(())
119 }
120 }
121
122 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
123 match self.sleep.poll_elapsed(cx) {
124 Poll::Ready(_) => {
125 let now = now();
126 let expire = self.expire.get() + time::Duration::from(self.dur);
127 if expire <= now {
128 Err((self.f)())
129 } else {
130 let expire = expire - now;
131 self.sleep
132 .reset(Millis(expire.as_millis().try_into().unwrap_or(u32::MAX)));
133 let _ = self.sleep.poll_elapsed(cx);
134 Ok(())
135 }
136 }
137 Poll::Pending => Ok(()),
138 }
139 }
140
141 #[inline]
142 async fn call(&self, req: R, _: ServiceCtx<'_, Self>) -> Result<R, E> {
143 self.expire.set(now());
144 Ok(req)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::task::Poll;
151
152 use super::*;
153 use crate::future::lazy;
154
155 #[derive(Debug, PartialEq)]
156 struct TestErr;
157
158 #[ntex_macros::rt_test2]
159 async fn test_ka() {
160 let factory = KeepAlive::new(Millis(100), || TestErr);
161 assert!(format!("{factory:?}").contains("KeepAlive"));
162 let _ = factory.clone();
163
164 let service = factory.pipeline(&()).await.unwrap().bind();
165 assert!(format!("{service:?}").contains("KeepAliveService"));
166
167 assert_eq!(service.call(1usize).await, Ok(1usize));
168 assert!(lazy(|cx| service.poll_ready(cx)).await.is_ready());
169
170 sleep(Millis(500)).await;
171 assert_eq!(
172 lazy(|cx| service.poll_ready(cx)).await,
173 Poll::Ready(Err(TestErr))
174 );
175 }
176}