1use std::time::Duration;
4
5use futures::stream::StreamExt;
6
7use crate::source::Source;
8
9#[derive(Debug, Clone, Copy)]
10pub struct RestartSettings {
11 pub min_backoff: Duration,
12 pub max_backoff: Duration,
13 pub random_factor: f64,
14 pub max_restarts: Option<usize>,
15}
16
17impl Default for RestartSettings {
18 fn default() -> Self {
19 Self {
20 min_backoff: Duration::from_millis(100),
21 max_backoff: Duration::from_secs(30),
22 random_factor: 0.0,
23 max_restarts: Some(5),
24 }
25 }
26}
27
28pub struct RestartSource;
29
30impl RestartSource {
31 pub fn with_backoff<T, F>(settings: RestartSettings, factory: F) -> Source<T>
35 where
36 T: Send + 'static,
37 F: FnMut() -> Source<T> + Send + 'static,
38 {
39 let state = RestartState { factory, settings, attempts: 0 };
40 let s = futures::stream::unfold(
41 (state, None::<futures::stream::BoxStream<'static, T>>),
42 |(mut state, current)| async move {
43 let mut stream = match current {
45 Some(s) => s,
46 None => state.next_stream().await?,
47 };
48 if let Some(v) = stream.next().await {
49 Some((v, (state, Some(stream))))
50 } else {
51 let maybe_next = state.next_stream().await;
53 match maybe_next {
54 Some(mut s) => s.next().await.map(|v| (v, (state, Some(s)))),
55 None => None,
56 }
57 }
58 },
59 )
60 .boxed();
61 Source { inner: s }
62 }
63}
64
65struct RestartState<T, F>
66where
67 F: FnMut() -> Source<T> + Send + 'static,
68{
69 factory: F,
70 settings: RestartSettings,
71 attempts: usize,
72}
73
74impl<T, F> RestartState<T, F>
75where
76 T: Send + 'static,
77 F: FnMut() -> Source<T> + Send + 'static,
78{
79 async fn next_stream(&mut self) -> Option<futures::stream::BoxStream<'static, T>> {
80 if let Some(limit) = self.settings.max_restarts {
81 if self.attempts >= limit {
82 return None;
83 }
84 }
85 if self.attempts > 0 {
86 let base = self.settings.min_backoff.as_millis() as u64;
87 let cap = self.settings.max_backoff.as_millis() as u64;
88 let back = (base.saturating_mul(1 << self.attempts.min(20))).min(cap.max(base));
89 tokio::time::sleep(Duration::from_millis(back)).await;
90 }
91 self.attempts += 1;
92 let src = (self.factory)();
93 Some(src.into_boxed())
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::sink::Sink;
101 use std::sync::atomic::{AtomicUsize, Ordering};
102 use std::sync::Arc;
103
104 #[tokio::test]
105 async fn restart_source_resubscribes_until_max() {
106 let calls = Arc::new(AtomicUsize::new(0));
107 let calls_c = calls.clone();
108 let settings = RestartSettings {
109 min_backoff: Duration::from_millis(1),
110 max_backoff: Duration::from_millis(5),
111 random_factor: 0.0,
112 max_restarts: Some(3),
113 };
114 let source = RestartSource::with_backoff(settings, move || {
115 calls_c.fetch_add(1, Ordering::SeqCst);
116 crate::source::Source::from_iter(vec![1])
117 });
118 let out = Sink::collect(source).await;
119 assert_eq!(out, vec![1, 1, 1]);
120 assert_eq!(calls.load(Ordering::SeqCst), 3);
121 }
122}