use crate::algorithms::astar::CfgGraphNode;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct SccResult {
pub components: Vec<Vec<u64>>,
pub node_to_component: HashMap<u64, usize>,
pub cycle_count: usize,
}
struct TarjanState {
index: u64,
stack: Vec<u64>,
on_stack: HashSet<u64>,
indices: HashMap<u64, u64>,
lowlinks: HashMap<u64, u64>,
components: Vec<Vec<u64>>,
}
impl TarjanState {
fn new() -> Self {
Self {
index: 0,
stack: Vec::new(),
on_stack: HashSet::new(),
indices: HashMap::new(),
lowlinks: HashMap::new(),
components: Vec::new(),
}
}
}
pub fn tarjan_scc(nodes: &[CfgGraphNode]) -> SccResult {
let mut state = TarjanState::new();
let node_map: HashMap<u64, &CfgGraphNode> = nodes.iter().map(|n| (n.id, n)).collect();
for node in nodes {
if !state.indices.contains_key(&node.id) {
strongconnect(node.id, &node_map, &mut state);
}
}
let mut node_to_component = HashMap::new();
for (i, component) in state.components.iter().enumerate() {
for &node_id in component {
node_to_component.insert(node_id, i);
}
}
let cycle_count = state.components.iter().filter(|c| c.len() > 1).count();
SccResult {
components: state.components,
node_to_component,
cycle_count,
}
}
fn strongconnect(node_id: u64, node_map: &HashMap<u64, &CfgGraphNode>, state: &mut TarjanState) {
state.indices.insert(node_id, state.index);
state.lowlinks.insert(node_id, state.index);
state.index += 1;
state.stack.push(node_id);
state.on_stack.insert(node_id);
if let Some(node) = node_map.get(&node_id) {
for &successor_id in &node.successors {
if !state.indices.contains_key(&successor_id) {
strongconnect(successor_id, node_map, state);
let successor_lowlink = *state.lowlinks.get(&successor_id).unwrap();
let node_lowlink = state.lowlinks.get_mut(&node_id).unwrap();
*node_lowlink = (*node_lowlink).min(successor_lowlink);
} else if state.on_stack.contains(&successor_id) {
let successor_index = *state.indices.get(&successor_id).unwrap();
let node_lowlink = state.lowlinks.get_mut(&node_id).unwrap();
*node_lowlink = (*node_lowlink).min(successor_index);
}
}
}
let node_lowlink = *state.lowlinks.get(&node_id).unwrap();
let node_index = *state.indices.get(&node_id).unwrap();
if node_lowlink == node_index {
let mut component = Vec::new();
loop {
let w = state.stack.pop().unwrap();
state.on_stack.remove(&w);
component.push(w);
if w == node_id {
break;
}
}
state.components.push(component);
}
}
pub fn find_cycles(nodes: &[CfgGraphNode]) -> Vec<Vec<u64>> {
let result = tarjan_scc(nodes);
result
.components
.into_iter()
.filter(|c| c.len() > 1)
.collect()
}
pub fn has_cycles(nodes: &[CfgGraphNode]) -> bool {
let result = tarjan_scc(nodes);
result.cycle_count > 0
}
pub fn condense_graph(nodes: &[CfgGraphNode]) -> Vec<Vec<usize>> {
let result = tarjan_scc(nodes);
let num_components = result.components.len();
let mut condensed = vec![HashSet::new(); num_components];
for node in nodes {
let from_component = result.node_to_component.get(&node.id).copied().unwrap_or(0);
for &successor_id in &node.successors {
let to_component = result
.node_to_component
.get(&successor_id)
.copied()
.unwrap_or(0);
if from_component != to_component {
condensed[from_component].insert(to_component);
}
}
}
condensed
.into_iter()
.map(|s| s.into_iter().collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tarjan_scc_no_cycles() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let result = tarjan_scc(&nodes);
assert_eq!(result.components.len(), 3);
assert_eq!(result.cycle_count, 0);
}
#[test]
fn test_tarjan_scc_simple_cycle() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![0],
},
];
let result = tarjan_scc(&nodes);
assert_eq!(result.components.len(), 1);
assert_eq!(result.cycle_count, 1);
}
#[test]
fn test_has_cycles() {
let nodes_no = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
assert!(!has_cycles(&nodes_no));
let nodes_yes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![0],
},
];
assert!(has_cycles(&nodes_yes));
}
#[test]
fn test_condense_graph() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1, 2],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![0],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let condensed = condense_graph(&nodes);
assert_eq!(condensed.len(), 2);
}
}