use std::collections::{HashMap, HashSet, LinkedList, VecDeque};
use std::hash::Hash;
use itertools::Itertools;
use thiserror::Error;
use hugr_core::hugr::patch::outline_cfg::OutlineCfg;
use hugr_core::hugr::views::HugrView;
use hugr_core::hugr::{Patch, hugrmut::HugrMut};
use hugr_core::ops::OpTag;
use hugr_core::ops::OpTrait;
use hugr_core::{Direction, Hugr, Node};
pub trait CfgNodeMap<T> {
fn entry_node(&self) -> T;
fn exit_node(&self) -> T;
fn successors(&self, node: T) -> impl Iterator<Item = T>;
fn predecessors(&self, node: T) -> impl Iterator<Item = T>;
}
pub trait CfgNester<T>: CfgNodeMap<T> {
fn nest_sese_region(&mut self, entry_edge: (T, T), exit_edge: (T, T)) -> T;
}
pub fn transform_cfg_to_nested<T: Copy + Eq + Hash + std::fmt::Debug>(
view: &mut impl CfgNester<T>,
) {
let edge_classes = EdgeClassifier::get_edge_classes(view);
let mut rem_edges: HashMap<usize, HashSet<(T, T)>> = HashMap::new();
for (e, cls) in &edge_classes {
rem_edges.entry(*cls).or_default().insert(*e);
}
fn traverse<T: Copy + Eq + Hash + std::fmt::Debug>(
view: &mut impl CfgNester<T>,
n: T,
edge_classes: &HashMap<(T, T), usize>,
rem_edges: &mut HashMap<usize, HashSet<(T, T)>>,
stop_at: Option<usize>,
) -> Option<(T, T)> {
let mut seen = HashSet::new();
let mut stack = Vec::new();
let mut exit_edges = Vec::new();
stack.push(n);
while let Some(n) = stack.pop() {
if !seen.insert(n) {
continue;
}
let (exit, rest): (Vec<_>, Vec<_>) = view
.successors(n)
.map(|s| (n, s))
.partition(|e| stop_at.is_some() && edge_classes.get(e).copied() == stop_at);
exit_edges.extend(exit.into_iter().at_most_one().unwrap());
for mut e in rest {
if let Some(cls) = edge_classes.get(&e) {
assert!(rem_edges.get_mut(cls).unwrap().remove(&e));
while !rem_edges.get_mut(cls).unwrap().is_empty() {
let prev_e = e;
e = traverse(view, e.1, edge_classes, rem_edges, Some(*cls)).unwrap();
assert!(rem_edges.get_mut(cls).unwrap().remove(&e));
if prev_e.1 != e.0 || view.successors(e.0).count() > 1 {
e = (view.nest_sese_region(prev_e, e), e.1);
}
}
}
stack.push(e.1);
}
}
exit_edges.into_iter().unique().at_most_one().unwrap()
}
traverse(view, view.entry_node(), &edge_classes, &mut rem_edges, None);
}
pub fn transform_all_cfgs(h: &mut Hugr) {
let mut node_stack = Vec::from([h.entrypoint()]);
while let Some(n) = node_stack.pop() {
if h.get_optype(n).tag() == OpTag::Cfg {
transform_cfg_to_nested(&mut IdentityCfgMap::new(h.with_entrypoint_mut(n)));
}
node_stack.extend(h.children(n));
}
}
type CfgEdge<T> = (T, T);
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
enum EdgeDest<T> {
Forward(T),
Backward(T),
}
impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeDest<T> {
pub fn target(&self) -> T {
match self {
EdgeDest::Forward(i) => *i,
EdgeDest::Backward(i) => *i,
}
}
}
fn all_edges<'a, T: Copy + Clone + PartialEq + Eq + Hash + 'a>(
cfg: &'a impl CfgNodeMap<T>,
n: T,
) -> impl Iterator<Item = EdgeDest<T>> + 'a {
let extra = if n == cfg.exit_node() {
vec![cfg.entry_node()]
} else {
vec![]
};
cfg.successors(n)
.chain(extra)
.map(EdgeDest::Forward)
.chain(cfg.predecessors(n).map(EdgeDest::Backward))
.unique()
}
fn flip<T: Copy + Clone + PartialEq + Eq + Hash>(src: T, d: EdgeDest<T>) -> (T, EdgeDest<T>) {
match d {
EdgeDest::Forward(tgt) => (tgt, EdgeDest::Backward(src)),
EdgeDest::Backward(tgt) => (tgt, EdgeDest::Forward(src)),
}
}
fn cfg_edge<T: Copy + Clone + PartialEq + Eq + Hash>(s: T, d: EdgeDest<T>) -> CfgEdge<T> {
match d {
EdgeDest::Forward(t) => (s, t),
EdgeDest::Backward(t) => (t, s),
}
}
pub struct IdentityCfgMap<H: HugrView> {
h: H,
entry: H::Node,
exit: H::Node,
}
impl<H: HugrView> IdentityCfgMap<H> {
pub fn new(h: H) -> Self {
let (entry, exit) = h.children(h.entrypoint()).take(2).collect_tuple().unwrap();
debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit);
Self { h, entry, exit }
}
}
impl<H: HugrView> CfgNodeMap<H::Node> for IdentityCfgMap<H> {
fn entry_node(&self) -> H::Node {
self.entry
}
fn exit_node(&self) -> H::Node {
self.exit
}
fn successors(&self, node: H::Node) -> impl Iterator<Item = H::Node> {
self.h.neighbours(node, Direction::Outgoing)
}
fn predecessors(&self, node: H::Node) -> impl Iterator<Item = H::Node> {
self.h.neighbours(node, Direction::Incoming)
}
}
impl<H: HugrMut<Node = Node>> CfgNester<H::Node> for IdentityCfgMap<H> {
fn nest_sese_region(
&mut self,
entry_edge: (H::Node, H::Node),
exit_edge: (H::Node, H::Node),
) -> H::Node {
let blocks = region_blocks(self, entry_edge, exit_edge).unwrap();
assert!(
[entry_edge.0, entry_edge.1, exit_edge.0, exit_edge.1]
.iter()
.all(|n| self.h.get_parent(*n) == Some(self.h.entrypoint()))
);
let [new_block, new_cfg] = OutlineCfg::new(blocks).apply(&mut self.h).unwrap();
debug_assert!(
[entry_edge.0, exit_edge.1]
.iter()
.all(|n| self.h.get_parent(*n) == Some(self.h.entrypoint()))
);
debug_assert!(
[entry_edge.1, exit_edge.0]
.iter()
.all(|n| self.h.get_parent(*n) == Some(new_cfg))
);
new_block
}
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum RegionBlocksError<T> {
ExitEdgeNotPresent(T, T),
EntryEdgeNotPresent(T, T),
EntryEdgeSourceInRegion(T),
UnexpectedEntryEdges(Vec<T>),
}
pub fn region_blocks<T: Copy + Eq + Hash + std::fmt::Debug>(
v: &impl CfgNodeMap<T>,
entry_edge: (T, T),
exit_edge: (T, T),
) -> Result<HashSet<T>, RegionBlocksError<T>> {
let mut blocks = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(entry_edge.1);
while let Some(n) = queue.pop_front() {
if blocks.insert(n) {
if n == exit_edge.0 {
let succs: Vec<T> = v.successors(n).collect();
let n_succs = succs.len();
let internal_succs: Vec<T> =
succs.into_iter().filter(|s| *s != exit_edge.1).collect();
if internal_succs.len() == n_succs {
return Err(RegionBlocksError::ExitEdgeNotPresent(
exit_edge.0,
exit_edge.1,
));
}
queue.extend(internal_succs);
} else {
queue.extend(v.successors(n));
}
}
}
if blocks.contains(&entry_edge.0) {
return Err(RegionBlocksError::EntryEdgeSourceInRegion(entry_edge.0));
}
let ext_preds = v
.predecessors(entry_edge.1)
.unique()
.filter(|p| !blocks.contains(p));
let (expected, extra): (Vec<T>, Vec<T>) = ext_preds.partition(|i| *i == entry_edge.0);
if expected != vec![entry_edge.0] {
return Err(RegionBlocksError::EntryEdgeNotPresent(
entry_edge.0,
entry_edge.1,
));
}
if !extra.is_empty() {
return Err(RegionBlocksError::UnexpectedEntryEdges(extra));
}
Ok(blocks)
}
struct UndirectedDFSTree<T> {
dfs_num: HashMap<T, usize>,
dfs_parents: HashMap<T, EdgeDest<T>>,
}
impl<T: Copy + Clone + PartialEq + Eq + Hash> UndirectedDFSTree<T> {
fn new(cfg: &impl CfgNodeMap<T>) -> Self {
let mut reachable = HashSet::new();
{
let mut pending = VecDeque::new();
pending.push_back(cfg.exit_node());
while let Some(n) = pending.pop_front() {
if reachable.insert(n) {
pending.extend(cfg.predecessors(n));
}
}
}
let mut dfs_num = HashMap::new();
let mut dfs_parents = HashMap::new();
{
let mut pending = vec![(cfg.entry_node(), EdgeDest::Backward(cfg.exit_node()))];
while let Some((n, p_edge)) = pending.pop() {
if !dfs_num.contains_key(&n) && reachable.contains(&n) {
dfs_num.insert(n, dfs_num.len());
dfs_parents.insert(n, p_edge);
for e in all_edges(cfg, n) {
pending.push(flip(n, e));
}
}
}
dfs_parents.remove(&cfg.entry_node()).unwrap();
}
UndirectedDFSTree {
dfs_num,
dfs_parents,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
enum Bracket<T> {
Real(CfgEdge<T>),
Capping(usize, T),
}
struct BracketList<T> {
items: LinkedList<Bracket<T>>, size: usize, }
impl<T: Copy + Clone + PartialEq + Eq + Hash> BracketList<T> {
pub fn new() -> Self {
BracketList {
items: LinkedList::new(),
size: 0,
}
}
pub fn tag(&mut self, deleted: &HashSet<Bracket<T>>) -> Option<(Bracket<T>, usize)> {
while let Some(e) = self.items.front() {
if deleted.contains(e) {
self.items.pop_front();
} else {
return Some((e.clone(), self.size));
}
}
None
}
pub fn concat(&mut self, other: BracketList<T>) {
let BracketList { mut items, size } = other;
self.items.append(&mut items);
assert!(items.is_empty());
self.size += size;
}
pub fn delete(&mut self, b: &Bracket<T>, deleted: &mut HashSet<Bracket<T>>) {
debug_assert!(self.items.contains(b)); let was_new = deleted.insert(b.clone());
assert!(was_new);
self.size -= 1;
}
pub fn push(&mut self, e: Bracket<T>) {
self.items.push_back(e);
self.size += 1;
}
}
pub struct EdgeClassifier<T> {
deleted_backedges: HashSet<Bracket<T>>,
capping_edges: HashMap<usize, Vec<T>>,
edge_classes: HashMap<CfgEdge<T>, Option<(Bracket<T>, usize)>>,
}
impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
pub fn get_edge_classes(cfg: &impl CfgNodeMap<T>) -> HashMap<CfgEdge<T>, usize> {
let tree = UndirectedDFSTree::new(cfg);
let mut s = Self {
deleted_backedges: HashSet::new(),
capping_edges: HashMap::new(),
edge_classes: HashMap::new(),
};
s.traverse(cfg, &tree, cfg.entry_node());
assert!(s.capping_edges.is_empty());
s.edge_classes.remove(&(cfg.exit_node(), cfg.entry_node()));
let mut cycle_class_idxs = HashMap::new();
s.edge_classes
.into_iter()
.map(|(k, v)| {
let l = cycle_class_idxs.len();
(k, *cycle_class_idxs.entry(v).or_insert(l))
})
.collect()
}
fn traverse(
&mut self,
cfg: &impl CfgNodeMap<T>,
tree: &UndirectedDFSTree<T>,
n: T,
) -> (usize, BracketList<T>) {
let n_dfs = *tree.dfs_num.get(&n).unwrap(); let (children, non_capping_backedges): (Vec<_>, Vec<_>) = all_edges(cfg, n)
.filter(|e| tree.dfs_num.contains_key(&e.target()))
.partition(|e| {
let (tgt, from) = flip(n, *e);
tree.dfs_parents.get(&tgt) == Some(&from)
});
let child_results: Vec<_> = children
.iter()
.map(|c| self.traverse(cfg, tree, c.target()))
.collect();
let mut min_dfs_target: [Option<usize>; 2] = [None, None]; let mut bs = BracketList::new();
for (tgt, brs) in child_results {
if tgt < min_dfs_target[0].unwrap_or(usize::MAX) {
min_dfs_target = [Some(tgt), min_dfs_target[0]];
} else if tgt < min_dfs_target[1].unwrap_or(usize::MAX) {
min_dfs_target[1] = Some(tgt);
}
bs.concat(brs);
}
if let Some(min1dfs) = min_dfs_target[1]
&& min1dfs < n_dfs
{
bs.push(Bracket::Capping(min1dfs, n));
self.capping_edges.entry(min1dfs).or_default().push(n);
}
let parent_edge = tree.dfs_parents.get(&n);
let (be_up, be_down): (Vec<_>, Vec<_>) = non_capping_backedges
.into_iter()
.map(|e| (*tree.dfs_num.get(&e.target()).unwrap(), e))
.partition(|(dfs, _)| *dfs < n_dfs);
for (_, e) in be_down {
let e = cfg_edge(n, e);
let b = Bracket::Real(e);
bs.delete(&b, &mut self.deleted_backedges);
self.edge_classes.entry(e).or_insert_with(|| Some((b, 0)));
}
for src in self.capping_edges.remove(&n_dfs).unwrap_or_default() {
bs.delete(&Bracket::Capping(n_dfs, src), &mut self.deleted_backedges);
}
be_up
.iter()
.filter(|(_, e)| Some(e) != parent_edge)
.for_each(|(_, e)| bs.push(Bracket::Real(cfg_edge(n, *e))));
let class = bs.tag(&self.deleted_backedges);
if let Some((Bracket::Real(e), 1)) = &class {
self.edge_classes.insert(*e, class.clone());
}
if let Some(parent_edge) = tree.dfs_parents.get(&n) {
self.edge_classes.insert(cfg_edge(n, *parent_edge), class);
}
let highest_target = be_up
.into_iter()
.map(|(dfs, _)| dfs)
.chain(min_dfs_target[0]);
(highest_target.min().unwrap_or(usize::MAX), bs)
}
}
#[cfg(test)]
pub(crate) mod test {
use super::*;
use hugr_core::builder::{
BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, endo_sig,
};
use hugr_core::extension::prelude::usize_t;
use hugr_core::Node;
use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError};
use hugr_core::ops::Value;
use hugr_core::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use hugr_core::types::{EdgeKind, Signature};
use hugr_core::utils::depth;
pub fn group_by<E: Eq + Hash + Ord, V: Eq + Hash>(h: HashMap<E, V>) -> HashSet<Vec<E>> {
let mut res = HashMap::new();
for (k, v) in h {
res.entry(v).or_insert_with(Vec::new).push(k);
}
res.into_values().map(sorted).collect()
}
pub fn sorted<E: Ord>(items: impl IntoIterator<Item = E>) -> Vec<E> {
let mut v: Vec<_> = items.into_iter().collect();
v.sort();
v
}
#[test]
fn test_cond_then_loop_separate() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(Signature::new_endo([usize_t()]))?;
let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
let entry = n_identity(
cfg_builder.simple_entry_builder(vec![usize_t()].into(), 1)?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?;
cfg_builder.branch(&entry, 0, &split)?;
let head = n_identity(
cfg_builder.simple_block_builder(endo_sig([usize_t()]), 1)?,
&const_unit,
)?;
let tail = n_identity(
cfg_builder.simple_block_builder(endo_sig([usize_t()]), 2)?,
&pred_const,
)?;
cfg_builder.branch(&tail, 1, &head)?;
cfg_builder.branch(&head, 0, &tail)?; cfg_builder.branch(&merge, 0, &head)?;
let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;
let h = cfg_builder.finish_hugr()?;
let (entry, exit) = (entry.node(), exit.node());
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
let mut v = IdentityCfgMap::new(h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == split)
.map(|(_, t)| t)
.collect::<Vec<_>>()[..]
else {
panic!("Split node should have two successors");
};
let classes = group_by(edge_classes);
assert_eq!(
classes,
HashSet::from([
sorted([(split, left), (left, merge)]), sorted([(split, right), (right, merge)]), Vec::from([(head, tail)]), Vec::from([(tail, head)]), sorted([(entry, split), (merge, head), (tail, exit)]), ])
);
transform_cfg_to_nested(&mut v);
let h = v.h;
h.validate().unwrap();
assert_eq!(3, depth(&h, entry));
assert_eq!(3, depth(&h, exit));
for n in [split, left, right, merge, head, tail] {
assert_eq!(5, depth(&h, n));
}
let first = [split, left, right, merge]
.iter()
.map(|n| h.get_parent(*n).unwrap())
.unique()
.exactly_one()
.unwrap();
let second = [head, tail]
.iter()
.map(|n| h.get_parent(*n).unwrap())
.unique()
.exactly_one()
.unwrap();
assert_ne!(first, second);
Ok(())
}
#[test]
fn test_cond_then_loop_combined() -> Result<(), BuildError> {
let (h, merge, tail) = build_cond_then_loop_cfg()?;
let (merge, tail) = (merge.node(), tail.node());
let [entry, exit]: [Node; 2] = h
.children(h.entrypoint())
.take(2)
.collect_vec()
.try_into()
.unwrap();
let v = IdentityCfgMap::new(h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == entry)
.map(|(_, t)| t)
.collect::<Vec<_>>()[..]
else {
panic!("Entry node should have two successors");
};
let classes = group_by(edge_classes);
assert_eq!(
classes,
HashSet::from([
sorted([(entry, left), (left, merge)]), sorted([(entry, right), (right, merge)]), Vec::from([(tail, exit)]), Vec::from([(merge, tail)]), Vec::from([(tail, merge)]), ])
);
Ok(())
}
#[test]
fn test_cond_in_loop_separate_headers() -> Result<(), BuildError> {
let (mut h, head, tail) = build_conditional_in_loop_cfg(true)?;
let head = head.node();
let tail = tail.node();
let split = h.output_neighbours(head).exactly_one().ok().unwrap();
let merge = h.input_neighbours(tail).exactly_one().ok().unwrap();
let v = IdentityCfgMap::new(&h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let IdentityCfgMap { h: _, entry, exit } = v;
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == split)
.map(|(_, t)| t)
.collect::<Vec<_>>()[..]
else {
panic!("Split node should have two successors");
};
let classes = group_by(edge_classes);
assert_eq!(
classes,
HashSet::from([
sorted([(split, left), (left, merge)]), sorted([(split, right), (right, merge)]), sorted([(head, split), (merge, tail)]), sorted([(entry, head), (tail, exit)]), Vec::from([(tail, head)]) ])
);
transform_cfg_to_nested(&mut IdentityCfgMap::new(&mut h));
h.validate().unwrap();
assert_eq!(3, depth(&h, entry));
assert_eq!(5, depth(&h, head));
for n in [split, left, right, merge] {
assert_eq!(7, depth(&h, n));
}
assert_eq!(5, depth(&h, tail));
assert_eq!(3, depth(&h, exit));
Ok(())
}
#[test]
fn test_cond_in_loop_combined_headers() -> Result<(), BuildError> {
let (h, head, tail) = build_conditional_in_loop_cfg(false)?;
let head = head.node();
let tail = tail.node();
let v = IdentityCfgMap::new(h);
let edge_classes = EdgeClassifier::get_edge_classes(&v);
let IdentityCfgMap { h: _, entry, exit } = v;
let merge = *edge_classes
.keys()
.filter(|(_, t)| *t == tail)
.map(|(s, _)| s)
.exactly_one()
.unwrap();
let [&left, &right] = edge_classes
.keys()
.filter(|(s, _)| *s == head)
.map(|(_, t)| t)
.collect::<Vec<_>>()[..]
else {
panic!("Loop header should have two successors");
};
let classes = group_by(edge_classes);
assert_eq!(
classes,
HashSet::from([
sorted([(head, left), (left, merge)]), sorted([(head, right), (right, merge)]), Vec::from([(merge, tail)]), sorted([(entry, head), (tail, exit)]), Vec::from([(tail, head)]) ])
);
Ok(())
}
#[test]
fn incorrect_insertion() {
let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap();
let final_node = tail.node();
let final_node_input = h.node_inputs(final_node).next().unwrap();
let rw = IdentityInsertion::new(final_node, final_node_input);
let apply_result = h.apply_patch(rw);
assert_eq!(
apply_result,
Err(IdentityInsertionError::InvalidPortKind(Some(
EdgeKind::ControlFlow
)))
);
}
fn n_identity<T: DataflowSubContainer>(
mut dataflow_builder: T,
pred_const: &ConstID,
) -> Result<T::ContainerHandle, BuildError> {
let w = dataflow_builder.input_wires();
let u = dataflow_builder.load_const(pred_const);
dataflow_builder.finish_with_outputs([u].into_iter().chain(w))
}
fn build_if_then_else_merge<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
unit_const: &ConstID,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let split = n_identity(
cfg.simple_block_builder(endo_sig([usize_t()]), 2)?,
const_pred,
)?;
let merge = build_then_else_merge_from_if(cfg, unit_const, split)?;
Ok((split, merge))
}
fn build_then_else_merge_from_if<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg: &mut CFGBuilder<T>,
unit_const: &ConstID,
split: BasicBlockID,
) -> Result<BasicBlockID, BuildError> {
let merge = n_identity(
cfg.simple_block_builder(endo_sig([usize_t()]), 1)?,
unit_const,
)?;
let left = n_identity(
cfg.simple_block_builder(endo_sig([usize_t()]), 1)?,
unit_const,
)?;
let right = n_identity(
cfg.simple_block_builder(endo_sig([usize_t()]), 1)?,
unit_const,
)?;
cfg.branch(&split, 0, &left)?;
cfg.branch(&split, 1, &right)?;
cfg.branch(&left, 0, &merge)?;
cfg.branch(&right, 0, &merge)?;
Ok(merge)
}
fn build_cond_then_loop_cfg() -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(Signature::new_endo([usize_t()]))?;
let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
let entry = n_identity(
cfg_builder.simple_entry_builder(vec![usize_t()].into(), 2)?,
&pred_const,
)?;
let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?;
let tail = n_identity(
cfg_builder.simple_block_builder(endo_sig([usize_t()]), 2)?,
&pred_const,
)?;
cfg_builder.branch(&tail, 1, &merge)?;
cfg_builder.branch(&merge, 0, &tail)?; let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;
let h = cfg_builder.finish_hugr()?;
Ok((h, merge, tail))
}
pub(crate) fn build_conditional_in_loop_cfg(
separate_headers: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(Signature::new_endo([usize_t()]))?;
let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?;
let h = cfg_builder.finish_hugr()?;
Ok((h, head, tail))
}
pub(crate) fn build_conditional_in_loop<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
let entry = n_identity(
cfg_builder.simple_entry_builder(vec![usize_t()].into(), 1)?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?;
let head = if separate_headers {
let head = n_identity(
cfg_builder.simple_block_builder(endo_sig([usize_t()]), 1)?,
&const_unit,
)?;
cfg_builder.branch(&head, 0, &split)?;
head
} else {
split
};
let tail = n_identity(
cfg_builder.simple_block_builder(endo_sig([usize_t()]), 2)?,
&pred_const,
)?;
cfg_builder.branch(&tail, 1, &head)?;
cfg_builder.branch(&merge, 0, &tail)?;
let exit = cfg_builder.exit_block();
cfg_builder.branch(&entry, 0, &head)?;
cfg_builder.branch(&tail, 0, &exit)?;
Ok((head, tail))
}
}