use acir::{
circuit::{brillig::BrilligOutputs, directives::Directive, Circuit, ExpressionWidth, Opcode},
native_types::{Expression, Witness},
FieldElement,
};
use indexmap::IndexMap;
mod csat;
pub(crate) use csat::CSatTransformer;
use super::{transform_assert_messages, AcirTransformationMap};
pub fn transform(
acir: Circuit,
expression_width: ExpressionWidth,
) -> (Circuit, AcirTransformationMap) {
let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect();
let (mut acir, acir_opcode_positions) =
transform_internal(acir, expression_width, acir_opcode_positions);
let transformation_map = AcirTransformationMap::new(acir_opcode_positions);
acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map);
(acir, transformation_map)
}
#[tracing::instrument(level = "trace", name = "transform_acir", skip(acir, acir_opcode_positions))]
pub(super) fn transform_internal(
acir: Circuit,
expression_width: ExpressionWidth,
acir_opcode_positions: Vec<usize>,
) -> (Circuit, Vec<usize>) {
let mut transformer = match &expression_width {
ExpressionWidth::Unbounded => {
return (acir, acir_opcode_positions);
}
ExpressionWidth::Bounded { width } => {
let mut csat = CSatTransformer::new(*width);
for value in acir.circuit_arguments() {
csat.mark_solvable(value);
}
csat
}
};
let mut new_acir_opcode_positions: Vec<usize> = Vec::with_capacity(acir_opcode_positions.len());
let mut transformed_opcodes = Vec::new();
let mut next_witness_index = acir.current_witness_index + 1;
let mut intermediate_variables: IndexMap<Expression, (FieldElement, Witness)> = IndexMap::new();
for (index, opcode) in acir.opcodes.into_iter().enumerate() {
match opcode {
Opcode::AssertZero(arith_expr) => {
let len = intermediate_variables.len();
let arith_expr = transformer.transform(
arith_expr,
&mut intermediate_variables,
&mut next_witness_index,
);
next_witness_index += (intermediate_variables.len() - len) as u32;
let mut new_opcodes = Vec::new();
for (g, (norm, w)) in intermediate_variables.iter().skip(len) {
let mut intermediate_opcode = g * *norm;
intermediate_opcode.linear_combinations.push((-FieldElement::one(), *w));
intermediate_opcode.sort();
new_opcodes.push(intermediate_opcode);
}
new_opcodes.push(arith_expr);
for opcode in new_opcodes {
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(Opcode::AssertZero(opcode));
}
}
Opcode::BlackBoxFuncCall(ref func) => {
for witness in func.get_outputs_vec() {
transformer.mark_solvable(witness);
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
Opcode::Directive(ref directive) => {
match directive {
Directive::ToLeRadix { b, .. } => {
for witness in b {
transformer.mark_solvable(*witness);
}
}
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
Opcode::MemoryInit { .. } => {
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
Opcode::MemoryOp { ref op, .. } => {
for (_, witness1, witness2) in &op.value.mul_terms {
transformer.mark_solvable(*witness1);
transformer.mark_solvable(*witness2);
}
for (_, witness) in &op.value.linear_combinations {
transformer.mark_solvable(*witness);
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
Opcode::BrilligCall { ref outputs, .. } => {
for output in outputs {
match output {
BrilligOutputs::Simple(w) => transformer.mark_solvable(*w),
BrilligOutputs::Array(v) => {
for witness in v {
transformer.mark_solvable(*witness);
}
}
}
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
Opcode::Call { ref outputs, .. } => {
for witness in outputs {
transformer.mark_solvable(*witness);
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
}
}
let current_witness_index = next_witness_index - 1;
let acir = Circuit {
current_witness_index,
expression_width,
opcodes: transformed_opcodes,
..acir
};
(acir, new_acir_opcode_positions)
}