use crate::streaming::channel::DELTA_MIN_SIZE;
use crate::streaming::channel::SyncStats;
use crate::streaming::protocol::{
Data, DataEnd, DataFlags, Delete, DeleteEnd, DestFileEnd, DestFileEntry, DestFileFlags, FileEnd, FileEntry, MessageType, Mkdir, Symlink,
};
use crate::temp_file::TempFileGuard;
use anyhow::{Context, Result};
use bytes::{Buf, Bytes, BytesMut};
use std::collections::HashMap;
use std::path::{Component, Path, PathBuf};
use tokio::fs::{self, File, OpenOptions};
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom};
const MAX_DELTA_COPY_SIZE: usize = 16 * 1024 * 1024;
const DEST_ENTRY_BATCH_SIZE: usize = 64 * 1024;
fn validate_path(root: &Path, relative: &str) -> Result<PathBuf> {
if relative.is_empty() {
anyhow::bail!("Empty path not allowed");
}
let rel_path = Path::new(relative);
if rel_path.is_absolute() {
anyhow::bail!("Absolute paths not allowed: {}", relative);
}
for component in rel_path.components() {
match component {
Component::ParentDir => {
anyhow::bail!("Path traversal not allowed: {}", relative);
}
Component::Prefix(_) => {
anyhow::bail!("Windows prefix paths not allowed: {}", relative);
}
_ => {}
}
}
let full = root.join(rel_path);
let normalized = normalize_path(&full);
let root_normalized = normalize_path(root);
if !normalized.starts_with(&root_normalized) {
anyhow::bail!("Path escapes root directory: {}", relative);
}
Ok(full)
}
fn normalize_path(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::ParentDir => {
normalized.pop();
}
Component::CurDir => {}
c => normalized.push(c),
}
}
normalized
}
fn validate_symlink_target(root: &Path, link_path: &Path, target: &str) -> Result<()> {
let target_path = Path::new(target);
if target_path.is_absolute() {
anyhow::bail!("Absolute symlink targets not allowed: {} -> {}", link_path.display(), target);
}
if let Some(link_parent) = link_path.parent() {
let resolved = link_parent.join(target_path);
let normalized = normalize_path(&resolved);
let root_normalized = normalize_path(root);
if !normalized.starts_with(&root_normalized) {
anyhow::bail!("Symlink target escapes root: {} -> {}", link_path.display(), target);
}
}
Ok(())
}
pub struct ReceiverConfig {
pub root: PathBuf,
pub block_size: u32,
}
pub struct Receiver {
config: ReceiverConfig,
pending_files: HashMap<String, PendingFile>,
stats: SyncStats,
}
struct PendingFile {
entry: FileEntry,
temp_path: PathBuf,
file: Option<File>,
original_file: Option<File>,
bytes_written: u64,
guard: Option<TempFileGuard>,
}
impl Receiver {
pub fn new(config: ReceiverConfig) -> Self {
Self { config, pending_files: HashMap::new(), stats: SyncStats::new() }
}
pub async fn scan_dest<F>(&self, mut on_entry: F) -> Result<(u64, u64)>
where
F: FnMut(Bytes) -> Result<()>,
{
let mut total_files = 0u64;
let mut total_bytes = 0u64;
let scanner = crate::sync::scanner::Scanner::new(&self.config.root);
let entries = tokio::task::spawn_blocking(move || scanner.scan()).await??;
let mut batch = BytesMut::with_capacity(DEST_ENTRY_BATCH_SIZE);
for entry in entries {
let rel_path = entry.relative_path.as_ref();
let path_str = rel_path.to_string_lossy().to_string();
if path_str.is_empty() {
continue;
}
let mut flags = DestFileFlags::empty();
if entry.is_dir {
flags |= DestFileFlags::DIR;
}
let (block_size, checksums) = if !entry.is_dir && entry.size >= DELTA_MIN_SIZE {
flags |= DestFileFlags::HAS_CHECKSUMS;
let cs = self.compute_checksums(&entry.path).await?;
(self.config.block_size, cs)
} else {
(0, vec![])
};
let mtime = entry.modified.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs() as i64;
let mode = if entry.is_dir { 0o755 } else { 0o644 };
let dest_entry = DestFileEntry { path: path_str, size: entry.size, mtime, mode, flags, block_size, checksums };
let encoded = dest_entry.encode();
batch.extend_from_slice(&encoded);
if batch.len() >= DEST_ENTRY_BATCH_SIZE {
on_entry(batch.split().freeze())?;
}
total_files += 1;
total_bytes += entry.size;
}
if !batch.is_empty() {
on_entry(batch.freeze())?;
}
let end = DestFileEnd { total_files, total_bytes };
on_entry(end.encode())?;
Ok((total_files, total_bytes))
}
async fn compute_checksums(&self, path: &Path) -> Result<Vec<crate::streaming::protocol::BlockChecksum>> {
let p = path.to_path_buf();
let bs = self.config.block_size as usize;
let checksums = tokio::task::spawn_blocking(move || crate::delta::checksum::compute_checksums(&p, bs)).await??;
Ok(checksums
.into_iter()
.map(|c| crate::streaming::protocol::BlockChecksum { offset: c.offset, weak: c.weak, strong: c.strong })
.collect())
}
pub async fn handle_message(&mut self, msg_type: MessageType, payload: Bytes) -> Result<()> {
match msg_type {
MessageType::FileEntry => {
let entry = FileEntry::decode(payload)?;
self.handle_file_entry(entry).await?;
}
MessageType::Data => {
let data = Data::decode(payload)?;
self.handle_data(data).await?;
}
MessageType::DataEnd => {
let end = DataEnd::decode(payload)?;
self.handle_data_end(end).await?;
}
MessageType::Mkdir => {
let mkdir = Mkdir::decode(payload)?;
self.handle_mkdir(mkdir).await?;
}
MessageType::Symlink => {
let symlink = Symlink::decode(payload)?;
self.handle_symlink(symlink).await?;
}
MessageType::Delete => {
let delete = Delete::decode(payload)?;
self.handle_delete(delete).await?;
}
MessageType::FileEnd => {
let _end = FileEnd::decode(payload)?;
}
MessageType::DeleteEnd => {
let _end = DeleteEnd::decode(payload)?;
}
_ => {
}
}
Ok(())
}
async fn handle_file_entry(&mut self, entry: FileEntry) -> Result<()> {
let full_path = validate_path(&self.config.root, &entry.path)?;
if let Some(parent) = full_path.parent() {
fs::create_dir_all(parent).await?;
}
let temp_path = full_path.with_extension("sy.tmp");
let guard = TempFileGuard::new(&temp_path);
let file = OpenOptions::new().write(true).create(true).truncate(true).open(&temp_path).await?;
self.pending_files.insert(
entry.path.clone(),
PendingFile {
entry,
temp_path,
file: Some(file),
original_file: None, bytes_written: 0,
guard: Some(guard),
},
);
Ok(())
}
async fn handle_data(&mut self, data: Data) -> Result<()> {
let root = self.config.root.clone();
let pending = self.pending_files.get_mut(&data.path).ok_or_else(|| anyhow::anyhow!("No pending file for {}", data.path))?;
if let Some(ref mut file) = pending.file {
if data.flags.contains(DataFlags::DELTA) {
if pending.original_file.is_none() {
let original_path = validate_path(&root, &data.path)?;
pending.original_file = Some(File::open(&original_path).await.context("Failed to open original file for delta application")?);
}
let original = pending.original_file.as_mut().expect("original_file must be set before applying delta");
Self::apply_delta_with_original(file, original, &data.data).await?;
} else {
file.seek(SeekFrom::Start(data.offset)).await?;
file.write_all(&data.data).await?;
}
pending.bytes_written += data.data.len() as u64;
}
Ok(())
}
async fn handle_data_end(&mut self, end: DataEnd) -> Result<()> {
if let Some(mut pending) = self.pending_files.remove(&end.path) {
if let Some(mut file) = pending.file.take() {
file.flush().await?;
file.sync_all().await?;
}
let full_path = validate_path(&self.config.root, &end.path)?;
if end.status == DataEnd::STATUS_OK {
fs::rename(&pending.temp_path, &full_path).await?;
if let Some(guard) = pending.guard.take() {
guard.defuse();
}
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(pending.entry.mode);
if let Err(e) = fs::set_permissions(&full_path, perms).await {
tracing::warn!("Failed to set permissions on {}: {}", full_path.display(), e);
}
}
let mtime = filetime::FileTime::from_unix_time(pending.entry.mtime, 0);
let _ = tokio::task::spawn_blocking(move || filetime::set_file_mtime(&full_path, mtime)).await?;
self.stats.files_ok += 1;
self.stats.bytes_transferred += pending.bytes_written;
} else {
self.stats.files_err += 1;
}
}
Ok(())
}
async fn handle_mkdir(&mut self, mkdir: Mkdir) -> Result<()> {
let full_path = validate_path(&self.config.root, &mkdir.path)?;
fs::create_dir_all(&full_path).await?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(mkdir.mode);
if let Err(e) = fs::set_permissions(&full_path, perms).await {
tracing::warn!("Failed to set permissions on {}: {}", full_path.display(), e);
}
}
self.stats.dirs_created += 1;
Ok(())
}
async fn handle_symlink(&mut self, symlink: Symlink) -> Result<()> {
let full_path = validate_path(&self.config.root, &symlink.path)?;
validate_symlink_target(&self.config.root, &full_path, &symlink.target)?;
let _ = fs::remove_file(&full_path).await;
#[cfg(unix)]
tokio::fs::symlink(&symlink.target, &full_path).await?;
#[cfg(windows)]
tokio::task::spawn_blocking({
let target = symlink.target.clone();
let path = full_path.clone();
move || std::os::windows::fs::symlink_file(&target, &path)
})
.await??;
self.stats.symlinks_created += 1;
Ok(())
}
async fn handle_delete(&mut self, delete: Delete) -> Result<()> {
let full_path = validate_path(&self.config.root, &delete.path)?;
if delete.is_dir {
let _ = fs::remove_dir_all(&full_path).await;
} else {
let _ = fs::remove_file(&full_path).await;
}
self.stats.deleted += 1;
Ok(())
}
async fn apply_delta_with_original(file: &mut File, original: &mut File, delta_data: &[u8]) -> Result<()> {
let file_size = original.metadata().await?.len();
let mut copy_buf = Vec::new();
let mut reader = delta_data;
while reader.has_remaining() {
let op_type = reader.get_u8();
match op_type {
0x00 => {
if reader.remaining() < 12 {
anyhow::bail!("Delta copy op truncated");
}
let offset = reader.get_u64();
let size = reader.get_u32() as usize;
if size > MAX_DELTA_COPY_SIZE {
anyhow::bail!("Delta copy size {} exceeds max {}", size, MAX_DELTA_COPY_SIZE);
}
if offset > file_size {
anyhow::bail!("Delta copy offset {} exceeds file size {}", offset, file_size);
}
if offset.saturating_add(size as u64) > file_size {
anyhow::bail!("Delta copy range {}..{} exceeds file size {}", offset, offset + size as u64, file_size);
}
copy_buf.resize(size, 0);
original.seek(SeekFrom::Start(offset)).await?;
original.read_exact(&mut copy_buf).await?;
file.write_all(©_buf).await?;
}
0x01 => {
if reader.remaining() < 4 {
anyhow::bail!("Delta insert op truncated");
}
let len = reader.get_u32() as usize;
if len > MAX_DELTA_COPY_SIZE {
anyhow::bail!("Delta insert size {} exceeds max {}", len, MAX_DELTA_COPY_SIZE);
}
if reader.remaining() < len {
anyhow::bail!("Delta insert data truncated");
}
copy_buf.resize(len, 0);
reader.copy_to_slice(&mut copy_buf);
file.write_all(©_buf).await?;
}
_ => anyhow::bail!("Unknown delta op type: {}", op_type),
}
}
Ok(())
}
pub fn stats(&self) -> &SyncStats {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn test_receiver_basic() {
let tmp = TempDir::new().unwrap();
let config = ReceiverConfig { root: tmp.path().to_path_buf(), block_size: 4096 };
let mut receiver = Receiver::new(config);
let entry = FileEntry {
path: "test.txt".to_string(),
size: 11,
mtime: 1234567890,
mode: 0o644,
inode: 0,
flags: crate::streaming::protocol::FileFlags::empty(),
symlink_target: None,
link_target: None,
};
receiver.handle_message(MessageType::FileEntry, entry.encode().slice(5..)).await.unwrap();
let data = Data {
path: "test.txt".to_string(),
offset: 0,
flags: crate::streaming::protocol::DataFlags::empty(),
data: Bytes::from("hello world"),
};
receiver.handle_message(MessageType::Data, data.encode().slice(5..)).await.unwrap();
let end = DataEnd { path: "test.txt".to_string(), status: DataEnd::STATUS_OK };
receiver.handle_message(MessageType::DataEnd, end.encode().slice(5..)).await.unwrap();
let content = fs::read_to_string(tmp.path().join("test.txt")).unwrap();
assert_eq!(content, "hello world");
}
}