use crate::streaming::{
Generator, GeneratorConfig, Receiver, ReceiverConfig, Sender, SenderConfig,
channel::{SyncStats, file_job_channel},
protocol::{Done, Hello, HelloFlags, MessageType, read_frame, write_frame},
};
use anyhow::Result;
use bytes::Bytes;
use std::path::PathBuf;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
pub struct StreamingSync {
pub local_root: PathBuf,
pub remote_root: PathBuf,
pub delete_enabled: bool,
pub compress: bool,
}
impl StreamingSync {
pub fn new(local_root: PathBuf, remote_root: PathBuf, delete_enabled: bool, compress: bool) -> Self {
Self { local_root, remote_root, delete_enabled, compress }
}
pub async fn push<R, W>(&self, reader: &mut R, writer: &mut W) -> Result<SyncStats>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let hello = Hello::new(HelloFlags::empty(), self.remote_root.to_string_lossy().into_owned());
write_frame(writer, &hello.encode()).await?;
writer.flush().await?;
let (msg_type, payload) = read_frame(reader).await?;
if msg_type != MessageType::Hello {
anyhow::bail!("Expected Hello response, got {:?}", msg_type);
}
let _server_hello = Hello::decode(payload)?;
let mut generator = Generator::new(GeneratorConfig {
root: self.local_root.clone(),
include_hidden: true,
follow_symlinks: false,
delete_enabled: self.delete_enabled,
});
loop {
let (msg_type, payload) = read_frame(reader).await?;
match msg_type {
MessageType::DestFileEntry => {
let entry = crate::streaming::protocol::DestFileEntry::decode(payload)?;
generator.add_dest_entry(entry);
}
MessageType::DestFileEnd => {
break;
}
MessageType::Fatal => {
let fatal = crate::streaming::protocol::Fatal::decode(payload)?;
anyhow::bail!("Remote fatal error: {}", fatal.message);
}
_ => {
anyhow::bail!("Unexpected message during Initial Exchange: {:?}", msg_type);
}
}
}
let (tx, rx) = file_job_channel();
let gen_handle = tokio::spawn(async move { generator.run(tx).await });
let sender = Sender::new(SenderConfig { root: self.local_root.clone(), compress: self.compress });
let (data_tx, mut data_rx) = mpsc::unbounded_channel::<Bytes>();
let sender_handle = tokio::spawn(async move { sender.run(rx, |bytes| data_tx.send(bytes).map_err(|_| anyhow::anyhow!("Data channel closed"))).await });
while let Some(bytes) = data_rx.recv().await {
writer.write_all(&bytes).await?;
}
let client_done = Done { files_ok: 0, files_err: 0, bytes: 0, duration_ms: 0 };
write_frame(writer, &client_done.encode()).await?;
writer.flush().await?;
let (total_files, total_bytes) = gen_handle.await??;
sender_handle.await??;
let (msg_type, payload) = read_frame(reader).await?;
if msg_type == MessageType::Done {
let done = Done::decode(payload)?;
Ok(SyncStats { files_ok: done.files_ok, files_err: done.files_err, bytes_transferred: done.bytes, ..Default::default() })
} else {
Ok(SyncStats { files_ok: total_files, bytes_transferred: total_bytes, ..Default::default() })
}
}
pub async fn pull<R, W>(&self, reader: &mut R, writer: &mut W) -> Result<SyncStats>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut flags = HelloFlags::PULL;
if self.delete_enabled {
flags |= HelloFlags::DELETE;
}
if self.compress {
flags |= HelloFlags::COMPRESSION;
}
let hello = Hello::new(flags, self.remote_root.to_string_lossy().into_owned());
write_frame(writer, &hello.encode()).await?;
writer.flush().await?;
let (msg_type, payload) = read_frame(reader).await?;
if msg_type != MessageType::Hello {
anyhow::bail!("Expected Hello response, got {:?}", msg_type);
}
let _server_hello = Hello::decode(payload)?;
if !self.local_root.exists() {
tokio::fs::create_dir_all(&self.local_root).await?;
}
let (data_tx, mut data_rx) = mpsc::unbounded_channel::<Bytes>();
let receiver_root = self.local_root.clone();
let scan_handle = tokio::spawn(async move {
let receiver = Receiver::new(ReceiverConfig { root: receiver_root, block_size: 4096 });
receiver.scan_dest(|bytes| data_tx.send(bytes).map_err(|_| anyhow::anyhow!("Data channel closed"))).await
});
while let Some(bytes) = data_rx.recv().await {
writer.write_all(&bytes).await?;
}
writer.flush().await?;
scan_handle.await??;
let mut receiver = Receiver::new(ReceiverConfig { root: self.local_root.clone(), block_size: 4096 });
loop {
let (msg_type, payload) = read_frame(reader).await?;
if msg_type == MessageType::Done {
let done = Done::decode(payload)?;
let mut stats = receiver.stats().clone();
stats.files_ok = done.files_ok;
stats.files_err = done.files_err;
stats.bytes_transferred = done.bytes;
return Ok(stats);
}
receiver.handle_message(msg_type, payload).await?;
}
}
}