1use bytes::{Bytes, BytesMut};
4use futures::stream::{self, StreamExt};
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::mpsc;
10use tokio::time::sleep;
11
12use crate::error::Result;
13use crate::types::{CloudStorage, TransferProgress};
14
15pub struct TransferManager {
17 storage: Arc<dyn CloudStorage>,
19 config: TransferConfig,
21 transfers: Arc<RwLock<HashMap<String, TransferState>>>,
23}
24
25impl TransferManager {
26 #[must_use]
28 pub fn new(storage: Arc<dyn CloudStorage>, config: TransferConfig) -> Self {
29 Self {
30 storage,
31 config,
32 transfers: Arc::new(RwLock::new(HashMap::new())),
33 }
34 }
35
36 pub async fn upload(
42 &self,
43 key: &str,
44 data: Bytes,
45 progress_tx: Option<mpsc::Sender<TransferProgress>>,
46 ) -> Result<()> {
47 let transfer_id = format!("upload-{key}");
48 self.init_transfer(&transfer_id, data.len() as u64);
49
50 if data.len() <= self.config.chunk_size {
51 self.upload_single_part(key, data, &transfer_id, progress_tx)
53 .await
54 } else {
55 self.upload_multipart(key, data, &transfer_id, progress_tx)
57 .await
58 }
59 }
60
61 pub async fn download(
67 &self,
68 key: &str,
69 progress_tx: Option<mpsc::Sender<TransferProgress>>,
70 ) -> Result<Bytes> {
71 let transfer_id = format!("download-{key}");
72
73 let metadata = self.storage.get_metadata(key).await?;
75 let total_size = metadata.info.size;
76
77 self.init_transfer(&transfer_id, total_size);
78
79 if total_size <= self.config.chunk_size as u64 {
80 self.download_single_part(key, &transfer_id, progress_tx)
82 .await
83 } else {
84 self.download_multipart(key, total_size, &transfer_id, progress_tx)
86 .await
87 }
88 }
89
90 async fn upload_single_part(
92 &self,
93 key: &str,
94 data: Bytes,
95 transfer_id: &str,
96 progress_tx: Option<mpsc::Sender<TransferProgress>>,
97 ) -> Result<()> {
98 let mut attempts = 0;
99 let total_size = data.len() as u64;
100
101 loop {
102 match self.storage.upload(key, data.clone()).await {
103 Ok(()) => {
104 self.update_progress(transfer_id, total_size, total_size, &progress_tx)
105 .await;
106 self.complete_transfer(transfer_id);
107 return Ok(());
108 }
109 Err(e) if e.is_retryable() && attempts < self.config.max_retries => {
110 attempts += 1;
111 tracing::warn!("Upload attempt {} failed: {}", attempts, e);
112 sleep(self.retry_delay(attempts)).await;
113 }
114 Err(e) => {
115 self.fail_transfer(transfer_id);
116 return Err(e);
117 }
118 }
119 }
120 }
121
122 async fn upload_multipart(
124 &self,
125 key: &str,
126 data: Bytes,
127 transfer_id: &str,
128 progress_tx: Option<mpsc::Sender<TransferProgress>>,
129 ) -> Result<()> {
130 let chunk_size = self.config.chunk_size;
131 let total_size = data.len() as u64;
132 let num_chunks = data.len().div_ceil(chunk_size);
133
134 let chunks: Vec<Bytes> = (0..num_chunks)
136 .map(|i| {
137 let start = i * chunk_size;
138 let end = std::cmp::min(start + chunk_size, data.len());
139 data.slice(start..end)
140 })
141 .collect();
142
143 let max_concurrent = self.config.max_concurrent_transfers;
145 let mut bytes_transferred = 0u64;
146
147 let chunk_futures: Vec<_> = chunks
148 .into_iter()
149 .enumerate()
150 .map(|(i, chunk)| {
151 let chunk_key = format!("{key}.part{i}");
152 let storage = self.storage.clone();
153 let chunk_len = chunk.len() as u64;
154
155 async move {
156 let mut attempts = 0;
157 loop {
158 match storage.upload(&chunk_key, chunk.clone()).await {
159 Ok(()) => return Ok(chunk_len),
160 Err(e) if e.is_retryable() && attempts < self.config.max_retries => {
161 attempts += 1;
162 sleep(Duration::from_secs(2u64.pow(attempts))).await;
163 }
164 Err(e) => return Err(e),
165 }
166 }
167 }
168 })
169 .collect();
170
171 let mut stream = stream::iter(chunk_futures).buffer_unordered(max_concurrent);
172
173 while let Some(result) = stream.next().await {
174 let chunk_len = result?;
175 bytes_transferred += chunk_len;
176 self.update_progress(transfer_id, bytes_transferred, total_size, &progress_tx)
177 .await;
178 }
179
180 self.complete_transfer(transfer_id);
183 Ok(())
184 }
185
186 async fn download_single_part(
188 &self,
189 key: &str,
190 transfer_id: &str,
191 progress_tx: Option<mpsc::Sender<TransferProgress>>,
192 ) -> Result<Bytes> {
193 let mut attempts = 0;
194
195 loop {
196 match self.storage.download(key).await {
197 Ok(data) => {
198 let total_size = data.len() as u64;
199 self.update_progress(transfer_id, total_size, total_size, &progress_tx)
200 .await;
201 self.complete_transfer(transfer_id);
202 return Ok(data);
203 }
204 Err(e) if e.is_retryable() && attempts < self.config.max_retries => {
205 attempts += 1;
206 tracing::warn!("Download attempt {} failed: {}", attempts, e);
207 sleep(self.retry_delay(attempts)).await;
208 }
209 Err(e) => {
210 self.fail_transfer(transfer_id);
211 return Err(e);
212 }
213 }
214 }
215 }
216
217 async fn download_multipart(
219 &self,
220 key: &str,
221 total_size: u64,
222 transfer_id: &str,
223 progress_tx: Option<mpsc::Sender<TransferProgress>>,
224 ) -> Result<Bytes> {
225 let chunk_size = self.config.chunk_size as u64;
226 let num_chunks = total_size.div_ceil(chunk_size);
227
228 let ranges: Vec<(u64, u64)> = (0..num_chunks)
230 .map(|i| {
231 let start = i * chunk_size;
232 let end = std::cmp::min(start + chunk_size - 1, total_size - 1);
233 (start, end)
234 })
235 .collect();
236
237 let max_concurrent = self.config.max_concurrent_transfers;
238 let mut bytes_transferred = 0u64;
239
240 let chunk_futures: Vec<_> = ranges
242 .into_iter()
243 .map(|(start, end)| {
244 let storage = self.storage.clone();
245 let key = key.to_string();
246
247 async move {
248 let mut attempts = 0;
249 loop {
250 match storage.download_range(&key, start, end).await {
251 Ok(data) => return Ok((start, data)),
252 Err(e) if e.is_retryable() && attempts < self.config.max_retries => {
253 attempts += 1;
254 sleep(Duration::from_secs(2u64.pow(attempts))).await;
255 }
256 Err(e) => return Err(e),
257 }
258 }
259 }
260 })
261 .collect();
262
263 let mut stream = stream::iter(chunk_futures).buffer_unordered(max_concurrent);
264 let mut chunks: Vec<(u64, Bytes)> = Vec::new();
265
266 while let Some(result) = stream.next().await {
267 let (offset, chunk) = result?;
268 bytes_transferred += chunk.len() as u64;
269 chunks.push((offset, chunk));
270 self.update_progress(transfer_id, bytes_transferred, total_size, &progress_tx)
271 .await;
272 }
273
274 chunks.sort_by_key(|(offset, _)| *offset);
276 let mut combined = BytesMut::with_capacity(total_size as usize);
277 for (_, chunk) in chunks {
278 combined.extend_from_slice(&chunk);
279 }
280
281 self.complete_transfer(transfer_id);
282 Ok(combined.freeze())
283 }
284
285 fn init_transfer(&self, transfer_id: &str, total_size: u64) {
287 let state = TransferState {
288 total_size,
289 bytes_transferred: 0,
290 start_time: Instant::now(),
291 status: TransferStatus::InProgress,
292 };
293 self.transfers
294 .write()
295 .insert(transfer_id.to_string(), state);
296 }
297
298 async fn update_progress(
300 &self,
301 transfer_id: &str,
302 bytes_transferred: u64,
303 total_size: u64,
304 progress_tx: &Option<mpsc::Sender<TransferProgress>>,
305 ) {
306 let (_elapsed, rate_bps, eta_secs) = {
307 if let Some(state) = self.transfers.write().get_mut(transfer_id) {
308 state.bytes_transferred = bytes_transferred;
309
310 let elapsed = state.start_time.elapsed().as_secs_f64();
311 let rate_bps = if elapsed > 0.0 {
312 bytes_transferred as f64 / elapsed
313 } else {
314 0.0
315 };
316
317 let remaining_bytes = total_size.saturating_sub(bytes_transferred);
318 let eta_secs = if rate_bps > 0.0 {
319 Some(remaining_bytes as f64 / rate_bps)
320 } else {
321 None
322 };
323 (elapsed, rate_bps, eta_secs)
324 } else {
325 (0.0, 0.0, None)
326 }
327 };
328
329 if let Some(tx) = progress_tx {
330 let progress = TransferProgress {
331 bytes_transferred,
332 total_bytes: total_size,
333 rate_bps,
334 eta_secs,
335 };
336 let _ = tx.send(progress).await;
337 }
338 }
339
340 fn complete_transfer(&self, transfer_id: &str) {
342 if let Some(state) = self.transfers.write().get_mut(transfer_id) {
343 state.status = TransferStatus::Completed;
344 }
345 }
346
347 fn fail_transfer(&self, transfer_id: &str) {
349 if let Some(state) = self.transfers.write().get_mut(transfer_id) {
350 state.status = TransferStatus::Failed;
351 }
352 }
353
354 fn retry_delay(&self, attempt: u32) -> Duration {
356 let base_delay = Duration::from_secs(1);
357 let max_delay = Duration::from_secs(60);
358 let delay = base_delay * 2u32.pow(attempt);
359 std::cmp::min(delay, max_delay)
360 }
361
362 #[must_use]
364 pub fn get_status(&self, transfer_id: &str) -> Option<TransferState> {
365 self.transfers.read().get(transfer_id).cloned()
366 }
367}
368
369#[derive(Debug, Clone)]
371pub struct TransferConfig {
372 pub chunk_size: usize,
374 pub max_concurrent_transfers: usize,
376 pub max_retries: u32,
378 pub verify_checksum: bool,
380 pub bandwidth_limit_bps: Option<u64>,
382}
383
384impl Default for TransferConfig {
385 fn default() -> Self {
386 Self {
387 chunk_size: 5 * 1024 * 1024, max_concurrent_transfers: 4,
389 max_retries: 3,
390 verify_checksum: true,
391 bandwidth_limit_bps: None,
392 }
393 }
394}
395
396impl TransferConfig {
397 #[must_use]
399 pub fn small_files() -> Self {
400 Self {
401 chunk_size: 1024 * 1024, max_concurrent_transfers: 8,
403 max_retries: 3,
404 verify_checksum: true,
405 bandwidth_limit_bps: None,
406 }
407 }
408
409 #[must_use]
411 pub fn large_files() -> Self {
412 Self {
413 chunk_size: 20 * 1024 * 1024, max_concurrent_transfers: 8,
415 max_retries: 5,
416 verify_checksum: true,
417 bandwidth_limit_bps: None,
418 }
419 }
420}
421
422#[derive(Debug, Clone)]
424pub struct TransferState {
425 pub total_size: u64,
427 pub bytes_transferred: u64,
429 pub start_time: Instant,
431 pub status: TransferStatus,
433}
434
435#[derive(Debug, Clone, Copy, PartialEq, Eq)]
437pub enum TransferStatus {
438 InProgress,
440 Completed,
442 Failed,
444 Paused,
446}
447
448pub struct ChecksumCalculator {
450 md5: md5::Md5,
452 sha256: sha2::Sha256,
454}
455
456impl ChecksumCalculator {
457 #[must_use]
459 pub fn new() -> Self {
460 use sha2::Digest;
461 Self {
462 md5: md5::Md5::new(),
463 sha256: sha2::Sha256::new(),
464 }
465 }
466
467 pub fn update(&mut self, data: &[u8]) {
469 use sha2::Digest;
470 self.md5.update(data);
471 self.sha256.update(data);
472 }
473
474 #[must_use]
476 pub fn finalize(self) -> Checksums {
477 use sha2::Digest;
478 let md5_digest = self.md5.finalize();
479 let sha256_digest = self.sha256.finalize();
480
481 Checksums {
482 md5: hex::encode(&md5_digest[..]),
483 sha256: hex::encode(&sha256_digest[..]),
484 }
485 }
486}
487
488impl Default for ChecksumCalculator {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct Checksums {
497 pub md5: String,
499 pub sha256: String,
501}
502
503impl Checksums {
504 #[must_use]
506 pub fn verify_md5(&self, expected: &str) -> bool {
507 self.md5.eq_ignore_ascii_case(expected)
508 }
509
510 #[must_use]
512 pub fn verify_sha256(&self, expected: &str) -> bool {
513 self.sha256.eq_ignore_ascii_case(expected)
514 }
515}
516
517#[derive(Debug, Clone)]
519pub struct RetryPolicy {
520 pub max_attempts: u32,
522 pub base_delay_ms: u64,
524 pub max_delay_ms: u64,
526}
527
528impl Default for RetryPolicy {
529 fn default() -> Self {
530 Self {
531 max_attempts: 3,
532 base_delay_ms: 200,
533 max_delay_ms: 30_000,
534 }
535 }
536}
537
538impl RetryPolicy {
539 #[must_use]
543 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
544 let shift = attempt.min(62); let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
546 let ms = self
547 .base_delay_ms
548 .saturating_mul(multiplier)
549 .min(self.max_delay_ms);
550 Duration::from_millis(ms)
551 }
552}
553
554pub fn execute_with_retry<F, T, E>(policy: &RetryPolicy, mut op: F) -> std::result::Result<T, E>
564where
565 F: FnMut(u32) -> std::result::Result<T, E>,
566 E: std::fmt::Debug,
567{
568 let mut last_err: Option<E> = None;
569 for attempt in 0..policy.max_attempts {
570 match op(attempt) {
571 Ok(value) => return Ok(value),
572 Err(e) => {
573 tracing::warn!("Attempt {} failed: {:?}", attempt + 1, e);
574 last_err = Some(e);
575 if attempt + 1 < policy.max_attempts {
576 std::thread::sleep(policy.delay_for_attempt(attempt));
577 }
578 }
579 }
580 }
581 Err(last_err.expect("max_attempts must be > 0"))
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_transfer_config_defaults() {
591 let config = TransferConfig::default();
592 assert_eq!(config.chunk_size, 5 * 1024 * 1024);
593 assert_eq!(config.max_concurrent_transfers, 4);
594 assert_eq!(config.max_retries, 3);
595 }
596
597 #[test]
598 fn test_transfer_config_presets() {
599 let small = TransferConfig::small_files();
600 assert_eq!(small.chunk_size, 1024 * 1024);
601
602 let large = TransferConfig::large_files();
603 assert_eq!(large.chunk_size, 20 * 1024 * 1024);
604 }
605
606 #[test]
607 fn test_checksum_calculator() {
608 let mut calc = ChecksumCalculator::new();
609 calc.update(b"test data");
610 let checksums = calc.finalize();
611
612 assert!(!checksums.md5.is_empty());
613 assert!(!checksums.sha256.is_empty());
614 }
615
616 #[test]
617 fn test_checksum_verification() {
618 let mut calc = ChecksumCalculator::new();
619 calc.update(b"test");
620 let checksums = calc.finalize();
621
622 assert!(checksums.verify_md5("098f6bcd4621d373cade4e832627b4f6"));
624 }
625
626 #[test]
627 fn test_transfer_status() {
628 assert_eq!(TransferStatus::InProgress, TransferStatus::InProgress);
629 assert_ne!(TransferStatus::InProgress, TransferStatus::Completed);
630 }
631
632 #[test]
635 fn test_retry_policy_default() {
636 let p = RetryPolicy::default();
637 assert_eq!(p.max_attempts, 3);
638 assert_eq!(p.base_delay_ms, 200);
639 assert_eq!(p.max_delay_ms, 30_000);
640 }
641
642 #[test]
643 fn test_retry_policy_delay_doubles() {
644 let p = RetryPolicy {
645 max_attempts: 5,
646 base_delay_ms: 100,
647 max_delay_ms: 10_000,
648 };
649 assert_eq!(p.delay_for_attempt(0), Duration::from_millis(100));
650 assert_eq!(p.delay_for_attempt(1), Duration::from_millis(200));
651 assert_eq!(p.delay_for_attempt(2), Duration::from_millis(400));
652 assert_eq!(p.delay_for_attempt(3), Duration::from_millis(800));
653 }
654
655 #[test]
656 fn test_retry_policy_delay_capped_at_max() {
657 let p = RetryPolicy {
658 max_attempts: 10,
659 base_delay_ms: 1_000,
660 max_delay_ms: 3_000,
661 };
662 assert_eq!(p.delay_for_attempt(10), Duration::from_millis(3_000));
664 }
665
666 #[test]
667 fn test_execute_with_retry_success_first_try() {
668 let policy = RetryPolicy {
669 max_attempts: 3,
670 base_delay_ms: 0,
671 max_delay_ms: 0,
672 };
673 let mut call_count = 0u32;
674 let result: std::result::Result<i32, &str> = execute_with_retry(&policy, |_attempt| {
675 call_count += 1;
676 Ok(42)
677 });
678 assert_eq!(result, Ok(42));
679 assert_eq!(call_count, 1);
680 }
681
682 #[test]
683 fn test_execute_with_retry_succeeds_on_second_attempt() {
684 let policy = RetryPolicy {
685 max_attempts: 3,
686 base_delay_ms: 0,
687 max_delay_ms: 0,
688 };
689 let mut call_count = 0u32;
690 let result: std::result::Result<i32, &str> = execute_with_retry(&policy, |_attempt| {
691 call_count += 1;
692 if call_count < 2 {
693 Err("transient")
694 } else {
695 Ok(99)
696 }
697 });
698 assert_eq!(result, Ok(99));
699 assert_eq!(call_count, 2);
700 }
701
702 #[test]
703 fn test_execute_with_retry_exhausts_attempts() {
704 let policy = RetryPolicy {
705 max_attempts: 3,
706 base_delay_ms: 0,
707 max_delay_ms: 0,
708 };
709 let mut call_count = 0u32;
710 let result: std::result::Result<i32, &str> = execute_with_retry(&policy, |_attempt| {
711 call_count += 1;
712 Err("always fails")
713 });
714 assert!(result.is_err());
715 assert_eq!(call_count, 3);
716 }
717}