use subtle::ConstantTimeEq;
use crate::Field;
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub trait BatchInvert<F: Field> {
fn batch_invert(self) -> F;
}
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
impl<'a, F, I> BatchInvert<F> for I
where
F: Field + ConstantTimeEq,
I: IntoIterator<Item = &'a mut F>,
{
fn batch_invert(self) -> F {
let mut acc = F::one();
let iter = self.into_iter();
let mut tmp = alloc::vec::Vec::with_capacity(iter.size_hint().0);
for p in iter {
let q = *p;
tmp.push((acc, p));
acc = F::conditional_select(&(acc * q), &acc, q.ct_eq(&F::zero()));
}
acc = acc.invert().unwrap();
let allinv = acc;
for (tmp, p) in tmp.into_iter().rev() {
let skip = p.ct_eq(&F::zero());
let tmp = tmp * acc;
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, p, skip);
}
allinv
}
}
pub struct BatchInverter {}
impl BatchInverter {
pub fn invert_with_external_scratch<F>(elements: &mut [F], scratch_space: &mut [F]) -> F
where
F: Field + ConstantTimeEq,
{
assert_eq!(elements.len(), scratch_space.len());
let mut acc = F::one();
for (p, scratch) in elements.iter().zip(scratch_space.iter_mut()) {
*scratch = acc;
acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
}
acc = acc.invert().unwrap();
let allinv = acc;
for (p, scratch) in elements.iter_mut().zip(scratch_space.iter()).rev() {
let tmp = *scratch * acc;
let skip = p.ct_eq(&F::zero());
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, &p, skip);
}
allinv
}
pub fn invert_with_internal_scratch<F, T, TE, TS>(
items: &mut [T],
element: TE,
scratch_space: TS,
) -> F
where
F: Field + ConstantTimeEq,
TE: Fn(&mut T) -> &mut F,
TS: Fn(&mut T) -> &mut F,
{
let mut acc = F::one();
for item in items.iter_mut() {
*(scratch_space)(item) = acc;
let p = (element)(item);
acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
}
acc = acc.invert().unwrap();
let allinv = acc;
for item in items.iter_mut().rev() {
let tmp = *(scratch_space)(item) * acc;
let p = (element)(item);
let skip = p.ct_eq(&F::zero());
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, &p, skip);
}
allinv
}
}