Skip to main content

atomr_streams/
recovery.rs

1//! Recovery operators on `Source<Result<T, E>>`.
2//!
3//! Operators: `Recover`, `RecoverWith`, `RecoverWithRetries`, `MapError`.
4//!
5//! These operators are exposed as free functions that take a
6//! `Source<Result<T, E>>` and return a transformed `Source`. They
7//! sit alongside the linear-operator surface on `Source<T>` rather
8//! than directly on `Source` because Result-shaped sources need
9//! their own combinator semantics.
10
11use crate::source::Source;
12use futures::stream::StreamExt;
13
14/// Replace any `Err(e)` with `f(e)` and continue, mapping the stream
15/// to `T` (success values are unwrapped). The first error stops the
16/// upstream — subsequent elements are dropped.
17///
18pub fn recover<T, E, F>(src: Source<Result<T, E>>, mut f: F) -> Source<T>
19where
20    T: Send + 'static,
21    E: Send + 'static,
22    F: FnMut(E) -> Option<T> + Send + 'static,
23{
24    let inner = src.into_boxed();
25    let mut errored = false;
26    let stream = inner
27        .take_while(move |item| {
28            let cont = !errored;
29            if item.is_err() {
30                errored = true;
31            }
32            futures::future::ready(cont)
33        })
34        .filter_map(move |item| {
35            futures::future::ready(match item {
36                Ok(v) => Some(v),
37                Err(e) => f(e),
38            })
39        });
40    Source { inner: stream.boxed() }
41}
42
43/// Map the error variant via `f`. Both `Ok` and `Err` continue
44/// downstream; only the `Err` payload type changes.
45///
46pub fn map_error<T, E1, E2, F>(src: Source<Result<T, E1>>, mut f: F) -> Source<Result<T, E2>>
47where
48    T: Send + 'static,
49    E1: Send + 'static,
50    E2: Send + 'static,
51    F: FnMut(E1) -> E2 + Send + 'static,
52{
53    let stream = src.into_boxed().map(move |item| match item {
54        Ok(v) => Ok(v),
55        Err(e) => Err(f(e)),
56    });
57    Source { inner: stream.boxed() }
58}
59
60/// Replace the upstream's tail with `replacement` upon the first
61/// `Err(_)`. Pre-error `Ok(_)` values flow through unchanged.
62///
63/// with
64/// `maxAttempts = 1` (multi-attempt retry waits on the
65/// `RestartSource` machinery — Phase 12 follow-on).
66pub fn recover_with<T, E>(src: Source<Result<T, E>>, replacement: Source<T>) -> Source<T>
67where
68    T: Send + 'static,
69    E: Send + 'static,
70{
71    use futures::stream;
72    let mut tripped = false;
73    let mut replacement_opt = Some(replacement);
74    let inner = src.into_boxed();
75    let stream = inner.flat_map(move |item| {
76        if tripped {
77            return stream::empty().boxed();
78        }
79        match item {
80            Ok(v) => stream::iter(std::iter::once(v)).boxed(),
81            Err(_) => {
82                tripped = true;
83                if let Some(rep) = replacement_opt.take() {
84                    rep.into_boxed()
85                } else {
86                    stream::empty().boxed()
87                }
88            }
89        }
90    });
91    Source { inner: stream.boxed() }
92}
93
94/// Replace the upstream's tail with `replacement_factory()` on each
95/// error, capped at `max_attempts` total replacements. After
96/// `max_attempts`, subsequent errors propagate as terminations.
97///
98pub fn recover_with_retries<T, E, F>(
99    src: Source<Result<T, E>>,
100    max_attempts: usize,
101    mut replacement_factory: F,
102) -> Source<T>
103where
104    T: Send + 'static,
105    E: Send + 'static,
106    F: FnMut() -> Source<T> + Send + 'static,
107{
108    use futures::stream;
109    let mut attempts_left = max_attempts;
110    let mut tripped = false;
111    let inner = src.into_boxed();
112    let stream = inner.flat_map(move |item| {
113        if tripped {
114            return stream::empty().boxed();
115        }
116        match item {
117            Ok(v) => stream::iter(std::iter::once(v)).boxed(),
118            Err(_) if attempts_left > 0 => {
119                attempts_left -= 1;
120                replacement_factory().into_boxed()
121            }
122            Err(_) => {
123                tripped = true;
124                stream::empty().boxed()
125            }
126        }
127    });
128    Source { inner: stream.boxed() }
129}
130
131/// Alias for [`map_error`] matching
132/// naming. Keeping both names makes porting tests verbatim possible.
133pub fn select_error<T, E1, E2, F>(src: Source<Result<T, E1>>, f: F) -> Source<Result<T, E2>>
134where
135    T: Send + 'static,
136    E1: Send + 'static,
137    E2: Send + 'static,
138    F: FnMut(E1) -> E2 + Send + 'static,
139{
140    map_error(src, f)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::sink::Sink;
147
148    #[tokio::test]
149    async fn recover_replaces_error_with_value_and_terminates() {
150        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("oops"), Ok(99)]);
151        let recovered = recover(s, |_e| Some(0));
152        let collected = Sink::collect(recovered).await;
153        assert_eq!(collected, vec![1, 2, 0]);
154    }
155
156    #[tokio::test]
157    async fn recover_with_none_drops_error() {
158        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("e"), Ok(2)]);
159        let recovered = recover(s, |_| None);
160        let collected = Sink::collect(recovered).await;
161        assert_eq!(collected, vec![1]);
162    }
163
164    #[tokio::test]
165    async fn recover_passes_through_when_no_error() {
166        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Ok(3)]);
167        let recovered = recover(s, |_| Some(0));
168        let collected = Sink::collect(recovered).await;
169        assert_eq!(collected, vec![1, 2, 3]);
170    }
171
172    #[tokio::test]
173    async fn map_error_changes_error_type() {
174        #[derive(Debug, PartialEq)]
175        struct Wrapped(String);
176        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("boom")]);
177        let mapped = map_error(s, |e| Wrapped(e.to_string()));
178        let collected = Sink::collect(mapped).await;
179        assert_eq!(collected.len(), 2);
180        assert_eq!(collected[0], Ok(1));
181        assert_eq!(collected[1], Err(Wrapped("boom".into())));
182    }
183
184    #[tokio::test]
185    async fn recover_with_switches_to_replacement_on_error() {
186        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("e"), Ok(99)]);
187        let replacement: Source<i32> = Source::from_iter(vec![100, 200]);
188        let recovered = recover_with(s, replacement);
189        let collected = Sink::collect(recovered).await;
190        assert_eq!(collected, vec![1, 2, 100, 200]);
191    }
192
193    #[tokio::test]
194    async fn recover_with_passes_through_when_no_error() {
195        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2)]);
196        let replacement: Source<i32> = Source::from_iter(vec![100]);
197        let recovered = recover_with(s, replacement);
198        let collected = Sink::collect(recovered).await;
199        assert_eq!(collected, vec![1, 2]);
200    }
201
202    #[tokio::test]
203    async fn recover_with_retries_replays_factory_each_time() {
204        let s: Source<Result<i32, &'static str>> =
205            Source::from_iter(vec![Ok(1), Err("e1"), Err("e2"), Ok(99)]);
206        let mut counter = 0;
207        let recovered = recover_with_retries(s, 2, move || {
208            counter += 1;
209            Source::from_iter(vec![counter * 10])
210        });
211        let collected = Sink::collect(recovered).await;
212        // 1 → first error → replacement (10) drains → second error →
213        // replacement (20) drains → upstream Ok(99) flows through
214        // because retries remain (effectively unlimited until the
215        // attempt counter hits zero).
216        assert_eq!(collected, vec![1, 10, 20, 99]);
217    }
218
219    #[tokio::test]
220    async fn recover_with_retries_caps_at_max_attempts() {
221        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Err("e1"), Err("e2"), Err("e3")]);
222        let recovered = recover_with_retries(s, 1, || Source::from_iter(vec![777]));
223        let collected = Sink::collect(recovered).await;
224        // first error consumes the single attempt (777 emitted); second
225        // error trips the stream.
226        assert_eq!(collected, vec![777]);
227    }
228
229    #[tokio::test]
230    async fn select_error_alias_matches_map_error() {
231        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("boom")]);
232        let mapped = select_error(s, |e| e.to_string());
233        let collected = Sink::collect(mapped).await;
234        assert_eq!(collected, vec![Ok(1), Err("boom".to_string())]);
235    }
236}