acir/native_types/expression/
operators.rs

1use crate::native_types::Witness;
2use acir_field::FieldElement;
3use std::{
4    cmp::Ordering,
5    ops::{Add, Mul, Neg, Sub},
6};
7
8use super::Expression;
9
10// Negation
11
12impl Neg for &Expression {
13    type Output = Expression;
14    fn neg(self) -> Self::Output {
15        // XXX(med) : Implement an efficient way to do this
16
17        let mul_terms: Vec<_> =
18            self.mul_terms.iter().map(|(q_m, w_l, w_r)| (-*q_m, *w_l, *w_r)).collect();
19
20        let linear_combinations: Vec<_> =
21            self.linear_combinations.iter().map(|(q_k, w_k)| (-*q_k, *w_k)).collect();
22        let q_c = -self.q_c;
23
24        Expression { mul_terms, linear_combinations, q_c }
25    }
26}
27
28// FieldElement
29
30impl Add<FieldElement> for Expression {
31    type Output = Expression;
32    fn add(self, rhs: FieldElement) -> Self::Output {
33        // Increase the constant
34        let q_c = self.q_c + rhs;
35
36        Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
37    }
38}
39
40impl Add<Expression> for FieldElement {
41    type Output = Expression;
42    #[inline]
43    fn add(self, rhs: Expression) -> Self::Output {
44        rhs + self
45    }
46}
47
48impl Sub<FieldElement> for Expression {
49    type Output = Expression;
50    fn sub(self, rhs: FieldElement) -> Self::Output {
51        // Increase the constant
52        let q_c = self.q_c - rhs;
53
54        Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
55    }
56}
57
58impl Sub<Expression> for FieldElement {
59    type Output = Expression;
60    #[inline]
61    fn sub(self, rhs: Expression) -> Self::Output {
62        rhs - self
63    }
64}
65
66impl Mul<FieldElement> for &Expression {
67    type Output = Expression;
68    fn mul(self, rhs: FieldElement) -> Self::Output {
69        // Scale the mul terms
70        let mul_terms: Vec<_> =
71            self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * rhs, *w_l, *w_r)).collect();
72
73        // Scale the linear combinations terms
74        let lin_combinations: Vec<_> =
75            self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * rhs, *w_l)).collect();
76
77        // Scale the constant
78        let q_c = self.q_c * rhs;
79
80        Expression { mul_terms, q_c, linear_combinations: lin_combinations }
81    }
82}
83
84impl Mul<&Expression> for FieldElement {
85    type Output = Expression;
86    #[inline]
87    fn mul(self, rhs: &Expression) -> Self::Output {
88        rhs * self
89    }
90}
91
92// Witness
93
94impl Add<Witness> for &Expression {
95    type Output = Expression;
96    fn add(self, rhs: Witness) -> Expression {
97        self + &Expression::from(rhs)
98    }
99}
100
101impl Add<&Expression> for Witness {
102    type Output = Expression;
103    #[inline]
104    fn add(self, rhs: &Expression) -> Expression {
105        rhs + self
106    }
107}
108
109impl Sub<Witness> for &Expression {
110    type Output = Expression;
111    fn sub(self, rhs: Witness) -> Expression {
112        self - &Expression::from(rhs)
113    }
114}
115
116impl Sub<&Expression> for Witness {
117    type Output = Expression;
118    #[inline]
119    fn sub(self, rhs: &Expression) -> Expression {
120        rhs - self
121    }
122}
123
124// Mul<Witness> is not implemented as this could result in degree 3 terms.
125
126// Expression
127
128impl Add<&Expression> for &Expression {
129    type Output = Expression;
130    fn add(self, rhs: &Expression) -> Expression {
131        self.add_mul(FieldElement::one(), rhs)
132    }
133}
134
135impl Sub<&Expression> for &Expression {
136    type Output = Expression;
137    fn sub(self, rhs: &Expression) -> Expression {
138        self.add_mul(-FieldElement::one(), rhs)
139    }
140}
141
142impl Mul<&Expression> for &Expression {
143    type Output = Option<Expression>;
144    fn mul(self, rhs: &Expression) -> Option<Expression> {
145        if self.is_const() {
146            return Some(self.q_c * rhs);
147        } else if rhs.is_const() {
148            return Some(self * rhs.q_c);
149        } else if !(self.is_linear() && rhs.is_linear()) {
150            // `Expression`s can only represent terms which are up to degree 2.
151            // We then disallow multiplication of `Expression`s which have degree 2 terms.
152            return None;
153        }
154
155        let mut output = Expression::from_field(self.q_c * rhs.q_c);
156
157        //TODO to optimize...
158        for lc in &self.linear_combinations {
159            let single = single_mul(lc.1, rhs);
160            output = output.add_mul(lc.0, &single);
161        }
162
163        //linear terms
164        let mut i1 = 0; //a
165        let mut i2 = 0; //b
166        while i1 < self.linear_combinations.len() && i2 < rhs.linear_combinations.len() {
167            let (a_c, a_w) = self.linear_combinations[i1];
168            let (b_c, b_w) = rhs.linear_combinations[i2];
169
170            // Apply scaling from multiplication
171            let a_c = rhs.q_c * a_c;
172            let b_c = self.q_c * b_c;
173
174            let (coeff, witness) = match a_w.cmp(&b_w) {
175                Ordering::Greater => {
176                    i2 += 1;
177                    (b_c, b_w)
178                }
179                Ordering::Less => {
180                    i1 += 1;
181                    (a_c, a_w)
182                }
183                Ordering::Equal => {
184                    // Here we're taking both terms as the witness indices are equal.
185                    // We then advance both `i1` and `i2`.
186                    i1 += 1;
187                    i2 += 1;
188                    (a_c + b_c, a_w)
189                }
190            };
191
192            if !coeff.is_zero() {
193                output.linear_combinations.push((coeff, witness));
194            }
195        }
196        while i1 < self.linear_combinations.len() {
197            let (a_c, a_w) = self.linear_combinations[i1];
198            let coeff = rhs.q_c * a_c;
199            if !coeff.is_zero() {
200                output.linear_combinations.push((coeff, a_w));
201            }
202            i1 += 1;
203        }
204        while i2 < rhs.linear_combinations.len() {
205            let (b_c, b_w) = rhs.linear_combinations[i2];
206            let coeff = self.q_c * b_c;
207            if !coeff.is_zero() {
208                output.linear_combinations.push((coeff, b_w));
209            }
210            i2 += 1;
211        }
212
213        Some(output)
214    }
215}
216
217/// Returns `w*b.linear_combinations`
218fn single_mul(w: Witness, b: &Expression) -> Expression {
219    Expression {
220        mul_terms: b
221            .linear_combinations
222            .iter()
223            .map(|(a, wit)| {
224                let (wl, wr) = if w < *wit { (w, *wit) } else { (*wit, w) };
225                (*a, wl, wr)
226            })
227            .collect(),
228        ..Default::default()
229    }
230}
231
232#[test]
233fn add_smoke_test() {
234    let a = Expression {
235        mul_terms: vec![],
236        linear_combinations: vec![(FieldElement::from(2u128), Witness(2))],
237        q_c: FieldElement::from(2u128),
238    };
239
240    let b = Expression {
241        mul_terms: vec![],
242        linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
243        q_c: FieldElement::one(),
244    };
245
246    assert_eq!(
247        &a + &b,
248        Expression {
249            mul_terms: vec![],
250            linear_combinations: vec![
251                (FieldElement::from(2u128), Witness(2)),
252                (FieldElement::from(4u128), Witness(4))
253            ],
254            q_c: FieldElement::from(3u128)
255        }
256    );
257
258    // Enforce commutativity
259    assert_eq!(&a + &b, &b + &a);
260}
261
262#[test]
263fn mul_smoke_test() {
264    let a = Expression {
265        mul_terms: vec![],
266        linear_combinations: vec![(FieldElement::from(2u128), Witness(2))],
267        q_c: FieldElement::from(2u128),
268    };
269
270    let b = Expression {
271        mul_terms: vec![],
272        linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
273        q_c: FieldElement::one(),
274    };
275
276    assert_eq!(
277        (&a * &b).unwrap(),
278        Expression {
279            mul_terms: vec![(FieldElement::from(8u128), Witness(2), Witness(4)),],
280            linear_combinations: vec![
281                (FieldElement::from(2u128), Witness(2)),
282                (FieldElement::from(8u128), Witness(4))
283            ],
284            q_c: FieldElement::from(2u128)
285        }
286    );
287
288    // Enforce commutativity
289    assert_eq!(&a * &b, &b * &a);
290}