quil_rs/program/scheduling/
schedule.rs

1//! A Schedule represents a flattening of the [`DependencyGraph`] into a linear sequence of
2//! instructions, with each instruction assigned a start time and duration.
3
4use std::collections::HashMap;
5
6use itertools::Itertools;
7use petgraph::{
8    visit::{EdgeFiltered, Topo},
9    Direction,
10};
11
12use crate::{
13    instruction::{
14        AttributeValue, Capture, Delay, Instruction, Pulse, RawCapture, WaveformInvocation,
15    },
16    quil::Quil,
17    Program,
18};
19
20use super::{ExecutionDependency, ScheduledBasicBlock, ScheduledGraphNode};
21
22#[derive(Clone, Debug, Default, PartialEq, PartialOrd)]
23#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
24pub struct Seconds(pub f64);
25
26impl std::ops::Add<Seconds> for Seconds {
27    type Output = Seconds;
28
29    fn add(self, rhs: Seconds) -> Self::Output {
30        Self(rhs.0 + self.0)
31    }
32}
33
34impl std::ops::Sub<Seconds> for Seconds {
35    type Output = Seconds;
36
37    fn sub(self, rhs: Seconds) -> Self::Output {
38        Self(self.0 - rhs.0)
39    }
40}
41
42pub trait Zero: PartialEq + Sized {
43    fn zero() -> Self;
44
45    fn is_zero(&self) -> bool {
46        self == &Self::zero()
47    }
48}
49
50impl Zero for Seconds {
51    fn zero() -> Self {
52        Self(0.0)
53    }
54}
55
56#[derive(Clone, Debug, PartialEq)]
57pub struct Schedule<TimeUnit> {
58    pub(crate) items: Vec<ComputedScheduleItem<TimeUnit>>,
59    /// The total duration of the block. This is the end time of the schedule when it starts at `TimeUnit::zero()`
60    duration: TimeUnit,
61}
62
63impl<TimeUnit> Schedule<TimeUnit> {
64    pub fn duration(&self) -> &TimeUnit {
65        &self.duration
66    }
67
68    pub fn items(&self) -> &[ComputedScheduleItem<TimeUnit>] {
69        self.items.as_ref()
70    }
71
72    pub fn into_items(self) -> Vec<ComputedScheduleItem<TimeUnit>> {
73        self.items
74    }
75}
76
77impl<TimeUnit: Clone + PartialOrd + std::ops::Add<TimeUnit, Output = TimeUnit> + Zero>
78    From<Vec<ComputedScheduleItem<TimeUnit>>> for Schedule<TimeUnit>
79{
80    fn from(items: Vec<ComputedScheduleItem<TimeUnit>>) -> Self {
81        let duration = items
82            .iter()
83            .map(|item| item.time_span.start_time.clone() + item.time_span.duration.clone())
84            .fold(TimeUnit::zero(), |acc, el| if el > acc { el } else { acc });
85        Self { items, duration }
86    }
87}
88
89impl<TimeUnit: Zero> Default for Schedule<TimeUnit> {
90    fn default() -> Self {
91        Self {
92            items: Default::default(),
93            duration: TimeUnit::zero(),
94        }
95    }
96}
97
98pub type ScheduleSeconds = Schedule<Seconds>;
99
100#[derive(Clone, Debug, PartialEq)]
101pub struct ComputedScheduleItem<TimeUnit> {
102    pub time_span: TimeSpan<TimeUnit>,
103    pub instruction_index: usize,
104}
105
106#[derive(Debug, thiserror::Error)]
107pub enum ComputedScheduleError {
108    #[error("unknown duration for instruction {}", instruction.to_quil_or_debug())]
109    UnknownDuration { instruction: Instruction },
110
111    #[error("internal error: invalid dependency graph")]
112    InvalidDependencyGraph,
113}
114
115pub type ComputedScheduleResult<T> = Result<T, ComputedScheduleError>;
116
117/// Represents a span of time, for some unit of time
118#[derive(Clone, Debug, PartialEq)]
119pub struct TimeSpan<TimeUnit> {
120    /// The inclusive start time of the described item
121    pub start_time: TimeUnit,
122
123    /// The described item's continuous duration
124    pub duration: TimeUnit,
125}
126
127impl<TimeUnit> TimeSpan<TimeUnit> {
128    pub fn start_time(&self) -> &TimeUnit {
129        &self.start_time
130    }
131
132    pub fn duration(&self) -> &TimeUnit {
133        &self.duration
134    }
135}
136
137impl<TimeUnit: Clone + std::ops::Add<TimeUnit, Output = TimeUnit>> TimeSpan<TimeUnit> {
138    pub fn end(&self) -> TimeUnit {
139        self.start_time.clone() + self.duration.clone()
140    }
141}
142
143impl<
144        TimeUnit: Clone
145            + PartialOrd
146            + std::ops::Add<TimeUnit, Output = TimeUnit>
147            + std::ops::Sub<TimeUnit, Output = TimeUnit>,
148    > TimeSpan<TimeUnit>
149{
150    pub(crate) fn union(self, rhs: Self) -> Self {
151        let start_time = if rhs.start_time < self.start_time {
152            rhs.start_time.clone()
153        } else {
154            self.start_time.clone()
155        };
156
157        let self_end_time = self.start_time.clone() + self.duration;
158        let rhs_end_time = rhs.start_time + rhs.duration;
159        let end_time = if self_end_time < rhs_end_time {
160            rhs_end_time
161        } else {
162            self_end_time
163        };
164
165        Self {
166            duration: end_time - start_time.clone(),
167            start_time,
168        }
169    }
170}
171
172impl<'p> ScheduledBasicBlock<'p> {
173    /// Return the duration of a scheduled Quil instruction:
174    ///
175    /// * For PULSE and CAPTURE, this is the duration of the waveform at the frame's sample rate
176    /// * For DELAY and RAW-CAPTURE, it's the named duration
177    /// * For supporting instructions like SET-*, SHIFT-*, and FENCE, it's 0
178    ///
179    /// Return `None` for other instructions.
180    pub(crate) fn get_instruction_duration_seconds(
181        program: &Program,
182        instruction: &Instruction,
183    ) -> Option<Seconds> {
184        match instruction {
185            Instruction::Capture(Capture { waveform, .. })
186            | Instruction::Pulse(Pulse { waveform, .. }) => {
187                Self::get_waveform_duration_seconds(program, instruction, waveform)
188            }
189            Instruction::Delay(Delay { duration, .. })
190            | Instruction::RawCapture(RawCapture { duration, .. }) => {
191                duration.to_real().ok().map(Seconds)
192            }
193            Instruction::Fence(_)
194            | Instruction::SetFrequency(_)
195            | Instruction::SetPhase(_)
196            | Instruction::SetScale(_)
197            | Instruction::ShiftFrequency(_)
198            | Instruction::ShiftPhase(_)
199            | Instruction::SwapPhases(_) => Some(Seconds(0.0)),
200            _ => None,
201        }
202    }
203
204    /// Return the duration of a Quil waveform:
205    ///
206    /// If the waveform is defined in the program with `DEFWAVEFORM`, the duration is the sample count
207    /// divided by the sample rate.
208    ///
209    /// Otherwise, it's the `duration` parameter of the waveform invocation. This relies on the assumption that
210    /// all template waveforms in use have such a parameter in units of seconds.
211    fn get_waveform_duration_seconds(
212        program: &Program,
213        instruction: &Instruction,
214        WaveformInvocation { name, parameters }: &WaveformInvocation,
215    ) -> Option<Seconds> {
216        if let Some(definition) = program.waveforms.get(name) {
217            let sample_count = definition.matrix.len();
218            let common_sample_rate =
219                program
220                    .get_frames_for_instruction(instruction)
221                    .and_then(|frames| {
222                        frames
223                            .used
224                            .into_iter()
225                            .filter_map(|frame| {
226                                program
227                                    .frames
228                                    .get(frame)
229                                    .and_then(|frame_definition| {
230                                        frame_definition.get("SAMPLE-RATE")
231                                    })
232                                    .and_then(|sample_rate_expression| match sample_rate_expression
233                                    {
234                                        AttributeValue::String(_) => None,
235                                        AttributeValue::Expression(expression) => Some(expression),
236                                    })
237                                    .and_then(|expression| expression.to_real().ok())
238                            })
239                            .all_equal_value()
240                            .ok()
241                    });
242
243            common_sample_rate
244                .map(|sample_rate| sample_count as f64 / sample_rate)
245                .map(Seconds)
246        } else {
247            // Per the Quil spec, all waveform templates have a "duration"
248            // parameter, and "erf_square" also has "pad_left" and "pad_right".
249            // We explicitly choose to be more flexible here, and allow any
250            // built-in waveform templates to have "pad_*" parameters, as well
251            // as allow "erf_square" to omit them.
252            let parameter = |parameter_name| {
253                parameters
254                    .get(parameter_name)
255                    .and_then(|v| v.to_real().ok())
256                    .map(Seconds)
257            };
258            Some(
259                parameter("duration")?
260                    + parameter("pad_left").unwrap_or(Seconds::zero())
261                    + parameter("pad_right").unwrap_or(Seconds::zero()),
262            )
263        }
264    }
265
266    /// Compute the flattened schedule for this [`ScheduledBasicBlock`] in terms of seconds,
267    /// using a default built-in calculation for the duration of scheduled instructions.
268    pub fn as_schedule_seconds(
269        &self,
270        program: &Program,
271    ) -> ComputedScheduleResult<ScheduleSeconds> {
272        self.as_schedule(program, Self::get_instruction_duration_seconds)
273    }
274
275    /// Compute the flattened schedule for this [`ScheduledBasicBlock`] using a user-provided
276    /// closure for computation of instruction duration.
277    ///
278    /// Return an error if the schedule cannot be computed from the information provided.
279    pub fn as_schedule<
280        F,
281        TimeUnit: Clone + PartialOrd + std::ops::Add<TimeUnit, Output = TimeUnit> + Zero,
282    >(
283        &self,
284        program: &'p Program,
285        get_duration: F,
286    ) -> ComputedScheduleResult<Schedule<TimeUnit>>
287    where
288        F: Fn(&'p Program, &'p Instruction) -> Option<TimeUnit>,
289    {
290        let mut schedule = Schedule::default();
291        let mut end_time_by_instruction_index = HashMap::<usize, TimeUnit>::new();
292
293        let graph_filtered = EdgeFiltered::from_fn(&self.graph, |(_, _, dependencies)| {
294            dependencies.contains(&ExecutionDependency::Scheduled)
295        });
296        let mut topo = Topo::new(&graph_filtered);
297
298        while let Some(instruction_node) = topo.next(&graph_filtered) {
299            if let ScheduledGraphNode::InstructionIndex(index) = instruction_node {
300                let instruction = *self
301                    .basic_block()
302                    .instructions()
303                    .get(index)
304                    .ok_or_else(|| ComputedScheduleError::InvalidDependencyGraph)?;
305                let duration = get_duration(program, instruction).ok_or(
306                    ComputedScheduleError::UnknownDuration {
307                        instruction: instruction.clone(),
308                    },
309                )?;
310
311                let latest_previous_instruction_scheduler_end_time = self
312                    .graph
313                    .edges_directed(instruction_node, Direction::Incoming)
314                    .filter_map(|(source, _, dependencies)| {
315                        if dependencies.contains(&ExecutionDependency::Scheduled) {
316                            match source {
317                                ScheduledGraphNode::BlockStart => Ok(Some(TimeUnit::zero())),
318                                ScheduledGraphNode::InstructionIndex(previous_index) => {
319                                    end_time_by_instruction_index
320                                        .get(&previous_index)
321                                        .cloned()
322                                        .ok_or(ComputedScheduleError::InvalidDependencyGraph)
323                                        .map(Some)
324                                }
325                                ScheduledGraphNode::BlockEnd => unreachable!(),
326                            }
327                        } else {
328                            Ok(None)
329                        }
330                        .transpose()
331                    })
332                    .collect::<Result<Vec<TimeUnit>, _>>()?
333                    .into_iter()
334                    // this implementation allows us to require PartialOrd instead of Ord (required for `.max()`),
335                    // which is convenient for f64
336                    .fold(TimeUnit::zero(), |acc, el| if el > acc { el } else { acc });
337
338                let start_time = latest_previous_instruction_scheduler_end_time;
339                let end_time = start_time.clone() + duration.clone();
340                if schedule.duration < end_time {
341                    schedule.duration = end_time.clone();
342                }
343
344                end_time_by_instruction_index.insert(index, end_time);
345                schedule.items.push(ComputedScheduleItem {
346                    time_span: TimeSpan {
347                        start_time,
348                        duration,
349                    },
350                    instruction_index: index,
351                });
352            }
353        }
354
355        Ok(schedule)
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use core::panic;
362    use std::str::FromStr;
363
364    use crate::{instruction::InstructionHandler, program::scheduling::TimeSpan, Program};
365
366    #[rstest::rstest]
367    #[case("CAPTURE 0 \"a\" flat(duration: 1.0) ro", Some(1.0))]
368    #[case("DELAY 0 \"a\" 1.0", Some(1.0))]
369    #[case("FENCE", Some(0.0))]
370    #[case("PULSE 0 \"a\" flat(duration: 1.0)", Some(1.0))]
371    #[case("RAW-CAPTURE 0 \"a\" 1.0 ro", Some(1.0))]
372    #[case("RESET", None)]
373    #[case("SET-FREQUENCY 0 \"a\" 1.0", Some(0.0))]
374    #[case("SET-PHASE 0 \"a\" 1.0", Some(0.0))]
375    #[case("SET-SCALE 0 \"a\" 1.0", Some(0.0))]
376    #[case("SHIFT-FREQUENCY 0 \"a\" 1.0", Some(0.0))]
377    #[case("SHIFT-PHASE 0 \"a\" 1.0", Some(0.0))]
378    #[case("SWAP-PHASES 0 \"a\" 0 \"b\"", Some(0.0))]
379    fn instruction_duration_seconds(
380        #[case] input_program: &str,
381        #[case] expected_duration: Option<f64>,
382    ) {
383        let empty_program = Program::new();
384        let program = Program::from_str(input_program)
385            .map_err(|e| e.to_string())
386            .unwrap();
387        let instruction = program.into_instructions().remove(0);
388        let duration =
389            crate::program::scheduling::ScheduledBasicBlock::get_instruction_duration_seconds(
390                &empty_program,
391                &instruction,
392            );
393        assert_eq!(
394            expected_duration.map(crate::program::scheduling::Seconds),
395            duration
396        );
397    }
398
399    #[rstest::rstest]
400    #[case(
401        r#"FENCE
402FENCE
403FENCE
404"#,
405        Ok(vec![0.0, 0.0, 0.0])
406    )]
407    #[case(
408        r#"DEFFRAME 0 "a":
409    SAMPLE-RATE: 1e9
410PULSE 0 "a" flat(duration: 1.0)
411PULSE 0 "a" flat(duration: 1.0)
412PULSE 0 "a" flat(duration: 1.0)
413"#,
414        Ok(vec![0.0, 1.0, 2.0])
415    )]
416    #[case(
417        r#"DEFFRAME 0 "a":
418    SAMPLE-RATE: 1e9
419PULSE 0 "a" erf_square(duration: 1.0, pad_left: 0.2, pad_right: 0.3)
420PULSE 0 "a" erf_square(duration: 0.1, pad_left: 0.7, pad_right: 0.7)
421PULSE 0 "a" erf_square(duration: 0.5, pad_left: 0.6, pad_right: 0.4)
422FENCE
423"#,
424        Ok(vec![0.0, 1.5, 3.0, 4.5])
425    )]
426    #[case(
427        r#"DEFFRAME 0 "a":
428    SAMPLE-RATE: 1e9
429DEFFRAME 0 "b":
430    SAMPLE-RATE: 1e9
431NONBLOCKING PULSE 0 "a" flat(duration: 1.0)
432NONBLOCKING PULSE 0 "b" flat(duration: 10.0)
433FENCE
434PULSE 0 "a" flat(duration: 1.0)
435FENCE
436PULSE 0 "a" flat(duration: 1.0)
437"#,
438        Ok(vec![0.0, 0.0, 10.0, 10.0, 11.0, 11.0])
439    )]
440    #[case(
441        r#"DEFFRAME 0 "a":
442    SAMPLE-RATE: 1e9
443DEFFRAME 0 "b":
444    SAMPLE-RATE: 1e9
445DELAY 0 "a" 1.0
446SET-PHASE 0 "a" 1.0
447SHIFT-PHASE 0 "a" 1.0
448SWAP-PHASES 0 "a" 0 "b"
449SET-FREQUENCY 0 "a" 1.0
450SHIFT-FREQUENCY 0 "a" 1.0
451SET-SCALE 0 "a" 1.0
452FENCE
453PULSE 0 "a" flat(duration: 1.0)
454"#,
455        Ok(vec![0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
456    )]
457    #[case("RESET", Err(()))]
458    fn schedule_seconds(#[case] input_program: &str, #[case] expected_times: Result<Vec<f64>, ()>) {
459        let program: Program = input_program.parse().unwrap();
460        let block: crate::program::analysis::BasicBlock = (&program).try_into().unwrap();
461        let mut handler = InstructionHandler::default();
462        let scheduled_block =
463            crate::program::scheduling::ScheduledBasicBlock::build(block, &program, &mut handler)
464                .unwrap();
465        match (
466            scheduled_block.as_schedule_seconds(&program),
467            expected_times,
468        ) {
469            (Ok(schedule), Ok(expected_times)) => {
470                let times = schedule
471                    .items()
472                    .iter()
473                    .map(|item| item.time_span.start_time.0)
474                    .collect::<Vec<_>>();
475                assert_eq!(expected_times, times);
476            }
477            (Err(_), Err(_)) => {}
478            (Ok(schedule), Err(_)) => {
479                let times = schedule
480                    .items()
481                    .iter()
482                    .map(|item| item.time_span.start_time.0)
483                    .collect::<Vec<_>>();
484                panic!("expected error, got {:?}", times);
485            }
486            (Err(error), Ok(_)) => {
487                panic!("expected success, got error: {error}")
488            }
489        }
490    }
491
492    #[rstest::rstest]
493    #[case::identical((0, 10), (0, 10), (0, 10))]
494    #[case::adjacent((0, 1), (1, 1), (0, 2))]
495    #[case::disjoint((0, 10), (20, 10), (0, 30))]
496    #[case::disjoint_reverse((20, 10), (0, 10), (0, 30))]
497    fn time_span_union(
498        #[case] a: (usize, usize),
499        #[case] b: (usize, usize),
500        #[case] expected: (usize, usize),
501    ) {
502        let a = TimeSpan {
503            start_time: a.0,
504            duration: a.1,
505        };
506        let b = TimeSpan {
507            start_time: b.0,
508            duration: b.1,
509        };
510        let expected = TimeSpan {
511            start_time: expected.0,
512            duration: expected.1,
513        };
514        assert_eq!(expected, a.union(b));
515    }
516}