Skip to main content

ark_relations/utils/
linear_combination.rs

1#![allow(clippy::suspicious_arithmetic_impl)]
2
3use ark_ff::Field;
4use ark_std::{
5    ops::{Add, AddAssign, Deref, DerefMut, Mul, MulAssign, Neg, Sub},
6    vec,
7    vec::Vec,
8};
9
10use super::variable::Variable;
11
12/// A linear combination of variables according to associated coefficients.
13#[derive(Debug, Clone, PartialEq, Eq, Default, PartialOrd, Ord)]
14#[must_use]
15pub struct LinearCombination<F: Field>(pub Vec<(F, Variable)>);
16
17/// Generate a `LinearCombination` from arithmetic expressions involving
18/// `Variable`s.
19#[macro_export]
20macro_rules! lc {
21    // Empty input
22    () => { $crate::gr1cs::LinearCombination::new() };
23
24    // List of (coeff, var) pairs: lc![(a, b), (c, d), ...]
25    ($(($coeff:expr, $var:expr)),+ $(,)?) => { $crate::gr1cs::LinearCombination::from_sum_coeff_vars(&[$(($coeff, $var)),*]) };
26
27    // List of variables: lc![a, b, c, ...]
28    ($($var:expr),+ $(,)?) => { $crate::gr1cs::LinearCombination::sum_vars(&[$($var),*]) };
29}
30
31/// Generate a `LinearCombination` representing the difference of two variables.
32#[macro_export]
33macro_rules! lc_diff {
34    // Subtraction of two variables: lc_diff!(a, b)
35    ($a:expr, $b:expr) => {
36        $crate::gr1cs::LinearCombination::diff_vars($a, $b)
37    };
38}
39
40impl<F: Field> LinearCombination<F> {
41    /// Create a new empty linear combination.
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Create a new empty linear combination.
47    pub fn zero() -> Self {
48        Self::new()
49    }
50
51    /// Deduplicate entries in `self` by combining coefficients of identical variables.
52    #[inline]
53    pub fn compactify(&mut self) {
54        // For 0 or 1 element, there is nothing to do.
55        if self.len() <= 1 {
56            return;
57        }
58
59        // Sort by the variable key.
60        self.0.sort_unstable_by_key(|e| e.1);
61
62        // Use write_index to indicate where to write the next unique element.
63        let mut write_index = 0;
64
65        // Iterate through the vector starting at the second element.
66        for read_index in 1..self.0.len() {
67            // Compare the current (unique) element with the next one.
68            if self.0[write_index].1 == self.0[read_index].1 {
69                // They have the same key: accumulate the coefficient.
70                let add_coeff = self.0[read_index].0; // Copy out the value to avoid borrowing issues.
71                self.0[write_index].0 += add_coeff;
72            } else {
73                // When encountering a new key, move the write pointer forward
74                // and copy the new pair.
75                write_index += 1;
76                self.0[write_index] = self.0[read_index];
77            }
78        }
79
80        // Drop any extra entries that were overwritten.
81        self.0.truncate(write_index + 1);
82    }
83
84    /// Create a new linear combination from the sum of many variables.
85    #[inline]
86    pub fn sum_vars(variables: &[Variable]) -> Self {
87        let lc = variables
88            .iter()
89            .map(|&var| (F::ONE, var))
90            .collect::<Vec<_>>();
91        let mut lc = LinearCombination(lc);
92        lc.compactify();
93        lc
94    }
95
96    /// Create a new linear combination from the sum of many (coefficient, variable) pairs.
97    #[inline]
98    pub fn from_sum_coeff_vars(terms: &[(F, Variable)]) -> Self {
99        let mut lc = LinearCombination(terms.to_vec());
100        lc.compactify();
101        lc
102    }
103
104    /// Create a new linear combination from the difference of two variables.
105    pub fn diff_vars(a: Variable, b: Variable) -> Self {
106        if a == b {
107            LinearCombination::zero()
108        } else {
109            LinearCombination(vec![(F::one(), a), (-F::one(), b)])
110        }
111    }
112}
113
114impl<F: Field> Deref for LinearCombination<F> {
115    type Target = Vec<(F, Variable)>;
116
117    #[inline]
118    fn deref(&self) -> &Vec<(F, Variable)> {
119        &self.0
120    }
121}
122
123impl<F: Field> DerefMut for LinearCombination<F> {
124    #[inline]
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        &mut self.0
127    }
128}
129
130impl<F: Field> From<(F, Variable)> for LinearCombination<F> {
131    #[inline]
132    fn from(input: (F, Variable)) -> Self {
133        if input.0.is_zero() || input.1.is_zero() {
134            LinearCombination::zero()
135        } else {
136            LinearCombination(vec![input])
137        }
138    }
139}
140
141impl<F: Field> From<Variable> for LinearCombination<F> {
142    #[inline]
143    fn from(var: Variable) -> Self {
144        if var.is_zero() {
145            LinearCombination::zero()
146        } else {
147            LinearCombination::from((F::one(), var))
148        }
149    }
150}
151
152impl<F: Field> IntoIterator for LinearCombination<F> {
153    type Item = (F, Variable);
154    type IntoIter = ark_std::vec::IntoIter<(F, Variable)>;
155
156    #[inline]
157    fn into_iter(self) -> Self::IntoIter {
158        self.0.into_iter()
159    }
160}
161
162impl<F: Field> LinearCombination<F> {
163    /// Negate the coefficients of all variables in `self`.
164    #[inline]
165    pub fn negate_in_place(&mut self) {
166        self.0.iter_mut().for_each(|(coeff, _)| *coeff = -(*coeff));
167    }
168
169    /// Get the location of a variable in `self`.
170    ///
171    /// # Errors
172    /// If the variable is not found, returns the index where it would be inserted.
173    #[inline]
174    pub fn get_var_loc(&self, search_var: &Variable) -> Result<usize, usize> {
175        if self.0.len() < 6 {
176            let mut found_index = 0;
177            for (i, (_, var)) in self.iter().enumerate() {
178                if var >= search_var {
179                    found_index = i;
180                    break;
181                } else {
182                    found_index += 1;
183                }
184            }
185            Err(found_index)
186        } else {
187            self.0
188                .binary_search_by_key(search_var, |&(_, cur_var)| cur_var)
189        }
190    }
191}
192
193impl<F: Field> Add<(F, Variable)> for LinearCombination<F> {
194    type Output = Self;
195
196    #[inline]
197    fn add(mut self, coeff_var: (F, Variable)) -> Self {
198        self += coeff_var;
199        self
200    }
201}
202
203impl<F: Field> AddAssign<(F, Variable)> for LinearCombination<F> {
204    #[inline]
205    fn add_assign(&mut self, (coeff, var): (F, Variable)) {
206        match self.get_var_loc(&var) {
207            Ok(found) => self.0[found].0 += &coeff,
208            Err(not_found) => self.0.insert(not_found, (coeff, var)),
209        }
210    }
211}
212
213impl<F: Field> Sub<(F, Variable)> for LinearCombination<F> {
214    type Output = Self;
215
216    #[inline]
217    fn sub(self, (coeff, var): (F, Variable)) -> Self {
218        self + (-coeff, var)
219    }
220}
221
222impl<F: Field> Neg for LinearCombination<F> {
223    type Output = Self;
224
225    #[inline]
226    fn neg(mut self) -> Self {
227        self.negate_in_place();
228        self
229    }
230}
231
232impl<F: Field> Mul<F> for LinearCombination<F> {
233    type Output = Self;
234
235    #[inline]
236    fn mul(mut self, scalar: F) -> Self {
237        self *= scalar;
238        self
239    }
240}
241
242impl<F: Field> Mul<F> for &LinearCombination<F> {
243    type Output = LinearCombination<F>;
244
245    #[inline]
246    fn mul(self, scalar: F) -> LinearCombination<F> {
247        let mut cur = self.clone();
248        cur *= scalar;
249        cur
250    }
251}
252
253impl<F: Field> MulAssign<F> for LinearCombination<F> {
254    #[inline]
255    fn mul_assign(&mut self, scalar: F) {
256        self.0.iter_mut().for_each(|(coeff, _)| *coeff *= &scalar);
257    }
258}
259
260impl<F: Field> Add<Variable> for LinearCombination<F> {
261    type Output = Self;
262
263    #[inline]
264    fn add(self, other: Variable) -> LinearCombination<F> {
265        self + (F::one(), other)
266    }
267}
268
269impl<'a, F: Field> Add<&'a Variable> for LinearCombination<F> {
270    type Output = Self;
271
272    #[inline]
273    fn add(self, other: &'a Variable) -> LinearCombination<F> {
274        self + *other
275    }
276}
277
278impl<'a, F: Field> Sub<&'a Variable> for LinearCombination<F> {
279    type Output = Self;
280
281    #[inline]
282    fn sub(self, other: &'a Variable) -> LinearCombination<F> {
283        self - *other
284    }
285}
286
287impl<F: Field> Sub<Variable> for LinearCombination<F> {
288    type Output = LinearCombination<F>;
289
290    #[inline]
291    fn sub(self, other: Variable) -> LinearCombination<F> {
292        self - (F::one(), other)
293    }
294}
295
296fn op_impl<F: Field, F1, F2>(
297    cur: &LinearCombination<F>,
298    other: &LinearCombination<F>,
299    push_fn: F1,
300    combine_fn: F2,
301) -> LinearCombination<F>
302where
303    F1: Fn(F) -> F,
304    F2: Fn(F, F) -> F,
305{
306    let mut new_vec = Vec::new();
307    let mut i = 0;
308    let mut j = 0;
309    while i < cur.len() && j < other.len() {
310        use core::cmp::Ordering;
311
312        let self_cur = &cur[i];
313        let other_cur = &other[j];
314        match self_cur.1.cmp(&other_cur.1) {
315            Ordering::Greater => {
316                new_vec.push((push_fn(other[j].0), other[j].1));
317                j += 1;
318            },
319            Ordering::Less => {
320                new_vec.push(*self_cur);
321                i += 1;
322            },
323            Ordering::Equal => {
324                new_vec.push((combine_fn(self_cur.0, other_cur.0), self_cur.1));
325                i += 1;
326                j += 1;
327            },
328        }
329    }
330    new_vec.extend_from_slice(&cur[i..]);
331    while j < other.0.len() {
332        new_vec.push((push_fn(other[j].0), other[j].1));
333        j += 1;
334    }
335    LinearCombination(new_vec)
336}
337
338impl<F: Field> Add<&LinearCombination<F>> for &LinearCombination<F> {
339    type Output = LinearCombination<F>;
340
341    fn add(self, other: &LinearCombination<F>) -> LinearCombination<F> {
342        if other.0.is_empty() {
343            return self.clone();
344        } else if self.0.is_empty() {
345            return other.clone();
346        }
347        op_impl(
348            self,
349            other,
350            |coeff| coeff,
351            |cur_coeff, other_coeff| cur_coeff + other_coeff,
352        )
353    }
354}
355
356impl<F: Field> Add<LinearCombination<F>> for &LinearCombination<F> {
357    type Output = LinearCombination<F>;
358
359    fn add(self, other: LinearCombination<F>) -> LinearCombination<F> {
360        if self.0.is_empty() {
361            return other;
362        } else if other.0.is_empty() {
363            return self.clone();
364        }
365        op_impl(
366            self,
367            &other,
368            |coeff| coeff,
369            |cur_coeff, other_coeff| cur_coeff + other_coeff,
370        )
371    }
372}
373
374impl<'a, F: Field> Add<&'a LinearCombination<F>> for LinearCombination<F> {
375    type Output = LinearCombination<F>;
376
377    fn add(self, other: &'a LinearCombination<F>) -> LinearCombination<F> {
378        if other.0.is_empty() {
379            return self;
380        } else if self.0.is_empty() {
381            return other.clone();
382        }
383        op_impl(
384            &self,
385            other,
386            |coeff| coeff,
387            |cur_coeff, other_coeff| cur_coeff + other_coeff,
388        )
389    }
390}
391
392impl<F: Field> Add<LinearCombination<F>> for LinearCombination<F> {
393    type Output = Self;
394
395    fn add(self, other: Self) -> Self {
396        if other.0.is_empty() {
397            return self;
398        } else if self.0.is_empty() {
399            return other;
400        }
401        op_impl(
402            &self,
403            &other,
404            |coeff| coeff,
405            |cur_coeff, other_coeff| cur_coeff + other_coeff,
406        )
407    }
408}
409
410impl<F: Field> Sub<&LinearCombination<F>> for &LinearCombination<F> {
411    type Output = LinearCombination<F>;
412
413    fn sub(self, other: &LinearCombination<F>) -> LinearCombination<F> {
414        if other.0.is_empty() {
415            let cur = self.clone();
416            return cur;
417        } else if self.0.is_empty() {
418            let mut other = other.clone();
419            other.negate_in_place();
420            return other;
421        }
422
423        op_impl(
424            self,
425            other,
426            |coeff| -coeff,
427            |cur_coeff, other_coeff| cur_coeff - other_coeff,
428        )
429    }
430}
431
432impl<'a, F: Field> Sub<&'a LinearCombination<F>> for LinearCombination<F> {
433    type Output = LinearCombination<F>;
434
435    fn sub(self, other: &'a LinearCombination<F>) -> LinearCombination<F> {
436        if other.0.is_empty() {
437            return self;
438        } else if self.0.is_empty() {
439            let mut other = other.clone();
440            other.negate_in_place();
441            return other;
442        }
443        op_impl(
444            &self,
445            other,
446            |coeff| -coeff,
447            |cur_coeff, other_coeff| cur_coeff - other_coeff,
448        )
449    }
450}
451
452impl<F: Field> Sub<LinearCombination<F>> for &LinearCombination<F> {
453    type Output = LinearCombination<F>;
454
455    fn sub(self, mut other: LinearCombination<F>) -> LinearCombination<F> {
456        if self.0.is_empty() {
457            other.negate_in_place();
458            return other;
459        } else if other.0.is_empty() {
460            return self.clone();
461        }
462
463        op_impl(
464            self,
465            &other,
466            |coeff| -coeff,
467            |cur_coeff, other_coeff| cur_coeff - other_coeff,
468        )
469    }
470}
471
472impl<F: Field> Sub<LinearCombination<F>> for LinearCombination<F> {
473    type Output = LinearCombination<F>;
474
475    fn sub(self, mut other: LinearCombination<F>) -> LinearCombination<F> {
476        if other.0.is_empty() {
477            return self;
478        } else if self.0.is_empty() {
479            other.negate_in_place();
480            return other;
481        }
482        op_impl(
483            &self,
484            &other,
485            |coeff| -coeff,
486            |cur_coeff, other_coeff| cur_coeff - other_coeff,
487        )
488    }
489}
490
491impl<F: Field> Add<(F, &LinearCombination<F>)> for &LinearCombination<F> {
492    type Output = LinearCombination<F>;
493
494    fn add(self, (mul_coeff, other): (F, &LinearCombination<F>)) -> LinearCombination<F> {
495        if other.0.is_empty() {
496            return self.clone();
497        } else if self.0.is_empty() {
498            let mut other = other.clone();
499            other.mul_assign(mul_coeff);
500            return other;
501        }
502        op_impl(
503            self,
504            other,
505            |coeff| mul_coeff * coeff,
506            |cur_coeff, other_coeff| cur_coeff + mul_coeff * other_coeff,
507        )
508    }
509}
510
511impl<'a, F: Field> Add<(F, &'a LinearCombination<F>)> for LinearCombination<F> {
512    type Output = LinearCombination<F>;
513
514    fn add(self, (mul_coeff, other): (F, &'a LinearCombination<F>)) -> LinearCombination<F> {
515        if other.0.is_empty() {
516            return self;
517        } else if self.0.is_empty() {
518            let mut other = other.clone();
519            other.mul_assign(mul_coeff);
520            return other;
521        }
522        op_impl(
523            &self,
524            other,
525            |coeff| mul_coeff * coeff,
526            |cur_coeff, other_coeff| cur_coeff + mul_coeff * other_coeff,
527        )
528    }
529}
530
531impl<F: Field> Add<(F, LinearCombination<F>)> for &LinearCombination<F> {
532    type Output = LinearCombination<F>;
533
534    fn add(self, (mul_coeff, mut other): (F, LinearCombination<F>)) -> LinearCombination<F> {
535        if other.0.is_empty() {
536            return self.clone();
537        } else if self.0.is_empty() {
538            other.mul_assign(mul_coeff);
539            return other;
540        }
541        op_impl(
542            self,
543            &other,
544            |coeff| mul_coeff * coeff,
545            |cur_coeff, other_coeff| cur_coeff + mul_coeff * other_coeff,
546        )
547    }
548}
549
550impl<F: Field> Add<(F, Self)> for LinearCombination<F> {
551    type Output = Self;
552
553    fn add(self, (mul_coeff, other): (F, Self)) -> Self {
554        if other.0.is_empty() {
555            return self;
556        } else if self.0.is_empty() {
557            let mut other = other;
558            other.mul_assign(mul_coeff);
559            return other;
560        }
561        op_impl(
562            &self,
563            &other,
564            |coeff| mul_coeff * coeff,
565            |cur_coeff, other_coeff| cur_coeff + mul_coeff * other_coeff,
566        )
567    }
568}
569
570impl<F: Field> Sub<(F, &LinearCombination<F>)> for &LinearCombination<F> {
571    type Output = LinearCombination<F>;
572
573    fn sub(self, (coeff, other): (F, &LinearCombination<F>)) -> LinearCombination<F> {
574        self + (-coeff, other)
575    }
576}
577
578impl<'a, F: Field> Sub<(F, &'a LinearCombination<F>)> for LinearCombination<F> {
579    type Output = LinearCombination<F>;
580
581    fn sub(self, (coeff, other): (F, &'a LinearCombination<F>)) -> LinearCombination<F> {
582        self + (-coeff, other)
583    }
584}
585
586impl<F: Field> Sub<(F, LinearCombination<F>)> for &LinearCombination<F> {
587    type Output = LinearCombination<F>;
588
589    fn sub(self, (coeff, other): (F, LinearCombination<F>)) -> LinearCombination<F> {
590        self + (-coeff, other)
591    }
592}
593
594impl<F: Field> Sub<(F, LinearCombination<F>)> for LinearCombination<F> {
595    type Output = LinearCombination<F>;
596
597    fn sub(self, (coeff, other): (F, LinearCombination<F>)) -> LinearCombination<F> {
598        self + (-coeff, other)
599    }
600}