use std::collections::HashSet;
#[cfg(feature = "stubs")]
use pyo3_stub_gen::derive::gen_stub_pyclass;
use crate::{
expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression},
instruction::{CallResolutionError, MemoryReference, Sharing, Vector, WaveformInvocation},
pickleable_new,
};
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "quil.program", eq, frozen, hash, get_all, subclass)
)]
pub struct MemoryRegion {
pub size: Vector,
pub sharing: Option<Sharing>,
}
pickleable_new! {
impl MemoryRegion {
pub fn new(size: Vector, sharing: Option<Sharing>);
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct MemoryAccesses {
pub reads: HashSet<String>,
pub writes: HashSet<String>,
pub captures: HashSet<String>,
}
impl MemoryAccesses {
#[inline]
pub fn none() -> Self {
Self::default()
}
pub fn union(mut self, rhs: Self) -> Self {
let Self {
captures,
reads,
writes,
} = rhs;
self.captures.extend(captures);
self.reads.extend(reads);
self.writes.extend(writes);
self
}
}
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
pub enum MemoryAccessType {
Read,
Write,
Capture,
}
#[derive(Clone, PartialEq, Debug, thiserror::Error)]
pub enum MemoryAccessesError {
#[error(transparent)]
CallResolution(#[from] CallResolutionError),
#[error("Instruction handler reported an error when constructing memory accesses: {0}")]
InstructionHandlerError(String),
}
pub mod expression {
use super::*;
#[derive(Clone, Debug)]
pub struct MemoryReferences<'a> {
pub(super) stack: Vec<&'a Expression>,
}
impl<'a> Iterator for MemoryReferences<'a> {
type Item = &'a MemoryReference;
fn next(&mut self) -> Option<Self::Item> {
let Self { stack } = self;
'stack_search: while let Some(mut expr) = stack.pop() {
loop {
match expr {
Expression::Number(_)
| Expression::PiConstant()
| Expression::Variable(_) => continue 'stack_search,
Expression::Address(reference) => return Some(reference),
Expression::FunctionCall(FunctionCallExpression {
expression,
function: _,
})
| Expression::Prefix(PrefixExpression {
expression,
operator: _,
}) => expr = expression,
Expression::Infix(InfixExpression {
left,
right,
operator: _,
}) => {
stack.push(right);
expr = left;
}
}
}
}
None
}
}
impl std::iter::FusedIterator for MemoryReferences<'_> {}
}
impl Expression {
pub fn memory_references(&self) -> expression::MemoryReferences<'_> {
expression::MemoryReferences { stack: vec![self] }
}
}
impl WaveformInvocation {
pub fn memory_references(&self) -> impl std::iter::FusedIterator<Item = &MemoryReference> {
self.parameters
.values()
.flat_map(Expression::memory_references)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use rstest::rstest;
use crate::{
expression::Expression,
instruction::{
ArithmeticOperand, Convert, DefaultHandler, Exchange, ExternSignatureMap,
FrameIdentifier, Instruction, InstructionHandler as _, MemoryReference, Qubit,
SetFrequency, ShiftFrequency, Store,
},
program::MemoryAccesses,
};
#[rstest]
#[case(
r#"
cis(func_ref[0]) ^
cos(func_ref[1]) +
exp(func_ref[2]) -
sin(func_ref[3]) /
sqrt(func_ref[4]) *
(infix_ref[0] ^ infix_ref[0]) ^
(infix_ref[1] + infix_ref[1]) +
(infix_ref[2] - infix_ref[2]) -
(infix_ref[3] / infix_ref[3]) /
(infix_ref[4] * infix_ref[4]) *
1.0 ^
pi +
(-prefix_ref) -
%variable
"#,
&[
("func_ref", 0),
("func_ref", 1),
("func_ref", 2),
("func_ref", 3),
("func_ref", 4),
("infix_ref", 0),
("infix_ref", 0),
("infix_ref", 1),
("infix_ref", 1),
("infix_ref", 2),
("infix_ref", 2),
("infix_ref", 3),
("infix_ref", 3),
("infix_ref", 4),
("infix_ref", 4),
("prefix_ref", 0),
]
)]
fn expr_references(#[case] expr: &str, #[case] expected_refs: &[(&str, u64)]) {
let expr = expr.replace('\n', " ").parse::<Expression>().unwrap();
let computed_refs: Vec<_> = expr.memory_references().cloned().collect();
let expected_refs: Vec<_> = expected_refs
.iter()
.map(|(name, index)| MemoryReference {
name: (*name).to_owned(),
index: *index,
})
.collect();
assert_eq!(computed_refs, expected_refs);
}
#[rstest]
#[case(
Instruction::Store(Store {
destination: "destination".to_string(),
offset: MemoryReference {
name: "offset".to_string(),
index: Default::default()
},
source: ArithmeticOperand::MemoryReference(MemoryReference {
name: "source".to_string(),
index: Default::default()
}),
}),
MemoryAccesses {
captures: HashSet::new(),
reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
writes: ["destination"].iter().cloned().map(String::from).collect(),
}
)]
#[case(
Instruction::Convert(Convert {
destination: MemoryReference {
name: "destination".to_string(),
index: Default::default()
},
source: MemoryReference {
name: "source".to_string(),
index: Default::default()
},
}),
MemoryAccesses {
captures: HashSet::new(),
reads: ["source"].iter().cloned().map(String::from).collect(),
writes: ["destination"].iter().cloned().map(String::from).collect(),
}
)]
#[case(
Instruction::Exchange(Exchange {
left: MemoryReference {
name: "left".to_string(),
index: Default::default()
},
right: MemoryReference {
name: "right".to_string(),
index: Default::default()
},
}),
MemoryAccesses {
captures: HashSet::new(),
reads: ["left", "right"].iter().cloned().map(String::from).collect(),
writes: ["left", "right"].iter().cloned().map(String::from).collect(),
}
)]
#[case(
Instruction::SetFrequency(SetFrequency {
frequency: Expression::Address(MemoryReference {
name: "frequency".to_string(),
index: Default::default()
}),
frame: FrameIdentifier {
name: "frame".to_string(),
qubits: vec![Qubit::Fixed(0)]
}
}),
MemoryAccesses {
captures: HashSet::new(),
reads: ["frequency"].iter().cloned().map(String::from).collect(),
writes: HashSet::new(),
}
)]
#[case(
Instruction::ShiftFrequency(ShiftFrequency {
frequency: Expression::Address(MemoryReference {
name: "frequency".to_string(),
index: Default::default()
}),
frame: FrameIdentifier {
name: "frame".to_string(),
qubits: vec![Qubit::Fixed(0)]
}
}),
MemoryAccesses {
captures: HashSet::new(),
reads: ["frequency"].iter().cloned().map(String::from).collect(),
writes: HashSet::new(),
}
)]
fn test_instruction_accesses(
#[case] instruction: Instruction,
#[case] expected: MemoryAccesses,
) {
let memory_accesses = DefaultHandler
.memory_accesses(&ExternSignatureMap::default(), &instruction)
.expect("must be able to get memory accesses");
assert_eq!(memory_accesses.captures, expected.captures);
assert_eq!(memory_accesses.reads, expected.reads);
assert_eq!(memory_accesses.writes, expected.writes);
}
}