1use crate::extension::TO_BE_INFERRED;
2use crate::hugr::views::HugrView;
3use crate::ops::dataflow::DataflowOpTrait;
4use crate::types::{Signature, TypeRow};
5
6use crate::ops;
7use crate::ops::handle::CaseID;
8
9use super::build_traits::SubContainer;
10use super::handle::BuildHandle;
11use super::HugrBuilder;
12use super::{
13 build_traits::Container,
14 dataflow::{DFGBuilder, DFGWrapper},
15 BuildError, ConditionalID,
16};
17
18use crate::Node;
19use crate::{extension::ExtensionSet, hugr::HugrMut, Hugr};
20
21use std::collections::HashSet;
22
23use thiserror::Error;
24
25pub type CaseBuilder<B> = DFGWrapper<B, BuildHandle<CaseID>>;
27
28#[derive(Debug, Clone, PartialEq, Eq, Error)]
29#[non_exhaustive]
30pub enum ConditionalBuildError {
31 #[error("Case {case} of Conditional node {conditional} has already been built.")]
33 CaseBuilt { conditional: Node, case: usize },
34 #[error("Conditional node {conditional} has no case with index {case}.")]
36 NotCase { conditional: Node, case: usize },
37 #[error("Cases {cases:?} of Conditional node {conditional} have not been built.")]
39 NotAllCasesBuilt {
40 conditional: Node,
41 cases: HashSet<usize>,
42 },
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub struct ConditionalBuilder<T> {
48 pub(super) base: T,
49 pub(super) conditional_node: Node,
50 pub(super) n_out_wires: usize,
51 pub(super) case_nodes: Vec<Option<Node>>,
52}
53
54impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ConditionalBuilder<T> {
55 #[inline]
56 fn container_node(&self) -> Node {
57 self.conditional_node
58 }
59
60 #[inline]
61 fn hugr_mut(&mut self) -> &mut Hugr {
62 self.base.as_mut()
63 }
64
65 #[inline]
66 fn hugr(&self) -> &Hugr {
67 self.base.as_ref()
68 }
69}
70
71impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for ConditionalBuilder<H> {
72 type ContainerHandle = BuildHandle<ConditionalID>;
73
74 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
75 let cases: HashSet<usize> = self
76 .case_nodes
77 .iter()
78 .enumerate()
79 .filter_map(|(i, node)| if node.is_none() { Some(i) } else { None })
80 .collect();
81 if !cases.is_empty() {
82 return Err(ConditionalBuildError::NotAllCasesBuilt {
83 conditional: self.conditional_node,
84 cases,
85 }
86 .into());
87 }
88 Ok((self.conditional_node, self.n_out_wires).into())
89 }
90}
91impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
92 pub fn case_builder(&mut self, case: usize) -> Result<CaseBuilder<&mut Hugr>, BuildError> {
103 let conditional = self.conditional_node;
104 let control_op = self.hugr().get_optype(self.conditional_node);
105
106 let cond: ops::Conditional = control_op
107 .clone()
108 .try_into()
109 .expect("Parent node does not have Conditional optype.");
110 let extension_delta = cond.signature().runtime_reqs.clone();
111 let inputs = cond
112 .case_input_row(case)
113 .ok_or(ConditionalBuildError::NotCase { conditional, case })?;
114
115 if self.case_nodes.get(case).unwrap().is_some() {
116 return Err(ConditionalBuildError::CaseBuilt { conditional, case }.into());
117 }
118
119 let outputs = cond.outputs;
120 let case_op = ops::Case {
121 signature: Signature::new(inputs.clone(), outputs.clone())
122 .with_extension_delta(extension_delta.clone()),
123 };
124 let case_node =
125 if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
127 self.hugr_mut().add_node_before(sibling_node, case_op)
128 } else {
129 self.add_child_node(case_op)
130 };
131
132 self.case_nodes[case] = Some(case_node);
133
134 let dfg_builder = DFGBuilder::create_with_io(
135 self.hugr_mut(),
136 case_node,
137 Signature::new(inputs, outputs).with_extension_delta(extension_delta),
138 )?;
139
140 Ok(CaseBuilder::from_dfg_builder(dfg_builder))
141 }
142}
143
144impl HugrBuilder for ConditionalBuilder<Hugr> {
145 fn finish_hugr(mut self) -> Result<Hugr, crate::hugr::ValidationError> {
146 if cfg!(feature = "extension_inference") {
147 self.base.infer_extensions(false)?;
148 }
149 self.base.validate()?;
150 Ok(self.base)
151 }
152}
153
154impl ConditionalBuilder<Hugr> {
155 pub fn new(
157 sum_rows: impl IntoIterator<Item = TypeRow>,
158 other_inputs: impl Into<TypeRow>,
159 outputs: impl Into<TypeRow>,
160 ) -> Result<Self, BuildError> {
161 Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED)
162 }
163
164 pub fn new_exts(
168 sum_rows: impl IntoIterator<Item = TypeRow>,
169 other_inputs: impl Into<TypeRow>,
170 outputs: impl Into<TypeRow>,
171 extension_delta: impl Into<ExtensionSet>,
172 ) -> Result<Self, BuildError> {
173 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
174 let other_inputs = other_inputs.into();
175 let outputs = outputs.into();
176
177 let n_out_wires = outputs.len();
178 let n_cases = sum_rows.len();
179
180 let op = ops::Conditional {
181 sum_rows,
182 other_inputs,
183 outputs,
184 extension_delta: extension_delta.into(),
185 };
186 let base = Hugr::new(op);
187 let conditional_node = base.root();
188
189 Ok(ConditionalBuilder {
190 base,
191 conditional_node,
192 n_out_wires,
193 case_nodes: vec![None; n_cases],
194 })
195 }
196}
197
198impl CaseBuilder<Hugr> {
199 pub fn new(signature: Signature) -> Result<Self, BuildError> {
201 let op = ops::Case {
202 signature: signature.clone(),
203 };
204 let base = Hugr::new(op);
205 let root = base.root();
206 let dfg_builder = DFGBuilder::create_with_io(base, root, signature)?;
207
208 Ok(CaseBuilder::from_dfg_builder(dfg_builder))
209 }
210}
211#[cfg(test)]
212mod test {
213 use cool_asserts::assert_matches;
214
215 use crate::builder::{DataflowSubContainer, ModuleBuilder};
216
217 use crate::extension::prelude::usize_t;
218 use crate::{
219 builder::{test::n_identity, Dataflow},
220 ops::Value,
221 type_row,
222 };
223
224 use super::*;
225
226 #[test]
227 fn basic_conditional() -> Result<(), BuildError> {
228 let mut conditional_b = ConditionalBuilder::new_exts(
229 [type_row![], type_row![]],
230 vec![usize_t()],
231 vec![usize_t()],
232 ExtensionSet::new(),
233 )?;
234
235 n_identity(conditional_b.case_builder(1)?)?;
236 n_identity(conditional_b.case_builder(0)?)?;
237 Ok(())
238 }
239
240 #[test]
241 fn basic_conditional_module() -> Result<(), BuildError> {
242 let build_result: Result<Hugr, BuildError> = {
243 let mut module_builder = ModuleBuilder::new();
244 let mut fbuild = module_builder
245 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
246 let tru_const = fbuild.add_constant(Value::true_val());
247 let _fdef = {
248 let const_wire = fbuild.load_const(&tru_const);
249 let [int] = fbuild.input_wires_arr();
250 let conditional_id = {
251 let other_inputs = vec![(usize_t(), int)];
252 let outputs = vec![usize_t()].into();
253 let mut conditional_b = fbuild.conditional_builder(
254 ([type_row![], type_row![]], const_wire),
255 other_inputs,
256 outputs,
257 )?;
258
259 n_identity(conditional_b.case_builder(0)?)?;
260 n_identity(conditional_b.case_builder(1)?)?;
261
262 conditional_b.finish_sub_container()?
263 };
264 let [int] = conditional_id.outputs_arr();
265 fbuild.finish_with_outputs([int])?
266 };
267 Ok(module_builder.finish_hugr()?)
268 };
269
270 assert_matches!(build_result, Ok(_));
271
272 Ok(())
273 }
274
275 #[test]
276 fn test_not_all_cases() -> Result<(), BuildError> {
277 let mut builder =
278 ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
279 n_identity(builder.case_builder(0)?)?;
280 assert_matches!(
281 builder.finish_sub_container().map(|_| ()),
282 Err(BuildError::ConditionalError(
283 ConditionalBuildError::NotAllCasesBuilt { .. }
284 ))
285 );
286 Ok(())
287 }
288
289 #[test]
290 fn test_case_already_built() -> Result<(), BuildError> {
291 let mut builder =
292 ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
293 n_identity(builder.case_builder(0)?)?;
294 assert_matches!(
295 builder.case_builder(0).map(|_| ()),
296 Err(BuildError::ConditionalError(
297 ConditionalBuildError::CaseBuilt { .. }
298 ))
299 );
300 Ok(())
301 }
302}