Skip to main content

atomr_streams/
recovery.rs

1//! Recovery operators on `Source<Result<T, E>>`.
2//!
3//! Phase 12.7 of `docs/full-port-plan.md`. Akka.NET / Akka Streams
4//! parity: `Recover`, `RecoverWith`, `RecoverWithRetries`, `MapError`.
5//!
6//! These operators are exposed as free functions that take a
7//! `Source<Result<T, E>>` and return a transformed `Source`. They
8//! sit alongside the linear-operator surface on `Source<T>` rather
9//! than directly on `Source` because Result-shaped sources need
10//! their own combinator semantics.
11
12use crate::source::Source;
13use futures::stream::StreamExt;
14
15/// Replace any `Err(e)` with `f(e)` and continue, mapping the stream
16/// to `T` (success values are unwrapped). The first error stops the
17/// upstream — subsequent elements are dropped.
18///
19/// Akka.NET: `Source.Recover<T>(Func<Exception, Option<T>>)`.
20pub fn recover<T, E, F>(src: Source<Result<T, E>>, mut f: F) -> Source<T>
21where
22    T: Send + 'static,
23    E: Send + 'static,
24    F: FnMut(E) -> Option<T> + Send + 'static,
25{
26    let inner = src.into_boxed();
27    let mut errored = false;
28    let stream = inner
29        .take_while(move |item| {
30            let cont = !errored;
31            if item.is_err() {
32                errored = true;
33            }
34            futures::future::ready(cont)
35        })
36        .filter_map(move |item| {
37            futures::future::ready(match item {
38                Ok(v) => Some(v),
39                Err(e) => f(e),
40            })
41        });
42    Source { inner: stream.boxed() }
43}
44
45/// Map the error variant via `f`. Both `Ok` and `Err` continue
46/// downstream; only the `Err` payload type changes.
47///
48/// Akka.NET: `Source.SelectError`.
49pub fn map_error<T, E1, E2, F>(src: Source<Result<T, E1>>, mut f: F) -> Source<Result<T, E2>>
50where
51    T: Send + 'static,
52    E1: Send + 'static,
53    E2: Send + 'static,
54    F: FnMut(E1) -> E2 + Send + 'static,
55{
56    let stream = src.into_boxed().map(move |item| match item {
57        Ok(v) => Ok(v),
58        Err(e) => Err(f(e)),
59    });
60    Source { inner: stream.boxed() }
61}
62
63/// Replace the upstream's tail with `replacement` upon the first
64/// `Err(_)`. Pre-error `Ok(_)` values flow through unchanged.
65///
66/// Akka.NET: `Source.RecoverWithRetries(maxAttempts, …)` with
67/// `maxAttempts = 1` (multi-attempt retry waits on the
68/// `RestartSource` machinery — Phase 12 follow-on).
69pub fn recover_with<T, E>(src: Source<Result<T, E>>, replacement: Source<T>) -> Source<T>
70where
71    T: Send + 'static,
72    E: Send + 'static,
73{
74    use futures::stream;
75    let mut tripped = false;
76    let mut replacement_opt = Some(replacement);
77    let inner = src.into_boxed();
78    let stream = inner.flat_map(move |item| {
79        if tripped {
80            return stream::empty().boxed();
81        }
82        match item {
83            Ok(v) => stream::iter(std::iter::once(v)).boxed(),
84            Err(_) => {
85                tripped = true;
86                if let Some(rep) = replacement_opt.take() {
87                    rep.into_boxed()
88                } else {
89                    stream::empty().boxed()
90                }
91            }
92        }
93    });
94    Source { inner: stream.boxed() }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::sink::Sink;
101
102    #[tokio::test]
103    async fn recover_replaces_error_with_value_and_terminates() {
104        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("oops"), Ok(99)]);
105        let recovered = recover(s, |_e| Some(0));
106        let collected = Sink::collect(recovered).await;
107        assert_eq!(collected, vec![1, 2, 0]);
108    }
109
110    #[tokio::test]
111    async fn recover_with_none_drops_error() {
112        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("e"), Ok(2)]);
113        let recovered = recover(s, |_| None);
114        let collected = Sink::collect(recovered).await;
115        assert_eq!(collected, vec![1]);
116    }
117
118    #[tokio::test]
119    async fn recover_passes_through_when_no_error() {
120        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Ok(3)]);
121        let recovered = recover(s, |_| Some(0));
122        let collected = Sink::collect(recovered).await;
123        assert_eq!(collected, vec![1, 2, 3]);
124    }
125
126    #[tokio::test]
127    async fn map_error_changes_error_type() {
128        #[derive(Debug, PartialEq)]
129        struct Wrapped(String);
130        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("boom")]);
131        let mapped = map_error(s, |e| Wrapped(e.to_string()));
132        let collected = Sink::collect(mapped).await;
133        assert_eq!(collected.len(), 2);
134        assert_eq!(collected[0], Ok(1));
135        assert_eq!(collected[1], Err(Wrapped("boom".into())));
136    }
137
138    #[tokio::test]
139    async fn recover_with_switches_to_replacement_on_error() {
140        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("e"), Ok(99)]);
141        let replacement: Source<i32> = Source::from_iter(vec![100, 200]);
142        let recovered = recover_with(s, replacement);
143        let collected = Sink::collect(recovered).await;
144        assert_eq!(collected, vec![1, 2, 100, 200]);
145    }
146
147    #[tokio::test]
148    async fn recover_with_passes_through_when_no_error() {
149        let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2)]);
150        let replacement: Source<i32> = Source::from_iter(vec![100]);
151        let recovered = recover_with(s, replacement);
152        let collected = Sink::collect(recovered).await;
153        assert_eq!(collected, vec![1, 2]);
154    }
155}