1use std::{cell::Cell, convert::Infallible, fmt, marker, task::Context, task::Poll, time};
2
3use ntex_service::{Service, ServiceCtx, ServiceFactory};
4
5use crate::time::{Millis, Sleep, now, sleep};
6
7pub struct KeepAlive<R, E, F> {
11 f: F,
12 ka: Millis,
13 _t: marker::PhantomData<(R, E)>,
14}
15
16impl<R, E, F> KeepAlive<R, E, F>
17where
18 F: Fn() -> E + Clone,
19{
20 pub fn new(ka: Millis, err: F) -> Self {
25 KeepAlive {
26 ka,
27 f: err,
28 _t: marker::PhantomData,
29 }
30 }
31}
32
33impl<R, E, F> Clone for KeepAlive<R, E, F>
34where
35 F: Clone,
36{
37 fn clone(&self) -> Self {
38 KeepAlive {
39 f: self.f.clone(),
40 ka: self.ka,
41 _t: marker::PhantomData,
42 }
43 }
44}
45
46impl<R, E, F> fmt::Debug for KeepAlive<R, E, F> {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_struct("KeepAlive")
49 .field("ka", &self.ka)
50 .field("f", &std::any::type_name::<F>())
51 .finish()
52 }
53}
54
55impl<R, E, F, C> ServiceFactory<R, C> for KeepAlive<R, E, F>
56where
57 F: Fn() -> E + Clone,
58{
59 type Response = R;
60 type Error = E;
61
62 type Service = KeepAliveService<R, E, F>;
63 type InitError = Infallible;
64
65 #[inline]
66 async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
67 Ok(KeepAliveService::new(self.ka, self.f.clone()))
68 }
69}
70
71pub struct KeepAliveService<R, E, F> {
72 f: F,
73 dur: Millis,
74 sleep: Sleep,
75 expire: Cell<time::Instant>,
76 _t: marker::PhantomData<(R, E)>,
77}
78
79impl<R, E, F> KeepAliveService<R, E, F>
80where
81 F: Fn() -> E,
82{
83 pub fn new(dur: Millis, f: F) -> Self {
84 let expire = Cell::new(now());
85
86 KeepAliveService {
87 f,
88 dur,
89 expire,
90 sleep: sleep(dur),
91 _t: marker::PhantomData,
92 }
93 }
94}
95
96impl<R, E, F> fmt::Debug for KeepAliveService<R, E, F> {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.debug_struct("KeepAliveService")
99 .field("dur", &self.dur)
100 .field("expire", &self.expire)
101 .field("f", &std::any::type_name::<F>())
102 .finish()
103 }
104}
105
106impl<R, E, F> Service<R> for KeepAliveService<R, E, F>
107where
108 F: Fn() -> E,
109{
110 type Response = R;
111 type Error = E;
112
113 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
114 let expire = self.expire.get() + time::Duration::from(self.dur);
115 if expire <= now() { Err((self.f)()) } else { Ok(()) }
116 }
117
118 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
119 match self.sleep.poll_elapsed(cx) {
120 Poll::Ready(_) => {
121 let now = now();
122 let expire = self.expire.get() + time::Duration::from(self.dur);
123 if expire <= now {
124 Err((self.f)())
125 } else {
126 let expire = expire - now;
127 self.sleep
128 .reset(Millis(expire.as_millis().try_into().unwrap_or(u32::MAX)));
129 let _ = self.sleep.poll_elapsed(cx);
130 Ok(())
131 }
132 }
133 Poll::Pending => Ok(()),
134 }
135 }
136
137 #[inline]
138 async fn call(&self, req: R, _: ServiceCtx<'_, Self>) -> Result<R, E> {
139 self.expire.set(now());
140 Ok(req)
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::task::Poll;
147
148 use super::*;
149 use crate::future::lazy;
150
151 #[derive(Debug, PartialEq)]
152 struct TestErr;
153
154 #[ntex::test]
155 async fn test_ka() {
156 let factory = KeepAlive::new(Millis(100), || TestErr);
157 assert!(format!("{factory:?}").contains("KeepAlive"));
158 let _ = factory.clone();
159
160 let service = factory.pipeline(&()).await.unwrap().bind();
161 assert!(format!("{service:?}").contains("KeepAliveService"));
162
163 assert_eq!(service.call(1usize).await, Ok(1usize));
164 assert!(lazy(|cx| service.poll_ready(cx)).await.is_ready());
165
166 sleep(Millis(500)).await;
167 assert_eq!(
168 lazy(|cx| service.poll_ready(cx)).await,
169 Poll::Ready(Err(TestErr))
170 );
171 }
172}