use std::collections::HashMap;
use acir::circuit::{AssertionPayload, Circuit, ExpressionWidth, OpcodeLocation};
mod optimizers;
mod transformers;
pub use optimizers::optimize;
use optimizers::optimize_internal;
pub use transformers::transform;
use transformers::transform_internal;
#[derive(Debug)]
pub struct AcirTransformationMap {
old_indices_to_new_indices: HashMap<usize, Vec<usize>>,
}
impl AcirTransformationMap {
fn new(acir_opcode_positions: Vec<usize>) -> Self {
let mut old_indices_to_new_indices = HashMap::with_capacity(acir_opcode_positions.len());
for (new_index, old_index) in acir_opcode_positions.into_iter().enumerate() {
old_indices_to_new_indices.entry(old_index).or_insert_with(Vec::new).push(new_index);
}
AcirTransformationMap { old_indices_to_new_indices }
}
pub fn new_locations(
&self,
old_location: OpcodeLocation,
) -> impl Iterator<Item = OpcodeLocation> + '_ {
let old_acir_index = match old_location {
OpcodeLocation::Acir(index) => index,
OpcodeLocation::Brillig { acir_index, .. } => acir_index,
};
self.old_indices_to_new_indices.get(&old_acir_index).into_iter().flat_map(
move |new_indices| {
new_indices.iter().map(move |new_index| match old_location {
OpcodeLocation::Acir(_) => OpcodeLocation::Acir(*new_index),
OpcodeLocation::Brillig { brillig_index, .. } => {
OpcodeLocation::Brillig { acir_index: *new_index, brillig_index }
}
})
},
)
}
}
fn transform_assert_messages(
assert_messages: Vec<(OpcodeLocation, AssertionPayload)>,
map: &AcirTransformationMap,
) -> Vec<(OpcodeLocation, AssertionPayload)> {
assert_messages
.into_iter()
.flat_map(|(location, message)| {
let new_locations = map.new_locations(location);
new_locations.into_iter().map(move |new_location| (new_location, message.clone()))
})
.collect()
}
pub fn compile(
acir: Circuit,
expression_width: ExpressionWidth,
) -> (Circuit, AcirTransformationMap) {
let (acir, acir_opcode_positions) = optimize_internal(acir);
let (mut acir, acir_opcode_positions) =
transform_internal(acir, expression_width, acir_opcode_positions);
let transformation_map = AcirTransformationMap::new(acir_opcode_positions);
acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map);
(acir, transformation_map)
}