use futures::{
io::{AsyncReadExt, AsyncWriteExt},
StreamExt, TryFutureExt,
};
use sha2::{digest::FixedOutput, Digest, Sha256};
use super::{offer::*, *};
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
#[serde(rename_all = "kebab-case")]
pub enum OfferMessage {
Message(String),
File {
filename: String,
filesize: u64,
},
Directory {
dirname: String,
mode: String,
zipsize: u64,
numbytes: u64,
numfiles: u64,
},
#[serde(other)]
Unknown,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AnswerMessage {
MessageAck(String),
FileAck(String),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct TransitV1 {
pub abilities_v1: TransitAbilities,
pub hints_v1: transit::Hints,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case")]
struct TransitAck {
pub ack: String,
pub sha256: String,
}
impl TransitAck {
pub(crate) fn new(msg: impl Into<String>, sha256: impl Into<String>) -> Self {
TransitAck {
ack: msg.into(),
sha256: sha256.into(),
}
}
#[cfg(test)]
pub(crate) fn serialize(&self) -> String {
json!(self).to_string()
}
pub(crate) fn serialize_vec(&self) -> Vec<u8> {
serde_json::to_vec(self).unwrap()
}
}
pub(crate) async fn send(
wormhole: Wormhole,
relay_hints: Vec<transit::RelayHint>,
transit_abilities: transit::Abilities,
offer: OfferSend,
progress_handler: impl FnMut(u64, u64) + 'static,
transit_handler: impl FnOnce(transit::TransitInfo),
_peer_version: AppVersion,
cancel: impl Future<Output = ()>,
) -> Result<(), TransferError> {
if offer.is_multiple() {
let folder = OfferSendEntry::Directory {
content: offer.content,
};
send_folder(
wormhole,
relay_hints,
"<unnamed folder>".into(),
folder,
transit_abilities,
transit_handler,
progress_handler,
cancel,
)
.await
} else if offer.is_directory() {
let (folder_name, folder) = offer.content.into_iter().next().unwrap();
send_folder(
wormhole,
relay_hints,
folder_name,
folder,
transit_abilities,
transit_handler,
progress_handler,
cancel,
)
.await
} else {
let (file_name, file) = offer.content.into_iter().next().unwrap();
let (mut file, file_size) = match file {
OfferSendEntry::RegularFile { content, size } => {
let content = content();
let content = content.await?;
(content, size)
},
_ => unreachable!(),
};
send_file(
wormhole,
relay_hints,
&mut file,
file_name,
file_size,
transit_abilities,
transit_handler,
progress_handler,
cancel,
)
.await
}
}
pub(crate) async fn send_file<F, G, H>(
mut wormhole: Wormhole,
relay_hints: Vec<transit::RelayHint>,
file: &mut F,
file_name: impl Into<String>,
file_size: u64,
transit_abilities: transit::Abilities,
transit_handler: G,
progress_handler: H,
cancel: impl Future<Output = ()>,
) -> Result<(), TransferError>
where
F: AsyncRead + Unpin + Send,
G: FnOnce(transit::TransitInfo),
H: FnMut(u64, u64) + 'static,
{
let run = Box::pin(async {
let connector = transit::init(transit_abilities, None, relay_hints).await?;
tracing::debug!("Sending transit message '{:?}", connector.our_hints());
wormhole
.send_json(&PeerMessage::transit_v1(
*connector.our_abilities(),
(**connector.our_hints()).clone(),
))
.await?;
tracing::debug!("Sending file offer");
wormhole
.send_json(&PeerMessage::offer_file_v1(file_name, file_size))
.await?;
let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
PeerMessage::Transit(transit) => {
tracing::debug!("Received transit message: {:?}", transit);
(transit.abilities_v1, transit.hints_v1)
},
other => {
bail!(TransferError::unexpected_message("transit", other))
},
};
{
let fileack_msg = wormhole.receive_json::<PeerMessage>().await??;
tracing::debug!("Received file ack message: {:?}", fileack_msg);
match fileack_msg.check_err()? {
PeerMessage::Answer(AnswerMessage::FileAck(msg)) => {
ensure!(msg == "ok", TransferError::AckError);
},
_ => {
bail!(TransferError::unexpected_message(
"answer/file_ack",
fileack_msg
));
},
}
}
let (mut transit, info) = connector
.leader_connect(
wormhole.key().derive_transit_key(wormhole.appid()),
their_abilities,
Arc::new(their_hints),
)
.await?;
transit_handler(info);
tracing::debug!("Beginning file transfer");
let file = futures::stream::once(futures::future::ready(std::io::Result::Ok(
Box::new(file) as Box<dyn AsyncRead + Unpin + Send>,
)));
let checksum = v1::send_records(&mut transit, file, file_size, progress_handler).await?;
tracing::debug!("sent file. Waiting for ack");
let transit_ack = transit.receive_record().await?;
let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
ensure!(
transit_ack_msg.sha256 == hex::encode(checksum),
TransferError::Checksum
);
tracing::debug!("Transfer complete!");
Ok(())
});
futures::pin_mut!(cancel);
let result = cancel::cancellable_2(run, cancel).await;
cancel::handle_run_result(wormhole, result).await
}
pub(crate) async fn send_folder(
mut wormhole: Wormhole,
relay_hints: Vec<transit::RelayHint>,
mut folder_name: String,
folder: OfferSendEntry,
transit_abilities: transit::Abilities,
transit_handler: impl FnOnce(transit::TransitInfo),
progress_handler: impl FnMut(u64, u64) + 'static,
cancel: impl Future<Output = ()>,
) -> Result<(), TransferError> {
let run = Box::pin(async {
let connector = transit::init(transit_abilities, None, relay_hints).await?;
tracing::debug!("Sending transit message '{:?}", connector.our_hints());
wormhole
.send_json(&PeerMessage::transit_v1(
*connector.our_abilities(),
(**connector.our_hints()).clone(),
))
.await?;
tracing::debug!("Estimating the file size");
use futures::{
future::{ready, BoxFuture},
io::Cursor,
};
use std::io::Result as IoResult;
type WrappedDataFut = BoxFuture<'static, IoResult<Box<dyn AsyncRead + Unpin + Send>>>;
fn wrap(buffer: impl AsRef<[u8]> + Unpin + Send + 'static) -> WrappedDataFut {
Box::pin(ready(IoResult::Ok(
Box::new(Cursor::new(buffer)) as Box<dyn AsyncRead + Unpin + Send>
))) as _
}
fn create_offer(
mut total_content: Vec<WrappedDataFut>,
total_size: &mut u64,
offer: OfferSendEntry,
path: &mut Vec<String>,
) -> IoResult<Vec<WrappedDataFut>> {
match offer {
OfferSendEntry::Directory { content } => {
tracing::debug!("Adding directory {path:?}");
let header = tar_helper::create_header_directory(path)?;
*total_size += header.len() as u64;
total_content.push(wrap(header));
for (name, file) in content {
path.push(name);
total_content = create_offer(total_content, total_size, file, path)?;
path.pop();
}
},
OfferSendEntry::RegularFile { size, content } => {
tracing::debug!("Adding file {path:?}; {size} bytes");
let header = tar_helper::create_header_file(path, size)?;
let padding = tar_helper::padding(size);
*total_size += header.len() as u64;
*total_size += padding.len() as u64;
*total_size += size;
total_content.push(wrap(header));
let content = content().map_ok(
|read| Box::new(read) as Box<dyn AsyncRead + Unpin + Send>,
);
total_content.push(Box::pin(content) as _);
total_content.push(wrap(padding));
},
}
Ok(total_content)
}
let mut total_size = 0;
let mut content = create_offer(
Vec::new(),
&mut total_size,
folder,
&mut vec![folder_name.clone()],
)?;
total_size += 1024;
content.push(wrap([0; 1024]));
let content = futures::stream::iter(content).then(|content| content);
tracing::debug!("Sending file offer ({total_size} bytes)");
folder_name.push_str(".tar");
wormhole
.send_json(&PeerMessage::offer_file_v1(folder_name, total_size))
.await?;
let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
PeerMessage::Transit(transit) => {
tracing::debug!("received transit message: {:?}", transit);
(transit.abilities_v1, transit.hints_v1)
},
other => {
bail!(TransferError::unexpected_message("transit", other));
},
};
match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
PeerMessage::Answer(AnswerMessage::FileAck(msg)) => {
ensure!(msg == "ok", TransferError::AckError);
},
other => {
bail!(TransferError::unexpected_message("answer/file_ack", other));
},
}
let (mut transit, info) = connector
.leader_connect(
wormhole.key().derive_transit_key(wormhole.appid()),
their_abilities,
Arc::new(their_hints),
)
.await?;
transit_handler(info);
tracing::debug!("Beginning file transfer");
let checksum =
v1::send_records(&mut transit, content, total_size, progress_handler).await?;
tracing::debug!("sent file. Waiting for ack");
let transit_ack = transit.receive_record().await?;
let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
ensure!(
transit_ack_msg.sha256 == hex::encode(checksum),
TransferError::Checksum
);
tracing::debug!("Transfer complete!");
Ok(())
});
futures::pin_mut!(cancel);
let result = cancel::cancellable_2(run, cancel).await;
cancel::handle_run_result(wormhole, result).await
}
pub async fn request(
mut wormhole: Wormhole,
relay_hints: Vec<transit::RelayHint>,
transit_abilities: transit::Abilities,
cancel: impl Future<Output = ()>,
) -> Result<Option<ReceiveRequest>, TransferError> {
let run = Box::pin(async {
let connector = transit::init(transit_abilities, None, relay_hints).await?;
tracing::debug!("Sending transit message '{:?}", connector.our_hints());
wormhole
.send_json(&PeerMessage::transit_v1(
*connector.our_abilities(),
(**connector.our_hints()).clone(),
))
.await?;
let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
PeerMessage::Transit(transit) => {
tracing::debug!("received transit message: {:?}", transit);
(transit.abilities_v1, transit.hints_v1)
},
other => {
bail!(TransferError::unexpected_message("transit", other));
},
};
let (filename, filesize) =
match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
PeerMessage::Offer(offer_type) => match offer_type {
v1::OfferMessage::File { filename, filesize } => (filename, filesize),
v1::OfferMessage::Directory {
mut dirname,
zipsize,
..
} => {
dirname.push_str(".zip");
(dirname, zipsize)
},
_ => bail!(TransferError::UnsupportedOffer),
},
other => {
bail!(TransferError::unexpected_message("offer", other));
},
};
Ok((filename, filesize, connector, their_abilities, their_hints))
});
futures::pin_mut!(cancel);
let result = cancel::cancellable_2(run, cancel).await;
cancel::handle_run_result_noclose(wormhole, result)
.await
.map(|inner: Option<_>| {
inner.map(
|((filename, filesize, connector, their_abilities, their_hints), wormhole, _)| {
ReceiveRequest::new(
filename,
filesize,
connector,
their_abilities,
their_hints,
wormhole,
)
},
)
})
}
#[must_use]
pub struct ReceiveRequest {
wormhole: Wormhole,
connector: TransitConnector,
#[deprecated(since = "0.7.0", note = "use ReceiveRequest::file_name(..) instead")]
#[cfg(not(target_family = "wasm"))]
pub filename: PathBuf,
file_name: String,
#[deprecated(since = "0.7.0", note = "use ReceiveRequest::file_size(..) instead")]
pub filesize: u64,
#[allow(dead_code)]
offer: Arc<Offer>,
their_abilities: transit::Abilities,
their_hints: Arc<transit::Hints>,
}
#[allow(deprecated)]
impl ReceiveRequest {
fn new(
file_name: String,
filesize: u64,
connector: TransitConnector,
their_abilities: transit::Abilities,
their_hints: transit::Hints,
wormhole: Wormhole,
) -> Self {
let their_hints = Arc::new(their_hints);
let mut content = BTreeMap::new();
content.insert(
file_name.clone(),
OfferEntry::RegularFile {
size: filesize,
content: (),
},
);
let offer = Arc::new(Offer { content });
#[allow(deprecated)]
Self {
wormhole,
connector,
#[cfg(not(target_family = "wasm"))]
filename: PathBuf::from(file_name.clone()),
file_name,
filesize,
offer,
their_abilities,
their_hints,
}
}
pub async fn accept<F, G, W>(
mut self,
transit_handler: G,
progress_handler: F,
content_handler: &mut W,
cancel: impl Future<Output = ()>,
) -> Result<(), TransferError>
where
F: FnMut(u64, u64) + 'static,
G: FnOnce(transit::TransitInfo),
W: AsyncWrite + Unpin,
{
let run = Box::pin(async {
tracing::debug!("Sending ack");
self.wormhole
.send_json(&PeerMessage::file_ack_v1("ok"))
.await?;
let (mut transit, info) = self
.connector
.follower_connect(
self.wormhole
.key()
.derive_transit_key(self.wormhole.appid()),
self.their_abilities,
self.their_hints.clone(),
)
.await?;
transit_handler(info);
tracing::debug!("Beginning file transfer");
tcp_file_receive(
&mut transit,
self.filesize,
progress_handler,
content_handler,
)
.await?;
Ok(())
});
futures::pin_mut!(cancel);
let result = cancel::cancellable_2(run, cancel).await;
cancel::handle_run_result(self.wormhole, result).await
}
pub async fn reject(mut self) -> Result<(), TransferError> {
self.wormhole
.send_json(&PeerMessage::error_message("transfer rejected"))
.await?;
self.wormhole.close().await?;
Ok(())
}
#[cfg(feature = "experimental-transfer-v2")]
#[allow(missing_docs)]
pub fn offer(&self) -> Arc<Offer> {
self.offer.clone()
}
pub fn file_name(&self) -> String {
self.file_name.clone()
}
pub fn file_size(&self) -> u64 {
self.filesize
}
}
pub(crate) async fn send_records<'a>(
transit: &mut Transit,
files: impl futures::Stream<Item = std::io::Result<Box<dyn AsyncRead + Unpin + Send + 'a>>>,
file_size: u64,
mut progress_handler: impl FnMut(u64, u64) + 'static,
) -> Result<Vec<u8>, TransferError> {
progress_handler(0, file_size);
let mut hasher = Sha256::default();
let mut plaintext = vec![0u8; 16 * 1024].into_boxed_slice();
let mut sent_size = 0;
futures::pin_mut!(files);
while let Some(mut file) = files.next().await.transpose()? {
loop {
let n = file.read(&mut plaintext[..]).await?;
if n == 0 {
break;
}
transit.send_record(&plaintext[0..n]).await?;
sent_size += n as u64;
progress_handler(sent_size, file_size);
hasher.update(&plaintext[..n]);
}
}
transit.flush().await?;
ensure!(
sent_size == file_size,
TransferError::FileSize {
sent_size,
file_size
}
);
Ok(hasher.finalize_fixed().to_vec())
}
pub(crate) async fn receive_records<F, W>(
filesize: u64,
transit: &mut Transit,
mut progress_handler: F,
mut content_handler: W,
) -> Result<Vec<u8>, TransferError>
where
F: FnMut(u64, u64) + 'static,
W: AsyncWrite + Unpin,
{
let mut hasher = Sha256::default();
let total = filesize;
let mut remaining_size = filesize as usize;
progress_handler(0, total);
while remaining_size > 0 {
let plaintext = transit.receive_record().await?;
content_handler.write_all(&plaintext).await?;
hasher.update(&plaintext);
remaining_size -= plaintext.len();
let remaining = remaining_size as u64;
progress_handler(total - remaining, total);
}
content_handler.close().await?;
tracing::debug!("done");
Ok(hasher.finalize_fixed().to_vec())
}
pub(crate) async fn tcp_file_receive<F, W>(
transit: &mut Transit,
filesize: u64,
progress_handler: F,
content_handler: &mut W,
) -> Result<(), TransferError>
where
F: FnMut(u64, u64) + 'static,
W: AsyncWrite + Unpin,
{
let checksum = receive_records(filesize, transit, progress_handler, content_handler).await?;
let sha256sum = hex::encode(checksum.as_slice());
tracing::debug!("sha256 sum: {:?}", sha256sum);
transit
.send_record(&TransitAck::new("ok", &sha256sum).serialize_vec())
.await?;
tracing::debug!("Transfer complete");
Ok(())
}
mod tar_helper {
#[allow(unused_imports)]
use std::{
borrow::Cow,
io::{self, Read, Write},
path::Path,
str,
};
pub(crate) fn create_header_file(path: &[String], size: u64) -> std::io::Result<Vec<u8>> {
let mut header = tar::Header::new_gnu();
header.set_size(size);
let mut data = Vec::with_capacity(1024);
prepare_header_path(&mut data, &mut header, path.join("/").as_ref())?;
header.set_mode(0o644);
header.set_cksum();
data.write_all(header.as_bytes())?;
Ok(data)
}
pub(crate) fn create_header_directory(path: &[String]) -> std::io::Result<Vec<u8>> {
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Directory);
let mut data = Vec::with_capacity(1024);
prepare_header_path(&mut data, &mut header, path.join("/").as_ref())?;
header.set_mode(0o755);
header.set_cksum();
data.write_all(header.as_bytes())?;
Ok(data)
}
pub(crate) fn padding(size: u64) -> &'static [u8] {
const BLOCK: [u8; 512] = [0; 512];
if size % 512 != 0 {
&BLOCK[size as usize % 512..]
} else {
&[]
}
}
fn append(
mut dst: &mut dyn std::io::Write,
header: &tar::Header,
mut data: &mut dyn std::io::Read,
) -> std::io::Result<()> {
dst.write_all(header.as_bytes())?;
let len = std::io::copy(&mut data, &mut dst)?;
dst.write_all(padding(len))?;
Ok(())
}
fn prepare_header(size: u64, entry_type: u8) -> tar::Header {
let mut header = tar::Header::new_gnu();
let name = b"././@LongLink";
header.as_gnu_mut().unwrap().name[..name.len()].clone_from_slice(&name[..]);
header.set_mode(0o644);
header.set_uid(0);
header.set_gid(0);
header.set_mtime(0);
header.set_size(size + 1);
header.set_entry_type(tar::EntryType::new(entry_type));
header.set_cksum();
header
}
fn prepare_header_path(
dst: &mut dyn std::io::Write,
header: &mut tar::Header,
path: &str,
) -> std::io::Result<()> {
if let Err(e) = header.set_path(path) {
let data = path2bytes(path);
let max = header.as_old().name.len();
if data.len() < max {
return Err(e);
}
let header2 = prepare_header(data.len() as u64, b'L');
let mut data2 = data.chain(io::repeat(0).take(1));
append(dst, &header2, &mut data2)?;
let truncated = match std::str::from_utf8(&data[..max]) {
Ok(s) => s,
Err(e) => std::str::from_utf8(&data[..e.valid_up_to()]).unwrap(),
};
header.set_path(truncated)?;
}
Ok(())
}
#[cfg(any(windows, target_arch = "wasm32"))]
pub(crate) fn path2bytes(p: &str) -> Cow<[u8]> {
let bytes = p.as_bytes();
if bytes.contains(&b'\\') {
let mut bytes = bytes.to_owned();
for b in &mut bytes {
if *b == b'\\' {
*b = b'/';
}
}
Cow::Owned(bytes)
} else {
Cow::Borrowed(bytes)
}
}
#[cfg(unix)]
pub(crate) fn path2bytes(p: &str) -> Cow<[u8]> {
Cow::Borrowed(p.as_bytes())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_transit_ack() {
let f1 = TransitAck::new("ok", "deadbeaf");
assert_eq!(f1.serialize(), "{\"ack\":\"ok\",\"sha256\":\"deadbeaf\"}");
}
}