use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::error::DownloadError;
const MAGIC: u32 = 0x4259_4845; const VERSION: u32 = 1;
const HEADER_SIZE: usize = 16;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ControlSnapshot {
pub url: String,
pub total_size: u64,
pub piece_size: u64,
pub piece_count: usize,
pub completed_bitset: Vec<u8>,
pub downloaded_bytes: u64,
pub etag: Option<String>,
pub last_modified: Option<String>,
}
impl ControlSnapshot {
pub fn control_path(output_path: &Path) -> PathBuf {
let mut p = output_path.as_os_str().to_os_string();
p.push(".bytehaul");
PathBuf::from(p)
}
pub async fn save(&self, control_path: &Path) -> Result<(), DownloadError> {
tracing::debug!(path = %control_path.display(), downloaded_bytes = self.downloaded_bytes, "saving control file");
let snapshot = self.clone();
let path = control_path.to_path_buf();
tokio::task::spawn_blocking(move || save_sync(&snapshot, &path))
.await
.map_err(|e| DownloadError::TaskFailed(format!("spawn_blocking: {e}")))?
}
pub async fn load(control_path: &Path) -> Result<Self, DownloadError> {
tracing::debug!(path = %control_path.display(), "loading control file");
let path = control_path.to_path_buf();
tokio::task::spawn_blocking(move || load_sync(&path))
.await
.map_err(|e| DownloadError::TaskFailed(format!("spawn_blocking: {e}")))?
}
pub async fn delete(control_path: &Path) -> Result<(), DownloadError> {
tracing::debug!(path = %control_path.display(), "deleting control file");
match tokio::fs::remove_file(control_path).await {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(DownloadError::Io(e)),
}
}
}
fn save_sync(snapshot: &ControlSnapshot, path: &Path) -> Result<(), DownloadError> {
use std::io::Write;
let payload = bincode::serialize(snapshot)
.map_err(|e| DownloadError::Internal(format!("control file serialize failed: {e}")))?;
let checksum = crc32fast::hash(&payload);
let tmp_path = path.with_extension("bytehaul.tmp");
let mut file = std::fs::File::create(&tmp_path)?;
file.write_all(&MAGIC.to_le_bytes())?;
file.write_all(&VERSION.to_le_bytes())?;
file.write_all(&(payload.len() as u32).to_le_bytes())?;
file.write_all(&checksum.to_le_bytes())?;
file.write_all(&payload)?;
file.flush()?;
file.sync_all()?;
drop(file);
std::fs::rename(&tmp_path, path)?;
Ok(())
}
fn load_sync(path: &Path) -> Result<ControlSnapshot, DownloadError> {
use std::io::Read;
let mut file = std::fs::File::open(path)?;
let mut data = Vec::new();
file.read_to_end(&mut data)?;
if data.len() < HEADER_SIZE {
return Err(DownloadError::ControlFileCorrupted("too short".into()));
}
let magic = u32::from_le_bytes(data[0..4].try_into().unwrap());
let version = u32::from_le_bytes(data[4..8].try_into().unwrap());
let payload_len = u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize;
let checksum = u32::from_le_bytes(data[12..16].try_into().unwrap());
if magic != MAGIC {
return Err(DownloadError::ControlFileCorrupted("bad magic".into()));
}
if version != VERSION {
return Err(DownloadError::ControlFileCorrupted(format!(
"unsupported version: {version}"
)));
}
if data.len() < HEADER_SIZE + payload_len {
return Err(DownloadError::ControlFileCorrupted(
"truncated payload".into(),
));
}
let payload = &data[HEADER_SIZE..HEADER_SIZE + payload_len];
let actual_checksum = crc32fast::hash(payload);
if actual_checksum != checksum {
return Err(DownloadError::ControlFileCorrupted(
"checksum mismatch".into(),
));
}
bincode::deserialize(payload)
.map_err(|e| DownloadError::ControlFileCorrupted(format!("deserialize: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_save_load_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("test.bytehaul");
let snapshot = ControlSnapshot {
url: "https://example.com/file.bin".to_string(),
total_size: 1_000_000,
piece_size: 1_000_000,
piece_count: 1,
completed_bitset: vec![0],
downloaded_bytes: 500_000,
etag: Some("\"abc123\"".to_string()),
last_modified: Some("Thu, 01 Jan 2026 00:00:00 GMT".to_string()),
};
snapshot.save(&ctrl_path).await.unwrap();
let loaded = ControlSnapshot::load(&ctrl_path).await.unwrap();
assert_eq!(loaded.url, snapshot.url);
assert_eq!(loaded.total_size, snapshot.total_size);
assert_eq!(loaded.downloaded_bytes, snapshot.downloaded_bytes);
assert_eq!(loaded.etag, snapshot.etag);
assert_eq!(loaded.last_modified, snapshot.last_modified);
}
#[tokio::test]
async fn test_load_corrupted() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("bad.bytehaul");
std::fs::write(&ctrl_path, b"garbage data here").unwrap();
let result = ControlSnapshot::load(&ctrl_path).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_delete_nonexistent() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("nope.bytehaul");
ControlSnapshot::delete(&ctrl_path).await.unwrap();
}
#[tokio::test]
async fn test_control_path_derivation() {
let path = std::path::PathBuf::from("/tmp/myfile.bin");
let ctrl = ControlSnapshot::control_path(&path);
assert!(ctrl.to_str().unwrap().ends_with(".bytehaul"));
}
#[tokio::test]
async fn test_load_bad_magic() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("bad_magic.bytehaul");
let mut data = vec![0u8; 20];
data[0..4].copy_from_slice(&0xDEADBEEFu32.to_le_bytes());
data[4..8].copy_from_slice(&1u32.to_le_bytes());
data[8..12].copy_from_slice(&0u32.to_le_bytes());
data[12..16].copy_from_slice(&0u32.to_le_bytes());
std::fs::write(&ctrl_path, &data).unwrap();
let result = ControlSnapshot::load(&ctrl_path).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("bad magic"));
}
#[tokio::test]
async fn test_load_bad_version() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("bad_ver.bytehaul");
let mut data = vec![0u8; 20];
data[0..4].copy_from_slice(&MAGIC.to_le_bytes());
data[4..8].copy_from_slice(&99u32.to_le_bytes()); data[8..12].copy_from_slice(&0u32.to_le_bytes());
data[12..16].copy_from_slice(&0u32.to_le_bytes());
std::fs::write(&ctrl_path, &data).unwrap();
let result = ControlSnapshot::load(&ctrl_path).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("unsupported version"));
}
#[tokio::test]
async fn test_load_truncated_payload() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("truncated.bytehaul");
let mut data = vec![0u8; 16];
data[0..4].copy_from_slice(&MAGIC.to_le_bytes());
data[4..8].copy_from_slice(&VERSION.to_le_bytes());
data[8..12].copy_from_slice(&1000u32.to_le_bytes()); data[12..16].copy_from_slice(&0u32.to_le_bytes());
std::fs::write(&ctrl_path, &data).unwrap();
let result = ControlSnapshot::load(&ctrl_path).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("truncated"));
}
#[tokio::test]
async fn test_load_checksum_mismatch() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("bad_crc.bytehaul");
let snapshot = ControlSnapshot {
url: "https://example.com/file.bin".to_string(),
total_size: 100,
piece_size: 100,
piece_count: 1,
completed_bitset: vec![0],
downloaded_bytes: 0,
etag: None,
last_modified: None,
};
snapshot.save(&ctrl_path).await.unwrap();
let mut data = std::fs::read(&ctrl_path).unwrap();
if data.len() > HEADER_SIZE + 2 {
data[HEADER_SIZE + 1] ^= 0xFF; }
std::fs::write(&ctrl_path, &data).unwrap();
let result = ControlSnapshot::load(&ctrl_path).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_delete_existing_file() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("to_delete.bytehaul");
std::fs::write(&ctrl_path, b"some data").unwrap();
assert!(ctrl_path.exists());
ControlSnapshot::delete(&ctrl_path).await.unwrap();
assert!(!ctrl_path.exists());
}
#[tokio::test]
async fn test_load_too_short() {
let dir = tempfile::tempdir().unwrap();
let ctrl_path = dir.path().join("short.bytehaul");
std::fs::write(&ctrl_path, b"short").unwrap();
let result = ControlSnapshot::load(&ctrl_path).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("too short"));
}
#[tokio::test]
async fn test_delete_directory_returns_io_error() {
let dir = tempfile::tempdir().unwrap();
let err = ControlSnapshot::delete(dir.path()).await.unwrap_err();
assert!(matches!(err, DownloadError::Io(_)));
}
}