#![allow(dead_code)]
use anyhow::Result;
use bytes::Bytes;
use std::path::{Path, PathBuf};
use tokio::fs;
use tokio::io::{self, AsyncWriteExt};
use tokio::sync::mpsc;
use crate::streaming::{
Generator, GeneratorConfig, Receiver, ReceiverConfig, Sender, SenderConfig,
channel::file_job_channel,
protocol::{self as v2, HelloFlags, MessageType},
};
fn expand_tilde(path: &Path) -> PathBuf {
let path_str = path.to_string_lossy();
if path_str == "~" {
dirs::home_dir().unwrap_or_else(|| PathBuf::from("."))
} else if let Some(rest) = path_str.strip_prefix("~/") {
if let Some(home) = dirs::home_dir() { home.join(rest) } else { path.to_path_buf() }
} else {
path.to_path_buf()
}
}
pub async fn run_server() -> Result<()> {
let args: Vec<String> = std::env::args().collect();
let raw_path = args.last().map(PathBuf::from).unwrap_or_else(|| PathBuf::from("."));
let root_path = expand_tilde(&raw_path);
if !root_path.exists() {
std::fs::create_dir_all(&root_path)?;
}
let mut stdin = io::stdin();
let mut stdout = io::stdout();
let (msg_type, payload) = v2::read_frame(&mut stdin).await?;
if msg_type != MessageType::Hello {
let fatal = v2::Fatal { code: 1, message: format!("Expected HELLO, got {:?}", msg_type) };
v2::write_frame(&mut stdout, &fatal.encode()).await?;
stdout.flush().await?;
return Ok(());
}
let hello = v2::Hello::decode(payload)?;
if !root_path.exists() {
fs::create_dir_all(&root_path).await?;
}
let resp = v2::Hello::new(HelloFlags::empty(), "");
v2::write_frame(&mut stdout, &resp.encode()).await?;
stdout.flush().await?;
if hello.flags.contains(HelloFlags::PULL) {
run_server_pull(hello, root_path, stdin, stdout).await
} else {
run_server_push(hello, root_path, stdin, stdout).await
}
}
async fn run_server_pull(hello: v2::Hello, root_path: PathBuf, mut stdin: impl io::AsyncRead + Unpin, mut stdout: impl io::AsyncWrite + Unpin) -> Result<()> {
let mut generator = Generator::new(GeneratorConfig {
root: root_path.clone(),
include_hidden: true,
follow_symlinks: false,
delete_enabled: hello.flags.contains(HelloFlags::DELETE),
});
loop {
let (msg_type, payload) = v2::read_frame(&mut stdin).await?;
match msg_type {
MessageType::DestFileEntry => {
let entry = v2::DestFileEntry::decode(payload)?;
generator.add_dest_entry(entry);
}
MessageType::DestFileEnd => break,
_ => anyhow::bail!("Unexpected message during Initial Exchange: {:?}", msg_type),
}
}
let (tx, rx) = file_job_channel();
let gen_handle = tokio::spawn(async move { generator.run(tx).await });
let sender = Sender::new(SenderConfig { root: root_path, compress: hello.flags.contains(HelloFlags::COMPRESSION) });
let (data_tx, mut data_rx) = mpsc::unbounded_channel::<Bytes>();
let sender_handle = tokio::spawn(async move { sender.run(rx, |bytes| data_tx.send(bytes).map_err(|_| anyhow::anyhow!("Data channel closed"))).await });
while let Some(bytes) = data_rx.recv().await {
v2::write_frame(&mut stdout, &bytes).await?;
}
stdout.flush().await?;
let (total_files, total_bytes) = gen_handle.await??;
sender_handle.await??;
let done = v2::Done { files_ok: total_files, files_err: 0, bytes: total_bytes, duration_ms: 0 };
v2::write_frame(&mut stdout, &done.encode()).await?;
stdout.flush().await?;
Ok(())
}
async fn run_server_push(_hello: v2::Hello, root_path: PathBuf, mut stdin: impl io::AsyncRead + Unpin, mut stdout: impl io::AsyncWrite + Unpin) -> Result<()> {
let mut receiver = Receiver::new(ReceiverConfig { root: root_path.clone(), block_size: 4096 });
let (data_tx, mut data_rx) = mpsc::unbounded_channel::<Bytes>();
let receiver_root = root_path.clone();
let scan_handle = tokio::spawn(async move {
let receiver = Receiver::new(ReceiverConfig { root: receiver_root, block_size: 4096 });
receiver.scan_dest(|bytes| data_tx.send(bytes).map_err(|_| anyhow::anyhow!("Data channel closed"))).await
});
while let Some(bytes) = data_rx.recv().await {
v2::write_frame(&mut stdout, &bytes).await?;
}
stdout.flush().await?;
scan_handle.await??;
loop {
let (msg_type, payload) = v2::read_frame(&mut stdin).await?;
if msg_type == MessageType::Done {
break;
}
receiver.handle_message(msg_type, payload).await?;
}
let done = v2::Done {
files_ok: receiver.stats().files_ok,
files_err: receiver.stats().files_err,
bytes: receiver.stats().bytes_transferred,
duration_ms: 0,
};
v2::write_frame(&mut stdout, &done.encode()).await?;
stdout.flush().await?;
Ok(())
}