Skip to main content

atomr_streams/
restart.rs

1//! Restart combinators — re-run the inner graph on failure/completion.
2
3use std::time::Duration;
4
5use futures::stream::StreamExt;
6
7use crate::source::Source;
8
9#[derive(Debug, Clone, Copy)]
10pub struct RestartSettings {
11    pub min_backoff: Duration,
12    pub max_backoff: Duration,
13    pub random_factor: f64,
14    pub max_restarts: Option<usize>,
15}
16
17impl Default for RestartSettings {
18    fn default() -> Self {
19        Self {
20            min_backoff: Duration::from_millis(100),
21            max_backoff: Duration::from_secs(30),
22            random_factor: 0.0,
23            max_restarts: Some(5),
24        }
25    }
26}
27
28pub struct RestartSource;
29
30impl RestartSource {
31    /// Re-subscribe to the source returned by the factory after it completes
32    /// (and every element it produced has been emitted). Equivalent to
33    /// `RestartSource.WithBackoff` when combined with the built-in backoff.
34    pub fn with_backoff<T, F>(settings: RestartSettings, factory: F) -> Source<T>
35    where
36        T: Send + 'static,
37        F: FnMut() -> Source<T> + Send + 'static,
38    {
39        let state = RestartState { factory, settings, attempts: 0 };
40        let s = futures::stream::unfold(
41            (state, None::<futures::stream::BoxStream<'static, T>>),
42            |(mut state, current)| async move {
43                // Lazily open a subscription.
44                let mut stream = match current {
45                    Some(s) => s,
46                    None => state.next_stream().await?,
47                };
48                if let Some(v) = stream.next().await {
49                    Some((v, (state, Some(stream))))
50                } else {
51                    // Completed; check restart policy.
52                    let maybe_next = state.next_stream().await;
53                    match maybe_next {
54                        Some(mut s) => s.next().await.map(|v| (v, (state, Some(s)))),
55                        None => None,
56                    }
57                }
58            },
59        )
60        .boxed();
61        Source { inner: s }
62    }
63}
64
65struct RestartState<T, F>
66where
67    F: FnMut() -> Source<T> + Send + 'static,
68{
69    factory: F,
70    settings: RestartSettings,
71    attempts: usize,
72}
73
74impl<T, F> RestartState<T, F>
75where
76    T: Send + 'static,
77    F: FnMut() -> Source<T> + Send + 'static,
78{
79    async fn next_stream(&mut self) -> Option<futures::stream::BoxStream<'static, T>> {
80        if let Some(limit) = self.settings.max_restarts {
81            if self.attempts >= limit {
82                return None;
83            }
84        }
85        if self.attempts > 0 {
86            let base = self.settings.min_backoff.as_millis() as u64;
87            let cap = self.settings.max_backoff.as_millis() as u64;
88            let back = (base.saturating_mul(1 << self.attempts.min(20))).min(cap.max(base));
89            tokio::time::sleep(Duration::from_millis(back)).await;
90        }
91        self.attempts += 1;
92        let src = (self.factory)();
93        Some(src.into_boxed())
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::sink::Sink;
101    use std::sync::atomic::{AtomicUsize, Ordering};
102    use std::sync::Arc;
103
104    #[tokio::test]
105    async fn restart_source_resubscribes_until_max() {
106        let calls = Arc::new(AtomicUsize::new(0));
107        let calls_c = calls.clone();
108        let settings = RestartSettings {
109            min_backoff: Duration::from_millis(1),
110            max_backoff: Duration::from_millis(5),
111            random_factor: 0.0,
112            max_restarts: Some(3),
113        };
114        let source = RestartSource::with_backoff(settings, move || {
115            calls_c.fetch_add(1, Ordering::SeqCst);
116            crate::source::Source::from_iter(vec![1])
117        });
118        let out = Sink::collect(source).await;
119        assert_eq!(out, vec![1, 1, 1]);
120        assert_eq!(calls.load(Ordering::SeqCst), 3);
121    }
122}