use std::collections::{BTreeSet, HashMap};
use hugr::HugrView;
use hugr::core::HugrNode;
use petgraph::unionfind::UnionFind;
use crate::serialize::pytket::PytketEncodeError;
use crate::serialize::pytket::opaque::OpaqueSubgraph;
#[derive(Debug, Clone)]
pub struct UnsupportedTracker<N> {
nodes: HashMap<N, UnsupportedNode>,
components: UnionFind<usize>,
}
type ComponentId = usize;
#[derive(Debug, Clone, Copy, Default)]
struct UnsupportedNode {
component: ComponentId,
}
impl<N: HugrNode> UnsupportedTracker<N> {
pub fn new(_hugr: &impl HugrView<Node = N>) -> Self {
Self {
nodes: HashMap::new(),
components: UnionFind::new_empty(),
}
}
pub fn is_unsupported(&self, node: N) -> bool {
self.nodes.contains_key(&node)
}
pub fn record_node(&mut self, node: N, hugr: &impl HugrView<Node = N>) {
let node_data = UnsupportedNode {
component: self.components.new_set(),
};
self.nodes.insert(node, node_data);
for neighbour in hugr.input_neighbours(node) {
if let Some(neigh_data) = self.nodes.get(&neighbour) {
self.components
.union(neigh_data.component, node_data.component);
}
}
}
pub fn extract_component(
&mut self,
node: N,
hugr: &impl HugrView<Node = N>,
) -> Result<OpaqueSubgraph<N>, PytketEncodeError<N>> {
let node_data = self.nodes.remove(&node).unwrap();
let component = node_data.component;
let representative = self.components.find_mut(component);
let mut nodes = BTreeSet::new();
nodes.insert(node);
for (&n, data) in &self.nodes {
if self.components.find_mut(data.component) == representative {
nodes.insert(n);
}
}
for n in &nodes {
self.nodes.remove(n);
}
OpaqueSubgraph::try_from_nodes(nodes, hugr)
}
pub fn iter(&self) -> impl Iterator<Item = N> + '_ {
self.nodes.keys().copied()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
}
impl<N> Default for UnsupportedTracker<N> {
fn default() -> Self {
Self {
nodes: HashMap::new(),
components: UnionFind::new_empty(),
}
}
}