hugr_core/extension/prelude/
unwrap_builder.rs

1use std::iter;
2
3use crate::{
4    Wire,
5    builder::{BuildError, BuildHandle, Dataflow, DataflowSubContainer, SubContainer},
6    extension::prelude::{ConstError, PANIC_OP_ID},
7    ops::handle::DataflowOpID,
8    types::{SumType, Type, TypeArg, TypeRow},
9};
10use itertools::{Itertools as _, zip_eq};
11
12use super::PRELUDE;
13
14/// Extend dataflow builders with methods for building unwrap operations.
15pub trait UnwrapBuilder: Dataflow {
16    /// Add a panic operation to the dataflow with the given error.
17    fn add_panic(
18        &mut self,
19        err: ConstError,
20        output_row: impl IntoIterator<Item = Type>,
21        inputs: impl IntoIterator<Item = (Wire, Type)>,
22    ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
23        let (input_wires, input_types): (Vec<_>, Vec<_>) = inputs.into_iter().unzip();
24        let input_arg: TypeArg = input_types
25            .into_iter()
26            .map(<TypeArg as From<_>>::from)
27            .collect_vec()
28            .into();
29        let output_arg: TypeArg = output_row
30            .into_iter()
31            .map(<TypeArg as From<_>>::from)
32            .collect_vec()
33            .into();
34        let op = PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?;
35        let err = self.add_load_value(err);
36        self.add_dataflow_op(op, iter::once(err).chain(input_wires))
37    }
38
39    /// Build an unwrap operation for a sum type to extract the variant at the given tag
40    /// or panic if the tag is not the expected value.
41    fn build_unwrap_sum<const N: usize>(
42        &mut self,
43        tag: usize,
44        sum_type: SumType,
45        input: Wire,
46    ) -> Result<[Wire; N], BuildError> {
47        self.build_expect_sum(tag, sum_type, input, |i| {
48            format!("Expected variant {tag} but got variant {i}")
49        })
50    }
51
52    /// Build an unwrap operation for a sum type to extract the variant at the given tag
53    /// or panic with given message if the tag is not the expected value.
54    ///
55    /// `error` is a function that takes the actual tag and returns the error message
56    /// for cases where the tag is not the expected value.
57    ///
58    /// # Panics
59    ///
60    /// If `tag` is greater than the number of variants in the sum type.
61    ///
62    /// # Errors
63    ///
64    /// Errors in building the unwrapping conditional.
65    fn build_expect_sum<const N: usize, T: Into<ConstError>>(
66        &mut self,
67        tag: usize,
68        sum_type: SumType,
69        input: Wire,
70        mut error: impl FnMut(usize) -> T,
71    ) -> Result<[Wire; N], BuildError> {
72        let variants: Vec<TypeRow> = (0..sum_type.num_variants())
73            .map(|i| {
74                let tr_rv = sum_type.get_variant(i).unwrap().to_owned();
75                TypeRow::try_from(tr_rv)
76            })
77            .collect::<Result<_, _>>()?;
78
79        // TODO don't panic if tag >= num_variants
80        let output_row = variants.get(tag).unwrap();
81
82        let mut conditional =
83            self.conditional_builder((variants.clone(), input), [], output_row.clone())?;
84        for (i, variant) in variants.iter().enumerate() {
85            let mut case = conditional.case_builder(i)?;
86            if i == tag {
87                let outputs = case.input_wires();
88                case.finish_with_outputs(outputs)?;
89            } else {
90                let output_row = output_row.iter().cloned();
91                let inputs = zip_eq(case.input_wires(), variant.iter().cloned());
92                let err = error(i).into();
93                let outputs = case.add_panic(err, output_row, inputs)?.outputs();
94                case.finish_with_outputs(outputs)?;
95            }
96        }
97        Ok(conditional.finish_sub_container()?.outputs_arr())
98    }
99}
100
101impl<D: Dataflow> UnwrapBuilder for D {}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::{
107        builder::{DFGBuilder, DataflowHugr},
108        extension::prelude::{bool_t, option_type},
109        types::Signature,
110    };
111
112    #[test]
113    fn test_build_unwrap() {
114        let mut builder =
115            DFGBuilder::new(Signature::new(Type::from(option_type(bool_t())), bool_t())).unwrap();
116
117        let [opt] = builder.input_wires_arr();
118
119        let [res] = builder
120            .build_unwrap_sum(1, option_type(bool_t()), opt)
121            .unwrap();
122        builder.finish_hugr_with_outputs([res]).unwrap();
123    }
124}