qs_core/
receive.rs

1#![allow(clippy::suspicious_open_options)]
2
3use crate::{
4    common::{
5        get_files_available, receive_packet, send_packet, FileSendRecvTree, FilesAvailable,
6        PacketRecvError,
7    },
8    packets::{ReceiverToSender, SenderToReceiver},
9    BUF_SIZE, QS_ALPN, QS_PROTO_VERSION,
10};
11use async_compression::tokio::bufread::GzipDecoder;
12use std::{io, path::PathBuf};
13use thiserror::Error;
14use tokio::io::AsyncWriteExt;
15
16/// Generic receive function
17///
18/// # Returns
19/// * `Ok(true)` if the transfer should continue
20/// * `Ok(false)` if the transfer should stop
21pub async fn receive_file<R, W>(
22    recv: &mut R,
23    file: &mut W,
24    skip: u64,
25    size: u64,
26    read_callback: &mut impl FnMut(u64),
27    should_continue: &mut impl FnMut() -> bool,
28) -> std::io::Result<bool>
29where
30    R: tokio::io::AsyncReadExt + Unpin,
31    W: tokio::io::AsyncWriteExt + tokio::io::AsyncSeekExt + Unpin,
32{
33    file.seek(tokio::io::SeekFrom::Start(skip)).await?;
34
35    let mut buf = vec![0; BUF_SIZE];
36    let mut written = skip;
37
38    while written < size {
39        if !should_continue() {
40            return Ok(false);
41        }
42
43        let to_write = std::cmp::min(BUF_SIZE as u64, size - written);
44        let n = recv.read_exact(&mut buf[..to_write as usize]).await?;
45
46        if n == 0 {
47            return Err(io::Error::new(
48                io::ErrorKind::UnexpectedEof,
49                "unexpected eof",
50            ));
51        }
52
53        file.write_all(&buf[..n]).await?;
54        written += n as u64;
55
56        read_callback(n as u64);
57    }
58
59    Ok(true)
60}
61
62/// # Returns
63/// * `Ok(true)` if the transfer should continue
64/// * `Ok(false)` if the transfer should stop
65pub fn receive_directory<S>(
66    send: &mut S,
67    root_path: &std::path::Path,
68    files: &[FileSendRecvTree],
69    read_callback: &mut impl FnMut(u64),
70    should_continue: &mut impl FnMut() -> bool,
71) -> std::io::Result<bool>
72where
73    S: tokio::io::AsyncReadExt + Unpin + Send,
74{
75    for file in files {
76        match file {
77            FileSendRecvTree::File { name, skip, size } => {
78                let path = root_path.join(name);
79
80                let continues = tokio::task::block_in_place(|| {
81                    let rt = tokio::runtime::Runtime::new().unwrap();
82                    rt.block_on(async {
83                        let mut file = tokio::fs::OpenOptions::new()
84                            .write(true)
85                            .create(true)
86                            .open(&path)
87                            .await?;
88                        let continues = receive_file(
89                            send,
90                            &mut file,
91                            *skip,
92                            *size,
93                            read_callback,
94                            should_continue,
95                        )
96                        .await?;
97
98                        file.sync_all().await?;
99                        file.shutdown().await?;
100                        Ok::<bool, std::io::Error>(continues)
101                    })
102                })?;
103
104                if !continues {
105                    return Ok(false);
106                }
107            }
108            FileSendRecvTree::Dir { name, files } => {
109                let root_path = root_path.join(name);
110
111                if !root_path.exists() {
112                    std::fs::create_dir(&root_path)?;
113                }
114
115                if !receive_directory(send, &root_path, files, read_callback, should_continue)? {
116                    return Ok(false);
117                }
118            }
119        }
120    }
121
122    Ok(true)
123}
124
125#[derive(Debug, Error)]
126pub enum ReceiveError {
127    #[error("IO error: {0}")]
128    Io(#[from] std::io::Error),
129    #[error("connect error: {0}")]
130    Connect(String),
131    #[error("connection error: {0}")]
132    Connection(#[from] iroh::endpoint::ConnectionError),
133    #[error("write error: {0}")]
134    Write(#[from] quinn::WriteError),
135    #[error("read error {0}")]
136    Read(#[from] quinn::ReadError),
137    #[error("version mismatch, expected: {0}, got: {1}")]
138    WrongVersion(String, String),
139    #[error(
140        "wrong roundezvous protocol version, the roundezvous server expected {0}, but got: {1}"
141    )]
142    WrongRoundezvousVersion(u32, u32),
143    #[error("unexpected data packet: {0:?}")]
144    UnexpectedDataPacket(SenderToReceiver),
145    #[error("files rejected")]
146    FilesRejected,
147    #[error("invalid code")]
148    InvalidCode,
149    #[error("receive packet error: {0}")]
150    ReceivePacket(#[from] PacketRecvError),
151}
152
153/// A receiver that can receive files
154pub struct Receiver {
155    /// Receiver arguments
156    args: ReceiverArgs,
157    /// The connection to the sender
158    conn: iroh::endpoint::Connection,
159    /// The local endpoint
160    endpoint: iroh::Endpoint,
161}
162
163/// Arguments for the receiver
164pub struct ReceiverArgs {
165    /// Resume interrupted transfer
166    pub resume: bool,
167}
168
169impl Receiver {
170    pub async fn connect(
171        this_endpoint: iroh::Endpoint,
172        node_addr: iroh::NodeAddr,
173        args: ReceiverArgs,
174    ) -> Result<Self, ReceiveError> {
175        let conn = this_endpoint
176            .connect(node_addr, QS_ALPN)
177            .await
178            .map_err(|e| ReceiveError::Connect(e.to_string()))?;
179
180        tracing::info!("receiver connected to sender");
181
182        Ok(Self {
183            args,
184            conn,
185            endpoint: this_endpoint,
186        })
187    }
188
189    /// Close the connection
190    pub async fn close(&mut self) {
191        self.conn.close(0u32.into(), &[0]);
192        self.endpoint.close().await;
193    }
194
195    /// Wait for the other peer to close the connection
196    pub async fn wait_for_close(&mut self) {
197        self.conn.closed().await;
198    }
199
200    /// Get the type of the connection
201    pub async fn connection_type(&self) -> Option<iroh::endpoint::ConnectionType> {
202        let node_id = self.conn.remote_node_id().ok()?;
203        self.endpoint.conn_type(node_id).ok()?.get().ok()
204    }
205
206    /// Receive files
207    /// # Arguments
208    /// * `initial_progress_callback` - Callback with the initial progress of each file to send (name, current, total)
209    /// * `accept_files_callback` - Callback to accept or reject the files (Some(path) to accept, None to reject)
210    /// * `read_callback` - Callback every time data is written to disk
211    /// * `should_continue` - Callback to check if the transfer should continue
212    ///
213    /// # Returns
214    /// * `Ok(true)` if the transfer was finished successfully
215    /// * `Ok(false)` if the transfer was stopped
216    pub async fn receive_files(
217        &mut self,
218        mut initial_progress_callback: impl FnMut(&[(String, u64, u64)]),
219        mut accept_files_callback: impl FnMut(&[FilesAvailable]) -> Option<PathBuf>,
220        read_callback: &mut impl FnMut(u64),
221        should_continue: &mut impl FnMut() -> bool,
222    ) -> Result<bool, ReceiveError> {
223        match receive_packet::<SenderToReceiver>(&self.conn).await? {
224            SenderToReceiver::ConnRequest { version_num } => {
225                if version_num != QS_PROTO_VERSION {
226                    send_packet(
227                        ReceiverToSender::WrongVersion {
228                            expected: QS_PROTO_VERSION.to_string(),
229                        },
230                        &self.conn,
231                    )
232                    .await?;
233                    return Err(ReceiveError::WrongVersion(
234                        QS_PROTO_VERSION.to_string(),
235                        version_num,
236                    ));
237                }
238                send_packet(ReceiverToSender::Ok, &self.conn).await?;
239            }
240            p => return Err(ReceiveError::UnexpectedDataPacket(p)),
241        }
242
243        let files_offered = match receive_packet::<SenderToReceiver>(&self.conn).await? {
244            SenderToReceiver::FileInfo { files } => files,
245            p => return Err(ReceiveError::UnexpectedDataPacket(p)),
246        };
247
248        let output_path = match accept_files_callback(&files_offered) {
249            Some(path) => path,
250            None => {
251                send_packet(ReceiverToSender::RejectFiles, &self.conn).await?;
252                // Wait for the sender to acknowledge the rejection
253                self.wait_for_close().await;
254                return Err(ReceiveError::FilesRejected);
255            }
256        };
257
258        let files_available = {
259            let mut files = Vec::new();
260            for file in &files_offered {
261                let path = output_path.join(file.name());
262                files.push(get_files_available(&path).ok());
263            }
264
265            files
266        };
267
268        let files_to_skip = if self.args.resume {
269            let mut to_skip = Vec::new();
270            for (available, offered) in files_available.iter().zip(&files_offered) {
271                match available {
272                    Some(available) => to_skip.push(offered.get_skippable(available)),
273                    None => to_skip.push(None),
274                }
275            }
276
277            to_skip
278        } else {
279            // Don't skip any files
280            vec![None; files_offered.len()]
281        };
282
283        let to_receive: Vec<Option<FileSendRecvTree>> = files_offered
284            .iter()
285            .zip(&files_to_skip)
286            .map(|(offered, skip)| {
287                if let Some(skip) = skip {
288                    offered.remove_skipped(skip)
289                } else {
290                    Some(offered.to_send_recv_tree())
291                }
292            })
293            .collect();
294
295        // progress callback
296        let mut progress: Vec<(String, u64, u64)> = Vec::with_capacity(to_receive.len());
297        for (offered, skip) in files_offered.iter().zip(&files_to_skip) {
298            progress.push((
299                offered.name().to_string(),
300                skip.as_ref().map(|s| s.skip()).unwrap_or(0),
301                offered.size(),
302            ));
303        }
304
305        initial_progress_callback(&progress);
306
307        send_packet(
308            ReceiverToSender::AcceptFilesSkip {
309                files: files_to_skip,
310            },
311            &self.conn,
312        )
313        .await?;
314
315        let recv = self.conn.accept_uni().await?;
316        let mut recv = GzipDecoder::new(tokio::io::BufReader::with_capacity(BUF_SIZE, recv));
317
318        let mut interrupted = false;
319
320        for file in to_receive.into_iter().flatten() {
321            match file {
322                FileSendRecvTree::File { name, skip, size } => {
323                    let path = output_path.join(name);
324                    let mut file = tokio::fs::OpenOptions::new()
325                        .write(true)
326                        .create(true)
327                        .open(&path)
328                        .await?;
329
330                    interrupted = !receive_file(
331                        &mut recv,
332                        &mut file,
333                        skip,
334                        size,
335                        read_callback,
336                        should_continue,
337                    )
338                    .await?;
339                    file.sync_all().await?;
340                    file.shutdown().await?;
341
342                    if interrupted {
343                        break;
344                    }
345                }
346                FileSendRecvTree::Dir { name, files } => {
347                    let path = output_path.join(name);
348
349                    if !path.exists() {
350                        std::fs::create_dir(&path)?;
351                    }
352
353                    if !receive_directory(&mut recv, &path, &files, read_callback, should_continue)?
354                    {
355                        interrupted = true;
356                        break;
357                    }
358                }
359            }
360        }
361
362        self.close().await;
363
364        if interrupted {
365            tracing::info!("transfer interrupted");
366        }
367
368        Ok(!interrupted)
369    }
370}