bellpepper_core/
lc.rs

1use std::ops::{Add, Sub};
2
3use ff::PrimeField;
4use serde::{Deserialize, Serialize};
5
6/// Represents a variable in our constraint system.
7#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Variable(pub Index);
9
10impl Variable {
11    /// This constructs a variable with an arbitrary index.
12    /// Circuit implementations are not recommended to use this.
13    pub fn new_unchecked(idx: Index) -> Variable {
14        Variable(idx)
15    }
16
17    /// This returns the index underlying the variable.
18    /// Circuit implementations are not recommended to use this.
19    pub fn get_unchecked(&self) -> Index {
20        self.0
21    }
22}
23
24/// Represents the index of either an input variable or
25/// auxiliary variable.
26#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)]
27pub enum Index {
28    Input(usize),
29    Aux(usize),
30}
31
32/// This represents a linear combination of some variables, with coefficients
33/// in the scalar field of a pairing-friendly elliptic curve group.
34#[derive(Clone, Debug, PartialEq)]
35pub struct LinearCombination<Scalar: PrimeField> {
36    inputs: Indexer<Scalar>,
37    aux: Indexer<Scalar>,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41struct Indexer<T> {
42    /// Stores a list of `T` indexed by the number in the first slot of the tuple.
43    values: Vec<(usize, T)>,
44    /// `(index, key)` of the last insertion operation. Used to optimize consecutive operations
45    last_inserted: Option<(usize, usize)>,
46}
47
48impl<T> Default for Indexer<T> {
49    fn default() -> Self {
50        Indexer {
51            values: Vec::new(),
52            last_inserted: None,
53        }
54    }
55}
56
57impl<T> Indexer<T> {
58    pub fn from_value(index: usize, value: T) -> Self {
59        Indexer {
60            values: vec![(index, value)],
61            last_inserted: Some((0, index)),
62        }
63    }
64
65    pub fn iter(&self) -> impl Iterator<Item = (&usize, &T)> + '_ {
66        self.values.iter().map(|(key, value)| (key, value))
67    }
68
69    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&usize, &mut T)> + '_ {
70        self.values.iter_mut().map(|(key, value)| (&*key, value))
71    }
72
73    pub fn insert_or_update<F, G>(&mut self, key: usize, insert: F, update: G)
74    where
75        F: FnOnce() -> T,
76        G: FnOnce(&mut T),
77    {
78        if let Some((last_index, last_key)) = self.last_inserted {
79            // Optimization to avoid doing binary search on inserts & updates that are linear, meaning
80            // they are adding a consecutive values.
81            if last_key == key {
82                // update the same key again
83                update(&mut self.values[last_index].1);
84                return;
85            } else if last_key + 1 == key {
86                // optimization for follow on updates
87                let i = last_index + 1;
88                if i >= self.values.len() {
89                    // insert at the end
90                    self.values.push((key, insert()));
91                    self.last_inserted = Some((i, key));
92                } else if self.values[i].0 == key {
93                    // update
94                    update(&mut self.values[i].1);
95                } else {
96                    // insert
97                    self.values.insert(i, (key, insert()));
98                    self.last_inserted = Some((i, key));
99                }
100                return;
101            }
102        }
103        match self.values.binary_search_by_key(&key, |(k, _)| *k) {
104            Ok(i) => {
105                update(&mut self.values[i].1);
106            }
107            Err(i) => {
108                self.values.insert(i, (key, insert()));
109                self.last_inserted = Some((i, key));
110            }
111        }
112    }
113
114    pub fn len(&self) -> usize {
115        self.values.len()
116    }
117
118    pub fn is_empty(&self) -> bool {
119        self.values.is_empty()
120    }
121}
122
123impl<Scalar: PrimeField> Default for LinearCombination<Scalar> {
124    fn default() -> Self {
125        Self::zero()
126    }
127}
128
129impl<Scalar: PrimeField> LinearCombination<Scalar> {
130    pub fn zero() -> LinearCombination<Scalar> {
131        LinearCombination {
132            inputs: Default::default(),
133            aux: Default::default(),
134        }
135    }
136
137    pub fn from_coeff(var: Variable, coeff: Scalar) -> Self {
138        match var {
139            Variable(Index::Input(i)) => Self {
140                inputs: Indexer::from_value(i, coeff),
141                aux: Default::default(),
142            },
143            Variable(Index::Aux(i)) => Self {
144                inputs: Default::default(),
145                aux: Indexer::from_value(i, coeff),
146            },
147        }
148    }
149
150    pub fn from_variable(var: Variable) -> Self {
151        Self::from_coeff(var, Scalar::ONE)
152    }
153
154    pub fn iter(&self) -> impl Iterator<Item = (Variable, &Scalar)> + '_ {
155        self.inputs
156            .iter()
157            .map(|(k, v)| (Variable(Index::Input(*k)), v))
158            .chain(self.aux.iter().map(|(k, v)| (Variable(Index::Aux(*k)), v)))
159    }
160
161    #[inline]
162    pub fn iter_inputs(&self) -> impl Iterator<Item = (&usize, &Scalar)> + '_ {
163        self.inputs.iter()
164    }
165
166    #[inline]
167    pub fn iter_aux(&self) -> impl Iterator<Item = (&usize, &Scalar)> + '_ {
168        self.aux.iter()
169    }
170
171    pub fn iter_mut(&mut self) -> impl Iterator<Item = (Variable, &mut Scalar)> + '_ {
172        self.inputs
173            .iter_mut()
174            .map(|(k, v)| (Variable(Index::Input(*k)), v))
175            .chain(
176                self.aux
177                    .iter_mut()
178                    .map(|(k, v)| (Variable(Index::Aux(*k)), v)),
179            )
180    }
181
182    #[inline]
183    fn add_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) {
184        self.inputs
185            .insert_or_update(new_var, || coeff, |val| *val += coeff);
186    }
187
188    #[inline]
189    fn add_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) {
190        self.aux
191            .insert_or_update(new_var, || coeff, |val| *val += coeff);
192    }
193
194    pub fn add_unsimplified(
195        mut self,
196        (coeff, var): (Scalar, Variable),
197    ) -> LinearCombination<Scalar> {
198        match var.0 {
199            Index::Input(new_var) => {
200                self.add_assign_unsimplified_input(new_var, coeff);
201            }
202            Index::Aux(new_var) => {
203                self.add_assign_unsimplified_aux(new_var, coeff);
204            }
205        }
206
207        self
208    }
209
210    #[inline]
211    fn sub_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) {
212        self.add_assign_unsimplified_input(new_var, -coeff);
213    }
214
215    #[inline]
216    fn sub_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) {
217        self.add_assign_unsimplified_aux(new_var, -coeff);
218    }
219
220    pub fn sub_unsimplified(
221        mut self,
222        (coeff, var): (Scalar, Variable),
223    ) -> LinearCombination<Scalar> {
224        match var.0 {
225            Index::Input(new_var) => {
226                self.sub_assign_unsimplified_input(new_var, coeff);
227            }
228            Index::Aux(new_var) => {
229                self.sub_assign_unsimplified_aux(new_var, coeff);
230            }
231        }
232
233        self
234    }
235
236    pub fn len(&self) -> usize {
237        self.inputs.len() + self.aux.len()
238    }
239
240    pub fn is_empty(&self) -> bool {
241        self.inputs.is_empty() && self.aux.is_empty()
242    }
243
244    pub fn eval(&self, input_assignment: &[Scalar], aux_assignment: &[Scalar]) -> Scalar {
245        let mut acc = Scalar::ZERO;
246
247        let one = Scalar::ONE;
248
249        for (index, coeff) in self.iter_inputs() {
250            let mut tmp = input_assignment[*index];
251            if coeff != &one {
252                tmp *= coeff;
253            }
254            acc += tmp;
255        }
256
257        for (index, coeff) in self.iter_aux() {
258            let mut tmp = aux_assignment[*index];
259            if coeff != &one {
260                tmp *= coeff;
261            }
262            acc += tmp;
263        }
264
265        acc
266    }
267}
268
269impl<Scalar: PrimeField> Add<(Scalar, Variable)> for LinearCombination<Scalar> {
270    type Output = LinearCombination<Scalar>;
271
272    fn add(self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
273        self.add_unsimplified((coeff, var))
274    }
275}
276
277impl<Scalar: PrimeField> Sub<(Scalar, Variable)> for LinearCombination<Scalar> {
278    type Output = LinearCombination<Scalar>;
279
280    #[allow(clippy::suspicious_arithmetic_impl)]
281    fn sub(self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
282        self.sub_unsimplified((coeff, var))
283    }
284}
285
286impl<Scalar: PrimeField> Add<Variable> for LinearCombination<Scalar> {
287    type Output = LinearCombination<Scalar>;
288
289    fn add(self, other: Variable) -> LinearCombination<Scalar> {
290        self + (Scalar::ONE, other)
291    }
292}
293
294impl<Scalar: PrimeField> Sub<Variable> for LinearCombination<Scalar> {
295    type Output = LinearCombination<Scalar>;
296
297    fn sub(self, other: Variable) -> LinearCombination<Scalar> {
298        self - (Scalar::ONE, other)
299    }
300}
301
302impl<'a, Scalar: PrimeField> Add<&'a LinearCombination<Scalar>> for LinearCombination<Scalar> {
303    type Output = LinearCombination<Scalar>;
304
305    fn add(mut self, other: &'a LinearCombination<Scalar>) -> LinearCombination<Scalar> {
306        for (var, val) in other.inputs.iter() {
307            self.add_assign_unsimplified_input(*var, *val);
308        }
309
310        for (var, val) in other.aux.iter() {
311            self.add_assign_unsimplified_aux(*var, *val);
312        }
313
314        self
315    }
316}
317
318impl<'a, Scalar: PrimeField> Sub<&'a LinearCombination<Scalar>> for LinearCombination<Scalar> {
319    type Output = LinearCombination<Scalar>;
320
321    fn sub(mut self, other: &'a LinearCombination<Scalar>) -> LinearCombination<Scalar> {
322        for (var, val) in other.inputs.iter() {
323            self.sub_assign_unsimplified_input(*var, *val);
324        }
325
326        for (var, val) in other.aux.iter() {
327            self.sub_assign_unsimplified_aux(*var, *val);
328        }
329
330        self
331    }
332}
333
334impl<'a, Scalar: PrimeField> Add<(Scalar, &'a LinearCombination<Scalar>)>
335    for LinearCombination<Scalar>
336{
337    type Output = LinearCombination<Scalar>;
338
339    fn add(
340        mut self,
341        (coeff, other): (Scalar, &'a LinearCombination<Scalar>),
342    ) -> LinearCombination<Scalar> {
343        for (var, val) in other.inputs.iter() {
344            self.add_assign_unsimplified_input(*var, *val * coeff);
345        }
346
347        for (var, val) in other.aux.iter() {
348            self.add_assign_unsimplified_aux(*var, *val * coeff);
349        }
350
351        self
352    }
353}
354
355impl<'a, Scalar: PrimeField> Sub<(Scalar, &'a LinearCombination<Scalar>)>
356    for LinearCombination<Scalar>
357{
358    type Output = LinearCombination<Scalar>;
359
360    fn sub(
361        mut self,
362        (coeff, other): (Scalar, &'a LinearCombination<Scalar>),
363    ) -> LinearCombination<Scalar> {
364        for (var, val) in other.inputs.iter() {
365            self.sub_assign_unsimplified_input(*var, *val * coeff);
366        }
367
368        for (var, val) in other.aux.iter() {
369            self.sub_assign_unsimplified_aux(*var, *val * coeff);
370        }
371
372        self
373    }
374}
375
376#[cfg(all(test, feature = "groth16"))]
377mod tests {
378    use super::*;
379    use blstrs::Scalar;
380    use ff::Field;
381
382    #[test]
383    fn test_add_simplify() {
384        let n = 5;
385
386        let mut lc = LinearCombination::<Scalar>::zero();
387
388        let mut expected_sums = vec![Scalar::ZERO; n];
389        let mut total_additions = 0;
390        for (i, expected_sum) in expected_sums.iter_mut().enumerate() {
391            for _ in 0..i + 1 {
392                let coeff = Scalar::ONE;
393                lc = lc + (coeff, Variable::new_unchecked(Index::Aux(i)));
394                *expected_sum += coeff;
395                total_additions += 1;
396            }
397        }
398
399        // There are only as many terms as distinct variable Indexes — not one per addition operation.
400        assert_eq!(n, lc.len());
401        assert!(lc.len() != total_additions);
402
403        // Each variable has the expected coefficient, the sume of those added by its Index.
404        lc.iter().for_each(|(var, coeff)| match var.0 {
405            Index::Aux(i) => assert_eq!(expected_sums[i], *coeff),
406            _ => panic!("unexpected variable type"),
407        });
408    }
409
410    #[test]
411    fn test_insert_or_update() {
412        let mut indexer = Indexer::default();
413        let one = Scalar::ONE;
414        let mut two = one;
415        two += one;
416
417        indexer.insert_or_update(2, || one, |v| *v += one);
418        assert_eq!(&indexer.values, &[(2, one)]);
419        assert_eq!(&indexer.last_inserted, &Some((0, 2)));
420
421        indexer.insert_or_update(3, || one, |v| *v += one);
422        assert_eq!(&indexer.values, &[(2, one), (3, one)]);
423        assert_eq!(&indexer.last_inserted, &Some((1, 3)));
424
425        indexer.insert_or_update(1, || one, |v| *v += one);
426        assert_eq!(&indexer.values, &[(1, one), (2, one), (3, one)]);
427        assert_eq!(&indexer.last_inserted, &Some((0, 1)));
428
429        indexer.insert_or_update(2, || one, |v| *v += one);
430        assert_eq!(&indexer.values, &[(1, one), (2, two), (3, one)]);
431        assert_eq!(&indexer.last_inserted, &Some((0, 1)));
432    }
433}