1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
use std::collections::{BTreeMap, BTreeSet, HashMap};
use crate::{
compiler::optimizers::GeneralOptimizer,
pwg::{
arithmetic::ExpressionSolver, blackbox::solve_range_opcode, directives::solve_directives,
BrilligSolver, BrilligSolverStatus,
},
};
use acir::{
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::BlackBoxFuncCall,
Circuit, Opcode,
},
native_types::{Expression, Witness, WitnessMap},
};
use acvm_blackbox_solver::StubbedBlackBoxSolver;
/// `ConstantBackpropagationOptimizer` will attempt to determine any constant witnesses within the program.
/// It does this by attempting to solve the program without any inputs (i.e. using an empty witness map),
/// any values which it can determine are then enforced to be constant values.
///
/// The optimizer will then replace any witnesses wherever they appear within the circuit with these constant values.
/// This is repeated until the circuit stabilizes.
pub(crate) struct ConstantBackpropagationOptimizer {
circuit: Circuit,
}
impl ConstantBackpropagationOptimizer {
/// Creates a new `ConstantBackpropagationOptimizer`
pub(crate) fn new(circuit: Circuit) -> Self {
Self { circuit }
}
fn gather_known_witnesses(&self) -> (WitnessMap, BTreeSet<Witness>) {
// We do not want to affect the circuit's interface so avoid optimizing away these witnesses.
let mut required_witnesses: BTreeSet<Witness> = self
.circuit
.private_parameters
.union(&self.circuit.public_parameters.0)
.chain(&self.circuit.return_values.0)
.copied()
.collect();
for opcode in &self.circuit.opcodes {
match &opcode {
Opcode::BlackBoxFuncCall(func_call) => {
required_witnesses.extend(
func_call.get_inputs_vec().into_iter().map(|func_input| func_input.witness),
);
required_witnesses.extend(func_call.get_outputs_vec());
}
Opcode::MemoryInit { init, .. } => {
required_witnesses.extend(init);
}
Opcode::MemoryOp { op, .. } => {
required_witnesses.insert(op.index.to_witness().unwrap());
required_witnesses.insert(op.value.to_witness().unwrap());
}
_ => (),
};
}
let mut known_witnesses = WitnessMap::new();
for opcode in self.circuit.opcodes.iter().rev() {
if let Opcode::AssertZero(expr) = opcode {
let solve_result = ExpressionSolver::solve(&mut known_witnesses, expr);
// It doesn't matter what the result is. We expect most opcodes to not be solved successfully so we discard errors.
// At the same time, if the expression can be solved then we track this by the updates to `known_witnesses`
drop(solve_result);
}
}
// We want to retain any references to required witnesses so we "forget" these assignments.
let known_witnesses: BTreeMap<_, _> = known_witnesses
.into_iter()
.filter(|(witness, _)| !required_witnesses.contains(witness))
.collect();
(known_witnesses.into(), required_witnesses)
}
/// Returns a `Circuit` where with any constant witnesses replaced with the constant they resolve to.
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn backpropagate_constants(
circuit: Circuit,
order_list: Vec<usize>,
) -> (Circuit, Vec<usize>) {
let old_circuit_size = circuit.opcodes.len();
let optimizer = Self::new(circuit);
let (circuit, order_list) = optimizer.backpropagate_constants_iteration(order_list);
let new_circuit_size = circuit.opcodes.len();
if new_circuit_size < old_circuit_size {
Self::backpropagate_constants(circuit, order_list)
} else {
(circuit, order_list)
}
}
/// Applies a single round of constant backpropagation to a `Circuit`.
pub(crate) fn backpropagate_constants_iteration(
mut self,
order_list: Vec<usize>,
) -> (Circuit, Vec<usize>) {
let (mut known_witnesses, required_witnesses) = self.gather_known_witnesses();
let opcodes = std::mem::take(&mut self.circuit.opcodes);
fn remap_expression(known_witnesses: &WitnessMap, expression: Expression) -> Expression {
GeneralOptimizer::optimize(ExpressionSolver::evaluate(&expression, known_witnesses))
}
let mut new_order_list = Vec::with_capacity(order_list.len());
let mut new_opcodes = Vec::with_capacity(opcodes.len());
for (idx, opcode) in opcodes.into_iter().enumerate() {
let new_opcode = match opcode {
Opcode::AssertZero(expression) => {
let new_expr = remap_expression(&known_witnesses, expression);
if new_expr.is_zero() {
continue;
}
// Attempt to solve the opcode to see if we can determine the value of any witnesses in the expression.
// We only do this _after_ we apply any simplifications to create the new opcode as we want to
// keep the constraint on the witness which we are solving for here.
let solve_result = ExpressionSolver::solve(&mut known_witnesses, &new_expr);
// It doesn't matter what the result is. We expect most opcodes to not be solved successfully so we discard errors.
// At the same time, if the expression can be solved then we track this by the updates to `known_witnesses`
drop(solve_result);
Opcode::AssertZero(new_expr)
}
Opcode::Brillig(brillig) => {
let remapped_inputs = brillig
.inputs
.into_iter()
.map(|input| match input {
BrilligInputs::Single(expr) => {
BrilligInputs::Single(remap_expression(&known_witnesses, expr))
}
BrilligInputs::Array(expr_array) => {
let new_input: Vec<_> = expr_array
.into_iter()
.map(|expr| remap_expression(&known_witnesses, expr))
.collect();
BrilligInputs::Array(new_input)
}
input @ BrilligInputs::MemoryArray(_) => input,
})
.collect();
let remapped_predicate = brillig
.predicate
.map(|predicate| remap_expression(&known_witnesses, predicate));
let new_brillig = Brillig {
inputs: remapped_inputs,
predicate: remapped_predicate,
..brillig
};
let brillig_output_is_required_witness =
new_brillig.outputs.iter().any(|output| match output {
BrilligOutputs::Simple(witness) => required_witnesses.contains(witness),
BrilligOutputs::Array(witness_array) => witness_array
.iter()
.any(|witness| required_witnesses.contains(witness)),
});
if brillig_output_is_required_witness {
// If one of the brillig opcode's outputs is a required witness then we can't remove the opcode. In this case we can't replace
// all of the uses of this witness with the calculated constant so we'll be attempting to use an uninitialized witness.
//
// We then do not attempt execution of this opcode and just simplify the inputs.
Opcode::Brillig(new_brillig)
} else if let Ok(mut solver) = BrilligSolver::new(
&known_witnesses,
&HashMap::new(),
&new_brillig,
&StubbedBlackBoxSolver,
idx,
) {
match solver.solve() {
Ok(BrilligSolverStatus::Finished) => {
// Write execution outputs
match solver.finalize(&mut known_witnesses, &new_brillig) {
Ok(()) => {
// If we've managed to execute the brillig opcode at compile time, we can now just write in the
// results as constants for the rest of the circuit.
continue;
}
_ => Opcode::Brillig(new_brillig),
}
}
Ok(BrilligSolverStatus::InProgress) => unreachable!(
"Solver should either finish, block on foreign call, or error."
),
Ok(BrilligSolverStatus::ForeignCallWait(_)) | Err(_) => {
Opcode::Brillig(new_brillig)
}
}
} else {
Opcode::Brillig(new_brillig)
}
}
Opcode::Directive(Directive::ToLeRadix { a, b, radix }) => {
if b.iter().all(|output| known_witnesses.contains_key(output)) {
continue;
} else if b.iter().any(|witness| required_witnesses.contains(witness)) {
// If one of the brillig opcode's outputs is a required witness then we can't remove the opcode. In this case we can't replace
// all of the uses of this witness with the calculated constant so we'll be attempting to use an uninitialized witness.
//
// We then do not attempt execution of this opcode and just simplify the inputs.
Opcode::Directive(Directive::ToLeRadix {
a: remap_expression(&known_witnesses, a),
b,
radix,
})
} else {
let directive = Directive::ToLeRadix {
a: remap_expression(&known_witnesses, a),
b,
radix,
};
let result = solve_directives(&mut known_witnesses, &directive);
match result {
Ok(()) => continue,
Err(_) => Opcode::Directive(directive),
}
}
}
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
if solve_range_opcode(&known_witnesses, &input).is_ok() {
continue;
} else {
opcode
}
}
Opcode::BlackBoxFuncCall(_)
| Opcode::MemoryOp { .. }
| Opcode::MemoryInit { .. } => opcode,
};
new_opcodes.push(new_opcode);
new_order_list.push(order_list[idx]);
}
self.circuit.opcodes = new_opcodes;
(self.circuit, new_order_list)
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::compiler::optimizers::constant_backpropagation::ConstantBackpropagationOptimizer;
use acir::{
brillig::MemoryAddress,
circuit::{
brillig::{Brillig, BrilligOutputs},
opcodes::{BlackBoxFuncCall, FunctionInput},
Circuit, ExpressionWidth, Opcode, PublicInputs,
},
native_types::Witness,
};
use brillig_vm::brillig::Opcode as BrilligOpcode;
fn test_circuit(opcodes: Vec<Opcode>) -> Circuit {
Circuit {
current_witness_index: 1,
expression_width: ExpressionWidth::Bounded { width: 4 },
opcodes,
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs::default(),
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
}
}
#[test]
fn retain_brillig_with_required_witness_outputs() {
let brillig_opcode = Opcode::Brillig(Brillig {
inputs: Vec::new(),
outputs: vec![BrilligOutputs::Simple(Witness(1))],
bytecode: vec![
BrilligOpcode::Const {
destination: MemoryAddress(0),
bit_size: 32,
value: 1u128.into(),
},
BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 1 },
],
predicate: None,
});
let blackbox_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
lhs: FunctionInput { witness: Witness(1), num_bits: 64 },
rhs: FunctionInput { witness: Witness(2), num_bits: 64 },
output: Witness(3),
});
let opcodes = vec![brillig_opcode, blackbox_opcode];
// The optimizer should keep the lowest bit size range constraint
let circuit = test_circuit(opcodes);
let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect();
let optimizer = ConstantBackpropagationOptimizer::new(circuit);
let (optimized_circuit, _) =
optimizer.backpropagate_constants_iteration(acir_opcode_positions);
assert_eq!(
optimized_circuit.opcodes.len(),
2,
"The brillig opcode should not be removed as the output is needed as a witness"
);
}
}