atomr_streams/
recovery.rs1use crate::source::Source;
13use futures::stream::StreamExt;
14
15pub fn recover<T, E, F>(src: Source<Result<T, E>>, mut f: F) -> Source<T>
21where
22 T: Send + 'static,
23 E: Send + 'static,
24 F: FnMut(E) -> Option<T> + Send + 'static,
25{
26 let inner = src.into_boxed();
27 let mut errored = false;
28 let stream = inner
29 .take_while(move |item| {
30 let cont = !errored;
31 if item.is_err() {
32 errored = true;
33 }
34 futures::future::ready(cont)
35 })
36 .filter_map(move |item| {
37 futures::future::ready(match item {
38 Ok(v) => Some(v),
39 Err(e) => f(e),
40 })
41 });
42 Source { inner: stream.boxed() }
43}
44
45pub fn map_error<T, E1, E2, F>(src: Source<Result<T, E1>>, mut f: F) -> Source<Result<T, E2>>
50where
51 T: Send + 'static,
52 E1: Send + 'static,
53 E2: Send + 'static,
54 F: FnMut(E1) -> E2 + Send + 'static,
55{
56 let stream = src.into_boxed().map(move |item| match item {
57 Ok(v) => Ok(v),
58 Err(e) => Err(f(e)),
59 });
60 Source { inner: stream.boxed() }
61}
62
63pub fn recover_with<T, E>(src: Source<Result<T, E>>, replacement: Source<T>) -> Source<T>
70where
71 T: Send + 'static,
72 E: Send + 'static,
73{
74 use futures::stream;
75 let mut tripped = false;
76 let mut replacement_opt = Some(replacement);
77 let inner = src.into_boxed();
78 let stream = inner.flat_map(move |item| {
79 if tripped {
80 return stream::empty().boxed();
81 }
82 match item {
83 Ok(v) => stream::iter(std::iter::once(v)).boxed(),
84 Err(_) => {
85 tripped = true;
86 if let Some(rep) = replacement_opt.take() {
87 rep.into_boxed()
88 } else {
89 stream::empty().boxed()
90 }
91 }
92 }
93 });
94 Source { inner: stream.boxed() }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::sink::Sink;
101
102 #[tokio::test]
103 async fn recover_replaces_error_with_value_and_terminates() {
104 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("oops"), Ok(99)]);
105 let recovered = recover(s, |_e| Some(0));
106 let collected = Sink::collect(recovered).await;
107 assert_eq!(collected, vec![1, 2, 0]);
108 }
109
110 #[tokio::test]
111 async fn recover_with_none_drops_error() {
112 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("e"), Ok(2)]);
113 let recovered = recover(s, |_| None);
114 let collected = Sink::collect(recovered).await;
115 assert_eq!(collected, vec![1]);
116 }
117
118 #[tokio::test]
119 async fn recover_passes_through_when_no_error() {
120 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Ok(3)]);
121 let recovered = recover(s, |_| Some(0));
122 let collected = Sink::collect(recovered).await;
123 assert_eq!(collected, vec![1, 2, 3]);
124 }
125
126 #[tokio::test]
127 async fn map_error_changes_error_type() {
128 #[derive(Debug, PartialEq)]
129 struct Wrapped(String);
130 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("boom")]);
131 let mapped = map_error(s, |e| Wrapped(e.to_string()));
132 let collected = Sink::collect(mapped).await;
133 assert_eq!(collected.len(), 2);
134 assert_eq!(collected[0], Ok(1));
135 assert_eq!(collected[1], Err(Wrapped("boom".into())));
136 }
137
138 #[tokio::test]
139 async fn recover_with_switches_to_replacement_on_error() {
140 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("e"), Ok(99)]);
141 let replacement: Source<i32> = Source::from_iter(vec![100, 200]);
142 let recovered = recover_with(s, replacement);
143 let collected = Sink::collect(recovered).await;
144 assert_eq!(collected, vec![1, 2, 100, 200]);
145 }
146
147 #[tokio::test]
148 async fn recover_with_passes_through_when_no_error() {
149 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2)]);
150 let replacement: Source<i32> = Source::from_iter(vec![100]);
151 let recovered = recover_with(s, replacement);
152 let collected = Sink::collect(recovered).await;
153 assert_eq!(collected, vec![1, 2]);
154 }
155}