1use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256};
26use std::io;
27use std::time::Duration;
28use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
29use tracing::debug;
30
31#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct StreamingTransferConfig {
34 pub chunk_size: usize,
36 pub checkpoint_interval: usize,
38 pub verify_digest: bool,
40 #[serde(with = "duration_secs")]
42 pub read_timeout: Duration,
43 #[serde(with = "duration_secs")]
45 pub write_timeout: Duration,
46}
47
48mod duration_secs {
50 use serde::{Deserialize, Deserializer, Serializer};
51 use std::time::Duration;
52
53 pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
54 s.serialize_u64(d.as_secs())
55 }
56
57 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
58 let secs = u64::deserialize(d)?;
59 Ok(Duration::from_secs(secs))
60 }
61}
62
63impl Default for StreamingTransferConfig {
64 fn default() -> Self {
65 Self::tactical()
66 }
67}
68
69impl StreamingTransferConfig {
70 pub fn datacenter() -> Self {
72 Self {
73 chunk_size: 32 * 1024 * 1024, checkpoint_interval: 128,
75 verify_digest: true,
76 read_timeout: Duration::from_secs(60),
77 write_timeout: Duration::from_secs(60),
78 }
79 }
80
81 pub fn tactical() -> Self {
83 Self {
84 chunk_size: 8 * 1024 * 1024, checkpoint_interval: 64,
86 verify_digest: true,
87 read_timeout: Duration::from_secs(120),
88 write_timeout: Duration::from_secs(120),
89 }
90 }
91
92 pub fn edge() -> Self {
94 Self {
95 chunk_size: 1024 * 1024, checkpoint_interval: 32,
97 verify_digest: true,
98 read_timeout: Duration::from_secs(300),
99 write_timeout: Duration::from_secs(300),
100 }
101 }
102
103 pub fn custom(chunk_size: usize, checkpoint_interval: usize) -> Self {
105 Self {
106 chunk_size,
107 checkpoint_interval,
108 verify_digest: true,
109 read_timeout: Duration::from_secs(120),
110 write_timeout: Duration::from_secs(120),
111 }
112 }
113
114 pub fn checkpoint_bytes(&self) -> u64 {
116 self.chunk_size as u64 * self.checkpoint_interval as u64
117 }
118}
119
120#[derive(Clone, Debug, Serialize, Deserialize)]
124pub struct TransferCheckpoint {
125 pub session_id: String,
127 pub digest: String,
129 pub total_size: u64,
131 pub offset: u64,
133 pub chunks_completed: u64,
135 pub partial_sha256: Vec<u8>,
139 pub upload_session_url: Option<String>,
141}
142
143impl TransferCheckpoint {
144 pub fn new(session_id: &str, digest: &str, total_size: u64) -> Self {
146 Self {
147 session_id: session_id.to_string(),
148 digest: digest.to_string(),
149 total_size,
150 offset: 0,
151 chunks_completed: 0,
152 partial_sha256: Vec::new(),
153 upload_session_url: None,
154 }
155 }
156
157 pub fn is_complete(&self) -> bool {
159 self.offset >= self.total_size
160 }
161
162 pub fn progress(&self) -> f64 {
164 if self.total_size == 0 {
165 return 1.0;
166 }
167 self.offset as f64 / self.total_size as f64
168 }
169
170 pub fn remaining(&self) -> u64 {
172 self.total_size.saturating_sub(self.offset)
173 }
174}
175
176#[derive(Clone, Debug)]
178pub struct TransferResult {
179 pub bytes_transferred: u64,
181 pub total_size: u64,
183 pub computed_digest: String,
185 pub resumed: bool,
187 pub checkpoints_saved: u64,
189}
190
191pub type CheckpointCallback = Box<dyn FnMut(&TransferCheckpoint) -> io::Result<()> + Send>;
196
197pub async fn stream_transfer<R, W>(
215 mut source: R,
216 mut target: W,
217 config: &StreamingTransferConfig,
218 checkpoint: &mut TransferCheckpoint,
219 mut on_checkpoint: Option<CheckpointCallback>,
220) -> io::Result<TransferResult>
221where
222 R: AsyncRead + Unpin,
223 W: AsyncWrite + Unpin,
224{
225 let resumed = checkpoint.offset > 0;
226 let initial_offset = checkpoint.offset;
227 let mut hasher = Sha256::new();
228 let mut buf = vec![0u8; config.chunk_size];
229 let mut checkpoints_saved: u64 = 0;
230
231 if resumed {
234 let mut skip_remaining = checkpoint.offset;
235 while skip_remaining > 0 {
236 let to_read = (skip_remaining as usize).min(buf.len());
237 let n = source.read(&mut buf[..to_read]).await?;
238 if n == 0 {
239 return Err(io::Error::new(
240 io::ErrorKind::UnexpectedEof,
241 format!(
242 "source ended at {} while skipping to offset {}",
243 checkpoint.total_size - skip_remaining,
244 checkpoint.offset
245 ),
246 ));
247 }
248 hasher.update(&buf[..n]);
250 skip_remaining -= n as u64;
251 }
252 debug!(
253 session_id = %checkpoint.session_id,
254 offset = checkpoint.offset,
255 "resumed transfer, skipped to offset"
256 );
257 }
258
259 loop {
261 let n = tokio::time::timeout(config.read_timeout, source.read(&mut buf))
262 .await
263 .map_err(|_| {
264 io::Error::new(
265 io::ErrorKind::TimedOut,
266 format!(
267 "read timed out after {:?} at offset {}",
268 config.read_timeout, checkpoint.offset
269 ),
270 )
271 })?
272 .map_err(|e| {
273 io::Error::new(
274 e.kind(),
275 format!("read failed at offset {}: {e}", checkpoint.offset),
276 )
277 })?;
278 if n == 0 {
279 break; }
281
282 hasher.update(&buf[..n]);
284
285 tokio::time::timeout(config.write_timeout, target.write_all(&buf[..n]))
287 .await
288 .map_err(|_| {
289 io::Error::new(
290 io::ErrorKind::TimedOut,
291 format!(
292 "write timed out after {:?} at offset {}",
293 config.write_timeout, checkpoint.offset
294 ),
295 )
296 })?
297 .map_err(|e| {
298 io::Error::new(
299 e.kind(),
300 format!("write failed at offset {}: {e}", checkpoint.offset),
301 )
302 })?;
303
304 checkpoint.offset += n as u64;
306 checkpoint.chunks_completed += 1;
307
308 if checkpoint
310 .chunks_completed
311 .is_multiple_of(config.checkpoint_interval as u64)
312 {
313 if let Some(ref mut cb) = on_checkpoint {
314 cb(checkpoint)?;
315 checkpoints_saved += 1;
316 debug!(
317 session_id = %checkpoint.session_id,
318 offset = checkpoint.offset,
319 progress = format!("{:.1}%", checkpoint.progress() * 100.0),
320 "checkpoint saved"
321 );
322 }
323 }
324 }
325
326 target.flush().await?;
327
328 let hash = hasher.finalize();
330 let computed_digest = format!("sha256:{}", hex::encode(hash));
331
332 if config.verify_digest && !checkpoint.digest.is_empty() && computed_digest != checkpoint.digest
334 {
335 return Err(io::Error::new(
336 io::ErrorKind::InvalidData,
337 format!(
338 "digest mismatch: expected {}, computed {}",
339 checkpoint.digest, computed_digest
340 ),
341 ));
342 }
343
344 let bytes_transferred = checkpoint.offset - initial_offset;
345
346 Ok(TransferResult {
347 bytes_transferred,
348 total_size: checkpoint.offset,
349 computed_digest,
350 resumed,
351 checkpoints_saved,
352 })
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use std::io::Cursor;
359
360 #[test]
361 fn test_config_profiles() {
362 let dc = StreamingTransferConfig::datacenter();
363 assert_eq!(dc.chunk_size, 32 * 1024 * 1024);
364 assert_eq!(dc.checkpoint_interval, 128);
365 assert_eq!(dc.checkpoint_bytes(), 32 * 1024 * 1024 * 128);
366
367 let tac = StreamingTransferConfig::tactical();
368 assert_eq!(tac.chunk_size, 8 * 1024 * 1024);
369
370 let edge = StreamingTransferConfig::edge();
371 assert_eq!(edge.chunk_size, 1024 * 1024);
372 assert_eq!(edge.checkpoint_interval, 32);
373 }
374
375 #[test]
376 fn test_config_custom() {
377 let c = StreamingTransferConfig::custom(4096, 10);
378 assert_eq!(c.chunk_size, 4096);
379 assert_eq!(c.checkpoint_interval, 10);
380 assert_eq!(c.checkpoint_bytes(), 40960);
381 }
382
383 #[test]
384 fn test_checkpoint_new() {
385 let cp = TransferCheckpoint::new("sess-1", "sha256:abc", 1000);
386 assert_eq!(cp.session_id, "sess-1");
387 assert_eq!(cp.offset, 0);
388 assert!(!cp.is_complete());
389 assert_eq!(cp.remaining(), 1000);
390 assert!((cp.progress() - 0.0).abs() < f64::EPSILON);
391 }
392
393 #[test]
394 fn test_checkpoint_progress() {
395 let mut cp = TransferCheckpoint::new("sess-1", "sha256:abc", 1000);
396 cp.offset = 500;
397 assert!((cp.progress() - 0.5).abs() < f64::EPSILON);
398 assert_eq!(cp.remaining(), 500);
399 assert!(!cp.is_complete());
400
401 cp.offset = 1000;
402 assert!(cp.is_complete());
403 assert_eq!(cp.remaining(), 0);
404 }
405
406 #[test]
407 fn test_checkpoint_zero_size() {
408 let cp = TransferCheckpoint::new("sess-1", "", 0);
409 assert!(cp.is_complete());
410 assert!((cp.progress() - 1.0).abs() < f64::EPSILON);
411 }
412
413 #[test]
414 fn test_checkpoint_serde_roundtrip() {
415 let mut cp = TransferCheckpoint::new("sess-1", "sha256:abc", 5000);
416 cp.offset = 2048;
417 cp.chunks_completed = 4;
418 cp.upload_session_url = Some("https://registry.example.com/upload/123".to_string());
419
420 let json = serde_json::to_string(&cp).unwrap();
421 let deserialized: TransferCheckpoint = serde_json::from_str(&json).unwrap();
422
423 assert_eq!(deserialized.session_id, "sess-1");
424 assert_eq!(deserialized.offset, 2048);
425 assert_eq!(deserialized.chunks_completed, 4);
426 assert!(deserialized.upload_session_url.is_some());
427 }
428
429 #[tokio::test]
430 async fn test_stream_transfer_small_blob() {
431 let data = b"hello world, this is a test blob";
432 let source = Cursor::new(data.to_vec());
433 let mut target = Vec::new();
434
435 let config = StreamingTransferConfig::custom(16, 2); let mut checkpoint = TransferCheckpoint::new("test-1", "", data.len() as u64);
437
438 let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
439 .await
440 .unwrap();
441
442 assert_eq!(target, data);
443 assert_eq!(result.bytes_transferred, data.len() as u64);
444 assert_eq!(result.total_size, data.len() as u64);
445 assert!(!result.resumed);
446 assert!(result.computed_digest.starts_with("sha256:"));
447 }
448
449 #[tokio::test]
450 async fn test_stream_transfer_with_checkpoints() {
451 let data = vec![0xABu8; 1024]; let source = Cursor::new(data.clone());
453 let mut target = Vec::new();
454
455 let config = StreamingTransferConfig::custom(100, 2); let mut checkpoint = TransferCheckpoint::new("test-2", "", data.len() as u64);
457
458 let on_checkpoint: CheckpointCallback = Box::new(|_cp| Ok(()));
459
460 let result = stream_transfer(
461 source,
462 &mut target,
463 &config,
464 &mut checkpoint,
465 Some(on_checkpoint),
466 )
467 .await
468 .unwrap();
469
470 assert_eq!(target, data);
471 assert!(result.checkpoints_saved > 0);
472 assert!(result.checkpoints_saved >= 4);
474 }
475
476 #[tokio::test]
477 async fn test_stream_transfer_digest_verification() {
478 let data = b"test data for digest verification";
479
480 let mut hasher = Sha256::new();
482 hasher.update(data);
483 let expected = format!("sha256:{}", hex::encode(hasher.finalize()));
484
485 let source = Cursor::new(data.to_vec());
487 let mut target = Vec::new();
488 let config = StreamingTransferConfig::custom(1024, 1);
489 let mut checkpoint = TransferCheckpoint::new("test-3", &expected, data.len() as u64);
490
491 let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
492 .await
493 .unwrap();
494 assert_eq!(result.computed_digest, expected);
495 }
496
497 #[tokio::test]
498 async fn test_stream_transfer_digest_mismatch() {
499 let data = b"test data";
500 let source = Cursor::new(data.to_vec());
501 let mut target = Vec::new();
502 let config = StreamingTransferConfig::custom(1024, 1);
503 let mut checkpoint = TransferCheckpoint::new("test-4", "sha256:wrong", data.len() as u64);
504
505 let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None).await;
506 assert!(result.is_err());
507 let err = result.unwrap_err();
508 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
509 assert!(err.to_string().contains("digest mismatch"));
510 }
511
512 #[tokio::test]
513 async fn test_stream_transfer_resume() {
514 let data = vec![0xCDu8; 200];
516 let source = Cursor::new(data.clone());
517 let mut target = Vec::new();
518
519 let config = StreamingTransferConfig::custom(50, 1);
520 let mut checkpoint = TransferCheckpoint::new("test-5", "", data.len() as u64);
521 checkpoint.offset = 50; let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
524 .await
525 .unwrap();
526
527 assert!(result.resumed);
528 assert_eq!(result.bytes_transferred, 150);
530 assert_eq!(target.len(), 150);
531 assert_eq!(result.total_size, 200);
532 }
533
534 #[tokio::test]
535 async fn test_stream_transfer_empty() {
536 let source = Cursor::new(Vec::<u8>::new());
537 let mut target = Vec::new();
538 let config = StreamingTransferConfig::custom(1024, 1);
539 let mut checkpoint = TransferCheckpoint::new("test-6", "", 0);
540
541 let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
542 .await
543 .unwrap();
544
545 assert!(target.is_empty());
546 assert_eq!(result.bytes_transferred, 0);
547 assert!(checkpoint.is_complete());
548 }
549
550 #[tokio::test]
551 async fn test_stream_transfer_exact_chunk_boundary() {
552 let data = vec![0xEFu8; 300];
554 let source = Cursor::new(data.clone());
555 let mut target = Vec::new();
556 let config = StreamingTransferConfig::custom(100, 1);
557 let mut checkpoint = TransferCheckpoint::new("test-7", "", 300);
558
559 let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
560 .await
561 .unwrap();
562
563 assert_eq!(target, data);
564 assert_eq!(result.bytes_transferred, 300);
565 }
566
567 #[tokio::test]
568 async fn test_stream_transfer_checkpoint_callback_error() {
569 let data = vec![0u8; 500];
570 let source = Cursor::new(data);
571 let mut target = Vec::new();
572 let config = StreamingTransferConfig::custom(50, 2);
573 let mut checkpoint = TransferCheckpoint::new("test-8", "", 500);
574
575 let on_checkpoint: CheckpointCallback = Box::new(|_cp| {
576 Err(io::Error::new(
577 io::ErrorKind::Other,
578 "checkpoint store full",
579 ))
580 });
581
582 let result = stream_transfer(
583 source,
584 &mut target,
585 &config,
586 &mut checkpoint,
587 Some(on_checkpoint),
588 )
589 .await;
590
591 assert!(result.is_err());
592 }
593
594 #[test]
595 fn test_transfer_result_fields() {
596 let result = TransferResult {
597 bytes_transferred: 5000,
598 total_size: 10000,
599 computed_digest: "sha256:abc".to_string(),
600 resumed: true,
601 checkpoints_saved: 3,
602 };
603 assert_eq!(result.bytes_transferred, 5000);
604 assert!(result.resumed);
605 }
606}