use super::iter::SeqIter;
use super::{SeqVec, SeqVecBitReader, SeqVecError};
use crate::variable::traits::Storable;
use dsi_bitstream::dispatch::CodesRead;
use dsi_bitstream::prelude::{BitRead, BitSeek, Endianness};
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
impl<T, E, B> SeqVec<T, E, B>
where
T: Storable + Send + Sync,
E: Endianness + Send + Sync,
B: AsRef<[u64]> + Send + Sync,
for<'a> SeqVecBitReader<'a, E>: BitRead<E, Error = core::convert::Infallible>
+ CodesRead<E>
+ BitSeek<Error = core::convert::Infallible>
+ Send,
{
pub fn par_iter(&self) -> impl ParallelIterator<Item = Vec<T>> + '_ {
let num_sequences = self.num_sequences();
(0..num_sequences).into_par_iter().map_init(
|| self.reader(),
move |reader, i| {
let capacity = self
.seq_lengths
.as_ref()
.map(|l| unsafe { l.get_unchecked(i) as usize })
.unwrap_or(0);
let mut buf = Vec::with_capacity(capacity);
reader.decode_into(i, &mut buf).unwrap();
buf
},
)
}
pub fn par_for_each<F, R>(&self, f: F) -> Vec<R>
where
F: Fn(SeqIter<'_, T, E>) -> R + Sync + Send,
R: Send,
{
let num_sequences = self.num_sequences();
let data = self.data.as_ref();
let bit_offsets = &self.bit_offsets;
let seq_lengths = self.seq_lengths.as_ref();
let encoding = self.encoding;
(0..num_sequences)
.into_par_iter()
.map(|i| {
let start_bit = unsafe { bit_offsets.get_unchecked(i) };
let end_bit = unsafe { bit_offsets.get_unchecked(i + 1) };
let len = seq_lengths.map(|l| unsafe { l.get_unchecked(i) as usize });
let iter = SeqIter::new_with_len(data, start_bit, end_bit, encoding, len);
f(iter)
})
.collect()
}
pub fn par_for_each_reduce<F, R, ID, OP>(&self, f: F, identity: ID, op: OP) -> R
where
F: Fn(SeqIter<'_, T, E>) -> R + Sync + Send,
R: Send,
ID: Fn() -> R + Sync + Send,
OP: Fn(R, R) -> R + Sync + Send,
{
let num_sequences = self.num_sequences();
let data = self.data.as_ref();
let bit_offsets = &self.bit_offsets;
let seq_lengths = self.seq_lengths.as_ref();
let encoding = self.encoding;
(0..num_sequences)
.into_par_iter()
.map(|i| {
let start_bit = unsafe { bit_offsets.get_unchecked(i) };
let end_bit = unsafe { bit_offsets.get_unchecked(i + 1) };
let len = seq_lengths.map(|l| unsafe { l.get_unchecked(i) as usize });
let iter = SeqIter::new_with_len(data, start_bit, end_bit, encoding, len);
f(iter)
})
.reduce(identity, op)
}
pub fn par_into_vecs(self) -> Vec<Vec<T>> {
let num_sequences = self.num_sequences();
let seqvec = &self;
(0..num_sequences)
.into_par_iter()
.map_init(
|| seqvec.reader(),
move |reader, i| {
let capacity = seqvec
.seq_lengths
.as_ref()
.map(|l| unsafe { l.get_unchecked(i) as usize })
.unwrap_or(0);
let mut buf = Vec::with_capacity(capacity);
reader.decode_into(i, &mut buf).unwrap();
buf
},
)
.collect()
}
pub fn par_decode_many(&self, indices: &[usize]) -> Result<Vec<Vec<T>>, SeqVecError> {
if indices.is_empty() {
return Ok(Vec::new());
}
let num_sequences = self.num_sequences();
for &index in indices {
if index >= num_sequences {
return Err(SeqVecError::IndexOutOfBounds(index));
}
}
Ok(unsafe { self.par_decode_many_unchecked(indices) })
}
pub unsafe fn par_decode_many_unchecked(&self, indices: &[usize]) -> Vec<Vec<T>> {
#[cfg(debug_assertions)]
{
let num_sequences = self.num_sequences();
for &index in indices {
debug_assert!(
index < num_sequences,
"Index out of bounds: index was {} but num_sequences was {}",
index,
num_sequences
);
}
}
if indices.is_empty() {
return Vec::new();
}
let mut results = vec![Vec::new(); indices.len()];
results.par_iter_mut().enumerate().for_each_init(
|| self.reader(),
|reader, (original_pos, result)| {
let target_index = indices[original_pos];
if let Some(lengths) = &self.seq_lengths {
let capacity = unsafe { lengths.get_unchecked(target_index) as usize };
result.reserve(capacity);
}
reader.decode_into(target_index, result).unwrap();
},
);
results
}
pub fn par_for_each_many<F, R>(&self, indices: &[usize], f: F) -> Result<Vec<R>, SeqVecError>
where
F: Fn(SeqIter<'_, T, E>) -> R + Sync + Send,
R: Send,
{
if indices.is_empty() {
return Ok(Vec::new());
}
let num_sequences = self.num_sequences();
for &index in indices {
if index >= num_sequences {
return Err(SeqVecError::IndexOutOfBounds(index));
}
}
let data = self.data.as_ref();
let bit_offsets = &self.bit_offsets;
let seq_lengths = self.seq_lengths.as_ref();
let encoding = self.encoding;
let results = indices
.par_iter()
.map(|&i| {
let start_bit = unsafe { bit_offsets.get_unchecked(i) };
let end_bit = unsafe { bit_offsets.get_unchecked(i + 1) };
let len = seq_lengths.map(|l| unsafe { l.get_unchecked(i) as usize });
let iter = SeqIter::new_with_len(data, start_bit, end_bit, encoding, len);
f(iter)
})
.collect();
Ok(results)
}
}