use crate::Location;
use leo_span::Symbol;
use indexmap::{IndexMap, IndexSet};
use std::{fmt::Debug, hash::Hash, rc::Rc};
pub type CompositeGraph = DiGraph<Location>;
pub type CallGraph = DiGraph<Location>;
pub type ImportGraph = DiGraph<Symbol>;
pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
#[derive(Debug)]
pub enum DiGraphError<N: GraphNode> {
CycleDetected(Vec<N>),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DiGraph<N: GraphNode> {
nodes: IndexSet<Rc<N>>,
edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
}
impl<N: GraphNode> Default for DiGraph<N> {
fn default() -> Self {
Self { nodes: IndexSet::new(), edges: IndexMap::new() }
}
}
impl<N: GraphNode> DiGraph<N> {
pub fn new(nodes: IndexSet<N>) -> Self {
let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
Self { nodes, edges: IndexMap::new() }
}
pub fn add_node(&mut self, node: N) {
self.nodes.insert(Rc::new(node));
}
pub fn nodes(&self) -> impl Iterator<Item = &N> {
self.nodes.iter().map(|rc| rc.as_ref())
}
pub fn add_edge(&mut self, from: N, to: N) {
let from_rc = self.get_or_insert(from);
let to_rc = self.get_or_insert(to);
self.edges.entry(from_rc).or_default().insert(to_rc);
}
pub fn remove_node(&mut self, node: &N) -> bool {
if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
self.edges.shift_remove(&rc_node);
for targets in self.edges.values_mut() {
targets.shift_remove(&rc_node);
}
true
} else {
false
}
}
pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
self.edges
.get(node) .into_iter()
.flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
}
pub fn transitive_closure(&self, node: &N) -> IndexSet<N> {
let mut res = IndexSet::new();
let mut queue: Vec<_> = self.neighbors(node).collect();
while let Some(cur) = queue.pop() {
if !res.contains(cur) {
res.insert(cur.clone());
queue.extend(self.neighbors(cur));
}
}
res
}
pub fn contains_node(&self, node: N) -> bool {
self.nodes.contains(&Rc::new(node))
}
pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
self.post_order_with_filter(|_| true)
}
pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
where
F: Fn(&N) -> bool,
{
let mut finished = IndexSet::with_capacity(self.nodes.len());
for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
if !finished.contains(node_rc) {
let mut discovered = IndexSet::new();
if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
let mut path = vec![cycle_node.as_ref().clone()];
while let Some(next) = discovered.pop() {
path.push(next.as_ref().clone());
if Rc::ptr_eq(&next, &cycle_node) {
break;
}
}
path.reverse();
return Err(DiGraphError::CycleDetected(path));
}
}
}
Ok(finished.iter().map(|rc| (**rc).clone()).collect())
}
pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
self.nodes.retain(|n| keep.contains(n));
self.edges.retain(|n, _| keep.contains(n));
for targets in self.edges.values_mut() {
targets.retain(|t| keep.contains(t));
}
}
fn contains_cycle_from(
&self,
node: &Rc<N>,
discovered: &mut IndexSet<Rc<N>>,
finished: &mut IndexSet<Rc<N>>,
) -> Option<Rc<N>> {
discovered.insert(node.clone());
if let Some(children) = self.edges.get(node) {
for child in children {
if discovered.contains(child) {
return Some(child.clone());
}
if !finished.contains(child)
&& let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished)
{
return Some(cycle_node);
}
}
}
discovered.pop();
finished.insert(node.clone());
None
}
fn get_or_insert(&mut self, node: N) -> Rc<N> {
if let Some(existing) = self.nodes.get(&node) {
return existing.clone();
}
let rc = Rc::new(node);
self.nodes.insert(rc.clone());
rc
}
}
#[cfg(test)]
mod test {
use super::*;
fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
let result = graph.post_order();
assert!(result.is_ok());
let order: Vec<N> = result.unwrap().into_iter().collect();
assert_eq!(order, expected);
}
#[test]
fn test_post_order() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(1, 3);
graph.add_edge(2, 4);
graph.add_edge(3, 4);
graph.add_edge(4, 5);
check_post_order(&graph, &[5, 4, 2, 3, 1]);
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(6, 2);
graph.add_edge(2, 1);
graph.add_edge(2, 4);
graph.add_edge(4, 3);
graph.add_edge(4, 5);
graph.add_edge(6, 7);
graph.add_edge(7, 9);
graph.add_edge(9, 8);
check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
}
#[test]
fn test_cycle() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(2, 3);
graph.add_edge(2, 4);
graph.add_edge(4, 1);
let result = graph.post_order();
assert!(result.is_err());
let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
let expected = Vec::from([1u32, 2, 4, 1]);
assert_eq!(cycle, expected);
}
#[test]
fn test_transitive_closure() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(2, 3);
graph.add_edge(2, 4);
graph.add_edge(4, 1);
graph.add_edge(3, 5);
assert_eq!(graph.transitive_closure(&2), IndexSet::from([4, 1, 2, 3, 5]));
assert_eq!(graph.transitive_closure(&3), IndexSet::from([5]));
assert_eq!(graph.transitive_closure(&5), IndexSet::from([]));
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(1, 3);
graph.add_edge(2, 5);
graph.add_edge(3, 5);
graph.add_edge(3, 4);
assert_eq!(graph.transitive_closure(&1), IndexSet::from([2, 5, 3, 4]));
assert_eq!(graph.transitive_closure(&2), IndexSet::from([5]));
assert_eq!(graph.transitive_closure(&3), IndexSet::from([5, 4]));
assert_eq!(graph.transitive_closure(&4), IndexSet::from([]));
}
#[test]
fn test_unconnected_graph() {
let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
check_post_order(&graph, &[1, 2, 3, 4, 5]);
}
#[test]
fn test_retain_nodes() {
let mut graph = DiGraph::<u32>::new(IndexSet::new());
graph.add_edge(1, 2);
graph.add_edge(1, 3);
graph.add_edge(1, 5);
graph.add_edge(2, 3);
graph.add_edge(2, 4);
graph.add_edge(2, 5);
graph.add_edge(3, 4);
graph.add_edge(4, 5);
let mut nodes = IndexSet::new();
nodes.insert(1);
nodes.insert(2);
nodes.insert(3);
graph.retain_nodes(&nodes);
let mut expected = DiGraph::<u32>::new(IndexSet::new());
expected.add_edge(1, 2);
expected.add_edge(1, 3);
expected.add_edge(2, 3);
expected.edges.insert(3.into(), IndexSet::new());
assert_eq!(graph, expected);
}
}