1use std::fmt::Display;
11
12use anyhow::Result;
13#[cfg(not(target_arch = "wasm32"))]
14use backon::Retryable;
15use futures::{StreamExt, stream};
16
17#[cfg(not(target_arch = "wasm32"))]
18use crate::retry::{is_retryable_anyhow, retry_backoff};
19
20#[derive(Debug, Clone)]
22pub enum BulkOutcome<T> {
23 Success(T),
25 Skipped(String),
27 Failed(String),
29}
30
31#[derive(Debug, Clone)]
33pub struct BulkResult<I, T> {
34 pub succeeded: usize,
36 pub failed: usize,
38 pub skipped: usize,
40 pub outcomes: Vec<(I, BulkOutcome<T>)>,
42}
43
44impl<I, T> Default for BulkResult<I, T> {
45 fn default() -> Self {
46 Self {
47 succeeded: 0,
48 failed: 0,
49 skipped: 0,
50 outcomes: Vec::new(),
51 }
52 }
53}
54
55pub async fn process_bulk<I, D, T, F, Fut, P>(
119 items: Vec<(I, D)>,
120 processor: F,
121 progress_callback: P,
122) -> BulkResult<I, T>
123where
124 I: Clone + Display + Send + 'static,
125 D: Clone + Send + 'static,
126 T: Send + 'static,
127 F: Fn((I, D)) -> Fut + Send + Sync + 'static,
128 Fut: std::future::Future<Output = Result<Option<T>>> + Send,
129 P: Fn(usize, usize, &str) + Send + Sync + 'static,
130{
131 let total = items.len();
132 let progress_callback = std::sync::Arc::new(progress_callback);
133 let processor = std::sync::Arc::new(processor);
134
135 let mut tasks = Vec::new();
137 for (idx, (id, data)) in items.into_iter().enumerate() {
138 let id_clone = id.clone();
139 let data_clone = data.clone();
140 let progress_callback = progress_callback.clone();
141 let processor = processor.clone();
142
143 let task = async move {
144 progress_callback(idx + 1, total, &format!("Processing {id}"));
146
147 let id_for_retry = id_clone.clone();
149 let data_for_retry = data_clone.clone();
150 #[cfg(not(target_arch = "wasm32"))]
151 let result = (|| {
152 let processor = processor.clone();
153 let id = id_for_retry.clone();
154 let data = data_for_retry.clone();
155 async move { processor((id, data)).await }
156 })
157 .retry(retry_backoff())
158 .when(is_retryable_anyhow)
159 .notify(|err, dur| {
160 tracing::warn!(
161 error = %err,
162 delay_ms = dur.as_millis(),
163 item = %id_clone,
164 "Retrying after transient failure"
165 );
166 })
167 .await;
168 #[cfg(target_arch = "wasm32")]
169 let result = processor((id_for_retry, data_for_retry)).await;
170
171 (id_clone, result)
172 };
173
174 tasks.push(task);
175 }
176
177 let outcomes = stream::iter(tasks)
178 .buffer_unordered(5)
179 .collect::<Vec<_>>()
180 .await;
181
182 let mut bulk_result = BulkResult::default();
184
185 for (id, result) in outcomes {
186 match result {
187 Ok(Some(value)) => {
188 bulk_result.succeeded += 1;
189 bulk_result.outcomes.push((id, BulkOutcome::Success(value)));
190 }
191 Ok(None) => {
192 bulk_result.skipped += 1;
193 bulk_result
194 .outcomes
195 .push((id, BulkOutcome::Skipped("Skipped".to_string())));
196 }
197 Err(e) => {
198 bulk_result.failed += 1;
199 bulk_result
200 .outcomes
201 .push((id, BulkOutcome::Failed(e.to_string())));
202 }
203 }
204 }
205
206 bulk_result
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[tokio::test]
214 async fn test_successful_processing() {
215 let items = vec![
216 ("item1".to_string(), 1),
217 ("item2".to_string(), 2),
218 ("item3".to_string(), 3),
219 ];
220
221 let result = process_bulk(
222 items,
223 |(id, value)| async move { Ok(Some(format!("{}: {}", id, value * 2))) },
224 |_current, _total, _action| {},
225 )
226 .await;
227
228 assert_eq!(result.succeeded, 3);
229 assert_eq!(result.failed, 0);
230 assert_eq!(result.skipped, 0);
231 assert_eq!(result.outcomes.len(), 3);
232 }
233
234 #[tokio::test]
235 async fn test_mixed_outcomes() {
236 let items = vec![
237 ("success".to_string(), 1),
238 ("skip".to_string(), 2),
239 ("fail".to_string(), 3),
240 ];
241
242 let result = process_bulk(
243 items,
244 |(id, _value)| async move {
245 match id.as_str() {
246 "success" => Ok(Some("done".to_string())),
247 "skip" => Ok(None),
248 "fail" => Err(anyhow::anyhow!("Processing failed")),
249 _ => unreachable!(),
250 }
251 },
252 |_current, _total, _action| {},
253 )
254 .await;
255
256 assert_eq!(result.succeeded, 1);
257 assert_eq!(result.failed, 1);
258 assert_eq!(result.skipped, 1);
259 assert_eq!(result.outcomes.len(), 3);
260 }
261
262 #[tokio::test]
263 async fn test_progress_callback_invocation() {
264 use std::sync::{Arc, Mutex};
265
266 let items = vec![("item1".to_string(), 1), ("item2".to_string(), 2)];
267
268 let progress_calls = Arc::new(Mutex::new(Vec::new()));
269 let progress_calls_clone = progress_calls.clone();
270
271 let _result = process_bulk(
272 items,
273 |(_id, _value)| async move { Ok(Some("done".to_string())) },
274 move |current, total, action| {
275 progress_calls_clone
276 .lock()
277 .unwrap()
278 .push((current, total, action.to_string()));
279 },
280 )
281 .await;
282
283 let calls = progress_calls.lock().unwrap();
284 assert_eq!(calls.len(), 2);
285 assert_eq!(calls[0].0, 1);
286 assert_eq!(calls[0].1, 2);
287 assert_eq!(calls[1].0, 2);
288 assert_eq!(calls[1].1, 2);
289 }
290}