use acir::{
native_types::{Expression, Witness, WitnessMap},
FieldElement,
};
use super::{insert_value, ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError};
pub(crate) struct ExpressionSolver;
#[allow(clippy::enum_variant_names)]
pub(super) enum OpcodeStatus {
OpcodeSatisfied(FieldElement),
OpcodeSolvable(FieldElement, (FieldElement, Witness)),
OpcodeUnsolvable,
}
pub(crate) enum MulTerm {
OneUnknown(FieldElement, Witness), TooManyUnknowns,
Solved(FieldElement),
}
impl ExpressionSolver {
pub(crate) fn solve(
initial_witness: &mut WitnessMap,
opcode: &Expression,
) -> Result<(), OpcodeResolutionError> {
let opcode = &ExpressionSolver::evaluate(opcode, initial_witness);
let mul_result =
ExpressionSolver::solve_mul_term(opcode, initial_witness).map_err(|_| {
OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
)
})?;
let opcode_status = ExpressionSolver::solve_fan_in_term(opcode, initial_witness);
match (mul_result, opcode_status) {
(MulTerm::TooManyUnknowns, _) | (_, OpcodeStatus::OpcodeUnsolvable) => {
Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
))
}
(MulTerm::OneUnknown(q, w1), OpcodeStatus::OpcodeSolvable(a, (b, w2))) => {
if w1 == w2 {
let total_sum = a + opcode.q_c;
if (q + b).is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
} else {
let assignment = -total_sum / (q + b);
insert_value(&w1, assignment, initial_witness)
}
} else {
Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(opcode.clone()),
))
}
}
(
MulTerm::OneUnknown(partial_prod, unknown_var),
OpcodeStatus::OpcodeSatisfied(sum),
) => {
let total_sum = sum + opcode.q_c;
if partial_prod.is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
} else {
let assignment = -(total_sum / partial_prod);
insert_value(&unknown_var, assignment, initial_witness)
}
}
(MulTerm::Solved(a), OpcodeStatus::OpcodeSatisfied(b)) => {
if !(a + b + opcode.q_c).is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
}
(
MulTerm::Solved(total_prod),
OpcodeStatus::OpcodeSolvable(partial_sum, (coeff, unknown_var)),
) => {
let total_sum = total_prod + partial_sum + opcode.q_c;
if coeff.is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
})
} else {
Ok(())
}
} else {
let assignment = -(total_sum / coeff);
insert_value(&unknown_var, assignment, initial_witness)
}
}
}
}
fn solve_mul_term(
arith_opcode: &Expression,
witness_assignments: &WitnessMap,
) -> Result<MulTerm, OpcodeStatus> {
match arith_opcode.mul_terms.len() {
0 => Ok(MulTerm::Solved(FieldElement::zero())),
1 => Ok(ExpressionSolver::solve_mul_term_helper(
&arith_opcode.mul_terms[0],
witness_assignments,
)),
_ => Err(OpcodeStatus::OpcodeUnsolvable),
}
}
fn solve_mul_term_helper(
term: &(FieldElement, Witness, Witness),
witness_assignments: &WitnessMap,
) -> MulTerm {
let (q_m, w_l, w_r) = term;
let w_l_value = witness_assignments.get(w_l);
let w_r_value = witness_assignments.get(w_r);
match (w_l_value, w_r_value) {
(None, None) => MulTerm::TooManyUnknowns,
(Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
(None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
(Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
}
}
fn solve_fan_in_term_helper(
term: &(FieldElement, Witness),
witness_assignments: &WitnessMap,
) -> Option<FieldElement> {
let (q_l, w_l) = term;
let w_l_value = witness_assignments.get(w_l);
w_l_value.map(|a| *q_l * *a)
}
pub(super) fn solve_fan_in_term(
arith_opcode: &Expression,
witness_assignments: &WitnessMap,
) -> OpcodeStatus {
let mut unknown_variable = (FieldElement::zero(), Witness::default());
let mut num_unknowns = 0;
let mut result = FieldElement::zero();
for term in arith_opcode.linear_combinations.iter() {
let value = ExpressionSolver::solve_fan_in_term_helper(term, witness_assignments);
match value {
Some(a) => result += a,
None => {
unknown_variable = *term;
num_unknowns += 1;
}
}
if num_unknowns > 1 {
return OpcodeStatus::OpcodeUnsolvable;
}
}
if num_unknowns == 0 {
return OpcodeStatus::OpcodeSatisfied(result);
}
OpcodeStatus::OpcodeSolvable(result, unknown_variable)
}
pub(crate) fn evaluate(expr: &Expression, initial_witness: &WitnessMap) -> Expression {
let mut result = Expression::default();
for &(c, w1, w2) in &expr.mul_terms {
let mul_result = ExpressionSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness);
match mul_result {
MulTerm::OneUnknown(v, w) => {
if !v.is_zero() {
result.linear_combinations.push((v, w));
}
}
MulTerm::TooManyUnknowns => {
if !c.is_zero() {
result.mul_terms.push((c, w1, w2));
}
}
MulTerm::Solved(f) => result.q_c += f,
}
}
for &(c, w) in &expr.linear_combinations {
if let Some(f) = ExpressionSolver::solve_fan_in_term_helper(&(c, w), initial_witness) {
result.q_c += f;
} else if !c.is_zero() {
result.linear_combinations.push((c, w));
}
}
result.q_c += expr.q_c;
result
}
}
#[test]
fn expression_solver_smoke_test() {
let a = Witness(0);
let b = Witness(1);
let c = Witness(2);
let d = Witness(3);
let opcode_a = Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::one(), a),
(-FieldElement::one(), b),
(-FieldElement::one(), c),
(-FieldElement::one(), d),
],
q_c: FieldElement::zero(),
};
let e = Witness(4);
let opcode_b = Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::one(), e),
(-FieldElement::one(), a),
(-FieldElement::one(), b),
],
q_c: FieldElement::zero(),
};
let mut values = WitnessMap::new();
values.insert(b, FieldElement::from(2_i128));
values.insert(c, FieldElement::from(1_i128));
values.insert(d, FieldElement::from(1_i128));
assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(()));
assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(()));
assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128));
}