use crate::base::if_rayon;
use alloc::vec::Vec;
#[cfg(feature = "rayon")]
use core::cmp::max;
use core::ops::{Mul, MulAssign};
use num_traits::{Inv, One, Zero};
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[tracing::instrument(name = "BatchInversion::batch_inversion", level = "debug", skip_all)]
pub fn batch_inversion<F>(v: &mut [F])
where
F: One + Zero + MulAssign + Inv<Output = Option<F>> + Mul<Output = F> + Send + Sync + Copy,
{
batch_inversion_and_mul(v, F::one());
}
#[tracing::instrument(
name = "BatchInversion::batch_inversion_and_mul",
level = "debug",
skip_all
)]
pub fn batch_inversion_and_mul<F>(v: &mut [F], coeff: F)
where
F: One + Zero + MulAssign + Inv<Output = Option<F>> + Mul<Output = F> + Send + Sync + Copy,
{
if_rayon!(
{
let num_cpus_available = max(1, rayon::current_num_threads());
let num_elem_per_thread =
max(v.len().div_ceil(num_cpus_available), super::MIN_RAYON_LEN);
v.par_chunks_mut(num_elem_per_thread).for_each(|chunk| {
serial_batch_inversion_and_mul(chunk, coeff);
});
},
serial_batch_inversion_and_mul(v, coeff)
);
}
fn serial_batch_inversion_and_mul<F>(v: &mut [F], coeff: F)
where
F: One + Zero + MulAssign + Inv<Output = Option<F>> + Mul<Output = F> + Copy,
{
let mut prod = Vec::with_capacity(v.len());
let mut tmp = F::one();
for &f in v.iter().filter(|f| !f.is_zero()) {
tmp *= f;
prod.push(tmp);
}
tmp = tmp.inv().unwrap();
tmp *= coeff;
for (f, s) in v
.iter_mut()
.rev()
.filter(|f| !f.is_zero())
.zip(prod.into_iter().rev().skip(1).chain(Some(F::one())))
{
let new_tmp = tmp * *f;
*f = tmp * s;
tmp = new_tmp;
}
}