hugr_core/builder/
conditional.rs

1use crate::hugr::views::HugrView;
2use crate::types::{Signature, TypeRow};
3
4use crate::ops::handle::{CaseID, NodeHandle};
5use crate::ops::{self};
6
7use super::HugrBuilder;
8use super::build_traits::SubContainer;
9use super::handle::BuildHandle;
10use super::{
11    BuildError, ConditionalID,
12    build_traits::Container,
13    dataflow::{DFGBuilder, DFGWrapper},
14};
15
16use crate::Node;
17use crate::{Hugr, hugr::HugrMut};
18
19use std::collections::HashSet;
20
21use thiserror::Error;
22
23/// Builder for a [`ops::Case`] child graph.
24pub type CaseBuilder<B> = DFGWrapper<B, BuildHandle<CaseID>>;
25
26#[derive(Debug, Clone, PartialEq, Eq, Error)]
27#[non_exhaustive]
28pub enum ConditionalBuildError {
29    /// Case already built.
30    #[error("Case {case} of Conditional node {conditional} has already been built.")]
31    CaseBuilt { conditional: Node, case: usize },
32    /// Case already built.
33    #[error("Conditional node {conditional} has no case with index {case}.")]
34    NotCase { conditional: Node, case: usize },
35    /// Not all cases of Conditional built.
36    #[error("Cases {cases:?} of Conditional node {conditional} have not been built.")]
37    NotAllCasesBuilt {
38        conditional: Node,
39        cases: HashSet<usize>,
40    },
41}
42
43/// Builder for a [`ops::Conditional`] node's children.
44#[derive(Debug, Clone, PartialEq)]
45pub struct ConditionalBuilder<T> {
46    pub(super) base: T,
47    pub(super) conditional_node: Node,
48    pub(super) n_out_wires: usize,
49    pub(super) case_nodes: Vec<Option<Node>>,
50}
51
52impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ConditionalBuilder<T> {
53    #[inline]
54    fn container_node(&self) -> Node {
55        self.conditional_node
56    }
57
58    #[inline]
59    fn hugr_mut(&mut self) -> &mut Hugr {
60        self.base.as_mut()
61    }
62
63    #[inline]
64    fn hugr(&self) -> &Hugr {
65        self.base.as_ref()
66    }
67}
68
69impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for ConditionalBuilder<H> {
70    type ContainerHandle = BuildHandle<ConditionalID>;
71
72    fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
73        let cases: HashSet<usize> = self
74            .case_nodes
75            .iter()
76            .enumerate()
77            .filter_map(|(i, node)| if node.is_none() { Some(i) } else { None })
78            .collect();
79        if !cases.is_empty() {
80            return Err(ConditionalBuildError::NotAllCasesBuilt {
81                conditional: self.conditional_node,
82                cases,
83            }
84            .into());
85        }
86        Ok((self.conditional_node, self.n_out_wires).into())
87    }
88}
89impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
90    /// Return a builder the Case node with index `case`.
91    ///
92    /// # Panics
93    ///
94    /// Panics if the parent node is not of type [`ops::Conditional`].
95    ///
96    /// # Errors
97    ///
98    /// This function will return an error if the case has already been built,
99    /// `case` is not a valid index or if there is an error adding nodes.
100    pub fn case_builder(&mut self, case: usize) -> Result<CaseBuilder<&mut Hugr>, BuildError> {
101        let conditional = self.conditional_node;
102        let control_op = self.hugr().get_optype(self.conditional_node);
103
104        let cond: ops::Conditional = control_op
105            .clone()
106            .try_into()
107            .expect("Parent node does not have Conditional optype.");
108        let inputs = cond
109            .case_input_row(case)
110            .ok_or(ConditionalBuildError::NotCase { conditional, case })?;
111
112        if self.case_nodes.get(case).unwrap().is_some() {
113            return Err(ConditionalBuildError::CaseBuilt { conditional, case }.into());
114        }
115
116        let outputs = cond.outputs;
117        let case_op = ops::Case {
118            signature: Signature::new(inputs.clone(), outputs.clone()),
119        };
120        let case_node =
121            // add case before any existing subsequent cases
122            if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
123                self.hugr_mut().add_node_before(sibling_node, case_op)
124            } else {
125                self.add_child_node(case_op)
126            };
127
128        self.case_nodes[case] = Some(case_node);
129
130        let dfg_builder = DFGBuilder::create_with_io(
131            self.hugr_mut(),
132            case_node,
133            Signature::new(inputs, outputs),
134        )?;
135
136        Ok(CaseBuilder::from_dfg_builder(dfg_builder))
137    }
138}
139
140impl HugrBuilder for ConditionalBuilder<Hugr> {
141    fn finish_hugr(self) -> Result<Hugr, crate::hugr::ValidationError<Node>> {
142        self.base.validate()?;
143        Ok(self.base)
144    }
145}
146
147impl ConditionalBuilder<Hugr> {
148    /// Initialize a Conditional rooted HUGR builder.
149    pub fn new(
150        sum_rows: impl IntoIterator<Item = TypeRow>,
151        other_inputs: impl Into<TypeRow>,
152        outputs: impl Into<TypeRow>,
153    ) -> Result<Self, BuildError> {
154        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
155        let other_inputs = other_inputs.into();
156        let outputs: TypeRow = outputs.into();
157
158        let n_out_wires = outputs.len();
159        let n_cases = sum_rows.len();
160
161        let op = ops::Conditional {
162            sum_rows,
163            other_inputs,
164            outputs,
165        };
166        let base = Hugr::new_with_entrypoint(op).expect("Conditional entrypoint should be valid");
167        let conditional_node = base.entrypoint();
168
169        Ok(ConditionalBuilder {
170            base,
171            conditional_node,
172            n_out_wires,
173            case_nodes: vec![None; n_cases],
174        })
175    }
176}
177
178impl CaseBuilder<Hugr> {
179    /// Initialize a Case rooted HUGR
180    pub fn new(signature: Signature) -> Result<Self, BuildError> {
181        // Start by building a conditional with a single case
182        let mut conditional =
183            ConditionalBuilder::new([signature.input.clone()], vec![], signature.output.clone())?;
184        let case = conditional.case_builder(0)?.finish_sub_container()?.node();
185
186        // Extract the half-finished hugr, and wrap it in an owned case builder
187        let mut base = std::mem::take(conditional.hugr_mut());
188        base.set_entrypoint(case);
189        let dfg_builder = DFGBuilder::create(base, case)?;
190        Ok(CaseBuilder::from_dfg_builder(dfg_builder))
191    }
192}
193#[cfg(test)]
194mod test {
195    use cool_asserts::assert_matches;
196
197    use crate::builder::{DataflowSubContainer, ModuleBuilder};
198
199    use crate::extension::prelude::usize_t;
200    use crate::{
201        builder::{Dataflow, test::n_identity},
202        ops::Value,
203        type_row,
204    };
205
206    use super::*;
207
208    #[test]
209    fn basic_conditional_case() -> Result<(), BuildError> {
210        let case_b = CaseBuilder::new(Signature::new_endo(vec![usize_t(), usize_t()]))?;
211        let [in0, in1] = case_b.input_wires_arr();
212        case_b.finish_with_outputs([in0, in1])?;
213        Ok(())
214    }
215
216    #[test]
217    fn basic_conditional() -> Result<(), BuildError> {
218        let mut conditional_b =
219            ConditionalBuilder::new([type_row![], type_row![]], vec![usize_t()], vec![usize_t()])?;
220
221        n_identity(conditional_b.case_builder(1)?)?;
222        n_identity(conditional_b.case_builder(0)?)?;
223        Ok(())
224    }
225
226    #[test]
227    fn basic_conditional_module() -> Result<(), BuildError> {
228        let build_result: Result<Hugr, BuildError> = {
229            let mut module_builder = ModuleBuilder::new();
230            let mut fbuild = module_builder
231                .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
232            let tru_const = fbuild.add_constant(Value::true_val());
233            let _fdef = {
234                let const_wire = fbuild.load_const(&tru_const);
235                let [int] = fbuild.input_wires_arr();
236                let conditional_id = {
237                    let other_inputs = vec![(usize_t(), int)];
238                    let outputs = vec![usize_t()].into();
239                    let mut conditional_b = fbuild.conditional_builder(
240                        ([type_row![], type_row![]], const_wire),
241                        other_inputs,
242                        outputs,
243                    )?;
244
245                    n_identity(conditional_b.case_builder(0)?)?;
246                    n_identity(conditional_b.case_builder(1)?)?;
247
248                    conditional_b.finish_sub_container()?
249                };
250                let [int] = conditional_id.outputs_arr();
251                fbuild.finish_with_outputs([int])?
252            };
253            Ok(module_builder.finish_hugr()?)
254        };
255
256        assert_matches!(build_result, Ok(_));
257
258        Ok(())
259    }
260
261    #[test]
262    fn test_not_all_cases() -> Result<(), BuildError> {
263        let mut builder =
264            ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
265        n_identity(builder.case_builder(0)?)?;
266        assert_matches!(
267            builder.finish_sub_container().map(|_| ()),
268            Err(BuildError::ConditionalError(
269                ConditionalBuildError::NotAllCasesBuilt { .. }
270            ))
271        );
272        Ok(())
273    }
274
275    #[test]
276    fn test_case_already_built() -> Result<(), BuildError> {
277        let mut builder =
278            ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
279        n_identity(builder.case_builder(0)?)?;
280        assert_matches!(
281            builder.case_builder(0).map(|_| ()),
282            Err(BuildError::ConditionalError(
283                ConditionalBuildError::CaseBuilt { .. }
284            ))
285        );
286        Ok(())
287    }
288}