hugr_core/extension/prelude/
unwrap_builder.rs1use std::iter;
2
3use crate::{
4 Wire,
5 builder::{BuildError, BuildHandle, Dataflow, DataflowSubContainer, SubContainer},
6 extension::prelude::{ConstError, PANIC_OP_ID},
7 ops::handle::DataflowOpID,
8 types::{SumType, Type, TypeArg, TypeRow},
9};
10use itertools::{Itertools as _, zip_eq};
11
12use super::PRELUDE;
13
14pub trait UnwrapBuilder: Dataflow {
16 fn add_panic(
18 &mut self,
19 err: ConstError,
20 output_row: impl IntoIterator<Item = Type>,
21 inputs: impl IntoIterator<Item = (Wire, Type)>,
22 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
23 let (input_wires, input_types): (Vec<_>, Vec<_>) = inputs.into_iter().unzip();
24 let input_arg: TypeArg = input_types
25 .into_iter()
26 .map(<TypeArg as From<_>>::from)
27 .collect_vec()
28 .into();
29 let output_arg: TypeArg = output_row
30 .into_iter()
31 .map(<TypeArg as From<_>>::from)
32 .collect_vec()
33 .into();
34 let op = PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?;
35 let err = self.add_load_value(err);
36 self.add_dataflow_op(op, iter::once(err).chain(input_wires))
37 }
38
39 fn build_unwrap_sum<const N: usize>(
42 &mut self,
43 tag: usize,
44 sum_type: SumType,
45 input: Wire,
46 ) -> Result<[Wire; N], BuildError> {
47 self.build_expect_sum(tag, sum_type, input, |i| {
48 format!("Expected variant {tag} but got variant {i}")
49 })
50 }
51
52 fn build_expect_sum<const N: usize, T: Into<ConstError>>(
66 &mut self,
67 tag: usize,
68 sum_type: SumType,
69 input: Wire,
70 mut error: impl FnMut(usize) -> T,
71 ) -> Result<[Wire; N], BuildError> {
72 let variants: Vec<TypeRow> = (0..sum_type.num_variants())
73 .map(|i| {
74 let tr_rv = sum_type.get_variant(i).unwrap().to_owned();
75 TypeRow::try_from(tr_rv)
76 })
77 .collect::<Result<_, _>>()?;
78
79 let output_row = variants.get(tag).unwrap();
81
82 let mut conditional =
83 self.conditional_builder((variants.clone(), input), [], output_row.clone())?;
84 for (i, variant) in variants.iter().enumerate() {
85 let mut case = conditional.case_builder(i)?;
86 if i == tag {
87 let outputs = case.input_wires();
88 case.finish_with_outputs(outputs)?;
89 } else {
90 let output_row = output_row.iter().cloned();
91 let inputs = zip_eq(case.input_wires(), variant.iter().cloned());
92 let err = error(i).into();
93 let outputs = case.add_panic(err, output_row, inputs)?.outputs();
94 case.finish_with_outputs(outputs)?;
95 }
96 }
97 Ok(conditional.finish_sub_container()?.outputs_arr())
98 }
99}
100
101impl<D: Dataflow> UnwrapBuilder for D {}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::{
107 builder::{DFGBuilder, DataflowHugr},
108 extension::prelude::{bool_t, option_type},
109 types::Signature,
110 };
111
112 #[test]
113 fn test_build_unwrap() {
114 let mut builder =
115 DFGBuilder::new(Signature::new(Type::from(option_type(bool_t())), bool_t())).unwrap();
116
117 let [opt] = builder.input_wires_arr();
118
119 let [res] = builder
120 .build_unwrap_sum(1, option_type(bool_t()), opt)
121 .unwrap();
122 builder.finish_hugr_with_outputs([res]).unwrap();
123 }
124}