1use crate::hugr::views::HugrView;
2use crate::types::{Signature, TypeRow};
3
4use crate::ops::handle::{CaseID, NodeHandle};
5use crate::ops::{self};
6
7use super::HugrBuilder;
8use super::build_traits::SubContainer;
9use super::handle::BuildHandle;
10use super::{
11 BuildError, ConditionalID,
12 build_traits::Container,
13 dataflow::{DFGBuilder, DFGWrapper},
14};
15
16use crate::Node;
17use crate::{Hugr, hugr::HugrMut};
18
19use std::collections::HashSet;
20
21use thiserror::Error;
22
23pub type CaseBuilder<B> = DFGWrapper<B, BuildHandle<CaseID>>;
25
26#[derive(Debug, Clone, PartialEq, Eq, Error)]
27#[non_exhaustive]
28pub enum ConditionalBuildError {
29 #[error("Case {case} of Conditional node {conditional} has already been built.")]
31 CaseBuilt { conditional: Node, case: usize },
32 #[error("Conditional node {conditional} has no case with index {case}.")]
34 NotCase { conditional: Node, case: usize },
35 #[error("Cases {cases:?} of Conditional node {conditional} have not been built.")]
37 NotAllCasesBuilt {
38 conditional: Node,
39 cases: HashSet<usize>,
40 },
41}
42
43#[derive(Debug, Clone, PartialEq)]
45pub struct ConditionalBuilder<T> {
46 pub(super) base: T,
47 pub(super) conditional_node: Node,
48 pub(super) n_out_wires: usize,
49 pub(super) case_nodes: Vec<Option<Node>>,
50}
51
52impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ConditionalBuilder<T> {
53 #[inline]
54 fn container_node(&self) -> Node {
55 self.conditional_node
56 }
57
58 #[inline]
59 fn hugr_mut(&mut self) -> &mut Hugr {
60 self.base.as_mut()
61 }
62
63 #[inline]
64 fn hugr(&self) -> &Hugr {
65 self.base.as_ref()
66 }
67}
68
69impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for ConditionalBuilder<H> {
70 type ContainerHandle = BuildHandle<ConditionalID>;
71
72 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
73 let cases: HashSet<usize> = self
74 .case_nodes
75 .iter()
76 .enumerate()
77 .filter_map(|(i, node)| if node.is_none() { Some(i) } else { None })
78 .collect();
79 if !cases.is_empty() {
80 return Err(ConditionalBuildError::NotAllCasesBuilt {
81 conditional: self.conditional_node,
82 cases,
83 }
84 .into());
85 }
86 Ok((self.conditional_node, self.n_out_wires).into())
87 }
88}
89impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
90 pub fn case_builder(&mut self, case: usize) -> Result<CaseBuilder<&mut Hugr>, BuildError> {
101 let conditional = self.conditional_node;
102 let control_op = self.hugr().get_optype(self.conditional_node);
103
104 let cond: ops::Conditional = control_op
105 .clone()
106 .try_into()
107 .expect("Parent node does not have Conditional optype.");
108 let inputs = cond
109 .case_input_row(case)
110 .ok_or(ConditionalBuildError::NotCase { conditional, case })?;
111
112 if self.case_nodes.get(case).unwrap().is_some() {
113 return Err(ConditionalBuildError::CaseBuilt { conditional, case }.into());
114 }
115
116 let outputs = cond.outputs;
117 let case_op = ops::Case {
118 signature: Signature::new(inputs.clone(), outputs.clone()),
119 };
120 let case_node =
121 if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
123 self.hugr_mut().add_node_before(sibling_node, case_op)
124 } else {
125 self.add_child_node(case_op)
126 };
127
128 self.case_nodes[case] = Some(case_node);
129
130 let dfg_builder = DFGBuilder::create_with_io(
131 self.hugr_mut(),
132 case_node,
133 Signature::new(inputs, outputs),
134 )?;
135
136 Ok(CaseBuilder::from_dfg_builder(dfg_builder))
137 }
138}
139
140impl HugrBuilder for ConditionalBuilder<Hugr> {
141 fn finish_hugr(self) -> Result<Hugr, crate::hugr::ValidationError<Node>> {
142 self.base.validate()?;
143 Ok(self.base)
144 }
145}
146
147impl ConditionalBuilder<Hugr> {
148 pub fn new(
150 sum_rows: impl IntoIterator<Item = TypeRow>,
151 other_inputs: impl Into<TypeRow>,
152 outputs: impl Into<TypeRow>,
153 ) -> Result<Self, BuildError> {
154 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
155 let other_inputs = other_inputs.into();
156 let outputs: TypeRow = outputs.into();
157
158 let n_out_wires = outputs.len();
159 let n_cases = sum_rows.len();
160
161 let op = ops::Conditional {
162 sum_rows,
163 other_inputs,
164 outputs,
165 };
166 let base = Hugr::new_with_entrypoint(op).expect("Conditional entrypoint should be valid");
167 let conditional_node = base.entrypoint();
168
169 Ok(ConditionalBuilder {
170 base,
171 conditional_node,
172 n_out_wires,
173 case_nodes: vec![None; n_cases],
174 })
175 }
176}
177
178impl CaseBuilder<Hugr> {
179 pub fn new(signature: Signature) -> Result<Self, BuildError> {
181 let mut conditional =
183 ConditionalBuilder::new([signature.input.clone()], vec![], signature.output.clone())?;
184 let case = conditional.case_builder(0)?.finish_sub_container()?.node();
185
186 let mut base = std::mem::take(conditional.hugr_mut());
188 base.set_entrypoint(case);
189 let dfg_builder = DFGBuilder::create(base, case)?;
190 Ok(CaseBuilder::from_dfg_builder(dfg_builder))
191 }
192}
193#[cfg(test)]
194mod test {
195 use cool_asserts::assert_matches;
196
197 use crate::builder::{DataflowSubContainer, ModuleBuilder};
198
199 use crate::extension::prelude::usize_t;
200 use crate::{
201 builder::{Dataflow, test::n_identity},
202 ops::Value,
203 type_row,
204 };
205
206 use super::*;
207
208 #[test]
209 fn basic_conditional_case() -> Result<(), BuildError> {
210 let case_b = CaseBuilder::new(Signature::new_endo(vec![usize_t(), usize_t()]))?;
211 let [in0, in1] = case_b.input_wires_arr();
212 case_b.finish_with_outputs([in0, in1])?;
213 Ok(())
214 }
215
216 #[test]
217 fn basic_conditional() -> Result<(), BuildError> {
218 let mut conditional_b =
219 ConditionalBuilder::new([type_row![], type_row![]], vec![usize_t()], vec![usize_t()])?;
220
221 n_identity(conditional_b.case_builder(1)?)?;
222 n_identity(conditional_b.case_builder(0)?)?;
223 Ok(())
224 }
225
226 #[test]
227 fn basic_conditional_module() -> Result<(), BuildError> {
228 let build_result: Result<Hugr, BuildError> = {
229 let mut module_builder = ModuleBuilder::new();
230 let mut fbuild = module_builder
231 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
232 let tru_const = fbuild.add_constant(Value::true_val());
233 let _fdef = {
234 let const_wire = fbuild.load_const(&tru_const);
235 let [int] = fbuild.input_wires_arr();
236 let conditional_id = {
237 let other_inputs = vec![(usize_t(), int)];
238 let outputs = vec![usize_t()].into();
239 let mut conditional_b = fbuild.conditional_builder(
240 ([type_row![], type_row![]], const_wire),
241 other_inputs,
242 outputs,
243 )?;
244
245 n_identity(conditional_b.case_builder(0)?)?;
246 n_identity(conditional_b.case_builder(1)?)?;
247
248 conditional_b.finish_sub_container()?
249 };
250 let [int] = conditional_id.outputs_arr();
251 fbuild.finish_with_outputs([int])?
252 };
253 Ok(module_builder.finish_hugr()?)
254 };
255
256 assert_matches!(build_result, Ok(_));
257
258 Ok(())
259 }
260
261 #[test]
262 fn test_not_all_cases() -> Result<(), BuildError> {
263 let mut builder =
264 ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
265 n_identity(builder.case_builder(0)?)?;
266 assert_matches!(
267 builder.finish_sub_container().map(|_| ()),
268 Err(BuildError::ConditionalError(
269 ConditionalBuildError::NotAllCasesBuilt { .. }
270 ))
271 );
272 Ok(())
273 }
274
275 #[test]
276 fn test_case_already_built() -> Result<(), BuildError> {
277 let mut builder =
278 ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
279 n_identity(builder.case_builder(0)?)?;
280 assert_matches!(
281 builder.case_builder(0).map(|_| ()),
282 Err(BuildError::ConditionalError(
283 ConditionalBuildError::CaseBuilt { .. }
284 ))
285 );
286 Ok(())
287 }
288}