use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::time::Duration;
use compio::io::AsyncRead;
use crate::{
EngineError,
discovery::BroadcasterGuard,
network,
protocol::{Metadata, TRANSFER_DIR, TRANSFER_FILE},
transfer,
};
#[derive(Debug, Clone)]
pub struct HayateSender {
target: Option<SocketAddr>,
code: Option<String>,
compress: bool,
hash_algo: String,
}
impl Default for HayateSender {
fn default() -> Self {
Self {
target: None,
code: None,
compress: true,
hash_algo: "blake3".to_owned(),
}
}
}
impl HayateSender {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn target(mut self, target: SocketAddr) -> Self {
self.target = Some(target);
self.code = None;
self
}
#[must_use]
pub fn code(mut self, code: String) -> Self {
self.code = Some(code);
self.target = None;
self
}
#[must_use]
pub fn compress(mut self, compress: bool) -> Self {
self.compress = compress;
self
}
#[must_use]
pub fn hash_algo(mut self, algo: String) -> Self {
self.hash_algo = algo;
self
}
pub async fn send(
self,
path: impl AsRef<Path>,
progress_cb: impl FnMut(u64) + Send + 'static,
) -> Result<String, EngineError> {
let path = path.as_ref();
if !path.exists() {
return Err(EngineError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Path does not exist: {}", path.display()),
)));
}
let (meta, _) = self.build_metadata(path)?;
let (_endpoint, conn) = if let Some(target_addr) = self.target {
let endpoint = network::bind_client().await?;
let client_cfg = network::client_config()?;
let connecting = endpoint.connect(target_addr, "hayate.local", Some(client_cfg))?;
let conn = connecting.await?;
(endpoint, conn)
} else {
let phrase = self.code.as_ref().ok_or_else(|| {
EngineError::Handshake("Neither target nor code specified".into())
})?;
let bind_addr =
SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0);
let endpoint = network::bind_server(bind_addr).await?;
let local_port = endpoint.local_addr()?.port();
let phrase_clone = phrase.clone();
let (cancel_tx, cancel_rx) = flume::bounded(1);
let _broadcaster_guard = BroadcasterGuard::new(cancel_tx);
compio::runtime::spawn(async move {
let channel_id = crate::discovery::derive_channel_id(&phrase_clone);
let _ =
crate::discovery::start_broadcaster(&channel_id, local_port, cancel_rx).await;
})
.detach();
let incoming = endpoint
.wait_incoming()
.await
.ok_or_else(|| EngineError::Handshake("Endpoint closed during pairing".into()))?;
let conn = incoming.await?;
(endpoint, conn)
};
let (mut send_stream, mut recv_stream) = conn.open_bi()?;
let (key, cipher_id) = transfer::handshake_sender_split(
&mut send_stream,
&mut recv_stream,
&meta,
self.code.as_deref(),
)
.await?;
let checksum = if path.is_dir() {
self.send_directory(
path,
&key,
cipher_id,
&self.hash_algo,
&mut send_stream,
progress_cb,
)
.await?
} else {
self.send_file(
path,
&key,
cipher_id,
&self.hash_algo,
&mut send_stream,
progress_cb,
)
.await?
};
send_stream.finish()?;
let drain_buf = vec![0u8; 1];
let compio::BufResult(res, _) = recv_stream.read(drain_buf).await;
let _ = res;
conn.close(0u32.into(), b"complete");
Ok(checksum)
}
fn build_metadata(&self, path: &Path) -> Result<(Metadata, u64), EngineError> {
let filename = path
.file_name()
.ok_or_else(|| {
EngineError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Path has no filename",
))
})?
.to_string_lossy()
.into_owned();
if path.is_dir() {
let total = crate::tar::estimate_dir_size(path);
Ok((
Metadata {
filename,
total_size: total,
transfer_type: TRANSFER_DIR,
hash_algo: self.hash_algo.clone(),
},
total,
))
} else {
let total = std::fs::metadata(path).map_err(EngineError::Io)?.len();
Ok((
Metadata {
filename,
total_size: total,
transfer_type: TRANSFER_FILE,
hash_algo: self.hash_algo.clone(),
},
total,
))
}
}
async fn send_file(
&self,
path: &Path,
key: &[u8; 32],
cipher_id: u8,
hash_algo: &str,
stream: &mut compio_quic::SendStream,
progress_cb: impl FnMut(u64) + Send + 'static,
) -> Result<String, EngineError> {
let file = compio::fs::File::open(path)
.await
.map_err(EngineError::Io)?;
let source = transfer::PayloadSource::File { file, pos: 0 };
let filename = path.file_name().and_then(|s| s.to_str());
transfer::send_payload_write(
key,
cipher_id,
source,
stream,
self.compress,
filename,
hash_algo,
progress_cb,
)
.await
}
async fn send_directory(
&self,
dir: &Path,
key: &[u8; 32],
cipher_id: u8,
hash_algo: &str,
stream: &mut compio_quic::SendStream,
progress_cb: impl FnMut(u64) + Send + 'static,
) -> Result<String, EngineError> {
let (tx, rx) = flume::bounded::<Result<Vec<u8>, std::io::Error>>(8);
let dir_clone = dir.to_path_buf();
std::thread::spawn(move || {
use std::io::Write;
struct ChanWriter {
tx: flume::Sender<Result<Vec<u8>, std::io::Error>>,
}
impl std::io::Write for ChanWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.tx.send(Ok(buf.to_vec())).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "receiver gone")
})?;
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let writer = ChanWriter { tx: tx.clone() };
let mut buffered_writer = std::io::BufWriter::with_capacity(128 * 1024, writer);
let mut run = move || -> Result<(), std::io::Error> {
crate::tar::write_tar_sync(&dir_clone, &mut buffered_writer)?;
buffered_writer.flush()?;
Ok(())
};
if let Err(e) = run() {
let _ = tx.send(Err(e));
}
});
let source = transfer::PayloadSource::Channel(rx);
transfer::send_payload_write(
key,
cipher_id,
source,
stream,
self.compress,
None,
hash_algo,
progress_cb,
)
.await
}
}
#[derive(Debug, Clone)]
pub struct HayateReceiver {
bind_addr: SocketAddr,
code: Option<String>,
auto_accept: bool,
}
impl Default for HayateReceiver {
fn default() -> Self {
Self {
bind_addr: "0.0.0.0:50001".parse().unwrap(),
code: None,
auto_accept: false,
}
}
}
impl HayateReceiver {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn bind(mut self, addr: SocketAddr) -> Self {
self.bind_addr = addr;
self
}
#[must_use]
pub fn code(mut self, code: String) -> Self {
self.code = Some(code);
self
}
#[must_use]
pub fn auto_accept(mut self, auto_accept: bool) -> Self {
self.auto_accept = auto_accept;
self
}
pub async fn receive(
self,
output_dir: impl AsRef<Path>,
consent_cb: impl FnOnce(&Metadata) -> bool,
progress_cb: impl FnMut(u64) + Send + 'static,
) -> Result<(String, PathBuf), EngineError> {
let output_dir = output_dir.as_ref();
let (_endpoint, conn) = if let Some(phrase) = &self.code {
let Some((_name, peer_addr, _os)) = crate::discovery::listen_for_broadcast(
Some(phrase.as_str()),
Duration::from_secs(30),
)
.await?
else {
return Err(EngineError::Handshake(
"Timed out waiting for sender broadcast".into(),
));
};
let endpoint = network::bind_client().await?;
let client_cfg = network::client_config()?;
let connecting = endpoint.connect(peer_addr, "hayate.local", Some(client_cfg))?;
let conn = connecting.await?;
(endpoint, conn)
} else {
let endpoint = network::bind_server(self.bind_addr).await?;
let incoming = endpoint
.wait_incoming()
.await
.ok_or_else(|| EngineError::Handshake("Endpoint closed".into()))?;
let conn = incoming.await?;
(endpoint, conn)
};
let (mut send_stream, mut recv_stream) = conn.accept_bi().await?;
let ((key, cipher_id), meta) = transfer::handshake_receiver_split(
&mut send_stream,
&mut recv_stream,
self.code.as_deref(),
)
.await?;
let accept = self.auto_accept || consent_cb(&meta);
transfer::send_consent_write(&mut send_stream, accept).await?;
if !accept {
conn.close(0u32.into(), b"rejected");
return Err(EngineError::TransferRejected);
}
let dest = resolve_output(output_dir, &meta);
let checksum = transfer::receive_payload_split(
&key,
cipher_id,
&mut recv_stream,
&dest,
meta.transfer_type,
meta.total_size,
&meta.hash_algo,
progress_cb,
)
.await?;
conn.close(0u32.into(), b"complete");
Ok((checksum, dest))
}
}
fn resolve_output(output_dir: &Path, meta: &Metadata) -> PathBuf {
let name = Path::new(&meta.filename)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("received_file"));
output_dir.join(name)
}