hyper_gen/
utils.rs

1use glob::glob;
2use indicatif::{ProgressBar, ProgressStyle};
3
4use log::{info, warn};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use chrono::Local;
9use clap::{arg, value_parser, Command};
10use env_logger::{Builder, Target};
11use log::LevelFilter;
12use std::io::Write;
13
14use crate::{hd, params, types::*};
15
16pub fn create_cli() -> CliParams {
17    Builder::new()
18        .format(|buf, record| {
19            writeln!(
20                buf,
21                "{} [{}] - {}",
22                Local::now().format("%Y-%m-%d-%H:%M:%S"),
23                record.level(),
24                record.args()
25            )
26        })
27        .filter(None, LevelFilter::Info)
28        .target(Target::Stdout)
29        .init();
30
31    let cmd = Command::new("hyper-gen")
32        .bin_name("hyper-gen")
33        .subcommand_required(true)
34        .version(params::VERSION)
35        .about(
36            "HyperGen: Fast and memory-efficient genome sketching in hyperdimensional space\n\n
37        1. Genome sketching using FracMinhash and hyperdimensional computing (HDC). Three file types (.fna .fa .fasta) are supported:\n
38        hyper-gen-rust sketch -p {fna_path} -o {output_sketch_file} \n\n
39        2. ANI estimation and database search:\n
40        hyper-gen-rust dist -r {ref_sketch} -q {query_sketch} -o {output_ANI_results}",
41        )
42        .subcommand(
43            // sketch command
44            clap::command!(params::CMD_SKETCH).args(&[
45                arg!(-p --path <PATH> "Input folder path to sketch").required(true)
46                    .value_parser(value_parser!(PathBuf)),
47                arg!(-r --path_r <PATH_R> "Path to ref sketch file")
48                    .default_value("1")
49                    .value_parser(value_parser!(PathBuf)),
50                arg!(-q --path_q <PATH_Q> "Path to query sketch file")
51                    .default_value("1")
52                    .value_parser(value_parser!(PathBuf)),
53                arg!(-o --out [OUT] "Output path ").required(true).value_parser(value_parser!(PathBuf)),
54                arg!(-t --thread <THREAD> "# of threads used for computation")
55                    .default_value("16")
56                    .value_parser(value_parser!(u8)),
57                arg!(-m --sketch_method <METHOD> "Sketch method")
58                    .default_value("t1ha2")
59                    .value_parser(value_parser!(String)),
60                arg!(-C --canonical <CANONICAL> "If use canonical kmer")
61                    .default_value("true")
62                    .value_parser(value_parser!(bool)),
63                arg!(-k --ksize <KSIZE> "k-mer size for sketching")
64                    .default_value("21")
65                    .value_parser(value_parser!(u8)),
66                arg!(-S --seed <SEED> "Hash seed")
67                    .default_value("123")
68                    .value_parser(value_parser!(u64)),
69                arg!(-s --scaled <SCALED> "Scaled factor for FracMinHash")
70                    .default_value("1500")
71                    .value_parser(value_parser!(u64)),
72                arg!(-d --hv_d <HD_D> "Dimension for hypervector")
73                    .default_value("4096")
74                    .value_parser(value_parser!(usize)),
75                arg!(-Q --quant_scale <HD_D> "Scaling factor for HV quantization")
76                    .default_value("1.0")
77                    .value_parser(value_parser!(f32)),
78                arg!(-a --ani_th <ANI_TH> "ANI threshold")
79                    .default_value("85.0")
80                    .value_parser(value_parser!(f32)),
81                arg!(-D --device <DEVICE> "Device to run")
82                    .default_value("cpu")
83                    .value_parser(value_parser!(String)),
84            ]),
85        )
86        .subcommand(
87            // dist command
88            clap::command!(params::CMD_DIST).args(&[
89                arg!(-p --path <PATH> "Path to sketch file")
90                    .default_value("1")
91                    .value_parser(value_parser!(PathBuf)),
92                arg!(-r --path_r <PATH_R> "Path to ref sketch file").required(true)
93                    .value_parser(value_parser!(PathBuf)),
94                arg!(-q --path_q <PATH_Q> "Path to query sketch file").required(true)
95                    .value_parser(value_parser!(PathBuf)),
96                arg!(-o --out [OUT] "Output path ").required(true).value_parser(value_parser!(PathBuf)),
97                arg!(-t --thread <THREAD> "# of threads used for computation")
98                    .default_value("16")
99                    .value_parser(value_parser!(u8)),
100                arg!(-m --sketch_method <METHOD> "Sketch method")
101                    .default_value("fracminhash")
102                    .value_parser(value_parser!(String)),
103                arg!(-C --canonical <CANONICAL> "If use canonical kmer")
104                    .default_value("true")
105                    .value_parser(value_parser!(bool)),
106                arg!(-k --ksize <KSIZE> "k-mer size for sketching")
107                    .default_value("21")
108                    .value_parser(value_parser!(u8)),
109                arg!(-S --seed <SEED> "Hash seed")
110                    .default_value("123")
111                    .value_parser(value_parser!(u64)),
112                arg!(-s --scaled <SCALED> "Scaled factor for FracMinHash")
113                    .default_value("1500")
114                    .value_parser(value_parser!(u64)),
115                arg!(-d --hv_d <HD_D> "Dimension for hypervector")
116                    .default_value("4096")
117                    .value_parser(value_parser!(usize)),
118                arg!(-Q --quant_scale <HD_D> "Scaling factor for HV quantization")
119                    .default_value("1.0")
120                    .value_parser(value_parser!(f32)),
121                arg!(-a --ani_th <ANI_TH> "ANI threshold")
122                    .default_value("85.0")
123                    .value_parser(value_parser!(f32)),
124                arg!(-D --device <DEVICE> "Device to run")
125                    .default_value("cpu")
126                    .value_parser(value_parser!(String)),
127            ]),
128        )
129        .subcommand(
130            // search command
131            clap::command!(params::CMD_SEARCH).args(&[
132                arg!(-p --path <PATH> "Path to sketch file")
133                    .default_value("1")
134                    .value_parser(value_parser!(PathBuf)),
135                arg!(-r --path_r <PATH_R> "Path to ref sketch file")
136                    .value_parser(value_parser!(PathBuf)),
137                arg!(-q --path_q <PATH_Q> "Path to query sketch file")
138                    .value_parser(value_parser!(PathBuf)),
139                arg!(-o --out [OUT] "Output path ").value_parser(value_parser!(PathBuf)),
140                arg!(-t --thread <THREAD> "# of threads used for computation")
141                    .default_value("16")
142                    .value_parser(value_parser!(u8)),
143                arg!(-k --ksize <KSIZE> "k-mer size for sketching")
144                    .default_value("21")
145                    .value_parser(value_parser!(u8)),
146                arg!(-m --sketch_method <METHOD> "Sketch method")
147                    .default_value("fracminhash")
148                    .value_parser(value_parser!(String)),
149                arg!(-s --scaled <SCALED> "Scaled factor for FracMinHash")
150                    .default_value("1500")
151                    .value_parser(value_parser!(u64)),
152                arg!(-d --hv_d <HD_D> "Dimension for hypervector")
153                    .default_value("4096")
154                    .value_parser(value_parser!(usize)),
155                arg!(-Q --quant_scale <HD_D> "Scaling factor for HV quantization")
156                    .default_value("1.0")
157                    .value_parser(value_parser!(f32)),
158                arg!(-a --ani_th <ANI_TH> "ANI threshold")
159                    .default_value("85.0")
160                    .value_parser(value_parser!(f32)),
161            ]),
162        );
163
164    parse_cmd(cmd)
165}
166
167pub fn parse_cmd(cmd: Command) -> CliParams {
168    let matches = cmd.get_matches();
169
170    let (mode, matches) = match matches.subcommand() {
171        Some((params::CMD_SKETCH, matches)) => (params::CMD_SKETCH, matches),
172        Some((params::CMD_DIST, matches)) => (params::CMD_DIST, matches),
173        Some((params::CMD_SEARCH, matches)) => (params::CMD_SEARCH, matches),
174        _ => unreachable!("clap should ensure we don't get here"),
175    };
176
177    let cli_params = CliParams {
178        mode: mode.to_string(),
179        path: matches.get_one::<PathBuf>("path").expect("").clone(),
180        path_ref_sketch: matches.get_one::<PathBuf>("path_r").expect("").clone(),
181        path_query_sketch: matches.get_one::<PathBuf>("path_q").expect("").clone(),
182        out_file: {
183            if matches.contains_id("out") {
184                matches.get_one::<PathBuf>("out").expect("").clone()
185            } else {
186                PathBuf::new()
187            }
188        },
189        ksize: matches.get_one::<u8>("ksize").expect("").clone(),
190        sketch_method: matches
191            .get_one::<String>("sketch_method")
192            .expect("")
193            .clone(),
194        canonical: matches.get_one::<bool>("canonical").expect("").clone(),
195        seed: matches.get_one::<u64>("seed").expect("").clone(),
196        scaled: matches.get_one::<u64>("scaled").expect("").clone(),
197        hv_d: matches.get_one::<usize>("hv_d").expect("").clone(),
198        hv_quant_scale: matches.get_one::<f32>("quant_scale").expect("").clone(),
199        ani_threshold: matches.get_one::<f32>("ani_th").expect("").clone(),
200        if_compressed: true, // TODO
201        threads: matches.get_one::<u8>("thread").expect("").clone(),
202        device: matches.get_one::<String>("device").expect("").clone(),
203    };
204
205    cli_params
206}
207
208pub fn get_fasta_files(path: &PathBuf) -> Vec<PathBuf> {
209    // pub fn get_fasta_files(path: PathBuf) -> Vec<Result<PathBuf, GlobError>> {
210    let mut all_files = Vec::new();
211    for t in ["*.fna", "*.fa", "*.fasta"] {
212        let mut files: Vec<_> = glob(path.join(t).to_str().unwrap())
213            .expect("Failed to read glob pattern")
214            .map(|f| f.unwrap())
215            .collect();
216
217        all_files.append(&mut files);
218    }
219
220    all_files
221}
222
223pub fn get_progress_bar(n_file: usize) -> ProgressBar {
224    let pb = ProgressBar::new(n_file as u64);
225    pb.set_style(
226        ProgressStyle::default_bar()
227            .template("{wide_bar} {pos}/{len} ({percent}%) - Elapsed: {elapsed_precise}, ETA: {eta_precise}")
228            .unwrap()
229    );
230
231    pb
232}
233
234pub fn dump_sketch(file_sketch: &Vec<FileSketch>, out_file_path: &PathBuf) {
235    let out_filename = out_file_path.to_str().unwrap();
236
237    // Serialization
238    let serialized = bincode::serialize::<Vec<FileSketch>>(&file_sketch).unwrap();
239    // let serialized = bitcode::encode(file_sketch);
240
241    // Dump sketch file
242    fs::write(out_filename, &serialized).expect("Dump sketch file failed!");
243
244    let sketch_size_mb = serialized.len() as f32 / 1024.0 / 1024.0;
245    info!(
246        "Dump sketch file to {} with size {:.2} MB",
247        out_filename, sketch_size_mb
248    );
249}
250
251pub fn load_sketch(path: &Path) -> Vec<FileSketch> {
252    info!("Loading sketch from {}", path.to_str().unwrap());
253    let serialized = fs::read(path).expect("Opening sketch file failed!");
254    let file_sketch = bincode::deserialize::<Vec<FileSketch>>(&serialized[..]).unwrap();
255    // let file_sketch = bitcode::decode(&serialized[..]).unwrap();
256
257    file_sketch
258}
259
260pub fn dump_ani_file(sketch_dist: &SketchDist) {
261    // Sort based on ANIs
262    let mut indices = (0..sketch_dist.file_ani.len()).collect::<Vec<_>>();
263    indices.sort_by(|&i1, &i2| {
264        sketch_dist.file_ani[i1]
265            .1
266            .partial_cmp(&sketch_dist.file_ani[i2].1)
267            .unwrap()
268    });
269    indices.reverse();
270
271    // Dump in order
272    let mut csv_str = String::new();
273    let mut cnt: f32 = 0.0;
274    for i in 0..sketch_dist.file_ani.len() {
275        if sketch_dist.file_ani[indices[i]].1 >= sketch_dist.ani_threshold {
276            csv_str.push_str(&format!(
277                "{}\t{}\t{:.3}\n",
278                sketch_dist.file_ani[indices[i]].0 .0,
279                sketch_dist.file_ani[indices[i]].0 .1,
280                sketch_dist.file_ani[indices[i]].1
281            ));
282            cnt += 1.0;
283        } else {
284            break;
285        }
286    }
287
288    fs::write(sketch_dist.out_file.to_str().unwrap(), &csv_str.as_bytes())
289        .expect("Dump ANI file failed!");
290
291    // Warning if output ANIs are too sparse
292    let total_dist = sketch_dist.file_ani.len() as f32;
293    let perc = cnt / total_dist * 100.0;
294    if perc < 5.0 {
295        warn!(
296            "Output ANIs with threshold {:.1} are too divergent: {} of {} ({:.2}%) ANIs are reported",
297            sketch_dist.ani_threshold, cnt, total_dist, perc
298        );
299    } else {
300        info!(
301            "Output {} of {} ANIs above threshold {:.1} to file {}",
302            cnt,
303            total_dist,
304            sketch_dist.ani_threshold,
305            sketch_dist.out_file.to_str().unwrap()
306        )
307    }
308}
309
310use std::collections::HashMap;
311
312pub fn dump_distribution_to_txt(path: &Path) {
313    let mut file_sketch = load_sketch(path);
314
315    hd::decompress_file_sketch(&mut file_sketch);
316
317    // Write to files
318    let data: Vec<Vec<i16>> = (0..file_sketch.len())
319        .map(|i| file_sketch[i].hv.clone())
320        .collect();
321
322    // Create a histogram
323    let mut hist: HashMap<i16, u32> = HashMap::new();
324    for i in 0..data.len() {
325        for j in &data[i] {
326            if hist.get(j) == None {
327                hist.insert(*j, 1);
328            } else if let Some(c) = hist.get_mut(&j) {
329                *c += 1;
330            }
331        }
332    }
333
334    for kv in hist {
335        println!("{}\t{}", kv.0, kv.1);
336    }
337}