quick_file_transfer/server/
util.rs1use std::{
2 fs::{self, File},
3 io::{self, BufReader, BufWriter, StdoutLock, Write},
4 net::{TcpListener, TcpStream},
5 path::{Path, PathBuf},
6 sync::{atomic::AtomicBool, Arc},
7 thread::JoinHandle,
8};
9
10use flate2::read::GzDecoder;
11use lz4_flex::frame::FrameDecoder;
12
13use crate::{
14 config::{
15 compression::CompressionVariant,
16 transfer::{
17 command::{ServerCommand, ServerResult},
18 listen::ListenArgs,
19 },
20 },
21 server::child::run_child,
22 util::{bind_listen_to_free_port_in_range, format_data_size, incremental_rw},
23 BUFFERED_RW_BUFSIZE, TCP_STREAM_BUFSIZE,
24};
25
26pub fn file_with_bufwriter(path: &Path) -> anyhow::Result<BufWriter<File>> {
27 let f = match fs::File::create(path) {
28 Ok(f) => f,
29 Err(e) => {
30 if e.kind() == io::ErrorKind::PermissionDenied {
31 log::error!("{e}");
32 log::info!("Attempting to retrieve additional debug information...");
33 let file_exists = path.exists();
34 let fpath_str = path.display().to_string();
35 let file_permissions: Option<fs::Permissions> = if file_exists {
36 if let Ok(md) = path.metadata() {
37 Some(md.permissions())
38 } else {
39 log::error!("Failed to retrieve permissions for {fpath_str}");
40 None
41 }
42 } else {
43 None
44 };
45
46 let parent = path.parent();
47 let parent_permissions: Option<fs::Permissions> =
48 parent.and_then(|p| p.metadata().ok().map(|md| md.permissions()));
49 let mut context_str = String::new();
50 if file_exists {
51 context_str.push_str(&format!("\n\tFile {fpath_str} exists on disk"));
52 } else {
53 context_str.push_str(&format!("\n\tFile {fpath_str} does not exist"));
54 }
55 if let Some(fpermission) = file_permissions {
56 context_str.push_str(&format!(" - with permissions: {fpermission:?}"));
57 }
58 if let Some(parent_permissions) = parent_permissions {
59 context_str.push_str(&format!(
60 "\n\tParent directory {:?} - permissions: {parent_permissions:?}",
61 parent.unwrap(),
62 ));
63 }
64 log::debug!("Additional context for {fpath_str}:{context_str}");
65 };
66 return Err(e.into());
67 }
68 };
69 let writer = BufWriter::with_capacity(BUFFERED_RW_BUFSIZE, f);
70 Ok(writer)
71}
72
73pub fn stdout_bufwriter() -> BufWriter<StdoutLock<'static>> {
74 let stdout = io::stdout().lock();
75 BufWriter::with_capacity(BUFFERED_RW_BUFSIZE, stdout)
76}
77
78pub fn handle_receive_data(
79 listen_args: &ListenArgs,
80 tcp_socket: &mut TcpStream,
81 fname: String,
82 decompression: Option<CompressionVariant>,
83 root_dest: Option<&Path>,
84) -> anyhow::Result<u64> {
85 let mut bufwriter = match (
86 listen_args.output.as_deref(),
87 listen_args.output_dir.as_deref(),
88 root_dest,
89 ) {
90 (_, _, Some(root_dest)) => {
91 if root_dest.is_file() {
92 tracing::info!("Initiation bufwriter targeting {root_dest:?}");
93 file_with_bufwriter(root_dest)?
94 } else {
95 let full_path = root_dest.join(fname);
96 tracing::info!("Initiation bufwriter targeting {full_path:?}");
97 file_with_bufwriter(&full_path)?
98 }
99 }
100 (None, Some(d), _) => {
101 if !d.is_dir() && d.exists() {
102 anyhow::bail!("Output directory path {d:?} is invalid - has to point at a directory or non-existent path")
103 }
104 if !d.exists() {
105 fs::create_dir(d)?;
106 }
107 let new_fpath = d.join(fname);
108 file_with_bufwriter(&new_fpath)?
109 }
110 (Some(f), None, _) => file_with_bufwriter(f)?,
111 (None, None, _) => {
112 unreachable!()
113 }
114 (Some(_), Some(_), _) => {
115 unreachable!("Specifying both an output name and an output directory is invalid")
116 }
117 };
118
119 let mut buf_tcp_reader = BufReader::with_capacity(BUFFERED_RW_BUFSIZE, tcp_socket);
120
121 let len = match decompression {
122 Some(compr) => match compr {
123 CompressionVariant::Bzip2 => {
124 let mut tcp_decoder = bzip2::read::BzDecoder::new(buf_tcp_reader);
125 incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
126 }
127 CompressionVariant::Gzip => {
128 let mut tcp_decoder = GzDecoder::new(buf_tcp_reader);
129 incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
130 }
131 CompressionVariant::Lz4 => {
132 let mut tcp_decoder = FrameDecoder::new(buf_tcp_reader);
133 incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
134 }
135 CompressionVariant::Xz => {
136 let mut tcp_decoder = xz2::read::XzDecoder::new(buf_tcp_reader);
137 incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut tcp_decoder)?
138 }
139 },
140 None => incremental_rw::<TCP_STREAM_BUFSIZE, _, _>(&mut bufwriter, &mut buf_tcp_reader)?,
141 };
142 if len < 1023 {
143 log::info!("Received: {len} B");
144 } else {
145 log::info!("Received: {} [{len} B]", format_data_size(len));
146 }
147
148 Ok(len)
149}
150
151pub fn send_result(stream: &mut TcpStream, result: &ServerResult) -> anyhow::Result<()> {
153 tracing::trace!("Sending result: {result:?}");
154 let result_bytes = bincode::serialize(result)?;
155 debug_assert!(result_bytes.len() <= u8::MAX as usize);
156 let size = result_bytes.len() as u16;
157 let header = size.to_be_bytes();
158
159 stream.write_all(&header)?;
161 stream.write_all(&result_bytes)?;
162 Ok(())
163}
164
165pub fn join_all_threads(handles: Vec<JoinHandle<anyhow::Result<()>>>) -> Result<(), String> {
166 let mut errors = String::new();
167 for h in handles {
168 let mut h_name = h.thread().name().unwrap_or_default().to_owned();
169 match h.join().map_err(|e| format!("{e:?}")) {
170 Ok(_) => (),
171 Err(e) => {
172 tracing::error!("Thread {h_name} joined with error: {e}");
173 h_name.push_str(" failed: ");
174 if !errors.is_empty() {
175 errors.push('\n');
176 }
177 errors.extend(h_name.drain(..));
178 errors.push_str(&e);
179 }
180 }
181 }
182 if errors.is_empty() {
183 Ok(())
184 } else {
185 tracing::warn!("{errors}");
186 Err(errors)
187 }
188}
189
190pub fn spawn_child_on_new_port(
191 socket: &mut TcpStream,
192 cfg: &ListenArgs,
193 stop_flag: &Arc<AtomicBool>,
194 server_cmd_get_free_port: &ServerCommand,
195 root_dest: Option<PathBuf>,
196) -> anyhow::Result<JoinHandle<anyhow::Result<()>>> {
197 let (start_port_range, end_port_range) = match server_cmd_get_free_port {
198 ServerCommand::GetFreePort((start_port_range, end_port_range)) => {
199 (start_port_range, end_port_range)
200 }
201 _ => unreachable!(),
202 };
203 let start = start_port_range.unwrap_or(49152);
204 let end = end_port_range.unwrap_or(61000);
205 let thread_listener: TcpListener = match bind_listen_to_free_port_in_range(&cfg.ip, start, end)
206 {
207 Some(listener) => listener,
208 None => {
209 log::error!("Unable to find free port in range {start}-{end}, attempting to bind to any free port");
210 TcpListener::bind((cfg.ip.as_str(), 0))?
211 }
212 };
213 let free_port = thread_listener
214 .local_addr()
215 .expect("Unable to get local address for TCP listener")
216 .port();
217 tracing::trace!("Bound to free port: {free_port}");
218
219 let free_port_be_bytes = free_port.to_be_bytes();
220 debug_assert_eq!(free_port_be_bytes.len(), 2);
221 socket.write_all(&free_port_be_bytes)?;
222 socket.flush()?;
223 let thread_builder = std::thread::Builder::new().name(format!("ThreadOn#{free_port}"));
224 let handle: JoinHandle<anyhow::Result<()>> = thread_builder
225 .spawn({
226 let cfg = cfg.clone();
227 let local_stop_flag = Arc::clone(stop_flag);
228 move || {
229 thread_listener.set_nonblocking(true)?;
230 run_child(
231 &thread_listener,
232 &cfg,
233 &local_stop_flag,
234 root_dest.as_deref(),
235 )
236 }
237 })
238 .expect("Failed spawning thread");
239 Ok(handle)
240}