1#![doc = include_str!("../README.md")]
2
3use std::{convert::Infallible, fmt::Debug, future::Future, marker::PhantomData, sync::Arc};
4
5use parking_lot::RwLock;
6use tokio::time::{sleep, Duration, Instant};
7
8pub struct Refreshed<T, E> {
10 inner: Arc<RwLock<RefreshState<T, E>>>,
11}
12
13impl<T, E> Clone for Refreshed<T, E> {
14 fn clone(&self) -> Self {
15 Self {
16 inner: self.inner.clone(),
17 }
18 }
19}
20
21struct RefreshState<T, E> {
23 pub value: Arc<T>,
25 updated: Instant,
27 last_error: Option<Arc<E>>,
29}
30
31impl<T, E> Clone for RefreshState<T, E> {
32 fn clone(&self) -> Self {
33 RefreshState {
34 value: self.value.clone(),
35 updated: self.updated,
36 last_error: self.last_error.clone(),
37 }
38 }
39}
40
41impl<T, E> Refreshed<T, E> {
42 pub fn builder() -> Builder<T, E> {
44 Builder::default()
45 }
46
47 pub fn get(&self) -> Arc<T> {
49 self.inner.read().value.clone()
50 }
51
52 pub fn get_updated(&self) -> Instant {
54 self.inner.read().updated
55 }
56
57 pub fn get_last_error(&self) -> Option<Arc<E>> {
61 self.inner.read().last_error.clone()
62 }
63
64 #[cfg(test)]
65 fn get_state(&self) -> RefreshState<T, E> {
67 self.inner.read().clone()
68 }
69}
70
71pub struct Builder<T, E> {
74 duration: Duration,
75 success: Arc<dyn Fn(&T) + Send + Sync>,
76 error: Arc<dyn Fn(&E) + Send + Sync>,
77 exit: Arc<dyn Fn() + Send + Sync>,
78 _phantom: PhantomData<Result<T, E>>,
79}
80
81impl<T, E> Default for Builder<T, E> {
82 fn default() -> Self {
83 Builder {
84 duration: Duration::from_secs(60),
85 success: Arc::new(|_| ()),
86 error: Arc::new(|_| ()),
87 exit: Arc::new(|| log::debug!("Refresh loop exited")),
88 _phantom: PhantomData,
89 }
90 }
91}
92
93impl<T, E> Builder<T, E>
94where
95 T: Send + Sync + 'static,
96 E: Send + Sync + 'static,
97{
98 pub fn duration(&mut self, duration: Duration) -> &mut Self {
100 self.duration = duration;
101 self
102 }
103
104 pub fn error(&mut self, error: impl Fn(&E) + Send + Sync + 'static) -> &mut Self {
106 self.error = Arc::new(error);
107 self
108 }
109
110 pub fn success(&mut self, success: impl Fn(&T) + Send + Sync + 'static) -> &mut Self {
112 self.success = Arc::new(success);
113 self
114 }
115
116 pub fn exit(&mut self, exit: impl Fn() + Send + Sync + 'static) -> &mut Self {
118 self.exit = Arc::new(exit);
119 self
120 }
121
122 pub async fn try_build<Fut, MkFut>(&self, mut mk_fut: MkFut) -> Result<Refreshed<T, E>, E>
126 where
127 Fut: Future<Output = Result<T, E>> + Send + 'static,
128 MkFut: FnMut(bool) -> Fut + Send + 'static,
129 {
130 let init = RefreshState {
131 value: Arc::new(mk_fut(false).await?),
132 updated: Instant::now(),
133 last_error: None,
134 };
135 let refresh = Refreshed {
136 inner: Arc::new(RwLock::new(init)),
137 };
138 let weak = Arc::downgrade(&refresh.inner);
139 let duration = self.duration;
140 let success = self.success.clone();
141 let error = self.error.clone();
142 let exit = self.exit.clone();
143 tokio::spawn(async move {
144 let _exit = Dropper(Some(|| exit()));
145 loop {
146 sleep(duration).await;
147 let arc = match weak.upgrade() {
148 None => break,
149 Some(arc) => arc,
150 };
151
152 match mk_fut(true).await {
153 Err(e) => {
154 error(&e);
155 arc.write().last_error = Some(Arc::new(e));
156 }
157 Ok(t) => {
158 success(&t);
159 let mut lock = arc.write();
160 lock.value = Arc::new(t);
161 lock.updated = Instant::now();
162 lock.last_error = None;
163 }
164 }
165 }
166 });
167 Ok(refresh)
168 }
169}
170
171struct Dropper<F: FnOnce()>(Option<F>);
173
174impl<F: FnOnce()> Drop for Dropper<F> {
175 fn drop(&mut self) {
176 if let Some(f) = self.0.take() {
177 f()
178 }
179 }
180}
181
182impl<T> Builder<T, Infallible>
183where
184 T: Send + Sync + 'static,
185{
186 pub async fn build<Fut, MkFut>(&self, mut mk_fut: MkFut) -> Refreshed<T, Infallible>
188 where
189 Fut: Future<Output = T> + Send + 'static,
190 MkFut: FnMut(bool) -> Fut + Send + 'static,
191 {
192 let res = self
193 .try_build(move |is_refresh| {
194 let fut = mk_fut(is_refresh);
195 async move {
196 let t = fut.await;
197 Ok::<_, Infallible>(t)
198 }
199 })
200 .await;
201
202 absurd(res)
203 }
204}
205
206fn absurd<T>(res: Result<T, Infallible>) -> T {
207 res.expect("absurd!")
208}
209
210impl<T, E> Builder<T, E>
211where
212 T: Send + Sync + 'static,
213 E: Debug + Send + Sync + 'static,
214{
215 pub fn log_errors(&mut self) -> &mut Self {
217 self.error(|e| log::error!("{:?}", e))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use std::{convert::Infallible, sync::Arc};
224
225 use parking_lot::RwLock;
226 use tokio::time::{sleep, Duration};
227
228 use super::Refreshed;
229
230 #[tokio::test]
231 async fn simple_no_refresh() {
232 let x = Refreshed::builder()
233 .try_build(|_| async { Ok::<_, Infallible>(42_u32) })
234 .await
235 .unwrap();
236 assert_eq!(*x.get(), 42);
237 }
238
239 #[tokio::test]
240 async fn refreshes() {
241 let counter = Arc::new(RwLock::new(0u32));
242 let counter_clone = counter.clone();
243 let mk_fut = move |_| {
244 let counter_clone = counter_clone.clone();
245 async move {
246 let mut lock = counter_clone.write();
247 *lock += 1;
248 Ok::<u32, Infallible>(*lock)
249 }
250 };
251 let duration = Duration::from_millis(10);
252 let x = Refreshed::builder()
253 .duration(duration)
254 .try_build(mk_fut)
255 .await
256 .unwrap();
257 assert_eq!(*x.get(), 1);
258 for _ in 0..10u32 {
259 sleep(duration).await;
260 assert_eq!(*x.get(), *counter.read());
261 }
262 }
263
264 #[tokio::test]
265 async fn stops_refreshing() {
266 let exited = Arc::new(RwLock::new(false));
267 let exited_clone = exited.clone();
268 let counter = Arc::new(RwLock::new(0u32));
269 let counter_clone = counter.clone();
270 let mk_fut = move |_| {
271 let counter_clone = counter_clone.clone();
272 async move {
273 let mut lock = counter_clone.write();
274 *lock += 1;
275 Ok::<u32, Infallible>(*lock)
276 }
277 };
278 let duration = Duration::from_millis(10);
279 let x = Refreshed::builder()
280 .duration(duration)
281 .exit(move || *exited_clone.write() = true)
282 .try_build(mk_fut)
283 .await
284 .unwrap();
285 assert_eq!(*x.get(), 1);
286 assert_eq!(*exited.read(), false);
287 sleep(duration).await;
288 std::mem::drop(x);
289 let val = *counter.read();
290 for _ in 0..5u32 {
291 sleep(duration).await;
292 assert_eq!(val, *counter.read());
293 }
294 assert_eq!(*exited.read(), true);
295 }
296
297 #[tokio::test]
298 async fn count_successes() {
299 let counter = Arc::new(RwLock::new(0u32));
300 let counter_clone = counter.clone();
301 let success = Arc::new(RwLock::new(1u32));
303 let success_clone = success.clone();
304 let mk_fut = move |_| {
305 let counter_clone = counter_clone.clone();
306 async move {
307 let mut lock = counter_clone.write();
308 *lock += 1;
309 Ok::<u32, Infallible>(*lock)
310 }
311 };
312 let duration = Duration::from_millis(10);
313 let x = Refreshed::builder()
314 .duration(duration)
315 .success(move |_| *success_clone.write() += 1)
316 .try_build(mk_fut)
317 .await
318 .unwrap();
319 assert_eq!(*x.get(), 1);
320 for _ in 0..10u32 {
321 sleep(duration).await;
322 assert_eq!(*x.get(), *counter.read());
323 assert_eq!(*x.get(), *success.read());
324 }
325 }
326
327 #[tokio::test]
328 async fn simple_build() {
329 let x = Refreshed::builder().build(|_| async { 42_u32 }).await;
330 assert_eq!(*x.get(), 42);
331 }
332
333 #[tokio::test]
334 async fn exit_on_panic() {
335 let exited = Arc::new(RwLock::new(false));
336 let exited_clone = exited.clone();
337 let mk_fut = move |is_refresh| async move {
338 if is_refresh {
339 panic!("Don't panic!");
340 } else {
341 ()
342 }
343 };
344 let duration = Duration::from_millis(10);
345 let x = Refreshed::builder()
346 .duration(duration)
347 .exit(move || *exited_clone.write() = true)
348 .build(mk_fut)
349 .await;
350 assert_eq!(*exited.read(), false);
351 sleep(duration).await;
352 sleep(duration).await;
353 assert_eq!(*exited.read(), true);
354 assert_eq!(x.get_state().last_error, None);
355 }
356}