use alloc::vec;
use alloc::vec::Vec;
use crate::field::Field;
pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
const WIDTH: usize = 4;
let n = x.len();
if n == 0 {
return Vec::new();
} else if n == 1 {
return vec![x[0].inverse()];
} else if n == 2 {
let x01 = x[0] * x[1];
let x01inv = x01.inverse();
return vec![x01inv * x[1], x01inv * x[0]];
} else if n == 3 {
let x01 = x[0] * x[1];
let x012 = x01 * x[2];
let x012inv = x012.inverse();
let x01inv = x012inv * x[2];
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
}
debug_assert!(n >= WIDTH);
let mut buf: Vec<F> = Vec::with_capacity(n);
let mut cumul_prod: [F; WIDTH] = x[..WIDTH].try_into().unwrap();
buf.extend(cumul_prod);
for (i, &xi) in x[WIDTH..].iter().enumerate() {
cumul_prod[i % WIDTH] *= xi;
buf.push(cumul_prod[i % WIDTH]);
}
debug_assert_eq!(buf.len(), n);
let mut a_inv = {
let c01 = cumul_prod[0] * cumul_prod[1];
let c23 = cumul_prod[2] * cumul_prod[3];
let c0123 = c01 * c23;
let c0123inv = c0123.inverse();
let c01inv = c0123inv * c23;
let c23inv = c0123inv * c01;
[
c01inv * cumul_prod[1],
c01inv * cumul_prod[0],
c23inv * cumul_prod[3],
c23inv * cumul_prod[2],
]
};
for i in (WIDTH..n).rev() {
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
a_inv[i % WIDTH] *= x[i];
}
for i in (0..WIDTH).rev() {
buf[i] = a_inv[i];
}
for (&bi, &xi) in buf.iter().zip(x) {
debug_assert_eq!(bi * xi, F::one());
}
buf
}