use std::collections::{BTreeSet, HashMap};
use hugr::core::HugrNode;
use hugr::HugrView;
use petgraph::unionfind::UnionFind;
use crate::Circuit;
#[derive(Debug, Clone)]
pub struct UnsupportedTracker<N> {
nodes: HashMap<N, UnsupportedNode>,
components: UnionFind<usize>,
}
type ComponentId = usize;
#[derive(Debug, Clone, Copy, Default)]
struct UnsupportedNode {
component: ComponentId,
}
impl<N: HugrNode> UnsupportedTracker<N> {
pub fn new(_circ: &Circuit<impl HugrView>) -> Self {
Self {
nodes: HashMap::new(),
components: UnionFind::new_empty(),
}
}
pub fn is_unsupported(&self, node: N) -> bool {
self.nodes.contains_key(&node)
}
pub fn record_node(&mut self, node: N, circ: &Circuit<impl HugrView<Node = N>>) {
let node_data = UnsupportedNode {
component: self.components.new_set(),
};
self.nodes.insert(node, node_data);
for neighbour in circ.hugr().input_neighbours(node) {
if let Some(neigh_data) = self.nodes.get(&neighbour) {
self.components
.union(neigh_data.component, node_data.component);
}
}
}
pub fn extract_component(&mut self, node: N) -> BTreeSet<N> {
let node_data = self.nodes.remove(&node).unwrap();
let component = node_data.component;
let representative = self.components.find_mut(component);
let mut nodes = BTreeSet::new();
nodes.insert(node);
for (&n, data) in &self.nodes {
if self.components.find_mut(data.component) == representative {
nodes.insert(n);
}
}
for n in &nodes {
self.nodes.remove(n);
}
nodes
}
pub fn iter(&self) -> impl Iterator<Item = N> + '_ {
self.nodes.keys().copied()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
}
impl<N> Default for UnsupportedTracker<N> {
fn default() -> Self {
Self {
nodes: HashMap::new(),
components: UnionFind::new_empty(),
}
}
}