1use crate::cli::Commands;
2use crate::cli::HashAlgorithms;
3use crate::cli::OutputFormats;
4use crate::compare::CompareResult;
5use crate::hash_functions::Function;
6use crate::signature::Signature;
7use crate::sketch::Sketch;
8use crate::sketcher;
9use anyhow::anyhow;
10use anyhow::Result;
11use needletail::parse_fastx_file;
12use rayon::prelude::IntoParallelRefIterator;
13use rayon::prelude::ParallelIterator;
14use sourmash::signature::Signature as SourmashSignature;
15use std::io;
16use std::io::Write;
17use std::sync::mpsc;
18use std::sync::mpsc::Receiver;
19use std::thread;
20use std::{
21 ffi::OsStr,
22 fs::{self, File},
23 io::{BufRead, BufReader},
24 path::PathBuf,
25};
26
27pub struct FileHandler {}
28
29impl FileHandler {
30 pub fn sketch_files(command: Commands, threads: Option<usize>) -> Result<()> {
31 match command.to_owned() {
32 Commands::Sketch {
33 input,
34 output,
35 kmer_size,
36 fscale,
37 kscale,
38 nmin,
39 nmax,
40 algorithm,
41 format,
42 singleton,
43 } => {
44 let files = FileHandler::test_and_collect_files(input, true)?;
45 let pool = rayon::ThreadPoolBuilder::new()
46 .num_threads(threads.unwrap_or_default())
47 .build()?;
48
49 let function = Function::from_alg(algorithm.clone(), kmer_size);
50
51 let (send, recv) = mpsc::channel();
52
53 let is_stdout = output.is_none();
54 let handler = thread::spawn(|| {
55 FileHandler::write_output(output, format, recv)
56 });
58
59 let _ = pool.install(|| {
60 files.par_iter().try_for_each(|file_path| {
61 match FileHandler::sketch_file(
62 file_path,
63 kmer_size,
64 fscale,
65 kscale,
66 nmin,
67 nmax,
68 singleton,
69 false,
70 function.clone(),
71 algorithm.clone(),
72 is_stdout,
73 ) {
74 Ok(sig) => send.send(sig).map_err(|_| anyhow!("Error while sending")),
75 Err(_) => Err(anyhow!("Error while sketching file {:?}", file_path)),
76 }
77 })
78 });
79
80 drop(send);
81
82 Ok(handler
83 .join()
84 .map_err(|_| anyhow!("Unable to join threads"))??)
85 }
86 _ => Err(anyhow!("Wrong command")),
87 }
88 }
89
90 pub fn sketch_file(
91 input: &PathBuf,
92 kmer_length: u8,
93 fscale: Option<u64>,
94 kscale: Option<u64>,
95 nmin: Option<u64>,
96 nmax: Option<u64>,
97 singleton: bool,
98 stats: bool,
99 function: Function,
100 algorithm: HashAlgorithms,
101 stdout: bool,
102 ) -> Result<Signature> {
103 let mut x = fs::metadata(input)?.len();
104 if let Some(ext) = input.extension() {
105 if let Some(ext_str) = ext.to_str() {
106 if ext_str == "gz" {
107 x *= 3;
109 }
110 }
111 }
112 let start = std::time::Instant::now();
113 let kscale = if let Some(kscale) = kscale {
114 (x as f64 / kscale as f64) as u64
115 } else {
116 u64::MAX
117 };
118 let max_hash = if let Some(fscale) = fscale {
119 (u64::MAX as f64 / fscale as f64) as u64
120 } else {
121 u64::MAX
122 };
123 let mut sketcher = sketcher::Sketcher::new(
124 kmer_length,
125 input
126 .to_str()
127 .ok_or_else(|| anyhow!("Unknown path"))?
128 .to_string(),
129 singleton,
130 stats,
131 kscale,
132 max_hash,
133 nmin,
134 nmax,
135 function,
136 algorithm,
137 );
138 let mut reader = parse_fastx_file(input)?;
139 let mut counter = 0;
140 while let Some(record) = reader.next() {
141 sketcher.process(&record?);
142 counter += 1;
143 }
144 let elapsed = start.elapsed().as_millis();
145 if !stdout {
146 println!(
147 "Processed {:?} with {} records, in {:?} seconds",
148 input,
149 counter,
150 elapsed as f64 / 1000.0,
151 );
152 }
153 Ok(sketcher.finish())
154 }
155
156 pub fn write_output(
157 output: Option<PathBuf>,
158 output_format: OutputFormats,
159 signature_recv: Receiver<Signature>,
160 ) -> Result<()> {
161 let stdout = output.is_none();
162 let mut output: Box<dyn Write> = match output {
163 Some(o) => Box::new(std::io::BufWriter::new(File::create(o)?)),
164 None => Box::new(std::io::BufWriter::new(io::stdout())),
165 };
166
167 match output_format {
168 OutputFormats::Bin => {
169 while let Ok(sig) = signature_recv.recv() {
170 let name = sig.file_name.clone();
171 let len = sig.sketches.first().unwrap().hashes.len();
172 bincode::serialize_into(&mut output, &vec![sig])?;
173 if !stdout {
174 println!("Wrote signature: {:?} with {:?} hashes.", name, len);
175 }
176 }
177 }
178 OutputFormats::Sourmash => {
179 while let Ok(sig) = signature_recv.recv() {
180 let sourmash_sig: SourmashSignature = sig.into();
181 serde_json::to_writer(&mut output, &vec![sourmash_sig])?;
182 }
183 }
184 }
185
186 Ok(())
187 }
188
189 pub fn read_signatures(input: &PathBuf) -> Result<Vec<Signature>> {
190 let read_to_bytes = std::fs::read(input)?;
191 Ok(bincode::deserialize_from(read_to_bytes.as_slice()).unwrap())
192 }
193
194 pub fn concat(inputs: Vec<PathBuf>, output: PathBuf) -> Result<()> {
195 let o_file = std::fs::File::create(output)?;
196 let mut bufwriter = std::io::BufWriter::new(o_file);
197
198 for input in inputs {
199 let mut reader = BufReader::new(std::fs::File::open(input)?);
200 while let Ok(result) =
201 bincode::deserialize_from::<&mut BufReader<File>, Sketch>(&mut reader)
202 {
203 bincode::serialize_into(&mut bufwriter, &result)?;
204 }
205 }
206 Ok(())
207 }
208
209 pub fn test_and_collect_files(input: Vec<PathBuf>, check_ext: bool) -> Result<Vec<PathBuf>> {
210 let mut resulting_paths = Vec::new();
211 let mut found_list: Option<PathBuf> = None;
212 for path in input {
213 if !path.exists() {
214 return Err(anyhow::anyhow!("File {:?} does not exist", path));
215 }
216 if path.is_dir() {
217 for p in path.read_dir()? {
218 let p = p?;
219 if p.path().is_file() {
220 if let Some(ext) = p.path().extension() {
221 if test_extension(ext) {
222 resulting_paths.push(p.path());
223 } else if ext == "list" {
224 if resulting_paths.is_empty() {
225 found_list = Some(p.path());
226 break;
227 } else {
228 return Err(anyhow::anyhow!(
229 "Found multiple list files in {:?}",
230 path
231 ));
232 }
233 } else {
234 return Err(anyhow::anyhow!(
235 "File with {:?} invalid extension",
236 path
237 ));
238 }
239 } else {
240 return Err(anyhow::anyhow!(
241 "File {:?} does not have an extension",
242 p.path()
243 ));
244 }
245 } else {
246 return Err(anyhow::anyhow!("File {:?} is not a file", p.path()));
247 }
248 }
249 }
250
251 if path.is_file() {
252 if let Some(ext) = path.extension() {
253 if test_extension(ext) || !check_ext {
254 resulting_paths.push(path);
255 } else if ext == "list" {
256 if resulting_paths.is_empty() {
257 found_list = Some(path);
258 break;
259 } else {
260 return Err(anyhow::anyhow!("Found multiple list files in {:?}", path));
261 }
262 } else {
263 return Err(anyhow::anyhow!("File with {:?} invalid extension", path));
264 }
265 } else {
266 return Err(anyhow::anyhow!(
267 "File {:?} does not have an extension",
268 path
269 ));
270 }
271 }
272 }
273
274 if let Some(list) = found_list {
275 let reader = BufReader::new(std::fs::File::open(list)?);
276 for line in reader.lines() {
277 let as_path_buf = PathBuf::from(line?);
278 if as_path_buf.exists()
279 && test_extension(as_path_buf.extension().ok_or_else(|| {
280 anyhow::anyhow!("File {:?} does not have an extension", as_path_buf)
281 })?)
282 || !check_ext
283 {
284 resulting_paths.push(as_path_buf);
285 }
286 }
287 }
288 Ok(resulting_paths)
289 }
290
291 pub fn write_result(result: &Vec<CompareResult>, output: PathBuf) -> Result<()> {
292 let o_file = std::fs::File::create(output)?;
293 let mut bufwriter = std::io::BufWriter::new(o_file);
294 for r in result {
295 writeln!(bufwriter, "{}", r)?;
296 }
297 Ok(())
298 }
299}
300
301pub fn test_extension(ext: &OsStr) -> bool {
302 !(ext != "fasta" && ext != "fa" && ext != "fastq" && ext != "fq" && ext != "gz")
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_test_extension() {
311 assert!(test_extension(OsStr::new("fasta")));
312 assert!(test_extension(OsStr::new("fa")));
313 assert!(test_extension(OsStr::new("fastq")));
314 assert!(test_extension(OsStr::new("fq")));
315 assert!(test_extension(OsStr::new("gz")));
316 assert!(!test_extension(OsStr::new("txt")));
317 assert!(!test_extension(OsStr::new("list")));
318 }
319}