1use std::collections::HashSet;
16
17#[cfg(feature = "stubs")]
18use pyo3_stub_gen::derive::gen_stub_pyclass;
19
20use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression};
21use crate::instruction::{
22 Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, CallResolutionError, Capture,
23 CircuitDefinition, Comparison, ComparisonOperand, Convert, Delay, Exchange, ExternSignatureMap,
24 Gate, GateDefinition, GateSpecification, Instruction, JumpUnless, JumpWhen, Load,
25 MeasureCalibrationDefinition, Measurement, MemoryReference, Move, Pulse, RawCapture,
26 SetFrequency, SetPhase, SetScale, Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic,
27 Vector, WaveformInvocation,
28};
29use crate::pickleable_new;
30
31#[derive(Clone, Debug, Hash, PartialEq, Eq)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "quil.program", eq, frozen, hash, get_all, subclass)
36)]
37pub struct MemoryRegion {
38 pub size: Vector,
39 pub sharing: Option<Sharing>,
40}
41
42pickleable_new! {
43 impl MemoryRegion {
44 pub fn new(size: Vector, sharing: Option<Sharing>);
45 }
46}
47
48#[derive(Clone, Debug, Default, PartialEq)]
49pub struct MemoryAccesses {
50 pub captures: HashSet<String>,
51 pub reads: HashSet<String>,
52 pub writes: HashSet<String>,
53}
54
55#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
57pub enum MemoryAccessType {
58 Read,
60
61 Write,
63
64 Capture,
66}
67
68macro_rules! merge_sets {
69 ($left:expr, $right:expr) => {
70 $left.union(&$right).cloned().collect::<HashSet<String>>()
71 };
72}
73
74macro_rules! set_from_reference_vec {
76 ($vec:expr) => {
77 $vec.into_iter()
78 .map(|el| el.clone())
79 .collect::<HashSet<String>>()
80 };
81}
82
83macro_rules! set_from_optional_memory_reference {
85 ($reference:expr) => {
86 set_from_reference_vec![$reference.map_or_else(Vec::new, |reference| vec![&reference.name])]
87 };
88}
89
90macro_rules! set_from_memory_references {
92 ($references:expr) => {
93 set_from_reference_vec![$references.iter().map(|reference| &reference.name)]
94 };
95}
96
97#[derive(thiserror::Error, Debug, PartialEq, Clone)]
98pub enum MemoryAccessesError {
99 #[error(transparent)]
100 CallResolution(#[from] CallResolutionError),
101}
102
103pub type MemoryAccessesResult = Result<MemoryAccesses, MemoryAccessesError>;
104
105impl Instruction {
106 pub fn get_memory_accesses(
113 &self,
114 extern_signature_map: &ExternSignatureMap,
115 ) -> MemoryAccessesResult {
116 Ok(match self {
117 Instruction::Convert(Convert {
118 source,
119 destination,
120 }) => MemoryAccesses {
121 reads: set_from_memory_references![[source]],
122 writes: set_from_memory_references![[destination]],
123 ..Default::default()
124 },
125 Instruction::Call(call) => call.get_memory_accesses(extern_signature_map)?,
126 Instruction::Comparison(Comparison {
127 destination,
128 lhs,
129 rhs,
130 operator: _,
131 }) => {
132 let mut reads = HashSet::from([lhs.name.clone()]);
133 let writes = HashSet::from([destination.name.clone()]);
134 if let ComparisonOperand::MemoryReference(mem) = &rhs {
135 reads.insert(mem.name.clone());
136 }
137
138 MemoryAccesses {
139 reads,
140 writes,
141 ..Default::default()
142 }
143 }
144 Instruction::BinaryLogic(BinaryLogic {
145 destination,
146 source,
147 operator: _,
148 }) => {
149 let mut reads = HashSet::new();
150 let mut writes = HashSet::new();
151 reads.insert(destination.name.clone());
152 writes.insert(destination.name.clone());
153 if let BinaryOperand::MemoryReference(mem) = &source {
154 reads.insert(mem.name.clone());
155 }
156
157 MemoryAccesses {
158 reads,
159 writes,
160 ..Default::default()
161 }
162 }
163 Instruction::UnaryLogic(UnaryLogic { operand, .. }) => MemoryAccesses {
164 reads: HashSet::from([operand.name.clone()]),
165 writes: HashSet::from([operand.name.clone()]),
166 ..Default::default()
167 },
168 Instruction::Arithmetic(Arithmetic {
169 destination,
170 source,
171 ..
172 }) => MemoryAccesses {
173 writes: HashSet::from([destination.name.clone()]),
174 reads: set_from_optional_memory_reference![source.get_memory_reference()],
175 ..Default::default()
176 },
177 Instruction::Move(Move {
178 destination,
179 source,
180 }) => MemoryAccesses {
181 writes: set_from_memory_references![[destination]],
182 reads: set_from_optional_memory_reference![source.get_memory_reference()],
183 ..Default::default()
184 },
185 Instruction::CalibrationDefinition(definition) => {
186 let references: Vec<&MemoryReference> = definition
187 .identifier
188 .parameters
189 .iter()
190 .flat_map(|expr| expr.get_memory_references())
191 .collect();
192 MemoryAccesses {
193 reads: set_from_memory_references![references],
194 ..Default::default()
195 }
196 }
197 Instruction::Capture(Capture {
198 memory_reference,
199 waveform,
200 ..
201 }) => MemoryAccesses {
202 captures: set_from_memory_references!([memory_reference]),
203 reads: set_from_memory_references!(waveform.get_memory_references()),
204 ..Default::default()
205 },
206 Instruction::CircuitDefinition(CircuitDefinition { instructions, .. })
207 | Instruction::MeasureCalibrationDefinition(MeasureCalibrationDefinition {
208 instructions,
209 ..
210 }) => instructions.iter().try_fold(
211 Default::default(),
212 |acc: MemoryAccesses, el| -> MemoryAccessesResult {
213 let el_accesses = el.get_memory_accesses(extern_signature_map)?;
214 Ok(MemoryAccesses {
215 reads: merge_sets!(acc.reads, el_accesses.reads),
216 writes: merge_sets!(acc.writes, el_accesses.writes),
217 captures: merge_sets!(acc.captures, el_accesses.captures),
218 })
219 },
220 )?,
221 Instruction::Delay(Delay { duration, .. }) => MemoryAccesses {
222 reads: set_from_memory_references!(duration.get_memory_references()),
223 ..Default::default()
224 },
225 Instruction::Exchange(Exchange { left, right }) => MemoryAccesses {
226 reads: set_from_memory_references![[left, right]],
227 writes: set_from_memory_references![[left, right]],
228 ..Default::default()
229 },
230 Instruction::Gate(Gate { parameters, .. }) => MemoryAccesses {
231 reads: set_from_memory_references!(parameters
232 .iter()
233 .flat_map(|param| param.get_memory_references())
234 .collect::<Vec<&MemoryReference>>()),
235 ..Default::default()
236 },
237 Instruction::GateDefinition(GateDefinition { specification, .. }) => {
238 if let GateSpecification::Matrix(matrix) = specification {
239 let references = matrix
240 .iter()
241 .flat_map(|row| row.iter().flat_map(|cell| cell.get_memory_references()))
242 .collect::<Vec<&MemoryReference>>();
243 MemoryAccesses {
244 reads: set_from_memory_references!(references),
245 ..Default::default()
246 }
247 } else {
248 Default::default()
249 }
250 }
251 Instruction::JumpWhen(JumpWhen {
252 target: _,
253 condition,
254 })
255 | Instruction::JumpUnless(JumpUnless {
256 target: _,
257 condition,
258 }) => MemoryAccesses {
259 reads: set_from_memory_references!([condition]),
260 ..Default::default()
261 },
262 Instruction::Load(Load {
263 destination,
264 source,
265 offset,
266 }) => MemoryAccesses {
267 writes: set_from_memory_references![[destination]],
268 reads: set_from_reference_vec![vec![source, &offset.name]],
269 ..Default::default()
270 },
271 Instruction::Measurement(Measurement { target, .. }) => MemoryAccesses {
272 captures: set_from_optional_memory_reference!(target.as_ref()),
273 ..Default::default()
274 },
275 Instruction::Pulse(Pulse { waveform, .. }) => MemoryAccesses {
276 reads: set_from_memory_references![waveform.get_memory_references()],
277 ..Default::default()
278 },
279 Instruction::RawCapture(RawCapture {
280 duration,
281 memory_reference,
282 ..
283 }) => MemoryAccesses {
284 reads: set_from_memory_references![duration.get_memory_references()],
285 captures: set_from_memory_references![[memory_reference]],
286 ..Default::default()
287 },
288 Instruction::SetPhase(SetPhase { phase: expr, .. })
289 | Instruction::SetScale(SetScale { scale: expr, .. })
290 | Instruction::ShiftPhase(ShiftPhase { phase: expr, .. }) => MemoryAccesses {
291 reads: set_from_memory_references!(expr.get_memory_references()),
292 ..Default::default()
293 },
294 Instruction::SetFrequency(SetFrequency { frequency, .. })
295 | Instruction::ShiftFrequency(ShiftFrequency { frequency, .. }) => MemoryAccesses {
296 reads: set_from_memory_references!(frequency.get_memory_references()),
297 ..Default::default()
298 },
299 Instruction::Store(Store {
300 destination,
301 offset,
302 source,
303 }) => {
304 let mut reads = vec![&offset.name];
305 if let Some(source) = source.get_memory_reference() {
306 reads.push(&source.name);
307 }
308 MemoryAccesses {
309 reads: set_from_reference_vec![reads],
310 writes: set_from_reference_vec![vec![destination]],
311 ..Default::default()
312 }
313 }
314 Instruction::Declaration(_)
315 | Instruction::Fence(_)
316 | Instruction::FrameDefinition(_)
317 | Instruction::Halt()
318 | Instruction::Wait()
319 | Instruction::Include(_)
320 | Instruction::Jump(_)
321 | Instruction::Label(_)
322 | Instruction::Nop()
323 | Instruction::Pragma(_)
324 | Instruction::Reset(_)
325 | Instruction::SwapPhases(_)
326 | Instruction::WaveformDefinition(_) => Default::default(),
327 })
328 }
329}
330
331impl ArithmeticOperand {
332 pub fn get_memory_reference(&self) -> Option<&MemoryReference> {
333 match self {
334 ArithmeticOperand::LiteralInteger(_) => None,
335 ArithmeticOperand::LiteralReal(_) => None,
336 ArithmeticOperand::MemoryReference(reference) => Some(reference),
337 }
338 }
339}
340
341impl Expression {
342 pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
344 match self {
345 Expression::Address(reference) => vec![reference],
346 Expression::FunctionCall(FunctionCallExpression { expression, .. }) => {
347 expression.get_memory_references()
348 }
349 Expression::Infix(InfixExpression { left, right, .. }) => {
350 let mut result = left.get_memory_references();
351 result.extend(right.get_memory_references());
352 result
353 }
354 Expression::Number(_) => vec![],
355 Expression::PiConstant() => vec![],
356 Expression::Prefix(PrefixExpression { expression, .. }) => {
357 expression.get_memory_references()
358 }
359 Expression::Variable(_) => vec![],
360 }
361 }
362}
363
364impl WaveformInvocation {
365 pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
367 self.parameters
368 .values()
369 .flat_map(Expression::get_memory_references)
370 .collect()
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use rstest::rstest;
377
378 use crate::expression::Expression;
379 use crate::instruction::{
380 ArithmeticOperand, Convert, Exchange, ExternSignatureMap, FrameIdentifier, Instruction,
381 MemoryReference, Qubit, SetFrequency, ShiftFrequency, Store,
382 };
383 use crate::program::MemoryAccesses;
384 use std::collections::HashSet;
385
386 #[rstest]
387 #[case(
388 Instruction::Store(Store {
389 destination: "destination".to_string(),
390 offset: MemoryReference {
391 name: "offset".to_string(),
392 index: Default::default()
393 },
394 source: ArithmeticOperand::MemoryReference(MemoryReference {
395 name: "source".to_string(),
396 index: Default::default()
397 }),
398 }),
399 MemoryAccesses {
400 captures: HashSet::new(),
401 reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
402 writes: ["destination"].iter().cloned().map(String::from).collect(),
403 }
404 )]
405 #[case(
406 Instruction::Convert(Convert {
407 destination: MemoryReference {
408 name: "destination".to_string(),
409 index: Default::default()
410 },
411 source: MemoryReference {
412 name: "source".to_string(),
413 index: Default::default()
414 },
415 }),
416 MemoryAccesses {
417 captures: HashSet::new(),
418 reads: ["source"].iter().cloned().map(String::from).collect(),
419 writes: ["destination"].iter().cloned().map(String::from).collect(),
420 }
421 )]
422 #[case(
423 Instruction::Exchange(Exchange {
424 left: MemoryReference {
425 name: "left".to_string(),
426 index: Default::default()
427 },
428 right: MemoryReference {
429 name: "right".to_string(),
430 index: Default::default()
431 },
432 }),
433 MemoryAccesses {
434 captures: HashSet::new(),
435 reads: ["left", "right"].iter().cloned().map(String::from).collect(),
436 writes: ["left", "right"].iter().cloned().map(String::from).collect(),
437 }
438 )]
439 #[case(
440 Instruction::SetFrequency(SetFrequency {
441 frequency: Expression::Address(MemoryReference {
442 name: "frequency".to_string(),
443 index: Default::default()
444 }),
445 frame: FrameIdentifier {
446 name: "frame".to_string(),
447 qubits: vec![Qubit::Fixed(0)]
448 }
449 }),
450 MemoryAccesses {
451 captures: HashSet::new(),
452 reads: ["frequency"].iter().cloned().map(String::from).collect(),
453 writes: HashSet::new(),
454 }
455 )]
456 #[case(
457 Instruction::ShiftFrequency(ShiftFrequency {
458 frequency: Expression::Address(MemoryReference {
459 name: "frequency".to_string(),
460 index: Default::default()
461 }),
462 frame: FrameIdentifier {
463 name: "frame".to_string(),
464 qubits: vec![Qubit::Fixed(0)]
465 }
466 }),
467 MemoryAccesses {
468 captures: HashSet::new(),
469 reads: ["frequency"].iter().cloned().map(String::from).collect(),
470 writes: HashSet::new(),
471 }
472 )]
473 fn test_instruction_accesses(
474 #[case] instruction: Instruction,
475 #[case] expected: MemoryAccesses,
476 ) {
477 let memory_accesses = instruction
478 .get_memory_accesses(&ExternSignatureMap::default())
479 .expect("must be able to get memory accesses");
480 assert_eq!(memory_accesses.captures, expected.captures);
481 assert_eq!(memory_accesses.reads, expected.reads);
482 assert_eq!(memory_accesses.writes, expected.writes);
483 }
484}