quick_file_transfer/server/
util.rs

1use std::{
2    fs::{self, File},
3    io::{self, BufReader, BufWriter, StdoutLock, Write},
4    net::{TcpListener, TcpStream},
5    path::{Path, PathBuf},
6    sync::{atomic::AtomicBool, Arc},
7    thread::JoinHandle,
8};
9
10use flate2::read::GzDecoder;
11use lz4_flex::frame::FrameDecoder;
12
13use crate::{
14    config::{
15        compression::CompressionVariant,
16        transfer::{
17            command::{ServerCommand, ServerResult},
18            listen::ListenArgs,
19        },
20    },
21    server::child::run_child,
22    util::{bind_listen_to_free_port_in_range, format_data_size, incremental_rw},
23    BUFFERED_RW_BUFSIZE, TCP_STREAM_BUFSIZE,
24};
25
26pub fn file_with_bufwriter(path: &Path) -> anyhow::Result<BufWriter<File>> {
27    let f = match fs::File::create(path) {
28        Ok(f) => f,
29        Err(e) => {
30            if e.kind() == io::ErrorKind::PermissionDenied {
31                log::error!("{e}");
32                log::info!("Attempting to retrieve additional debug information...");
33                let file_exists = path.exists();
34                let fpath_str = path.display().to_string();
35                let file_permissions: Option<fs::Permissions> = if file_exists {
36                    if let Ok(md) = path.metadata() {
37                        Some(md.permissions())
38                    } else {
39                        log::error!("Failed to retrieve permissions for {fpath_str}");
40                        None
41                    }
42                } else {
43                    None
44                };
45
46                let parent = path.parent();
47                let parent_permissions: Option<fs::Permissions> =
48                    parent.and_then(|p| p.metadata().ok().map(|md| md.permissions()));
49                let mut context_str = String::new();
50                if file_exists {
51                    context_str.push_str(&format!("\n\tFile {fpath_str} exists on disk"));
52                } else {
53                    context_str.push_str(&format!("\n\tFile {fpath_str} does not exist"));
54                }
55                if let Some(fpermission) = file_permissions {
56                    context_str.push_str(&format!(" - with permissions: {fpermission:?}"));
57                }
58                if let Some(parent_permissions) = parent_permissions {
59                    context_str.push_str(&format!(
60                        "\n\tParent directory {:?} - permissions: {parent_permissions:?}",
61                        parent.unwrap(),
62                    ));
63                }
64                log::debug!("Additional context for {fpath_str}:{context_str}");
65            };
66            return Err(e.into());
67        }
68    };
69    let writer = BufWriter::with_capacity(BUFFERED_RW_BUFSIZE, f);
70    Ok(writer)
71}
72
73pub fn stdout_bufwriter() -> BufWriter<StdoutLock<'static>> {
74    let stdout = io::stdout().lock();
75    BufWriter::with_capacity(BUFFERED_RW_BUFSIZE, stdout)
76}
77
78pub fn handle_receive_data(
79    listen_args: &ListenArgs,
80    tcp_socket: &mut TcpStream,
81    fname: String,
82    decompression: Option<CompressionVariant>,
83    root_dest: Option<&Path>,
84) -> anyhow::Result<u64> {
85    let mut bufwriter = match (
86        listen_args.output.as_deref(),
87        listen_args.output_dir.as_deref(),
88        root_dest,
89    ) {
90        (_, _, Some(root_dest)) => {
91            if root_dest.is_file() {
92                tracing::info!("Initiation bufwriter targeting {root_dest:?}");
93                file_with_bufwriter(root_dest)?
94            } else {
95                let full_path = root_dest.join(fname);
96                tracing::info!("Initiation bufwriter targeting {full_path:?}");
97                file_with_bufwriter(&full_path)?
98            }
99        }
100        (None, Some(d), _) => {
101            if !d.is_dir() && d.exists() {
102                anyhow::bail!("Output directory path {d:?} is invalid - has to point at a directory or non-existent path")
103            }
104            if !d.exists() {
105                fs::create_dir(d)?;
106            }
107            let new_fpath = d.join(fname);
108            file_with_bufwriter(&new_fpath)?
109        }
110        (Some(f), None, _) => file_with_bufwriter(f)?,
111        (None, None, _) => {
112            unreachable!()
113        }
114        (Some(_), Some(_), _) => {
115            unreachable!("Specifying both an output name and an output directory is invalid")
116        }
117    };
118
119    let mut buf_tcp_reader = BufReader::with_capacity(BUFFERED_RW_BUFSIZE, tcp_socket);
120
121    let len = match decompression {
122        Some(compr) => match compr {
123            CompressionVariant::Bzip2 => {
124                let mut tcp_decoder = bzip2::read::BzDecoder::new(buf_tcp_reader);
125                incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
126            }
127            CompressionVariant::Gzip => {
128                let mut tcp_decoder = GzDecoder::new(buf_tcp_reader);
129                incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
130            }
131            CompressionVariant::Lz4 => {
132                let mut tcp_decoder = FrameDecoder::new(buf_tcp_reader);
133                incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
134            }
135            CompressionVariant::Xz => {
136                let mut tcp_decoder = xz2::read::XzDecoder::new(buf_tcp_reader);
137                incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
138            }
139        },
140        None => incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut buf_tcp_reader)?,
141    };
142    if len < 1023 {
143        log::info!("Received: {len} B");
144    } else {
145        log::info!("Received: {} [{len} B]", format_data_size(len));
146    }
147
148    Ok(len)
149}
150
151/// Send a [ServerResult] to the client
152pub fn send_result(stream: &mut TcpStream, result: &ServerResult) -> anyhow::Result<()> {
153    tracing::trace!("Sending result: {result:?}");
154    let result_bytes = bincode::serialize(result)?;
155    debug_assert!(result_bytes.len() <= u8::MAX as usize);
156    let size = result_bytes.len() as u16;
157    let header = size.to_be_bytes();
158
159    // Send the header followed by the command
160    stream.write_all(&header)?;
161    stream.write_all(&result_bytes)?;
162    Ok(())
163}
164
165pub fn join_all_threads(handles: Vec<JoinHandle<anyhow::Result<()>>>) -> Result<(), String> {
166    let mut errors = String::new();
167    for h in handles {
168        let mut h_name = h.thread().name().unwrap_or_default().to_owned();
169        match h.join().map_err(|e| format!("{e:?}")) {
170            Ok(_) => (),
171            Err(e) => {
172                tracing::error!("Thread {h_name} joined with error: {e}");
173                h_name.push_str(" failed: ");
174                if !errors.is_empty() {
175                    errors.push('\n');
176                }
177                errors.extend(h_name.drain(..));
178                errors.push_str(&e);
179            }
180        }
181    }
182    if errors.is_empty() {
183        Ok(())
184    } else {
185        tracing::warn!("{errors}");
186        Err(errors)
187    }
188}
189
190pub fn spawn_child_on_new_port(
191    socket: &mut TcpStream,
192    cfg: &ListenArgs,
193    stop_flag: &Arc<AtomicBool>,
194    server_cmd_get_free_port: &ServerCommand,
195    root_dest: Option<PathBuf>,
196) -> anyhow::Result<JoinHandle<anyhow::Result<()>>> {
197    let (start_port_range, end_port_range) = match server_cmd_get_free_port {
198        ServerCommand::GetFreePort((start_port_range, end_port_range)) => {
199            (start_port_range, end_port_range)
200        }
201        _ => unreachable!(),
202    };
203    let start = start_port_range.unwrap_or(49152);
204    let end = end_port_range.unwrap_or(61000);
205    let thread_listener: TcpListener = match bind_listen_to_free_port_in_range(&cfg.ip, start, end)
206    {
207        Some(listener) => listener,
208        None => {
209            log::error!("Unable to find free port in range {start}-{end}, attempting to bind to any free port");
210            TcpListener::bind((cfg.ip.as_str(), 0))?
211        }
212    };
213    let free_port = thread_listener
214        .local_addr()
215        .expect("Unable to get local address for TCP listener")
216        .port();
217    tracing::trace!("Bound to free port: {free_port}");
218
219    let free_port_be_bytes = free_port.to_be_bytes();
220    debug_assert_eq!(free_port_be_bytes.len(), 2);
221    socket.write_all(&free_port_be_bytes)?;
222    socket.flush()?;
223    let thread_builder = std::thread::Builder::new().name(format!("ThreadOn#{free_port}"));
224    let handle: JoinHandle<anyhow::Result<()>> = thread_builder
225        .spawn({
226            let cfg = cfg.clone();
227            let local_stop_flag = Arc::clone(stop_flag);
228            move || {
229                thread_listener.set_nonblocking(true)?;
230                run_child(
231                    &thread_listener,
232                    &cfg,
233                    &local_stop_flag,
234                    root_dest.as_deref(),
235                )
236            }
237        })
238        .expect("Failed spawning thread");
239    Ok(handle)
240}