1use futures::{
2 io::{AsyncReadExt, AsyncWriteExt},
3 StreamExt, TryFutureExt,
4};
5use sha2::{digest::FixedOutput, Digest, Sha256};
6
7use super::{offer::*, *};
8
9#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
10#[serde(rename_all = "kebab-case")]
11pub enum OfferMessage {
12 Message(String),
13 File {
14 filename: String,
15 filesize: u64,
16 },
17 Directory {
18 dirname: String,
19 mode: String,
20 zipsize: u64,
21 numbytes: u64,
22 numfiles: u64,
23 },
24 #[serde(other)]
25 Unknown,
26}
27
28#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
29#[serde(rename_all = "snake_case")]
30pub enum AnswerMessage {
31 MessageAck(String),
32 FileAck(String),
33}
34
35#[derive(Serialize, Deserialize, Debug, Clone)]
39#[serde(rename_all = "kebab-case")]
40pub struct TransitV1 {
41 pub abilities_v1: TransitAbilities,
42 pub hints_v1: transit::Hints,
43}
44
45#[derive(Serialize, Deserialize, Debug, PartialEq)]
46#[serde(rename_all = "kebab-case")]
47struct TransitAck {
48 pub ack: String,
49 pub sha256: String,
50}
51
52impl TransitAck {
53 pub(crate) fn new(msg: impl Into<String>, sha256: impl Into<String>) -> Self {
54 TransitAck {
55 ack: msg.into(),
56 sha256: sha256.into(),
57 }
58 }
59
60 #[cfg(test)]
61 pub(crate) fn serialize(&self) -> String {
62 json!(self).to_string()
63 }
64
65 pub(crate) fn serialize_vec(&self) -> Vec<u8> {
66 serde_json::to_vec(self).unwrap()
67 }
68}
69
70pub(crate) async fn send(
71 wormhole: Wormhole,
72 relay_hints: Vec<transit::RelayHint>,
73 transit_abilities: transit::Abilities,
74 offer: OfferSend,
75 progress_handler: impl FnMut(u64, u64) + 'static,
76 transit_handler: impl FnOnce(transit::TransitInfo),
77 _peer_version: AppVersion,
78 cancel: impl Future<Output = ()>,
79) -> Result<(), TransferError> {
80 if offer.is_multiple() {
81 let folder = OfferSendEntry::Directory {
82 content: offer.content,
83 };
84 send_folder(
85 wormhole,
86 relay_hints,
87 "<unnamed folder>".into(),
88 folder,
89 transit_abilities,
90 transit_handler,
91 progress_handler,
92 cancel,
93 )
94 .await
95 } else if offer.is_directory() {
96 let (folder_name, folder) = offer.content.into_iter().next().unwrap();
97 send_folder(
98 wormhole,
99 relay_hints,
100 folder_name,
101 folder,
102 transit_abilities,
103 transit_handler,
104 progress_handler,
105 cancel,
106 )
107 .await
108 } else {
109 let (file_name, file) = offer.content.into_iter().next().unwrap();
110 let (mut file, file_size) = match file {
111 OfferSendEntry::RegularFile { content, size } => {
112 let content = content();
114 let content = content.await?;
115 (content, size)
116 },
117 _ => unreachable!(),
118 };
119 send_file(
120 wormhole,
121 relay_hints,
122 &mut file,
123 file_name,
124 file_size,
125 transit_abilities,
126 transit_handler,
127 progress_handler,
128 cancel,
129 )
130 .await
131 }
132}
133
134pub(crate) async fn send_file<F, G, H>(
135 mut wormhole: Wormhole,
136 relay_hints: Vec<transit::RelayHint>,
137 file: &mut F,
138 file_name: impl Into<String>,
139 file_size: u64,
140 transit_abilities: transit::Abilities,
141 transit_handler: G,
142 progress_handler: H,
143 cancel: impl Future<Output = ()>,
144) -> Result<(), TransferError>
145where
146 F: AsyncRead + Unpin + Send,
147 G: FnOnce(transit::TransitInfo),
148 H: FnMut(u64, u64) + 'static,
149{
150 let run = Box::pin(async {
151 let connector = transit::init(transit_abilities, None, relay_hints).await?;
152
153 tracing::debug!("Sending transit message '{:?}", connector.our_hints());
155 wormhole
156 .send_json(&PeerMessage::transit_v1(
157 *connector.our_abilities(),
158 (**connector.our_hints()).clone(),
159 ))
160 .await?;
161
162 tracing::debug!("Sending file offer");
164 wormhole
165 .send_json(&PeerMessage::offer_file_v1(file_name, file_size))
166 .await?;
167
168 let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
170 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
171 PeerMessage::Transit(transit) => {
172 tracing::debug!("Received transit message: {:?}", transit);
173 (transit.abilities_v1, transit.hints_v1)
174 },
175 other => {
176 bail!(TransferError::unexpected_message("transit", other))
177 },
178 };
179
180 {
181 let fileack_msg = wormhole.receive_json::<PeerMessage>().await??;
183 tracing::debug!("Received file ack message: {:?}", fileack_msg);
184
185 match fileack_msg.check_err()? {
186 PeerMessage::Answer(AnswerMessage::FileAck(msg)) => {
187 ensure!(msg == "ok", TransferError::AckError);
188 },
189 _ => {
190 bail!(TransferError::unexpected_message(
191 "answer/file_ack",
192 fileack_msg
193 ));
194 },
195 }
196 }
197
198 let (mut transit, info) = connector
199 .leader_connect(
200 wormhole.key().derive_transit_key(wormhole.appid()),
201 their_abilities,
202 Arc::new(their_hints),
203 )
204 .await?;
205 transit_handler(info);
206
207 tracing::debug!("Beginning file transfer");
208
209 let file = futures::stream::once(futures::future::ready(std::io::Result::Ok(
211 Box::new(file) as Box<dyn AsyncRead + Unpin + Send>,
212 )));
213 let checksum = v1::send_records(&mut transit, file, file_size, progress_handler).await?;
214
215 tracing::debug!("sent file. Waiting for ack");
217 let transit_ack = transit.receive_record().await?;
218 let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
219 ensure!(
220 transit_ack_msg.sha256 == hex::encode(checksum),
221 TransferError::Checksum
222 );
223 tracing::debug!("Transfer complete!");
224
225 Ok(())
226 });
227
228 futures::pin_mut!(cancel);
229 let result = cancel::cancellable_2(run, cancel).await;
230 cancel::handle_run_result(wormhole, result).await
231}
232
233pub(crate) async fn send_folder(
234 mut wormhole: Wormhole,
235 relay_hints: Vec<transit::RelayHint>,
236 mut folder_name: String,
237 folder: OfferSendEntry,
238 transit_abilities: transit::Abilities,
239 transit_handler: impl FnOnce(transit::TransitInfo),
240 progress_handler: impl FnMut(u64, u64) + 'static,
241 cancel: impl Future<Output = ()>,
242) -> Result<(), TransferError> {
243 let run = Box::pin(async {
244 let connector = transit::init(transit_abilities, None, relay_hints).await?;
245
246 tracing::debug!("Sending transit message '{:?}", connector.our_hints());
248 wormhole
249 .send_json(&PeerMessage::transit_v1(
250 *connector.our_abilities(),
251 (**connector.our_hints()).clone(),
252 ))
253 .await?;
254
255 tracing::debug!("Estimating the file size");
260
261 use futures::{
263 future::{ready, BoxFuture},
264 io::Cursor,
265 };
266 use std::io::Result as IoResult;
267
268 type WrappedDataFut = BoxFuture<'static, IoResult<Box<dyn AsyncRead + Unpin + Send>>>;
269
270 fn wrap(buffer: impl AsRef<[u8]> + Unpin + Send + 'static) -> WrappedDataFut {
272 Box::pin(ready(IoResult::Ok(
273 Box::new(Cursor::new(buffer)) as Box<dyn AsyncRead + Unpin + Send>
274 ))) as _
275 }
276
277 fn create_offer(
279 mut total_content: Vec<WrappedDataFut>,
280 total_size: &mut u64,
281 offer: OfferSendEntry,
282 path: &mut Vec<String>,
283 ) -> IoResult<Vec<WrappedDataFut>> {
284 match offer {
285 OfferSendEntry::Directory { content } => {
286 tracing::debug!("Adding directory {path:?}");
287 let header = tar_helper::create_header_directory(path)?;
288 *total_size += header.len() as u64;
289 total_content.push(wrap(header));
290
291 for (name, file) in content {
292 path.push(name);
293 total_content = create_offer(total_content, total_size, file, path)?;
294 path.pop();
295 }
296 },
297 OfferSendEntry::RegularFile { size, content } => {
298 tracing::debug!("Adding file {path:?}; {size} bytes");
299 let header = tar_helper::create_header_file(path, size)?;
300 let padding = tar_helper::padding(size);
301 *total_size += header.len() as u64;
302 *total_size += padding.len() as u64;
303 *total_size += size;
304
305 total_content.push(wrap(header));
306 let content = content().map_ok(
307 |read| Box::new(read) as Box<dyn AsyncRead + Unpin + Send>,
309 );
310 total_content.push(Box::pin(content) as _);
311 total_content.push(wrap(padding));
312 },
313 }
315 Ok(total_content)
316 }
317
318 let mut total_size = 0;
319 let mut content = create_offer(
320 Vec::new(),
321 &mut total_size,
322 folder,
323 &mut vec![folder_name.clone()],
324 )?;
325
326 total_size += 1024;
328 content.push(wrap([0; 1024]));
329
330 let content = futures::stream::iter(content).then(|content| content);
331
332 tracing::debug!("Sending file offer ({total_size} bytes)");
336 folder_name.push_str(".tar");
337 wormhole
338 .send_json(&PeerMessage::offer_file_v1(folder_name, total_size))
339 .await?;
340
341 let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
343 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
344 PeerMessage::Transit(transit) => {
345 tracing::debug!("received transit message: {:?}", transit);
346 (transit.abilities_v1, transit.hints_v1)
347 },
348 other => {
349 bail!(TransferError::unexpected_message("transit", other));
350 },
351 };
352
353 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
355 PeerMessage::Answer(AnswerMessage::FileAck(msg)) => {
356 ensure!(msg == "ok", TransferError::AckError);
357 },
358 other => {
359 bail!(TransferError::unexpected_message("answer/file_ack", other));
360 },
361 }
362
363 let (mut transit, info) = connector
364 .leader_connect(
365 wormhole.key().derive_transit_key(wormhole.appid()),
366 their_abilities,
367 Arc::new(their_hints),
368 )
369 .await?;
370 transit_handler(info);
371
372 tracing::debug!("Beginning file transfer");
373
374 let checksum =
376 v1::send_records(&mut transit, content, total_size, progress_handler).await?;
377
378 tracing::debug!("sent file. Waiting for ack");
380 let transit_ack = transit.receive_record().await?;
381 let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
382 ensure!(
383 transit_ack_msg.sha256 == hex::encode(checksum),
384 TransferError::Checksum
385 );
386 tracing::debug!("Transfer complete!");
387
388 Ok(())
389 });
390
391 futures::pin_mut!(cancel);
392 let result = cancel::cancellable_2(run, cancel).await;
393 cancel::handle_run_result(wormhole, result).await
394}
395
396pub async fn request(
405 mut wormhole: Wormhole,
406 relay_hints: Vec<transit::RelayHint>,
407 transit_abilities: transit::Abilities,
408 cancel: impl Future<Output = ()>,
409) -> Result<Option<ReceiveRequest>, TransferError> {
410 let run = Box::pin(async {
412 let connector = transit::init(transit_abilities, None, relay_hints).await?;
413
414 tracing::debug!("Sending transit message '{:?}", connector.our_hints());
416 wormhole
417 .send_json(&PeerMessage::transit_v1(
418 *connector.our_abilities(),
419 (**connector.our_hints()).clone(),
420 ))
421 .await?;
422
423 let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
425 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
426 PeerMessage::Transit(transit) => {
427 tracing::debug!("received transit message: {:?}", transit);
428 (transit.abilities_v1, transit.hints_v1)
429 },
430 other => {
431 bail!(TransferError::unexpected_message("transit", other));
432 },
433 };
434
435 let (filename, filesize) =
437 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
438 PeerMessage::Offer(offer_type) => match offer_type {
439 v1::OfferMessage::File { filename, filesize } => (filename, filesize),
440 v1::OfferMessage::Directory {
441 mut dirname,
442 zipsize,
443 ..
444 } => {
445 dirname.push_str(".zip");
446 (dirname, zipsize)
447 },
448 _ => bail!(TransferError::UnsupportedOffer),
449 },
450 other => {
451 bail!(TransferError::unexpected_message("offer", other));
452 },
453 };
454
455 Ok((filename, filesize, connector, their_abilities, their_hints))
456 });
457
458 futures::pin_mut!(cancel);
459 let result = cancel::cancellable_2(run, cancel).await;
460 cancel::handle_run_result_noclose(wormhole, result)
461 .await
462 .map(|inner: Option<_>| {
463 inner.map(
464 |((filename, filesize, connector, their_abilities, their_hints), wormhole, _)| {
465 ReceiveRequest::new(
466 filename,
467 filesize,
468 connector,
469 their_abilities,
470 their_hints,
471 wormhole,
472 )
473 },
474 )
475 })
476}
477
478#[must_use]
484pub struct ReceiveRequest {
485 wormhole: Wormhole,
486 connector: TransitConnector,
487
488 #[deprecated(since = "0.7.0", note = "use ReceiveRequest::file_name(..) instead")]
490 #[cfg(not(target_family = "wasm"))]
491 pub filename: PathBuf,
492 file_name: String,
493
494 #[deprecated(since = "0.7.0", note = "use ReceiveRequest::file_size(..) instead")]
496 pub filesize: u64,
497
498 #[allow(dead_code)]
499 offer: Arc<Offer>,
500 their_abilities: transit::Abilities,
501 their_hints: Arc<transit::Hints>,
502}
503
504#[allow(deprecated)]
505impl ReceiveRequest {
506 fn new(
507 file_name: String,
508 filesize: u64,
509 connector: TransitConnector,
510 their_abilities: transit::Abilities,
511 their_hints: transit::Hints,
512 wormhole: Wormhole,
513 ) -> Self {
514 let their_hints = Arc::new(their_hints);
515 let mut content = BTreeMap::new();
516
517 content.insert(
519 file_name.clone(),
520 OfferEntry::RegularFile {
521 size: filesize,
522 content: (),
523 },
524 );
525
526 let offer = Arc::new(Offer { content });
527
528 #[allow(deprecated)]
529 Self {
530 wormhole,
531 connector,
532 #[cfg(not(target_family = "wasm"))]
533 filename: PathBuf::from(file_name.clone()),
534 file_name,
535 filesize,
536 offer,
537 their_abilities,
538 their_hints,
539 }
540 }
541
542 pub async fn accept<F, G, W>(
548 mut self,
549 transit_handler: G,
550 progress_handler: F,
551 content_handler: &mut W,
552 cancel: impl Future<Output = ()>,
553 ) -> Result<(), TransferError>
554 where
555 F: FnMut(u64, u64) + 'static,
556 G: FnOnce(transit::TransitInfo),
557 W: AsyncWrite + Unpin,
558 {
559 let run = Box::pin(async {
560 tracing::debug!("Sending ack");
562 self.wormhole
563 .send_json(&PeerMessage::file_ack_v1("ok"))
564 .await?;
565
566 let (mut transit, info) = self
567 .connector
568 .follower_connect(
569 self.wormhole
570 .key()
571 .derive_transit_key(self.wormhole.appid()),
572 self.their_abilities,
573 self.their_hints.clone(),
574 )
575 .await?;
576 transit_handler(info);
577
578 tracing::debug!("Beginning file transfer");
579 tcp_file_receive(
580 &mut transit,
581 self.filesize,
582 progress_handler,
583 content_handler,
584 )
585 .await?;
586 Ok(())
587 });
588
589 futures::pin_mut!(cancel);
590 let result = cancel::cancellable_2(run, cancel).await;
591 cancel::handle_run_result(self.wormhole, result).await
592 }
593
594 pub async fn reject(mut self) -> Result<(), TransferError> {
600 self.wormhole
601 .send_json(&PeerMessage::error_message("transfer rejected"))
602 .await?;
603 self.wormhole.close().await?;
604
605 Ok(())
606 }
607
608 #[cfg(feature = "experimental-transfer-v2")]
609 #[allow(missing_docs)]
610 pub fn offer(&self) -> Arc<Offer> {
611 self.offer.clone()
612 }
613
614 pub fn file_name(&self) -> String {
618 self.file_name.clone()
619 }
620
621 pub fn file_size(&self) -> u64 {
623 self.filesize
624 }
625}
626
627pub(crate) async fn send_records<'a>(
630 transit: &mut Transit,
631 files: impl futures::Stream<Item = std::io::Result<Box<dyn AsyncRead + Unpin + Send + 'a>>>,
632 file_size: u64,
633 mut progress_handler: impl FnMut(u64, u64) + 'static,
634) -> Result<Vec<u8>, TransferError> {
635 progress_handler(0, file_size);
646
647 let mut hasher = Sha256::default();
648
649 let mut plaintext = vec![0u8; 16 * 1024].into_boxed_slice();
650 let mut sent_size = 0;
651 futures::pin_mut!(files);
652 while let Some(mut file) = files.next().await.transpose()? {
653 loop {
654 let n = file.read(&mut plaintext[..]).await?;
656
657 if n == 0 {
658 break;
660 }
661
662 transit.send_record(&plaintext[0..n]).await?;
664 sent_size += n as u64;
665 progress_handler(sent_size, file_size);
666
667 hasher.update(&plaintext[..n]);
669
670 }
675 }
676 transit.flush().await?;
677
678 ensure!(
679 sent_size == file_size,
680 TransferError::FileSize {
681 sent_size,
682 file_size
683 }
684 );
685
686 Ok(hasher.finalize_fixed().to_vec())
687}
688
689pub(crate) async fn receive_records<F, W>(
690 filesize: u64,
691 transit: &mut Transit,
692 mut progress_handler: F,
693 mut content_handler: W,
694) -> Result<Vec<u8>, TransferError>
695where
696 F: FnMut(u64, u64) + 'static,
697 W: AsyncWrite + Unpin,
698{
699 let mut hasher = Sha256::default();
700 let total = filesize;
701
702 let mut remaining_size = filesize as usize;
703
704 progress_handler(0, total);
707
708 while remaining_size > 0 {
709 let plaintext = transit.receive_record().await?;
711
712 content_handler.write_all(&plaintext).await?;
713
714 hasher.update(&plaintext);
716
717 remaining_size -= plaintext.len();
718
719 let remaining = remaining_size as u64;
720 progress_handler(total - remaining, total);
721 }
722 content_handler.close().await?;
723
724 tracing::debug!("done");
725 Ok(hasher.finalize_fixed().to_vec())
727}
728
729pub(crate) async fn tcp_file_receive<F, W>(
730 transit: &mut Transit,
731 filesize: u64,
732 progress_handler: F,
733 content_handler: &mut W,
734) -> Result<(), TransferError>
735where
736 F: FnMut(u64, u64) + 'static,
737 W: AsyncWrite + Unpin,
738{
739 let checksum = receive_records(filesize, transit, progress_handler, content_handler).await?;
743
744 let sha256sum = hex::encode(checksum.as_slice());
745 tracing::debug!("sha256 sum: {:?}", sha256sum);
746
747 transit
749 .send_record(&TransitAck::new("ok", &sha256sum).serialize_vec())
750 .await?;
751
752 tracing::debug!("Transfer complete");
755 Ok(())
756}
757
758mod tar_helper {
760 #[allow(unused_imports)]
762 use std::{
763 borrow::Cow,
764 io::{self, Read, Write},
765 path::Path,
766 str,
767 };
768
769 pub(crate) fn create_header_file(path: &[String], size: u64) -> std::io::Result<Vec<u8>> {
770 let mut header = tar::Header::new_gnu();
771 header.set_size(size);
772 let mut data = Vec::with_capacity(1024);
773 prepare_header_path(&mut data, &mut header, path.join("/").as_ref())?;
774 header.set_mode(0o644);
775 header.set_cksum();
776 data.write_all(header.as_bytes())?;
777 Ok(data)
778 }
779
780 pub(crate) fn create_header_directory(path: &[String]) -> std::io::Result<Vec<u8>> {
781 let mut header = tar::Header::new_gnu();
782 header.set_entry_type(tar::EntryType::Directory);
783 let mut data = Vec::with_capacity(1024);
784 prepare_header_path(&mut data, &mut header, path.join("/").as_ref())?;
785 header.set_mode(0o755);
786 header.set_cksum();
787 data.write_all(header.as_bytes())?;
788 Ok(data)
790 }
791
792 pub(crate) fn padding(size: u64) -> &'static [u8] {
793 const BLOCK: [u8; 512] = [0; 512];
794 if size % 512 != 0 {
795 &BLOCK[size as usize % 512..]
796 } else {
797 &[]
798 }
799 }
800
801 fn append(
802 mut dst: &mut dyn std::io::Write,
803 header: &tar::Header,
804 mut data: &mut dyn std::io::Read,
805 ) -> std::io::Result<()> {
806 dst.write_all(header.as_bytes())?;
807 let len = std::io::copy(&mut data, &mut dst)?;
808 dst.write_all(padding(len))?;
809 Ok(())
810 }
811
812 fn prepare_header(size: u64, entry_type: u8) -> tar::Header {
813 let mut header = tar::Header::new_gnu();
814 let name = b"././@LongLink";
815 header.as_gnu_mut().unwrap().name[..name.len()].clone_from_slice(&name[..]);
816 header.set_mode(0o644);
817 header.set_uid(0);
818 header.set_gid(0);
819 header.set_mtime(0);
820 header.set_size(size + 1);
822 header.set_entry_type(tar::EntryType::new(entry_type));
823 header.set_cksum();
824 header
825 }
826
827 fn prepare_header_path(
828 dst: &mut dyn std::io::Write,
829 header: &mut tar::Header,
830 path: &str,
831 ) -> std::io::Result<()> {
832 if let Err(e) = header.set_path(path) {
837 let data = path2bytes(path);
838 let max = header.as_old().name.len();
839 if data.len() < max {
842 return Err(e);
843 }
844 let header2 = prepare_header(data.len() as u64, b'L');
845 let mut data2 = data.chain(io::repeat(0).take(1));
847 append(dst, &header2, &mut data2)?;
848
849 let truncated = match std::str::from_utf8(&data[..max]) {
855 Ok(s) => s,
856 Err(e) => std::str::from_utf8(&data[..e.valid_up_to()]).unwrap(),
857 };
858 header.set_path(truncated)?;
859 }
860 Ok(())
861 }
862
863 #[cfg(any(windows, target_arch = "wasm32"))]
864 pub(crate) fn path2bytes(p: &str) -> Cow<[u8]> {
865 let bytes = p.as_bytes();
866 if bytes.contains(&b'\\') {
867 let mut bytes = bytes.to_owned();
869 for b in &mut bytes {
870 if *b == b'\\' {
871 *b = b'/';
872 }
873 }
874 Cow::Owned(bytes)
875 } else {
876 Cow::Borrowed(bytes)
877 }
878 }
879
880 #[cfg(unix)]
881 pub(crate) fn path2bytes(p: &str) -> Cow<[u8]> {
882 Cow::Borrowed(p.as_bytes())
883 }
884}
885
886#[cfg(test)]
887mod test {
888 use super::*;
889
890 #[test]
891 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
892 fn test_transit_ack() {
893 let f1 = TransitAck::new("ok", "deadbeaf");
894 assert_eq!(f1.serialize(), "{\"ack\":\"ok\",\"sha256\":\"deadbeaf\"}");
895 }
896}