1use std::fmt::Display;
11
12use anyhow::Result;
13use backon::Retryable;
14use futures::{StreamExt, stream};
15
16use crate::{is_retryable_anyhow, retry_backoff};
17
18#[derive(Debug, Clone)]
20pub enum BulkOutcome<T> {
21 Success(T),
23 Skipped(String),
25 Failed(String),
27}
28
29#[derive(Debug, Clone)]
31pub struct BulkResult<I, T> {
32 pub succeeded: usize,
34 pub failed: usize,
36 pub skipped: usize,
38 pub outcomes: Vec<(I, BulkOutcome<T>)>,
40}
41
42impl<I, T> Default for BulkResult<I, T> {
43 fn default() -> Self {
44 Self {
45 succeeded: 0,
46 failed: 0,
47 skipped: 0,
48 outcomes: Vec::new(),
49 }
50 }
51}
52
53pub async fn process_bulk<I, D, T, F, Fut, P>(
117 items: Vec<(I, D)>,
118 processor: F,
119 progress_callback: P,
120) -> BulkResult<I, T>
121where
122 I: Clone + Display + Send + 'static,
123 D: Clone + Send + 'static,
124 T: Send + 'static,
125 F: Fn((I, D)) -> Fut + Send + Sync + 'static,
126 Fut: std::future::Future<Output = Result<Option<T>>> + Send,
127 P: Fn(usize, usize, &str) + Send + Sync + 'static,
128{
129 let total = items.len();
130 let progress_callback = std::sync::Arc::new(progress_callback);
131 let processor = std::sync::Arc::new(processor);
132
133 let mut tasks = Vec::new();
135 for (idx, (id, data)) in items.into_iter().enumerate() {
136 let id_clone = id.clone();
137 let data_clone = data.clone();
138 let progress_callback = progress_callback.clone();
139 let processor = processor.clone();
140
141 let task = async move {
142 progress_callback(idx + 1, total, &format!("Processing {id}"));
144
145 let id_for_retry = id_clone.clone();
147 let data_for_retry = data_clone.clone();
148 let result = (|| {
149 let processor = processor.clone();
150 let id = id_for_retry.clone();
151 let data = data_for_retry.clone();
152 async move { processor((id, data)).await }
153 })
154 .retry(retry_backoff())
155 .when(is_retryable_anyhow)
156 .notify(|err, dur| {
157 tracing::warn!(
158 error = %err,
159 delay_ms = dur.as_millis(),
160 item = %id_clone,
161 "Retrying after transient failure"
162 );
163 })
164 .await;
165
166 (id_clone, result)
167 };
168
169 tasks.push(task);
170 }
171
172 let outcomes = stream::iter(tasks)
173 .buffer_unordered(5)
174 .collect::<Vec<_>>()
175 .await;
176
177 let mut bulk_result = BulkResult::default();
179
180 for (id, result) in outcomes {
181 match result {
182 Ok(Some(value)) => {
183 bulk_result.succeeded += 1;
184 bulk_result.outcomes.push((id, BulkOutcome::Success(value)));
185 }
186 Ok(None) => {
187 bulk_result.skipped += 1;
188 bulk_result
189 .outcomes
190 .push((id, BulkOutcome::Skipped("Skipped".to_string())));
191 }
192 Err(e) => {
193 bulk_result.failed += 1;
194 bulk_result
195 .outcomes
196 .push((id, BulkOutcome::Failed(e.to_string())));
197 }
198 }
199 }
200
201 bulk_result
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[tokio::test]
209 async fn test_successful_processing() {
210 let items = vec![
211 ("item1".to_string(), 1),
212 ("item2".to_string(), 2),
213 ("item3".to_string(), 3),
214 ];
215
216 let result = process_bulk(
217 items,
218 |(id, value)| async move { Ok(Some(format!("{}: {}", id, value * 2))) },
219 |_current, _total, _action| {},
220 )
221 .await;
222
223 assert_eq!(result.succeeded, 3);
224 assert_eq!(result.failed, 0);
225 assert_eq!(result.skipped, 0);
226 assert_eq!(result.outcomes.len(), 3);
227 }
228
229 #[tokio::test]
230 async fn test_mixed_outcomes() {
231 let items = vec![
232 ("success".to_string(), 1),
233 ("skip".to_string(), 2),
234 ("fail".to_string(), 3),
235 ];
236
237 let result = process_bulk(
238 items,
239 |(id, _value)| async move {
240 match id.as_str() {
241 "success" => Ok(Some("done".to_string())),
242 "skip" => Ok(None),
243 "fail" => Err(anyhow::anyhow!("Processing failed")),
244 _ => unreachable!(),
245 }
246 },
247 |_current, _total, _action| {},
248 )
249 .await;
250
251 assert_eq!(result.succeeded, 1);
252 assert_eq!(result.failed, 1);
253 assert_eq!(result.skipped, 1);
254 assert_eq!(result.outcomes.len(), 3);
255 }
256
257 #[tokio::test]
258 async fn test_progress_callback_invocation() {
259 use std::sync::{Arc, Mutex};
260
261 let items = vec![("item1".to_string(), 1), ("item2".to_string(), 2)];
262
263 let progress_calls = Arc::new(Mutex::new(Vec::new()));
264 let progress_calls_clone = progress_calls.clone();
265
266 let _result = process_bulk(
267 items,
268 |(_id, _value)| async move { Ok(Some("done".to_string())) },
269 move |current, total, action| {
270 progress_calls_clone
271 .lock()
272 .unwrap()
273 .push((current, total, action.to_string()));
274 },
275 )
276 .await;
277
278 let calls = progress_calls.lock().unwrap();
279 assert_eq!(calls.len(), 2);
280 assert_eq!(calls[0].0, 1);
281 assert_eq!(calls[0].1, 2);
282 assert_eq!(calls[1].0, 2);
283 assert_eq!(calls[1].1, 2);
284 }
285}