quil_rs/program/scheduling/
schedule.rs

1//! A Schedule represents a flattening of the [`DependencyGraph`] into a linear sequence of
2//! instructions, with each instruction assigned a start time and duration.
3
4use std::collections::HashMap;
5
6use itertools::Itertools;
7use petgraph::{
8    visit::{EdgeFiltered, Topo},
9    Direction,
10};
11
12use crate::{
13    instruction::{
14        AttributeValue, Capture, Delay, Instruction, InstructionHandler, Pulse, RawCapture,
15        WaveformInvocation,
16    },
17    quil::Quil,
18    Program,
19};
20
21use super::{ExecutionDependency, ScheduledBasicBlock, ScheduledGraphNode};
22
23#[derive(Clone, Debug, Default, PartialEq, PartialOrd)]
24#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
25pub struct Seconds(pub f64);
26
27impl std::ops::Add<Seconds> for Seconds {
28    type Output = Seconds;
29
30    fn add(self, rhs: Seconds) -> Self::Output {
31        Self(rhs.0 + self.0)
32    }
33}
34
35impl std::ops::Sub<Seconds> for Seconds {
36    type Output = Seconds;
37
38    fn sub(self, rhs: Seconds) -> Self::Output {
39        Self(self.0 - rhs.0)
40    }
41}
42
43pub trait Zero: PartialEq + Sized {
44    fn zero() -> Self;
45
46    fn is_zero(&self) -> bool {
47        self == &Self::zero()
48    }
49}
50
51impl Zero for Seconds {
52    fn zero() -> Self {
53        Self(0.0)
54    }
55}
56
57#[derive(Clone, Debug, PartialEq)]
58pub struct Schedule<TimeUnit> {
59    pub(crate) items: Vec<ComputedScheduleItem<TimeUnit>>,
60    /// The total duration of the block. This is the end time of the schedule when it starts at `TimeUnit::zero()`
61    duration: TimeUnit,
62}
63
64impl<TimeUnit> Schedule<TimeUnit> {
65    pub fn duration(&self) -> &TimeUnit {
66        &self.duration
67    }
68
69    pub fn items(&self) -> &[ComputedScheduleItem<TimeUnit>] {
70        self.items.as_ref()
71    }
72
73    pub fn into_items(self) -> Vec<ComputedScheduleItem<TimeUnit>> {
74        self.items
75    }
76}
77
78impl<TimeUnit: Clone + PartialOrd + std::ops::Add<TimeUnit, Output = TimeUnit> + Zero>
79    From<Vec<ComputedScheduleItem<TimeUnit>>> for Schedule<TimeUnit>
80{
81    fn from(items: Vec<ComputedScheduleItem<TimeUnit>>) -> Self {
82        let duration = items
83            .iter()
84            .map(|item| item.time_span.start_time.clone() + item.time_span.duration.clone())
85            .fold(TimeUnit::zero(), |acc, el| if el > acc { el } else { acc });
86        Self { items, duration }
87    }
88}
89
90impl<TimeUnit: Zero> Default for Schedule<TimeUnit> {
91    fn default() -> Self {
92        Self {
93            items: Default::default(),
94            duration: TimeUnit::zero(),
95        }
96    }
97}
98
99pub type ScheduleSeconds = Schedule<Seconds>;
100
101#[derive(Clone, Debug, PartialEq)]
102pub struct ComputedScheduleItem<TimeUnit> {
103    pub time_span: TimeSpan<TimeUnit>,
104    pub instruction_index: usize,
105}
106
107#[derive(Debug, thiserror::Error)]
108pub enum ComputedScheduleError {
109    #[error("unknown duration for instruction {}", instruction.to_quil_or_debug())]
110    UnknownDuration { instruction: Instruction },
111
112    #[error("internal error: invalid dependency graph")]
113    InvalidDependencyGraph,
114}
115
116pub type ComputedScheduleResult<T> = Result<T, ComputedScheduleError>;
117
118/// Represents a span of time, for some unit of time
119#[derive(Clone, Debug, PartialEq)]
120pub struct TimeSpan<TimeUnit> {
121    /// The inclusive start time of the described item
122    pub start_time: TimeUnit,
123
124    /// The described item's continuous duration
125    pub duration: TimeUnit,
126}
127
128impl<TimeUnit> TimeSpan<TimeUnit> {
129    pub fn start_time(&self) -> &TimeUnit {
130        &self.start_time
131    }
132
133    pub fn duration(&self) -> &TimeUnit {
134        &self.duration
135    }
136}
137
138impl<TimeUnit: Clone + std::ops::Add<TimeUnit, Output = TimeUnit>> TimeSpan<TimeUnit> {
139    pub fn end(&self) -> TimeUnit {
140        self.start_time.clone() + self.duration.clone()
141    }
142}
143
144impl<
145        TimeUnit: Clone
146            + PartialOrd
147            + std::ops::Add<TimeUnit, Output = TimeUnit>
148            + std::ops::Sub<TimeUnit, Output = TimeUnit>,
149    > TimeSpan<TimeUnit>
150{
151    pub(crate) fn union(self, rhs: Self) -> Self {
152        let start_time = if rhs.start_time < self.start_time {
153            rhs.start_time.clone()
154        } else {
155            self.start_time.clone()
156        };
157
158        let self_end_time = self.start_time.clone() + self.duration;
159        let rhs_end_time = rhs.start_time + rhs.duration;
160        let end_time = if self_end_time < rhs_end_time {
161            rhs_end_time
162        } else {
163            self_end_time
164        };
165
166        Self {
167            duration: end_time - start_time.clone(),
168            start_time,
169        }
170    }
171}
172
173impl<'p> ScheduledBasicBlock<'p> {
174    /// Return the duration of a scheduled Quil instruction:
175    ///
176    /// * For PULSE and CAPTURE, this is the duration of the waveform at the frame's sample rate
177    /// * For DELAY and RAW-CAPTURE, it's the named duration
178    /// * For supporting instructions like SET-*, SHIFT-*, and FENCE, it's 0
179    ///
180    /// Return `None` for other instructions.
181    pub(crate) fn instruction_duration_seconds<H: InstructionHandler>(
182        program: &Program,
183        instruction: &Instruction,
184        handler: &H,
185    ) -> Option<Seconds> {
186        match instruction {
187            Instruction::Capture(Capture { waveform, .. })
188            | Instruction::Pulse(Pulse { waveform, .. }) => {
189                Self::waveform_duration_seconds(program, instruction, waveform, handler)
190            }
191            Instruction::Delay(Delay { duration, .. })
192            | Instruction::RawCapture(RawCapture { duration, .. }) => {
193                duration.to_real().ok().map(Seconds)
194            }
195            Instruction::Fence(_)
196            | Instruction::SetFrequency(_)
197            | Instruction::SetPhase(_)
198            | Instruction::SetScale(_)
199            | Instruction::ShiftFrequency(_)
200            | Instruction::ShiftPhase(_)
201            | Instruction::SwapPhases(_) => Some(Seconds(0.0)),
202            _ => None,
203        }
204    }
205
206    /// Return the duration of a Quil waveform:
207    ///
208    /// If the waveform is defined in the program with `DEFWAVEFORM`, the duration is the sample count
209    /// divided by the sample rate.
210    ///
211    /// Otherwise, it's the `duration` parameter of the waveform invocation. This relies on the assumption that
212    /// all template waveforms in use have such a parameter in units of seconds.
213    fn waveform_duration_seconds<H: InstructionHandler>(
214        program: &Program,
215        instruction: &Instruction,
216        WaveformInvocation { name, parameters }: &WaveformInvocation,
217        handler: &H,
218    ) -> Option<Seconds> {
219        if let Some(definition) = program.waveforms.get(name) {
220            let sample_count = definition.matrix.len();
221            let common_sample_rate =
222                handler
223                    .matching_frames(program, instruction)
224                    .and_then(|frames| {
225                        frames
226                            .used
227                            .into_iter()
228                            .filter_map(|frame| {
229                                program
230                                    .frames
231                                    .get(frame)
232                                    .and_then(|frame_definition| {
233                                        frame_definition.get("SAMPLE-RATE")
234                                    })
235                                    .and_then(|sample_rate_expression| match sample_rate_expression
236                                    {
237                                        AttributeValue::String(_) => None,
238                                        AttributeValue::Expression(expression) => Some(expression),
239                                    })
240                                    .and_then(|expression| expression.to_real().ok())
241                            })
242                            .all_equal_value()
243                            .ok()
244                    });
245
246            common_sample_rate
247                .map(|sample_rate| sample_count as f64 / sample_rate)
248                .map(Seconds)
249        } else {
250            // Per the Quil spec, all waveform templates have a "duration"
251            // parameter, and "erf_square" also has "pad_left" and "pad_right".
252            // We explicitly choose to be more flexible here, and allow any
253            // built-in waveform templates to have "pad_*" parameters, as well
254            // as allow "erf_square" to omit them.
255            let parameter = |parameter_name| {
256                parameters
257                    .get(parameter_name)
258                    .and_then(|v| v.to_real().ok())
259                    .map(Seconds)
260            };
261            Some(
262                parameter("duration")?
263                    + parameter("pad_left").unwrap_or(Seconds::zero())
264                    + parameter("pad_right").unwrap_or(Seconds::zero()),
265            )
266        }
267    }
268
269    /// Compute the flattened schedule for this [`ScheduledBasicBlock`] in terms of seconds,
270    /// using a default built-in calculation for the duration of scheduled instructions.
271    pub fn as_schedule_seconds<H: InstructionHandler>(
272        &self,
273        program: &Program,
274        handler: &H,
275    ) -> ComputedScheduleResult<ScheduleSeconds> {
276        self.as_schedule(program, |prog, instr| {
277            Self::instruction_duration_seconds(prog, instr, handler)
278        })
279    }
280
281    /// Compute the flattened schedule for this [`ScheduledBasicBlock`] using a user-provided
282    /// closure for computation of instruction duration.
283    ///
284    /// Return an error if the schedule cannot be computed from the information provided.
285    pub fn as_schedule<
286        F,
287        TimeUnit: Clone + PartialOrd + std::ops::Add<TimeUnit, Output = TimeUnit> + Zero,
288    >(
289        &self,
290        program: &'p Program,
291        get_duration: F,
292    ) -> ComputedScheduleResult<Schedule<TimeUnit>>
293    where
294        F: Fn(&'p Program, &'p Instruction) -> Option<TimeUnit>,
295    {
296        let mut schedule = Schedule::default();
297        let mut end_time_by_instruction_index = HashMap::<usize, TimeUnit>::new();
298
299        let graph_filtered = EdgeFiltered::from_fn(&self.graph, |(_, _, dependencies)| {
300            dependencies.contains(&ExecutionDependency::Scheduled)
301        });
302        let mut topo = Topo::new(&graph_filtered);
303
304        while let Some(instruction_node) = topo.next(&graph_filtered) {
305            if let ScheduledGraphNode::InstructionIndex(index) = instruction_node {
306                let instruction = *self
307                    .basic_block()
308                    .instructions()
309                    .get(index)
310                    .ok_or_else(|| ComputedScheduleError::InvalidDependencyGraph)?;
311                let duration = get_duration(program, instruction).ok_or(
312                    ComputedScheduleError::UnknownDuration {
313                        instruction: instruction.clone(),
314                    },
315                )?;
316
317                let latest_previous_instruction_scheduler_end_time = self
318                    .graph
319                    .edges_directed(instruction_node, Direction::Incoming)
320                    .filter_map(|(source, _, dependencies)| {
321                        if dependencies.contains(&ExecutionDependency::Scheduled) {
322                            match source {
323                                ScheduledGraphNode::BlockStart => Ok(Some(TimeUnit::zero())),
324                                ScheduledGraphNode::InstructionIndex(previous_index) => {
325                                    end_time_by_instruction_index
326                                        .get(&previous_index)
327                                        .cloned()
328                                        .ok_or(ComputedScheduleError::InvalidDependencyGraph)
329                                        .map(Some)
330                                }
331                                ScheduledGraphNode::BlockEnd => unreachable!(),
332                            }
333                        } else {
334                            Ok(None)
335                        }
336                        .transpose()
337                    })
338                    .collect::<Result<Vec<TimeUnit>, _>>()?
339                    .into_iter()
340                    // this implementation allows us to require PartialOrd instead of Ord (required for `.max()`),
341                    // which is convenient for f64
342                    .fold(TimeUnit::zero(), |acc, el| if el > acc { el } else { acc });
343
344                let start_time = latest_previous_instruction_scheduler_end_time;
345                let end_time = start_time.clone() + duration.clone();
346                if schedule.duration < end_time {
347                    schedule.duration = end_time.clone();
348                }
349
350                end_time_by_instruction_index.insert(index, end_time);
351                schedule.items.push(ComputedScheduleItem {
352                    time_span: TimeSpan {
353                        start_time,
354                        duration,
355                    },
356                    instruction_index: index,
357                });
358            }
359        }
360
361        Ok(schedule)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use core::panic;
368    use std::str::FromStr;
369
370    use crate::{instruction::DefaultHandler, program::scheduling::TimeSpan, Program};
371
372    #[rstest::rstest]
373    #[case("CAPTURE 0 \"a\" flat(duration: 1.0) ro", Some(1.0))]
374    #[case("DELAY 0 \"a\" 1.0", Some(1.0))]
375    #[case("FENCE", Some(0.0))]
376    #[case("PULSE 0 \"a\" flat(duration: 1.0)", Some(1.0))]
377    #[case("RAW-CAPTURE 0 \"a\" 1.0 ro", Some(1.0))]
378    #[case("RESET", None)]
379    #[case("SET-FREQUENCY 0 \"a\" 1.0", Some(0.0))]
380    #[case("SET-PHASE 0 \"a\" 1.0", Some(0.0))]
381    #[case("SET-SCALE 0 \"a\" 1.0", Some(0.0))]
382    #[case("SHIFT-FREQUENCY 0 \"a\" 1.0", Some(0.0))]
383    #[case("SHIFT-PHASE 0 \"a\" 1.0", Some(0.0))]
384    #[case("SWAP-PHASES 0 \"a\" 0 \"b\"", Some(0.0))]
385    fn instruction_duration_seconds(
386        #[case] input_program: &str,
387        #[case] expected_duration: Option<f64>,
388    ) {
389        let empty_program = Program::new();
390        let program = Program::from_str(input_program)
391            .map_err(|e| e.to_string())
392            .unwrap();
393        let instruction = program.into_instructions().remove(0);
394        let duration =
395            crate::program::scheduling::ScheduledBasicBlock::instruction_duration_seconds(
396                &empty_program,
397                &instruction,
398                &DefaultHandler,
399            );
400        assert_eq!(
401            expected_duration.map(crate::program::scheduling::Seconds),
402            duration
403        );
404    }
405
406    #[rstest::rstest]
407    #[case(
408        r#"FENCE
409FENCE
410FENCE
411"#,
412        Ok(vec![0.0, 0.0, 0.0])
413    )]
414    #[case(
415        r#"DEFFRAME 0 "a":
416    SAMPLE-RATE: 1e9
417PULSE 0 "a" flat(duration: 1.0)
418PULSE 0 "a" flat(duration: 1.0)
419PULSE 0 "a" flat(duration: 1.0)
420"#,
421        Ok(vec![0.0, 1.0, 2.0])
422    )]
423    #[case(
424        r#"DEFFRAME 0 "a":
425    SAMPLE-RATE: 1e9
426PULSE 0 "a" erf_square(duration: 1.0, pad_left: 0.2, pad_right: 0.3)
427PULSE 0 "a" erf_square(duration: 0.1, pad_left: 0.7, pad_right: 0.7)
428PULSE 0 "a" erf_square(duration: 0.5, pad_left: 0.6, pad_right: 0.4)
429FENCE
430"#,
431        Ok(vec![0.0, 1.5, 3.0, 4.5])
432    )]
433    #[case(
434        r#"DEFFRAME 0 "a":
435    SAMPLE-RATE: 1e9
436DEFFRAME 0 "b":
437    SAMPLE-RATE: 1e9
438NONBLOCKING PULSE 0 "a" flat(duration: 1.0)
439NONBLOCKING PULSE 0 "b" flat(duration: 10.0)
440FENCE
441PULSE 0 "a" flat(duration: 1.0)
442FENCE
443PULSE 0 "a" flat(duration: 1.0)
444"#,
445        Ok(vec![0.0, 0.0, 10.0, 10.0, 11.0, 11.0])
446    )]
447    #[case(
448        r#"DEFFRAME 0 "a":
449    SAMPLE-RATE: 1e9
450DEFFRAME 0 "b":
451    SAMPLE-RATE: 1e9
452DELAY 0 "a" 1.0
453SET-PHASE 0 "a" 1.0
454SHIFT-PHASE 0 "a" 1.0
455SWAP-PHASES 0 "a" 0 "b"
456SET-FREQUENCY 0 "a" 1.0
457SHIFT-FREQUENCY 0 "a" 1.0
458SET-SCALE 0 "a" 1.0
459FENCE
460PULSE 0 "a" flat(duration: 1.0)
461"#,
462        Ok(vec![0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
463    )]
464    #[case("RESET", Err(()))]
465    fn schedule_seconds(#[case] input_program: &str, #[case] expected_times: Result<Vec<f64>, ()>) {
466        let program: Program = input_program.parse().unwrap();
467        let block: crate::program::analysis::BasicBlock = (&program).try_into().unwrap();
468        let scheduled_block = crate::program::scheduling::ScheduledBasicBlock::build(
469            block,
470            &program,
471            &DefaultHandler,
472        )
473        .unwrap();
474        match (
475            scheduled_block.as_schedule_seconds(&program, &DefaultHandler),
476            expected_times,
477        ) {
478            (Ok(schedule), Ok(expected_times)) => {
479                let times = schedule
480                    .items()
481                    .iter()
482                    .map(|item| item.time_span.start_time.0)
483                    .collect::<Vec<_>>();
484                assert_eq!(expected_times, times);
485            }
486            (Err(_), Err(_)) => {}
487            (Ok(schedule), Err(_)) => {
488                let times = schedule
489                    .items()
490                    .iter()
491                    .map(|item| item.time_span.start_time.0)
492                    .collect::<Vec<_>>();
493                panic!("expected error, got {:?}", times);
494            }
495            (Err(error), Ok(_)) => {
496                panic!("expected success, got error: {error}")
497            }
498        }
499    }
500
501    #[rstest::rstest]
502    #[case::identical((0, 10), (0, 10), (0, 10))]
503    #[case::adjacent((0, 1), (1, 1), (0, 2))]
504    #[case::disjoint((0, 10), (20, 10), (0, 30))]
505    #[case::disjoint_reverse((20, 10), (0, 10), (0, 30))]
506    fn time_span_union(
507        #[case] a: (usize, usize),
508        #[case] b: (usize, usize),
509        #[case] expected: (usize, usize),
510    ) {
511        let a = TimeSpan {
512            start_time: a.0,
513            duration: a.1,
514        };
515        let b = TimeSpan {
516            start_time: b.0,
517            duration: b.1,
518        };
519        let expected = TimeSpan {
520            start_time: expected.0,
521            duration: expected.1,
522        };
523        assert_eq!(expected, a.union(b));
524    }
525}