1use std::ops::{Add, Sub};
2
3use ff::PrimeField;
4use serde::{Deserialize, Serialize};
5
6#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Variable(pub Index);
9
10impl Variable {
11 pub fn new_unchecked(idx: Index) -> Variable {
14 Variable(idx)
15 }
16
17 pub fn get_unchecked(&self) -> Index {
20 self.0
21 }
22}
23
24#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)]
27pub enum Index {
28 Input(usize),
29 Aux(usize),
30}
31
32#[derive(Clone, Debug, PartialEq)]
35pub struct LinearCombination<Scalar: PrimeField> {
36 inputs: Indexer<Scalar>,
37 aux: Indexer<Scalar>,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41struct Indexer<T> {
42 values: Vec<(usize, T)>,
44 last_inserted: Option<(usize, usize)>,
46}
47
48impl<T> Default for Indexer<T> {
49 fn default() -> Self {
50 Indexer {
51 values: Vec::new(),
52 last_inserted: None,
53 }
54 }
55}
56
57impl<T> Indexer<T> {
58 pub fn from_value(index: usize, value: T) -> Self {
59 Indexer {
60 values: vec![(index, value)],
61 last_inserted: Some((0, index)),
62 }
63 }
64
65 pub fn iter(&self) -> impl Iterator<Item = (&usize, &T)> + '_ {
66 self.values.iter().map(|(key, value)| (key, value))
67 }
68
69 pub fn iter_mut(&mut self) -> impl Iterator<Item = (&usize, &mut T)> + '_ {
70 self.values.iter_mut().map(|(key, value)| (&*key, value))
71 }
72
73 pub fn insert_or_update<F, G>(&mut self, key: usize, insert: F, update: G)
74 where
75 F: FnOnce() -> T,
76 G: FnOnce(&mut T),
77 {
78 if let Some((last_index, last_key)) = self.last_inserted {
79 if last_key == key {
82 update(&mut self.values[last_index].1);
84 return;
85 } else if last_key + 1 == key {
86 let i = last_index + 1;
88 if i >= self.values.len() {
89 self.values.push((key, insert()));
91 self.last_inserted = Some((i, key));
92 } else if self.values[i].0 == key {
93 update(&mut self.values[i].1);
95 } else {
96 self.values.insert(i, (key, insert()));
98 self.last_inserted = Some((i, key));
99 }
100 return;
101 }
102 }
103 match self.values.binary_search_by_key(&key, |(k, _)| *k) {
104 Ok(i) => {
105 update(&mut self.values[i].1);
106 }
107 Err(i) => {
108 self.values.insert(i, (key, insert()));
109 self.last_inserted = Some((i, key));
110 }
111 }
112 }
113
114 pub fn len(&self) -> usize {
115 self.values.len()
116 }
117
118 pub fn is_empty(&self) -> bool {
119 self.values.is_empty()
120 }
121}
122
123impl<Scalar: PrimeField> Default for LinearCombination<Scalar> {
124 fn default() -> Self {
125 Self::zero()
126 }
127}
128
129impl<Scalar: PrimeField> LinearCombination<Scalar> {
130 pub fn zero() -> LinearCombination<Scalar> {
131 LinearCombination {
132 inputs: Default::default(),
133 aux: Default::default(),
134 }
135 }
136
137 pub fn from_coeff(var: Variable, coeff: Scalar) -> Self {
138 match var {
139 Variable(Index::Input(i)) => Self {
140 inputs: Indexer::from_value(i, coeff),
141 aux: Default::default(),
142 },
143 Variable(Index::Aux(i)) => Self {
144 inputs: Default::default(),
145 aux: Indexer::from_value(i, coeff),
146 },
147 }
148 }
149
150 pub fn from_variable(var: Variable) -> Self {
151 Self::from_coeff(var, Scalar::ONE)
152 }
153
154 pub fn iter(&self) -> impl Iterator<Item = (Variable, &Scalar)> + '_ {
155 self.inputs
156 .iter()
157 .map(|(k, v)| (Variable(Index::Input(*k)), v))
158 .chain(self.aux.iter().map(|(k, v)| (Variable(Index::Aux(*k)), v)))
159 }
160
161 #[inline]
162 pub fn iter_inputs(&self) -> impl Iterator<Item = (&usize, &Scalar)> + '_ {
163 self.inputs.iter()
164 }
165
166 #[inline]
167 pub fn iter_aux(&self) -> impl Iterator<Item = (&usize, &Scalar)> + '_ {
168 self.aux.iter()
169 }
170
171 pub fn iter_mut(&mut self) -> impl Iterator<Item = (Variable, &mut Scalar)> + '_ {
172 self.inputs
173 .iter_mut()
174 .map(|(k, v)| (Variable(Index::Input(*k)), v))
175 .chain(
176 self.aux
177 .iter_mut()
178 .map(|(k, v)| (Variable(Index::Aux(*k)), v)),
179 )
180 }
181
182 #[inline]
183 fn add_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) {
184 self.inputs
185 .insert_or_update(new_var, || coeff, |val| *val += coeff);
186 }
187
188 #[inline]
189 fn add_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) {
190 self.aux
191 .insert_or_update(new_var, || coeff, |val| *val += coeff);
192 }
193
194 pub fn add_unsimplified(
195 mut self,
196 (coeff, var): (Scalar, Variable),
197 ) -> LinearCombination<Scalar> {
198 match var.0 {
199 Index::Input(new_var) => {
200 self.add_assign_unsimplified_input(new_var, coeff);
201 }
202 Index::Aux(new_var) => {
203 self.add_assign_unsimplified_aux(new_var, coeff);
204 }
205 }
206
207 self
208 }
209
210 #[inline]
211 fn sub_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) {
212 self.add_assign_unsimplified_input(new_var, -coeff);
213 }
214
215 #[inline]
216 fn sub_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) {
217 self.add_assign_unsimplified_aux(new_var, -coeff);
218 }
219
220 pub fn sub_unsimplified(
221 mut self,
222 (coeff, var): (Scalar, Variable),
223 ) -> LinearCombination<Scalar> {
224 match var.0 {
225 Index::Input(new_var) => {
226 self.sub_assign_unsimplified_input(new_var, coeff);
227 }
228 Index::Aux(new_var) => {
229 self.sub_assign_unsimplified_aux(new_var, coeff);
230 }
231 }
232
233 self
234 }
235
236 pub fn len(&self) -> usize {
237 self.inputs.len() + self.aux.len()
238 }
239
240 pub fn is_empty(&self) -> bool {
241 self.inputs.is_empty() && self.aux.is_empty()
242 }
243
244 pub fn eval(&self, input_assignment: &[Scalar], aux_assignment: &[Scalar]) -> Scalar {
245 let mut acc = Scalar::ZERO;
246
247 let one = Scalar::ONE;
248
249 for (index, coeff) in self.iter_inputs() {
250 let mut tmp = input_assignment[*index];
251 if coeff != &one {
252 tmp *= coeff;
253 }
254 acc += tmp;
255 }
256
257 for (index, coeff) in self.iter_aux() {
258 let mut tmp = aux_assignment[*index];
259 if coeff != &one {
260 tmp *= coeff;
261 }
262 acc += tmp;
263 }
264
265 acc
266 }
267}
268
269impl<Scalar: PrimeField> Add<(Scalar, Variable)> for LinearCombination<Scalar> {
270 type Output = LinearCombination<Scalar>;
271
272 fn add(self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
273 self.add_unsimplified((coeff, var))
274 }
275}
276
277impl<Scalar: PrimeField> Sub<(Scalar, Variable)> for LinearCombination<Scalar> {
278 type Output = LinearCombination<Scalar>;
279
280 #[allow(clippy::suspicious_arithmetic_impl)]
281 fn sub(self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
282 self.sub_unsimplified((coeff, var))
283 }
284}
285
286impl<Scalar: PrimeField> Add<Variable> for LinearCombination<Scalar> {
287 type Output = LinearCombination<Scalar>;
288
289 fn add(self, other: Variable) -> LinearCombination<Scalar> {
290 self + (Scalar::ONE, other)
291 }
292}
293
294impl<Scalar: PrimeField> Sub<Variable> for LinearCombination<Scalar> {
295 type Output = LinearCombination<Scalar>;
296
297 fn sub(self, other: Variable) -> LinearCombination<Scalar> {
298 self - (Scalar::ONE, other)
299 }
300}
301
302impl<'a, Scalar: PrimeField> Add<&'a LinearCombination<Scalar>> for LinearCombination<Scalar> {
303 type Output = LinearCombination<Scalar>;
304
305 fn add(mut self, other: &'a LinearCombination<Scalar>) -> LinearCombination<Scalar> {
306 for (var, val) in other.inputs.iter() {
307 self.add_assign_unsimplified_input(*var, *val);
308 }
309
310 for (var, val) in other.aux.iter() {
311 self.add_assign_unsimplified_aux(*var, *val);
312 }
313
314 self
315 }
316}
317
318impl<'a, Scalar: PrimeField> Sub<&'a LinearCombination<Scalar>> for LinearCombination<Scalar> {
319 type Output = LinearCombination<Scalar>;
320
321 fn sub(mut self, other: &'a LinearCombination<Scalar>) -> LinearCombination<Scalar> {
322 for (var, val) in other.inputs.iter() {
323 self.sub_assign_unsimplified_input(*var, *val);
324 }
325
326 for (var, val) in other.aux.iter() {
327 self.sub_assign_unsimplified_aux(*var, *val);
328 }
329
330 self
331 }
332}
333
334impl<'a, Scalar: PrimeField> Add<(Scalar, &'a LinearCombination<Scalar>)>
335 for LinearCombination<Scalar>
336{
337 type Output = LinearCombination<Scalar>;
338
339 fn add(
340 mut self,
341 (coeff, other): (Scalar, &'a LinearCombination<Scalar>),
342 ) -> LinearCombination<Scalar> {
343 for (var, val) in other.inputs.iter() {
344 self.add_assign_unsimplified_input(*var, *val * coeff);
345 }
346
347 for (var, val) in other.aux.iter() {
348 self.add_assign_unsimplified_aux(*var, *val * coeff);
349 }
350
351 self
352 }
353}
354
355impl<'a, Scalar: PrimeField> Sub<(Scalar, &'a LinearCombination<Scalar>)>
356 for LinearCombination<Scalar>
357{
358 type Output = LinearCombination<Scalar>;
359
360 fn sub(
361 mut self,
362 (coeff, other): (Scalar, &'a LinearCombination<Scalar>),
363 ) -> LinearCombination<Scalar> {
364 for (var, val) in other.inputs.iter() {
365 self.sub_assign_unsimplified_input(*var, *val * coeff);
366 }
367
368 for (var, val) in other.aux.iter() {
369 self.sub_assign_unsimplified_aux(*var, *val * coeff);
370 }
371
372 self
373 }
374}
375
376#[cfg(all(test, feature = "groth16"))]
377mod tests {
378 use super::*;
379 use blstrs::Scalar;
380 use ff::Field;
381
382 #[test]
383 fn test_add_simplify() {
384 let n = 5;
385
386 let mut lc = LinearCombination::<Scalar>::zero();
387
388 let mut expected_sums = vec![Scalar::ZERO; n];
389 let mut total_additions = 0;
390 for (i, expected_sum) in expected_sums.iter_mut().enumerate() {
391 for _ in 0..i + 1 {
392 let coeff = Scalar::ONE;
393 lc = lc + (coeff, Variable::new_unchecked(Index::Aux(i)));
394 *expected_sum += coeff;
395 total_additions += 1;
396 }
397 }
398
399 assert_eq!(n, lc.len());
401 assert!(lc.len() != total_additions);
402
403 lc.iter().for_each(|(var, coeff)| match var.0 {
405 Index::Aux(i) => assert_eq!(expected_sums[i], *coeff),
406 _ => panic!("unexpected variable type"),
407 });
408 }
409
410 #[test]
411 fn test_insert_or_update() {
412 let mut indexer = Indexer::default();
413 let one = Scalar::ONE;
414 let mut two = one;
415 two += one;
416
417 indexer.insert_or_update(2, || one, |v| *v += one);
418 assert_eq!(&indexer.values, &[(2, one)]);
419 assert_eq!(&indexer.last_inserted, &Some((0, 2)));
420
421 indexer.insert_or_update(3, || one, |v| *v += one);
422 assert_eq!(&indexer.values, &[(2, one), (3, one)]);
423 assert_eq!(&indexer.last_inserted, &Some((1, 3)));
424
425 indexer.insert_or_update(1, || one, |v| *v += one);
426 assert_eq!(&indexer.values, &[(1, one), (2, one), (3, one)]);
427 assert_eq!(&indexer.last_inserted, &Some((0, 1)));
428
429 indexer.insert_or_update(2, || one, |v| *v += one);
430 assert_eq!(&indexer.values, &[(1, one), (2, two), (3, one)]);
431 assert_eq!(&indexer.last_inserted, &Some((0, 1)));
432 }
433}