use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::io;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::debug;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct StreamingTransferConfig {
pub chunk_size: usize,
pub checkpoint_interval: usize,
pub verify_digest: bool,
#[serde(with = "duration_secs")]
pub read_timeout: Duration,
#[serde(with = "duration_secs")]
pub write_timeout: Duration,
}
mod duration_secs {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
s.serialize_u64(d.as_secs())
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
let secs = u64::deserialize(d)?;
Ok(Duration::from_secs(secs))
}
}
impl Default for StreamingTransferConfig {
fn default() -> Self {
Self::tactical()
}
}
impl StreamingTransferConfig {
pub fn datacenter() -> Self {
Self {
chunk_size: 32 * 1024 * 1024, checkpoint_interval: 128,
verify_digest: true,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(60),
}
}
pub fn tactical() -> Self {
Self {
chunk_size: 8 * 1024 * 1024, checkpoint_interval: 64,
verify_digest: true,
read_timeout: Duration::from_secs(120),
write_timeout: Duration::from_secs(120),
}
}
pub fn edge() -> Self {
Self {
chunk_size: 1024 * 1024, checkpoint_interval: 32,
verify_digest: true,
read_timeout: Duration::from_secs(300),
write_timeout: Duration::from_secs(300),
}
}
pub fn custom(chunk_size: usize, checkpoint_interval: usize) -> Self {
Self {
chunk_size,
checkpoint_interval,
verify_digest: true,
read_timeout: Duration::from_secs(120),
write_timeout: Duration::from_secs(120),
}
}
pub fn checkpoint_bytes(&self) -> u64 {
self.chunk_size as u64 * self.checkpoint_interval as u64
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TransferCheckpoint {
pub session_id: String,
pub digest: String,
pub total_size: u64,
pub offset: u64,
pub chunks_completed: u64,
pub partial_sha256: Vec<u8>,
pub upload_session_url: Option<String>,
}
impl TransferCheckpoint {
pub fn new(session_id: &str, digest: &str, total_size: u64) -> Self {
Self {
session_id: session_id.to_string(),
digest: digest.to_string(),
total_size,
offset: 0,
chunks_completed: 0,
partial_sha256: Vec::new(),
upload_session_url: None,
}
}
pub fn is_complete(&self) -> bool {
self.offset >= self.total_size
}
pub fn progress(&self) -> f64 {
if self.total_size == 0 {
return 1.0;
}
self.offset as f64 / self.total_size as f64
}
pub fn remaining(&self) -> u64 {
self.total_size.saturating_sub(self.offset)
}
}
#[derive(Clone, Debug)]
pub struct TransferResult {
pub bytes_transferred: u64,
pub total_size: u64,
pub computed_digest: String,
pub resumed: bool,
pub checkpoints_saved: u64,
}
pub type CheckpointCallback = Box<dyn FnMut(&TransferCheckpoint) -> io::Result<()> + Send>;
pub async fn stream_transfer<R, W>(
mut source: R,
mut target: W,
config: &StreamingTransferConfig,
checkpoint: &mut TransferCheckpoint,
mut on_checkpoint: Option<CheckpointCallback>,
) -> io::Result<TransferResult>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let resumed = checkpoint.offset > 0;
let initial_offset = checkpoint.offset;
let mut hasher = Sha256::new();
let mut buf = vec![0u8; config.chunk_size];
let mut checkpoints_saved: u64 = 0;
if resumed {
let mut skip_remaining = checkpoint.offset;
while skip_remaining > 0 {
let to_read = (skip_remaining as usize).min(buf.len());
let n = source.read(&mut buf[..to_read]).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"source ended at {} while skipping to offset {}",
checkpoint.total_size - skip_remaining,
checkpoint.offset
),
));
}
hasher.update(&buf[..n]);
skip_remaining -= n as u64;
}
debug!(
session_id = %checkpoint.session_id,
offset = checkpoint.offset,
"resumed transfer, skipped to offset"
);
}
loop {
let n = tokio::time::timeout(config.read_timeout, source.read(&mut buf))
.await
.map_err(|_| {
io::Error::new(
io::ErrorKind::TimedOut,
format!(
"read timed out after {:?} at offset {}",
config.read_timeout, checkpoint.offset
),
)
})?
.map_err(|e| {
io::Error::new(
e.kind(),
format!("read failed at offset {}: {e}", checkpoint.offset),
)
})?;
if n == 0 {
break; }
hasher.update(&buf[..n]);
tokio::time::timeout(config.write_timeout, target.write_all(&buf[..n]))
.await
.map_err(|_| {
io::Error::new(
io::ErrorKind::TimedOut,
format!(
"write timed out after {:?} at offset {}",
config.write_timeout, checkpoint.offset
),
)
})?
.map_err(|e| {
io::Error::new(
e.kind(),
format!("write failed at offset {}: {e}", checkpoint.offset),
)
})?;
checkpoint.offset += n as u64;
checkpoint.chunks_completed += 1;
if checkpoint
.chunks_completed
.is_multiple_of(config.checkpoint_interval as u64)
{
if let Some(ref mut cb) = on_checkpoint {
cb(checkpoint)?;
checkpoints_saved += 1;
debug!(
session_id = %checkpoint.session_id,
offset = checkpoint.offset,
progress = format!("{:.1}%", checkpoint.progress() * 100.0),
"checkpoint saved"
);
}
}
}
target.flush().await?;
let hash = hasher.finalize();
let computed_digest = format!("sha256:{}", hex::encode(hash));
if config.verify_digest && !checkpoint.digest.is_empty() && computed_digest != checkpoint.digest
{
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"digest mismatch: expected {}, computed {}",
checkpoint.digest, computed_digest
),
));
}
let bytes_transferred = checkpoint.offset - initial_offset;
Ok(TransferResult {
bytes_transferred,
total_size: checkpoint.offset,
computed_digest,
resumed,
checkpoints_saved,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_config_profiles() {
let dc = StreamingTransferConfig::datacenter();
assert_eq!(dc.chunk_size, 32 * 1024 * 1024);
assert_eq!(dc.checkpoint_interval, 128);
assert_eq!(dc.checkpoint_bytes(), 32 * 1024 * 1024 * 128);
let tac = StreamingTransferConfig::tactical();
assert_eq!(tac.chunk_size, 8 * 1024 * 1024);
let edge = StreamingTransferConfig::edge();
assert_eq!(edge.chunk_size, 1024 * 1024);
assert_eq!(edge.checkpoint_interval, 32);
}
#[test]
fn test_config_custom() {
let c = StreamingTransferConfig::custom(4096, 10);
assert_eq!(c.chunk_size, 4096);
assert_eq!(c.checkpoint_interval, 10);
assert_eq!(c.checkpoint_bytes(), 40960);
}
#[test]
fn test_checkpoint_new() {
let cp = TransferCheckpoint::new("sess-1", "sha256:abc", 1000);
assert_eq!(cp.session_id, "sess-1");
assert_eq!(cp.offset, 0);
assert!(!cp.is_complete());
assert_eq!(cp.remaining(), 1000);
assert!((cp.progress() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_checkpoint_progress() {
let mut cp = TransferCheckpoint::new("sess-1", "sha256:abc", 1000);
cp.offset = 500;
assert!((cp.progress() - 0.5).abs() < f64::EPSILON);
assert_eq!(cp.remaining(), 500);
assert!(!cp.is_complete());
cp.offset = 1000;
assert!(cp.is_complete());
assert_eq!(cp.remaining(), 0);
}
#[test]
fn test_checkpoint_zero_size() {
let cp = TransferCheckpoint::new("sess-1", "", 0);
assert!(cp.is_complete());
assert!((cp.progress() - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_checkpoint_serde_roundtrip() {
let mut cp = TransferCheckpoint::new("sess-1", "sha256:abc", 5000);
cp.offset = 2048;
cp.chunks_completed = 4;
cp.upload_session_url = Some("https://registry.example.com/upload/123".to_string());
let json = serde_json::to_string(&cp).unwrap();
let deserialized: TransferCheckpoint = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.session_id, "sess-1");
assert_eq!(deserialized.offset, 2048);
assert_eq!(deserialized.chunks_completed, 4);
assert!(deserialized.upload_session_url.is_some());
}
#[tokio::test]
async fn test_stream_transfer_small_blob() {
let data = b"hello world, this is a test blob";
let source = Cursor::new(data.to_vec());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(16, 2); let mut checkpoint = TransferCheckpoint::new("test-1", "", data.len() as u64);
let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
.await
.unwrap();
assert_eq!(target, data);
assert_eq!(result.bytes_transferred, data.len() as u64);
assert_eq!(result.total_size, data.len() as u64);
assert!(!result.resumed);
assert!(result.computed_digest.starts_with("sha256:"));
}
#[tokio::test]
async fn test_stream_transfer_with_checkpoints() {
let data = vec![0xABu8; 1024]; let source = Cursor::new(data.clone());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(100, 2); let mut checkpoint = TransferCheckpoint::new("test-2", "", data.len() as u64);
let on_checkpoint: CheckpointCallback = Box::new(|_cp| Ok(()));
let result = stream_transfer(
source,
&mut target,
&config,
&mut checkpoint,
Some(on_checkpoint),
)
.await
.unwrap();
assert_eq!(target, data);
assert!(result.checkpoints_saved > 0);
assert!(result.checkpoints_saved >= 4);
}
#[tokio::test]
async fn test_stream_transfer_digest_verification() {
let data = b"test data for digest verification";
let mut hasher = Sha256::new();
hasher.update(data);
let expected = format!("sha256:{}", hex::encode(hasher.finalize()));
let source = Cursor::new(data.to_vec());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(1024, 1);
let mut checkpoint = TransferCheckpoint::new("test-3", &expected, data.len() as u64);
let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
.await
.unwrap();
assert_eq!(result.computed_digest, expected);
}
#[tokio::test]
async fn test_stream_transfer_digest_mismatch() {
let data = b"test data";
let source = Cursor::new(data.to_vec());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(1024, 1);
let mut checkpoint = TransferCheckpoint::new("test-4", "sha256:wrong", data.len() as u64);
let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("digest mismatch"));
}
#[tokio::test]
async fn test_stream_transfer_resume() {
let data = vec![0xCDu8; 200];
let source = Cursor::new(data.clone());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(50, 1);
let mut checkpoint = TransferCheckpoint::new("test-5", "", data.len() as u64);
checkpoint.offset = 50;
let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
.await
.unwrap();
assert!(result.resumed);
assert_eq!(result.bytes_transferred, 150);
assert_eq!(target.len(), 150);
assert_eq!(result.total_size, 200);
}
#[tokio::test]
async fn test_stream_transfer_empty() {
let source = Cursor::new(Vec::<u8>::new());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(1024, 1);
let mut checkpoint = TransferCheckpoint::new("test-6", "", 0);
let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
.await
.unwrap();
assert!(target.is_empty());
assert_eq!(result.bytes_transferred, 0);
assert!(checkpoint.is_complete());
}
#[tokio::test]
async fn test_stream_transfer_exact_chunk_boundary() {
let data = vec![0xEFu8; 300];
let source = Cursor::new(data.clone());
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(100, 1);
let mut checkpoint = TransferCheckpoint::new("test-7", "", 300);
let result = stream_transfer(source, &mut target, &config, &mut checkpoint, None)
.await
.unwrap();
assert_eq!(target, data);
assert_eq!(result.bytes_transferred, 300);
}
#[tokio::test]
async fn test_stream_transfer_checkpoint_callback_error() {
let data = vec![0u8; 500];
let source = Cursor::new(data);
let mut target = Vec::new();
let config = StreamingTransferConfig::custom(50, 2);
let mut checkpoint = TransferCheckpoint::new("test-8", "", 500);
let on_checkpoint: CheckpointCallback = Box::new(|_cp| {
Err(io::Error::new(
io::ErrorKind::Other,
"checkpoint store full",
))
});
let result = stream_transfer(
source,
&mut target,
&config,
&mut checkpoint,
Some(on_checkpoint),
)
.await;
assert!(result.is_err());
}
#[test]
fn test_transfer_result_fields() {
let result = TransferResult {
bytes_transferred: 5000,
total_size: 10000,
computed_digest: "sha256:abc".to_string(),
resumed: true,
checkpoints_saved: 3,
};
assert_eq!(result.bytes_transferred, 5000);
assert!(result.resumed);
}
}