1use std::collections::HashMap;
2
3use acir::{
4 circuit::opcodes::MemOp,
5 native_types::{Expression, Witness, WitnessMap},
6 FieldElement,
7};
8
9use super::{
10 arithmetic::ExpressionSolver, get_value, insert_value, is_predicate_false, witness_to_value,
11};
12use super::{ErrorLocation, OpcodeResolutionError};
13
14type MemoryIndex = u32;
15
16#[derive(Default)]
18pub(crate) struct MemoryOpSolver {
19 pub(super) block_value: HashMap<MemoryIndex, FieldElement>,
20 pub(super) block_len: u32,
21}
22
23impl MemoryOpSolver {
24 fn write_memory_index(
25 &mut self,
26 index: MemoryIndex,
27 value: FieldElement,
28 ) -> Result<(), OpcodeResolutionError> {
29 if index >= self.block_len {
30 return Err(OpcodeResolutionError::IndexOutOfBounds {
31 opcode_location: ErrorLocation::Unresolved,
32 index,
33 array_size: self.block_len,
34 });
35 }
36 self.block_value.insert(index, value);
37 Ok(())
38 }
39
40 fn read_memory_index(&self, index: MemoryIndex) -> Result<FieldElement, OpcodeResolutionError> {
41 self.block_value.get(&index).copied().ok_or(OpcodeResolutionError::IndexOutOfBounds {
42 opcode_location: ErrorLocation::Unresolved,
43 index,
44 array_size: self.block_len,
45 })
46 }
47
48 pub(crate) fn init(
50 &mut self,
51 init: &[Witness],
52 initial_witness: &WitnessMap,
53 ) -> Result<(), OpcodeResolutionError> {
54 self.block_len = init.len() as u32;
55 for (memory_index, witness) in init.iter().enumerate() {
56 self.write_memory_index(
57 memory_index as MemoryIndex,
58 *witness_to_value(initial_witness, *witness)?,
59 )?;
60 }
61 Ok(())
62 }
63
64 pub(crate) fn solve_memory_op(
65 &mut self,
66 op: &MemOp,
67 initial_witness: &mut WitnessMap,
68 predicate: &Option<Expression>,
69 ) -> Result<(), OpcodeResolutionError> {
70 let operation = get_value(&op.operation, initial_witness)?;
71
72 let index = get_value(&op.index, initial_witness)?;
74 let memory_index = index.try_to_u64().unwrap() as MemoryIndex;
75
76 let value = ExpressionSolver::evaluate(&op.value, initial_witness);
81
82 let is_read_operation = operation.is_zero();
84
85 let skip_operation = is_predicate_false(initial_witness, predicate)?;
87
88 if is_read_operation {
89 let value_read_witness = value.to_witness().expect(
94 "Memory must be read into a specified witness index, encountered an Expression",
95 );
96
97 let value_in_array = if skip_operation {
100 FieldElement::zero()
101 } else {
102 self.read_memory_index(memory_index)?
103 };
104 insert_value(&value_read_witness, value_in_array, initial_witness)
105 } else {
106 let value_write = value;
111
112 if skip_operation {
114 Ok(())
117 } else {
118 let value_to_write = get_value(&value_write, initial_witness)?;
119 self.write_memory_index(memory_index, value_to_write)
120 }
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use std::collections::BTreeMap;
128
129 use acir::{
130 circuit::opcodes::MemOp,
131 native_types::{Expression, Witness, WitnessMap},
132 FieldElement,
133 };
134
135 use super::MemoryOpSolver;
136
137 #[test]
138 fn test_solver() {
139 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
140 (Witness(1), FieldElement::from(1u128)),
141 (Witness(2), FieldElement::from(1u128)),
142 (Witness(3), FieldElement::from(2u128)),
143 ]));
144
145 let init = vec![Witness(1), Witness(2)];
146
147 let trace = vec![
148 MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
149 MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
150 ];
151
152 let mut block_solver = MemoryOpSolver::default();
153 block_solver.init(&init, &initial_witness).unwrap();
154
155 for op in trace {
156 block_solver.solve_memory_op(&op, &mut initial_witness, &None).unwrap();
157 }
158
159 assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
160 }
161
162 #[test]
163 fn test_index_out_of_bounds() {
164 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
165 (Witness(1), FieldElement::from(1u128)),
166 (Witness(2), FieldElement::from(1u128)),
167 (Witness(3), FieldElement::from(2u128)),
168 ]));
169
170 let init = vec![Witness(1), Witness(2)];
171
172 let invalid_trace = vec![
173 MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
174 MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
175 ];
176 let mut block_solver = MemoryOpSolver::default();
177 block_solver.init(&init, &initial_witness).unwrap();
178 let mut err = None;
179 for op in invalid_trace {
180 if err.is_none() {
181 err = block_solver.solve_memory_op(&op, &mut initial_witness, &None).err();
182 }
183 }
184
185 assert!(matches!(
186 err,
187 Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
188 opcode_location: _,
189 index: 2,
190 array_size: 2
191 })
192 ));
193 }
194
195 #[test]
196 fn test_predicate_on_read() {
197 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
198 (Witness(1), FieldElement::from(1u128)),
199 (Witness(2), FieldElement::from(1u128)),
200 (Witness(3), FieldElement::from(2u128)),
201 ]));
202
203 let init = vec![Witness(1), Witness(2)];
204
205 let invalid_trace = vec![
206 MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
207 MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
208 ];
209 let mut block_solver = MemoryOpSolver::default();
210 block_solver.init(&init, &initial_witness).unwrap();
211 let mut err = None;
212 for op in invalid_trace {
213 if err.is_none() {
214 err = block_solver
215 .solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
216 .err();
217 }
218 }
219
220 assert_eq!(err, None);
222 assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
224 }
225
226 #[test]
227 fn test_predicate_on_write() {
228 let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
229 (Witness(1), FieldElement::from(1u128)),
230 (Witness(2), FieldElement::from(1u128)),
231 (Witness(3), FieldElement::from(2u128)),
232 ]));
233
234 let init = vec![Witness(1), Witness(2)];
235
236 let invalid_trace = vec![
237 MemOp::write_to_mem_index(FieldElement::from(2u128).into(), Witness(3).into()),
238 MemOp::read_at_mem_index(FieldElement::from(0u128).into(), Witness(4)),
239 MemOp::read_at_mem_index(FieldElement::from(1u128).into(), Witness(5)),
240 ];
241 let mut block_solver = MemoryOpSolver::default();
242 block_solver.init(&init, &initial_witness).unwrap();
243 let mut err = None;
244 for op in invalid_trace {
245 if err.is_none() {
246 err = block_solver
247 .solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
248 .err();
249 }
250 }
251
252 assert_eq!(err, None);
254 assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
256 assert_eq!(initial_witness[&Witness(5)], FieldElement::from(0u128));
257 }
258}