quil_rs/program/memory.rs
1// Copyright 2021 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17#[cfg(feature = "stubs")]
18use pyo3_stub_gen::derive::gen_stub_pyclass;
19
20use crate::{
21 expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression},
22 instruction::{CallResolutionError, MemoryReference, Sharing, Vector, WaveformInvocation},
23 pickleable_new,
24};
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq)]
27#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
28#[cfg_attr(
29 feature = "python",
30 pyo3::pyclass(module = "quil.program", eq, frozen, hash, get_all, subclass)
31)]
32pub struct MemoryRegion {
33 pub size: Vector,
34 pub sharing: Option<Sharing>,
35}
36
37pickleable_new! {
38 impl MemoryRegion {
39 pub fn new(size: Vector, sharing: Option<Sharing>);
40 }
41}
42
43/// How an instruction or sequence of instructions can access memory.
44///
45/// Each access is stored as the name of the region that was accessed. We do not store the
46/// individual indices that were accessed.
47#[derive(Clone, Debug, Default, PartialEq, Eq)]
48pub struct MemoryAccesses {
49 /// All memory regions these instructions can read from.
50 pub reads: HashSet<String>,
51
52 /// All memory regions these instructions can write to from within the processor.
53 ///
54 /// The "within the processor" clause indicates that this covers the write to the destination of
55 /// a [`MOVE`][Instruction::Move], but not the write to the target of a
56 /// [`MEASURE`][Instruction::Measurement].
57 pub writes: HashSet<String>,
58
59 /// All memory regions these instructions can write to from outside the processor.
60 ///
61 /// The "outside the processor" clause indicates that this covers the write to the target of a
62 /// [`MEASURE`][Instruction::Measurement], but not the write to the destination of a
63 /// [`MOVE`][Instruction::Move].
64 pub captures: HashSet<String>,
65}
66
67impl MemoryAccesses {
68 /// An empty set of memory accesses
69 #[inline]
70 pub fn none() -> Self {
71 Self::default()
72 }
73
74 pub fn union(mut self, rhs: Self) -> Self {
75 let Self {
76 captures,
77 reads,
78 writes,
79 } = rhs;
80 self.captures.extend(captures);
81 self.reads.extend(reads);
82 self.writes.extend(writes);
83 self
84 }
85}
86
87/// Express a mode of memory access.
88#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
89pub enum MemoryAccessType {
90 /// Read from a memory location
91 Read,
92
93 /// Write to a memory location using classical instructions
94 Write,
95
96 /// Write to a memory location using readout (`CAPTURE` and `RAW-CAPTURE` instructions)
97 Capture,
98}
99
100#[derive(Clone, PartialEq, Debug, thiserror::Error)]
101pub enum MemoryAccessesError {
102 #[error(transparent)]
103 CallResolution(#[from] CallResolutionError),
104
105 #[error("Instruction handler reported an error when constructing memory accesses: {0}")]
106 InstructionHandlerError(String),
107}
108
109pub mod expression {
110 // Contains an implementation of an iterator over all [`MemoryReference`]s read from by an
111 // expression. We construct a separate iterator here so we can avoid reifying a vector of
112 // [`MemoryReference`]s that we will then immediately consume.
113
114 use super::*;
115
116 /// An iterator over all the memory references contained in an expression.
117 #[derive(Clone, Debug)]
118 pub struct MemoryReferences<'a> {
119 pub(super) stack: Vec<&'a Expression>,
120 }
121
122 impl<'a> Iterator for MemoryReferences<'a> {
123 type Item = &'a MemoryReference;
124
125 fn next(&mut self) -> Option<Self::Item> {
126 // If we imagine collecting into a vector, this function is roughly
127 //
128 // ```
129 // fn collect_into(expr: &Expression, output: &mut Vec<&MemoryReference>) {
130 // match expr {
131 // Expression::NoReferences(_) => (),
132 // Expression::Address(reference) => output.push(reference),
133 // Expression::OneSubexpression(expression) => collect_into(expression, output),
134 // Expression::TwoSubexpressions(left, right) => {
135 // collect_into(left, output);
136 // collect_into(right, output);
137 // }
138 // }
139 // }
140 // ```
141 //
142 // In order to implement an iterator without allocating the whole vector, we still have
143 // to reify the stack for the two-subexpression case; that's what `self.stack` is.
144 //
145 // We then implement this function with two loops. The outer loop, `'stack_search`, is
146 // our stack traversal; it finds the nearest stack frame. The inner loop is effectively
147 // our *tail calls*, where we don't need to allocate another stack frame; instead, we
148 // assign to the current stack frame and keep going. We don't have to write back to the
149 // vector in the tail call case because no leaf node has more than one reference, so
150 // once we've popped a stack frame off it'll fully bottom out.
151 //
152 // Note also that in the two-subexpression case we actually swap the order from the
153 // `collect_into` example. This is because `collect_into` reuses the same `output`, so
154 // the state is preserved and `collect_into(right, output)` appends after `left`.
155 // However, when we are emitting our results immediately, the tail call will happen
156 // *before* the delayed stack frame.
157 //
158 // And yes, this would all be simpler with `gen` blocks.
159
160 let Self { stack } = self;
161
162 // Search through all parent expressions
163 'stack_search: while let Some(mut expr) = stack.pop() {
164 // An optimization for when there's only one child expression
165 loop {
166 match expr {
167 // We're done with this expression and didn't find anything; time to walk down
168 // another tree branch, if there are any left.
169 Expression::Number(_)
170 | Expression::PiConstant()
171 | Expression::Variable(_) => continue 'stack_search,
172
173 // We're done with this expression and it was successful; stop iterating here,
174 // and when we return, walk down another tree branch if there are any left.
175 Expression::Address(reference) => return Some(reference),
176
177 // This expression only has one subexpression; we can avoid pushing a new
178 // "stack frame" and immediately popping it by overwriting the current stack
179 // frame.
180 Expression::FunctionCall(FunctionCallExpression {
181 expression,
182 function: _,
183 })
184 | Expression::Prefix(PrefixExpression {
185 expression,
186 operator: _,
187 }) => expr = expression,
188
189 // This expression has two subexpressions; we delay searching through the
190 // right child by pushing it on the stack, and "tail call" to search through
191 // the left child immediately as we did with the single-subexpression case
192 // above.
193 Expression::Infix(InfixExpression {
194 left,
195 right,
196 operator: _,
197 }) => {
198 stack.push(right);
199 expr = left;
200 }
201 }
202 }
203 }
204
205 // We've finished our traversal; there are no subexpressions left.
206 None
207 }
208 }
209
210 impl std::iter::FusedIterator for MemoryReferences<'_> {}
211}
212
213impl Expression {
214 /// Return an iterator over all the memory references contained within this expression.
215 pub fn memory_references(&self) -> expression::MemoryReferences<'_> {
216 expression::MemoryReferences { stack: vec![self] }
217 }
218}
219
220impl WaveformInvocation {
221 /// Return, if any, the memory references contained within this `WaveformInvocation`.
222 pub fn memory_references(&self) -> impl std::iter::FusedIterator<Item = &MemoryReference> {
223 self.parameters
224 .values()
225 .flat_map(Expression::memory_references)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use std::collections::HashSet;
232
233 use rstest::rstest;
234
235 use crate::{
236 expression::Expression,
237 instruction::{
238 ArithmeticOperand, Convert, DefaultHandler, Exchange, ExternSignatureMap,
239 FrameIdentifier, Instruction, InstructionHandler as _, MemoryReference, Qubit,
240 SetFrequency, ShiftFrequency, Store,
241 },
242 program::MemoryAccesses,
243 };
244
245 #[rstest]
246 #[case(
247 r#"
248cis(func_ref[0]) ^
249cos(func_ref[1]) +
250exp(func_ref[2]) -
251sin(func_ref[3]) /
252sqrt(func_ref[4]) *
253
254(infix_ref[0] ^ infix_ref[0]) ^
255(infix_ref[1] + infix_ref[1]) +
256(infix_ref[2] - infix_ref[2]) -
257(infix_ref[3] / infix_ref[3]) /
258(infix_ref[4] * infix_ref[4]) *
259
2601.0 ^
261
262pi +
263
264(-prefix_ref) -
265
266%variable
267"#,
268 &[
269 ("func_ref", 0),
270 ("func_ref", 1),
271 ("func_ref", 2),
272 ("func_ref", 3),
273 ("func_ref", 4),
274 ("infix_ref", 0),
275 ("infix_ref", 0),
276 ("infix_ref", 1),
277 ("infix_ref", 1),
278 ("infix_ref", 2),
279 ("infix_ref", 2),
280 ("infix_ref", 3),
281 ("infix_ref", 3),
282 ("infix_ref", 4),
283 ("infix_ref", 4),
284 ("prefix_ref", 0),
285 ]
286 )]
287 fn expr_references(#[case] expr: &str, #[case] expected_refs: &[(&str, u64)]) {
288 let expr = expr.replace('\n', " ").parse::<Expression>().unwrap();
289
290 let computed_refs: Vec<_> = expr.memory_references().cloned().collect();
291
292 let expected_refs: Vec<_> = expected_refs
293 .iter()
294 .map(|(name, index)| MemoryReference {
295 name: (*name).to_owned(),
296 index: *index,
297 })
298 .collect();
299
300 assert_eq!(computed_refs, expected_refs);
301 }
302
303 #[rstest]
304 #[case(
305 Instruction::Store(Store {
306 destination: "destination".to_string(),
307 offset: MemoryReference {
308 name: "offset".to_string(),
309 index: Default::default()
310 },
311 source: ArithmeticOperand::MemoryReference(MemoryReference {
312 name: "source".to_string(),
313 index: Default::default()
314 }),
315 }),
316 MemoryAccesses {
317 captures: HashSet::new(),
318 reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
319 writes: ["destination"].iter().cloned().map(String::from).collect(),
320 }
321 )]
322 #[case(
323 Instruction::Convert(Convert {
324 destination: MemoryReference {
325 name: "destination".to_string(),
326 index: Default::default()
327 },
328 source: MemoryReference {
329 name: "source".to_string(),
330 index: Default::default()
331 },
332 }),
333 MemoryAccesses {
334 captures: HashSet::new(),
335 reads: ["source"].iter().cloned().map(String::from).collect(),
336 writes: ["destination"].iter().cloned().map(String::from).collect(),
337 }
338 )]
339 #[case(
340 Instruction::Exchange(Exchange {
341 left: MemoryReference {
342 name: "left".to_string(),
343 index: Default::default()
344 },
345 right: MemoryReference {
346 name: "right".to_string(),
347 index: Default::default()
348 },
349 }),
350 MemoryAccesses {
351 captures: HashSet::new(),
352 reads: ["left", "right"].iter().cloned().map(String::from).collect(),
353 writes: ["left", "right"].iter().cloned().map(String::from).collect(),
354 }
355 )]
356 #[case(
357 Instruction::SetFrequency(SetFrequency {
358 frequency: Expression::Address(MemoryReference {
359 name: "frequency".to_string(),
360 index: Default::default()
361 }),
362 frame: FrameIdentifier {
363 name: "frame".to_string(),
364 qubits: vec![Qubit::Fixed(0)]
365 }
366 }),
367 MemoryAccesses {
368 captures: HashSet::new(),
369 reads: ["frequency"].iter().cloned().map(String::from).collect(),
370 writes: HashSet::new(),
371 }
372 )]
373 #[case(
374 Instruction::ShiftFrequency(ShiftFrequency {
375 frequency: Expression::Address(MemoryReference {
376 name: "frequency".to_string(),
377 index: Default::default()
378 }),
379 frame: FrameIdentifier {
380 name: "frame".to_string(),
381 qubits: vec![Qubit::Fixed(0)]
382 }
383 }),
384 MemoryAccesses {
385 captures: HashSet::new(),
386 reads: ["frequency"].iter().cloned().map(String::from).collect(),
387 writes: HashSet::new(),
388 }
389 )]
390 fn test_instruction_accesses(
391 #[case] instruction: Instruction,
392 #[case] expected: MemoryAccesses,
393 ) {
394 let memory_accesses = DefaultHandler
395 .memory_accesses(&ExternSignatureMap::default(), &instruction)
396 .expect("must be able to get memory accesses");
397 assert_eq!(memory_accesses.captures, expected.captures);
398 assert_eq!(memory_accesses.reads, expected.reads);
399 assert_eq!(memory_accesses.writes, expected.writes);
400 }
401}