use leo_span::Symbol;
use indexmap::{IndexMap, IndexSet};
use std::{fmt::Debug, hash::Hash};
pub type StructGraph = DiGraph<Symbol>;
pub type CallGraph = DiGraph<Symbol>;
pub type ImportGraph = DiGraph<Symbol>;
pub trait Node: Copy + 'static + Eq + PartialEq + Debug + Hash {}
impl Node for Symbol {}
#[derive(Debug)]
pub enum DiGraphError<N: Node> {
CycleDetected(Vec<N>),
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct DiGraph<N: Node> {
nodes: IndexSet<N>,
edges: IndexMap<N, IndexSet<N>>,
}
impl<N: Node> DiGraph<N> {
pub fn new(nodes: IndexSet<N>) -> Self {
Self { nodes, edges: IndexMap::new() }
}
pub fn add_node(&mut self, node: N) {
self.nodes.insert(node);
}
pub fn add_edge(&mut self, from: N, to: N) {
self.nodes.insert(from);
self.nodes.insert(to);
let entry = self.edges.entry(from).or_default();
entry.insert(to);
}
pub fn contains_node(&self, node: N) -> bool {
self.nodes.contains(&node)
}
pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
let mut finished: IndexSet<N> = IndexSet::with_capacity(self.nodes.len());
for node in self.nodes.iter() {
if !finished.contains(node) {
let mut discovered: IndexSet<N> = IndexSet::new();
if let Some(node) = self.contains_cycle_from(*node, &mut discovered, &mut finished) {
let mut path = vec![node];
while let Some(next) = discovered.pop() {
path.push(next);
if next == node {
break;
}
}
path.reverse();
return Err(DiGraphError::CycleDetected(path));
}
}
}
Ok(finished)
}
pub fn retain_nodes(&mut self, nodes: &IndexSet<N>) {
self.nodes.retain(|node| nodes.contains(node));
self.edges.retain(|node, _| nodes.contains(node));
for (_, children) in self.edges.iter_mut() {
children.retain(|child| nodes.contains(child));
}
}
fn contains_cycle_from(&self, node: N, discovered: &mut IndexSet<N>, finished: &mut IndexSet<N>) -> Option<N> {
discovered.insert(node);
if let Some(children) = self.edges.get(&node) {
for child in children.iter() {
if discovered.contains(child) {
return Some(*child);
}
if !finished.contains(child) {
if let Some(child) = self.contains_cycle_from(*child, discovered, finished) {
return Some(child);
}
}
}
}
discovered.pop();
finished.insert(node);
None
}
}
#[cfg(test)]
mod test {
use super::*;
impl Node for u32 {}
fn check_post_order<N: Node>(graph: &DiGraph<N>, expected: &[N]) {
let result = graph.post_order();
assert!(result.is_ok());
let order: Vec<N> = result.unwrap().into_iter().collect();
assert_eq!(order, expected);
}
#[test]
fn test_post_order() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(1, 3);
graph.add_edge(2, 4);
graph.add_edge(3, 4);
graph.add_edge(4, 5);
check_post_order(&graph, &[5, 4, 2, 3, 1]);
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(6, 2);
graph.add_edge(2, 1);
graph.add_edge(2, 4);
graph.add_edge(4, 3);
graph.add_edge(4, 5);
graph.add_edge(6, 7);
graph.add_edge(7, 9);
graph.add_edge(9, 8);
check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
}
#[test]
fn test_cycle() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(2, 3);
graph.add_edge(2, 4);
graph.add_edge(4, 1);
let result = graph.post_order();
assert!(result.is_err());
let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
let expected = Vec::from([1u32, 2, 4, 1]);
assert_eq!(cycle, expected);
}
#[test]
fn test_unconnected_graph() {
let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
check_post_order(&graph, &[1, 2, 3, 4, 5]);
}
#[test]
fn test_retain_nodes() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(1, 3);
graph.add_edge(1, 5);
graph.add_edge(2, 3);
graph.add_edge(2, 4);
graph.add_edge(2, 5);
graph.add_edge(3, 4);
graph.add_edge(4, 5);
let mut nodes = IndexSet::new();
nodes.insert(1);
nodes.insert(2);
nodes.insert(3);
graph.retain_nodes(&nodes);
let mut expected = DiGraph::<u32>::new(IndexSet::new());
expected.add_edge(1, 2);
expected.add_edge(1, 3);
expected.add_edge(2, 3);
expected.edges.insert(3, IndexSet::new());
assert_eq!(graph, expected);
}
}