1use acir::{
2 native_types::{Expression, Witness, WitnessMap},
3 FieldElement,
4};
5
6use super::{insert_value, ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError};
7
8pub(crate) struct ExpressionSolver;
11
12#[allow(clippy::enum_variant_names)]
13pub(super) enum OpcodeStatus {
14 OpcodeSatisfied(FieldElement),
15 OpcodeSolvable(FieldElement, (FieldElement, Witness)),
16 OpcodeUnsolvable,
17}
18
19pub(crate) enum MulTerm {
20 OneUnknown(FieldElement, Witness), TooManyUnknowns,
22 Solved(FieldElement),
23}
24
25impl ExpressionSolver {
26 pub(crate) fn solve(
28 initial_witness: &mut WitnessMap,
29 opcode: &Expression,
30 ) -> Result<(), OpcodeResolutionError> {
31 let opcode = &ExpressionSolver::evaluate(opcode, initial_witness);
32 let mul_result =
34 ExpressionSolver::solve_mul_term(opcode, initial_witness).map_err(|_| {
35 OpcodeResolutionError::OpcodeNotSolvable(
36 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
37 )
38 })?;
39 let opcode_status = ExpressionSolver::solve_fan_in_term(opcode, initial_witness);
41
42 match (mul_result, opcode_status) {
43 (MulTerm::TooManyUnknowns, _) | (_, OpcodeStatus::OpcodeUnsolvable) => {
44 Err(OpcodeResolutionError::OpcodeNotSolvable(
45 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
46 ))
47 }
48 (MulTerm::OneUnknown(q, w1), OpcodeStatus::OpcodeSolvable(a, (b, w2))) => {
49 if w1 == w2 {
50 let total_sum = a + opcode.q_c;
52 if (q + b).is_zero() {
53 if !total_sum.is_zero() {
54 Err(OpcodeResolutionError::UnsatisfiedConstrain {
55 opcode_location: ErrorLocation::Unresolved,
56 payload: None,
57 })
58 } else {
59 Ok(())
60 }
61 } else {
62 let assignment = -total_sum / (q + b);
63 insert_value(&w1, assignment, initial_witness)
64 }
65 } else {
66 Err(OpcodeResolutionError::OpcodeNotSolvable(
68 OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
69 ))
70 }
71 }
72 (
73 MulTerm::OneUnknown(partial_prod, unknown_var),
74 OpcodeStatus::OpcodeSatisfied(sum),
75 ) => {
76 let total_sum = sum + opcode.q_c;
81 if partial_prod.is_zero() {
82 if !total_sum.is_zero() {
83 Err(OpcodeResolutionError::UnsatisfiedConstrain {
84 opcode_location: ErrorLocation::Unresolved,
85 payload: None,
86 })
87 } else {
88 Ok(())
89 }
90 } else {
91 let assignment = -(total_sum / partial_prod);
92 insert_value(&unknown_var, assignment, initial_witness)
93 }
94 }
95 (MulTerm::Solved(a), OpcodeStatus::OpcodeSatisfied(b)) => {
96 if !(a + b + opcode.q_c).is_zero() {
99 Err(OpcodeResolutionError::UnsatisfiedConstrain {
100 opcode_location: ErrorLocation::Unresolved,
101 payload: None,
102 })
103 } else {
104 Ok(())
105 }
106 }
107 (
108 MulTerm::Solved(total_prod),
109 OpcodeStatus::OpcodeSolvable(partial_sum, (coeff, unknown_var)),
110 ) => {
111 let total_sum = total_prod + partial_sum + opcode.q_c;
115 if coeff.is_zero() {
116 if !total_sum.is_zero() {
117 Err(OpcodeResolutionError::UnsatisfiedConstrain {
118 opcode_location: ErrorLocation::Unresolved,
119 payload: None,
120 })
121 } else {
122 Ok(())
123 }
124 } else {
125 let assignment = -(total_sum / coeff);
126 insert_value(&unknown_var, assignment, initial_witness)
127 }
128 }
129 }
130 }
131
132 fn solve_mul_term(
137 arith_opcode: &Expression,
138 witness_assignments: &WitnessMap,
139 ) -> Result<MulTerm, OpcodeStatus> {
140 match arith_opcode.mul_terms.len() {
143 0 => Ok(MulTerm::Solved(FieldElement::zero())),
144 1 => Ok(ExpressionSolver::solve_mul_term_helper(
145 &arith_opcode.mul_terms[0],
146 witness_assignments,
147 )),
148 _ => Err(OpcodeStatus::OpcodeUnsolvable),
149 }
150 }
151
152 fn solve_mul_term_helper(
153 term: &(FieldElement, Witness, Witness),
154 witness_assignments: &WitnessMap,
155 ) -> MulTerm {
156 let (q_m, w_l, w_r) = term;
157 let w_l_value = witness_assignments.get(w_l);
159 let w_r_value = witness_assignments.get(w_r);
160
161 match (w_l_value, w_r_value) {
162 (None, None) => MulTerm::TooManyUnknowns,
163 (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
164 (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
165 (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
166 }
167 }
168
169 fn solve_fan_in_term_helper(
170 term: &(FieldElement, Witness),
171 witness_assignments: &WitnessMap,
172 ) -> Option<FieldElement> {
173 let (q_l, w_l) = term;
174 let w_l_value = witness_assignments.get(w_l);
176 w_l_value.map(|a| *q_l * *a)
177 }
178
179 pub(super) fn solve_fan_in_term(
183 arith_opcode: &Expression,
184 witness_assignments: &WitnessMap,
185 ) -> OpcodeStatus {
186 let mut unknown_variable = (FieldElement::zero(), Witness::default());
190 let mut num_unknowns = 0;
191 let mut result = FieldElement::zero();
193
194 for term in arith_opcode.linear_combinations.iter() {
195 let value = ExpressionSolver::solve_fan_in_term_helper(term, witness_assignments);
196 match value {
197 Some(a) => result += a,
198 None => {
199 unknown_variable = *term;
200 num_unknowns += 1;
201 }
202 }
203
204 if num_unknowns > 1 {
206 return OpcodeStatus::OpcodeUnsolvable;
207 }
208 }
209
210 if num_unknowns == 0 {
211 return OpcodeStatus::OpcodeSatisfied(result);
212 }
213
214 OpcodeStatus::OpcodeSolvable(result, unknown_variable)
215 }
216
217 pub(crate) fn evaluate(expr: &Expression, initial_witness: &WitnessMap) -> Expression {
219 let mut result = Expression::default();
220 for &(c, w1, w2) in &expr.mul_terms {
221 let mul_result = ExpressionSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness);
222 match mul_result {
223 MulTerm::OneUnknown(v, w) => {
224 if !v.is_zero() {
225 result.linear_combinations.push((v, w));
226 }
227 }
228 MulTerm::TooManyUnknowns => {
229 if !c.is_zero() {
230 result.mul_terms.push((c, w1, w2));
231 }
232 }
233 MulTerm::Solved(f) => result.q_c += f,
234 }
235 }
236 for &(c, w) in &expr.linear_combinations {
237 if let Some(f) = ExpressionSolver::solve_fan_in_term_helper(&(c, w), initial_witness) {
238 result.q_c += f;
239 } else if !c.is_zero() {
240 result.linear_combinations.push((c, w));
241 }
242 }
243 result.q_c += expr.q_c;
244 result
245 }
246}
247
248#[test]
249fn expression_solver_smoke_test() {
250 let a = Witness(0);
251 let b = Witness(1);
252 let c = Witness(2);
253 let d = Witness(3);
254
255 let opcode_a = Expression {
257 mul_terms: vec![],
258 linear_combinations: vec![
259 (FieldElement::one(), a),
260 (-FieldElement::one(), b),
261 (-FieldElement::one(), c),
262 (-FieldElement::one(), d),
263 ],
264 q_c: FieldElement::zero(),
265 };
266
267 let e = Witness(4);
268 let opcode_b = Expression {
269 mul_terms: vec![],
270 linear_combinations: vec![
271 (FieldElement::one(), e),
272 (-FieldElement::one(), a),
273 (-FieldElement::one(), b),
274 ],
275 q_c: FieldElement::zero(),
276 };
277
278 let mut values = WitnessMap::new();
279 values.insert(b, FieldElement::from(2_i128));
280 values.insert(c, FieldElement::from(1_i128));
281 values.insert(d, FieldElement::from(1_i128));
282
283 assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
284 assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(()));
285
286 assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128));
287}