use futures::{AsyncReadExt, AsyncWriteExt};
use log::*;
use sha2::{digest::FixedOutput, Digest, Sha256};
use std::path::PathBuf;
use super::*;
pub async fn send_file<F, N, G, H>(
mut wormhole: Wormhole,
relay_hints: Vec<transit::RelayHint>,
file: &mut F,
file_name: N,
file_size: u64,
transit_abilities: transit::Abilities,
transit_handler: G,
progress_handler: H,
cancel: impl Future<Output = ()>,
) -> Result<(), TransferError>
where
F: AsyncRead + Unpin,
N: Into<PathBuf>,
G: FnOnce(transit::TransitInfo, std::net::SocketAddr),
H: FnMut(u64, u64) + 'static,
{
let run = Box::pin(async {
let connector = transit::init(transit_abilities, None, relay_hints).await?;
debug!("Sending transit message '{:?}", connector.our_hints());
wormhole
.send_json(&PeerMessage::transit(
*connector.our_abilities(),
(**connector.our_hints()).clone(),
))
.await?;
debug!("Sending file offer");
wormhole
.send_json(&PeerMessage::offer_file(file_name, file_size))
.await?;
let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
match wormhole.receive_json().await?? {
PeerMessage::Transit(transit) => {
debug!("Received transit message: {:?}", transit);
(transit.abilities_v1, transit.hints_v1)
},
PeerMessage::Error(err) => {
bail!(TransferError::PeerError(err));
},
other => {
bail!(TransferError::unexpected_message("transit", other))
},
};
{
let fileack_msg = wormhole.receive_json().await??;
debug!("Received file ack message: {:?}", fileack_msg);
match fileack_msg {
PeerMessage::Answer(Answer::FileAck(msg)) => {
ensure!(msg == "ok", TransferError::AckError);
},
PeerMessage::Error(err) => {
bail!(TransferError::PeerError(err));
},
_ => {
bail!(TransferError::unexpected_message(
"answer/file_ack",
fileack_msg
));
},
}
}
let (mut transit, info, addr) = connector
.leader_connect(
wormhole.key().derive_transit_key(wormhole.appid()),
their_abilities,
Arc::new(their_hints),
)
.await?;
transit_handler(info, addr);
debug!("Beginning file transfer");
let checksum = v1::send_records(&mut transit, file, file_size, progress_handler).await?;
debug!("sent file. Waiting for ack");
let transit_ack = transit.receive_record().await?;
let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
ensure!(
transit_ack_msg.sha256 == hex::encode(checksum),
TransferError::Checksum
);
debug!("Transfer complete!");
Ok(())
});
futures::pin_mut!(cancel);
let result = crate::util::cancellable_2(run, cancel).await;
super::handle_run_result(wormhole, result).await
}
pub async fn send_folder<N, M, G, H>(
mut wormhole: Wormhole,
relay_hints: Vec<transit::RelayHint>,
folder_path: N,
folder_name: M,
transit_abilities: transit::Abilities,
transit_handler: G,
progress_handler: H,
cancel: impl Future<Output = ()>,
) -> Result<(), TransferError>
where
N: Into<PathBuf>,
M: Into<PathBuf>,
G: FnOnce(transit::TransitInfo, std::net::SocketAddr),
H: FnMut(u64, u64) + 'static,
{
let run = Box::pin(async {
let connector = transit::init(transit_abilities, None, relay_hints).await?;
let folder_path = folder_path.into();
if !folder_path.is_dir() {
panic!(
"You should only call this method with directory paths, but '{}' is not",
folder_path.display()
);
}
debug!("Sending transit message '{:?}", connector.our_hints());
wormhole
.send_json(&PeerMessage::transit(
*connector.our_abilities(),
(**connector.our_hints()).clone(),
))
.await?;
struct CountWrite<W> {
inner: W,
count: u64,
}
impl<W: std::io::Write> std::io::Write for CountWrite<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let written = self.inner.write(buf)?;
self.count += written as u64;
Ok(written)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
log::info!("Calculating the size of '{}'", folder_path.display());
let folder_path2 = folder_path.clone();
let (length, sha256sum_initial) = {
let mut hasher = Sha256::new();
let mut counter = CountWrite {
inner: &mut hasher,
count: 0,
};
let mut builder = async_tar::Builder::new(futures::io::AllowStdIo::new(&mut counter));
builder.mode(async_tar::HeaderMode::Deterministic);
builder.follow_symlinks(false);
builder.append_dir_all("", folder_path2).await.unwrap();
builder.finish().await.unwrap();
std::mem::drop(builder);
let count = counter.count;
std::mem::drop(counter);
(count, hasher.finalize_fixed())
};
debug!("Sending file offer");
wormhole
.send_json(&PeerMessage::offer_file(folder_name, length))
.await?;
let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
match wormhole.receive_json().await?? {
PeerMessage::Transit(transit) => {
debug!("received transit message: {:?}", transit);
(transit.abilities_v1, transit.hints_v1)
},
PeerMessage::Error(err) => {
bail!(TransferError::PeerError(err));
},
other => {
bail!(TransferError::unexpected_message("transit", other));
},
};
match wormhole.receive_json().await?? {
PeerMessage::Answer(Answer::FileAck(msg)) => {
ensure!(msg == "ok", TransferError::AckError);
},
PeerMessage::Error(err) => {
bail!(TransferError::PeerError(err));
},
other => {
bail!(TransferError::unexpected_message("answer/file_ack", other));
},
}
let (mut transit, info, addr) = connector
.leader_connect(
wormhole.key().derive_transit_key(wormhole.appid()),
their_abilities,
Arc::new(their_hints),
)
.await?;
transit_handler(info, addr);
debug!("Beginning file transfer");
pub struct HashWriter<D: sha2::digest::Update, W: futures::io::AsyncWrite + Unpin> {
writer: W,
hasher: D,
}
use std::{
pin::Pin,
task::{Context, Poll},
};
impl<D: sha2::digest::Update + Unpin, W: futures::io::AsyncWrite + Unpin>
futures::io::AsyncWrite for HashWriter<D, W>
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match Pin::new(&mut self.writer).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
self.hasher.update(&buf[..n]);
Poll::Ready(Ok(n))
},
res => res,
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.writer).poll_flush(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.writer).poll_close(cx)
}
}
let (mut reader, writer) = futures_ringbuf::RingBuffer::new(4096).split();
let file_sender = async_std::task::spawn(async move {
let mut hash_writer = HashWriter {
writer,
hasher: Sha256::new(),
};
let mut builder = async_tar::Builder::new(&mut hash_writer);
builder.mode(async_tar::HeaderMode::Deterministic);
builder.follow_symlinks(false);
builder.append_dir_all("", folder_path).await?;
builder.finish().await?;
std::mem::drop(builder);
hash_writer.flush().await?;
hash_writer.close().await?;
let hasher = hash_writer.hasher;
std::io::Result::Ok(hasher.finalize_fixed())
});
let (checksum, sha256sum) =
match v1::send_records(&mut transit, &mut reader, length, progress_handler).await {
Ok(checksum) => (checksum, file_sender.await?),
Err(err) => {
log::debug!("Some more error {err}");
if let Some(Err(err)) = file_sender.cancel().await {
log::warn!("Error in background task: {err}");
}
return Err(err);
},
};
ensure!(
sha256sum == sha256sum_initial,
TransferError::FilesystemSkew
);
debug!("sent file. Waiting for ack");
let transit_ack = transit.receive_record().await?;
let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
ensure!(
transit_ack_msg.sha256 == hex::encode(checksum),
TransferError::Checksum
);
debug!("Transfer complete!");
Ok(())
});
futures::pin_mut!(cancel);
let result = crate::util::cancellable_2(run, cancel).await;
super::handle_run_result(wormhole, result).await
}
pub async fn send_records<F>(
transit: &mut Transit,
file: &mut (impl AsyncRead + Unpin),
file_size: u64,
mut progress_handler: F,
) -> Result<Vec<u8>, TransferError>
where
F: FnMut(u64, u64) + 'static,
{
progress_handler(0, file_size);
let mut hasher = Sha256::default();
let mut plaintext = Box::new([0u8; 4096]);
let mut sent_size = 0;
loop {
let n = file.read(&mut plaintext[..]).await?;
log::debug!("Read {n}");
if n == 0 {
break;
}
transit.send_record(&plaintext[0..n]).await?;
sent_size += n as u64;
progress_handler(sent_size, file_size);
hasher.update(&plaintext[..n]);
}
transit.flush().await?;
ensure!(
sent_size == file_size,
TransferError::FileSize {
sent_size,
file_size
}
);
Ok(hasher.finalize_fixed().to_vec())
}
pub async fn receive_records<F, W>(
filesize: u64,
transit: &mut Transit,
mut progress_handler: F,
content_handler: &mut W,
) -> Result<Vec<u8>, TransferError>
where
F: FnMut(u64, u64) + 'static,
W: AsyncWrite + Unpin,
{
let mut hasher = Sha256::default();
let total = filesize;
let mut remaining_size = filesize as usize;
progress_handler(0, total);
while remaining_size > 0 {
let plaintext = transit.receive_record().await?;
content_handler.write_all(&plaintext).await?;
hasher.update(&plaintext);
remaining_size -= plaintext.len();
let remaining = remaining_size as u64;
progress_handler(total - remaining, total);
}
debug!("done");
Ok(hasher.finalize_fixed().to_vec())
}
pub async fn tcp_file_receive<F, W>(
transit: &mut Transit,
filesize: u64,
progress_handler: F,
content_handler: &mut W,
) -> Result<(), TransferError>
where
F: FnMut(u64, u64) + 'static,
W: AsyncWrite + Unpin,
{
let checksum = receive_records(filesize, transit, progress_handler, content_handler).await?;
let sha256sum = hex::encode(checksum.as_slice());
debug!("sha256 sum: {:?}", sha256sum);
transit
.send_record(&TransitAck::new("ok", &sha256sum).serialize_vec())
.await?;
debug!("Transfer complete");
Ok(())
}