1#![allow(clippy::suspicious_open_options)]
2
3use crate::{
4 common::{
5 get_files_available, receive_packet, send_packet, FileSendRecvTree, FilesAvailable,
6 PacketRecvError,
7 },
8 packets::{ReceiverToSender, SenderToReceiver},
9 BUF_SIZE, QS_ALPN, QS_PROTO_VERSION,
10};
11use async_compression::tokio::bufread::GzipDecoder;
12use std::{io, path::PathBuf};
13use thiserror::Error;
14use tokio::io::AsyncWriteExt;
15
16pub async fn receive_file<R, W>(
22 recv: &mut R,
23 file: &mut W,
24 skip: u64,
25 size: u64,
26 read_callback: &mut impl FnMut(u64),
27 should_continue: &mut impl FnMut() -> bool,
28) -> std::io::Result<bool>
29where
30 R: tokio::io::AsyncReadExt + Unpin,
31 W: tokio::io::AsyncWriteExt + tokio::io::AsyncSeekExt + Unpin,
32{
33 file.seek(tokio::io::SeekFrom::Start(skip)).await?;
34
35 let mut buf = vec![0; BUF_SIZE];
36 let mut written = skip;
37
38 while written < size {
39 if !should_continue() {
40 return Ok(false);
41 }
42
43 let to_write = std::cmp::min(BUF_SIZE as u64, size - written);
44 let n = recv.read_exact(&mut buf[..to_write as usize]).await?;
45
46 if n == 0 {
47 return Err(io::Error::new(
48 io::ErrorKind::UnexpectedEof,
49 "unexpected eof",
50 ));
51 }
52
53 file.write_all(&buf[..n]).await?;
54 written += n as u64;
55
56 read_callback(n as u64);
57 }
58
59 Ok(true)
60}
61
62pub fn receive_directory<S>(
66 send: &mut S,
67 root_path: &std::path::Path,
68 files: &[FileSendRecvTree],
69 read_callback: &mut impl FnMut(u64),
70 should_continue: &mut impl FnMut() -> bool,
71) -> std::io::Result<bool>
72where
73 S: tokio::io::AsyncReadExt + Unpin + Send,
74{
75 for file in files {
76 match file {
77 FileSendRecvTree::File { name, skip, size } => {
78 let path = root_path.join(name);
79
80 let continues = tokio::task::block_in_place(|| {
81 let rt = tokio::runtime::Runtime::new().unwrap();
82 rt.block_on(async {
83 let mut file = tokio::fs::OpenOptions::new()
84 .write(true)
85 .create(true)
86 .open(&path)
87 .await?;
88 let continues = receive_file(
89 send,
90 &mut file,
91 *skip,
92 *size,
93 read_callback,
94 should_continue,
95 )
96 .await?;
97
98 file.sync_all().await?;
99 file.shutdown().await?;
100 Ok::<bool, std::io::Error>(continues)
101 })
102 })?;
103
104 if !continues {
105 return Ok(false);
106 }
107 }
108 FileSendRecvTree::Dir { name, files } => {
109 let root_path = root_path.join(name);
110
111 if !root_path.exists() {
112 std::fs::create_dir(&root_path)?;
113 }
114
115 if !receive_directory(send, &root_path, files, read_callback, should_continue)? {
116 return Ok(false);
117 }
118 }
119 }
120 }
121
122 Ok(true)
123}
124
125#[derive(Debug, Error)]
126pub enum ReceiveError {
127 #[error("IO error: {0}")]
128 Io(#[from] std::io::Error),
129 #[error("connect error: {0}")]
130 Connect(String),
131 #[error("connection error: {0}")]
132 Connection(#[from] iroh::endpoint::ConnectionError),
133 #[error("write error: {0}")]
134 Write(#[from] quinn::WriteError),
135 #[error("read error {0}")]
136 Read(#[from] quinn::ReadError),
137 #[error("version mismatch, expected: {0}, got: {1}")]
138 WrongVersion(String, String),
139 #[error(
140 "wrong roundezvous protocol version, the roundezvous server expected {0}, but got: {1}"
141 )]
142 WrongRoundezvousVersion(u32, u32),
143 #[error("unexpected data packet: {0:?}")]
144 UnexpectedDataPacket(SenderToReceiver),
145 #[error("files rejected")]
146 FilesRejected,
147 #[error("invalid code")]
148 InvalidCode,
149 #[error("receive packet error: {0}")]
150 ReceivePacket(#[from] PacketRecvError),
151}
152
153pub struct Receiver {
155 args: ReceiverArgs,
157 conn: iroh::endpoint::Connection,
159 endpoint: iroh::Endpoint,
161}
162
163pub struct ReceiverArgs {
165 pub resume: bool,
167}
168
169impl Receiver {
170 pub async fn connect(
171 this_endpoint: iroh::Endpoint,
172 node_addr: iroh::NodeAddr,
173 args: ReceiverArgs,
174 ) -> Result<Self, ReceiveError> {
175 let conn = this_endpoint
176 .connect(node_addr, QS_ALPN)
177 .await
178 .map_err(|e| ReceiveError::Connect(e.to_string()))?;
179
180 tracing::info!("receiver connected to sender");
181
182 Ok(Self {
183 args,
184 conn,
185 endpoint: this_endpoint,
186 })
187 }
188
189 pub async fn close(&mut self) {
191 self.conn.close(0u32.into(), &[0]);
192 self.endpoint.close().await;
193 }
194
195 pub async fn wait_for_close(&mut self) {
197 self.conn.closed().await;
198 }
199
200 pub async fn connection_type(&self) -> Option<iroh::endpoint::ConnectionType> {
202 let node_id = self.conn.remote_node_id().ok()?;
203 self.endpoint.conn_type(node_id).ok()?.get().ok()
204 }
205
206 pub async fn receive_files(
217 &mut self,
218 mut initial_progress_callback: impl FnMut(&[(String, u64, u64)]),
219 mut accept_files_callback: impl FnMut(&[FilesAvailable]) -> Option<PathBuf>,
220 read_callback: &mut impl FnMut(u64),
221 should_continue: &mut impl FnMut() -> bool,
222 ) -> Result<bool, ReceiveError> {
223 match receive_packet::<SenderToReceiver>(&self.conn).await? {
224 SenderToReceiver::ConnRequest { version_num } => {
225 if version_num != QS_PROTO_VERSION {
226 send_packet(
227 ReceiverToSender::WrongVersion {
228 expected: QS_PROTO_VERSION.to_string(),
229 },
230 &self.conn,
231 )
232 .await?;
233 return Err(ReceiveError::WrongVersion(
234 QS_PROTO_VERSION.to_string(),
235 version_num,
236 ));
237 }
238 send_packet(ReceiverToSender::Ok, &self.conn).await?;
239 }
240 p => return Err(ReceiveError::UnexpectedDataPacket(p)),
241 }
242
243 let files_offered = match receive_packet::<SenderToReceiver>(&self.conn).await? {
244 SenderToReceiver::FileInfo { files } => files,
245 p => return Err(ReceiveError::UnexpectedDataPacket(p)),
246 };
247
248 let output_path = match accept_files_callback(&files_offered) {
249 Some(path) => path,
250 None => {
251 send_packet(ReceiverToSender::RejectFiles, &self.conn).await?;
252 self.wait_for_close().await;
254 return Err(ReceiveError::FilesRejected);
255 }
256 };
257
258 let files_available = {
259 let mut files = Vec::new();
260 for file in &files_offered {
261 let path = output_path.join(file.name());
262 files.push(get_files_available(&path).ok());
263 }
264
265 files
266 };
267
268 let files_to_skip = if self.args.resume {
269 let mut to_skip = Vec::new();
270 for (available, offered) in files_available.iter().zip(&files_offered) {
271 match available {
272 Some(available) => to_skip.push(offered.get_skippable(available)),
273 None => to_skip.push(None),
274 }
275 }
276
277 to_skip
278 } else {
279 vec![None; files_offered.len()]
281 };
282
283 let to_receive: Vec<Option<FileSendRecvTree>> = files_offered
284 .iter()
285 .zip(&files_to_skip)
286 .map(|(offered, skip)| {
287 if let Some(skip) = skip {
288 offered.remove_skipped(skip)
289 } else {
290 Some(offered.to_send_recv_tree())
291 }
292 })
293 .collect();
294
295 let mut progress: Vec<(String, u64, u64)> = Vec::with_capacity(to_receive.len());
297 for (offered, skip) in files_offered.iter().zip(&files_to_skip) {
298 progress.push((
299 offered.name().to_string(),
300 skip.as_ref().map(|s| s.skip()).unwrap_or(0),
301 offered.size(),
302 ));
303 }
304
305 initial_progress_callback(&progress);
306
307 send_packet(
308 ReceiverToSender::AcceptFilesSkip {
309 files: files_to_skip,
310 },
311 &self.conn,
312 )
313 .await?;
314
315 let recv = self.conn.accept_uni().await?;
316 let mut recv = GzipDecoder::new(tokio::io::BufReader::with_capacity(BUF_SIZE, recv));
317
318 let mut interrupted = false;
319
320 for file in to_receive.into_iter().flatten() {
321 match file {
322 FileSendRecvTree::File { name, skip, size } => {
323 let path = output_path.join(name);
324 let mut file = tokio::fs::OpenOptions::new()
325 .write(true)
326 .create(true)
327 .open(&path)
328 .await?;
329
330 interrupted = !receive_file(
331 &mut recv,
332 &mut file,
333 skip,
334 size,
335 read_callback,
336 should_continue,
337 )
338 .await?;
339 file.sync_all().await?;
340 file.shutdown().await?;
341
342 if interrupted {
343 break;
344 }
345 }
346 FileSendRecvTree::Dir { name, files } => {
347 let path = output_path.join(name);
348
349 if !path.exists() {
350 std::fs::create_dir(&path)?;
351 }
352
353 if !receive_directory(&mut recv, &path, &files, read_callback, should_continue)?
354 {
355 interrupted = true;
356 break;
357 }
358 }
359 }
360 }
361
362 self.close().await;
363
364 if interrupted {
365 tracing::info!("transfer interrupted");
366 }
367
368 Ok(!interrupted)
369 }
370}