use clap::Parser;
use kun_peng::compact_hash::{read_next_page, Compact, HashConfig, Page, Row, Slot};
use kun_peng::utils::{find_and_sort_files, open_file};
use seqkmer::buffer_read_parallel;
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, BufWriter, Read, Result, Write};
use std::path::Path;
use std::path::PathBuf;
use std::time::Instant;
pub const BUFFER_SIZE: usize = 48 * 1024 * 1024;
#[derive(Parser, Debug, Clone)]
#[clap(
version,
about = "annotate a set of sequences",
long_about = "annotate a set of sequences"
)]
pub struct Args {
#[arg(long = "db", required = true)]
pub database: PathBuf,
#[clap(long)]
pub chunk_dir: PathBuf,
#[clap(long, default_value_t = BUFFER_SIZE)]
pub buffer_size: usize,
#[clap(long, value_parser = clap::value_parser!(u32).range(1..=32), default_value_t = 4)]
pub batch_size: u32,
#[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())]
pub num_threads: usize,
}
fn read_chunk_header<R: Read>(reader: &mut R) -> io::Result<(usize, usize)> {
let mut buffer = [0u8; 16];
reader.read_exact(&mut buffer)?;
let index = u64::from_le_bytes(
buffer[0..8]
.try_into()
.expect("Failed to convert bytes to u64 for index"),
);
let chunk_size = u64::from_le_bytes(
buffer[8..16]
.try_into()
.expect("Failed to convert bytes to u64 for chunk size"),
);
Ok((index as usize, chunk_size as usize))
}
fn _write_to_file(
file_index: u64,
bytes: &[u8],
last_file_index: &mut Option<u64>,
writer: &mut Option<BufWriter<File>>,
chunk_dir: &PathBuf,
) -> io::Result<()> {
if last_file_index.is_none() || last_file_index.unwrap() != file_index {
if let Some(mut w) = writer.take() {
w.flush()?;
}
let file_name = format!("sample_file_{}.bin", file_index);
let file_path = chunk_dir.join(file_name);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&file_path)?;
*writer = Some(BufWriter::new(file));
*last_file_index = Some(file_index);
}
if let Some(w) = writer.as_mut() {
w.write_all(bytes)?;
}
Ok(())
}
fn write_to_file(
file_index: u64,
seq_id_mod: u32,
bytes: &[u8],
writers: &mut HashMap<(u64, u32), BufWriter<File>>,
chunk_dir: &PathBuf,
) -> io::Result<()> {
let writer = writers.entry((file_index, seq_id_mod)).or_insert_with(|| {
let file_name = format!("sample_file_{}_{}.bin", file_index, seq_id_mod);
let file_path = chunk_dir.join(file_name);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&file_path)
.expect("failed to open file");
BufWriter::new(file)
});
writer.write_all(bytes)?;
Ok(())
}
fn clean_up_writers(
writers: &mut HashMap<(u64, u32), BufWriter<File>>,
current_file_index: u64,
) -> io::Result<()> {
let keys_to_remove: Vec<(u64, u32)> = writers
.keys()
.cloned()
.filter(|(idx, _)| *idx != current_file_index)
.collect();
for key in keys_to_remove {
if let Some(mut writer) = writers.remove(&key) {
writer.flush()?; }
}
Ok(())
}
fn process_batch<R>(
reader: &mut R,
hash_config: &HashConfig,
page: &Page,
chunk_dir: PathBuf,
buffer_size: usize,
bin_threads: u32,
num_threads: usize,
) -> std::io::Result<()>
where
R: Read + Send,
{
let row_size = std::mem::size_of::<Row>();
let mut writers: HashMap<(u64, u32), BufWriter<File>> = HashMap::new();
let mut current_file_index: Option<u64> = None;
let value_mask = hash_config.get_value_mask();
let value_bits = hash_config.get_value_bits();
let idx_mask = hash_config.get_idx_mask();
let idx_bits = hash_config.get_idx_bits();
buffer_read_parallel(
reader,
num_threads,
buffer_size,
|dataset: Vec<Slot<u64>>| {
let mut results: HashMap<(u64, u32), Vec<u8>> = HashMap::new();
for slot in dataset {
let indx = slot.idx & idx_mask;
let compacted = slot.value.left(value_bits) as u32;
let taxid = page.find_index(indx, compacted, value_bits, value_mask);
if taxid > 0 {
let kmer_id = slot.idx >> idx_bits;
let file_index = slot.value.right(value_mask) >> 32;
let seq_id = slot.get_seq_id() as u32;
let left = slot.value.left(value_bits) as u32;
let high = u32::combined(left, taxid, value_bits);
let row = Row::new(high, seq_id, kmer_id as u32);
let value_bytes = row.as_slice(row_size);
let seq_id_mod = seq_id % bin_threads;
results
.entry((file_index, seq_id_mod))
.or_insert_with(Vec::new)
.extend(value_bytes);
}
}
results
},
|result| {
while let Some(data) = result.next() {
let res = data.unwrap();
let mut file_keys: Vec<_> = res.keys().cloned().collect();
file_keys.sort_unstable();
for (file_index, seq_id_mod) in file_keys {
if let Some(bytes) = res.get(&(file_index, seq_id_mod)) {
if current_file_index != Some(file_index) {
clean_up_writers(&mut writers, file_index).expect("clean writer");
current_file_index = Some(file_index);
}
write_to_file(file_index, seq_id_mod, bytes, &mut writers, &chunk_dir)
.expect("write to file error");
}
}
}
},
)
.expect("failed");
for writer in writers.values_mut() {
writer.flush()?;
}
Ok(())
}
fn process_chunk_file<P: AsRef<Path>>(
args: &Args,
chunk_file: P,
hash_files: &Vec<PathBuf>,
large_page: &mut Page,
) -> Result<()> {
let file = open_file(chunk_file)?;
let mut reader = BufReader::new(file);
let (page_index, _) = read_chunk_header(&mut reader)?;
let start = Instant::now();
println!("start load table...");
let config = HashConfig::from_hash_header(&args.database.join("hash_config.k2d"))?;
read_next_page(large_page, hash_files, page_index, config)?;
let duration = start.elapsed();
println!("load table took: {:?}", duration);
process_batch(
&mut reader,
&config,
&large_page,
args.chunk_dir.clone(),
args.buffer_size,
args.batch_size,
args.num_threads,
)?;
Ok(())
}
pub fn run(args: Args) -> Result<()> {
let chunk_files = find_and_sort_files(&args.chunk_dir, "sample", ".k2", true)?;
let hash_files = find_and_sort_files(&args.database, "hash", ".k2d", true)?;
let start = Instant::now();
println!("annotate start...");
let config = HashConfig::from_hash_header(&args.database.join("hash_config.k2d"))?;
let mut large_page = Page::with_capacity(0, config.hash_capacity);
for chunk_file in &chunk_files {
process_chunk_file(&args, chunk_file, &hash_files, &mut large_page)?;
let _ = std::fs::remove_file(chunk_file);
}
let duration = start.elapsed();
println!("annotate took: {:?}", duration);
Ok(())
}
#[allow(dead_code)]
fn main() {
let args = Args::parse();
if let Err(e) = run(args) {
eprintln!("Application error: {}", e);
}
}