use std::collections::{HashMap, HashSet, VecDeque};
use crate::compiled::CompiledGraph;
use crate::state::GraphState;
use crate::viz::extract_edges;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GraphAnalysis {
pub node_count: usize,
pub edge_count: usize,
pub has_cycle: bool,
pub depth: usize,
pub unreachable: Vec<String>,
}
impl<S: GraphState> CompiledGraph<S> {
pub fn analyze(&self) -> GraphAnalysis {
let nodes: Vec<&String> = self.graph.nodes.keys().collect();
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
for node in &nodes {
adj.insert(node.as_str(), Vec::new());
}
let edges = extract_edges(self);
for e in &edges {
adj.entry(e.from.as_str()).or_default().push(e.to.as_str());
}
let start = self.graph.start.as_deref();
let mut reachable: HashSet<&str> = HashSet::new();
if let Some(s) = start {
let mut q: VecDeque<&str> = VecDeque::from([s]);
reachable.insert(s);
while let Some(n) = q.pop_front() {
if let Some(nexts) = adj.get(n) {
for &m in nexts {
if m == "__END__" {
continue;
}
if reachable.insert(m) {
q.push_back(m);
}
}
}
}
}
let mut unreachable: Vec<String> = nodes
.iter()
.filter(|n| !reachable.contains(n.as_str()))
.map(|n| (*n).clone())
.collect();
unreachable.sort();
let has_cycle = match start {
Some(s) => detect_cycle(s, &adj, &reachable),
None => false,
};
let depth = match start {
Some(s) if !has_cycle => longest_path(s, &adj, &reachable),
_ => 0,
};
GraphAnalysis {
node_count: self.graph.nodes.len(),
edge_count: edges.len(),
has_cycle,
depth,
unreachable,
}
}
}
fn detect_cycle<'a>(
start: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
reachable: &HashSet<&'a str>,
) -> bool {
enum Color {
White,
Gray,
Black,
}
let mut color: HashMap<&str, Color> = HashMap::new();
for &n in reachable {
color.insert(n, Color::White);
}
fn visit<'a>(
node: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
reachable: &HashSet<&'a str>,
color: &mut HashMap<&'a str, Color>,
) -> bool {
color.insert(node, Color::Gray);
if let Some(nexts) = adj.get(node) {
for &m in nexts {
if m == "__END__" || !reachable.contains(m) {
continue;
}
let recurse = match color.get(m) {
Some(Color::Gray) => return true, Some(Color::White) => true,
_ => false,
};
if recurse && visit(m, adj, reachable, color) {
return true;
}
}
}
color.insert(node, Color::Black);
false
}
visit(start, adj, reachable, &mut color)
}
fn longest_path<'a>(
start: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
reachable: &HashSet<&'a str>,
) -> usize {
fn dfs<'a>(
node: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
reachable: &HashSet<&'a str>,
memo: &mut HashMap<&'a str, usize>,
) -> usize {
if let Some(&d) = memo.get(node) {
return d;
}
let mut best = 0;
if let Some(nexts) = adj.get(node) {
for &m in nexts {
if m == "__END__" || !reachable.contains(m) {
continue;
}
let d = 1 + dfs(m, adj, reachable, memo);
if d > best {
best = d;
}
}
}
memo.insert(node, best);
best
}
let mut memo: HashMap<&str, usize> = HashMap::new();
dfs(start, adj, reachable, &mut memo)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::Graph;
use crate::goto::Goto;
use crate::node::{node_fn, NodeOut};
#[derive(Default, Clone)]
struct S;
#[derive(Default)]
struct SU;
impl GraphState for S {
type Update = SU;
fn apply(&mut self, _: Self::Update) {}
}
fn nf(name: &'static str) -> impl crate::Node<S> + 'static {
node_fn::<S, _, _>(name, |_, _| async move {
Ok(NodeOut {
update: SU,
goto: Goto::end(),
})
})
}
#[test]
fn linear_three_node_chain_depth_2() {
let g = Graph::<S>::new()
.node("a", nf("a"))
.node("b", nf("b"))
.node("c", nf("c"))
.edge("a", "b")
.edge("b", "c")
.start_at("a")
.compile()
.unwrap();
let r = g.analyze();
assert_eq!(r.node_count, 3);
assert_eq!(r.edge_count, 2);
assert!(!r.has_cycle);
assert_eq!(r.depth, 2);
assert!(r.unreachable.is_empty());
}
#[test]
fn diamond_depth_is_2() {
let g = Graph::<S>::new()
.node("a", nf("a"))
.node("b", nf("b"))
.node("c", nf("c"))
.node("d", nf("d"))
.edge("a", "b")
.edge("a", "c")
.edge("b", "d")
.edge("c", "d")
.start_at("a")
.compile()
.unwrap();
let r = g.analyze();
assert!(!r.has_cycle);
assert_eq!(r.depth, 2);
}
#[test]
fn self_loop_is_a_cycle() {
let g = Graph::<S>::new()
.node("a", nf("a"))
.edge("a", "a")
.start_at("a")
.compile()
.unwrap();
let r = g.analyze();
assert!(r.has_cycle);
assert_eq!(r.depth, 0, "depth undefined for graphs with cycles");
}
#[test]
fn detects_back_edge_cycle() {
let g = Graph::<S>::new()
.node("a", nf("a"))
.node("b", nf("b"))
.node("c", nf("c"))
.edge("a", "b")
.edge("b", "c")
.edge("c", "a")
.start_at("a")
.compile()
.unwrap();
let r = g.analyze();
assert!(r.has_cycle);
}
#[test]
fn unreachable_nodes_listed() {
let g = Graph::<S>::new()
.node("a", nf("a"))
.node("b", nf("b"))
.node("c", nf("c"))
.edge("a", "b")
.start_at("a")
.compile()
.unwrap();
let r = g.analyze();
assert_eq!(r.unreachable, vec!["c".to_string()]);
assert!(!r.has_cycle);
assert_eq!(r.depth, 1);
}
}