use clap::{error::ErrorKind, Error, Parser};
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use kun_peng::args::KLMTArgs;
use kun_peng::utils::{find_files, format_bytes, open_file};
use kun_peng::KBuildHasher;
use seqkmer::{read_parallel, BufferFastaReader};
use serde_json;
use std::collections::HashSet;
use std::fs::File;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
#[derive(Parser, Debug, Clone)]
#[clap(
version,
about = "estimate capacity",
long_about = "Estimates the size of the Kraken 2 hash table."
)]
pub struct Args {
#[arg(long, default_value = "lib")]
pub database: PathBuf,
#[clap(flatten)]
pub klmt: KLMTArgs,
#[arg(long, default_value_t = true)]
pub cache: bool,
#[clap(short, long, default_value = "4")]
pub n: usize,
#[clap(long, long, default_value_t = 0.7)]
pub load_factor: f64,
#[clap(short = 'p', long, default_value_t = 10)]
pub threads: usize,
}
const RANGE_SECTIONS: u64 = 1024;
const RANGE_MASK: u64 = RANGE_SECTIONS - 1;
fn build_output_path<P: AsRef<Path>>(input_path: &P, extension: &str) -> String {
let path = input_path.as_ref();
let parent_dir = path.parent().unwrap_or_else(|| Path::new(""));
let stem = path.file_stem().unwrap_or_else(|| path.as_os_str());
let mut output_path = parent_dir.join(stem);
output_path.set_extension(extension);
output_path.to_str().unwrap().to_owned()
}
fn process_sequence<P: AsRef<Path>>(
fna_file: &P,
args: Args,
) -> HyperLogLogPlus<u64, KBuildHasher> {
let json_path = build_output_path(fna_file, &format!("hllp_{}.json", args.n));
if args.cache && Path::new(&json_path).exists() {
let mut file = open_file(json_path).unwrap();
let mut serialized_hllp = String::new();
file.read_to_string(&mut serialized_hllp).unwrap();
let hllp: HyperLogLogPlus<u64, KBuildHasher> =
serde_json::from_str(&serialized_hllp).unwrap();
return hllp;
}
let meros = args.klmt.as_meros();
let mut hllp: HyperLogLogPlus<u64, _> =
HyperLogLogPlus::new(16, KBuildHasher::default()).unwrap();
let mut reader = BufferFastaReader::from_path(fna_file, 1)
.expect("Failed to open the FASTA file with FastaReader");
let range_n = args.n as u64;
read_parallel(
&mut reader,
args.threads,
&meros,
|record_set| {
let mut minimizer_set = HashSet::new();
for record in record_set {
record.body.apply_mut(|m_iter| {
let kmer_iter: HashSet<u64> = m_iter
.filter(|(_, hash_key)| *hash_key & RANGE_MASK < range_n)
.map(|(_, hash_key)| hash_key)
.collect();
minimizer_set.extend(kmer_iter);
});
}
minimizer_set
},
|record_sets| {
while let Some(data) = record_sets.next() {
let m_set = data.unwrap();
for minimizer in m_set {
hllp.insert(&minimizer);
}
}
},
)
.expect("read parallel error");
let serialized_hllp = serde_json::to_string(&hllp).unwrap();
if let Ok(mut file) = File::create(&json_path) {
if let Err(e) = file.write_all(serialized_hllp.as_bytes()) {
eprintln!("Failed to write to file: {}", e);
}
} else {
eprintln!("Failed to create file: {}", json_path);
}
hllp
}
pub fn run(args: Args) -> usize {
let meros = args.klmt.as_meros();
if meros.k_mer < meros.l_mer {
let err = Error::raw(ErrorKind::ValueValidation, "k cannot be less than l");
err.exit();
}
let mut hllp: HyperLogLogPlus<u64, KBuildHasher> =
HyperLogLogPlus::new(16, KBuildHasher::default()).unwrap();
let source: PathBuf = args.database.clone();
let fna_files = if source.is_file() {
vec![source.clone()]
} else {
let library_dir = &args.database.join("library");
find_files(library_dir, "library", ".fna")
};
if fna_files.is_empty() {
panic!("Error: No library.fna files found in the specified directory. Please ensure that the directory contains at least one library.fna file and try again.");
}
println!("estimate start... ");
for fna_file in fna_files {
let args_clone = Args {
database: source.clone(),
..args
};
let local_hllp = process_sequence(&fna_file, args_clone);
if let Err(e) = hllp.merge(&local_hllp) {
println!("hllp merge err {:?}", e);
}
}
let hllp_count = (hllp.count() * RANGE_SECTIONS as f64 / args.n as f64).round() as u64;
let required_capacity = (hllp_count + 8192) as f64 / args.load_factor;
println!(
"estimate count: {:?}, required capacity: {:?}, Estimated hash table requirement: {:}",
hllp_count,
required_capacity.ceil(),
format_bytes(required_capacity * 4f64)
);
required_capacity.ceil() as usize
}
#[allow(dead_code)]
fn main() {
let args = Args::parse();
run(args);
}