use hugr::ops::OpTrait;
use itertools::{Either, Itertools};
use std::collections::BTreeSet;
use crate::serialize::pytket::PytketEncodeError;
use hugr::core::HugrNode;
use hugr::hugr::views::SiblingSubgraph;
use hugr::hugr::views::sibling_subgraph::InvalidSubgraph;
use hugr::types::Signature;
use hugr::{Direction, Hugr, HugrView, IncomingPort, OutgoingPort};
#[derive(Debug, Clone)]
pub struct OpaqueSubgraph<N> {
nodes: BTreeSet<N>,
incoming_ports: Vec<(N, IncomingPort)>,
outgoing_ports: Vec<(N, OutgoingPort)>,
signature: Signature,
region: N,
sibling_subgraph_compatible: bool,
}
impl<N: HugrNode> OpaqueSubgraph<N> {
pub(in crate::serialize::pytket) fn try_from_nodes(
nodes: BTreeSet<N>,
hugr: &impl HugrView<Node = N>,
) -> Result<Self, PytketEncodeError<N>> {
let region = nodes
.first()
.and_then(|n| hugr.get_parent(*n))
.unwrap_or_else(|| hugr.entrypoint());
let mut incoming_ports = Vec::new();
let mut outgoing_ports = Vec::new();
let mut input_types = Vec::new();
let mut output_types = Vec::new();
let mut sibling_subgraph_compatible = true;
for &node in &nodes {
let op = hugr.get_optype(node);
let Some(signature) = op.dataflow_signature() else {
continue;
};
let mut has_nonlocal_boundary = false;
for port in signature
.ports(Direction::Incoming)
.chain(signature.ports(Direction::Outgoing))
{
let ty = signature.port_type(port).unwrap();
let mut is_local_boundary = false;
for (n, _) in hugr.linked_ports(node, port) {
if nodes.contains(&n) {
continue;
}
match hugr.get_parent(n) == Some(region) {
true => is_local_boundary = true,
false => has_nonlocal_boundary = true,
}
if is_local_boundary && has_nonlocal_boundary {
break;
}
}
if is_local_boundary {
match port.as_directed() {
Either::Left(inc) => {
incoming_ports.push((node, inc));
input_types.push(ty.clone());
}
Either::Right(out) => {
outgoing_ports.push((node, out));
output_types.push(ty.clone());
}
}
}
}
let is_region_parent = hugr.first_child(node).is_some();
let non_value_boundary = op
.static_port(Direction::Incoming)
.iter()
.chain(op.static_port(Direction::Outgoing).iter())
.chain(op.other_port(Direction::Incoming).iter())
.chain(op.other_port(Direction::Outgoing).iter())
.any(|&p| hugr.linked_ports(node, p).any(|(n, _)| !nodes.contains(&n)));
sibling_subgraph_compatible &=
!has_nonlocal_boundary && !is_region_parent && !non_value_boundary;
}
let signature = Signature::new(input_types, output_types);
Ok(Self {
nodes,
incoming_ports,
outgoing_ports,
signature,
region,
sibling_subgraph_compatible,
})
}
pub fn nodes(&self) -> &BTreeSet<N> {
&self.nodes
}
pub fn incoming_ports(&self) -> &[(N, IncomingPort)] {
&self.incoming_ports
}
pub fn outgoing_ports(&self) -> &[(N, OutgoingPort)] {
&self.outgoing_ports
}
pub fn signature(&self) -> &Signature {
&self.signature
}
pub fn region(&self) -> N {
self.region
}
pub fn is_sibling_subgraph_compatible(&self) -> bool {
self.sibling_subgraph_compatible
}
pub fn extract_subgraph(
&self,
hugr: &impl HugrView<Node = N>,
) -> Result<Hugr, InvalidSubgraph<N>> {
let nodes = self.nodes().iter().cloned().collect_vec();
let subgraph = SiblingSubgraph::try_from_nodes(nodes, hugr).unwrap();
Ok(subgraph.extract_subgraph(hugr, ""))
}
}