async_refresh/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{convert::Infallible, fmt::Debug, future::Future, marker::PhantomData, sync::Arc};
4
5use parking_lot::RwLock;
6use tokio::time::{sleep, Duration, Instant};
7
8/// A value which will be refreshed asynchronously.
9pub struct Refreshed<T, E> {
10    inner: Arc<RwLock<RefreshState<T, E>>>,
11}
12
13impl<T, E> Clone for Refreshed<T, E> {
14    fn clone(&self) -> Self {
15        Self {
16            inner: self.inner.clone(),
17        }
18    }
19}
20
21/// The internal state of a [Refreshed].
22struct RefreshState<T, E> {
23    /// The most recently updated value.
24    pub value: Arc<T>,
25    /// The timestamp when the most recent value was updated.
26    updated: Instant,
27    /// The error message, if present, from the last attempted refresh.
28    last_error: Option<Arc<E>>,
29}
30
31impl<T, E> Clone for RefreshState<T, E> {
32    fn clone(&self) -> Self {
33        RefreshState {
34            value: self.value.clone(),
35            updated: self.updated,
36            last_error: self.last_error.clone(),
37        }
38    }
39}
40
41impl<T, E> Refreshed<T, E> {
42    /// Create an initial [Builder] value with defaults.
43    pub fn builder() -> Builder<T, E> {
44        Builder::default()
45    }
46
47    /// Get the most recent value
48    pub fn get(&self) -> Arc<T> {
49        self.inner.read().value.clone()
50    }
51
52    /// Get the timestamp of the most recent successful update
53    pub fn get_updated(&self) -> Instant {
54        self.inner.read().updated
55    }
56
57    /// The error message, if present, from the last attempted refresh.
58    ///
59    /// Note that on each successful refresh, this is reset to `None`.
60    pub fn get_last_error(&self) -> Option<Arc<E>> {
61        self.inner.read().last_error.clone()
62    }
63
64    #[cfg(test)]
65    /// Get the full state
66    fn get_state(&self) -> RefreshState<T, E> {
67        self.inner.read().clone()
68    }
69}
70
71/// Construct the settings around how a [Refreshed] should be created and
72/// updated.
73pub struct Builder<T, E> {
74    duration: Duration,
75    success: Arc<dyn Fn(&T) + Send + Sync>,
76    error: Arc<dyn Fn(&E) + Send + Sync>,
77    exit: Arc<dyn Fn() + Send + Sync>,
78    _phantom: PhantomData<Result<T, E>>,
79}
80
81impl<T, E> Default for Builder<T, E> {
82    fn default() -> Self {
83        Builder {
84            duration: Duration::from_secs(60),
85            success: Arc::new(|_| ()),
86            error: Arc::new(|_| ()),
87            exit: Arc::new(|| log::debug!("Refresh loop exited")),
88            _phantom: PhantomData,
89        }
90    }
91}
92
93impl<T, E> Builder<T, E>
94where
95    T: Send + Sync + 'static,
96    E: Send + Sync + 'static,
97{
98    /// Set the duration for refreshing. Default value: 60 seconds.
99    pub fn duration(&mut self, duration: Duration) -> &mut Self {
100        self.duration = duration;
101        self
102    }
103
104    /// What should we do with error values produced while refreshing? Default: no action.
105    pub fn error(&mut self, error: impl Fn(&E) + Send + Sync + 'static) -> &mut Self {
106        self.error = Arc::new(error);
107        self
108    }
109
110    /// What should we do with success values produced while refreshing? Default: no action.
111    pub fn success(&mut self, success: impl Fn(&T) + Send + Sync + 'static) -> &mut Self {
112        self.success = Arc::new(success);
113        self
114    }
115
116    /// What should we do when the refresh loop exits? Default: debug level log message.
117    pub fn exit(&mut self, exit: impl Fn() + Send + Sync + 'static) -> &mut Self {
118        self.exit = Arc::new(exit);
119        self
120    }
121
122    /// Construct a [Refreshed] value from the given initialization function, which may fail.
123    ///
124    /// The closure is provided `false` on the first call, and `true` on subsequent refresh calls.
125    pub async fn try_build<Fut, MkFut>(&self, mut mk_fut: MkFut) -> Result<Refreshed<T, E>, E>
126    where
127        Fut: Future<Output = Result<T, E>> + Send + 'static,
128        MkFut: FnMut(bool) -> Fut + Send + 'static,
129    {
130        let init = RefreshState {
131            value: Arc::new(mk_fut(false).await?),
132            updated: Instant::now(),
133            last_error: None,
134        };
135        let refresh = Refreshed {
136            inner: Arc::new(RwLock::new(init)),
137        };
138        let weak = Arc::downgrade(&refresh.inner);
139        let duration = self.duration;
140        let success = self.success.clone();
141        let error = self.error.clone();
142        let exit = self.exit.clone();
143        tokio::spawn(async move {
144            let _exit = Dropper(Some(|| exit()));
145            loop {
146                sleep(duration).await;
147                let arc = match weak.upgrade() {
148                    None => break,
149                    Some(arc) => arc,
150                };
151
152                match mk_fut(true).await {
153                    Err(e) => {
154                        error(&e);
155                        arc.write().last_error = Some(Arc::new(e));
156                    }
157                    Ok(t) => {
158                        success(&t);
159                        let mut lock = arc.write();
160                        lock.value = Arc::new(t);
161                        lock.updated = Instant::now();
162                        lock.last_error = None;
163                    }
164                }
165            }
166        });
167        Ok(refresh)
168    }
169}
170
171/// Helper type which runs the provided function when dropped.
172struct Dropper<F: FnOnce()>(Option<F>);
173
174impl<F: FnOnce()> Drop for Dropper<F> {
175    fn drop(&mut self) {
176        if let Some(f) = self.0.take() {
177            f()
178        }
179    }
180}
181
182impl<T> Builder<T, Infallible>
183where
184    T: Send + Sync + 'static,
185{
186    /// Construct a [Refreshed] value from the given initialization function
187    pub async fn build<Fut, MkFut>(&self, mut mk_fut: MkFut) -> Refreshed<T, Infallible>
188    where
189        Fut: Future<Output = T> + Send + 'static,
190        MkFut: FnMut(bool) -> Fut + Send + 'static,
191    {
192        let res = self
193            .try_build(move |is_refresh| {
194                let fut = mk_fut(is_refresh);
195                async move {
196                    let t = fut.await;
197                    Ok::<_, Infallible>(t)
198                }
199            })
200            .await;
201
202        absurd(res)
203    }
204}
205
206fn absurd<T>(res: Result<T, Infallible>) -> T {
207    res.expect("absurd!")
208}
209
210impl<T, E> Builder<T, E>
211where
212    T: Send + Sync + 'static,
213    E: Debug + Send + Sync + 'static,
214{
215    /// Turn on default error logging when an error occurs.
216    pub fn log_errors(&mut self) -> &mut Self {
217        self.error(|e| log::error!("{:?}", e))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use std::{convert::Infallible, sync::Arc};
224
225    use parking_lot::RwLock;
226    use tokio::time::{sleep, Duration};
227
228    use super::Refreshed;
229
230    #[tokio::test]
231    async fn simple_no_refresh() {
232        let x = Refreshed::builder()
233            .try_build(|_| async { Ok::<_, Infallible>(42_u32) })
234            .await
235            .unwrap();
236        assert_eq!(*x.get(), 42);
237    }
238
239    #[tokio::test]
240    async fn refreshes() {
241        let counter = Arc::new(RwLock::new(0u32));
242        let counter_clone = counter.clone();
243        let mk_fut = move |_| {
244            let counter_clone = counter_clone.clone();
245            async move {
246                let mut lock = counter_clone.write();
247                *lock += 1;
248                Ok::<u32, Infallible>(*lock)
249            }
250        };
251        let duration = Duration::from_millis(10);
252        let x = Refreshed::builder()
253            .duration(duration)
254            .try_build(mk_fut)
255            .await
256            .unwrap();
257        assert_eq!(*x.get(), 1);
258        for _ in 0..10u32 {
259            sleep(duration).await;
260            assert_eq!(*x.get(), *counter.read());
261        }
262    }
263
264    #[tokio::test]
265    async fn stops_refreshing() {
266        let exited = Arc::new(RwLock::new(false));
267        let exited_clone = exited.clone();
268        let counter = Arc::new(RwLock::new(0u32));
269        let counter_clone = counter.clone();
270        let mk_fut = move |_| {
271            let counter_clone = counter_clone.clone();
272            async move {
273                let mut lock = counter_clone.write();
274                *lock += 1;
275                Ok::<u32, Infallible>(*lock)
276            }
277        };
278        let duration = Duration::from_millis(10);
279        let x = Refreshed::builder()
280            .duration(duration)
281            .exit(move || *exited_clone.write() = true)
282            .try_build(mk_fut)
283            .await
284            .unwrap();
285        assert_eq!(*x.get(), 1);
286        assert_eq!(*exited.read(), false);
287        sleep(duration).await;
288        std::mem::drop(x);
289        let val = *counter.read();
290        for _ in 0..5u32 {
291            sleep(duration).await;
292            assert_eq!(val, *counter.read());
293        }
294        assert_eq!(*exited.read(), true);
295    }
296
297    #[tokio::test]
298    async fn count_successes() {
299        let counter = Arc::new(RwLock::new(0u32));
300        let counter_clone = counter.clone();
301        // start at 1, since we don't count the initial load
302        let success = Arc::new(RwLock::new(1u32));
303        let success_clone = success.clone();
304        let mk_fut = move |_| {
305            let counter_clone = counter_clone.clone();
306            async move {
307                let mut lock = counter_clone.write();
308                *lock += 1;
309                Ok::<u32, Infallible>(*lock)
310            }
311        };
312        let duration = Duration::from_millis(10);
313        let x = Refreshed::builder()
314            .duration(duration)
315            .success(move |_| *success_clone.write() += 1)
316            .try_build(mk_fut)
317            .await
318            .unwrap();
319        assert_eq!(*x.get(), 1);
320        for _ in 0..10u32 {
321            sleep(duration).await;
322            assert_eq!(*x.get(), *counter.read());
323            assert_eq!(*x.get(), *success.read());
324        }
325    }
326
327    #[tokio::test]
328    async fn simple_build() {
329        let x = Refreshed::builder().build(|_| async { 42_u32 }).await;
330        assert_eq!(*x.get(), 42);
331    }
332
333    #[tokio::test]
334    async fn exit_on_panic() {
335        let exited = Arc::new(RwLock::new(false));
336        let exited_clone = exited.clone();
337        let mk_fut = move |is_refresh| async move {
338            if is_refresh {
339                panic!("Don't panic!");
340            } else {
341                ()
342            }
343        };
344        let duration = Duration::from_millis(10);
345        let x = Refreshed::builder()
346            .duration(duration)
347            .exit(move || *exited_clone.write() = true)
348            .build(mk_fut)
349            .await;
350        assert_eq!(*exited.read(), false);
351        sleep(duration).await;
352        sleep(duration).await;
353        assert_eq!(*exited.read(), true);
354        assert_eq!(x.get_state().last_error, None);
355    }
356}