1use super::{
2 build_traits::SubContainer,
3 dataflow::{DFGBuilder, DFGWrapper},
4 handle::BuildHandle,
5 BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
6};
7
8use crate::extension::TO_BE_INFERRED;
9use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType};
10use crate::{extension::ExtensionSet, types::Signature};
11use crate::{hugr::views::HugrView, types::TypeRow};
12
13use crate::Node;
14use crate::{hugr::HugrMut, type_row, Hugr};
15
16#[derive(Debug, PartialEq)]
113pub struct CFGBuilder<T> {
114 pub(super) base: T,
115 pub(super) cfg_node: Node,
116 pub(super) inputs: Option<TypeRow>,
117 pub(super) exit_node: Node,
118 pub(super) n_out_wires: usize,
119}
120
121impl<B: AsMut<Hugr> + AsRef<Hugr>> Container for CFGBuilder<B> {
122 #[inline]
123 fn container_node(&self) -> Node {
124 self.cfg_node
125 }
126
127 #[inline]
128 fn hugr_mut(&mut self) -> &mut Hugr {
129 self.base.as_mut()
130 }
131
132 #[inline]
133 fn hugr(&self) -> &Hugr {
134 self.base.as_ref()
135 }
136}
137
138impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {
139 type ContainerHandle = BuildHandle<CfgID>;
140 #[inline]
141 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
142 Ok((self.cfg_node, self.n_out_wires).into())
143 }
144}
145
146impl CFGBuilder<Hugr> {
147 pub fn new(signature: Signature) -> Result<Self, BuildError> {
149 let cfg_op = ops::CFG {
150 signature: signature.clone(),
151 };
152
153 let base = Hugr::new(cfg_op);
154 let cfg_node = base.root();
155 CFGBuilder::create(base, cfg_node, signature.input, signature.output)
156 }
157}
158
159impl HugrBuilder for CFGBuilder<Hugr> {
160 fn finish_hugr(mut self) -> Result<Hugr, crate::hugr::ValidationError> {
161 if cfg!(feature = "extension_inference") {
162 self.base.infer_extensions(false)?;
163 }
164 self.base.validate()?;
165 Ok(self.base)
166 }
167}
168
169impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
170 pub(super) fn create(
171 mut base: B,
172 cfg_node: Node,
173 input: TypeRow,
174 output: TypeRow,
175 ) -> Result<Self, BuildError> {
176 let n_out_wires = output.len();
177 let exit_block_type = OpType::ExitBlock(ExitBlock {
178 cfg_outputs: output,
179 });
180 let exit_node = base
181 .as_mut()
182 .add_node_with_parent(cfg_node, exit_block_type);
184 Ok(Self {
185 base,
186 cfg_node,
187 n_out_wires,
188 exit_node,
189 inputs: Some(input),
190 })
191 }
192
193 pub fn block_builder(
201 &mut self,
202 inputs: TypeRow,
203 sum_rows: impl IntoIterator<Item = TypeRow>,
204 other_outputs: TypeRow,
205 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
206 self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
207 }
208
209 pub fn block_builder_exts(
217 &mut self,
218 inputs: TypeRow,
219 sum_rows: impl IntoIterator<Item = TypeRow>,
220 other_outputs: TypeRow,
221 extension_delta: impl Into<ExtensionSet>,
222 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
223 self.any_block_builder(
224 inputs,
225 extension_delta.into(),
226 sum_rows,
227 other_outputs,
228 false,
229 )
230 }
231
232 fn any_block_builder(
233 &mut self,
234 inputs: TypeRow,
235 extension_delta: ExtensionSet,
236 sum_rows: impl IntoIterator<Item = TypeRow>,
237 other_outputs: TypeRow,
238 entry: bool,
239 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
240 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
241 let op = OpType::DataflowBlock(DataflowBlock {
242 inputs: inputs.clone(),
243 other_outputs: other_outputs.clone(),
244 sum_rows,
245 extension_delta,
246 });
247 let parent = self.container_node();
248 let block_n = if entry {
249 let exit = self.exit_node;
250 self.hugr_mut().add_node_before(exit, op)
252 } else {
253 self.hugr_mut().add_node_with_parent(parent, op)
255 };
256
257 BlockBuilder::create(self.hugr_mut(), block_n)
258 }
259
260 pub fn simple_block_builder(
268 &mut self,
269 signature: Signature,
270 n_cases: usize,
271 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
272 self.block_builder_exts(
273 signature.input,
274 vec![type_row![]; n_cases],
275 signature.output,
276 signature.runtime_reqs,
277 )
278 }
279
280 pub fn entry_builder(
288 &mut self,
289 sum_rows: impl IntoIterator<Item = TypeRow>,
290 other_outputs: TypeRow,
291 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
292 self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED)
293 }
294
295 pub fn entry_builder_exts(
304 &mut self,
305 sum_rows: impl IntoIterator<Item = TypeRow>,
306 other_outputs: TypeRow,
307 extension_delta: impl Into<ExtensionSet>,
308 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
309 let inputs = self
310 .inputs
311 .take()
312 .ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
313 self.any_block_builder(
314 inputs,
315 extension_delta.into(),
316 sum_rows,
317 other_outputs,
318 true,
319 )
320 }
321
322 pub fn simple_entry_builder(
329 &mut self,
330 outputs: TypeRow,
331 n_cases: usize,
332 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
333 self.entry_builder(vec![type_row![]; n_cases], outputs)
334 }
335
336 pub fn simple_entry_builder_exts(
344 &mut self,
345 outputs: TypeRow,
346 n_cases: usize,
347 extension_delta: impl Into<ExtensionSet>,
348 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
349 self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta)
350 }
351
352 pub fn exit_block(&self) -> BasicBlockID {
354 self.exit_node.into()
355 }
356
357 pub fn branch(
363 &mut self,
364 predecessor: &BasicBlockID,
365 branch: usize,
366 successor: &BasicBlockID,
367 ) -> Result<(), BuildError> {
368 let from = predecessor.node();
369 let to = successor.node();
370 self.hugr_mut().connect(from, branch, to, 0);
371 Ok(())
372 }
373}
374
375pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;
377
378impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
379 pub fn set_outputs(
382 &mut self,
383 branch_wire: Wire,
384 outputs: impl IntoIterator<Item = Wire>,
385 ) -> Result<(), BuildError> {
386 Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
387 }
388 fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
389 let block_op = base
390 .as_ref()
391 .get_optype(block_n)
392 .as_dataflow_block()
393 .unwrap();
394 let signature = block_op.inner_signature().into_owned();
395 let db = DFGBuilder::create_with_io(base, block_n, signature)?;
396 Ok(BlockBuilder::from_dfg_builder(db))
397 }
398
399 pub fn finish_with_outputs(
401 mut self,
402 branch_wire: Wire,
403 outputs: impl IntoIterator<Item = Wire>,
404 ) -> Result<<Self as SubContainer>::ContainerHandle, BuildError>
405 where
406 Self: Sized,
407 {
408 self.set_outputs(branch_wire, outputs)?;
409 self.finish_sub_container()
410 }
411}
412
413impl BlockBuilder<Hugr> {
414 pub fn new(
417 inputs: impl Into<TypeRow>,
418 sum_rows: impl IntoIterator<Item = TypeRow>,
419 other_outputs: impl Into<TypeRow>,
420 ) -> Result<Self, BuildError> {
421 Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
422 }
423
424 pub fn new_exts(
428 inputs: impl Into<TypeRow>,
429 sum_rows: impl IntoIterator<Item = TypeRow>,
430 other_outputs: impl Into<TypeRow>,
431 extension_delta: impl Into<ExtensionSet>,
432 ) -> Result<Self, BuildError> {
433 let inputs = inputs.into();
434 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
435 let other_outputs = other_outputs.into();
436 let op = DataflowBlock {
437 inputs: inputs.clone(),
438 other_outputs: other_outputs.clone(),
439 sum_rows,
440 extension_delta: extension_delta.into(),
441 };
442
443 let base = Hugr::new(op);
444 let root = base.root();
445 Self::create(base, root)
446 }
447
448 pub fn finish_hugr_with_outputs(
450 mut self,
451 branch_wire: Wire,
452 outputs: impl IntoIterator<Item = Wire>,
453 ) -> Result<Hugr, BuildError> {
454 self.set_outputs(branch_wire, outputs)?;
455 self.finish_hugr().map_err(BuildError::InvalidHUGR)
456 }
457}
458
459#[cfg(test)]
460pub(crate) mod test {
461 use crate::builder::{DataflowSubContainer, ModuleBuilder};
462
463 use crate::extension::prelude::usize_t;
464 use crate::hugr::validate::InterGraphEdgeError;
465 use crate::hugr::ValidationError;
466 use crate::type_row;
467 use cool_asserts::assert_matches;
468
469 use super::*;
470 #[test]
471 fn basic_module_cfg() -> Result<(), BuildError> {
472 let build_result = {
473 let mut module_builder = ModuleBuilder::new();
474 let mut func_builder = module_builder
475 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
476 let _f_id = {
477 let [int] = func_builder.input_wires_arr();
478
479 let cfg_id = {
480 let mut cfg_builder =
481 func_builder.cfg_builder(vec![(usize_t(), int)], vec![usize_t()].into())?;
482 build_basic_cfg(&mut cfg_builder)?;
483
484 cfg_builder.finish_sub_container()?
485 };
486
487 func_builder.finish_with_outputs(cfg_id.outputs())?
488 };
489 module_builder.finish_hugr()
490 };
491
492 assert!(build_result.is_ok(), "{}", build_result.unwrap_err());
493
494 Ok(())
495 }
496 #[test]
497 fn basic_cfg_hugr() -> Result<(), BuildError> {
498 let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
499 build_basic_cfg(&mut cfg_builder)?;
500 assert_matches!(cfg_builder.finish_hugr(), Ok(_));
501
502 Ok(())
503 }
504
505 pub(crate) fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
506 cfg_builder: &mut CFGBuilder<T>,
507 ) -> Result<(), BuildError> {
508 let usize_row: TypeRow = vec![usize_t()].into();
509 let sum2_variants = vec![usize_row.clone(), usize_row];
510 let mut entry_b = cfg_builder.entry_builder_exts(
511 sum2_variants.clone(),
512 type_row![],
513 ExtensionSet::new(),
514 )?;
515 let entry = {
516 let [inw] = entry_b.input_wires_arr();
517
518 let sum = entry_b.make_sum(1, sum2_variants, [inw])?;
519 entry_b.finish_with_outputs(sum, [])?
520 };
521 let mut middle_b = cfg_builder
522 .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
523 let middle = {
524 let c = middle_b.add_load_const(ops::Value::unary_unit_sum());
525 let [inw] = middle_b.input_wires_arr();
526 middle_b.finish_with_outputs(c, [inw])?
527 };
528 let exit = cfg_builder.exit_block();
529 cfg_builder.branch(&entry, 0, &middle)?;
530 cfg_builder.branch(&middle, 0, &exit)?;
531 cfg_builder.branch(&entry, 1, &exit)?;
532 Ok(())
533 }
534 #[test]
535 fn test_dom_edge() -> Result<(), BuildError> {
536 let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
537 let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
538 let sum_variants = vec![type_row![]];
539
540 let mut entry_b = cfg_builder.entry_builder_exts(
541 sum_variants.clone(),
542 type_row![],
543 ExtensionSet::new(),
544 )?;
545 let [inw] = entry_b.input_wires_arr();
546 let entry = {
547 let sum = entry_b.load_const(&sum_tuple_const);
548
549 entry_b.finish_with_outputs(sum, [])?
550 };
551 let mut middle_b =
552 cfg_builder.simple_block_builder(Signature::new(type_row![], vec![usize_t()]), 1)?;
553 let middle = {
554 let c = middle_b.load_const(&sum_tuple_const);
555 middle_b.finish_with_outputs(c, [inw])?
556 };
557 let exit = cfg_builder.exit_block();
558 cfg_builder.branch(&entry, 0, &middle)?;
559 cfg_builder.branch(&middle, 0, &exit)?;
560 assert_matches!(cfg_builder.finish_hugr(), Ok(_));
561
562 Ok(())
563 }
564
565 #[test]
566 fn test_non_dom_edge() -> Result<(), BuildError> {
567 let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
568 let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
569 let sum_variants = vec![type_row![]];
570 let mut middle_b = cfg_builder
571 .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
572 let [inw] = middle_b.input_wires_arr();
573 let middle = {
574 let c = middle_b.load_const(&sum_tuple_const);
575 middle_b.finish_with_outputs(c, [inw])?
576 };
577
578 let mut entry_b =
579 cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
580 let entry = {
581 let sum = entry_b.load_const(&sum_tuple_const);
582 entry_b.finish_with_outputs(sum, [inw])?
585 };
586 let exit = cfg_builder.exit_block();
587 cfg_builder.branch(&entry, 0, &middle)?;
588 cfg_builder.branch(&middle, 0, &exit)?;
589 assert_matches!(
590 cfg_builder.finish_hugr(),
591 Err(ValidationError::InterGraphEdgeError(
592 InterGraphEdgeError::NonDominatedAncestor { .. }
593 ))
594 );
595
596 Ok(())
597 }
598}