use core::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use crate::{
codec::{
batch_error::PartialInit, codebook::Codebook, codec_config::CodecConfig,
compressed_vector::CompressedVector, parallelism::Parallelism, service::Codec,
},
errors::CodecError,
};
use alloc::vec::Vec;
#[allow(unsafe_code)]
pub(super) fn compress_batch_parallel(
vectors: &[f32],
rows: usize,
cols: usize,
config: &CodecConfig,
codebook: &Codebook,
parallelism: Parallelism,
) -> Result<Vec<CompressedVector>, CodecError> {
debug_assert_eq!(
vectors.len(),
rows * cols,
"compress_batch_parallel: vectors.len() ({}) != rows ({}) * cols ({})",
vectors.len(),
rows,
cols
);
let mut partial: PartialInit<CompressedVector> = PartialInit::new(rows);
let out_atomic = AtomicPtr::new(partial.as_mut_ptr().cast::<CompressedVector>());
let flags: &[AtomicBool] = partial.flags();
let first_error = std::sync::Mutex::new(None::<CodecError>);
let has_error = AtomicBool::new(false);
parallelism.for_each_row(rows, |i| {
if has_error.load(Ordering::Relaxed) {
return;
}
#[allow(clippy::indexing_slicing)]
let slice = &vectors[i * cols..(i + 1) * cols];
match Codec::new().compress(slice, config, codebook) {
Ok(cv) => {
let base = out_atomic.load(Ordering::Relaxed);
unsafe { base.add(i).write(cv) };
if let Some(flag) = flags.get(i) {
flag.store(true, Ordering::Release);
}
}
Err(e) => {
has_error.store(true, Ordering::Relaxed);
let mut slot = first_error
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if slot.is_none() {
*slot = Some(e);
}
}
}
});
match first_error
.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner)
{
Some(e) => {
drop(partial);
Err(e)
}
None => {
Ok(unsafe { partial.into_vec() })
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::compress_batch_parallel;
use crate::codec::{codebook::Codebook, codec_config::CodecConfig, parallelism::Parallelism};
fn make_training() -> alloc::vec::Vec<f32> {
use rand_chacha::rand_core::{RngCore, SeedableRng};
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(99);
(0..32 * 8)
.map(|_| {
let bits = rng.next_u32();
#[allow(clippy::cast_precision_loss)]
let v = (bits as f32 / u32::MAX as f32) * 2.0 - 1.0;
v
})
.collect()
}
#[test]
#[cfg_attr(miri, ignore = "libm sqrtsd SSE2 asm unsupported under Miri")]
fn serial_and_custom_match_byte_for_byte() {
let training = make_training();
let cfg = CodecConfig::new(4, 42, 8, true).unwrap();
let cb = Codebook::train(&training, &cfg).unwrap();
let rows = 16_usize;
let cols = 8_usize;
let batch = &training[..rows * cols];
let serial =
compress_batch_parallel(batch, rows, cols, &cfg, &cb, Parallelism::Serial).unwrap();
let custom = compress_batch_parallel(
batch,
rows,
cols,
&cfg,
&cb,
Parallelism::Custom(|count, body| (0..count).for_each(body)),
)
.unwrap();
assert_eq!(serial.len(), custom.len());
for (a, b) in serial.iter().zip(custom.iter()) {
assert_eq!(a.indices(), b.indices(), "indices mismatch");
assert_eq!(a.residual(), b.residual(), "residual mismatch");
}
}
}