use acir::{
circuit::{
opcodes::{BlackBoxFuncCall, FunctionInput},
Circuit, Opcode,
},
native_types::Witness,
};
use std::collections::{BTreeMap, HashSet};
pub(crate) struct RangeOptimizer {
lists: BTreeMap<Witness, u32>,
circuit: Circuit,
}
impl RangeOptimizer {
pub(crate) fn new(circuit: Circuit) -> Self {
let range_list = Self::collect_ranges(&circuit);
Self { circuit, lists: range_list }
}
fn collect_ranges(circuit: &Circuit) -> BTreeMap<Witness, u32> {
let mut witness_to_bit_sizes: BTreeMap<Witness, u32> = BTreeMap::new();
for opcode in &circuit.opcodes {
let Some((witness, num_bits)) = (match opcode {
Opcode::AssertZero(expr) => {
if expr.is_degree_one_univariate() {
let (k, witness) = expr.linear_combinations[0];
let constant = expr.q_c;
let witness_value = -constant / k;
if witness_value.is_zero() {
Some((witness, 0))
} else {
let implied_range_constraint_bits = witness_value.num_bits() - 1;
Some((witness, implied_range_constraint_bits))
}
} else {
None
}
}
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness, num_bits },
}) => Some((*witness, *num_bits)),
_ => None,
}) else {
continue;
};
witness_to_bit_sizes
.entry(witness)
.and_modify(|old_range_bits| {
*old_range_bits = std::cmp::min(*old_range_bits, num_bits);
})
.or_insert(num_bits);
}
witness_to_bit_sizes
}
pub(crate) fn replace_redundant_ranges(self, order_list: Vec<usize>) -> (Circuit, Vec<usize>) {
let mut already_seen_witness = HashSet::new();
let mut new_order_list = Vec::with_capacity(order_list.len());
let mut optimized_opcodes = Vec::with_capacity(self.circuit.opcodes.len());
for (idx, opcode) in self.circuit.opcodes.into_iter().enumerate() {
let (witness, num_bits) = match &opcode {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
(input.witness, input.num_bits)
}
_ => {
optimized_opcodes.push(opcode);
new_order_list.push(order_list[idx]);
continue;
}
};
let already_added = already_seen_witness.contains(&witness);
if already_added {
continue;
}
let stored_num_bits = self.lists.get(&witness).expect("Could not find witness. This should never be the case if `collect_ranges` is called");
let is_lowest_bit_size = num_bits <= *stored_num_bits;
if is_lowest_bit_size {
already_seen_witness.insert(witness);
new_order_list.push(order_list[idx]);
optimized_opcodes.push(opcode);
}
}
(Circuit { opcodes: optimized_opcodes, ..self.circuit }, new_order_list)
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::compiler::optimizers::redundant_range::RangeOptimizer;
use acir::{
circuit::{
opcodes::{BlackBoxFuncCall, FunctionInput},
Circuit, ExpressionWidth, Opcode, PublicInputs,
},
native_types::{Expression, Witness},
};
fn test_circuit(ranges: Vec<(Witness, u32)>) -> Circuit {
fn test_range_constraint(witness: Witness, num_bits: u32) -> Opcode {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness, num_bits },
})
}
let opcodes: Vec<_> = ranges
.into_iter()
.map(|(witness, num_bits)| test_range_constraint(witness, num_bits))
.collect();
Circuit {
current_witness_index: 1,
expression_width: ExpressionWidth::Bounded { width: 4 },
opcodes,
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs::default(),
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
}
}
#[test]
fn retain_lowest_range_size() {
let circuit = test_circuit(vec![(Witness(1), 32), (Witness(1), 16)]);
let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect();
let optimizer = RangeOptimizer::new(circuit);
let range_size = *optimizer
.lists
.get(&Witness(1))
.expect("Witness(1) was inserted, but it is missing from the map");
assert_eq!(
range_size, 16,
"expected a range size of 16 since that was the lowest bit size provided"
);
let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions);
assert_eq!(optimized_circuit.opcodes.len(), 1);
assert_eq!(
optimized_circuit.opcodes[0],
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(1), num_bits: 16 }
})
);
}
#[test]
fn remove_duplicates() {
let circuit = test_circuit(vec![
(Witness(1), 16),
(Witness(1), 16),
(Witness(2), 23),
(Witness(2), 23),
]);
let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect();
let optimizer = RangeOptimizer::new(circuit);
let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions);
assert_eq!(optimized_circuit.opcodes.len(), 2);
assert_eq!(
optimized_circuit.opcodes[0],
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(1), num_bits: 16 }
})
);
assert_eq!(
optimized_circuit.opcodes[1],
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(2), num_bits: 23 }
})
);
}
#[test]
fn non_range_opcodes() {
let mut circuit = test_circuit(vec![(Witness(1), 16), (Witness(1), 16)]);
circuit.opcodes.push(Opcode::AssertZero(Expression::default()));
circuit.opcodes.push(Opcode::AssertZero(Expression::default()));
circuit.opcodes.push(Opcode::AssertZero(Expression::default()));
circuit.opcodes.push(Opcode::AssertZero(Expression::default()));
let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect();
let optimizer = RangeOptimizer::new(circuit);
let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions);
assert_eq!(optimized_circuit.opcodes.len(), 5);
}
#[test]
fn constant_implied_ranges() {
let mut circuit = test_circuit(vec![(Witness(1), 16)]);
circuit.opcodes.push(Opcode::AssertZero(Witness(1).into()));
let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect();
let optimizer = RangeOptimizer::new(circuit);
let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions);
assert_eq!(optimized_circuit.opcodes.len(), 1);
assert_eq!(optimized_circuit.opcodes[0], Opcode::AssertZero(Witness(1).into()));
}
}