use std::collections::{HashMap, VecDeque};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use crate::wtype::WInstance;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ReachabilityIndex {
reachable: FxHashSet<u32>,
children: HashMap<u32, SmallVec<u32, 4>>,
parents: HashMap<u32, u32>,
root: Option<u32>,
}
impl ReachabilityIndex {
#[must_use]
pub fn from_instance(inst: &WInstance) -> Self {
let mut reachable = FxHashSet::default();
let mut children: HashMap<u32, SmallVec<u32, 4>> = HashMap::new();
let mut parents: HashMap<u32, u32> = HashMap::new();
for &(parent, child, _) in &inst.arcs {
children.entry(parent).or_default().push(child);
parents.insert(child, parent);
}
let mut queue = VecDeque::new();
queue.push_back(inst.root);
reachable.insert(inst.root);
while let Some(current) = queue.pop_front() {
if let Some(kids) = children.get(¤t) {
for &child in kids {
if reachable.insert(child) {
queue.push_back(child);
}
}
}
}
Self {
reachable,
children,
parents,
root: Some(inst.root),
}
}
pub fn insert_edge(&mut self, parent: u32, child: u32) -> Vec<u32> {
self.children.entry(parent).or_default().push(child);
self.parents.insert(child, parent);
if !self.reachable.contains(&parent) || self.reachable.contains(&child) {
return Vec::new();
}
let mut newly_reachable = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(child);
while let Some(current) = queue.pop_front() {
if self.reachable.insert(current) {
newly_reachable.push(current);
if let Some(kids) = self.children.get(¤t) {
for &kid in kids {
if !self.reachable.contains(&kid) {
queue.push_back(kid);
}
}
}
}
}
newly_reachable
}
pub fn delete_edge(&mut self, parent: u32, child: u32) -> Vec<u32> {
if let Some(kids) = self.children.get_mut(&parent) {
if let Some(pos) = kids.iter().position(|&k| k == child) {
kids.remove(pos);
}
}
if self.parents.get(&child).copied() == Some(parent) {
self.parents.remove(&child);
}
if !self.reachable.contains(&child) {
return Vec::new();
}
if self.has_path_to_root(child) {
return Vec::new();
}
let mut newly_unreachable = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(child);
while let Some(current) = queue.pop_front() {
if self.reachable.remove(¤t) {
newly_unreachable.push(current);
if let Some(kids) = self.children.get(¤t) {
for &kid in kids {
if self.reachable.contains(&kid) && !self.has_path_to_root(kid) {
queue.push_back(kid);
}
}
}
}
}
newly_unreachable
}
#[must_use]
pub fn is_reachable(&self, node: u32) -> bool {
self.reachable.contains(&node)
}
#[must_use]
pub const fn root(&self) -> Option<u32> {
self.root
}
#[must_use]
pub fn parent_of(&self, node: u32) -> Option<u32> {
self.parents.get(&node).copied()
}
#[must_use]
pub fn children_of(&self, node: u32) -> &[u32] {
self.children
.get(&node)
.map_or(&[], smallvec::SmallVec::as_slice)
}
fn has_path_to_root(&self, node: u32) -> bool {
let Some(root) = self.root else {
return false;
};
if node == root {
return true;
}
let mut visited = FxHashSet::default();
let mut queue = VecDeque::new();
queue.push_back(node);
visited.insert(node);
while let Some(current) = queue.pop_front() {
if let Some(&p) = self.parents.get(¤t) {
if p == root {
return true;
}
if self.reachable.contains(&p) && visited.insert(p) {
queue.push_back(p);
}
}
for (&candidate, kids) in &self.children {
if kids.contains(¤t) && candidate != node && visited.insert(candidate) {
if candidate == root {
return true;
}
if self.reachable.contains(&candidate) {
queue.push_back(candidate);
}
}
}
}
false
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::expect_used)]
mod tests {
use std::collections::HashMap;
use panproto_gat::Name;
use panproto_schema::Edge;
use crate::metadata::Node;
use crate::wtype::WInstance;
use super::*;
fn test_node(id: u32, anchor: &str) -> Node {
Node::new(id, anchor)
}
fn make_instance(root: u32, arcs: &[(u32, u32)]) -> WInstance {
let mut nodes = HashMap::new();
nodes.insert(root, test_node(root, "root"));
for &(p, c) in arcs {
nodes.entry(p).or_insert_with(|| test_node(p, "v"));
nodes.entry(c).or_insert_with(|| test_node(c, "v"));
}
let edge = Edge {
src: Name::from("a"),
tgt: Name::from("b"),
kind: Name::from("prop"),
name: None,
};
let arcs_vec: Vec<(u32, u32, Edge)> =
arcs.iter().map(|&(p, c)| (p, c, edge.clone())).collect();
WInstance::new(nodes, arcs_vec, Vec::new(), root, Name::from("root"))
}
#[test]
fn from_instance_marks_all_reachable() {
let inst = make_instance(0, &[(0, 1), (1, 2)]);
let idx = ReachabilityIndex::from_instance(&inst);
assert!(idx.is_reachable(0));
assert!(idx.is_reachable(1));
assert!(idx.is_reachable(2));
assert_eq!(idx.root(), Some(0));
}
#[test]
fn delete_edge_marks_subtree_unreachable() {
let inst = make_instance(0, &[(0, 1), (1, 2)]);
let mut idx = ReachabilityIndex::from_instance(&inst);
let unreachable = idx.delete_edge(0, 1);
assert!(unreachable.contains(&1));
assert!(unreachable.contains(&2));
assert!(!idx.is_reachable(1));
assert!(!idx.is_reachable(2));
assert!(idx.is_reachable(0));
}
#[test]
fn insert_edge_makes_unreachable_reachable() {
let inst = make_instance(0, &[(0, 1), (1, 2)]);
let mut idx = ReachabilityIndex::from_instance(&inst);
idx.delete_edge(0, 1);
assert!(!idx.is_reachable(1));
assert!(!idx.is_reachable(2));
let newly = idx.insert_edge(0, 1);
assert!(newly.contains(&1));
assert!(newly.contains(&2));
assert!(idx.is_reachable(1));
assert!(idx.is_reachable(2));
}
#[test]
fn diamond_graph_one_path_removal() {
let mut nodes = HashMap::new();
for id in 0..=3 {
nodes.insert(id, test_node(id, "v"));
}
let edge = Edge {
src: Name::from("a"),
tgt: Name::from("b"),
kind: Name::from("prop"),
name: None,
};
let arcs = vec![
(0, 1, edge.clone()),
(0, 2, edge.clone()),
(1, 3, edge.clone()),
(2, 3, edge),
];
let inst = WInstance::new(nodes, arcs, Vec::new(), 0, Name::from("root"));
let mut idx = ReachabilityIndex::from_instance(&inst);
let unreachable = idx.delete_edge(1, 3);
assert!(idx.is_reachable(3));
assert!(!unreachable.contains(&3));
assert!(idx.is_reachable(1));
}
#[test]
fn empty_instance() {
let inst = make_instance(0, &[]);
let idx = ReachabilityIndex::from_instance(&inst);
assert!(idx.is_reachable(0));
assert!(!idx.is_reachable(1));
assert_eq!(idx.root(), Some(0));
}
#[test]
fn deep_chain() {
let arcs: Vec<(u32, u32)> = (0..9).map(|i| (i, i + 1)).collect();
let inst = make_instance(0, &arcs);
let mut idx = ReachabilityIndex::from_instance(&inst);
for i in 0..=9 {
assert!(idx.is_reachable(i));
}
let unreachable = idx.delete_edge(4, 5);
for i in 0..=4 {
assert!(idx.is_reachable(i));
}
for i in 5..=9 {
assert!(!idx.is_reachable(i));
assert!(unreachable.contains(&i));
}
}
}