jam_rs/
file_io.rs

1use crate::cli::Commands;
2use crate::cli::HashAlgorithms;
3use crate::cli::OutputFormats;
4use crate::compare::CompareResult;
5use crate::hash_functions::Function;
6use crate::signature::Signature;
7use crate::sketch::Sketch;
8use crate::sketcher;
9use anyhow::anyhow;
10use anyhow::Result;
11use needletail::parse_fastx_file;
12use rayon::prelude::IntoParallelRefIterator;
13use rayon::prelude::ParallelIterator;
14use sourmash::signature::Signature as SourmashSignature;
15use std::io;
16use std::io::Write;
17use std::sync::mpsc;
18use std::sync::mpsc::Receiver;
19use std::thread;
20use std::{
21    ffi::OsStr,
22    fs::{self, File},
23    io::{BufRead, BufReader},
24    path::PathBuf,
25};
26
27pub struct FileHandler {}
28
29impl FileHandler {
30    pub fn sketch_files(command: Commands, threads: Option<usize>) -> Result<()> {
31        match command.to_owned() {
32            Commands::Sketch {
33                input,
34                output,
35                kmer_size,
36                fscale,
37                kscale,
38                nmin,
39                nmax,
40                algorithm,
41                format,
42                singleton,
43            } => {
44                let files = FileHandler::test_and_collect_files(input, true)?;
45                let pool = rayon::ThreadPoolBuilder::new()
46                    .num_threads(threads.unwrap_or_default())
47                    .build()?;
48
49                let function = Function::from_alg(algorithm.clone(), kmer_size);
50
51                let (send, recv) = mpsc::channel();
52
53                let is_stdout = output.is_none();
54                let handler = thread::spawn(|| {
55                    FileHandler::write_output(output, format, recv)
56                    // thread code
57                });
58
59                let _ = pool.install(|| {
60                    files.par_iter().try_for_each(|file_path| {
61                        match FileHandler::sketch_file(
62                            file_path,
63                            kmer_size,
64                            fscale,
65                            kscale,
66                            nmin,
67                            nmax,
68                            singleton,
69                            false,
70                            function.clone(),
71                            algorithm.clone(),
72                            is_stdout,
73                        ) {
74                            Ok(sig) => send.send(sig).map_err(|_| anyhow!("Error while sending")),
75                            Err(_) => Err(anyhow!("Error while sketching file {:?}", file_path)),
76                        }
77                    })
78                });
79
80                drop(send);
81
82                Ok(handler
83                    .join()
84                    .map_err(|_| anyhow!("Unable to join threads"))??)
85            }
86            _ => Err(anyhow!("Wrong command")),
87        }
88    }
89
90    pub fn sketch_file(
91        input: &PathBuf,
92        kmer_length: u8,
93        fscale: Option<u64>,
94        kscale: Option<u64>,
95        nmin: Option<u64>,
96        nmax: Option<u64>,
97        singleton: bool,
98        stats: bool,
99        function: Function,
100        algorithm: HashAlgorithms,
101        stdout: bool,
102    ) -> Result<Signature> {
103        let mut x = fs::metadata(input)?.len();
104        if let Some(ext) = input.extension() {
105            if let Some(ext_str) = ext.to_str() {
106                if ext_str == "gz" {
107                    // Approximate the size of uncompressed file
108                    x *= 3;
109                }
110            }
111        }
112        let start = std::time::Instant::now();
113        let kscale = if let Some(kscale) = kscale {
114            (x as f64 / kscale as f64) as u64
115        } else {
116            u64::MAX
117        };
118        let max_hash = if let Some(fscale) = fscale {
119            (u64::MAX as f64 / fscale as f64) as u64
120        } else {
121            u64::MAX
122        };
123        let mut sketcher = sketcher::Sketcher::new(
124            kmer_length,
125            input
126                .to_str()
127                .ok_or_else(|| anyhow!("Unknown path"))?
128                .to_string(),
129            singleton,
130            stats,
131            kscale,
132            max_hash,
133            nmin,
134            nmax,
135            function,
136            algorithm,
137        );
138        let mut reader = parse_fastx_file(input)?;
139        let mut counter = 0;
140        while let Some(record) = reader.next() {
141            sketcher.process(&record?);
142            counter += 1;
143        }
144        let elapsed = start.elapsed().as_millis();
145        if !stdout {
146            println!(
147                "Processed {:?} with {} records, in {:?} seconds",
148                input,
149                counter,
150                elapsed as f64 / 1000.0,
151            );
152        }
153        Ok(sketcher.finish())
154    }
155
156    pub fn write_output(
157        output: Option<PathBuf>,
158        output_format: OutputFormats,
159        signature_recv: Receiver<Signature>,
160    ) -> Result<()> {
161        let stdout = output.is_none();
162        let mut output: Box<dyn Write> = match output {
163            Some(o) => Box::new(std::io::BufWriter::new(File::create(o)?)),
164            None => Box::new(std::io::BufWriter::new(io::stdout())),
165        };
166
167        match output_format {
168            OutputFormats::Bin => {
169                while let Ok(sig) = signature_recv.recv() {
170                    let name = sig.file_name.clone();
171                    let len = sig.sketches.first().unwrap().hashes.len();
172                    bincode::serialize_into(&mut output, &vec![sig])?;
173                    if !stdout {
174                        println!("Wrote signature: {:?} with {:?} hashes.", name, len);
175                    }
176                }
177            }
178            OutputFormats::Sourmash => {
179                while let Ok(sig) = signature_recv.recv() {
180                    let sourmash_sig: SourmashSignature = sig.into();
181                    serde_json::to_writer(&mut output, &vec![sourmash_sig])?;
182                }
183            }
184        }
185
186        Ok(())
187    }
188
189    pub fn read_signatures(input: &PathBuf) -> Result<Vec<Signature>> {
190        let read_to_bytes = std::fs::read(input)?;
191        Ok(bincode::deserialize_from(read_to_bytes.as_slice()).unwrap())
192    }
193
194    pub fn concat(inputs: Vec<PathBuf>, output: PathBuf) -> Result<()> {
195        let o_file = std::fs::File::create(output)?;
196        let mut bufwriter = std::io::BufWriter::new(o_file);
197
198        for input in inputs {
199            let mut reader = BufReader::new(std::fs::File::open(input)?);
200            while let Ok(result) =
201                bincode::deserialize_from::<&mut BufReader<File>, Sketch>(&mut reader)
202            {
203                bincode::serialize_into(&mut bufwriter, &result)?;
204            }
205        }
206        Ok(())
207    }
208
209    pub fn test_and_collect_files(input: Vec<PathBuf>, check_ext: bool) -> Result<Vec<PathBuf>> {
210        let mut resulting_paths = Vec::new();
211        let mut found_list: Option<PathBuf> = None;
212        for path in input {
213            if !path.exists() {
214                return Err(anyhow::anyhow!("File {:?} does not exist", path));
215            }
216            if path.is_dir() {
217                for p in path.read_dir()? {
218                    let p = p?;
219                    if p.path().is_file() {
220                        if let Some(ext) = p.path().extension() {
221                            if test_extension(ext) {
222                                resulting_paths.push(p.path());
223                            } else if ext == "list" {
224                                if resulting_paths.is_empty() {
225                                    found_list = Some(p.path());
226                                    break;
227                                } else {
228                                    return Err(anyhow::anyhow!(
229                                        "Found multiple list files in {:?}",
230                                        path
231                                    ));
232                                }
233                            } else {
234                                return Err(anyhow::anyhow!(
235                                    "File with {:?} invalid extension",
236                                    path
237                                ));
238                            }
239                        } else {
240                            return Err(anyhow::anyhow!(
241                                "File {:?} does not have an extension",
242                                p.path()
243                            ));
244                        }
245                    } else {
246                        return Err(anyhow::anyhow!("File {:?} is not a file", p.path()));
247                    }
248                }
249            }
250
251            if path.is_file() {
252                if let Some(ext) = path.extension() {
253                    if test_extension(ext) || !check_ext {
254                        resulting_paths.push(path);
255                    } else if ext == "list" {
256                        if resulting_paths.is_empty() {
257                            found_list = Some(path);
258                            break;
259                        } else {
260                            return Err(anyhow::anyhow!("Found multiple list files in {:?}", path));
261                        }
262                    } else {
263                        return Err(anyhow::anyhow!("File with {:?} invalid extension", path));
264                    }
265                } else {
266                    return Err(anyhow::anyhow!(
267                        "File {:?} does not have an extension",
268                        path
269                    ));
270                }
271            }
272        }
273
274        if let Some(list) = found_list {
275            let reader = BufReader::new(std::fs::File::open(list)?);
276            for line in reader.lines() {
277                let as_path_buf = PathBuf::from(line?);
278                if as_path_buf.exists()
279                    && test_extension(as_path_buf.extension().ok_or_else(|| {
280                        anyhow::anyhow!("File {:?} does not have an extension", as_path_buf)
281                    })?)
282                    || !check_ext
283                {
284                    resulting_paths.push(as_path_buf);
285                }
286            }
287        }
288        Ok(resulting_paths)
289    }
290
291    pub fn write_result(result: &Vec<CompareResult>, output: PathBuf) -> Result<()> {
292        let o_file = std::fs::File::create(output)?;
293        let mut bufwriter = std::io::BufWriter::new(o_file);
294        for r in result {
295            writeln!(bufwriter, "{}", r)?;
296        }
297        Ok(())
298    }
299}
300
301pub fn test_extension(ext: &OsStr) -> bool {
302    !(ext != "fasta" && ext != "fa" && ext != "fastq" && ext != "fq" && ext != "gz")
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_test_extension() {
311        assert!(test_extension(OsStr::new("fasta")));
312        assert!(test_extension(OsStr::new("fa")));
313        assert!(test_extension(OsStr::new("fastq")));
314        assert!(test_extension(OsStr::new("fq")));
315        assert!(test_extension(OsStr::new("gz")));
316        assert!(!test_extension(OsStr::new("txt")));
317        assert!(!test_extension(OsStr::new("list")));
318    }
319}