1use crate::source::Source;
12use futures::stream::StreamExt;
13
14pub fn recover<T, E, F>(src: Source<Result<T, E>>, mut f: F) -> Source<T>
19where
20 T: Send + 'static,
21 E: Send + 'static,
22 F: FnMut(E) -> Option<T> + Send + 'static,
23{
24 let inner = src.into_boxed();
25 let mut errored = false;
26 let stream = inner
27 .take_while(move |item| {
28 let cont = !errored;
29 if item.is_err() {
30 errored = true;
31 }
32 futures::future::ready(cont)
33 })
34 .filter_map(move |item| {
35 futures::future::ready(match item {
36 Ok(v) => Some(v),
37 Err(e) => f(e),
38 })
39 });
40 Source { inner: stream.boxed() }
41}
42
43pub fn map_error<T, E1, E2, F>(src: Source<Result<T, E1>>, mut f: F) -> Source<Result<T, E2>>
47where
48 T: Send + 'static,
49 E1: Send + 'static,
50 E2: Send + 'static,
51 F: FnMut(E1) -> E2 + Send + 'static,
52{
53 let stream = src.into_boxed().map(move |item| match item {
54 Ok(v) => Ok(v),
55 Err(e) => Err(f(e)),
56 });
57 Source { inner: stream.boxed() }
58}
59
60pub fn recover_with<T, E>(src: Source<Result<T, E>>, replacement: Source<T>) -> Source<T>
67where
68 T: Send + 'static,
69 E: Send + 'static,
70{
71 use futures::stream;
72 let mut tripped = false;
73 let mut replacement_opt = Some(replacement);
74 let inner = src.into_boxed();
75 let stream = inner.flat_map(move |item| {
76 if tripped {
77 return stream::empty().boxed();
78 }
79 match item {
80 Ok(v) => stream::iter(std::iter::once(v)).boxed(),
81 Err(_) => {
82 tripped = true;
83 if let Some(rep) = replacement_opt.take() {
84 rep.into_boxed()
85 } else {
86 stream::empty().boxed()
87 }
88 }
89 }
90 });
91 Source { inner: stream.boxed() }
92}
93
94pub fn recover_with_retries<T, E, F>(
99 src: Source<Result<T, E>>,
100 max_attempts: usize,
101 mut replacement_factory: F,
102) -> Source<T>
103where
104 T: Send + 'static,
105 E: Send + 'static,
106 F: FnMut() -> Source<T> + Send + 'static,
107{
108 use futures::stream;
109 let mut attempts_left = max_attempts;
110 let mut tripped = false;
111 let inner = src.into_boxed();
112 let stream = inner.flat_map(move |item| {
113 if tripped {
114 return stream::empty().boxed();
115 }
116 match item {
117 Ok(v) => stream::iter(std::iter::once(v)).boxed(),
118 Err(_) if attempts_left > 0 => {
119 attempts_left -= 1;
120 replacement_factory().into_boxed()
121 }
122 Err(_) => {
123 tripped = true;
124 stream::empty().boxed()
125 }
126 }
127 });
128 Source { inner: stream.boxed() }
129}
130
131pub fn select_error<T, E1, E2, F>(src: Source<Result<T, E1>>, f: F) -> Source<Result<T, E2>>
134where
135 T: Send + 'static,
136 E1: Send + 'static,
137 E2: Send + 'static,
138 F: FnMut(E1) -> E2 + Send + 'static,
139{
140 map_error(src, f)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::sink::Sink;
147
148 #[tokio::test]
149 async fn recover_replaces_error_with_value_and_terminates() {
150 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("oops"), Ok(99)]);
151 let recovered = recover(s, |_e| Some(0));
152 let collected = Sink::collect(recovered).await;
153 assert_eq!(collected, vec![1, 2, 0]);
154 }
155
156 #[tokio::test]
157 async fn recover_with_none_drops_error() {
158 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("e"), Ok(2)]);
159 let recovered = recover(s, |_| None);
160 let collected = Sink::collect(recovered).await;
161 assert_eq!(collected, vec![1]);
162 }
163
164 #[tokio::test]
165 async fn recover_passes_through_when_no_error() {
166 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Ok(3)]);
167 let recovered = recover(s, |_| Some(0));
168 let collected = Sink::collect(recovered).await;
169 assert_eq!(collected, vec![1, 2, 3]);
170 }
171
172 #[tokio::test]
173 async fn map_error_changes_error_type() {
174 #[derive(Debug, PartialEq)]
175 struct Wrapped(String);
176 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("boom")]);
177 let mapped = map_error(s, |e| Wrapped(e.to_string()));
178 let collected = Sink::collect(mapped).await;
179 assert_eq!(collected.len(), 2);
180 assert_eq!(collected[0], Ok(1));
181 assert_eq!(collected[1], Err(Wrapped("boom".into())));
182 }
183
184 #[tokio::test]
185 async fn recover_with_switches_to_replacement_on_error() {
186 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2), Err("e"), Ok(99)]);
187 let replacement: Source<i32> = Source::from_iter(vec![100, 200]);
188 let recovered = recover_with(s, replacement);
189 let collected = Sink::collect(recovered).await;
190 assert_eq!(collected, vec![1, 2, 100, 200]);
191 }
192
193 #[tokio::test]
194 async fn recover_with_passes_through_when_no_error() {
195 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Ok(2)]);
196 let replacement: Source<i32> = Source::from_iter(vec![100]);
197 let recovered = recover_with(s, replacement);
198 let collected = Sink::collect(recovered).await;
199 assert_eq!(collected, vec![1, 2]);
200 }
201
202 #[tokio::test]
203 async fn recover_with_retries_replays_factory_each_time() {
204 let s: Source<Result<i32, &'static str>> =
205 Source::from_iter(vec![Ok(1), Err("e1"), Err("e2"), Ok(99)]);
206 let mut counter = 0;
207 let recovered = recover_with_retries(s, 2, move || {
208 counter += 1;
209 Source::from_iter(vec![counter * 10])
210 });
211 let collected = Sink::collect(recovered).await;
212 assert_eq!(collected, vec![1, 10, 20, 99]);
217 }
218
219 #[tokio::test]
220 async fn recover_with_retries_caps_at_max_attempts() {
221 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Err("e1"), Err("e2"), Err("e3")]);
222 let recovered = recover_with_retries(s, 1, || Source::from_iter(vec![777]));
223 let collected = Sink::collect(recovered).await;
224 assert_eq!(collected, vec![777]);
227 }
228
229 #[tokio::test]
230 async fn select_error_alias_matches_map_error() {
231 let s: Source<Result<i32, &'static str>> = Source::from_iter(vec![Ok(1), Err("boom")]);
232 let mapped = select_error(s, |e| e.to_string());
233 let collected = Sink::collect(mapped).await;
234 assert_eq!(collected, vec![Ok(1), Err("boom".to_string())]);
235 }
236}