1extern crate atty;
2use std::collections::HashSet;
3use std::fs::{create_dir_all, File, OpenOptions};
4use std::io::{self, BufRead, BufReader, BufWriter, Result, Write};
5use std::path::{Path, PathBuf};
6
7use crate::utils::styled_bytes_progress_bar;
8use flate2::read::GzDecoder;
9use flate2::write;
10use flate2::Compression;
11use indicatif::ProgressBar;
12use std::ffi::OsStr;
13use std::io::Read;
14
15struct ProgressRead<R: Read> {
16 inner: R,
17 pb: Option<ProgressBar>,
18 total: Option<u64>,
19 finished: bool,
20 label: Option<String>,
21}
22
23impl<R: Read> ProgressRead<R> {
24 fn new(inner: R, pb: Option<ProgressBar>, total: Option<u64>, label: Option<String>) -> Self {
25 ProgressRead {
26 inner,
27 pb,
28 total,
29 finished: false,
30 label,
31 }
32 }
33}
34
35impl<R: Read> Read for ProgressRead<R> {
36 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
37 let n = self.inner.read(buf)?;
38 if n > 0 {
39 if let Some(pb) = &self.pb {
40 if !self.finished {
41 pb.inc(n as u64);
42 if let Some(total) = self.total {
46 if pb.position() >= total {
47 let pos = pb.position();
48 pb.finish();
49 if let Some(lbl) = &self.label {
50 let _ = eprintln!("Downloaded {} bytes from {}", pos, lbl);
51 } else {
52 let _ = eprintln!("Downloaded {} bytes", pos);
53 }
54 self.finished = true;
55 }
56 }
57 }
58 }
59 } else {
60 if let Some(pb) = &self.pb {
61 if !self.finished {
62 let pos = pb.position();
66 pb.finish();
67 if let Some(lbl) = &self.label {
68 let _ = eprintln!("Downloaded {} bytes from {}", pos, lbl);
69 } else {
70 let _ = eprintln!("Downloaded {} bytes", pos);
71 }
72 self.finished = true;
73 }
74 }
75 }
76 Ok(n)
77 }
78}
79
80fn read_stdin() -> Vec<Vec<u8>> {
81 let stdin = io::stdin();
82 let mut list: Vec<Vec<u8>> = vec![];
83 if atty::is(atty::Stream::Stdin) {
84 eprintln!("No input on STDIN!");
85 return list;
86 }
87 for line in stdin.lock().lines() {
88 let line_as_vec = match line {
89 Err(why) => panic!("couldn't read line: {}", why),
90 Ok(l) => l.as_bytes().to_vec(),
91 };
92 list.push(line_as_vec)
93 }
94 list
95}
96
97pub fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>>
98where
99 P: AsRef<Path>,
100{
101 let file = File::open(filename)?;
102 Ok(io::BufReader::new(file).lines())
103}
104
105fn read_file(file_path: &PathBuf) -> Vec<Vec<u8>> {
106 let mut output: Vec<Vec<u8>> = vec![];
107 if let Ok(lines) = read_lines(file_path) {
108 for line in lines {
109 let line_as_vec = match line {
110 Err(why) => panic!("couldn't read line: {}", why),
111 Ok(l) => l.as_bytes().to_vec(),
112 };
113 output.push(line_as_vec)
114 }
115 }
116 output
117}
118
119pub fn get_list(file_path: &Option<PathBuf>) -> HashSet<Vec<u8>> {
120 let list = match file_path {
121 None => vec![],
122 Some(p) if p == Path::new("-") => read_stdin(),
123 Some(_) => read_file(file_path.as_ref().unwrap()),
124 };
125 HashSet::from_iter(list)
126}
127
128pub fn get_file_writer(file_path: &PathBuf, append: bool) -> Box<dyn Write> {
129 if let Err(e) = create_dir_all(file_path.parent().unwrap()) {
130 panic!(
131 "couldn't create directory {}: {}",
132 file_path.parent().unwrap().display(),
133 e
134 );
135 }
136 let file = if append {
137 match OpenOptions::new().append(true).open(file_path) {
138 Err(why) => panic!("couldn't open {}: {}", file_path.display(), why),
139 Ok(file) => file,
140 }
141 } else {
142 match File::create(file_path) {
143 Err(why) => panic!("couldn't open {}: {}", file_path.display(), why),
144 Ok(file) => file,
145 }
146 };
147
148 let writer: Box<dyn Write> = if file_path.extension() == Some(OsStr::new("gz")) {
149 Box::new(BufWriter::with_capacity(
150 128 * 1024,
151 write::GzEncoder::new(file, Compression::default()),
152 ))
153 } else {
154 Box::new(BufWriter::with_capacity(128 * 1024, file))
155 };
156 writer
157}
158
159pub fn get_writer(file_path: &Option<PathBuf>) -> Box<dyn Write> {
160 let writer: Box<dyn Write> = match file_path {
161 Some(path) if path == Path::new("-") => Box::new(BufWriter::new(io::stdout().lock())),
162 Some(path) => {
163 create_dir_all(path.parent().unwrap()).unwrap();
164 get_file_writer(path, false)
165 }
166 None => Box::new(BufWriter::new(io::stdout().lock())),
167 };
168 writer
169}
170
171pub fn get_append_writer(file_path: &Option<PathBuf>) -> Box<dyn Write> {
172 let writer: Box<dyn Write> = match file_path {
173 Some(path) if path == Path::new("-") => Box::new(BufWriter::new(io::stdout().lock())),
174 Some(path) => {
175 create_dir_all(path.parent().unwrap()).unwrap();
176 get_file_writer(path, true)
177 }
178 None => Box::new(BufWriter::new(io::stdout().lock())),
179 };
180 writer
181}
182
183pub fn get_csv_writer(file_path: &Option<PathBuf>, delimiter: u8) -> csv::Writer<Box<dyn Write>> {
184 let file_writer = get_writer(file_path);
185 if delimiter == b'\t' {
186 csv::WriterBuilder::new()
187 .delimiter(b'\t')
188 .from_writer(file_writer)
189 } else {
190 csv::WriterBuilder::new().from_writer(file_writer)
191 }
192}
193
194pub fn local_file_reader(path: PathBuf) -> io::Result<Box<dyn BufRead>> {
197 let file = File::open(&path)?;
198
199 if path.extension() == Some(OsStr::new("gz")) {
200 Ok(Box::new(BufReader::new(GzDecoder::new(file))))
201 } else {
202 Ok(Box::new(BufReader::new(file)))
203 }
204}
205
206pub fn remote_file_reader(url: &str) -> io::Result<Box<dyn BufRead>> {
209 let response = reqwest::blocking::get(url.to_string())
210 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
211 if response.status().is_success() {
212 let content_len = response.content_length();
213 let pb = content_len.map(|len| styled_bytes_progress_bar(len, url));
214
215 let is_gz_ext = url.ends_with(".gz");
216 let is_gz_encoding = response
217 .headers()
218 .get(reqwest::header::CONTENT_ENCODING)
219 .and_then(|v| v.to_str().ok())
220 .map(|s| s.to_lowercase().contains("gzip") || s.to_lowercase().contains("x-gzip"))
221 .unwrap_or(false);
222 if is_gz_ext || is_gz_encoding {
223 let progress_response =
227 ProgressRead::new(response, pb, content_len, Some(url.to_string()));
228 let decoder = GzDecoder::new(progress_response);
229 Ok(Box::new(BufReader::new(decoder)))
230 } else {
231 let reader = ProgressRead::new(response, pb, content_len, Some(url.to_string()));
232 Ok(Box::new(BufReader::new(reader)))
233 }
234 } else {
235 let response = reqwest::blocking::get(url.to_string().replace(".gz", ""))
236 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
237 if response.status().is_success() {
238 let content_len = response.content_length();
239 let pb = content_len.map(|len| styled_bytes_progress_bar(len, url));
240
241 let is_gz_encoding = response
242 .headers()
243 .get(reqwest::header::CONTENT_ENCODING)
244 .and_then(|v| v.to_str().ok())
245 .map(|s| s.to_lowercase().contains("gzip") || s.to_lowercase().contains("x-gzip"))
246 .unwrap_or(false);
247 if is_gz_encoding {
248 let progress_response =
251 ProgressRead::new(response, pb, content_len, Some(url.to_string()));
252 let decoder = GzDecoder::new(progress_response);
253 Ok(Box::new(BufReader::new(decoder)))
254 } else {
255 let reader = ProgressRead::new(response, pb, content_len, Some(url.to_string()));
256 Ok(Box::new(BufReader::new(reader)))
257 }
258 } else {
259 Err(io::Error::other(format!(
260 "Failed to fetch file: {}",
261 response.status()
262 )))
263 }
264 }
265}
266
267pub fn ssh_file_reader(path: &str) -> io::Result<Box<dyn BufRead>> {
268 let path = path.replace("ssh://", "");
270 let parts: Vec<&str> = path.split(':').collect();
272 if parts.len() != 2 {
273 return Err(io::Error::new(
274 io::ErrorKind::InvalidInput,
275 "Invalid SSH path format. Expected ssh://host:path",
276 ));
277 }
278 let host = parts[0];
279 let path = parts[1];
280 let command = if path.ends_with(".gz") {
282 format!(
283 "ssh {} 'if [ -f {} ]; then cat {}; else cat {}; fi 2>/dev/null'",
284 host,
285 path,
286 path,
287 path.trim_end_matches(".gz")
288 )
289 } else {
290 format!("ssh {} cat {} 2>/dev/null", host, path)
291 };
292
293 let process = std::process::Command::new("sh")
294 .arg("-c")
295 .arg(command)
296 .stdout(std::process::Stdio::piped())
297 .spawn()
298 .map_err(|e| {
299 io::Error::new(
300 io::ErrorKind::Other,
301 format!("Failed to start SSH command: {}", e),
302 )
303 })?;
304
305 let stdout = process
306 .stdout
307 .ok_or_else(|| io::Error::other("Failed to capture stdout"))?;
308 let mut buffer = [0u8; 2];
309 let mut stdout_reader = BufReader::new(stdout);
310 match io::Read::read_exact(&mut stdout_reader, &mut buffer) {
311 Ok(_) => {
312 let is_gzipped = buffer == [0x1F, 0x8B];
313 let stdout = io::Read::chain(std::io::Cursor::new(buffer), stdout_reader);
314 if is_gzipped {
315 Ok(Box::new(BufReader::new(GzDecoder::new(stdout))))
316 } else {
317 Ok(Box::new(BufReader::new(stdout)))
318 }
319 }
320 Err(e) => Err(io::Error::new(
321 io::ErrorKind::UnexpectedEof,
322 format!("Failed to read from SSH output: {}", e),
323 )),
324 }
325}
326
327pub fn file_reader(path: PathBuf) -> io::Result<Box<dyn BufRead>> {
330 if path.to_string_lossy().starts_with("http") {
331 return remote_file_reader(&path.to_string_lossy());
332 } else if path.to_string_lossy().starts_with("ssh") {
333 return ssh_file_reader(&path.to_string_lossy());
334 }
335
336 let file = File::open(&path);
337
338 if path.extension() == Some(OsStr::new("gz")) {
339 match file {
340 Ok(f) => {
341 let mut magic = [0u8; 2];
343 let mut f_clone = f.try_clone()?;
344 use std::io::Read;
345 f_clone.read_exact(&mut magic)?;
346 if magic != [0x1F, 0x8B] {
347 return Err(io::Error::new(
348 io::ErrorKind::InvalidData,
349 format!(
350 "File {} has .gz extension but is not a valid gzip file (magic bytes: {:x?})",
351 path.display(),
352 magic
353 ),
354 ));
355 }
356 let file = File::open(&path)?;
358 Ok(Box::new(BufReader::new(GzDecoder::new(file))))
359 }
360 Err(_) => {
361 let mut unzipped_path = path.clone();
363 unzipped_path.set_extension("");
364 let file = File::open(&unzipped_path)?;
365 Ok(Box::new(BufReader::new(file)))
366 }
367 }
368 } else {
369 let file = file?;
370 Ok(Box::new(BufReader::new(file)))
371 }
372}
373
374pub fn get_empty_reader() -> Box<dyn BufRead> {
377 Box::new(BufReader::new(io::empty()))
378}
379
380pub fn get_csv_reader(
383 file_path: &Option<PathBuf>,
384 delimiter: u8,
385 has_headers: bool,
386 comment_char: Option<u8>,
387 skip_lines: usize,
388 flexible: bool,
389) -> io::Result<csv::Reader<Box<dyn BufRead>>> {
390 dbg!(&file_path);
391 let file_reader = file_reader(file_path.as_ref().unwrap().clone())?;
392 let mut file_reader = Box::new(file_reader);
394 for _ in 0..skip_lines {
395 let mut line = String::new();
396 file_reader.read_line(&mut line).unwrap();
397 }
398
399 Ok(csv::ReaderBuilder::new()
400 .delimiter(delimiter)
401 .has_headers(has_headers)
402 .comment(comment_char)
403 .flexible(flexible) .from_reader(file_reader))
405}
406
407pub fn write_list(entries: &HashSet<Vec<u8>>, file_path: &Option<PathBuf>) -> Result<()> {
408 let mut writer = get_writer(file_path);
409 for line in entries.iter() {
410 writeln!(&mut writer, "{}", String::from_utf8(line.to_vec()).unwrap())?;
411 }
412 Ok(())
413}
414
415pub fn append_to_path(p: &PathBuf, s: &str) -> PathBuf {
416 let mut p = p.clone().into_os_string();
417 p.push(s);
418 p.into()
419}