use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use parking_lot::ReentrantMutex;
use crate::handle::NodeId;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SubgraphId(pub(crate) u64);
impl SubgraphId {
#[must_use]
pub(crate) fn from_node(node: NodeId) -> Self {
Self(node.raw())
}
#[must_use]
pub fn raw(self) -> u64 {
self.0
}
}
impl std::fmt::Display for SubgraphId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "subgraph#{}", self.0)
}
}
pub struct SubgraphLockBox {
pub(crate) wave_owner: Arc<ReentrantMutex<()>>,
}
impl SubgraphLockBox {
fn new() -> Arc<Self> {
Arc::new(Self {
wave_owner: Arc::new(ReentrantMutex::new(())),
})
}
}
pub(crate) const MAX_LOCK_RETRIES: u32 = 100;
pub struct SubgraphRegistry {
parent: HashMap<NodeId, NodeId>,
rank: HashMap<NodeId, u32>,
children: HashMap<NodeId, HashSet<NodeId>>,
boxes: HashMap<NodeId, Arc<SubgraphLockBox>>,
epoch: u64,
}
impl SubgraphRegistry {
#[must_use]
pub(crate) fn new() -> Self {
Self {
parent: HashMap::new(),
rank: HashMap::new(),
children: HashMap::new(),
boxes: HashMap::new(),
epoch: 0,
}
}
#[must_use]
pub(crate) fn epoch(&self) -> u64 {
self.epoch
}
pub(crate) fn ensure_registered(&mut self, node: NodeId) {
if self.parent.contains_key(&node) {
return;
}
self.parent.insert(node, node);
self.rank.insert(node, 0);
self.children.insert(node, HashSet::new());
self.boxes.insert(node, SubgraphLockBox::new());
self.epoch = self.epoch.wrapping_add(1);
}
#[must_use]
pub(crate) fn find(&mut self, node: NodeId) -> NodeId {
let mut cur = node;
loop {
let parent = *self
.parent
.get(&cur)
.expect("subgraph_registry::find: node not registered");
if parent == cur {
break; }
cur = parent;
}
let root = cur;
let mut walker = node;
while walker != root {
let parent = *self
.parent
.get(&walker)
.expect("walker on path-to-root must be registered");
if parent != root {
self.parent.insert(walker, root);
if let Some(old_kids) = self.children.get_mut(&parent) {
old_kids.remove(&walker);
}
self.children.entry(root).or_default().insert(walker);
}
walker = parent;
}
root
}
pub(crate) fn union_nodes(&mut self, a: NodeId, b: NodeId) {
debug_assert!(
self.parent.contains_key(&a) && self.parent.contains_key(&b),
"union_nodes: both nodes must be registered first"
);
debug_assert!(
a != b,
"union_nodes called with self-edge — \
Core's cycle detection bypassed?"
);
let mut root_a = self.find(a);
let mut root_b = self.find(b);
if root_a == root_b {
return;
}
let rank_a = *self.rank.get(&root_a).unwrap_or(&0);
let rank_b = *self.rank.get(&root_b).unwrap_or(&0);
if rank_a < rank_b {
std::mem::swap(&mut root_a, &mut root_b);
}
self.parent.insert(root_b, root_a);
self.children.entry(root_a).or_default().insert(root_b);
if rank_a == rank_b {
self.rank.insert(root_a, rank_a + 1);
}
self.boxes.remove(&root_b);
self.epoch = self.epoch.wrapping_add(1);
}
#[allow(dead_code)]
pub(crate) fn cleanup_node(&mut self, node: NodeId) {
let Some(parent) = self.parent.get(&node).copied() else {
return; };
let direct_children: Vec<NodeId> = self
.children
.get(&node)
.map(|s| s.iter().copied().collect())
.unwrap_or_default();
if parent == node {
if let Some(&new_root) = direct_children.first() {
self.parent.insert(new_root, new_root);
for child in &direct_children {
if let Some(kids) = self.children.get_mut(child) {
kids.remove(&node);
}
}
let new_root_kids = self.children.entry(new_root).or_default();
for child in direct_children.iter().skip(1).copied() {
self.parent.insert(child, new_root);
new_root_kids.insert(child);
}
if let Some(box_arc) = self.boxes.remove(&node) {
self.boxes.insert(new_root, box_arc);
}
let old_rank = self.rank.get(&node).copied().unwrap_or(0);
let new_rank = self.rank.entry(new_root).or_insert(0);
if old_rank > *new_rank {
*new_rank = old_rank;
}
} else {
self.boxes.remove(&node);
}
} else {
if let Some(parent_kids) = self.children.get_mut(&parent) {
parent_kids.remove(&node);
for child in &direct_children {
parent_kids.insert(*child);
}
}
for child in &direct_children {
self.parent.insert(*child, parent);
}
}
self.children.remove(&node);
self.parent.remove(&node);
self.rank.remove(&node);
}
pub(crate) fn on_edge_removed(&mut self, _from: NodeId, _to: NodeId) {
}
pub(crate) fn split_partition(
&mut self,
component_nodes: &[NodeId],
keep_side_nodes: &[NodeId],
edges_in_component: &[(NodeId, NodeId)],
) {
debug_assert!(
!component_nodes.is_empty(),
"component_nodes must be non-empty"
);
debug_assert!(
!keep_side_nodes.is_empty(),
"keep_side_nodes must be non-empty"
);
let keep_side: HashSet<NodeId> = keep_side_nodes.iter().copied().collect();
debug_assert!(
component_nodes.iter().any(|n| !keep_side.contains(n)),
"orphan side must be non-empty (no-op caller)"
);
let original_root = self.find(component_nodes[0]);
let original_box = self
.boxes
.remove(&original_root)
.expect("original_root must have a registered box");
for &n in component_nodes {
self.parent.insert(n, n);
self.rank.insert(n, 0);
self.children.insert(n, HashSet::new());
}
for &(a, b) in edges_in_component {
if a != b {
self.union_nodes(a, b);
}
}
let keep_repr = keep_side_nodes[0];
let keep_root = self.find(keep_repr);
let orphan_repr = *component_nodes
.iter()
.find(|n| !keep_side.contains(n))
.expect("non-empty orphan side");
let orphan_root = self.find(orphan_repr);
debug_assert!(
keep_root != orphan_root,
"split_partition: keep_root {keep_root:?} and orphan_root {orphan_root:?} \
must be distinct after re-union — caller's BFS must have asserted \
disconnection"
);
self.boxes.insert(keep_root, original_box);
self.boxes.insert(orphan_root, SubgraphLockBox::new());
self.epoch = self.epoch.wrapping_add(1);
}
#[must_use]
pub(crate) fn lock_for(&mut self, node: NodeId) -> Option<(SubgraphId, Arc<SubgraphLockBox>)> {
if !self.parent.contains_key(&node) {
return None;
}
let root = self.find(node);
let box_arc = self.boxes.get(&root).cloned()?;
Some((SubgraphId::from_node(root), box_arc))
}
#[must_use]
pub(crate) fn lock_for_validate(
&mut self,
node: NodeId,
expected_box: &Arc<SubgraphLockBox>,
) -> bool {
let Some(root) = self.parent.get(&node).copied() else {
return false;
};
let actual_root = self.find(root);
match self.boxes.get(&actual_root) {
Some(actual) => Arc::ptr_eq(actual, expected_box),
None => false,
}
}
#[must_use]
pub(crate) fn all_partitions(&self) -> Vec<(SubgraphId, Arc<SubgraphLockBox>)> {
let mut out: Vec<(SubgraphId, Arc<SubgraphLockBox>)> = self
.boxes
.iter()
.map(|(root, box_arc)| (SubgraphId::from_node(*root), Arc::clone(box_arc)))
.collect();
out.sort_unstable_by_key(|(sid, _)| *sid);
out
}
#[must_use]
pub fn node_count(&self) -> usize {
self.parent.len()
}
#[must_use]
pub(crate) fn registered_nodes(&self) -> Vec<NodeId> {
self.parent.keys().copied().collect()
}
#[must_use]
pub fn component_count(&self) -> usize {
self.boxes.len()
}
#[must_use]
pub fn partition_of(&mut self, node: NodeId) -> Option<SubgraphId> {
if !self.parent.contains_key(&node) {
return None;
}
Some(SubgraphId::from_node(self.find(node)))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn n(raw: u64) -> NodeId {
NodeId::new(raw)
}
#[test]
fn singleton_register_creates_one_partition() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
assert_eq!(r.node_count(), 1);
assert_eq!(r.component_count(), 1);
assert_eq!(r.find(n(1)), n(1));
}
#[test]
fn union_merges_two_singletons() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.ensure_registered(n(2));
assert_eq!(r.component_count(), 2);
r.union_nodes(n(1), n(2));
assert_eq!(r.component_count(), 1);
assert_eq!(r.find(n(1)), r.find(n(2)));
}
#[test]
fn union_idempotent_within_same_component() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.ensure_registered(n(2));
r.union_nodes(n(1), n(2));
let comp_before = r.component_count();
r.union_nodes(n(1), n(2));
assert_eq!(r.component_count(), comp_before);
}
#[test]
fn cleanup_singleton_removes_partition() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.cleanup_node(n(1));
assert_eq!(r.node_count(), 0);
assert_eq!(r.component_count(), 0);
}
#[test]
fn cleanup_root_promotes_child() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.ensure_registered(n(2));
r.union_nodes(n(1), n(2));
let root_before = r.find(n(1));
let child = if root_before == n(1) { n(2) } else { n(1) };
r.cleanup_node(root_before);
assert_eq!(r.find(child), child);
assert_eq!(r.component_count(), 1);
}
#[test]
fn cleanup_non_root_re_links_grandchildren_to_parent() {
let mut r = SubgraphRegistry::new();
for i in 1..=3 {
r.ensure_registered(n(i));
}
r.union_nodes(n(1), n(2));
r.union_nodes(n(2), n(3));
let root_before = r.find(n(1));
let non_root = if root_before == n(1) {
n(2)
} else if root_before == n(2) {
n(1)
} else {
n(2)
};
r.cleanup_node(non_root);
let other = (1..=3u64)
.map(n)
.find(|x| *x != root_before && *x != non_root)
.expect("third node");
assert_eq!(r.find(root_before), r.find(other));
}
#[test]
fn lock_for_returns_same_box_for_same_component() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.ensure_registered(n(2));
r.union_nodes(n(1), n(2));
let (_sid_a, box_a) = r.lock_for(n(1)).expect("registered");
let (_sid_b, box_b) = r.lock_for(n(2)).expect("registered");
assert!(Arc::ptr_eq(&box_a, &box_b));
}
#[test]
fn lock_for_validate_detects_redirect_after_union() {
let mut r = SubgraphRegistry::new();
for i in 1..=4 {
r.ensure_registered(n(i));
}
r.union_nodes(n(2), n(3));
r.union_nodes(n(2), n(4));
let n2_root = r.find(n(2));
let (_sid_before, box_1_alone) = r.lock_for(n(1)).expect("registered");
let n1_root_before = r.find(n(1));
assert_eq!(n1_root_before, n(1), "n(1) is still its own root");
r.union_nodes(n(1), n(2));
let n1_root_after = r.find(n(1));
assert_eq!(
n1_root_after, n2_root,
"union-by-rank promoted n(2)'s tree; n(1)'s root displaced"
);
let still_valid = r.lock_for_validate(n(1), &box_1_alone);
assert!(
!still_valid,
"lock_for_validate must detect the box-redirect after union promotes a different root"
);
let (_sid_after, box_after) = r.lock_for(n(1)).expect("registered");
assert!(
!Arc::ptr_eq(&box_1_alone, &box_after),
"stale box and active box must be distinct Arc identities"
);
}
#[test]
fn partition_of_distinct_singletons_differ() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.ensure_registered(n(2));
let p1 = r.partition_of(n(1)).expect("registered");
let p2 = r.partition_of(n(2)).expect("registered");
assert_ne!(p1, p2);
}
#[test]
fn partition_of_unioned_nodes_match() {
let mut r = SubgraphRegistry::new();
r.ensure_registered(n(1));
r.ensure_registered(n(2));
r.union_nodes(n(1), n(2));
let p1 = r.partition_of(n(1)).expect("registered");
let p2 = r.partition_of(n(2)).expect("registered");
assert_eq!(p1, p2);
}
}