poly_commit/
util.rs

1use zkstd::common::{PrimeField, Vec};
2
3/// power operation for prime field
4pub fn powers_of<F: PrimeField>(scalar: &F, max_degree: usize) -> Vec<F> {
5    let mut powers = Vec::with_capacity(max_degree + 1);
6    powers.push(F::one());
7    for i in 1..=max_degree {
8        powers.push(powers[i - 1] * scalar);
9    }
10    powers
11}
12
13/// batch inversion operation for prime field vectors
14pub fn batch_inversion<F: PrimeField>(v: &mut [F]) {
15    // Montgomery’s Trick and Fast Implementation of Masked AES
16    // Genelle, Prouff and Quisquater
17    // Section 3.2
18
19    // First pass: compute [a, ab, abc, ...]
20    let mut prod = Vec::with_capacity(v.len());
21    let mut tmp = F::one();
22    for f in v.iter().filter(|f| f != &&F::zero()) {
23        tmp.mul_assign(f);
24        prod.push(tmp);
25    }
26
27    // Invert `tmp`.
28    tmp = tmp.invert().unwrap(); // Guaranteed to be nonzero.
29
30    // Second pass: iterate backwards to compute inverses
31    for (f, s) in v
32        .iter_mut()
33        // Backwards
34        .rev()
35        // Ignore normalized elements
36        .filter(|f| f != &&F::zero())
37        // Backwards, skip last element, fill in one for last term.
38        .zip(prod.into_iter().rev().skip(1).chain(Some(F::one())))
39    {
40        // tmp := tmp * f; f := tmp * s = 1/f
41        let new_tmp = tmp * *f;
42        *f = tmp * s;
43        tmp = new_tmp;
44    }
45}
46
47#[cfg(test)]
48mod test {
49    use bls_12_381::Fr as BlsScalar;
50    use zkstd::common::Group;
51
52    use super::*;
53    #[test]
54    fn test_batch_inversion() {
55        let one = BlsScalar::from(1);
56        let two = BlsScalar::from(2);
57        let three = BlsScalar::from(3);
58        let four = BlsScalar::from(4);
59        let five = BlsScalar::from(5);
60
61        let original_scalars = vec![one, two, three, four, five];
62        let mut inverted_scalars = vec![one, two, three, four, five];
63
64        batch_inversion(&mut inverted_scalars);
65        for (x, x_inv) in original_scalars.iter().zip(inverted_scalars.iter()) {
66            assert_eq!(x.invert().unwrap(), *x_inv);
67        }
68    }
69}