quil_rs/program/
memory.rs

1// Copyright 2021 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17#[cfg(feature = "stubs")]
18use pyo3_stub_gen::derive::gen_stub_pyclass;
19
20use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression};
21use crate::instruction::{
22    Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, CallResolutionError, Capture,
23    CircuitDefinition, Comparison, ComparisonOperand, Convert, Delay, Exchange, ExternSignatureMap,
24    Gate, GateDefinition, GateSpecification, Instruction, JumpUnless, JumpWhen, Load,
25    MeasureCalibrationDefinition, Measurement, MemoryReference, Move, Pulse, RawCapture,
26    SetFrequency, SetPhase, SetScale, Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic,
27    Vector, WaveformInvocation,
28};
29use crate::pickleable_new;
30
31#[derive(Clone, Debug, Hash, PartialEq, Eq)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(module = "quil.program", eq, frozen, hash, get_all, subclass)
36)]
37pub struct MemoryRegion {
38    pub size: Vector,
39    pub sharing: Option<Sharing>,
40}
41
42pickleable_new! {
43    impl MemoryRegion {
44        pub fn new(size: Vector, sharing: Option<Sharing>);
45    }
46}
47
48#[derive(Clone, Debug, Default, PartialEq)]
49pub struct MemoryAccesses {
50    pub captures: HashSet<String>,
51    pub reads: HashSet<String>,
52    pub writes: HashSet<String>,
53}
54
55/// Express a mode of memory access.
56#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
57pub enum MemoryAccessType {
58    /// Read from a memory location
59    Read,
60
61    /// Write to a memory location using classical instructions
62    Write,
63
64    /// Write to a memory location using readout (`CAPTURE` and `RAW-CAPTURE` instructions)
65    Capture,
66}
67
68macro_rules! merge_sets {
69    ($left:expr, $right:expr) => {
70        $left.union(&$right).cloned().collect::<HashSet<String>>()
71    };
72}
73
74/// Build a HashSet<String> from a Vec<&str> by cloning
75macro_rules! set_from_reference_vec {
76    ($vec:expr) => {
77        $vec.into_iter()
78            .map(|el| el.clone())
79            .collect::<HashSet<String>>()
80    };
81}
82
83/// Build a HashSet<String> from an Option<&MemoryReference>
84macro_rules! set_from_optional_memory_reference {
85    ($reference:expr) => {
86        set_from_reference_vec![$reference.map_or_else(Vec::new, |reference| vec![&reference.name])]
87    };
88}
89
90/// Build a HashSet<&String> from a Vec<&MemoryReference>
91macro_rules! set_from_memory_references {
92    ($references:expr) => {
93        set_from_reference_vec![$references.iter().map(|reference| &reference.name)]
94    };
95}
96
97#[derive(thiserror::Error, Debug, PartialEq, Clone)]
98pub enum MemoryAccessesError {
99    #[error(transparent)]
100    CallResolution(#[from] CallResolutionError),
101}
102
103pub type MemoryAccessesResult = Result<MemoryAccesses, MemoryAccessesError>;
104
105impl Instruction {
106    /// Return all memory accesses by the instruction - in expressions, captures, and memory manipulation.
107    ///
108    /// This will fail if the program contains [`Instruction::Call`] instructions that cannot
109    /// be resolved against a signature in the provided [`ExternSignatureMap`] (either because
110    /// they call functions that don't appear in the map or because the types of the parameters
111    /// are wrong).
112    pub fn get_memory_accesses(
113        &self,
114        extern_signature_map: &ExternSignatureMap,
115    ) -> MemoryAccessesResult {
116        Ok(match self {
117            Instruction::Convert(Convert {
118                source,
119                destination,
120            }) => MemoryAccesses {
121                reads: set_from_memory_references![[source]],
122                writes: set_from_memory_references![[destination]],
123                ..Default::default()
124            },
125            Instruction::Call(call) => call.get_memory_accesses(extern_signature_map)?,
126            Instruction::Comparison(Comparison {
127                destination,
128                lhs,
129                rhs,
130                operator: _,
131            }) => {
132                let mut reads = HashSet::from([lhs.name.clone()]);
133                let writes = HashSet::from([destination.name.clone()]);
134                if let ComparisonOperand::MemoryReference(mem) = &rhs {
135                    reads.insert(mem.name.clone());
136                }
137
138                MemoryAccesses {
139                    reads,
140                    writes,
141                    ..Default::default()
142                }
143            }
144            Instruction::BinaryLogic(BinaryLogic {
145                destination,
146                source,
147                operator: _,
148            }) => {
149                let mut reads = HashSet::new();
150                let mut writes = HashSet::new();
151                reads.insert(destination.name.clone());
152                writes.insert(destination.name.clone());
153                if let BinaryOperand::MemoryReference(mem) = &source {
154                    reads.insert(mem.name.clone());
155                }
156
157                MemoryAccesses {
158                    reads,
159                    writes,
160                    ..Default::default()
161                }
162            }
163            Instruction::UnaryLogic(UnaryLogic { operand, .. }) => MemoryAccesses {
164                reads: HashSet::from([operand.name.clone()]),
165                writes: HashSet::from([operand.name.clone()]),
166                ..Default::default()
167            },
168            Instruction::Arithmetic(Arithmetic {
169                destination,
170                source,
171                ..
172            }) => MemoryAccesses {
173                writes: HashSet::from([destination.name.clone()]),
174                reads: set_from_optional_memory_reference![source.get_memory_reference()],
175                ..Default::default()
176            },
177            Instruction::Move(Move {
178                destination,
179                source,
180            }) => MemoryAccesses {
181                writes: set_from_memory_references![[destination]],
182                reads: set_from_optional_memory_reference![source.get_memory_reference()],
183                ..Default::default()
184            },
185            Instruction::CalibrationDefinition(definition) => {
186                let references: Vec<&MemoryReference> = definition
187                    .identifier
188                    .parameters
189                    .iter()
190                    .flat_map(|expr| expr.get_memory_references())
191                    .collect();
192                MemoryAccesses {
193                    reads: set_from_memory_references![references],
194                    ..Default::default()
195                }
196            }
197            Instruction::Capture(Capture {
198                memory_reference,
199                waveform,
200                ..
201            }) => MemoryAccesses {
202                captures: set_from_memory_references!([memory_reference]),
203                reads: set_from_memory_references!(waveform.get_memory_references()),
204                ..Default::default()
205            },
206            Instruction::CircuitDefinition(CircuitDefinition { instructions, .. })
207            | Instruction::MeasureCalibrationDefinition(MeasureCalibrationDefinition {
208                instructions,
209                ..
210            }) => instructions.iter().try_fold(
211                Default::default(),
212                |acc: MemoryAccesses, el| -> MemoryAccessesResult {
213                    let el_accesses = el.get_memory_accesses(extern_signature_map)?;
214                    Ok(MemoryAccesses {
215                        reads: merge_sets!(acc.reads, el_accesses.reads),
216                        writes: merge_sets!(acc.writes, el_accesses.writes),
217                        captures: merge_sets!(acc.captures, el_accesses.captures),
218                    })
219                },
220            )?,
221            Instruction::Delay(Delay { duration, .. }) => MemoryAccesses {
222                reads: set_from_memory_references!(duration.get_memory_references()),
223                ..Default::default()
224            },
225            Instruction::Exchange(Exchange { left, right }) => MemoryAccesses {
226                reads: set_from_memory_references![[left, right]],
227                writes: set_from_memory_references![[left, right]],
228                ..Default::default()
229            },
230            Instruction::Gate(Gate { parameters, .. }) => MemoryAccesses {
231                reads: set_from_memory_references!(parameters
232                    .iter()
233                    .flat_map(|param| param.get_memory_references())
234                    .collect::<Vec<&MemoryReference>>()),
235                ..Default::default()
236            },
237            Instruction::GateDefinition(GateDefinition { specification, .. }) => {
238                if let GateSpecification::Matrix(matrix) = specification {
239                    let references = matrix
240                        .iter()
241                        .flat_map(|row| row.iter().flat_map(|cell| cell.get_memory_references()))
242                        .collect::<Vec<&MemoryReference>>();
243                    MemoryAccesses {
244                        reads: set_from_memory_references!(references),
245                        ..Default::default()
246                    }
247                } else {
248                    Default::default()
249                }
250            }
251            Instruction::JumpWhen(JumpWhen {
252                target: _,
253                condition,
254            })
255            | Instruction::JumpUnless(JumpUnless {
256                target: _,
257                condition,
258            }) => MemoryAccesses {
259                reads: set_from_memory_references!([condition]),
260                ..Default::default()
261            },
262            Instruction::Load(Load {
263                destination,
264                source,
265                offset,
266            }) => MemoryAccesses {
267                writes: set_from_memory_references![[destination]],
268                reads: set_from_reference_vec![vec![source, &offset.name]],
269                ..Default::default()
270            },
271            Instruction::Measurement(Measurement { target, .. }) => MemoryAccesses {
272                captures: set_from_optional_memory_reference!(target.as_ref()),
273                ..Default::default()
274            },
275            Instruction::Pulse(Pulse { waveform, .. }) => MemoryAccesses {
276                reads: set_from_memory_references![waveform.get_memory_references()],
277                ..Default::default()
278            },
279            Instruction::RawCapture(RawCapture {
280                duration,
281                memory_reference,
282                ..
283            }) => MemoryAccesses {
284                reads: set_from_memory_references![duration.get_memory_references()],
285                captures: set_from_memory_references![[memory_reference]],
286                ..Default::default()
287            },
288            Instruction::SetPhase(SetPhase { phase: expr, .. })
289            | Instruction::SetScale(SetScale { scale: expr, .. })
290            | Instruction::ShiftPhase(ShiftPhase { phase: expr, .. }) => MemoryAccesses {
291                reads: set_from_memory_references!(expr.get_memory_references()),
292                ..Default::default()
293            },
294            Instruction::SetFrequency(SetFrequency { frequency, .. })
295            | Instruction::ShiftFrequency(ShiftFrequency { frequency, .. }) => MemoryAccesses {
296                reads: set_from_memory_references!(frequency.get_memory_references()),
297                ..Default::default()
298            },
299            Instruction::Store(Store {
300                destination,
301                offset,
302                source,
303            }) => {
304                let mut reads = vec![&offset.name];
305                if let Some(source) = source.get_memory_reference() {
306                    reads.push(&source.name);
307                }
308                MemoryAccesses {
309                    reads: set_from_reference_vec![reads],
310                    writes: set_from_reference_vec![vec![destination]],
311                    ..Default::default()
312                }
313            }
314            Instruction::Declaration(_)
315            | Instruction::Fence(_)
316            | Instruction::FrameDefinition(_)
317            | Instruction::Halt()
318            | Instruction::Wait()
319            | Instruction::Include(_)
320            | Instruction::Jump(_)
321            | Instruction::Label(_)
322            | Instruction::Nop()
323            | Instruction::Pragma(_)
324            | Instruction::Reset(_)
325            | Instruction::SwapPhases(_)
326            | Instruction::WaveformDefinition(_) => Default::default(),
327        })
328    }
329}
330
331impl ArithmeticOperand {
332    pub fn get_memory_reference(&self) -> Option<&MemoryReference> {
333        match self {
334            ArithmeticOperand::LiteralInteger(_) => None,
335            ArithmeticOperand::LiteralReal(_) => None,
336            ArithmeticOperand::MemoryReference(reference) => Some(reference),
337        }
338    }
339}
340
341impl Expression {
342    /// Return, if any, the memory references contained within this Expression.
343    pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
344        match self {
345            Expression::Address(reference) => vec![reference],
346            Expression::FunctionCall(FunctionCallExpression { expression, .. }) => {
347                expression.get_memory_references()
348            }
349            Expression::Infix(InfixExpression { left, right, .. }) => {
350                let mut result = left.get_memory_references();
351                result.extend(right.get_memory_references());
352                result
353            }
354            Expression::Number(_) => vec![],
355            Expression::PiConstant() => vec![],
356            Expression::Prefix(PrefixExpression { expression, .. }) => {
357                expression.get_memory_references()
358            }
359            Expression::Variable(_) => vec![],
360        }
361    }
362}
363
364impl WaveformInvocation {
365    /// Return, if any, the memory references contained within this WaveformInvocation.
366    pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
367        self.parameters
368            .values()
369            .flat_map(Expression::get_memory_references)
370            .collect()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use rstest::rstest;
377
378    use crate::expression::Expression;
379    use crate::instruction::{
380        ArithmeticOperand, Convert, Exchange, ExternSignatureMap, FrameIdentifier, Instruction,
381        MemoryReference, Qubit, SetFrequency, ShiftFrequency, Store,
382    };
383    use crate::program::MemoryAccesses;
384    use std::collections::HashSet;
385
386    #[rstest]
387    #[case(
388        Instruction::Store(Store {
389            destination: "destination".to_string(),
390            offset: MemoryReference {
391                name: "offset".to_string(),
392                index: Default::default()
393            },
394            source: ArithmeticOperand::MemoryReference(MemoryReference {
395                name: "source".to_string(),
396                index: Default::default()
397            }),
398        }),
399        MemoryAccesses {
400            captures: HashSet::new(),
401            reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
402            writes: ["destination"].iter().cloned().map(String::from).collect(),
403        }
404    )]
405    #[case(
406        Instruction::Convert(Convert {
407            destination: MemoryReference {
408                name: "destination".to_string(),
409                index: Default::default()
410            },
411            source: MemoryReference {
412                name: "source".to_string(),
413                index: Default::default()
414            },
415        }),
416        MemoryAccesses {
417            captures: HashSet::new(),
418            reads: ["source"].iter().cloned().map(String::from).collect(),
419            writes: ["destination"].iter().cloned().map(String::from).collect(),
420        }
421    )]
422    #[case(
423        Instruction::Exchange(Exchange {
424            left: MemoryReference {
425                name: "left".to_string(),
426                index: Default::default()
427            },
428            right: MemoryReference {
429                name: "right".to_string(),
430                index: Default::default()
431            },
432        }),
433        MemoryAccesses {
434            captures: HashSet::new(),
435            reads: ["left", "right"].iter().cloned().map(String::from).collect(),
436            writes: ["left", "right"].iter().cloned().map(String::from).collect(),
437        }
438    )]
439    #[case(
440        Instruction::SetFrequency(SetFrequency {
441            frequency: Expression::Address(MemoryReference {
442                name: "frequency".to_string(),
443                index: Default::default()
444            }),
445            frame: FrameIdentifier {
446                name: "frame".to_string(),
447                qubits: vec![Qubit::Fixed(0)]
448            }
449        }),
450        MemoryAccesses {
451            captures: HashSet::new(),
452            reads: ["frequency"].iter().cloned().map(String::from).collect(),
453            writes: HashSet::new(),
454        }
455    )]
456    #[case(
457        Instruction::ShiftFrequency(ShiftFrequency {
458            frequency: Expression::Address(MemoryReference {
459                name: "frequency".to_string(),
460                index: Default::default()
461            }),
462            frame: FrameIdentifier {
463                name: "frame".to_string(),
464                qubits: vec![Qubit::Fixed(0)]
465            }
466        }),
467        MemoryAccesses {
468            captures: HashSet::new(),
469            reads: ["frequency"].iter().cloned().map(String::from).collect(),
470            writes: HashSet::new(),
471        }
472    )]
473    fn test_instruction_accesses(
474        #[case] instruction: Instruction,
475        #[case] expected: MemoryAccesses,
476    ) {
477        let memory_accesses = instruction
478            .get_memory_accesses(&ExternSignatureMap::default())
479            .expect("must be able to get memory accesses");
480        assert_eq!(memory_accesses.captures, expected.captures);
481        assert_eq!(memory_accesses.reads, expected.reads);
482        assert_eq!(memory_accesses.writes, expected.writes);
483    }
484}