use alloc::vec::Vec;
use p3_maybe_rayon::prelude::*;
use tracing::instrument;
use crate::field::Field;
use crate::{
ExtensionField, FieldArray, PackedFieldExtension, PackedValue, PrimeCharacteristicRing,
};
#[instrument(level = "debug", skip_all)]
#[must_use]
pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
const CHUNK_SIZE: usize = 1024;
const WIDTH: usize = 4;
let mut result = F::zero_vec(x.len());
x.par_chunks(CHUNK_SIZE)
.zip(result.par_chunks_mut(CHUNK_SIZE))
.for_each(|(x_chunk, result_chunk)| {
let (x_packed, x_tail) = FieldArray::<F, WIDTH>::pack_slice_with_suffix(x_chunk);
let (result_packed, result_tail) =
FieldArray::<F, WIDTH>::pack_slice_with_suffix_mut(result_chunk);
batch_multiplicative_inverse_general(x_packed, result_packed, |y| y.inverse());
batch_multiplicative_inverse_general(x_tail, result_tail, |y| y.inverse());
});
result
}
#[inline]
pub fn batch_multiplicative_inverse_general<F, Inv>(x: &[F], result: &mut [F], inv: Inv)
where
F: PrimeCharacteristicRing + Copy,
Inv: Fn(F) -> F,
{
let n = x.len();
assert_eq!(result.len(), n);
if n == 0 {
return;
}
result[0] = F::ONE;
for i in 1..n {
result[i] = result[i - 1] * x[i - 1];
}
let product = result[n - 1] * x[n - 1];
let mut inv = inv(product);
for i in (0..n).rev() {
result[i] *= inv;
inv *= x[i];
}
}
#[inline]
pub fn invert_packed_extension<F, EF>(packed: EF::ExtensionPacking) -> EF::ExtensionPacking
where
F: Field,
EF: ExtensionField<F>,
{
match F::Packing::WIDTH {
1 => invert_packed_extension_const::<F, EF, 1>(packed),
2 => invert_packed_extension_const::<F, EF, 2>(packed),
4 => invert_packed_extension_const::<F, EF, 4>(packed),
8 => invert_packed_extension_const::<F, EF, 8>(packed),
16 => invert_packed_extension_const::<F, EF, 16>(packed),
w => panic!("unsupported PackedField WIDTH = {w}"),
}
}
#[inline]
fn invert_packed_extension_const<F, EF, const W: usize>(
packed: EF::ExtensionPacking,
) -> EF::ExtensionPacking
where
F: Field,
EF: ExtensionField<F>,
{
let lanes: [EF; W] = core::array::from_fn(|i| packed.extract(i));
let mut invs = [EF::ZERO; W];
batch_multiplicative_inverse_general(&lanes, &mut invs, |x| x.inverse());
EF::ExtensionPacking::from_ext_fn(|i| invs[i])
}