use crate::algorithms::astar::CfgGraphNode;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct ReachabilityResult {
pub reachable: HashMap<u64, HashSet<u64>>,
pub total_pairs: usize,
}
pub fn transitive_closure(nodes: &[CfgGraphNode]) -> ReachabilityResult {
let node_map: HashMap<u64, &CfgGraphNode> = nodes.iter().map(|n| (n.id, n)).collect();
let mut reachable: HashMap<u64, HashSet<u64>> = HashMap::new();
let mut total_pairs = 0;
for node in nodes {
let mut reachable_from_node: HashSet<u64> = HashSet::new();
let mut worklist = vec![node.id];
while let Some(current) = worklist.pop() {
if let Some(current_node) = node_map.get(¤t) {
for &succ in ¤t_node.successors {
if !reachable_from_node.contains(&succ) {
reachable_from_node.insert(succ);
worklist.push(succ);
}
}
}
}
total_pairs += reachable_from_node.len();
reachable.insert(node.id, reachable_from_node);
}
ReachabilityResult {
reachable,
total_pairs,
}
}
pub fn transitive_reduction(nodes: &[CfgGraphNode]) -> Vec<(u64, u64)> {
let closure = transitive_closure(nodes);
let mut keep_edges: Vec<(u64, u64)> = Vec::new();
for node in nodes {
for &succ in &node.successors {
let mut has_alternative_path = false;
for &other_succ in &node.successors {
if other_succ != succ {
if let Some(other_reachable) = closure.reachable.get(&other_succ) {
if other_reachable.contains(&succ) {
has_alternative_path = true;
break;
}
}
}
}
if !has_alternative_path {
keep_edges.push((node.id, succ));
}
}
}
keep_edges
}
pub fn is_reachable(from: u64, to: u64, closure: &ReachabilityResult) -> bool {
closure
.reachable
.get(&from)
.map(|r| r.contains(&to))
.unwrap_or(false)
}
pub fn get_reachable_from(node: u64, closure: &ReachabilityResult) -> HashSet<u64> {
closure.reachable.get(&node).cloned().unwrap_or_default()
}
pub fn get_reachable_to(node: u64, closure: &ReachabilityResult) -> HashSet<u64> {
let mut result = HashSet::new();
for (&from, reachable) in &closure.reachable {
if reachable.contains(&node) {
result.insert(from);
}
}
result
}
pub fn count_ancestors(node: u64, closure: &ReachabilityResult) -> usize {
get_reachable_to(node, closure).len()
}
pub fn count_descendants(node: u64, closure: &ReachabilityResult) -> usize {
get_reachable_from(node, closure).len()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transitive_closure_linear() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![3],
},
CfgGraphNode {
id: 3,
x: 3.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let result = transitive_closure(&nodes);
assert_eq!(result.reachable[&0].len(), 3);
assert!(result.reachable[&0].contains(&1));
assert!(result.reachable[&0].contains(&3));
assert_eq!(result.reachable[&3].len(), 0);
}
#[test]
fn test_transitive_closure_branching() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1, 2],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![3],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![3],
},
CfgGraphNode {
id: 3,
x: 3.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let result = transitive_closure(&nodes);
assert_eq!(result.reachable[&0].len(), 3);
assert!(result.reachable[&0].contains(&1));
assert!(result.reachable[&0].contains(&2));
assert!(result.reachable[&0].contains(&3));
}
#[test]
fn test_transitive_reduction_linear() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![3],
},
CfgGraphNode {
id: 3,
x: 3.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let edges = transitive_reduction(&nodes);
assert_eq!(edges.len(), 3); }
#[test]
fn test_transitive_reduction_with_redundant() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1, 2],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let edges = transitive_reduction(&nodes);
assert_eq!(edges.len(), 2); assert!(edges.contains(&(0, 1)));
assert!(edges.contains(&(1, 2)));
}
#[test]
fn test_is_reachable() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let closure = transitive_closure(&nodes);
assert!(is_reachable(0, 2, &closure));
assert!(is_reachable(0, 1, &closure));
assert!(!is_reachable(2, 0, &closure));
}
#[test]
fn test_count_ancestors_descendants() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1, 2],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![3],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![3],
},
CfgGraphNode {
id: 3,
x: 3.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let closure = transitive_closure(&nodes);
assert_eq!(count_descendants(0, &closure), 3);
assert_eq!(count_ancestors(3, &closure), 3);
assert_eq!(count_descendants(3, &closure), 0);
}
#[test]
fn test_get_reachable_from() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1, 2],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let closure = transitive_closure(&nodes);
let reachable = get_reachable_from(0, &closure);
assert_eq!(reachable.len(), 2);
assert!(reachable.contains(&1));
assert!(reachable.contains(&2));
}
#[test]
fn test_total_pairs() {
let nodes = vec![
CfgGraphNode {
id: 0,
x: 0.0,
y: 0.0,
z: 0.0,
successors: vec![1],
},
CfgGraphNode {
id: 1,
x: 1.0,
y: 0.0,
z: 0.0,
successors: vec![2],
},
CfgGraphNode {
id: 2,
x: 2.0,
y: 0.0,
z: 0.0,
successors: vec![],
},
];
let result = transitive_closure(&nodes);
assert_eq!(result.total_pairs, 3);
}
}