use std::collections::HashMap;
use acir::{
brillig::ForeignCallResult,
circuit::{
brillig::BrilligBytecode, opcodes::BlockId, AssertionPayload, ErrorSelector,
ExpressionOrMemory, Opcode, OpcodeLocation, RawAssertionPayload, ResolvedAssertionPayload,
STRING_ERROR_SELECTOR,
},
native_types::{Expression, Witness, WitnessMap},
BlackBoxFunc, FieldElement,
};
use acvm_blackbox_solver::BlackBoxResolutionError;
use self::{
arithmetic::ExpressionSolver, blackbox::bigint::AcvmBigIntSolver, directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::BlackBoxFunctionSolver;
use thiserror::Error;
pub(crate) mod arithmetic;
pub(crate) mod brillig;
pub(crate) mod directives;
pub(crate) mod blackbox;
mod memory_op;
pub use self::brillig::{BrilligSolver, BrilligSolverStatus};
pub use brillig::ForeignCallWaitInfo;
#[derive(Debug, Clone, PartialEq)]
pub enum ACVMStatus {
Solved,
InProgress,
Failure(OpcodeResolutionError),
RequiresForeignCall(ForeignCallWaitInfo),
RequiresAcirCall(AcirCallWaitInfo),
}
impl std::fmt::Display for ACVMStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ACVMStatus::Solved => write!(f, "Solved"),
ACVMStatus::InProgress => write!(f, "In progress"),
ACVMStatus::Failure(_) => write!(f, "Execution failure"),
ACVMStatus::RequiresForeignCall(_) => write!(f, "Waiting on foreign call"),
ACVMStatus::RequiresAcirCall(_) => write!(f, "Waiting on acir call"),
}
}
}
pub enum StepResult<'a, B: BlackBoxFunctionSolver> {
Status(ACVMStatus),
IntoBrillig(BrilligSolver<'a, B>),
}
#[derive(Clone, PartialEq, Eq, Debug, Error)]
pub enum OpcodeNotSolvable {
#[error("missing assignment for witness index {0}")]
MissingAssignment(u32),
#[error("Attempted to load uninitialized memory block")]
MissingMemoryBlock(u32),
#[error("expression has too many unknowns {0}")]
ExpressionHasTooManyUnknowns(Expression),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
pub enum ErrorLocation {
#[default]
Unresolved,
Resolved(OpcodeLocation),
}
impl std::fmt::Display for ErrorLocation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorLocation::Unresolved => write!(f, "unresolved"),
ErrorLocation::Resolved(location) => {
write!(f, "{location}")
}
}
}
}
#[derive(Clone, PartialEq, Eq, Debug, Error)]
pub enum OpcodeResolutionError {
#[error("Cannot solve opcode: {0}")]
OpcodeNotSolvable(#[from] OpcodeNotSolvable),
#[error("Cannot satisfy constraint")]
UnsatisfiedConstrain {
opcode_location: ErrorLocation,
payload: Option<ResolvedAssertionPayload>,
},
#[error("Index out of bounds, array has size {array_size:?}, but index was {index:?}")]
IndexOutOfBounds { opcode_location: ErrorLocation, index: u32, array_size: u32 },
#[error("Failed to solve blackbox function: {0}, reason: {1}")]
BlackBoxFunctionFailed(BlackBoxFunc, String),
#[error("Failed to solve brillig function")]
BrilligFunctionFailed {
call_stack: Vec<OpcodeLocation>,
payload: Option<ResolvedAssertionPayload>,
},
#[error("Attempted to call `main` with a `Call` opcode")]
AcirMainCallAttempted { opcode_location: ErrorLocation },
#[error("{results_size:?} result values were provided for {outputs_size:?} call output witnesses, most likely due to bad ACIR codegen")]
AcirCallOutputsMismatch { opcode_location: ErrorLocation, results_size: u32, outputs_size: u32 },
}
impl From<BlackBoxResolutionError> for OpcodeResolutionError {
fn from(value: BlackBoxResolutionError) -> Self {
match value {
BlackBoxResolutionError::Failed(func, reason) => {
OpcodeResolutionError::BlackBoxFunctionFailed(func, reason)
}
}
}
}
pub struct ACVM<'a, B: BlackBoxFunctionSolver> {
status: ACVMStatus,
backend: &'a B,
block_solvers: HashMap<BlockId, MemoryOpSolver>,
bigint_solver: AcvmBigIntSolver,
opcodes: &'a [Opcode],
instruction_pointer: usize,
witness_map: WitnessMap,
brillig_solver: Option<BrilligSolver<'a, B>>,
acir_call_counter: usize,
acir_call_results: Vec<Vec<FieldElement>>,
unconstrained_functions: &'a [BrilligBytecode],
assertion_payloads: &'a [(OpcodeLocation, AssertionPayload)],
}
impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
pub fn new(
backend: &'a B,
opcodes: &'a [Opcode],
initial_witness: WitnessMap,
unconstrained_functions: &'a [BrilligBytecode],
assertion_payloads: &'a [(OpcodeLocation, AssertionPayload)],
) -> Self {
let status = if opcodes.is_empty() { ACVMStatus::Solved } else { ACVMStatus::InProgress };
ACVM {
status,
backend,
block_solvers: HashMap::default(),
bigint_solver: AcvmBigIntSolver::default(),
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
brillig_solver: None,
acir_call_counter: 0,
acir_call_results: Vec::default(),
unconstrained_functions,
assertion_payloads,
}
}
pub fn witness_map(&self) -> &WitnessMap {
&self.witness_map
}
pub fn overwrite_witness(
&mut self,
witness: Witness,
value: FieldElement,
) -> Option<FieldElement> {
self.witness_map.insert(witness, value)
}
pub fn opcodes(&self) -> &[Opcode] {
self.opcodes
}
pub fn instruction_pointer(&self) -> usize {
self.instruction_pointer
}
pub fn finalize(self) -> WitnessMap {
if self.status != ACVMStatus::Solved {
panic!("ACVM execution is not complete: ({})", self.status);
}
self.witness_map
}
fn status(&mut self, status: ACVMStatus) -> ACVMStatus {
self.status = status.clone();
status
}
pub fn get_status(&self) -> &ACVMStatus {
&self.status
}
fn fail(&mut self, error: OpcodeResolutionError) -> ACVMStatus {
self.status(ACVMStatus::Failure(error))
}
fn wait_for_foreign_call(&mut self, foreign_call: ForeignCallWaitInfo) -> ACVMStatus {
self.status(ACVMStatus::RequiresForeignCall(foreign_call))
}
pub fn get_pending_foreign_call(&self) -> Option<&ForeignCallWaitInfo> {
if let ACVMStatus::RequiresForeignCall(foreign_call) = &self.status {
Some(foreign_call)
} else {
None
}
}
pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
if !matches!(self.status, ACVMStatus::RequiresForeignCall(_)) {
panic!("ACVM is not expecting a foreign call response as no call was made");
}
let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver");
brillig_solver.resolve_pending_foreign_call(foreign_call_result);
self.status(ACVMStatus::InProgress);
}
fn wait_for_acir_call(&mut self, acir_call: AcirCallWaitInfo) -> ACVMStatus {
self.status(ACVMStatus::RequiresAcirCall(acir_call))
}
pub fn resolve_pending_acir_call(&mut self, call_result: Vec<FieldElement>) {
if !matches!(self.status, ACVMStatus::RequiresAcirCall(_)) {
panic!("ACVM is not expecting an ACIR call response as no call was made");
}
if self.acir_call_counter < self.acir_call_results.len() {
panic!("No unresolved ACIR calls");
}
self.acir_call_results.push(call_result);
self.status(ACVMStatus::InProgress);
}
pub fn solve(&mut self) -> ACVMStatus {
while self.status == ACVMStatus::InProgress {
self.solve_opcode();
}
self.status.clone()
}
pub fn solve_opcode(&mut self) -> ACVMStatus {
let opcode = &self.opcodes[self.instruction_pointer];
let resolution = match opcode {
Opcode::AssertZero(expr) => ExpressionSolver::solve(&mut self.witness_map, expr),
Opcode::BlackBoxFuncCall(bb_func) => blackbox::solve(
self.backend,
&mut self.witness_map,
bb_func,
&mut self.bigint_solver,
),
Opcode::Directive(directive) => solve_directives(&mut self.witness_map, directive),
Opcode::MemoryInit { block_id, init } => {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.init(init, &self.witness_map)
}
Opcode::MemoryOp { block_id, op, predicate } => {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::BrilligCall { .. } => match self.solve_brillig_call_opcode() {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
},
Opcode::Call { .. } => match self.solve_call_opcode() {
Ok(Some(input_values)) => return self.wait_for_acir_call(input_values),
res => res.map(|_| ()),
},
};
self.handle_opcode_resolution(resolution)
}
fn handle_opcode_resolution(
&mut self,
resolution: Result<(), OpcodeResolutionError>,
) -> ACVMStatus {
match resolution {
Ok(()) => {
self.instruction_pointer += 1;
if self.instruction_pointer == self.opcodes.len() {
self.status(ACVMStatus::Solved)
} else {
self.status(ACVMStatus::InProgress)
}
}
Err(mut error) => {
match &mut error {
OpcodeResolutionError::IndexOutOfBounds {
opcode_location: opcode_index,
..
} => {
*opcode_index = ErrorLocation::Resolved(OpcodeLocation::Acir(
self.instruction_pointer(),
));
}
OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: opcode_index,
payload: assertion_payload,
} => {
let location = OpcodeLocation::Acir(self.instruction_pointer());
*opcode_index = ErrorLocation::Resolved(location);
*assertion_payload = self.extract_assertion_payload(location);
}
_ => (),
};
self.fail(error)
}
}
}
fn extract_assertion_payload(
&self,
location: OpcodeLocation,
) -> Option<ResolvedAssertionPayload> {
let (_, found_assertion_payload) =
self.assertion_payloads.iter().find(|(loc, _)| location == *loc)?;
match found_assertion_payload {
AssertionPayload::StaticString(string) => {
Some(ResolvedAssertionPayload::String(string.clone()))
}
AssertionPayload::Dynamic(error_selector, expression) => {
let mut fields = vec![];
for expr in expression {
match expr {
ExpressionOrMemory::Expression(expr) => {
let value = get_value(expr, &self.witness_map).ok()?;
fields.push(value);
}
ExpressionOrMemory::Memory(block_id) => {
let memory_block = self.block_solvers.get(block_id)?;
fields.extend((0..memory_block.block_len).map(|memory_index| {
*memory_block
.block_value
.get(&memory_index)
.expect("All memory is initialized on creation")
}));
}
}
}
let error_selector = ErrorSelector::new(*error_selector);
Some(match error_selector {
STRING_ERROR_SELECTOR => {
let string = fields
.iter()
.map(|field| {
let as_u8: u8 = field
.try_to_u64()
.expect("String character doesn't fit in u64")
.try_into()
.expect("String character doesn't fit in u8");
as_u8 as char
})
.collect();
ResolvedAssertionPayload::String(string)
}
_ => {
ResolvedAssertionPayload::Raw(RawAssertionPayload {
selector: error_selector,
data: fields,
})
}
})
}
}
}
fn solve_brillig_call_opcode(
&mut self,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
let Opcode::BrilligCall { id, inputs, outputs, predicate } =
&self.opcodes[self.instruction_pointer]
else {
unreachable!("Not executing a BrilligCall opcode");
};
if is_predicate_false(&self.witness_map, predicate)? {
return BrilligSolver::<B>::zero_out_brillig_outputs(&mut self.witness_map, outputs)
.map(|_| None);
}
let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() {
Some(solver) => solver,
None => BrilligSolver::new_call(
&self.witness_map,
&self.block_solvers,
inputs,
&self.unconstrained_functions[*id as usize].bytecode,
self.backend,
self.instruction_pointer,
)?,
};
let result = solver.solve().map_err(|err| self.map_brillig_error(err))?;
match result {
BrilligSolverStatus::ForeignCallWait(foreign_call) => {
self.brillig_solver = Some(solver);
Ok(Some(foreign_call))
}
BrilligSolverStatus::InProgress => {
unreachable!("Brillig solver still in progress")
}
BrilligSolverStatus::Finished => {
solver.finalize(&mut self.witness_map, outputs)?;
Ok(None)
}
}
}
fn map_brillig_error(&self, mut err: OpcodeResolutionError) -> OpcodeResolutionError {
match &mut err {
OpcodeResolutionError::BrilligFunctionFailed { call_stack, payload } => {
let last_location =
call_stack.last().expect("Call stacks should have at least one item");
let assertion_descriptor =
self.assertion_payloads.iter().find_map(|(loc, payload)| {
if loc == last_location {
Some(payload)
} else {
None
}
});
if let Some(AssertionPayload::StaticString(string)) = assertion_descriptor {
*payload = Some(ResolvedAssertionPayload::String(string.clone()));
}
err
}
_ => err,
}
}
pub fn step_into_brillig(&mut self) -> StepResult<'a, B> {
let Opcode::BrilligCall { id, inputs, outputs, predicate } =
&self.opcodes[self.instruction_pointer]
else {
return StepResult::Status(self.solve_opcode());
};
let witness = &mut self.witness_map;
let should_skip = match is_predicate_false(witness, predicate) {
Ok(result) => result,
Err(err) => return StepResult::Status(self.handle_opcode_resolution(Err(err))),
};
if should_skip {
let resolution = BrilligSolver::<B>::zero_out_brillig_outputs(witness, outputs);
return StepResult::Status(self.handle_opcode_resolution(resolution));
}
let solver = BrilligSolver::new_call(
witness,
&self.block_solvers,
inputs,
&self.unconstrained_functions[*id as usize].bytecode,
self.backend,
self.instruction_pointer,
);
match solver {
Ok(solver) => StepResult::IntoBrillig(solver),
Err(..) => StepResult::Status(self.handle_opcode_resolution(solver.map(|_| ()))),
}
}
pub fn finish_brillig_with_solver(&mut self, solver: BrilligSolver<'a, B>) -> ACVMStatus {
if !matches!(self.opcodes[self.instruction_pointer], Opcode::BrilligCall { .. }) {
unreachable!("Not executing a Brillig/BrilligCall opcode");
}
self.brillig_solver = Some(solver);
self.solve_opcode()
}
pub fn solve_call_opcode(&mut self) -> Result<Option<AcirCallWaitInfo>, OpcodeResolutionError> {
let Opcode::Call { id, inputs, outputs, predicate } =
&self.opcodes[self.instruction_pointer]
else {
unreachable!("Not executing a Call opcode");
};
if *id == 0 {
return Err(OpcodeResolutionError::AcirMainCallAttempted {
opcode_location: ErrorLocation::Resolved(OpcodeLocation::Acir(
self.instruction_pointer(),
)),
});
}
if is_predicate_false(&self.witness_map, predicate)? {
for output in outputs {
insert_value(output, FieldElement::zero(), &mut self.witness_map)?;
}
return Ok(None);
}
if self.acir_call_counter >= self.acir_call_results.len() {
let mut initial_witness = WitnessMap::default();
for (i, input_witness) in inputs.iter().enumerate() {
let input_value = *witness_to_value(&self.witness_map, *input_witness)?;
initial_witness.insert(Witness(i as u32), input_value);
}
return Ok(Some(AcirCallWaitInfo { id: *id, initial_witness }));
}
let result_values = &self.acir_call_results[self.acir_call_counter];
if outputs.len() != result_values.len() {
return Err(OpcodeResolutionError::AcirCallOutputsMismatch {
opcode_location: ErrorLocation::Resolved(OpcodeLocation::Acir(
self.instruction_pointer(),
)),
results_size: result_values.len() as u32,
outputs_size: outputs.len() as u32,
});
}
for (output_witness, result_value) in outputs.iter().zip(result_values) {
insert_value(output_witness, *result_value, &mut self.witness_map)?;
}
self.acir_call_counter += 1;
Ok(None)
}
}
pub fn witness_to_value(
initial_witness: &WitnessMap,
witness: Witness,
) -> Result<&FieldElement, OpcodeResolutionError> {
match initial_witness.get(&witness) {
Some(value) => Ok(value),
None => Err(OpcodeNotSolvable::MissingAssignment(witness.0).into()),
}
}
pub fn get_value(
expr: &Expression,
initial_witness: &WitnessMap,
) -> Result<FieldElement, OpcodeResolutionError> {
let expr = ExpressionSolver::evaluate(expr, initial_witness);
match expr.to_const() {
Some(value) => Ok(value),
None => Err(OpcodeResolutionError::OpcodeNotSolvable(
OpcodeNotSolvable::MissingAssignment(any_witness_from_expression(&expr).unwrap().0),
)),
}
}
pub fn insert_value(
witness: &Witness,
value_to_insert: FieldElement,
initial_witness: &mut WitnessMap,
) -> Result<(), OpcodeResolutionError> {
let optional_old_value = initial_witness.insert(*witness, value_to_insert);
let old_value = match optional_old_value {
Some(old_value) => old_value,
None => return Ok(()),
};
if old_value != value_to_insert {
return Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
payload: None,
});
}
Ok(())
}
fn any_witness_from_expression(expr: &Expression) -> Option<Witness> {
if expr.linear_combinations.is_empty() {
if expr.mul_terms.is_empty() {
None
} else {
Some(expr.mul_terms[0].1)
}
} else {
Some(expr.linear_combinations[0].1)
}
}
pub(crate) fn is_predicate_false(
witness: &WitnessMap,
predicate: &Option<Expression>,
) -> Result<bool, OpcodeResolutionError> {
match predicate {
Some(pred) => get_value(pred, witness).map(|pred_value| pred_value.is_zero()),
None => Ok(false),
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AcirCallWaitInfo {
pub id: u32,
pub initial_witness: WitnessMap,
}