1use futures::{
2 StreamExt, TryFutureExt,
3 io::{AsyncReadExt, AsyncWriteExt},
4};
5use sha2::{Digest, Sha256, digest::FixedOutput};
6
7use crate::transit::TransitRole;
8
9use super::{offer::*, *};
10
11#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
12#[serde(rename_all = "kebab-case")]
13pub enum OfferMessage {
14 Message(String),
15 File {
16 filename: String,
17 filesize: u64,
18 },
19 Directory {
20 dirname: String,
21 mode: String,
22 zipsize: u64,
23 numbytes: u64,
24 numfiles: u64,
25 },
26 #[serde(other)]
27 Unknown,
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
31#[serde(rename_all = "snake_case")]
32pub enum AnswerMessage {
33 MessageAck(String),
34 FileAck(String),
35}
36
37#[derive(Serialize, Deserialize, Debug, Clone)]
41#[serde(rename_all = "kebab-case")]
42pub struct TransitV1 {
43 pub abilities_v1: TransitAbilities,
44 pub hints_v1: transit::Hints,
45}
46
47#[derive(Serialize, Deserialize, Debug, PartialEq)]
48#[serde(rename_all = "kebab-case")]
49struct TransitAck {
50 pub ack: String,
51 pub sha256: String,
52}
53
54impl TransitAck {
55 pub(crate) fn new(msg: impl Into<String>, sha256: impl Into<String>) -> Self {
56 TransitAck {
57 ack: msg.into(),
58 sha256: sha256.into(),
59 }
60 }
61
62 #[cfg(test)]
63 pub(crate) fn serialize(&self) -> String {
64 json!(self).to_string()
65 }
66
67 pub(crate) fn serialize_vec(&self) -> Vec<u8> {
68 serde_json::to_vec(self).unwrap()
69 }
70}
71
72pub(crate) async fn send(
73 wormhole: Wormhole,
74 relay_hints: Vec<transit::RelayHint>,
75 transit_abilities: transit::Abilities,
76 offer: OfferSend,
77 progress_handler: impl FnMut(u64, u64) + 'static,
78 transit_handler: impl FnOnce(transit::TransitInfo),
79 _peer_version: AppVersion,
80 cancel: impl Future<Output = ()>,
81) -> Result<(), TransferError> {
82 if offer.is_multiple() {
83 let folder = OfferSendEntry::Directory {
84 content: offer.content,
85 };
86 send_folder(
87 wormhole,
88 relay_hints,
89 "<unnamed folder>".into(),
90 folder,
91 transit_abilities,
92 transit_handler,
93 progress_handler,
94 cancel,
95 )
96 .await
97 } else if offer.is_directory() {
98 let (folder_name, folder) = offer.content.into_iter().next().unwrap();
99 send_folder(
100 wormhole,
101 relay_hints,
102 folder_name,
103 folder,
104 transit_abilities,
105 transit_handler,
106 progress_handler,
107 cancel,
108 )
109 .await
110 } else {
111 let (file_name, file) = offer.content.into_iter().next().unwrap();
112 let (mut file, file_size) = match file {
113 OfferSendEntry::RegularFile { content, size } => {
114 let content = content();
116 let content = content.await?;
117 (content, size)
118 },
119 _ => unreachable!(),
120 };
121 send_file(
122 wormhole,
123 relay_hints,
124 &mut file,
125 file_name,
126 file_size,
127 transit_abilities,
128 transit_handler,
129 progress_handler,
130 cancel,
131 )
132 .await
133 }
134}
135
136pub(crate) async fn send_file<F, G, H>(
137 mut wormhole: Wormhole,
138 relay_hints: Vec<transit::RelayHint>,
139 file: &mut F,
140 file_name: impl Into<String>,
141 file_size: u64,
142 transit_abilities: transit::Abilities,
143 transit_handler: G,
144 progress_handler: H,
145 cancel: impl Future<Output = ()>,
146) -> Result<(), TransferError>
147where
148 F: AsyncRead + Unpin + Send,
149 G: FnOnce(transit::TransitInfo),
150 H: FnMut(u64, u64) + 'static,
151{
152 let run = Box::pin(async {
153 let connector = transit::init(transit_abilities, None, relay_hints).await?;
154
155 tracing::debug!("Sending transit message '{:?}", connector.our_hints());
157 wormhole
158 .send_json(&PeerMessage::transit_v1(
159 *connector.our_abilities(),
160 (**connector.our_hints()).clone(),
161 ))
162 .await?;
163
164 tracing::debug!("Sending file offer");
166 wormhole
167 .send_json(&PeerMessage::offer_file_v1(file_name, file_size))
168 .await?;
169
170 let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
172 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
173 PeerMessage::Transit(transit) => {
174 tracing::debug!("Received transit message: {:?}", transit);
175 (transit.abilities_v1, transit.hints_v1)
176 },
177 other => {
178 bail!(TransferError::unexpected_message("transit", other))
179 },
180 };
181
182 {
183 let fileack_msg = wormhole.receive_json::<PeerMessage>().await??;
185 tracing::debug!("Received file ack message: {:?}", fileack_msg);
186
187 match fileack_msg.check_err()? {
188 PeerMessage::Answer(AnswerMessage::FileAck(msg)) => {
189 ensure!(msg == "ok", TransferError::AckError);
190 },
191 _ => {
192 bail!(TransferError::unexpected_message(
193 "answer/file_ack",
194 fileack_msg
195 ));
196 },
197 }
198 }
199
200 let (mut transit, info) = connector
201 .connect(
202 TransitRole::Leader,
203 wormhole.key().derive_transit_key(wormhole.appid()),
204 their_abilities,
205 Arc::new(their_hints),
206 )
207 .await?;
208 transit_handler(info);
209
210 tracing::debug!("Beginning file transfer");
211
212 let file = futures::stream::once(futures::future::ready(std::io::Result::Ok(
214 Box::new(file) as Box<dyn AsyncRead + Unpin + Send>,
215 )));
216 let checksum = v1::send_records(&mut transit, file, file_size, progress_handler).await?;
217
218 tracing::debug!("sent file. Waiting for ack");
220 let transit_ack = transit.receive_record().await?;
221 let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
222 ensure!(
223 transit_ack_msg.sha256 == hex::encode(checksum),
224 TransferError::Checksum
225 );
226 tracing::debug!("Transfer complete!");
227
228 Ok(())
229 });
230
231 futures::pin_mut!(cancel);
232 let result = cancel::cancellable_2(run, cancel).await;
233 cancel::handle_run_result(wormhole, result).await
234}
235
236pub(crate) async fn send_folder(
237 mut wormhole: Wormhole,
238 relay_hints: Vec<transit::RelayHint>,
239 mut folder_name: String,
240 folder: OfferSendEntry,
241 transit_abilities: transit::Abilities,
242 transit_handler: impl FnOnce(transit::TransitInfo),
243 progress_handler: impl FnMut(u64, u64) + 'static,
244 cancel: impl Future<Output = ()>,
245) -> Result<(), TransferError> {
246 let run = Box::pin(async {
247 let connector = transit::init(transit_abilities, None, relay_hints).await?;
248
249 tracing::debug!("Sending transit message '{:?}", connector.our_hints());
251 wormhole
252 .send_json(&PeerMessage::transit_v1(
253 *connector.our_abilities(),
254 (**connector.our_hints()).clone(),
255 ))
256 .await?;
257
258 tracing::debug!("Estimating the file size");
263
264 use futures::{
266 future::{BoxFuture, ready},
267 io::Cursor,
268 };
269 use std::io::Result as IoResult;
270
271 type WrappedDataFut = BoxFuture<'static, IoResult<Box<dyn AsyncRead + Unpin + Send>>>;
272
273 fn wrap(buffer: impl AsRef<[u8]> + Unpin + Send + 'static) -> WrappedDataFut {
275 Box::pin(ready(IoResult::Ok(
276 Box::new(Cursor::new(buffer)) as Box<dyn AsyncRead + Unpin + Send>
277 ))) as _
278 }
279
280 fn create_offer(
282 mut total_content: Vec<WrappedDataFut>,
283 total_size: &mut u64,
284 offer: OfferSendEntry,
285 path: &mut Vec<String>,
286 ) -> IoResult<Vec<WrappedDataFut>> {
287 match offer {
288 OfferSendEntry::Directory { content } => {
289 tracing::debug!("Adding directory {path:?}");
290 let header = tar_helper::create_header_directory(path)?;
291 *total_size += header.len() as u64;
292 total_content.push(wrap(header));
293
294 for (name, file) in content {
295 path.push(name);
296 total_content = create_offer(total_content, total_size, file, path)?;
297 path.pop();
298 }
299 },
300 OfferSendEntry::RegularFile { size, content } => {
301 tracing::debug!("Adding file {path:?}; {size} bytes");
302 let header = tar_helper::create_header_file(path, size)?;
303 let padding = tar_helper::padding(size);
304 *total_size += header.len() as u64;
305 *total_size += padding.len() as u64;
306 *total_size += size;
307
308 total_content.push(wrap(header));
309 let content = content().map_ok(
310 |read| Box::new(read) as Box<dyn AsyncRead + Unpin + Send>,
312 );
313 total_content.push(Box::pin(content) as _);
314 total_content.push(wrap(padding));
315 },
316 }
318 Ok(total_content)
319 }
320
321 let mut total_size = 0;
322 let mut content = create_offer(
323 Vec::new(),
324 &mut total_size,
325 folder,
326 &mut vec![folder_name.clone()],
327 )?;
328
329 total_size += 1024;
331 content.push(wrap([0; 1024]));
332
333 let content = futures::stream::iter(content).then(|content| content);
334
335 tracing::debug!("Sending file offer ({total_size} bytes)");
339 folder_name.push_str(".tar");
340 wormhole
341 .send_json(&PeerMessage::offer_file_v1(folder_name, total_size))
342 .await?;
343
344 let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
346 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
347 PeerMessage::Transit(transit) => {
348 tracing::debug!("received transit message: {:?}", transit);
349 (transit.abilities_v1, transit.hints_v1)
350 },
351 other => {
352 bail!(TransferError::unexpected_message("transit", other));
353 },
354 };
355
356 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
358 PeerMessage::Answer(AnswerMessage::FileAck(msg)) => {
359 ensure!(msg == "ok", TransferError::AckError);
360 },
361 other => {
362 bail!(TransferError::unexpected_message("answer/file_ack", other));
363 },
364 }
365
366 let (mut transit, info) = connector
367 .connect(
368 TransitRole::Leader,
369 wormhole.key().derive_transit_key(wormhole.appid()),
370 their_abilities,
371 Arc::new(their_hints),
372 )
373 .await?;
374 transit_handler(info);
375
376 tracing::debug!("Beginning file transfer");
377
378 let checksum =
380 v1::send_records(&mut transit, content, total_size, progress_handler).await?;
381
382 tracing::debug!("sent file. Waiting for ack");
384 let transit_ack = transit.receive_record().await?;
385 let transit_ack_msg = serde_json::from_slice::<TransitAck>(&transit_ack)?;
386 ensure!(
387 transit_ack_msg.sha256 == hex::encode(checksum),
388 TransferError::Checksum
389 );
390 tracing::debug!("Transfer complete!");
391
392 Ok(())
393 });
394
395 futures::pin_mut!(cancel);
396 let result = cancel::cancellable_2(run, cancel).await;
397 cancel::handle_run_result(wormhole, result).await
398}
399
400pub async fn request(
409 mut wormhole: Wormhole,
410 relay_hints: Vec<transit::RelayHint>,
411 transit_abilities: transit::Abilities,
412 cancel: impl Future<Output = ()>,
413) -> Result<Option<ReceiveRequest>, TransferError> {
414 let run = Box::pin(async {
416 let connector = transit::init(transit_abilities, None, relay_hints).await?;
417
418 tracing::debug!("Sending transit message '{:?}", connector.our_hints());
420 wormhole
421 .send_json(&PeerMessage::transit_v1(
422 *connector.our_abilities(),
423 (**connector.our_hints()).clone(),
424 ))
425 .await?;
426
427 let (their_abilities, their_hints): (transit::Abilities, transit::Hints) =
429 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
430 PeerMessage::Transit(transit) => {
431 tracing::debug!("received transit message: {:?}", transit);
432 (transit.abilities_v1, transit.hints_v1)
433 },
434 other => {
435 bail!(TransferError::unexpected_message("transit", other));
436 },
437 };
438
439 let (filename, filesize) =
441 match wormhole.receive_json::<PeerMessage>().await??.check_err()? {
442 PeerMessage::Offer(offer_type) => match offer_type {
443 v1::OfferMessage::File { filename, filesize } => (filename, filesize),
444 v1::OfferMessage::Directory {
445 mut dirname,
446 zipsize,
447 ..
448 } => {
449 dirname.push_str(".zip");
450 (dirname, zipsize)
451 },
452 _ => bail!(TransferError::UnsupportedOffer),
453 },
454 other => {
455 bail!(TransferError::unexpected_message("offer", other));
456 },
457 };
458
459 Ok((filename, filesize, connector, their_abilities, their_hints))
460 });
461
462 futures::pin_mut!(cancel);
463 let result = cancel::cancellable_2(run, cancel).await;
464 cancel::handle_run_result_noclose(wormhole, result)
465 .await
466 .map(|inner: Option<_>| {
467 inner.map(
468 |((filename, filesize, connector, their_abilities, their_hints), wormhole, _)| {
469 ReceiveRequest::new(
470 filename,
471 filesize,
472 connector,
473 their_abilities,
474 their_hints,
475 wormhole,
476 )
477 },
478 )
479 })
480}
481
482#[must_use]
488pub struct ReceiveRequest {
489 wormhole: Wormhole,
490 connector: TransitConnector,
491
492 file_name: String,
493
494 filesize: u64,
496
497 #[allow(dead_code)]
498 offer: Arc<Offer>,
499 their_abilities: transit::Abilities,
500 their_hints: Arc<transit::Hints>,
501}
502
503impl ReceiveRequest {
504 fn new(
505 file_name: String,
506 filesize: u64,
507 connector: TransitConnector,
508 their_abilities: transit::Abilities,
509 their_hints: transit::Hints,
510 wormhole: Wormhole,
511 ) -> Self {
512 let their_hints = Arc::new(their_hints);
513 let mut content = BTreeMap::new();
514
515 content.insert(
517 file_name.clone(),
518 OfferEntry::RegularFile {
519 size: filesize,
520 content: (),
521 },
522 );
523
524 let offer = Arc::new(Offer { content });
525
526 Self {
527 wormhole,
528 connector,
529 file_name,
530 filesize,
531 offer,
532 their_abilities,
533 their_hints,
534 }
535 }
536
537 pub async fn accept<F, G, W>(
543 mut self,
544 transit_handler: G,
545 progress_handler: F,
546 content_handler: &mut W,
547 cancel: impl Future<Output = ()>,
548 ) -> Result<(), TransferError>
549 where
550 F: FnMut(u64, u64) + 'static,
551 G: FnOnce(transit::TransitInfo),
552 W: AsyncWrite + Unpin,
553 {
554 let run = Box::pin(async {
555 tracing::debug!("Sending ack");
557 self.wormhole
558 .send_json(&PeerMessage::file_ack_v1("ok"))
559 .await?;
560
561 let (mut transit, info) = self
562 .connector
563 .connect(
564 TransitRole::Follower,
565 self.wormhole
566 .key()
567 .derive_transit_key(self.wormhole.appid()),
568 self.their_abilities,
569 self.their_hints.clone(),
570 )
571 .await?;
572 transit_handler(info);
573
574 tracing::debug!("Beginning file transfer");
575 tcp_file_receive(
576 &mut transit,
577 self.filesize,
578 progress_handler,
579 content_handler,
580 )
581 .await?;
582 Ok(())
583 });
584
585 futures::pin_mut!(cancel);
586 let result = cancel::cancellable_2(run, cancel).await;
587 cancel::handle_run_result(self.wormhole, result).await
588 }
589
590 pub async fn reject(mut self) -> Result<(), TransferError> {
596 self.wormhole
597 .send_json(&PeerMessage::error_message("transfer rejected"))
598 .await?;
599 self.wormhole.close().await?;
600
601 Ok(())
602 }
603
604 #[cfg(feature = "experimental-transfer-v2")]
605 #[allow(missing_docs)]
606 pub fn offer(&self) -> Arc<Offer> {
607 self.offer.clone()
608 }
609
610 pub fn file_name(&self) -> String {
614 self.file_name.clone()
615 }
616
617 pub fn file_size(&self) -> u64 {
619 self.filesize
620 }
621}
622
623pub(crate) async fn send_records<'a>(
626 transit: &mut Transit,
627 files: impl futures::Stream<Item = std::io::Result<Box<dyn AsyncRead + Unpin + Send + 'a>>>,
628 file_size: u64,
629 mut progress_handler: impl FnMut(u64, u64) + 'static,
630) -> Result<Vec<u8>, TransferError> {
631 progress_handler(0, file_size);
642
643 let mut hasher = Sha256::default();
644
645 let mut plaintext = vec![0u8; 16 * 1024].into_boxed_slice();
646 let mut sent_size = 0;
647 futures::pin_mut!(files);
648 while let Some(mut file) = files.next().await.transpose()? {
649 loop {
650 let n = file.read(&mut plaintext[..]).await?;
652
653 if n == 0 {
654 break;
656 }
657
658 transit.send_record(&plaintext[0..n]).await?;
660 sent_size += n as u64;
661 progress_handler(sent_size, file_size);
662
663 hasher.update(&plaintext[..n]);
665
666 }
671 }
672 transit.flush().await?;
673
674 ensure!(
675 sent_size == file_size,
676 TransferError::FileSize {
677 sent_size,
678 file_size
679 }
680 );
681
682 Ok(hasher.finalize_fixed().to_vec())
683}
684
685pub(crate) async fn receive_records<F, W>(
686 filesize: u64,
687 transit: &mut Transit,
688 mut progress_handler: F,
689 mut content_handler: W,
690) -> Result<Vec<u8>, TransferError>
691where
692 F: FnMut(u64, u64) + 'static,
693 W: AsyncWrite + Unpin,
694{
695 let mut hasher = Sha256::default();
696 let total = filesize;
697
698 let mut remaining_size = filesize as usize;
699
700 progress_handler(0, total);
703
704 while remaining_size > 0 {
705 let plaintext = transit.receive_record().await?;
707
708 content_handler.write_all(&plaintext).await?;
709
710 hasher.update(&plaintext);
712
713 remaining_size -= plaintext.len();
714
715 let remaining = remaining_size as u64;
716 progress_handler(total - remaining, total);
717 }
718 content_handler.close().await?;
719
720 tracing::debug!("done");
721 Ok(hasher.finalize_fixed().to_vec())
723}
724
725pub(crate) async fn tcp_file_receive<F, W>(
726 transit: &mut Transit,
727 filesize: u64,
728 progress_handler: F,
729 content_handler: &mut W,
730) -> Result<(), TransferError>
731where
732 F: FnMut(u64, u64) + 'static,
733 W: AsyncWrite + Unpin,
734{
735 let checksum = receive_records(filesize, transit, progress_handler, content_handler).await?;
739
740 let sha256sum = hex::encode(checksum.as_slice());
741 tracing::debug!("sha256 sum: {:?}", sha256sum);
742
743 transit
745 .send_record(&TransitAck::new("ok", &sha256sum).serialize_vec())
746 .await?;
747
748 tracing::debug!("Transfer complete");
751 Ok(())
752}
753
754mod tar_helper {
756 use std::{
758 borrow::Cow,
759 io::{self, Read, Write},
760 str,
761 };
762
763 pub(crate) fn create_header_file(path: &[String], size: u64) -> std::io::Result<Vec<u8>> {
764 let mut header = tar::Header::new_gnu();
765 header.set_size(size);
766 let mut data = Vec::with_capacity(1024);
767 prepare_header_path(&mut data, &mut header, path.join("/").as_ref())?;
768 header.set_mode(0o644);
769 header.set_cksum();
770 data.write_all(header.as_bytes())?;
771 Ok(data)
772 }
773
774 pub(crate) fn create_header_directory(path: &[String]) -> std::io::Result<Vec<u8>> {
775 let mut header = tar::Header::new_gnu();
776 header.set_entry_type(tar::EntryType::Directory);
777 let mut data = Vec::with_capacity(1024);
778 prepare_header_path(&mut data, &mut header, path.join("/").as_ref())?;
779 header.set_mode(0o755);
780 header.set_cksum();
781 data.write_all(header.as_bytes())?;
782 Ok(data)
784 }
785
786 pub(crate) fn padding(size: u64) -> &'static [u8] {
787 const BLOCK: [u8; 512] = [0; 512];
788 if !size.is_multiple_of(512) {
789 &BLOCK[size as usize % 512..]
790 } else {
791 &[]
792 }
793 }
794
795 fn append(
796 mut dst: &mut dyn std::io::Write,
797 header: &tar::Header,
798 mut data: &mut dyn std::io::Read,
799 ) -> std::io::Result<()> {
800 dst.write_all(header.as_bytes())?;
801 let len = std::io::copy(&mut data, &mut dst)?;
802 dst.write_all(padding(len))?;
803 Ok(())
804 }
805
806 fn prepare_header(size: u64, entry_type: u8) -> tar::Header {
807 let mut header = tar::Header::new_gnu();
808 let name = b"././@LongLink";
809 header.as_gnu_mut().unwrap().name[..name.len()].clone_from_slice(&name[..]);
810 header.set_mode(0o644);
811 header.set_uid(0);
812 header.set_gid(0);
813 header.set_mtime(0);
814 header.set_size(size + 1);
816 header.set_entry_type(tar::EntryType::new(entry_type));
817 header.set_cksum();
818 header
819 }
820
821 fn prepare_header_path(
822 dst: &mut dyn std::io::Write,
823 header: &mut tar::Header,
824 path: &str,
825 ) -> std::io::Result<()> {
826 if let Err(e) = header.set_path(path) {
831 let data = path2bytes(path);
832 let max = header.as_old().name.len();
833 if data.len() < max {
836 return Err(e);
837 }
838 let header2 = prepare_header(data.len() as u64, b'L');
839 let mut data2 = data.chain(io::repeat(0).take(1));
841 append(dst, &header2, &mut data2)?;
842
843 let truncated = match std::str::from_utf8(&data[..max]) {
849 Ok(s) => s,
850 Err(e) => std::str::from_utf8(&data[..e.valid_up_to()]).unwrap(),
851 };
852 header.set_path(truncated)?;
853 }
854 Ok(())
855 }
856
857 #[cfg(any(windows, target_arch = "wasm32"))]
858 pub(crate) fn path2bytes(p: &str) -> Cow<'_, [u8]> {
859 let bytes = p.as_bytes();
860 if bytes.contains(&b'\\') {
861 let mut bytes = bytes.to_owned();
863 for b in &mut bytes {
864 if *b == b'\\' {
865 *b = b'/';
866 }
867 }
868 Cow::Owned(bytes)
869 } else {
870 Cow::Borrowed(bytes)
871 }
872 }
873
874 #[cfg(unix)]
875 pub(crate) fn path2bytes(p: &str) -> Cow<'_, [u8]> {
876 Cow::Borrowed(p.as_bytes())
877 }
878}
879
880#[cfg(test)]
881mod test {
882 use super::*;
883
884 #[test]
885 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
886 fn test_transit_ack() {
887 let f1 = TransitAck::new("ok", "deadbeaf");
888 assert_eq!(f1.serialize(), "{\"ack\":\"ok\",\"sha256\":\"deadbeaf\"}");
889 }
890}