hugr_core/builder/
conditional.rs

1use crate::extension::TO_BE_INFERRED;
2use crate::hugr::views::HugrView;
3use crate::ops::dataflow::DataflowOpTrait;
4use crate::types::{Signature, TypeRow};
5
6use crate::ops;
7use crate::ops::handle::CaseID;
8
9use super::build_traits::SubContainer;
10use super::handle::BuildHandle;
11use super::HugrBuilder;
12use super::{
13    build_traits::Container,
14    dataflow::{DFGBuilder, DFGWrapper},
15    BuildError, ConditionalID,
16};
17
18use crate::Node;
19use crate::{extension::ExtensionSet, hugr::HugrMut, Hugr};
20
21use std::collections::HashSet;
22
23use thiserror::Error;
24
25/// Builder for a [`ops::Case`] child graph.
26pub type CaseBuilder<B> = DFGWrapper<B, BuildHandle<CaseID>>;
27
28#[derive(Debug, Clone, PartialEq, Eq, Error)]
29#[non_exhaustive]
30pub enum ConditionalBuildError {
31    /// Case already built.
32    #[error("Case {case} of Conditional node {conditional} has already been built.")]
33    CaseBuilt { conditional: Node, case: usize },
34    /// Case already built.
35    #[error("Conditional node {conditional} has no case with index {case}.")]
36    NotCase { conditional: Node, case: usize },
37    /// Not all cases of Conditional built.
38    #[error("Cases {cases:?} of Conditional node {conditional} have not been built.")]
39    NotAllCasesBuilt {
40        conditional: Node,
41        cases: HashSet<usize>,
42    },
43}
44
45/// Builder for a [`ops::Conditional`] node's children.
46#[derive(Debug, Clone, PartialEq)]
47pub struct ConditionalBuilder<T> {
48    pub(super) base: T,
49    pub(super) conditional_node: Node,
50    pub(super) n_out_wires: usize,
51    pub(super) case_nodes: Vec<Option<Node>>,
52}
53
54impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ConditionalBuilder<T> {
55    #[inline]
56    fn container_node(&self) -> Node {
57        self.conditional_node
58    }
59
60    #[inline]
61    fn hugr_mut(&mut self) -> &mut Hugr {
62        self.base.as_mut()
63    }
64
65    #[inline]
66    fn hugr(&self) -> &Hugr {
67        self.base.as_ref()
68    }
69}
70
71impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for ConditionalBuilder<H> {
72    type ContainerHandle = BuildHandle<ConditionalID>;
73
74    fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
75        let cases: HashSet<usize> = self
76            .case_nodes
77            .iter()
78            .enumerate()
79            .filter_map(|(i, node)| if node.is_none() { Some(i) } else { None })
80            .collect();
81        if !cases.is_empty() {
82            return Err(ConditionalBuildError::NotAllCasesBuilt {
83                conditional: self.conditional_node,
84                cases,
85            }
86            .into());
87        }
88        Ok((self.conditional_node, self.n_out_wires).into())
89    }
90}
91impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
92    /// Return a builder the Case node with index `case`.
93    ///
94    /// # Panics
95    ///
96    /// Panics if the parent node is not of type [`ops::Conditional`].
97    ///
98    /// # Errors
99    ///
100    /// This function will return an error if the case has already been built,
101    /// `case` is not a valid index or if there is an error adding nodes.
102    pub fn case_builder(&mut self, case: usize) -> Result<CaseBuilder<&mut Hugr>, BuildError> {
103        let conditional = self.conditional_node;
104        let control_op = self.hugr().get_optype(self.conditional_node);
105
106        let cond: ops::Conditional = control_op
107            .clone()
108            .try_into()
109            .expect("Parent node does not have Conditional optype.");
110        let extension_delta = cond.signature().runtime_reqs.clone();
111        let inputs = cond
112            .case_input_row(case)
113            .ok_or(ConditionalBuildError::NotCase { conditional, case })?;
114
115        if self.case_nodes.get(case).unwrap().is_some() {
116            return Err(ConditionalBuildError::CaseBuilt { conditional, case }.into());
117        }
118
119        let outputs = cond.outputs;
120        let case_op = ops::Case {
121            signature: Signature::new(inputs.clone(), outputs.clone())
122                .with_extension_delta(extension_delta.clone()),
123        };
124        let case_node =
125            // add case before any existing subsequent cases
126            if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
127                self.hugr_mut().add_node_before(sibling_node, case_op)
128            } else {
129                self.add_child_node(case_op)
130            };
131
132        self.case_nodes[case] = Some(case_node);
133
134        let dfg_builder = DFGBuilder::create_with_io(
135            self.hugr_mut(),
136            case_node,
137            Signature::new(inputs, outputs).with_extension_delta(extension_delta),
138        )?;
139
140        Ok(CaseBuilder::from_dfg_builder(dfg_builder))
141    }
142}
143
144impl HugrBuilder for ConditionalBuilder<Hugr> {
145    fn finish_hugr(mut self) -> Result<Hugr, crate::hugr::ValidationError> {
146        if cfg!(feature = "extension_inference") {
147            self.base.infer_extensions(false)?;
148        }
149        self.base.validate()?;
150        Ok(self.base)
151    }
152}
153
154impl ConditionalBuilder<Hugr> {
155    /// Initialize a Conditional rooted HUGR builder, extension delta will be inferred.
156    pub fn new(
157        sum_rows: impl IntoIterator<Item = TypeRow>,
158        other_inputs: impl Into<TypeRow>,
159        outputs: impl Into<TypeRow>,
160    ) -> Result<Self, BuildError> {
161        Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED)
162    }
163
164    /// Initialize a Conditional rooted HUGR builder,
165    /// `extension_delta` explicitly specified. Alternatively,
166    /// [new](Self::new) may be used to infer it.
167    pub fn new_exts(
168        sum_rows: impl IntoIterator<Item = TypeRow>,
169        other_inputs: impl Into<TypeRow>,
170        outputs: impl Into<TypeRow>,
171        extension_delta: impl Into<ExtensionSet>,
172    ) -> Result<Self, BuildError> {
173        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
174        let other_inputs = other_inputs.into();
175        let outputs = outputs.into();
176
177        let n_out_wires = outputs.len();
178        let n_cases = sum_rows.len();
179
180        let op = ops::Conditional {
181            sum_rows,
182            other_inputs,
183            outputs,
184            extension_delta: extension_delta.into(),
185        };
186        let base = Hugr::new(op);
187        let conditional_node = base.root();
188
189        Ok(ConditionalBuilder {
190            base,
191            conditional_node,
192            n_out_wires,
193            case_nodes: vec![None; n_cases],
194        })
195    }
196}
197
198impl CaseBuilder<Hugr> {
199    /// Initialize a Case rooted HUGR
200    pub fn new(signature: Signature) -> Result<Self, BuildError> {
201        let op = ops::Case {
202            signature: signature.clone(),
203        };
204        let base = Hugr::new(op);
205        let root = base.root();
206        let dfg_builder = DFGBuilder::create_with_io(base, root, signature)?;
207
208        Ok(CaseBuilder::from_dfg_builder(dfg_builder))
209    }
210}
211#[cfg(test)]
212mod test {
213    use cool_asserts::assert_matches;
214
215    use crate::builder::{DataflowSubContainer, ModuleBuilder};
216
217    use crate::extension::prelude::usize_t;
218    use crate::{
219        builder::{test::n_identity, Dataflow},
220        ops::Value,
221        type_row,
222    };
223
224    use super::*;
225
226    #[test]
227    fn basic_conditional() -> Result<(), BuildError> {
228        let mut conditional_b = ConditionalBuilder::new_exts(
229            [type_row![], type_row![]],
230            vec![usize_t()],
231            vec![usize_t()],
232            ExtensionSet::new(),
233        )?;
234
235        n_identity(conditional_b.case_builder(1)?)?;
236        n_identity(conditional_b.case_builder(0)?)?;
237        Ok(())
238    }
239
240    #[test]
241    fn basic_conditional_module() -> Result<(), BuildError> {
242        let build_result: Result<Hugr, BuildError> = {
243            let mut module_builder = ModuleBuilder::new();
244            let mut fbuild = module_builder
245                .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
246            let tru_const = fbuild.add_constant(Value::true_val());
247            let _fdef = {
248                let const_wire = fbuild.load_const(&tru_const);
249                let [int] = fbuild.input_wires_arr();
250                let conditional_id = {
251                    let other_inputs = vec![(usize_t(), int)];
252                    let outputs = vec![usize_t()].into();
253                    let mut conditional_b = fbuild.conditional_builder(
254                        ([type_row![], type_row![]], const_wire),
255                        other_inputs,
256                        outputs,
257                    )?;
258
259                    n_identity(conditional_b.case_builder(0)?)?;
260                    n_identity(conditional_b.case_builder(1)?)?;
261
262                    conditional_b.finish_sub_container()?
263                };
264                let [int] = conditional_id.outputs_arr();
265                fbuild.finish_with_outputs([int])?
266            };
267            Ok(module_builder.finish_hugr()?)
268        };
269
270        assert_matches!(build_result, Ok(_));
271
272        Ok(())
273    }
274
275    #[test]
276    fn test_not_all_cases() -> Result<(), BuildError> {
277        let mut builder =
278            ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
279        n_identity(builder.case_builder(0)?)?;
280        assert_matches!(
281            builder.finish_sub_container().map(|_| ()),
282            Err(BuildError::ConditionalError(
283                ConditionalBuildError::NotAllCasesBuilt { .. }
284            ))
285        );
286        Ok(())
287    }
288
289    #[test]
290    fn test_case_already_built() -> Result<(), BuildError> {
291        let mut builder =
292            ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
293        n_identity(builder.case_builder(0)?)?;
294        assert_matches!(
295            builder.case_builder(0).map(|_| ()),
296            Err(BuildError::ConditionalError(
297                ConditionalBuildError::CaseBuilt { .. }
298            ))
299        );
300        Ok(())
301    }
302}