quil_rs/program/
memory.rs

1// Copyright 2021 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17#[cfg(feature = "stubs")]
18use pyo3_stub_gen::derive::gen_stub_pyclass;
19
20use crate::{
21    expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression},
22    instruction::{CallResolutionError, MemoryReference, Sharing, Vector, WaveformInvocation},
23    pickleable_new,
24};
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq)]
27#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
28#[cfg_attr(
29    feature = "python",
30    pyo3::pyclass(module = "quil.program", eq, frozen, hash, get_all, subclass)
31)]
32pub struct MemoryRegion {
33    pub size: Vector,
34    pub sharing: Option<Sharing>,
35}
36
37pickleable_new! {
38    impl MemoryRegion {
39        pub fn new(size: Vector, sharing: Option<Sharing>);
40    }
41}
42
43/// How an instruction or sequence of instructions can access memory.
44///
45/// Each access is stored as the name of the region that was accessed.  We do not store the
46/// individual indices that were accessed.
47#[derive(Clone, Debug, Default, PartialEq, Eq)]
48pub struct MemoryAccesses {
49    /// All memory regions these instructions can read from.
50    pub reads: HashSet<String>,
51
52    /// All memory regions these instructions can write to from within the processor.
53    ///
54    /// The "within the processor" clause indicates that this covers the write to the destination of
55    /// a [`MOVE`][Instruction::Move], but not the write to the target of a
56    /// [`MEASURE`][Instruction::Measurement].
57    pub writes: HashSet<String>,
58
59    /// All memory regions these instructions can write to from outside the processor.
60    ///
61    /// The "outside the processor" clause indicates that this covers the write to the target of a
62    /// [`MEASURE`][Instruction::Measurement], but not the write to the destination of a
63    /// [`MOVE`][Instruction::Move].
64    pub captures: HashSet<String>,
65}
66
67impl MemoryAccesses {
68    /// An empty set of memory accesses
69    #[inline]
70    pub fn none() -> Self {
71        Self::default()
72    }
73
74    pub fn union(mut self, rhs: Self) -> Self {
75        let Self {
76            captures,
77            reads,
78            writes,
79        } = rhs;
80        self.captures.extend(captures);
81        self.reads.extend(reads);
82        self.writes.extend(writes);
83        self
84    }
85}
86
87/// Express a mode of memory access.
88#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
89pub enum MemoryAccessType {
90    /// Read from a memory location
91    Read,
92
93    /// Write to a memory location using classical instructions
94    Write,
95
96    /// Write to a memory location using readout (`CAPTURE` and `RAW-CAPTURE` instructions)
97    Capture,
98}
99
100#[derive(Clone, PartialEq, Debug, thiserror::Error)]
101pub enum MemoryAccessesError {
102    #[error(transparent)]
103    CallResolution(#[from] CallResolutionError),
104
105    #[error("Instruction handler reported an error when constructing memory accesses: {0}")]
106    InstructionHandlerError(String),
107}
108
109pub mod expression {
110    // Contains an implementation of an iterator over all [`MemoryReference`]s read from by an
111    // expression.  We construct a separate iterator here so we can avoid reifying a vector of
112    // [`MemoryReference`]s that we will then immediately consume.
113
114    use super::*;
115
116    /// An iterator over all the memory references contained in an expression.
117    #[derive(Clone, Debug)]
118    pub struct MemoryReferences<'a> {
119        pub(super) stack: Vec<&'a Expression>,
120    }
121
122    impl<'a> Iterator for MemoryReferences<'a> {
123        type Item = &'a MemoryReference;
124
125        fn next(&mut self) -> Option<Self::Item> {
126            // If we imagine collecting into a vector, this function is roughly
127            //
128            // ```
129            // fn collect_into(expr: &Expression, output: &mut Vec<&MemoryReference>) {
130            //     match expr {
131            //         Expression::NoReferences(_) => (),
132            //         Expression::Address(reference) => output.push(reference),
133            //         Expression::OneSubexpression(expression) => collect_into(expression, output),
134            //         Expression::TwoSubexpressions(left, right) => {
135            //             collect_into(left, output);
136            //             collect_into(right, output);
137            //         }
138            //     }
139            // }
140            // ```
141            //
142            // In order to implement an iterator without allocating the whole vector, we still have
143            // to reify the stack for the two-subexpression case; that's what `self.stack` is.
144            //
145            // We then implement this function with two loops.  The outer loop, `'stack_search`, is
146            // our stack traversal; it finds the nearest stack frame.  The inner loop is effectively
147            // our *tail calls*, where we don't need to allocate another stack frame; instead, we
148            // assign to the current stack frame and keep going.  We don't have to write back to the
149            // vector in the tail call case because no leaf node has more than one reference, so
150            // once we've popped a stack frame off it'll fully bottom out.
151            //
152            // Note also that in the two-subexpression case we actually swap the order from the
153            // `collect_into` example.  This is because `collect_into` reuses the same `output`, so
154            // the state is preserved and `collect_into(right, output)` appends after `left`.
155            // However, when we are emitting our results immediately, the tail call will happen
156            // *before* the delayed stack frame.
157            //
158            // And yes, this would all be simpler with `gen` blocks.
159
160            let Self { stack } = self;
161
162            // Search through all parent expressions
163            'stack_search: while let Some(mut expr) = stack.pop() {
164                // An optimization for when there's only one child expression
165                loop {
166                    match expr {
167                        // We're done with this expression and didn't find anything; time to walk down
168                        // another tree branch, if there are any left.
169                        Expression::Number(_)
170                        | Expression::PiConstant()
171                        | Expression::Variable(_) => continue 'stack_search,
172
173                        // We're done with this expression and it was successful; stop iterating here,
174                        // and when we return, walk down another tree branch if there are any left.
175                        Expression::Address(reference) => return Some(reference),
176
177                        // This expression only has one subexpression; we can avoid pushing a new
178                        // "stack frame" and immediately popping it by overwriting the current stack
179                        // frame.
180                        Expression::FunctionCall(FunctionCallExpression {
181                            expression,
182                            function: _,
183                        })
184                        | Expression::Prefix(PrefixExpression {
185                            expression,
186                            operator: _,
187                        }) => expr = expression,
188
189                        // This expression has two subexpressions; we delay searching through the
190                        // right child by pushing it on the stack, and "tail call" to search through
191                        // the left child immediately as we did with the single-subexpression case
192                        // above.
193                        Expression::Infix(InfixExpression {
194                            left,
195                            right,
196                            operator: _,
197                        }) => {
198                            stack.push(right);
199                            expr = left;
200                        }
201                    }
202                }
203            }
204
205            // We've finished our traversal; there are no subexpressions left.
206            None
207        }
208    }
209
210    impl std::iter::FusedIterator for MemoryReferences<'_> {}
211}
212
213impl Expression {
214    /// Return an iterator over all the memory references contained within this expression.
215    pub fn memory_references(&self) -> expression::MemoryReferences<'_> {
216        expression::MemoryReferences { stack: vec![self] }
217    }
218}
219
220impl WaveformInvocation {
221    /// Return, if any, the memory references contained within this `WaveformInvocation`.
222    pub fn memory_references(&self) -> impl std::iter::FusedIterator<Item = &MemoryReference> {
223        self.parameters
224            .values()
225            .flat_map(Expression::memory_references)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use std::collections::HashSet;
232
233    use rstest::rstest;
234
235    use crate::{
236        expression::Expression,
237        instruction::{
238            ArithmeticOperand, Convert, DefaultHandler, Exchange, ExternSignatureMap,
239            FrameIdentifier, Instruction, InstructionHandler as _, MemoryReference, Qubit,
240            SetFrequency, ShiftFrequency, Store,
241        },
242        program::MemoryAccesses,
243    };
244
245    #[rstest]
246    #[case(
247        r#"
248cis(func_ref[0]) ^
249cos(func_ref[1]) +
250exp(func_ref[2]) -
251sin(func_ref[3]) /
252sqrt(func_ref[4]) *
253
254(infix_ref[0] ^ infix_ref[0]) ^
255(infix_ref[1] + infix_ref[1]) +
256(infix_ref[2] - infix_ref[2]) -
257(infix_ref[3] / infix_ref[3]) /
258(infix_ref[4] * infix_ref[4]) *
259
2601.0 ^
261
262pi +
263
264(-prefix_ref) -
265
266%variable
267"#,
268        &[
269            ("func_ref", 0),
270            ("func_ref", 1),
271            ("func_ref", 2),
272            ("func_ref", 3),
273            ("func_ref", 4),
274            ("infix_ref", 0),
275            ("infix_ref", 0),
276            ("infix_ref", 1),
277            ("infix_ref", 1),
278            ("infix_ref", 2),
279            ("infix_ref", 2),
280            ("infix_ref", 3),
281            ("infix_ref", 3),
282            ("infix_ref", 4),
283            ("infix_ref", 4),
284            ("prefix_ref", 0),
285        ]
286    )]
287    fn expr_references(#[case] expr: &str, #[case] expected_refs: &[(&str, u64)]) {
288        let expr = expr.replace('\n', " ").parse::<Expression>().unwrap();
289
290        let computed_refs: Vec<_> = expr.memory_references().cloned().collect();
291
292        let expected_refs: Vec<_> = expected_refs
293            .iter()
294            .map(|(name, index)| MemoryReference {
295                name: (*name).to_owned(),
296                index: *index,
297            })
298            .collect();
299
300        assert_eq!(computed_refs, expected_refs);
301    }
302
303    #[rstest]
304    #[case(
305        Instruction::Store(Store {
306            destination: "destination".to_string(),
307            offset: MemoryReference {
308                name: "offset".to_string(),
309                index: Default::default()
310            },
311            source: ArithmeticOperand::MemoryReference(MemoryReference {
312                name: "source".to_string(),
313                index: Default::default()
314            }),
315        }),
316        MemoryAccesses {
317            captures: HashSet::new(),
318            reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
319            writes: ["destination"].iter().cloned().map(String::from).collect(),
320        }
321    )]
322    #[case(
323        Instruction::Convert(Convert {
324            destination: MemoryReference {
325                name: "destination".to_string(),
326                index: Default::default()
327            },
328            source: MemoryReference {
329                name: "source".to_string(),
330                index: Default::default()
331            },
332        }),
333        MemoryAccesses {
334            captures: HashSet::new(),
335            reads: ["source"].iter().cloned().map(String::from).collect(),
336            writes: ["destination"].iter().cloned().map(String::from).collect(),
337        }
338    )]
339    #[case(
340        Instruction::Exchange(Exchange {
341            left: MemoryReference {
342                name: "left".to_string(),
343                index: Default::default()
344            },
345            right: MemoryReference {
346                name: "right".to_string(),
347                index: Default::default()
348            },
349        }),
350        MemoryAccesses {
351            captures: HashSet::new(),
352            reads: ["left", "right"].iter().cloned().map(String::from).collect(),
353            writes: ["left", "right"].iter().cloned().map(String::from).collect(),
354        }
355    )]
356    #[case(
357        Instruction::SetFrequency(SetFrequency {
358            frequency: Expression::Address(MemoryReference {
359                name: "frequency".to_string(),
360                index: Default::default()
361            }),
362            frame: FrameIdentifier {
363                name: "frame".to_string(),
364                qubits: vec![Qubit::Fixed(0)]
365            }
366        }),
367        MemoryAccesses {
368            captures: HashSet::new(),
369            reads: ["frequency"].iter().cloned().map(String::from).collect(),
370            writes: HashSet::new(),
371        }
372    )]
373    #[case(
374        Instruction::ShiftFrequency(ShiftFrequency {
375            frequency: Expression::Address(MemoryReference {
376                name: "frequency".to_string(),
377                index: Default::default()
378            }),
379            frame: FrameIdentifier {
380                name: "frame".to_string(),
381                qubits: vec![Qubit::Fixed(0)]
382            }
383        }),
384        MemoryAccesses {
385            captures: HashSet::new(),
386            reads: ["frequency"].iter().cloned().map(String::from).collect(),
387            writes: HashSet::new(),
388        }
389    )]
390    fn test_instruction_accesses(
391        #[case] instruction: Instruction,
392        #[case] expected: MemoryAccesses,
393    ) {
394        let memory_accesses = DefaultHandler
395            .memory_accesses(&ExternSignatureMap::default(), &instruction)
396            .expect("must be able to get memory accesses");
397        assert_eq!(memory_accesses.captures, expected.captures);
398        assert_eq!(memory_accesses.reads, expected.reads);
399        assert_eq!(memory_accesses.writes, expected.writes);
400    }
401}