1#![allow(clippy::suspicious_open_options)]
2
3use crate::{
4 common::{get_files_available, receive_packet, send_packet, FileSendRecvTree, PacketRecvError},
5 packets::{ReceiverToSender, SenderToReceiver},
6 BUF_SIZE, QS_PROTO_VERSION,
7};
8use async_compression::tokio::write::GzipEncoder;
9use std::path::PathBuf;
10use thiserror::Error;
11use tokio::io::AsyncWriteExt;
12
13pub async fn send_file<S, R>(
19 send: &mut S,
20 file: &mut R,
21 skip: u64,
22 size: u64,
23 write_callback: &mut impl FnMut(u64),
24 should_continue: &mut impl FnMut() -> bool,
25) -> std::io::Result<bool>
26where
27 S: tokio::io::AsyncWriteExt + Unpin,
28 R: tokio::io::AsyncReadExt + tokio::io::AsyncSeekExt + Unpin,
29{
30 file.seek(tokio::io::SeekFrom::Start(skip)).await?;
31
32 let mut buf = vec![0; BUF_SIZE];
33 let mut read = skip;
34
35 while read < size {
36 if !should_continue() {
37 return Ok(false);
38 }
39
40 let to_read = std::cmp::min(BUF_SIZE as u64, size - read);
41 let n = file.read_exact(&mut buf[..to_read as usize]).await?;
42
43 if n == 0 {
44 return Err(std::io::Error::new(
45 std::io::ErrorKind::UnexpectedEof,
46 "unexpected eof",
47 ));
48 }
49
50 send.write_all(&buf[..n]).await?;
51 read += n as u64;
52
53 write_callback(n as u64);
54 }
55
56 Ok(true)
57}
58
59pub fn send_directory<S>(
63 send: &mut S,
64 root_path: &std::path::Path,
65 files: &[FileSendRecvTree],
66 write_callback: &mut impl FnMut(u64),
67 should_continue: &mut impl FnMut() -> bool,
68) -> std::io::Result<bool>
69where
70 S: tokio::io::AsyncWriteExt + Unpin + Send,
71{
72 for file in files {
73 match file {
74 FileSendRecvTree::File { name, skip, size } => {
75 let path = root_path.join(name);
76
77 let continues = tokio::task::block_in_place(|| {
78 let rt = tokio::runtime::Runtime::new().unwrap();
79 rt.block_on(async {
80 let mut file = tokio::fs::OpenOptions::new().read(true).open(&path).await?;
81
82 if !send_file(
83 send,
84 &mut file,
85 *skip,
86 *size,
87 write_callback,
88 should_continue,
89 )
90 .await?
91 {
92 return Ok::<bool, std::io::Error>(false);
93 }
94
95 file.shutdown().await?;
96 Ok::<bool, std::io::Error>(true)
97 })
98 })?;
99
100 if !continues {
101 return Ok(false);
102 }
103 }
104 FileSendRecvTree::Dir { name, files } => {
105 let root_path = root_path.join(name);
106 if !send_directory(send, &root_path, files, write_callback, should_continue)? {
107 return Ok(false);
108 };
109 }
110 }
111 }
112
113 Ok(true)
114}
115
116#[derive(Debug, Error)]
117pub enum SendError {
118 #[error("files do not exist: {0}")]
119 FileDoesNotExists(PathBuf),
120 #[error("IO error: {0}")]
121 Io(#[from] std::io::Error),
122 #[error("connection error: {0}")]
125 Connection(#[from] iroh::endpoint::ConnectionError),
126 #[error("read error: {0}")]
127 Read(#[from] quinn::ReadError),
128 #[error("wrong version, the receiver expected: {0}, but got: {1}")]
129 WrongVersion(String, String),
130 #[error(
131 "wrong roundezvous protocol version, the roundezvous server expected {0}, but got: {1}"
132 )]
133 WrongRoundezvousVersion(u32, u32),
134 #[error("unexpected data packet: {0:?}")]
135 UnexpectedDataPacket(ReceiverToSender),
136 #[error("files rejected")]
137 FilesRejected,
138 #[error("receive packet error: {0}")]
139 ReceivePacket(#[from] PacketRecvError),
140 #[error("failed to fetch node addr: {0}")]
141 NodeAddr(String),
142}
143
144pub struct Sender {
146 args: SenderArgs,
148 conn: iroh::endpoint::Connection,
150 endpoint: iroh::Endpoint,
152}
153
154pub struct SenderArgs {
156 pub files: Vec<PathBuf>,
158}
159
160impl Sender {
161 pub async fn connect(
162 this_endpoint: iroh::Endpoint,
163 args: SenderArgs,
164 ) -> Result<Self, SendError> {
165 if let Some(incoming) = this_endpoint.accept().await {
166 let connecting = incoming.accept()?;
167 let conn = connecting.await?;
168
169 tracing::info!("receiver connected to sender");
170
171 return Ok(Self {
172 args,
173 conn,
174 endpoint: this_endpoint,
175 });
176 }
177
178 unreachable!();
179 }
180
181 pub async fn close(&mut self) {
183 self.conn.close(0u32.into(), &[0]);
184 self.endpoint.close().await;
185 }
186
187 pub async fn wait_for_close(&mut self) {
189 self.conn.closed().await;
190 }
191
192 pub async fn connection_type(&self) -> Option<iroh::endpoint::ConnectionType> {
194 let node_id = self.conn.remote_node_id().ok()?;
195 self.endpoint.conn_type(node_id).ok()?.get().ok()
196 }
197
198 pub async fn send_files(
210 &mut self,
211 mut wait_for_other_peer_to_accept_files_callback: impl FnMut(),
212 mut files_decision_callback: impl FnMut(bool),
213 mut initial_progress_callback: impl FnMut(&[(String, u64, u64)]),
214 write_callback: &mut impl FnMut(u64),
215 should_continue: &mut impl FnMut() -> bool,
216 ) -> Result<bool, SendError> {
217 send_packet(
218 SenderToReceiver::ConnRequest {
219 version_num: QS_PROTO_VERSION.to_string(),
220 },
221 &self.conn,
222 )
223 .await?;
224
225 match receive_packet::<ReceiverToSender>(&self.conn).await? {
226 ReceiverToSender::Ok => (),
227 ReceiverToSender::WrongVersion { expected } => {
228 return Err(SendError::WrongVersion(expected, QS_PROTO_VERSION.to_string()));
229 }
230 p => return Err(SendError::UnexpectedDataPacket(p)),
231 }
232
233 let files_available = {
234 let mut files = Vec::new();
235 for file in &self.args.files {
236 if !file.exists() {
237 return Err(SendError::FileDoesNotExists(file.clone()));
238 }
239 files.push(get_files_available(file)?);
240 }
241 files
242 };
243
244 send_packet(
245 SenderToReceiver::FileInfo {
246 files: files_available.clone(),
247 },
248 &self.conn,
249 )
250 .await?;
251
252 wait_for_other_peer_to_accept_files_callback();
253
254 let to_skip = match receive_packet::<ReceiverToSender>(&self.conn).await? {
255 ReceiverToSender::AcceptFilesSkip { files } => {
256 files_decision_callback(true);
257 files
258 }
259 ReceiverToSender::RejectFiles => {
260 files_decision_callback(false);
261 self.close().await;
262 return Err(SendError::FilesRejected);
263 }
264 p => return Err(SendError::UnexpectedDataPacket(p)),
265 };
266
267 let to_send: Vec<Option<FileSendRecvTree>> = files_available
268 .iter()
269 .zip(&to_skip)
270 .map(|(file, skip)| {
271 if let Some(skip) = skip {
272 file.remove_skipped(skip)
273 } else {
274 Some(file.to_send_recv_tree())
275 }
276 })
277 .collect();
278
279 let mut progress: Vec<(String, u64, u64)> = Vec::with_capacity(files_available.len());
280 for (file, skip) in files_available.iter().zip(to_skip) {
281 progress.push((
282 file.name().to_string(),
283 skip.as_ref().map(|s| s.skip()).unwrap_or(0),
284 file.size(),
285 ));
286 }
287
288 initial_progress_callback(&progress);
289
290 let send = self.conn.open_uni().await?;
291 let mut send = GzipEncoder::new(send);
292
293 let mut interrupted = false;
294
295 for (path, file) in self.args.files.iter().zip(to_send) {
296 if let Some(file) = file {
297 match file {
298 FileSendRecvTree::File { skip, size, .. } => {
299 let mut file = tokio::fs::File::open(&path).await?;
300 if !send_file(
301 &mut send,
302 &mut file,
303 skip,
304 size,
305 write_callback,
306 should_continue,
307 )
308 .await?
309 {
310 interrupted = true;
311 break;
312 }
313 }
314 FileSendRecvTree::Dir { files, .. } => {
315 if !send_directory(
316 &mut send,
317 path,
318 &files,
319 write_callback,
320 should_continue,
321 )? {
322 interrupted = true;
323 break;
324 }
325 }
326 }
327 }
328 }
329
330 send.shutdown().await?;
331
332 if !interrupted {
333 self.wait_for_close().await;
334 } else {
335 tracing::info!("the transfer was interrupted");
336 }
337
338 Ok(!interrupted)
339 }
340}