Skip to main content

blobtk/
io.rs

1extern crate atty;
2use std::collections::HashSet;
3use std::fs::{create_dir_all, File, OpenOptions};
4use std::io::{self, BufRead, BufReader, BufWriter, Result, Write};
5use std::path::{Path, PathBuf};
6
7use crate::utils::styled_bytes_progress_bar;
8use flate2::read::GzDecoder;
9use flate2::write;
10use flate2::Compression;
11use indicatif::ProgressBar;
12use std::ffi::OsStr;
13use std::io::Read;
14
15struct ProgressRead<R: Read> {
16    inner: R,
17    pb: Option<ProgressBar>,
18    total: Option<u64>,
19    finished: bool,
20    label: Option<String>,
21}
22
23impl<R: Read> ProgressRead<R> {
24    fn new(inner: R, pb: Option<ProgressBar>, total: Option<u64>, label: Option<String>) -> Self {
25        ProgressRead {
26            inner,
27            pb,
28            total,
29            finished: false,
30            label,
31        }
32    }
33}
34
35impl<R: Read> Read for ProgressRead<R> {
36    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
37        let n = self.inner.read(buf)?;
38        if n > 0 {
39            if let Some(pb) = &self.pb {
40                if !self.finished {
41                    pb.inc(n as u64);
42                    // If we know the total content length, finish when we've
43                    // reached it so the final message is shown even if the
44                    // wrapped decoder doesn't issue a zero-length read.
45                    if let Some(total) = self.total {
46                        if pb.position() >= total {
47                            let pos = pb.position();
48                            pb.finish();
49                            if let Some(lbl) = &self.label {
50                                let _ = eprintln!("Downloaded {} bytes from {}", pos, lbl);
51                            } else {
52                                let _ = eprintln!("Downloaded {} bytes", pos);
53                            }
54                            self.finished = true;
55                        }
56                    }
57                }
58            }
59        } else {
60            if let Some(pb) = &self.pb {
61                if !self.finished {
62                    // Replace the progress bar with a final message so the
63                    // completed report is preserved and not overwritten by
64                    // subsequent progress bars.
65                    let pos = pb.position();
66                    pb.finish();
67                    if let Some(lbl) = &self.label {
68                        let _ = eprintln!("Downloaded {} bytes from {}", pos, lbl);
69                    } else {
70                        let _ = eprintln!("Downloaded {} bytes", pos);
71                    }
72                    self.finished = true;
73                }
74            }
75        }
76        Ok(n)
77    }
78}
79
80fn read_stdin() -> Vec<Vec<u8>> {
81    let stdin = io::stdin();
82    let mut list: Vec<Vec<u8>> = vec![];
83    if atty::is(atty::Stream::Stdin) {
84        eprintln!("No input on STDIN!");
85        return list;
86    }
87    for line in stdin.lock().lines() {
88        let line_as_vec = match line {
89            Err(why) => panic!("couldn't read line: {}", why),
90            Ok(l) => l.as_bytes().to_vec(),
91        };
92        list.push(line_as_vec)
93    }
94    list
95}
96
97pub fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>>
98where
99    P: AsRef<Path>,
100{
101    let file = File::open(filename)?;
102    Ok(io::BufReader::new(file).lines())
103}
104
105fn read_file(file_path: &PathBuf) -> Vec<Vec<u8>> {
106    let mut output: Vec<Vec<u8>> = vec![];
107    if let Ok(lines) = read_lines(file_path) {
108        for line in lines {
109            let line_as_vec = match line {
110                Err(why) => panic!("couldn't read line: {}", why),
111                Ok(l) => l.as_bytes().to_vec(),
112            };
113            output.push(line_as_vec)
114        }
115    }
116    output
117}
118
119pub fn get_list(file_path: &Option<PathBuf>) -> HashSet<Vec<u8>> {
120    let list = match file_path {
121        None => vec![],
122        Some(p) if p == Path::new("-") => read_stdin(),
123        Some(_) => read_file(file_path.as_ref().unwrap()),
124    };
125    HashSet::from_iter(list)
126}
127
128pub fn get_file_writer(file_path: &PathBuf, append: bool) -> Box<dyn Write> {
129    if let Err(e) = create_dir_all(file_path.parent().unwrap()) {
130        panic!(
131            "couldn't create directory {}: {}",
132            file_path.parent().unwrap().display(),
133            e
134        );
135    }
136    let file = if append {
137        match OpenOptions::new().append(true).open(file_path) {
138            Err(why) => panic!("couldn't open {}: {}", file_path.display(), why),
139            Ok(file) => file,
140        }
141    } else {
142        match File::create(file_path) {
143            Err(why) => panic!("couldn't open {}: {}", file_path.display(), why),
144            Ok(file) => file,
145        }
146    };
147
148    let writer: Box<dyn Write> = if file_path.extension() == Some(OsStr::new("gz")) {
149        Box::new(BufWriter::with_capacity(
150            128 * 1024,
151            write::GzEncoder::new(file, Compression::default()),
152        ))
153    } else {
154        Box::new(BufWriter::with_capacity(128 * 1024, file))
155    };
156    writer
157}
158
159pub fn get_writer(file_path: &Option<PathBuf>) -> Box<dyn Write> {
160    let writer: Box<dyn Write> = match file_path {
161        Some(path) if path == Path::new("-") => Box::new(BufWriter::new(io::stdout().lock())),
162        Some(path) => {
163            create_dir_all(path.parent().unwrap()).unwrap();
164            get_file_writer(path, false)
165        }
166        None => Box::new(BufWriter::new(io::stdout().lock())),
167    };
168    writer
169}
170
171pub fn get_append_writer(file_path: &Option<PathBuf>) -> Box<dyn Write> {
172    let writer: Box<dyn Write> = match file_path {
173        Some(path) if path == Path::new("-") => Box::new(BufWriter::new(io::stdout().lock())),
174        Some(path) => {
175            create_dir_all(path.parent().unwrap()).unwrap();
176            get_file_writer(path, true)
177        }
178        None => Box::new(BufWriter::new(io::stdout().lock())),
179    };
180    writer
181}
182
183pub fn get_csv_writer(file_path: &Option<PathBuf>, delimiter: u8) -> csv::Writer<Box<dyn Write>> {
184    let file_writer = get_writer(file_path);
185    if delimiter == b'\t' {
186        csv::WriterBuilder::new()
187            .delimiter(b'\t')
188            .from_writer(file_writer)
189    } else {
190        csv::WriterBuilder::new().from_writer(file_writer)
191    }
192}
193
194/// Return a BufRead object for a given file path.
195/// If the file path has a `.gz` extension, the file is decompressed on the fly.
196pub fn local_file_reader(path: PathBuf) -> io::Result<Box<dyn BufRead>> {
197    let file = File::open(&path)?;
198
199    if path.extension() == Some(OsStr::new("gz")) {
200        Ok(Box::new(BufReader::new(GzDecoder::new(file))))
201    } else {
202        Ok(Box::new(BufReader::new(file)))
203    }
204}
205
206/// Return a BufRead object for a given URL path.
207/// The file will be fetched.
208pub fn remote_file_reader(url: &str) -> io::Result<Box<dyn BufRead>> {
209    let response = reqwest::blocking::get(url.to_string())
210        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
211    if response.status().is_success() {
212        let content_len = response.content_length();
213        let pb = content_len.map(|len| styled_bytes_progress_bar(len, url));
214
215        let is_gz_ext = url.ends_with(".gz");
216        let is_gz_encoding = response
217            .headers()
218            .get(reqwest::header::CONTENT_ENCODING)
219            .and_then(|v| v.to_str().ok())
220            .map(|s| s.to_lowercase().contains("gzip") || s.to_lowercase().contains("x-gzip"))
221            .unwrap_or(false);
222        if is_gz_ext || is_gz_encoding {
223            // Wrap the original response with ProgressRead so the progress bar
224            // measures compressed bytes (Content-Length) rather than the
225            // decompressed stream which would exceed the reported length.
226            let progress_response =
227                ProgressRead::new(response, pb, content_len, Some(url.to_string()));
228            let decoder = GzDecoder::new(progress_response);
229            Ok(Box::new(BufReader::new(decoder)))
230        } else {
231            let reader = ProgressRead::new(response, pb, content_len, Some(url.to_string()));
232            Ok(Box::new(BufReader::new(reader)))
233        }
234    } else {
235        let response = reqwest::blocking::get(url.to_string().replace(".gz", ""))
236            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
237        if response.status().is_success() {
238            let content_len = response.content_length();
239            let pb = content_len.map(|len| styled_bytes_progress_bar(len, url));
240
241            let is_gz_encoding = response
242                .headers()
243                .get(reqwest::header::CONTENT_ENCODING)
244                .and_then(|v| v.to_str().ok())
245                .map(|s| s.to_lowercase().contains("gzip") || s.to_lowercase().contains("x-gzip"))
246                .unwrap_or(false);
247            if is_gz_encoding {
248                // See note above: count compressed bytes by wrapping the
249                // response in ProgressRead before GzDecoder.
250                let progress_response =
251                    ProgressRead::new(response, pb, content_len, Some(url.to_string()));
252                let decoder = GzDecoder::new(progress_response);
253                Ok(Box::new(BufReader::new(decoder)))
254            } else {
255                let reader = ProgressRead::new(response, pb, content_len, Some(url.to_string()));
256                Ok(Box::new(BufReader::new(reader)))
257            }
258        } else {
259            Err(io::Error::other(format!(
260                "Failed to fetch file: {}",
261                response.status()
262            )))
263        }
264    }
265}
266
267pub fn ssh_file_reader(path: &str) -> io::Result<Box<dyn BufRead>> {
268    // Remove protocol from path
269    let path = path.replace("ssh://", "");
270    // Split the path into host and file
271    let parts: Vec<&str> = path.split(':').collect();
272    if parts.len() != 2 {
273        return Err(io::Error::new(
274            io::ErrorKind::InvalidInput,
275            "Invalid SSH path format. Expected ssh://host:path",
276        ));
277    }
278    let host = parts[0];
279    let path = parts[1];
280    // Use SSH to read the file
281    let command = if path.ends_with(".gz") {
282        format!(
283            "ssh {} 'if [ -f {} ]; then cat {}; else cat {}; fi 2>/dev/null'",
284            host,
285            path,
286            path,
287            path.trim_end_matches(".gz")
288        )
289    } else {
290        format!("ssh {} cat {} 2>/dev/null", host, path)
291    };
292
293    let process = std::process::Command::new("sh")
294        .arg("-c")
295        .arg(command)
296        .stdout(std::process::Stdio::piped())
297        .spawn()
298        .map_err(|e| {
299            io::Error::new(
300                io::ErrorKind::Other,
301                format!("Failed to start SSH command: {}", e),
302            )
303        })?;
304
305    let stdout = process
306        .stdout
307        .ok_or_else(|| io::Error::other("Failed to capture stdout"))?;
308    let mut buffer = [0u8; 2];
309    let mut stdout_reader = BufReader::new(stdout);
310    match io::Read::read_exact(&mut stdout_reader, &mut buffer) {
311        Ok(_) => {
312            let is_gzipped = buffer == [0x1F, 0x8B];
313            let stdout = io::Read::chain(std::io::Cursor::new(buffer), stdout_reader);
314            if is_gzipped {
315                Ok(Box::new(BufReader::new(GzDecoder::new(stdout))))
316            } else {
317                Ok(Box::new(BufReader::new(stdout)))
318            }
319        }
320        Err(e) => Err(io::Error::new(
321            io::ErrorKind::UnexpectedEof,
322            format!("Failed to read from SSH output: {}", e),
323        )),
324    }
325}
326
327/// Return a BufRead object for a given file path.
328/// If the path is a URL the file will be fetched.
329pub fn file_reader(path: PathBuf) -> io::Result<Box<dyn BufRead>> {
330    if path.to_string_lossy().starts_with("http") {
331        return remote_file_reader(&path.to_string_lossy());
332    } else if path.to_string_lossy().starts_with("ssh") {
333        return ssh_file_reader(&path.to_string_lossy());
334    }
335
336    let file = File::open(&path);
337
338    if path.extension() == Some(OsStr::new("gz")) {
339        match file {
340            Ok(f) => {
341                // Check gzip magic bytes
342                let mut magic = [0u8; 2];
343                let mut f_clone = f.try_clone()?;
344                use std::io::Read;
345                f_clone.read_exact(&mut magic)?;
346                if magic != [0x1F, 0x8B] {
347                    return Err(io::Error::new(
348                        io::ErrorKind::InvalidData,
349                        format!(
350                            "File {} has .gz extension but is not a valid gzip file (magic bytes: {:x?})",
351                            path.display(),
352                            magic
353                        ),
354                    ));
355                }
356                // Re-open for actual reading
357                let file = File::open(&path)?;
358                Ok(Box::new(BufReader::new(GzDecoder::new(file))))
359            }
360            Err(_) => {
361                // Try unzipped file
362                let mut unzipped_path = path.clone();
363                unzipped_path.set_extension("");
364                let file = File::open(&unzipped_path)?;
365                Ok(Box::new(BufReader::new(file)))
366            }
367        }
368    } else {
369        let file = file?;
370        Ok(Box::new(BufReader::new(file)))
371    }
372}
373
374/// Return an empty Box<dyn BufRead>.
375/// This is useful when we want to read from stdin.
376pub fn get_empty_reader() -> Box<dyn BufRead> {
377    Box::new(BufReader::new(io::empty()))
378}
379
380/// Return a csv::Reader object for a given file path.
381/// If the file path has a `.gz` extension, the file is decompressed on the fly.
382pub fn get_csv_reader(
383    file_path: &Option<PathBuf>,
384    delimiter: u8,
385    has_headers: bool,
386    comment_char: Option<u8>,
387    skip_lines: usize,
388    flexible: bool,
389) -> io::Result<csv::Reader<Box<dyn BufRead>>> {
390    dbg!(&file_path);
391    let file_reader = file_reader(file_path.as_ref().unwrap().clone())?;
392    // Skip the first `skip_lines` lines
393    let mut file_reader = Box::new(file_reader);
394    for _ in 0..skip_lines {
395        let mut line = String::new();
396        file_reader.read_line(&mut line).unwrap();
397    }
398
399    Ok(csv::ReaderBuilder::new()
400        .delimiter(delimiter)
401        .has_headers(has_headers)
402        .comment(comment_char)
403        .flexible(flexible) // Allow incomplete rows
404        .from_reader(file_reader))
405}
406
407pub fn write_list(entries: &HashSet<Vec<u8>>, file_path: &Option<PathBuf>) -> Result<()> {
408    let mut writer = get_writer(file_path);
409    for line in entries.iter() {
410        writeln!(&mut writer, "{}", String::from_utf8(line.to_vec()).unwrap())?;
411    }
412    Ok(())
413}
414
415pub fn append_to_path(p: &PathBuf, s: &str) -> PathBuf {
416    let mut p = p.clone().into_os_string();
417    p.push(s);
418    p.into()
419}