use core::mem::swap;
use core::panic;
use std::mem;
use hashbrown::HashMap;
use indicatif::ProgressIterator;
use super::QualOpts;
use crate::io_utils::any_fastq;
use crate::ska_dict::bit_encoding::UInt;
use crate::ska_dict::SkaDict;
pub type InputFastx = (String, String, Option<String>);
pub struct MergeSkaDict<IntT> {
k: usize,
rc: bool,
n_samples: usize,
names: Vec<String>,
split_kmers: HashMap<IntT, Vec<u8>>,
}
impl<IntT> MergeSkaDict<IntT>
where
IntT: for<'a> UInt<'a>,
{
pub fn new(k: usize, n_samples: usize, rc: bool) -> Self {
Self {
k,
rc,
n_samples,
names: vec!["".to_string(); n_samples],
split_kmers: HashMap::default(),
}
}
pub fn build_from_array<'a>(
&'a mut self,
names: &'a mut Vec<String>,
split_kmers: &mut HashMap<IntT, Vec<u8>>,
) {
swap(names, &mut self.names);
swap(split_kmers, &mut self.split_kmers);
}
pub fn append(&mut self, other: &SkaDict<IntT>) {
if other.kmer_len() != self.k {
panic!(
"K-mer lengths do not match: {} {}",
other.kmer_len(),
self.k
);
}
if other.rc() != self.rc {
panic!("Strand use inconsistent");
}
self.names[other.idx()] = other.name().clone();
if self.ksize() == 0 {
for (kmer, base) in other.kmers() {
let mut base_vec: Vec<u8> = vec![0; self.n_samples];
base_vec[other.idx()] = *base;
self.split_kmers.insert(*kmer, base_vec);
}
} else {
for (kmer, base) in other.kmers() {
self.split_kmers
.entry(*kmer)
.and_modify(|b| {
b[other.idx()] = *base;
})
.or_insert_with(|| {
let mut new_base_vec: Vec<u8> = vec![0; self.n_samples];
new_base_vec[other.idx()] = *base;
new_base_vec
});
}
}
}
pub fn merge<'a>(&'a mut self, other: &'a mut MergeSkaDict<IntT>) {
if other.k != self.k {
panic!("K-mer lengths do not match: {} {}", other.k, self.k);
}
if other.rc() != self.rc {
panic!("Strand use inconsistent");
}
if other.ksize() > 0 {
if self.ksize() == 0 {
swap(&mut other.names, &mut self.names);
swap(&mut other.split_kmers, &mut self.split_kmers);
} else {
for name_it in other.names.iter_mut().zip(self.names.iter_mut()) {
let (other_name, self_name) = name_it;
if self_name.is_empty() {
swap(self_name, other_name);
}
}
for (kmer, other_vec) in &mut other.split_kmers {
self.split_kmers
.entry(*kmer)
.and_modify(|self_vec| {
for base_it in other_vec.iter().zip(self_vec.iter_mut()) {
*base_it.1 |= *base_it.0;
}
})
.or_insert_with(|| mem::take(other_vec));
}
}
}
}
pub fn extend<'a>(&'a mut self, other: &'a mut MergeSkaDict<IntT>) {
if other.k != self.k {
panic!("K-mer lengths do not match: {} {}", other.k, self.k);
}
if other.rc() != self.rc {
panic!("Strand use inconsistent");
}
self.names.extend_from_slice(&other.names);
let total_samples = self.n_samples + other.nsamples();
for (kmer, other_vec) in &mut other.split_kmers {
self.split_kmers
.entry(*kmer)
.and_modify(|self_vec| {
self_vec.extend_from_slice(other_vec);
})
.or_insert_with(|| {
let mut empty_samples = vec![0; self.n_samples];
empty_samples.extend_from_slice(other_vec);
empty_samples
});
}
for (_kmer, self_vec) in &mut self.split_kmers {
if self_vec.len() != total_samples {
self_vec.extend(vec![0; other.nsamples()]);
}
}
self.n_samples = total_samples;
}
pub fn kmer_len(&self) -> usize {
self.k
}
pub fn rc(&self) -> bool {
self.rc
}
pub fn names(&self) -> &Vec<String> {
&self.names
}
pub fn kmer_dict(&self) -> &HashMap<IntT, Vec<u8>> {
&self.split_kmers
}
pub fn ksize(&self) -> usize {
self.split_kmers.len()
}
pub fn nsamples(&self) -> usize {
self.n_samples
}
}
fn multi_append<IntT>(
input_files: &[InputFastx],
offset: usize,
total_size: usize,
k: usize,
rc: bool,
qual: &QualOpts,
proportion_reads: Option<f64>,
) -> MergeSkaDict<IntT>
where
IntT: for<'a> UInt<'a>,
{
let mut merged_dict = MergeSkaDict::new(k, total_size, rc);
for (idx, (name, filename, second_file)) in input_files.iter().enumerate() {
let ska_dict = SkaDict::new(
k,
idx + offset,
(filename, second_file.as_ref()),
name,
rc,
qual,
proportion_reads,
);
merged_dict.append(&ska_dict);
}
merged_dict
}
#[allow(clippy::too_many_arguments)]
fn parallel_append<IntT>(
depth: usize,
offset: usize,
file_list: &[InputFastx],
total_size: usize,
k: usize,
rc: bool,
qual: &QualOpts,
proportion_reads: Option<f64>,
) -> MergeSkaDict<IntT>
where
IntT: for<'a> UInt<'a>,
{
let split_point = file_list.len() / 2;
let (bottom, top) = file_list.split_at(split_point);
if depth == 1 {
let (mut bottom_merge, mut top_merge) = rayon::join(
|| multi_append(bottom, offset, total_size, k, rc, qual, proportion_reads),
|| {
multi_append(
top,
offset + split_point,
total_size,
k,
rc,
qual,
proportion_reads,
)
},
);
bottom_merge.merge(&mut top_merge);
bottom_merge
} else {
let (mut bottom_merge, mut top_merge) = rayon::join(
|| {
parallel_append(
depth - 1,
offset,
bottom,
total_size,
k,
rc,
qual,
proportion_reads,
)
},
|| {
parallel_append(
depth - 1,
offset + split_point,
top,
total_size,
k,
rc,
qual,
proportion_reads,
)
},
);
bottom_merge.merge(&mut top_merge);
bottom_merge
}
}
pub fn build_and_merge<IntT>(
input_files: &[InputFastx],
k: usize,
rc: bool,
qual: &QualOpts,
threads: usize,
proportion_reads: Option<f64>,
) -> MergeSkaDict<IntT>
where
IntT: for<'a> UInt<'a>,
{
log::info!("Building skf dicts from sequence input");
if any_fastq(input_files) {
log::info!("FASTQ files filtered with: {qual}");
} else {
log::info!("All input files FASTA (no error filtering)");
}
if threads > 1 {
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build_global()
.unwrap();
}
let total_size = input_files.len();
let mut merged_dict = MergeSkaDict::new(k, total_size, rc);
let max_threads = usize::max(1, usize::min(threads, 1 + total_size / 10));
let max_depth = f64::floor(f64::log2(max_threads as f64)) as usize;
if max_depth > 0 {
log::info!(
"Build and merge skf dicts in parallel using {} threads",
1 << max_depth
);
merged_dict = parallel_append(
max_depth,
0,
input_files,
total_size,
k,
rc,
qual,
proportion_reads,
);
} else {
log::info!("Build and merge serially");
for (idx, (name, filename, second_file)) in input_files.iter().progress().enumerate() {
let ska_dict = SkaDict::new(
k,
idx,
(filename, second_file.as_ref()),
name,
rc,
qual,
proportion_reads,
);
merged_dict.append(&ska_dict);
}
}
merged_dict
}