use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::{OpTrait, OpType, Output, DFG};
use hugr::types::{Signature, SumType, TypeEnum};
use hugr::HugrView;
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Itertools;
use crate::{Circuit, CircuitMutError};
#[expect(unused)] pub(super) fn rewrite_into_dfg(circ: &mut Circuit) -> Result<(), CircuitMutError> {
let old_optype = circ.hugr.get_optype(circ.parent());
if matches!(old_optype, OpType::DFG(_)) {
return Ok(());
}
let signature = circ.circuit_signature().into_owned();
let signature = match old_optype {
OpType::DataflowBlock(_) => remove_cfg_empty_output_tuple(circ, signature)?,
_ => signature,
};
circ.hugr.set_num_ports(
circ.parent(),
signature.input_count() + 1,
signature.input_count() + 1,
);
circ.hugr.replace_op(circ.parent(), DFG { signature });
Ok(())
}
fn remove_cfg_empty_output_tuple(
circ: &mut Circuit,
signature: Signature,
) -> Result<Signature, CircuitMutError> {
let sig = signature;
let input_node = circ.input_node();
let output_node = circ.output_node();
let output_op = circ.hugr.get_optype(output_node).clone();
let output_sig = output_op
.dataflow_signature()
.expect("Exit node with no dataflow signature.");
if !matches!(
output_sig.input[0].as_type_enum(),
TypeEnum::Sum(SumType::Unit { size: 1 })
) {
return Ok(sig);
}
let Some((tag_node, _)) = circ.hugr.single_linked_output(output_node, 0) else {
return Ok(sig);
};
let tag_op = circ.hugr.get_optype(tag_node);
if !matches!(tag_op, OpType::Tag(_)) {
return Ok(sig);
}
let hugr = circ.hugr_mut();
let input_neighs = hugr.all_linked_outputs(output_node).skip(1).collect_vec();
hugr.remove_node(output_node);
hugr.remove_node(tag_node);
let new_types = output_sig.input[1..].to_vec();
let new_node = hugr.add_node_after(
input_node,
Output {
types: new_types.clone().into(),
},
);
for (i, (neigh, port)) in input_neighs.into_iter().enumerate() {
hugr.connect(neigh, port, new_node, i);
}
let sig = Signature::new(sig.input, new_types);
Ok(sig)
}