acvm/pwg/
memory_op.rs

1use std::collections::HashMap;
2
3use acir::{
4    circuit::opcodes::MemOp,
5    native_types::{Expression, Witness, WitnessMap},
6    FieldElement,
7};
8
9use super::{
10    arithmetic::ExpressionSolver, get_value, insert_value, is_predicate_false, witness_to_value,
11};
12use super::{ErrorLocation, OpcodeResolutionError};
13
14type MemoryIndex = u32;
15
16/// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes.
17#[derive(Default)]
18pub(crate) struct MemoryOpSolver {
19    pub(super) block_value: HashMap<MemoryIndex, FieldElement>,
20    pub(super) block_len: u32,
21}
22
23impl MemoryOpSolver {
24    fn write_memory_index(
25        &mut self,
26        index: MemoryIndex,
27        value: FieldElement,
28    ) -> Result<(), OpcodeResolutionError> {
29        if index >= self.block_len {
30            return Err(OpcodeResolutionError::IndexOutOfBounds {
31                opcode_location: ErrorLocation::Unresolved,
32                index,
33                array_size: self.block_len,
34            });
35        }
36        self.block_value.insert(index, value);
37        Ok(())
38    }
39
40    fn read_memory_index(&self, index: MemoryIndex) -> Result<FieldElement, OpcodeResolutionError> {
41        self.block_value.get(&index).copied().ok_or(OpcodeResolutionError::IndexOutOfBounds {
42            opcode_location: ErrorLocation::Unresolved,
43            index,
44            array_size: self.block_len,
45        })
46    }
47
48    /// Set the block_value from a MemoryInit opcode
49    pub(crate) fn init(
50        &mut self,
51        init: &[Witness],
52        initial_witness: &WitnessMap,
53    ) -> Result<(), OpcodeResolutionError> {
54        self.block_len = init.len() as u32;
55        for (memory_index, witness) in init.iter().enumerate() {
56            self.write_memory_index(
57                memory_index as MemoryIndex,
58                *witness_to_value(initial_witness, *witness)?,
59            )?;
60        }
61        Ok(())
62    }
63
64    pub(crate) fn solve_memory_op(
65        &mut self,
66        op: &MemOp,
67        initial_witness: &mut WitnessMap,
68        predicate: &Option<Expression>,
69    ) -> Result<(), OpcodeResolutionError> {
70        let operation = get_value(&op.operation, initial_witness)?;
71
72        // Find the memory index associated with this memory operation.
73        let index = get_value(&op.index, initial_witness)?;
74        let memory_index = index.try_to_u64().unwrap() as MemoryIndex;
75
76        // Calculate the value associated with this memory operation.
77        //
78        // In read operations, this corresponds to the witness index at which the value from memory will be written.
79        // In write operations, this corresponds to the expression which will be written to memory.
80        let value = ExpressionSolver::evaluate(&op.value, initial_witness);
81
82        // `operation == 0` implies a read operation. (`operation == 1` implies write operation).
83        let is_read_operation = operation.is_zero();
84
85        // Fetch whether or not the predicate is false (e.g. equal to zero)
86        let skip_operation = is_predicate_false(initial_witness, predicate)?;
87
88        if is_read_operation {
89            // `value_read = arr[memory_index]`
90            //
91            // This is the value that we want to read into; i.e. copy from the memory block
92            // into this value.
93            let value_read_witness = value.to_witness().expect(
94                "Memory must be read into a specified witness index, encountered an Expression",
95            );
96
97            // A zero predicate indicates that we should skip the read operation
98            // and zero out the operation's output.
99            let value_in_array = if skip_operation {
100                FieldElement::zero()
101            } else {
102                self.read_memory_index(memory_index)?
103            };
104            insert_value(&value_read_witness, value_in_array, initial_witness)
105        } else {
106            // `arr[memory_index] = value_write`
107            //
108            // This is the value that we want to write into; i.e. copy from `value_write`
109            // into the memory block.
110            let value_write = value;
111
112            // A zero predicate indicates that we should skip the write operation.
113            if skip_operation {
114                // We only want to write to already initialized memory.
115                // Do nothing if the predicate is zero.
116                Ok(())
117            } else {
118                let value_to_write = get_value(&value_write, initial_witness)?;
119                self.write_memory_index(memory_index, value_to_write)
120            }
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use std::collections::BTreeMap;
128
129    use acir::{
130        circuit::opcodes::MemOp,
131        native_types::{Expression, Witness, WitnessMap},
132        FieldElement,
133    };
134
135    use super::MemoryOpSolver;
136
137    #[test]
138    fn test_solver() {
139        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
140            (Witness(1), FieldElement::from(1u128)),
141            (Witness(2), FieldElement::from(1u128)),
142            (Witness(3), FieldElement::from(2u128)),
143        ]));
144
145        let init = vec![Witness(1), Witness(2)];
146
147        let trace = vec![
148            MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
149            MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
150        ];
151
152        let mut block_solver = MemoryOpSolver::default();
153        block_solver.init(&init, &initial_witness).unwrap();
154
155        for op in trace {
156            block_solver.solve_memory_op(&op, &mut initial_witness, &None).unwrap();
157        }
158
159        assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
160    }
161
162    #[test]
163    fn test_index_out_of_bounds() {
164        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
165            (Witness(1), FieldElement::from(1u128)),
166            (Witness(2), FieldElement::from(1u128)),
167            (Witness(3), FieldElement::from(2u128)),
168        ]));
169
170        let init = vec![Witness(1), Witness(2)];
171
172        let invalid_trace = vec![
173            MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
174            MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
175        ];
176        let mut block_solver = MemoryOpSolver::default();
177        block_solver.init(&init, &initial_witness).unwrap();
178        let mut err = None;
179        for op in invalid_trace {
180            if err.is_none() {
181                err = block_solver.solve_memory_op(&op, &mut initial_witness, &None).err();
182            }
183        }
184
185        assert!(matches!(
186            err,
187            Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
188                opcode_location: _,
189                index: 2,
190                array_size: 2
191            })
192        ));
193    }
194
195    #[test]
196    fn test_predicate_on_read() {
197        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
198            (Witness(1), FieldElement::from(1u128)),
199            (Witness(2), FieldElement::from(1u128)),
200            (Witness(3), FieldElement::from(2u128)),
201        ]));
202
203        let init = vec![Witness(1), Witness(2)];
204
205        let invalid_trace = vec![
206            MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
207            MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
208        ];
209        let mut block_solver = MemoryOpSolver::default();
210        block_solver.init(&init, &initial_witness).unwrap();
211        let mut err = None;
212        for op in invalid_trace {
213            if err.is_none() {
214                err = block_solver
215                    .solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
216                    .err();
217            }
218        }
219
220        // Should have no index out of bounds error where predicate is zero
221        assert_eq!(err, None);
222        // The result of a read under a zero predicate should be zero
223        assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
224    }
225
226    #[test]
227    fn test_predicate_on_write() {
228        let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
229            (Witness(1), FieldElement::from(1u128)),
230            (Witness(2), FieldElement::from(1u128)),
231            (Witness(3), FieldElement::from(2u128)),
232        ]));
233
234        let init = vec![Witness(1), Witness(2)];
235
236        let invalid_trace = vec![
237            MemOp::write_to_mem_index(FieldElement::from(2u128).into(), Witness(3).into()),
238            MemOp::read_at_mem_index(FieldElement::from(0u128).into(), Witness(4)),
239            MemOp::read_at_mem_index(FieldElement::from(1u128).into(), Witness(5)),
240        ];
241        let mut block_solver = MemoryOpSolver::default();
242        block_solver.init(&init, &initial_witness).unwrap();
243        let mut err = None;
244        for op in invalid_trace {
245            if err.is_none() {
246                err = block_solver
247                    .solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
248                    .err();
249            }
250        }
251
252        // Should have no index out of bounds error where predicate is zero
253        assert_eq!(err, None);
254        // The memory under a zero predicate should be zeroed out
255        assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
256        assert_eq!(initial_witness[&Witness(5)], FieldElement::from(0u128));
257    }
258}