acir/native_types/expression/
operators.rs1use crate::native_types::Witness;
2use acir_field::FieldElement;
3use std::{
4 cmp::Ordering,
5 ops::{Add, Mul, Neg, Sub},
6};
7
8use super::Expression;
9
10impl Neg for &Expression {
13 type Output = Expression;
14 fn neg(self) -> Self::Output {
15 let mul_terms: Vec<_> =
18 self.mul_terms.iter().map(|(q_m, w_l, w_r)| (-*q_m, *w_l, *w_r)).collect();
19
20 let linear_combinations: Vec<_> =
21 self.linear_combinations.iter().map(|(q_k, w_k)| (-*q_k, *w_k)).collect();
22 let q_c = -self.q_c;
23
24 Expression { mul_terms, linear_combinations, q_c }
25 }
26}
27
28impl Add<FieldElement> for Expression {
31 type Output = Expression;
32 fn add(self, rhs: FieldElement) -> Self::Output {
33 let q_c = self.q_c + rhs;
35
36 Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
37 }
38}
39
40impl Add<Expression> for FieldElement {
41 type Output = Expression;
42 #[inline]
43 fn add(self, rhs: Expression) -> Self::Output {
44 rhs + self
45 }
46}
47
48impl Sub<FieldElement> for Expression {
49 type Output = Expression;
50 fn sub(self, rhs: FieldElement) -> Self::Output {
51 let q_c = self.q_c - rhs;
53
54 Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
55 }
56}
57
58impl Sub<Expression> for FieldElement {
59 type Output = Expression;
60 #[inline]
61 fn sub(self, rhs: Expression) -> Self::Output {
62 rhs - self
63 }
64}
65
66impl Mul<FieldElement> for &Expression {
67 type Output = Expression;
68 fn mul(self, rhs: FieldElement) -> Self::Output {
69 let mul_terms: Vec<_> =
71 self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * rhs, *w_l, *w_r)).collect();
72
73 let lin_combinations: Vec<_> =
75 self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * rhs, *w_l)).collect();
76
77 let q_c = self.q_c * rhs;
79
80 Expression { mul_terms, q_c, linear_combinations: lin_combinations }
81 }
82}
83
84impl Mul<&Expression> for FieldElement {
85 type Output = Expression;
86 #[inline]
87 fn mul(self, rhs: &Expression) -> Self::Output {
88 rhs * self
89 }
90}
91
92impl Add<Witness> for &Expression {
95 type Output = Expression;
96 fn add(self, rhs: Witness) -> Expression {
97 self + &Expression::from(rhs)
98 }
99}
100
101impl Add<&Expression> for Witness {
102 type Output = Expression;
103 #[inline]
104 fn add(self, rhs: &Expression) -> Expression {
105 rhs + self
106 }
107}
108
109impl Sub<Witness> for &Expression {
110 type Output = Expression;
111 fn sub(self, rhs: Witness) -> Expression {
112 self - &Expression::from(rhs)
113 }
114}
115
116impl Sub<&Expression> for Witness {
117 type Output = Expression;
118 #[inline]
119 fn sub(self, rhs: &Expression) -> Expression {
120 rhs - self
121 }
122}
123
124impl Add<&Expression> for &Expression {
129 type Output = Expression;
130 fn add(self, rhs: &Expression) -> Expression {
131 self.add_mul(FieldElement::one(), rhs)
132 }
133}
134
135impl Sub<&Expression> for &Expression {
136 type Output = Expression;
137 fn sub(self, rhs: &Expression) -> Expression {
138 self.add_mul(-FieldElement::one(), rhs)
139 }
140}
141
142impl Mul<&Expression> for &Expression {
143 type Output = Option<Expression>;
144 fn mul(self, rhs: &Expression) -> Option<Expression> {
145 if self.is_const() {
146 return Some(self.q_c * rhs);
147 } else if rhs.is_const() {
148 return Some(self * rhs.q_c);
149 } else if !(self.is_linear() && rhs.is_linear()) {
150 return None;
153 }
154
155 let mut output = Expression::from_field(self.q_c * rhs.q_c);
156
157 for lc in &self.linear_combinations {
159 let single = single_mul(lc.1, rhs);
160 output = output.add_mul(lc.0, &single);
161 }
162
163 let mut i1 = 0; let mut i2 = 0; while i1 < self.linear_combinations.len() && i2 < rhs.linear_combinations.len() {
167 let (a_c, a_w) = self.linear_combinations[i1];
168 let (b_c, b_w) = rhs.linear_combinations[i2];
169
170 let a_c = rhs.q_c * a_c;
172 let b_c = self.q_c * b_c;
173
174 let (coeff, witness) = match a_w.cmp(&b_w) {
175 Ordering::Greater => {
176 i2 += 1;
177 (b_c, b_w)
178 }
179 Ordering::Less => {
180 i1 += 1;
181 (a_c, a_w)
182 }
183 Ordering::Equal => {
184 i1 += 1;
187 i2 += 1;
188 (a_c + b_c, a_w)
189 }
190 };
191
192 if !coeff.is_zero() {
193 output.linear_combinations.push((coeff, witness));
194 }
195 }
196 while i1 < self.linear_combinations.len() {
197 let (a_c, a_w) = self.linear_combinations[i1];
198 let coeff = rhs.q_c * a_c;
199 if !coeff.is_zero() {
200 output.linear_combinations.push((coeff, a_w));
201 }
202 i1 += 1;
203 }
204 while i2 < rhs.linear_combinations.len() {
205 let (b_c, b_w) = rhs.linear_combinations[i2];
206 let coeff = self.q_c * b_c;
207 if !coeff.is_zero() {
208 output.linear_combinations.push((coeff, b_w));
209 }
210 i2 += 1;
211 }
212
213 Some(output)
214 }
215}
216
217fn single_mul(w: Witness, b: &Expression) -> Expression {
219 Expression {
220 mul_terms: b
221 .linear_combinations
222 .iter()
223 .map(|(a, wit)| {
224 let (wl, wr) = if w < *wit { (w, *wit) } else { (*wit, w) };
225 (*a, wl, wr)
226 })
227 .collect(),
228 ..Default::default()
229 }
230}
231
232#[test]
233fn add_smoke_test() {
234 let a = Expression {
235 mul_terms: vec![],
236 linear_combinations: vec![(FieldElement::from(2u128), Witness(2))],
237 q_c: FieldElement::from(2u128),
238 };
239
240 let b = Expression {
241 mul_terms: vec![],
242 linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
243 q_c: FieldElement::one(),
244 };
245
246 assert_eq!(
247 &a + &b,
248 Expression {
249 mul_terms: vec![],
250 linear_combinations: vec![
251 (FieldElement::from(2u128), Witness(2)),
252 (FieldElement::from(4u128), Witness(4))
253 ],
254 q_c: FieldElement::from(3u128)
255 }
256 );
257
258 assert_eq!(&a + &b, &b + &a);
260}
261
262#[test]
263fn mul_smoke_test() {
264 let a = Expression {
265 mul_terms: vec![],
266 linear_combinations: vec![(FieldElement::from(2u128), Witness(2))],
267 q_c: FieldElement::from(2u128),
268 };
269
270 let b = Expression {
271 mul_terms: vec![],
272 linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
273 q_c: FieldElement::one(),
274 };
275
276 assert_eq!(
277 (&a * &b).unwrap(),
278 Expression {
279 mul_terms: vec![(FieldElement::from(8u128), Witness(2), Witness(4)),],
280 linear_combinations: vec![
281 (FieldElement::from(2u128), Witness(2)),
282 (FieldElement::from(8u128), Witness(4))
283 ],
284 q_c: FieldElement::from(2u128)
285 }
286 );
287
288 assert_eq!(&a * &b, &b * &a);
290}