1use super::{
2 BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
3 build_traits::SubContainer,
4 dataflow::{DFGBuilder, DFGWrapper},
5 handle::BuildHandle,
6};
7
8use crate::ops::{self, DataflowBlock, DataflowParent, ExitBlock, OpType, handle::NodeHandle};
9use crate::types::Signature;
10use crate::{hugr::views::HugrView, types::TypeRow};
11
12use crate::Node;
13use crate::{Hugr, hugr::HugrMut, type_row};
14
15#[derive(Debug, PartialEq)]
111pub struct CFGBuilder<T> {
112 pub(super) base: T,
113 pub(super) cfg_node: Node,
114 pub(super) inputs: Option<TypeRow>,
115 pub(super) exit_node: Node,
116 pub(super) n_out_wires: usize,
117}
118
119impl<B: AsMut<Hugr> + AsRef<Hugr>> Container for CFGBuilder<B> {
120 #[inline]
121 fn container_node(&self) -> Node {
122 self.cfg_node
123 }
124
125 #[inline]
126 fn hugr_mut(&mut self) -> &mut Hugr {
127 self.base.as_mut()
128 }
129
130 #[inline]
131 fn hugr(&self) -> &Hugr {
132 self.base.as_ref()
133 }
134}
135
136impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {
137 type ContainerHandle = BuildHandle<CfgID>;
138 #[inline]
139 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
140 Ok((self.cfg_node, self.n_out_wires).into())
141 }
142}
143
144impl CFGBuilder<Hugr> {
145 pub fn new(signature: Signature) -> Result<Self, BuildError> {
147 let cfg_op = ops::CFG {
148 signature: signature.clone(),
149 };
150
151 let base = Hugr::new_with_entrypoint(cfg_op).expect("CFG entrypoints be valid");
152 let cfg_node = base.entrypoint();
153 CFGBuilder::create(base, cfg_node, signature.input, signature.output)
154 }
155}
156
157impl HugrBuilder for CFGBuilder<Hugr> {
158 fn finish_hugr(self) -> Result<Hugr, crate::hugr::ValidationError<Node>> {
159 self.base.validate()?;
160 Ok(self.base)
161 }
162}
163
164impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
165 pub(super) fn create(
166 mut base: B,
167 cfg_node: Node,
168 input: TypeRow,
169 output: TypeRow,
170 ) -> Result<Self, BuildError> {
171 let n_out_wires = output.len();
172 let exit_block_type = OpType::ExitBlock(ExitBlock {
173 cfg_outputs: output,
174 });
175 let exit_node = base
176 .as_mut()
177 .add_node_with_parent(cfg_node, exit_block_type);
179 Ok(Self {
180 base,
181 cfg_node,
182 n_out_wires,
183 exit_node,
184 inputs: Some(input),
185 })
186 }
187
188 pub fn block_builder(
196 &mut self,
197 inputs: TypeRow,
198 sum_rows: impl IntoIterator<Item = TypeRow>,
199 other_outputs: TypeRow,
200 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
201 self.any_block_builder(inputs, sum_rows, other_outputs, false)
202 }
203
204 fn any_block_builder(
205 &mut self,
206 inputs: TypeRow,
207 sum_rows: impl IntoIterator<Item = TypeRow>,
208 other_outputs: TypeRow,
209 entry: bool,
210 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
211 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
212 let op = OpType::DataflowBlock(DataflowBlock {
213 inputs: inputs.clone(),
214 other_outputs: other_outputs.clone(),
215 sum_rows,
216 });
217 let parent = self.container_node();
218 let block_n = if entry {
219 let exit = self.exit_node;
220 self.hugr_mut().add_node_before(exit, op)
222 } else {
223 self.hugr_mut().add_node_with_parent(parent, op)
225 };
226
227 BlockBuilder::create_with_io(self.hugr_mut(), block_n)
228 }
229
230 pub fn simple_block_builder(
238 &mut self,
239 signature: Signature,
240 n_cases: usize,
241 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
242 self.block_builder(
243 signature.input,
244 vec![type_row![]; n_cases],
245 signature.output,
246 )
247 }
248
249 pub fn entry_builder(
256 &mut self,
257 sum_rows: impl IntoIterator<Item = TypeRow>,
258 other_outputs: TypeRow,
259 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
260 let inputs = self
261 .inputs
262 .take()
263 .ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
264 self.any_block_builder(inputs, sum_rows, other_outputs, true)
265 }
266
267 pub fn simple_entry_builder(
274 &mut self,
275 outputs: TypeRow,
276 n_cases: usize,
277 ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
278 self.entry_builder(vec![type_row![]; n_cases], outputs)
279 }
280
281 pub fn exit_block(&self) -> BasicBlockID {
283 self.exit_node.into()
284 }
285
286 pub fn branch(
292 &mut self,
293 predecessor: &BasicBlockID,
294 branch: usize,
295 successor: &BasicBlockID,
296 ) -> Result<(), BuildError> {
297 let from = predecessor.node();
298 let to = successor.node();
299 self.hugr_mut().connect(from, branch, to, 0);
300 Ok(())
301 }
302}
303
304pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;
306
307impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
308 pub fn set_outputs(
311 &mut self,
312 branch_wire: Wire,
313 outputs: impl IntoIterator<Item = Wire>,
314 ) -> Result<(), BuildError> {
315 Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
316 }
317
318 fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
327 let db = DFGBuilder::create(base, block_n)?;
328 Ok(BlockBuilder::from_dfg_builder(db))
329 }
330
331 fn create_with_io(base: B, block_n: Node) -> Result<Self, BuildError> {
340 let block_op = base
341 .as_ref()
342 .get_optype(block_n)
343 .as_dataflow_block()
344 .unwrap();
345 let signature = block_op.inner_signature().into_owned();
346 let db = DFGBuilder::create_with_io(base, block_n, signature)?;
347 Ok(BlockBuilder::from_dfg_builder(db))
348 }
349
350 pub fn finish_with_outputs(
352 mut self,
353 branch_wire: Wire,
354 outputs: impl IntoIterator<Item = Wire>,
355 ) -> Result<<Self as SubContainer>::ContainerHandle, BuildError>
356 where
357 Self: Sized,
358 {
359 self.set_outputs(branch_wire, outputs)?;
360 self.finish_sub_container()
361 }
362}
363
364impl BlockBuilder<Hugr> {
365 pub fn new(
367 inputs: impl Into<TypeRow>,
368 sum_rows: impl IntoIterator<Item = TypeRow>,
369 other_outputs: impl Into<TypeRow>,
370 ) -> Result<Self, BuildError> {
371 let inputs = inputs.into();
372 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
373 let other_outputs: TypeRow = other_outputs.into();
374 let num_out_branches = sum_rows.len();
375
376 if let Some(row) = sum_rows.first() {
379 if sum_rows.iter().skip(1).any(|r2| row != r2) {
380 return Err(BuildError::BasicBlockTooComplex);
381 }
382 }
383 let cfg_outputs = sum_rows.first().cloned().unwrap_or_default();
384 let cfg_outputs = cfg_outputs.extend(other_outputs.as_slice());
385
386 let mut cfg = CFGBuilder::new(Signature::new(inputs, cfg_outputs))?;
387 let block = cfg.entry_builder(sum_rows, other_outputs)?;
388 let block = block.finish_sub_container()?;
389 for i in 0..num_out_branches {
390 cfg.branch(&block, i, &cfg.exit_block())?;
391 }
392 let mut base = std::mem::take(cfg.hugr_mut());
393 let root = block.node();
394 base.set_entrypoint(root);
395 Self::create(base, root)
396 }
397
398 pub fn finish_hugr_with_outputs(
400 mut self,
401 branch_wire: Wire,
402 outputs: impl IntoIterator<Item = Wire>,
403 ) -> Result<Hugr, BuildError> {
404 self.set_outputs(branch_wire, outputs)?;
405 self.finish_hugr().map_err(BuildError::InvalidHUGR)
406 }
407}
408
409#[cfg(test)]
410pub(crate) mod test {
411 use crate::builder::{DataflowSubContainer, ModuleBuilder};
412
413 use crate::extension::prelude::{bool_t, usize_t};
414 use crate::hugr::ValidationError;
415 use crate::hugr::validate::InterGraphEdgeError;
416 use crate::type_row;
417 use cool_asserts::assert_matches;
418
419 use super::*;
420 #[test]
421 fn basic_module_cfg() -> Result<(), BuildError> {
422 let build_result = {
423 let mut module_builder = ModuleBuilder::new();
424 let mut func_builder = module_builder
425 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
426 let _f_id = {
427 let [int] = func_builder.input_wires_arr();
428
429 let cfg_id = {
430 let mut cfg_builder =
431 func_builder.cfg_builder(vec![(usize_t(), int)], vec![usize_t()].into())?;
432 build_basic_cfg(&mut cfg_builder)?;
433
434 cfg_builder.finish_sub_container()?
435 };
436
437 func_builder.finish_with_outputs(cfg_id.outputs())?
438 };
439 module_builder.finish_hugr()
440 };
441
442 assert!(build_result.is_ok(), "{}", build_result.unwrap_err());
443
444 Ok(())
445 }
446 #[test]
447 fn basic_cfg_hugr() -> Result<(), BuildError> {
448 let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
449 build_basic_cfg(&mut cfg_builder)?;
450 assert_matches!(cfg_builder.finish_hugr(), Ok(_));
451
452 Ok(())
453 }
454
455 #[test]
456 fn basic_cfg_block() -> Result<(), BuildError> {
457 assert_eq!(
458 BlockBuilder::new(
459 vec![],
460 [vec![usize_t()].into(), vec![bool_t()].into()],
461 vec![]
462 ),
463 Err(BuildError::BasicBlockTooComplex)
464 );
465
466 let sum_rows: Vec<TypeRow> = vec![vec![usize_t()].into(), vec![usize_t()].into()];
467 let mut block_builder =
468 BlockBuilder::new(vec![usize_t()], sum_rows.clone(), vec![usize_t()])?;
469 let [inp] = block_builder.input_wires_arr();
470 let branch = block_builder.make_sum(0, sum_rows, [inp])?;
471 let hugr = block_builder.finish_hugr_with_outputs(branch, [inp])?;
472
473 hugr.validate().unwrap();
474
475 Ok(())
476 }
477
478 pub(crate) fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
479 cfg_builder: &mut CFGBuilder<T>,
480 ) -> Result<(), BuildError> {
481 let usize_row: TypeRow = vec![usize_t()].into();
482 let sum2_variants = vec![usize_row.clone(), usize_row];
483 let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?;
484 let entry = {
485 let [inw] = entry_b.input_wires_arr();
486
487 let sum = entry_b.make_sum(1, sum2_variants, [inw])?;
488 entry_b.finish_with_outputs(sum, [])?
489 };
490 let mut middle_b = cfg_builder
491 .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
492 let middle = {
493 let c = middle_b.add_load_const(ops::Value::unary_unit_sum());
494 let [inw] = middle_b.input_wires_arr();
495 middle_b.finish_with_outputs(c, [inw])?
496 };
497 let exit = cfg_builder.exit_block();
498 cfg_builder.branch(&entry, 0, &middle)?;
499 cfg_builder.branch(&middle, 0, &exit)?;
500 cfg_builder.branch(&entry, 1, &exit)?;
501 Ok(())
502 }
503 #[test]
504 fn test_dom_edge() -> Result<(), BuildError> {
505 let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
506 let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
507 let sum_variants = vec![type_row![]];
508
509 let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![])?;
510 let [inw] = entry_b.input_wires_arr();
511 let entry = {
512 let sum = entry_b.load_const(&sum_tuple_const);
513
514 entry_b.finish_with_outputs(sum, [])?
515 };
516 let mut middle_b =
517 cfg_builder.simple_block_builder(Signature::new(type_row![], vec![usize_t()]), 1)?;
518 let middle = {
519 let c = middle_b.load_const(&sum_tuple_const);
520 middle_b.finish_with_outputs(c, [inw])?
521 };
522 let exit = cfg_builder.exit_block();
523 cfg_builder.branch(&entry, 0, &middle)?;
524 cfg_builder.branch(&middle, 0, &exit)?;
525 assert_matches!(cfg_builder.finish_hugr(), Ok(_));
526
527 Ok(())
528 }
529
530 #[test]
531 fn test_non_dom_edge() -> Result<(), BuildError> {
532 let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
533 let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
534 let sum_variants = vec![type_row![]];
535 let mut middle_b = cfg_builder
536 .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
537 let [inw] = middle_b.input_wires_arr();
538 let middle = {
539 let c = middle_b.load_const(&sum_tuple_const);
540 middle_b.finish_with_outputs(c, [inw])?
541 };
542
543 let mut entry_b =
544 cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
545 let entry = {
546 let sum = entry_b.load_const(&sum_tuple_const);
547 entry_b.finish_with_outputs(sum, [inw])?
550 };
551 let exit = cfg_builder.exit_block();
552 cfg_builder.branch(&entry, 0, &middle)?;
553 cfg_builder.branch(&middle, 0, &exit)?;
554 assert_matches!(
555 cfg_builder.finish_hugr(),
556 Err(ValidationError::InterGraphEdgeError(
557 InterGraphEdgeError::NonDominatedAncestor { .. }
558 ))
559 );
560
561 Ok(())
562 }
563}