Skip to main content

atomr_streams/
restart.rs

1//! Restart combinators — re-run the inner graph on failure/completion.
2//! akka.net: `Dsl/RestartSource.cs`, `Dsl/RestartFlow.cs`, `Dsl/RestartSink.cs`.
3
4use std::time::Duration;
5
6use futures::stream::StreamExt;
7
8use crate::source::Source;
9
10#[derive(Debug, Clone, Copy)]
11pub struct RestartSettings {
12    pub min_backoff: Duration,
13    pub max_backoff: Duration,
14    pub random_factor: f64,
15    pub max_restarts: Option<usize>,
16}
17
18impl Default for RestartSettings {
19    fn default() -> Self {
20        Self {
21            min_backoff: Duration::from_millis(100),
22            max_backoff: Duration::from_secs(30),
23            random_factor: 0.0,
24            max_restarts: Some(5),
25        }
26    }
27}
28
29pub struct RestartSource;
30
31impl RestartSource {
32    /// Re-subscribe to the source returned by the factory after it completes
33    /// (and every element it produced has been emitted). Mirrors
34    /// `RestartSource.WithBackoff` when combined with the built-in backoff.
35    pub fn with_backoff<T, F>(settings: RestartSettings, factory: F) -> Source<T>
36    where
37        T: Send + 'static,
38        F: FnMut() -> Source<T> + Send + 'static,
39    {
40        let state = RestartState { factory, settings, attempts: 0 };
41        let s = futures::stream::unfold(
42            (state, None::<futures::stream::BoxStream<'static, T>>),
43            |(mut state, current)| async move {
44                // Lazily open a subscription.
45                let mut stream = match current {
46                    Some(s) => s,
47                    None => state.next_stream().await?,
48                };
49                if let Some(v) = stream.next().await {
50                    Some((v, (state, Some(stream))))
51                } else {
52                    // Completed; check restart policy.
53                    let maybe_next = state.next_stream().await;
54                    match maybe_next {
55                        Some(mut s) => s.next().await.map(|v| (v, (state, Some(s)))),
56                        None => None,
57                    }
58                }
59            },
60        )
61        .boxed();
62        Source { inner: s }
63    }
64}
65
66struct RestartState<T, F>
67where
68    F: FnMut() -> Source<T> + Send + 'static,
69{
70    factory: F,
71    settings: RestartSettings,
72    attempts: usize,
73}
74
75impl<T, F> RestartState<T, F>
76where
77    T: Send + 'static,
78    F: FnMut() -> Source<T> + Send + 'static,
79{
80    async fn next_stream(&mut self) -> Option<futures::stream::BoxStream<'static, T>> {
81        if let Some(limit) = self.settings.max_restarts {
82            if self.attempts >= limit {
83                return None;
84            }
85        }
86        if self.attempts > 0 {
87            let base = self.settings.min_backoff.as_millis() as u64;
88            let cap = self.settings.max_backoff.as_millis() as u64;
89            let back = (base.saturating_mul(1 << self.attempts.min(20))).min(cap.max(base));
90            tokio::time::sleep(Duration::from_millis(back)).await;
91        }
92        self.attempts += 1;
93        let src = (self.factory)();
94        Some(src.into_boxed())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::sink::Sink;
102    use std::sync::atomic::{AtomicUsize, Ordering};
103    use std::sync::Arc;
104
105    #[tokio::test]
106    async fn restart_source_resubscribes_until_max() {
107        let calls = Arc::new(AtomicUsize::new(0));
108        let calls_c = calls.clone();
109        let settings = RestartSettings {
110            min_backoff: Duration::from_millis(1),
111            max_backoff: Duration::from_millis(5),
112            random_factor: 0.0,
113            max_restarts: Some(3),
114        };
115        let source = RestartSource::with_backoff(settings, move || {
116            calls_c.fetch_add(1, Ordering::SeqCst);
117            crate::source::Source::from_iter(vec![1])
118        });
119        let out = Sink::collect(source).await;
120        assert_eq!(out, vec![1, 1, 1]);
121        assert_eq!(calls.load(Ordering::SeqCst), 3);
122    }
123}