use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::path::PathBuf;
use crate::types::{CycleGranularity, CycleReport, FunctionRef, SCC};
pub trait ToSccString {
fn to_scc_string(&self) -> String;
}
impl ToSccString for String {
fn to_scc_string(&self) -> String {
self.clone()
}
}
impl ToSccString for &str {
fn to_scc_string(&self) -> String {
(*self).to_string()
}
}
impl ToSccString for PathBuf {
fn to_scc_string(&self) -> String {
self.to_string_lossy().to_string()
}
}
impl ToSccString for FunctionRef {
fn to_scc_string(&self) -> String {
format!("{}:{}", self.file.display(), self.name)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TarjanPhase {
Entering,
ProcessingSuccessors,
Finishing,
}
#[derive(Debug, Clone)]
struct NodeState {
index: usize,
lowlink: usize,
on_stack: bool,
successor_idx: usize,
}
impl NodeState {
fn new(index: usize) -> Self {
Self {
index,
lowlink: index,
on_stack: true,
successor_idx: 0,
}
}
}
pub fn find_sccs<N>(nodes: &[N], edges: &HashMap<N, Vec<N>>) -> Vec<SCC>
where
N: Hash + Eq + Clone + ToSccString,
{
let mut index_counter: usize = 0;
let mut states: HashMap<N, NodeState> = HashMap::new();
let mut node_stack: Vec<N> = Vec::new();
let mut sccs: Vec<SCC> = Vec::new();
let mut work_stack: Vec<(N, TarjanPhase)> = Vec::new();
for start_node in nodes {
if states.contains_key(start_node) {
continue; }
work_stack.push((start_node.clone(), TarjanPhase::Entering));
while let Some((node, phase)) = work_stack.pop() {
match phase {
TarjanPhase::Entering => {
states.insert(node.clone(), NodeState::new(index_counter));
index_counter += 1;
node_stack.push(node.clone());
work_stack.push((node, TarjanPhase::ProcessingSuccessors));
}
TarjanPhase::ProcessingSuccessors => {
let successors = edges.get(&node).cloned().unwrap_or_default();
let state = states.get_mut(&node).unwrap();
let successor_idx = state.successor_idx;
if successor_idx < successors.len() {
let successor = &successors[successor_idx];
state.successor_idx += 1;
if !states.contains_key(successor) {
work_stack.push((node.clone(), TarjanPhase::ProcessingSuccessors));
work_stack.push((successor.clone(), TarjanPhase::Entering));
} else if states.get(successor).map(|s| s.on_stack).unwrap_or(false) {
let succ_index = states.get(successor).unwrap().index;
let state = states.get_mut(&node).unwrap();
state.lowlink = state.lowlink.min(succ_index);
work_stack.push((node.clone(), TarjanPhase::ProcessingSuccessors));
} else {
work_stack.push((node.clone(), TarjanPhase::ProcessingSuccessors));
}
} else {
work_stack.push((node, TarjanPhase::Finishing));
}
}
TarjanPhase::Finishing => {
let state = states.get(&node).unwrap();
let is_root = state.lowlink == state.index;
if is_root {
let mut scc_nodes: Vec<String> = Vec::new();
loop {
let w = node_stack.pop().expect("Stack should not be empty");
states.get_mut(&w).unwrap().on_stack = false;
scc_nodes.push(w.to_scc_string());
if w == node {
break;
}
}
let scc = SCC::new(scc_nodes);
sccs.push(scc);
}
if let Some((parent_node, TarjanPhase::ProcessingSuccessors)) =
work_stack.last()
{
let node_lowlink = states.get(&node).unwrap().lowlink;
let parent_state = states.get_mut(parent_node).unwrap();
parent_state.lowlink = parent_state.lowlink.min(node_lowlink);
}
}
}
}
}
sccs
}
pub fn detect_cycles(
graph: &HashMap<FunctionRef, Vec<FunctionRef>>,
granularity: CycleGranularity,
) -> CycleReport {
let mut all_nodes: HashSet<FunctionRef> = HashSet::new();
for (src, dsts) in graph {
all_nodes.insert(src.clone());
for dst in dsts {
all_nodes.insert(dst.clone());
}
}
let nodes: Vec<FunctionRef> = all_nodes.into_iter().collect();
let sccs = find_sccs(&nodes, graph);
let mut report = CycleReport::new(granularity);
for scc in sccs {
if scc.size > 1 {
let scc_nodes: HashSet<&String> = scc.nodes.iter().collect();
let mut edges: Vec<(String, String)> = Vec::new();
for node_str in &scc.nodes {
for (src, dsts) in graph {
if src.to_scc_string() == *node_str {
for dst in dsts {
let dst_str = dst.to_scc_string();
if scc_nodes.contains(&dst_str) {
edges.push((node_str.clone(), dst_str));
}
}
}
}
}
let scc_with_edges = scc.with_edges(edges);
report.add_cycle(scc_with_edges);
}
}
report.with_explanation()
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn make_string_graph(
edges_list: &[(&str, &str)],
) -> (Vec<String>, HashMap<String, Vec<String>>) {
let mut nodes: HashSet<String> = HashSet::new();
let mut edges: HashMap<String, Vec<String>> = HashMap::new();
for (src, dst) in edges_list {
nodes.insert(src.to_string());
nodes.insert(dst.to_string());
edges
.entry(src.to_string())
.or_default()
.push(dst.to_string());
}
(nodes.into_iter().collect(), edges)
}
#[test]
fn tarjan_finds_simple_cycle() {
let (nodes, edges) = make_string_graph(&[("A", "B"), ("B", "A")]);
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert_eq!(
cycles.len(),
1,
"Expected one cycle, found {}",
cycles.len()
);
assert_eq!(cycles[0].size, 2);
assert!(cycles[0].nodes.contains(&"A".to_string()));
assert!(cycles[0].nodes.contains(&"B".to_string()));
}
#[test]
fn tarjan_finds_complex_scc() {
let (nodes, edges) = make_string_graph(&[("A", "B"), ("B", "C"), ("C", "A")]);
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert_eq!(cycles.len(), 1, "Expected one cycle");
assert_eq!(cycles[0].size, 3);
assert!(cycles[0].nodes.contains(&"A".to_string()));
assert!(cycles[0].nodes.contains(&"B".to_string()));
assert!(cycles[0].nodes.contains(&"C".to_string()));
}
#[test]
fn tarjan_no_false_positives() {
let (nodes, edges) =
make_string_graph(&[("root", "A"), ("root", "B"), ("A", "C"), ("B", "C")]);
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert!(cycles.is_empty(), "DAG should have no cycles");
}
#[test]
fn tarjan_finds_multiple_sccs() {
let (nodes, edges) =
make_string_graph(&[("A", "B"), ("B", "A"), ("X", "Y"), ("Y", "Z"), ("Z", "X")]);
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert_eq!(cycles.len(), 2, "Expected two cycles");
let sizes: Vec<_> = cycles.iter().map(|c| c.size).collect();
assert!(sizes.contains(&2), "Expected 2-node cycle");
assert!(sizes.contains(&3), "Expected 3-node cycle");
}
#[test]
fn tarjan_handles_self_loop() {
let nodes = vec!["A".to_string()];
let mut edges: HashMap<String, Vec<String>> = HashMap::new();
edges.insert("A".to_string(), vec!["A".to_string()]);
let sccs = find_sccs(&nodes, &edges);
assert!(!sccs.is_empty());
}
#[test]
fn tarjan_handles_disconnected_components() {
let (nodes, edges) = make_string_graph(&[
("A", "B"),
("B", "A"), ("X", "Y"), ]);
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert_eq!(cycles.len(), 1, "Expected one cycle in disconnected graph");
}
#[test]
fn tarjan_handles_deep_chain() {
const DEPTH: usize = 5000;
let nodes: Vec<String> = (0..DEPTH).map(|i| i.to_string()).collect();
let mut edges: HashMap<String, Vec<String>> = HashMap::new();
for i in 0..DEPTH {
let next = (i + 1) % DEPTH; edges.insert(i.to_string(), vec![next.to_string()]);
}
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert_eq!(cycles.len(), 1, "Expected one large cycle");
assert_eq!(cycles[0].size, DEPTH);
}
#[test]
fn tarjan_handles_wide_graph() {
const WIDTH: usize = 10000;
let mut nodes: Vec<String> = vec!["root".to_string()];
let mut edges: HashMap<String, Vec<String>> = HashMap::new();
let mut root_edges = Vec::new();
for i in 0..WIDTH {
let node = format!("leaf_{}", i);
nodes.push(node.clone());
root_edges.push(node);
}
edges.insert("root".to_string(), root_edges);
let sccs = find_sccs(&nodes, &edges);
let cycles: Vec<_> = sccs.iter().filter(|scc| scc.size > 1).collect();
assert!(cycles.is_empty(), "Wide DAG should have no cycles");
}
#[test]
fn detect_cycles_with_function_refs() {
let mut graph: HashMap<FunctionRef, Vec<FunctionRef>> = HashMap::new();
let func_a = FunctionRef::new(PathBuf::from("a.py"), "func_a");
let func_b = FunctionRef::new(PathBuf::from("b.py"), "func_b");
let func_c = FunctionRef::new(PathBuf::from("c.py"), "func_c");
graph.insert(func_a.clone(), vec![func_b.clone()]);
graph.insert(func_b.clone(), vec![func_c.clone()]);
graph.insert(func_c.clone(), vec![func_a.clone()]);
let report = detect_cycles(&graph, CycleGranularity::Function);
assert_eq!(report.summary.cycle_count, 1);
assert_eq!(report.summary.largest_cycle, 3);
assert!(report.explanation.contains("1 cycle"));
assert!(report.explanation.contains("3 nodes"));
}
#[test]
fn detect_cycles_report_edges() {
let mut graph: HashMap<FunctionRef, Vec<FunctionRef>> = HashMap::new();
let func_a = FunctionRef::new(PathBuf::from("test.py"), "a");
let func_b = FunctionRef::new(PathBuf::from("test.py"), "b");
graph.insert(func_a.clone(), vec![func_b.clone()]);
graph.insert(func_b.clone(), vec![func_a.clone()]);
let report = detect_cycles(&graph, CycleGranularity::Function);
assert_eq!(report.cycles.len(), 1);
assert!(!report.cycles[0].edges.is_empty());
}
#[test]
fn detect_cycles_empty_graph() {
let graph: HashMap<FunctionRef, Vec<FunctionRef>> = HashMap::new();
let report = detect_cycles(&graph, CycleGranularity::Function);
assert_eq!(report.summary.cycle_count, 0);
assert!(report.cycles.is_empty());
}
}