1use anyhow::Result;
24use async_trait::async_trait;
25use std::sync::Arc;
26use tokio::sync::mpsc;
27use tokio::task::JoinSet;
28
29#[async_trait]
39pub trait UploadHandler: Send + Sync + 'static {
40 type Id: Clone + Send + std::fmt::Debug + 'static;
43
44 async fn upload(&self, id: Self::Id) -> Result<u64>;
49
50 fn pending_items(&self) -> Vec<Self::Id>;
53
54 fn name(&self) -> &str;
56}
57
58pub enum UploadMessage<Id> {
60 Upload(Id),
62 UploadWithAck(Id, tokio::sync::oneshot::Sender<anyhow::Result<()>>),
65 Shutdown,
67}
68
69impl<Id: std::fmt::Debug> std::fmt::Debug for UploadMessage<Id> {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Self::Upload(id) => f.debug_tuple("Upload").field(id).finish(),
73 Self::UploadWithAck(id, _) => f.debug_tuple("UploadWithAck").field(id).finish(),
74 Self::Shutdown => write!(f, "Shutdown"),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct UploaderStats {
82 pub uploads_attempted: u64,
83 pub uploads_succeeded: u64,
84 pub uploads_failed: u64,
85 pub bytes_uploaded: u64,
86}
87
88pub struct ConcurrentUploader<H: UploadHandler> {
93 handler: Arc<H>,
94 max_concurrent: usize,
95 stats: Arc<tokio::sync::Mutex<UploaderStats>>,
96}
97
98impl<H: UploadHandler> ConcurrentUploader<H> {
99 pub fn new(handler: Arc<H>, max_concurrent: usize) -> Self {
101 Self {
102 handler,
103 max_concurrent: max_concurrent.max(1),
104 stats: Arc::new(tokio::sync::Mutex::new(UploaderStats::default())),
105 }
106 }
107
108 pub async fn run(
110 &self,
111 mut rx: mpsc::Receiver<UploadMessage<H::Id>>,
112 ) -> Result<UploaderStats> {
113 tracing::info!(
114 "[{}] Uploader started (max_concurrent={})",
115 self.handler.name(),
116 self.max_concurrent
117 );
118
119 let mut in_flight: JoinSet<(H::Id, Result<u64>)> = JoinSet::new();
120
121 let pending = self.handler.pending_items();
123 if !pending.is_empty() {
124 tracing::info!(
125 "[{}] Resuming {} pending uploads",
126 self.handler.name(),
127 pending.len()
128 );
129 for id in pending {
130 while in_flight.len() >= self.max_concurrent {
131 if let Some(result) = in_flight.join_next().await {
132 self.handle_join_result(result).await;
133 }
134 }
135 let handler = self.handler.clone();
136 let id_clone = id.clone();
137 in_flight.spawn(async move {
138 let result = handler.upload(id_clone.clone()).await;
139 (id_clone, result)
140 });
141 }
142 }
143
144 loop {
146 tokio::select! {
147 msg = rx.recv(), if in_flight.len() < self.max_concurrent => {
149 match msg {
150 Some(UploadMessage::Upload(id)) => {
151 let handler = self.handler.clone();
152 let id_clone = id.clone();
153 in_flight.spawn(async move {
154 let result = handler.upload(id_clone.clone()).await;
155 (id_clone, result)
156 });
157 }
158 Some(UploadMessage::UploadWithAck(id, ack_tx)) => {
159 let handler = self.handler.clone();
160 let id_clone = id.clone();
161 in_flight.spawn(async move {
162 let result = handler.upload(id_clone.clone()).await;
163 let ack_result = match &result {
164 Ok(_) => Ok(()),
165 Err(e) => Err(anyhow::anyhow!("{}", e)),
166 };
167 let _ = ack_tx.send(ack_result);
168 (id_clone, result)
169 });
170 }
171 Some(UploadMessage::Shutdown) => {
172 tracing::info!(
173 "[{}] Shutdown signal, draining {} in-flight",
174 self.handler.name(),
175 in_flight.len()
176 );
177 while let Some(result) = in_flight.join_next().await {
178 self.handle_join_result(result).await;
179 }
180 break;
181 }
182 None => {
183 tracing::info!(
184 "[{}] Channel closed, draining {} in-flight",
185 self.handler.name(),
186 in_flight.len()
187 );
188 while let Some(result) = in_flight.join_next().await {
189 self.handle_join_result(result).await;
190 }
191 break;
192 }
193 }
194 }
195 Some(result) = in_flight.join_next() => {
197 self.handle_join_result(result).await;
198 }
199 }
200 }
201
202 let stats = self.stats.lock().await.clone();
203 tracing::info!(
204 "[{}] Uploader stopped. Stats: {:?}",
205 self.handler.name(),
206 stats
207 );
208 Ok(stats)
209 }
210
211 async fn handle_join_result(
213 &self,
214 result: Result<(H::Id, Result<u64>), tokio::task::JoinError>,
215 ) {
216 match result {
217 Ok((id, Ok(bytes))) => {
218 let mut stats = self.stats.lock().await;
219 stats.uploads_attempted += 1;
220 stats.uploads_succeeded += 1;
221 stats.bytes_uploaded += bytes;
222 tracing::debug!("[{}] Uploaded {:?} ({} bytes)", self.handler.name(), id, bytes);
223 }
224 Ok((id, Err(e))) => {
225 let mut stats = self.stats.lock().await;
226 stats.uploads_attempted += 1;
227 stats.uploads_failed += 1;
228 tracing::error!("[{}] Upload failed for {:?}: {}", self.handler.name(), id, e);
229 }
230 Err(e) => {
231 let mut stats = self.stats.lock().await;
232 stats.uploads_attempted += 1;
233 stats.uploads_failed += 1;
234 tracing::error!("[{}] Upload task panicked: {}", self.handler.name(), e);
235 }
236 }
237 }
238
239 pub async fn stats(&self) -> UploaderStats {
241 self.stats.lock().await.clone()
242 }
243}
244
245pub fn spawn_uploader<H: UploadHandler>(
261 uploader: Arc<ConcurrentUploader<H>>,
262) -> (
263 mpsc::Sender<UploadMessage<H::Id>>,
264 tokio::task::JoinHandle<()>,
265) {
266 let (tx, rx) = mpsc::channel(1000);
267
268 let handle = tokio::spawn(async move {
269 if let Err(e) = uploader.run(rx).await {
270 tracing::error!("Uploader task failed: {}", e);
271 }
272 });
273
274 (tx, handle)
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use std::collections::HashSet;
281 use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
282 use std::sync::Mutex;
283 use tokio::time::{timeout, Duration};
284
285 struct MockHandler {
287 name: String,
288 pending: Mutex<Vec<u64>>,
289 uploaded: Mutex<HashSet<u64>>,
290 upload_delay: Option<Duration>,
291 fail_ids: Mutex<HashSet<u64>>,
292 active: AtomicUsize,
294 peak_concurrent: AtomicUsize,
295 upload_count: AtomicU64,
296 }
297
298 impl MockHandler {
299 fn new(name: &str) -> Self {
300 Self {
301 name: name.to_string(),
302 pending: Mutex::new(vec![]),
303 uploaded: Mutex::new(HashSet::new()),
304 upload_delay: None,
305 fail_ids: Mutex::new(HashSet::new()),
306 active: AtomicUsize::new(0),
307 peak_concurrent: AtomicUsize::new(0),
308 upload_count: AtomicU64::new(0),
309 }
310 }
311
312 fn with_pending(mut self, pending: Vec<u64>) -> Self {
313 self.pending = Mutex::new(pending);
314 self
315 }
316
317 fn with_delay(mut self, delay: Duration) -> Self {
318 self.upload_delay = Some(delay);
319 self
320 }
321
322 fn with_fail_ids(mut self, ids: HashSet<u64>) -> Self {
323 self.fail_ids = Mutex::new(ids);
324 self
325 }
326
327 fn uploaded_ids(&self) -> HashSet<u64> {
328 self.uploaded.lock().unwrap().clone()
329 }
330
331 fn peak_concurrent(&self) -> usize {
332 self.peak_concurrent.load(Ordering::SeqCst)
333 }
334
335 fn upload_count(&self) -> u64 {
336 self.upload_count.load(Ordering::SeqCst)
337 }
338 }
339
340 #[async_trait]
341 impl UploadHandler for MockHandler {
342 type Id = u64;
343
344 async fn upload(&self, id: u64) -> Result<u64> {
345 let active = self.active.fetch_add(1, Ordering::SeqCst) + 1;
346 self.peak_concurrent.fetch_max(active, Ordering::SeqCst);
347
348 if let Some(delay) = self.upload_delay {
349 tokio::time::sleep(delay).await;
350 }
351
352 self.active.fetch_sub(1, Ordering::SeqCst);
353 self.upload_count.fetch_add(1, Ordering::SeqCst);
354
355 if self.fail_ids.lock().unwrap().contains(&id) {
356 return Err(anyhow::anyhow!("Simulated failure for id {}", id));
357 }
358
359 self.uploaded.lock().unwrap().insert(id);
360 Ok(100) }
362
363 fn pending_items(&self) -> Vec<u64> {
364 self.pending.lock().unwrap().clone()
365 }
366
367 fn name(&self) -> &str {
368 &self.name
369 }
370 }
371
372 #[tokio::test]
373 async fn test_basic_upload() {
374 let handler = Arc::new(MockHandler::new("test"));
375 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
376
377 let (tx, rx) = mpsc::channel(10);
378 let uploader_clone = uploader.clone();
379 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
380
381 tx.send(UploadMessage::Upload(1)).await.unwrap();
382 tx.send(UploadMessage::Shutdown).await.unwrap();
383
384 let stats = task.await.unwrap().unwrap();
385 assert_eq!(stats.uploads_succeeded, 1);
386 assert_eq!(stats.uploads_failed, 0);
387 assert_eq!(stats.bytes_uploaded, 100);
388 assert!(handler.uploaded_ids().contains(&1));
389 }
390
391 #[tokio::test]
392 async fn test_multiple_uploads() {
393 let handler = Arc::new(MockHandler::new("test"));
394 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
395
396 let (tx, rx) = mpsc::channel(20);
397 let uploader_clone = uploader.clone();
398 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
399
400 for i in 1..=10 {
401 tx.send(UploadMessage::Upload(i)).await.unwrap();
402 }
403 tx.send(UploadMessage::Shutdown).await.unwrap();
404
405 let stats = task.await.unwrap().unwrap();
406 assert_eq!(stats.uploads_succeeded, 10);
407 assert_eq!(stats.bytes_uploaded, 1000);
408 assert_eq!(handler.uploaded_ids().len(), 10);
409 }
410
411 #[tokio::test]
412 async fn test_resume_pending() {
413 let handler = Arc::new(MockHandler::new("test").with_pending(vec![1, 2, 3]));
414 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
415
416 let (tx, rx) = mpsc::channel(10);
417 let uploader_clone = uploader.clone();
418 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
419
420 tokio::time::sleep(Duration::from_millis(50)).await;
422 tx.send(UploadMessage::Shutdown).await.unwrap();
423
424 let stats = task.await.unwrap().unwrap();
425 assert_eq!(stats.uploads_succeeded, 3);
426 assert!(handler.uploaded_ids().contains(&1));
427 assert!(handler.uploaded_ids().contains(&2));
428 assert!(handler.uploaded_ids().contains(&3));
429 }
430
431 #[tokio::test]
432 async fn test_concurrent_respects_limit() {
433 let handler = Arc::new(
434 MockHandler::new("test").with_delay(Duration::from_millis(50)),
435 );
436 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 3));
437
438 let (tx, rx) = mpsc::channel(20);
439 let uploader_clone = uploader.clone();
440 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
441
442 for i in 1..=10 {
443 tx.send(UploadMessage::Upload(i)).await.unwrap();
444 }
445 tx.send(UploadMessage::Shutdown).await.unwrap();
446
447 let stats = timeout(Duration::from_secs(5), task)
448 .await
449 .expect("should complete")
450 .unwrap()
451 .unwrap();
452
453 assert_eq!(stats.uploads_succeeded, 10);
454 assert!(
456 handler.peak_concurrent() <= 3,
457 "peak concurrent was {}, expected <= 3",
458 handler.peak_concurrent()
459 );
460 assert!(
462 handler.peak_concurrent() > 1,
463 "peak concurrent was {}, expected > 1 (should use concurrency)",
464 handler.peak_concurrent()
465 );
466 }
467
468 #[tokio::test]
469 async fn test_failure_doesnt_block_others() {
470 let mut fail_ids = HashSet::new();
471 fail_ids.insert(3);
472 fail_ids.insert(7);
473
474 let handler = Arc::new(MockHandler::new("test").with_fail_ids(fail_ids));
475 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
476
477 let (tx, rx) = mpsc::channel(20);
478 let uploader_clone = uploader.clone();
479 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
480
481 for i in 1..=10 {
482 tx.send(UploadMessage::Upload(i)).await.unwrap();
483 }
484 tx.send(UploadMessage::Shutdown).await.unwrap();
485
486 let stats = task.await.unwrap().unwrap();
487 assert_eq!(stats.uploads_succeeded, 8);
488 assert_eq!(stats.uploads_failed, 2);
489 assert!(!handler.uploaded_ids().contains(&3));
490 assert!(!handler.uploaded_ids().contains(&7));
491 assert!(handler.uploaded_ids().contains(&1));
492 assert!(handler.uploaded_ids().contains(&10));
493 }
494
495 #[tokio::test]
496 async fn test_graceful_shutdown_drains() {
497 let handler = Arc::new(
498 MockHandler::new("test").with_delay(Duration::from_millis(100)),
499 );
500 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
501
502 let (tx, rx) = mpsc::channel(10);
503 let uploader_clone = uploader.clone();
504 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
505
506 for i in 1..=4 {
508 tx.send(UploadMessage::Upload(i)).await.unwrap();
509 }
510 tokio::time::sleep(Duration::from_millis(20)).await;
512 tx.send(UploadMessage::Shutdown).await.unwrap();
514
515 let stats = timeout(Duration::from_secs(5), task)
516 .await
517 .expect("should complete within timeout")
518 .unwrap()
519 .unwrap();
520
521 assert_eq!(stats.uploads_succeeded, 4);
523 }
524
525 #[tokio::test]
526 async fn test_channel_close_drains() {
527 let handler = Arc::new(
528 MockHandler::new("test").with_delay(Duration::from_millis(50)),
529 );
530 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
531
532 let (tx, rx) = mpsc::channel(10);
533 let uploader_clone = uploader.clone();
534 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
535
536 for i in 1..=3 {
537 tx.send(UploadMessage::Upload(i)).await.unwrap();
538 }
539 tokio::time::sleep(Duration::from_millis(10)).await;
540 drop(tx); let stats = timeout(Duration::from_secs(5), task)
543 .await
544 .expect("should complete")
545 .unwrap()
546 .unwrap();
547
548 assert_eq!(stats.uploads_succeeded, 3);
549 }
550
551 #[tokio::test]
552 async fn test_spawn_uploader_helper() {
553 let handler = Arc::new(MockHandler::new("test"));
554 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
555
556 let (tx, handle) = spawn_uploader(uploader);
557
558 tx.send(UploadMessage::Upload(1)).await.unwrap();
559 tx.send(UploadMessage::Upload(2)).await.unwrap();
560 tx.send(UploadMessage::Shutdown).await.unwrap();
561
562 handle.await.unwrap();
563 assert_eq!(handler.upload_count(), 2);
564 }
565
566 #[tokio::test]
567 async fn test_stats_tracking() {
568 let handler = Arc::new(MockHandler::new("test"));
569 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
570
571 let (tx, rx) = mpsc::channel(10);
572 let uploader_clone = uploader.clone();
573 tokio::spawn(async move { uploader_clone.run(rx).await });
574
575 tx.send(UploadMessage::Upload(1)).await.unwrap();
576 tokio::time::sleep(Duration::from_millis(50)).await;
577
578 let stats = uploader.stats().await;
579 assert_eq!(stats.uploads_attempted, 1);
580 assert_eq!(stats.bytes_uploaded, 100);
581
582 tx.send(UploadMessage::Upload(2)).await.unwrap();
583 tokio::time::sleep(Duration::from_millis(50)).await;
584
585 let stats = uploader.stats().await;
586 assert_eq!(stats.uploads_attempted, 2);
587 assert_eq!(stats.bytes_uploaded, 200);
588
589 tx.send(UploadMessage::Shutdown).await.unwrap();
590 }
591
592 #[tokio::test]
593 async fn test_max_concurrent_one() {
594 let handler = Arc::new(
596 MockHandler::new("test").with_delay(Duration::from_millis(20)),
597 );
598 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 1));
599
600 let (tx, rx) = mpsc::channel(10);
601 let uploader_clone = uploader.clone();
602 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
603
604 for i in 1..=5 {
605 tx.send(UploadMessage::Upload(i)).await.unwrap();
606 }
607 tx.send(UploadMessage::Shutdown).await.unwrap();
608
609 let stats = timeout(Duration::from_secs(5), task)
610 .await
611 .expect("should complete")
612 .unwrap()
613 .unwrap();
614
615 assert_eq!(stats.uploads_succeeded, 5);
616 assert_eq!(handler.peak_concurrent(), 1);
617 }
618
619 #[tokio::test]
620 async fn test_zero_concurrency_defaults_to_one() {
621 let handler = Arc::new(MockHandler::new("test"));
622 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 0));
623
624 let (tx, rx) = mpsc::channel(10);
625 let uploader_clone = uploader.clone();
626 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
627
628 tx.send(UploadMessage::Upload(1)).await.unwrap();
629 tx.send(UploadMessage::Shutdown).await.unwrap();
630
631 let stats = task.await.unwrap().unwrap();
632 assert_eq!(stats.uploads_succeeded, 1);
633 }
634
635 #[tokio::test]
636 async fn test_upload_with_ack_success() {
637 let handler = Arc::new(MockHandler::new("test"));
638 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
639
640 let (tx, rx) = mpsc::channel(10);
641 let uploader_clone = uploader.clone();
642 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
643
644 let (ack_tx, ack_rx) = tokio::sync::oneshot::channel();
645 tx.send(UploadMessage::UploadWithAck(1, ack_tx)).await.unwrap();
646
647 let result = ack_rx.await.unwrap();
648 assert!(result.is_ok());
649 assert!(handler.uploaded_ids().contains(&1));
650
651 tx.send(UploadMessage::Shutdown).await.unwrap();
652 let stats = task.await.unwrap().unwrap();
653 assert_eq!(stats.uploads_succeeded, 1);
654 }
655
656 #[tokio::test]
657 async fn test_upload_with_ack_failure() {
658 let mut fail_ids = HashSet::new();
659 fail_ids.insert(1);
660
661 let handler = Arc::new(MockHandler::new("test").with_fail_ids(fail_ids));
662 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
663
664 let (tx, rx) = mpsc::channel(10);
665 let uploader_clone = uploader.clone();
666 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
667
668 let (ack_tx, ack_rx) = tokio::sync::oneshot::channel();
669 tx.send(UploadMessage::UploadWithAck(1, ack_tx)).await.unwrap();
670
671 let result = ack_rx.await.unwrap();
672 assert!(result.is_err());
673 assert!(!handler.uploaded_ids().contains(&1));
674
675 tx.send(UploadMessage::Shutdown).await.unwrap();
676 let stats = task.await.unwrap().unwrap();
677 assert_eq!(stats.uploads_failed, 1);
678 }
679
680 #[tokio::test]
681 async fn test_upload_with_ack_mixed_with_fire_and_forget() {
682 let handler = Arc::new(MockHandler::new("test"));
683 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
684
685 let (tx, rx) = mpsc::channel(10);
686 let uploader_clone = uploader.clone();
687 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
688
689 tx.send(UploadMessage::Upload(1)).await.unwrap();
690 tx.send(UploadMessage::Upload(2)).await.unwrap();
691
692 let (ack_tx, ack_rx) = tokio::sync::oneshot::channel();
693 tx.send(UploadMessage::UploadWithAck(3, ack_tx)).await.unwrap();
694
695 tx.send(UploadMessage::Upload(4)).await.unwrap();
696
697 let result = ack_rx.await.unwrap();
698 assert!(result.is_ok());
699
700 tx.send(UploadMessage::Shutdown).await.unwrap();
701 let stats = task.await.unwrap().unwrap();
702 assert_eq!(stats.uploads_succeeded, 4);
703 }
704
705 #[tokio::test]
706 async fn test_empty_shutdown() {
707 let handler = Arc::new(MockHandler::new("test"));
708 let uploader = Arc::new(ConcurrentUploader::new(handler.clone(), 4));
709
710 let (tx, rx) = mpsc::channel(10);
711 let uploader_clone = uploader.clone();
712 let task = tokio::spawn(async move { uploader_clone.run(rx).await });
713
714 tx.send(UploadMessage::Shutdown).await.unwrap();
715
716 let stats = task.await.unwrap().unwrap();
717 assert_eq!(stats.uploads_attempted, 0);
718 }
719}