qs_core/
send.rs

1#![allow(clippy::suspicious_open_options)]
2
3use crate::{
4    common::{get_files_available, receive_packet, send_packet, FileSendRecvTree, PacketRecvError},
5    packets::{ReceiverToSender, SenderToReceiver},
6    BUF_SIZE, QS_PROTO_VERSION,
7};
8use async_compression::tokio::write::GzipEncoder;
9use std::path::PathBuf;
10use thiserror::Error;
11use tokio::io::AsyncWriteExt;
12
13/// Generic send function
14///
15/// # Returns
16/// * `Ok(true)` if the transfer should continue
17/// * `Ok(false)` if the transfer should stop
18pub async fn send_file<S, R>(
19    send: &mut S,
20    file: &mut R,
21    skip: u64,
22    size: u64,
23    write_callback: &mut impl FnMut(u64),
24    should_continue: &mut impl FnMut() -> bool,
25) -> std::io::Result<bool>
26where
27    S: tokio::io::AsyncWriteExt + Unpin,
28    R: tokio::io::AsyncReadExt + tokio::io::AsyncSeekExt + Unpin,
29{
30    file.seek(tokio::io::SeekFrom::Start(skip)).await?;
31
32    let mut buf = vec![0; BUF_SIZE];
33    let mut read = skip;
34
35    while read < size {
36        if !should_continue() {
37            return Ok(false);
38        }
39
40        let to_read = std::cmp::min(BUF_SIZE as u64, size - read);
41        let n = file.read_exact(&mut buf[..to_read as usize]).await?;
42
43        if n == 0 {
44            return Err(std::io::Error::new(
45                std::io::ErrorKind::UnexpectedEof,
46                "unexpected eof",
47            ));
48        }
49
50        send.write_all(&buf[..n]).await?;
51        read += n as u64;
52
53        write_callback(n as u64);
54    }
55
56    Ok(true)
57}
58
59/// # Returns
60/// * `Ok(true)` if the transfer should continue
61/// * `Ok(false)` if the transfer should stop
62pub fn send_directory<S>(
63    send: &mut S,
64    root_path: &std::path::Path,
65    files: &[FileSendRecvTree],
66    write_callback: &mut impl FnMut(u64),
67    should_continue: &mut impl FnMut() -> bool,
68) -> std::io::Result<bool>
69where
70    S: tokio::io::AsyncWriteExt + Unpin + Send,
71{
72    for file in files {
73        match file {
74            FileSendRecvTree::File { name, skip, size } => {
75                let path = root_path.join(name);
76
77                let continues = tokio::task::block_in_place(|| {
78                    let rt = tokio::runtime::Runtime::new().unwrap();
79                    rt.block_on(async {
80                        let mut file = tokio::fs::OpenOptions::new().read(true).open(&path).await?;
81
82                        if !send_file(
83                            send,
84                            &mut file,
85                            *skip,
86                            *size,
87                            write_callback,
88                            should_continue,
89                        )
90                        .await?
91                        {
92                            return Ok::<bool, std::io::Error>(false);
93                        }
94
95                        file.shutdown().await?;
96                        Ok::<bool, std::io::Error>(true)
97                    })
98                })?;
99
100                if !continues {
101                    return Ok(false);
102                }
103            }
104            FileSendRecvTree::Dir { name, files } => {
105                let root_path = root_path.join(name);
106                if !send_directory(send, &root_path, files, write_callback, should_continue)? {
107                    return Ok(false);
108                };
109            }
110        }
111    }
112
113    Ok(true)
114}
115
116#[derive(Debug, Error)]
117pub enum SendError {
118    #[error("files do not exist: {0}")]
119    FileDoesNotExists(PathBuf),
120    #[error("IO error: {0}")]
121    Io(#[from] std::io::Error),
122    // #[error("connect error: {0}")]
123    // Connect(#[from] iroh::endpoint::ConnectError),
124    #[error("connection error: {0}")]
125    Connection(#[from] iroh::endpoint::ConnectionError),
126    #[error("read error: {0}")]
127    Read(#[from] quinn::ReadError),
128    #[error("wrong version, the receiver expected: {0}, but got: {1}")]
129    WrongVersion(String, String),
130    #[error(
131        "wrong roundezvous protocol version, the roundezvous server expected {0}, but got: {1}"
132    )]
133    WrongRoundezvousVersion(u32, u32),
134    #[error("unexpected data packet: {0:?}")]
135    UnexpectedDataPacket(ReceiverToSender),
136    #[error("files rejected")]
137    FilesRejected,
138    #[error("receive packet error: {0}")]
139    ReceivePacket(#[from] PacketRecvError),
140    #[error("failed to fetch node addr: {0}")]
141    NodeAddr(String),
142}
143
144/// A client that can send files
145pub struct Sender {
146    /// Sender arguments
147    args: SenderArgs,
148    /// The connection to the receiver
149    conn: iroh::endpoint::Connection,
150    /// The local endpoint
151    endpoint: iroh::Endpoint,
152}
153
154/// Arguments for the sender
155pub struct SenderArgs {
156    /// Files/Directories to send
157    pub files: Vec<PathBuf>,
158}
159
160impl Sender {
161    pub async fn connect(
162        this_endpoint: iroh::Endpoint,
163        args: SenderArgs,
164    ) -> Result<Self, SendError> {
165        if let Some(incoming) = this_endpoint.accept().await {
166            let connecting = incoming.accept()?;
167            let conn = connecting.await?;
168
169            tracing::info!("receiver connected to sender");
170
171            return Ok(Self {
172                args,
173                conn,
174                endpoint: this_endpoint,
175            });
176        }
177
178        unreachable!();
179    }
180
181    /// Close the connection
182    pub async fn close(&mut self) {
183        self.conn.close(0u32.into(), &[0]);
184        self.endpoint.close().await;
185    }
186
187    /// Wait for the other peer to close the connection
188    pub async fn wait_for_close(&mut self) {
189        self.conn.closed().await;
190    }
191
192    /// Get the type of the connection
193    pub async fn connection_type(&self) -> Option<iroh::endpoint::ConnectionType> {
194        let node_id = self.conn.remote_node_id().ok()?;
195        self.endpoint.conn_type(node_id).ok()?.get().ok()
196    }
197
198    /// Send files
199    /// # Arguments
200    /// * `wait_for_other_peer_to_accept_files_callback` - Callback to wait for the other peer to accept the files
201    /// * `files_decision_callback` - Callback with the decision of the other peer to accept the files
202    /// * `initial_progress_callback` - Callback with the initial progress of each file to send (name, current, total)
203    /// * `write_callback` - Callback every time data is written to the connection
204    /// * `should_continue` - Callback to check if the transfer should continue
205    ///
206    /// # Returns
207    /// * `Ok(true)` if the transfer was finished successfully
208    /// * `Ok(false)` if the transfer was stopped
209    pub async fn send_files(
210        &mut self,
211        mut wait_for_other_peer_to_accept_files_callback: impl FnMut(),
212        mut files_decision_callback: impl FnMut(bool),
213        mut initial_progress_callback: impl FnMut(&[(String, u64, u64)]),
214        write_callback: &mut impl FnMut(u64),
215        should_continue: &mut impl FnMut() -> bool,
216    ) -> Result<bool, SendError> {
217        send_packet(
218            SenderToReceiver::ConnRequest {
219                version_num: QS_PROTO_VERSION.to_string(),
220            },
221            &self.conn,
222        )
223        .await?;
224
225        match receive_packet::<ReceiverToSender>(&self.conn).await? {
226            ReceiverToSender::Ok => (),
227            ReceiverToSender::WrongVersion { expected } => {
228                return Err(SendError::WrongVersion(expected, QS_PROTO_VERSION.to_string()));
229            }
230            p => return Err(SendError::UnexpectedDataPacket(p)),
231        }
232
233        let files_available = {
234            let mut files = Vec::new();
235            for file in &self.args.files {
236                if !file.exists() {
237                    return Err(SendError::FileDoesNotExists(file.clone()));
238                }
239                files.push(get_files_available(file)?);
240            }
241            files
242        };
243
244        send_packet(
245            SenderToReceiver::FileInfo {
246                files: files_available.clone(),
247            },
248            &self.conn,
249        )
250        .await?;
251
252        wait_for_other_peer_to_accept_files_callback();
253
254        let to_skip = match receive_packet::<ReceiverToSender>(&self.conn).await? {
255            ReceiverToSender::AcceptFilesSkip { files } => {
256                files_decision_callback(true);
257                files
258            }
259            ReceiverToSender::RejectFiles => {
260                files_decision_callback(false);
261                self.close().await;
262                return Err(SendError::FilesRejected);
263            }
264            p => return Err(SendError::UnexpectedDataPacket(p)),
265        };
266
267        let to_send: Vec<Option<FileSendRecvTree>> = files_available
268            .iter()
269            .zip(&to_skip)
270            .map(|(file, skip)| {
271                if let Some(skip) = skip {
272                    file.remove_skipped(skip)
273                } else {
274                    Some(file.to_send_recv_tree())
275                }
276            })
277            .collect();
278
279        let mut progress: Vec<(String, u64, u64)> = Vec::with_capacity(files_available.len());
280        for (file, skip) in files_available.iter().zip(to_skip) {
281            progress.push((
282                file.name().to_string(),
283                skip.as_ref().map(|s| s.skip()).unwrap_or(0),
284                file.size(),
285            ));
286        }
287
288        initial_progress_callback(&progress);
289
290        let send = self.conn.open_uni().await?;
291        let mut send = GzipEncoder::new(send);
292
293        let mut interrupted = false;
294
295        for (path, file) in self.args.files.iter().zip(to_send) {
296            if let Some(file) = file {
297                match file {
298                    FileSendRecvTree::File { skip, size, .. } => {
299                        let mut file = tokio::fs::File::open(&path).await?;
300                        if !send_file(
301                            &mut send,
302                            &mut file,
303                            skip,
304                            size,
305                            write_callback,
306                            should_continue,
307                        )
308                        .await?
309                        {
310                            interrupted = true;
311                            break;
312                        }
313                    }
314                    FileSendRecvTree::Dir { files, .. } => {
315                        if !send_directory(
316                            &mut send,
317                            path,
318                            &files,
319                            write_callback,
320                            should_continue,
321                        )? {
322                            interrupted = true;
323                            break;
324                        }
325                    }
326                }
327            }
328        }
329
330        send.shutdown().await?;
331
332        if !interrupted {
333            self.wait_for_close().await;
334        } else {
335            tracing::info!("the transfer was interrupted");
336        }
337
338        Ok(!interrupted)
339    }
340}