1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//! Batched field inversion APIs, using [Montgomery's trick].
//!
//! [Montgomery's trick]: https://zcash.github.io/halo2/background/fields.html#montgomerys-trick

use subtle::ConstantTimeEq;

use crate::Field;

/// Extension trait for iterators over mutable field elements which allows those field
/// elements to be inverted in a batch.
///
/// `I: IntoIterator<Item = &'a mut F: Field + ConstantTimeEq>` implements this trait when
/// the `alloc` feature flag is enabled.
///
/// For non-allocating contexts, see the [`BatchInverter`] struct.
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub trait BatchInvert<F: Field> {
    /// Consumes this iterator and inverts each field element (when nonzero). Zero-valued
    /// elements are left as zero.
    ///
    /// Returns the inverse of the product of all nonzero field elements.
    fn batch_invert(self) -> F;
}

#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
impl<'a, F, I> BatchInvert<F> for I
where
    F: Field + ConstantTimeEq,
    I: IntoIterator<Item = &'a mut F>,
{
    fn batch_invert(self) -> F {
        let mut acc = F::one();
        let iter = self.into_iter();
        let mut tmp = alloc::vec::Vec::with_capacity(iter.size_hint().0);
        for p in iter {
            let q = *p;
            tmp.push((acc, p));
            acc = F::conditional_select(&(acc * q), &acc, q.ct_eq(&F::zero()));
        }
        acc = acc.invert().unwrap();
        let allinv = acc;

        for (tmp, p) in tmp.into_iter().rev() {
            let skip = p.ct_eq(&F::zero());

            let tmp = tmp * acc;
            acc = F::conditional_select(&(acc * *p), &acc, skip);
            *p = F::conditional_select(&tmp, p, skip);
        }

        allinv
    }
}

/// A non-allocating batch inverter.
pub struct BatchInverter {}

impl BatchInverter {
    /// Inverts each field element in `elements` (when nonzero). Zero-valued elements are
    /// left as zero.
    ///
    /// - `scratch_space` is a slice of field elements that can be freely overwritten.
    ///
    /// Returns the inverse of the product of all nonzero field elements.
    ///
    /// # Panics
    ///
    /// This function will panic if `elements.len() != scratch_space.len()`.
    pub fn invert_with_external_scratch<F>(elements: &mut [F], scratch_space: &mut [F]) -> F
    where
        F: Field + ConstantTimeEq,
    {
        assert_eq!(elements.len(), scratch_space.len());

        let mut acc = F::one();
        for (p, scratch) in elements.iter().zip(scratch_space.iter_mut()) {
            *scratch = acc;
            acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
        }
        acc = acc.invert().unwrap();
        let allinv = acc;

        for (p, scratch) in elements.iter_mut().zip(scratch_space.iter()).rev() {
            let tmp = *scratch * acc;
            let skip = p.ct_eq(&F::zero());
            acc = F::conditional_select(&(acc * *p), &acc, skip);
            *p = F::conditional_select(&tmp, &p, skip);
        }

        allinv
    }

    /// Inverts each field element in `items` (when nonzero). Zero-valued elements are
    /// left as zero.
    ///
    /// - `element` is a function that extracts the element to be inverted from `items`.
    /// - `scratch_space` is a function that extracts the scratch space from `items`.
    ///
    /// Returns the inverse of the product of all nonzero field elements.
    pub fn invert_with_internal_scratch<F, T, TE, TS>(
        items: &mut [T],
        element: TE,
        scratch_space: TS,
    ) -> F
    where
        F: Field + ConstantTimeEq,
        TE: Fn(&mut T) -> &mut F,
        TS: Fn(&mut T) -> &mut F,
    {
        let mut acc = F::one();
        for item in items.iter_mut() {
            *(scratch_space)(item) = acc;
            let p = (element)(item);
            acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
        }
        acc = acc.invert().unwrap();
        let allinv = acc;

        for item in items.iter_mut().rev() {
            let tmp = *(scratch_space)(item) * acc;
            let p = (element)(item);
            let skip = p.ct_eq(&F::zero());
            acc = F::conditional_select(&(acc * *p), &acc, skip);
            *p = F::conditional_select(&tmp, &p, skip);
        }

        allinv
    }
}