use acir::{
native_types::{Expression, Witness},
FieldElement,
};
use std::collections::BTreeMap;
use crate::{OpcodeNotSolvable, OpcodeResolutionError};
pub struct ArithmeticSolver;
#[allow(clippy::enum_variant_names)]
enum GateStatus {
GateSatisfied(FieldElement),
GateSolvable(FieldElement, (FieldElement, Witness)),
GateUnsolvable,
}
enum MulTerm {
OneUnknown(FieldElement, Witness), TooManyUnknowns,
Solved(FieldElement),
}
impl ArithmeticSolver {
pub fn solve(
initial_witness: &mut BTreeMap<Witness, FieldElement>,
gate: &Expression,
) -> Result<(), OpcodeResolutionError> {
let mul_result = ArithmeticSolver::solve_mul_term(gate, initial_witness);
let gate_status = ArithmeticSolver::solve_fan_in_term(gate, initial_witness);
match (mul_result, gate_status) {
(MulTerm::TooManyUnknowns, _) | (_, GateStatus::GateUnsolvable) => {
Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(gate.clone()),
))
}
(MulTerm::OneUnknown(q, w1), GateStatus::GateSolvable(a, (b, w2))) => {
if w1 == w2 {
let total_sum = a + gate.q_c;
if (q + b).is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain)
} else {
Ok(())
}
} else {
let assignment = -total_sum / (q + b);
initial_witness.insert(w1, assignment);
Ok(())
}
} else {
Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::ExpressionHasTooManyUnknowns(gate.clone()),
))
}
}
(MulTerm::OneUnknown(partial_prod, unknown_var), GateStatus::GateSatisfied(sum)) => {
let total_sum = sum + gate.q_c;
if partial_prod.is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain)
} else {
Ok(())
}
} else {
let assignment = -(total_sum / partial_prod);
initial_witness.insert(unknown_var, assignment);
Ok(())
}
}
(MulTerm::Solved(a), GateStatus::GateSatisfied(b)) => {
if !(a + b + gate.q_c).is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain)
} else {
Ok(())
}
}
(
MulTerm::Solved(total_prod),
GateStatus::GateSolvable(partial_sum, (coeff, unknown_var)),
) => {
let total_sum = total_prod + partial_sum + gate.q_c;
if coeff.is_zero() {
if !total_sum.is_zero() {
Err(OpcodeResolutionError::UnsatisfiedConstrain)
} else {
Ok(())
}
} else {
let assignment = -(total_sum / coeff);
initial_witness.insert(unknown_var, assignment);
Ok(())
}
}
}
}
fn solve_mul_term(
arith_gate: &Expression,
witness_assignments: &BTreeMap<Witness, FieldElement>,
) -> MulTerm {
match arith_gate.mul_terms.len() {
0 => MulTerm::Solved(FieldElement::zero()),
1 => {
let q_m = &arith_gate.mul_terms[0].0;
let w_l = &arith_gate.mul_terms[0].1;
let w_r = &arith_gate.mul_terms[0].2;
let w_l_value = witness_assignments.get(w_l);
let w_r_value = witness_assignments.get(w_r);
match (w_l_value, w_r_value) {
(None, None) => MulTerm::TooManyUnknowns,
(Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
(None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
(Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
}
}
_ => panic!("Mul term in the arithmetic gate must contain either zero or one term"),
}
}
fn solve_fan_in_term(
arith_gate: &Expression,
witness_assignments: &BTreeMap<Witness, FieldElement>,
) -> GateStatus {
let mut unknown_variable = (FieldElement::zero(), Witness::default());
let mut num_unknowns = 0;
let mut result = FieldElement::zero();
for term in arith_gate.linear_combinations.iter() {
let q_l = term.0;
let w_l = &term.1;
let w_l_value = witness_assignments.get(w_l);
match w_l_value {
Some(a) => result += q_l * *a,
None => {
unknown_variable = *term;
num_unknowns += 1;
}
};
if num_unknowns > 1 {
return GateStatus::GateUnsolvable;
}
}
if num_unknowns == 0 {
return GateStatus::GateSatisfied(result);
}
GateStatus::GateSolvable(result, unknown_variable)
}
}
#[test]
fn arithmetic_smoke_test() {
let a = Witness(0);
let b = Witness(1);
let c = Witness(2);
let d = Witness(3);
let gate_a = Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::one(), a),
(-FieldElement::one(), b),
(-FieldElement::one(), c),
(-FieldElement::one(), d),
],
q_c: FieldElement::zero(),
};
let e = Witness(4);
let gate_b = Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::one(), e),
(-FieldElement::one(), a),
(-FieldElement::one(), b),
],
q_c: FieldElement::zero(),
};
let mut values: BTreeMap<Witness, FieldElement> = BTreeMap::new();
values.insert(b, FieldElement::from(2_i128));
values.insert(c, FieldElement::from(1_i128));
values.insert(d, FieldElement::from(1_i128));
assert_eq!(ArithmeticSolver::solve(&mut values, &gate_a), Ok(()));
assert_eq!(ArithmeticSolver::solve(&mut values, &gate_b), Ok(()));
assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128));
}