1use std::time::Duration;
5
6use futures::stream::StreamExt;
7
8use crate::source::Source;
9
10#[derive(Debug, Clone, Copy)]
11pub struct RestartSettings {
12 pub min_backoff: Duration,
13 pub max_backoff: Duration,
14 pub random_factor: f64,
15 pub max_restarts: Option<usize>,
16}
17
18impl Default for RestartSettings {
19 fn default() -> Self {
20 Self {
21 min_backoff: Duration::from_millis(100),
22 max_backoff: Duration::from_secs(30),
23 random_factor: 0.0,
24 max_restarts: Some(5),
25 }
26 }
27}
28
29pub struct RestartSource;
30
31impl RestartSource {
32 pub fn with_backoff<T, F>(settings: RestartSettings, factory: F) -> Source<T>
36 where
37 T: Send + 'static,
38 F: FnMut() -> Source<T> + Send + 'static,
39 {
40 let state = RestartState { factory, settings, attempts: 0 };
41 let s = futures::stream::unfold(
42 (state, None::<futures::stream::BoxStream<'static, T>>),
43 |(mut state, current)| async move {
44 let mut stream = match current {
46 Some(s) => s,
47 None => state.next_stream().await?,
48 };
49 if let Some(v) = stream.next().await {
50 Some((v, (state, Some(stream))))
51 } else {
52 let maybe_next = state.next_stream().await;
54 match maybe_next {
55 Some(mut s) => s.next().await.map(|v| (v, (state, Some(s)))),
56 None => None,
57 }
58 }
59 },
60 )
61 .boxed();
62 Source { inner: s }
63 }
64}
65
66struct RestartState<T, F>
67where
68 F: FnMut() -> Source<T> + Send + 'static,
69{
70 factory: F,
71 settings: RestartSettings,
72 attempts: usize,
73}
74
75impl<T, F> RestartState<T, F>
76where
77 T: Send + 'static,
78 F: FnMut() -> Source<T> + Send + 'static,
79{
80 async fn next_stream(&mut self) -> Option<futures::stream::BoxStream<'static, T>> {
81 if let Some(limit) = self.settings.max_restarts {
82 if self.attempts >= limit {
83 return None;
84 }
85 }
86 if self.attempts > 0 {
87 let base = self.settings.min_backoff.as_millis() as u64;
88 let cap = self.settings.max_backoff.as_millis() as u64;
89 let back = (base.saturating_mul(1 << self.attempts.min(20))).min(cap.max(base));
90 tokio::time::sleep(Duration::from_millis(back)).await;
91 }
92 self.attempts += 1;
93 let src = (self.factory)();
94 Some(src.into_boxed())
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::sink::Sink;
102 use std::sync::atomic::{AtomicUsize, Ordering};
103 use std::sync::Arc;
104
105 #[tokio::test]
106 async fn restart_source_resubscribes_until_max() {
107 let calls = Arc::new(AtomicUsize::new(0));
108 let calls_c = calls.clone();
109 let settings = RestartSettings {
110 min_backoff: Duration::from_millis(1),
111 max_backoff: Duration::from_millis(5),
112 random_factor: 0.0,
113 max_restarts: Some(3),
114 };
115 let source = RestartSource::with_backoff(settings, move || {
116 calls_c.fetch_add(1, Ordering::SeqCst);
117 crate::source::Source::from_iter(vec![1])
118 });
119 let out = Sink::collect(source).await;
120 assert_eq!(out, vec![1, 1, 1]);
121 assert_eq!(calls.load(Ordering::SeqCst), 3);
122 }
123}