use alloc::vec::Vec;
use p3_maybe_rayon::prelude::*;
use tracing::instrument;
use crate::field::Field;
use crate::{FieldArray, PackedValue, PrimeCharacteristicRing};
#[instrument(level = "debug", skip_all)]
#[must_use]
pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
const CHUNK_SIZE: usize = 1024;
let n = x.len();
let mut result = F::zero_vec(n);
x.par_chunks(CHUNK_SIZE)
.zip(result.par_chunks_mut(CHUNK_SIZE))
.for_each(|(x, result)| {
batch_multiplicative_inverse_helper(x, result);
});
result
}
fn batch_multiplicative_inverse_helper<F: Field>(x: &[F], result: &mut [F]) {
const WIDTH: usize = 4;
let n = x.len();
assert_eq!(result.len(), n);
if !n.is_multiple_of(WIDTH) {
return batch_multiplicative_inverse_general(x, result, |x| x.inverse());
}
let x_packed = FieldArray::<F, 4>::pack_slice(x);
let result_packed = FieldArray::<F, 4>::pack_slice_mut(result);
batch_multiplicative_inverse_general(x_packed, result_packed, |x_packed| x_packed.inverse());
}
pub(crate) fn batch_multiplicative_inverse_general<F, Inv>(x: &[F], result: &mut [F], inv: Inv)
where
F: PrimeCharacteristicRing + Copy,
Inv: Fn(F) -> F,
{
let n = x.len();
assert_eq!(result.len(), n);
if n == 0 {
return;
}
result[0] = F::ONE;
for i in 1..n {
result[i] = result[i - 1] * x[i - 1];
}
let product = result[n - 1] * x[n - 1];
let mut inv = inv(product);
for i in (0..n).rev() {
result[i] *= inv;
inv *= x[i];
}
}