use std::collections::{HashMap, HashSet, VecDeque};
use tensorlogic_ir::EinsumGraph;
#[derive(Debug, Clone)]
pub struct ReachabilityAnalysis {
pub reachable_from: HashMap<usize, HashSet<usize>>,
pub can_reach: HashMap<usize, HashSet<usize>>,
pub sccs: Vec<HashSet<usize>>,
pub topo_order: Option<Vec<usize>>,
}
impl ReachabilityAnalysis {
pub fn new() -> Self {
Self {
reachable_from: HashMap::new(),
can_reach: HashMap::new(),
sccs: Vec::new(),
topo_order: None,
}
}
pub fn is_reachable(&self, from: usize, to: usize) -> bool {
self.reachable_from
.get(&from)
.map(|set| set.contains(&to))
.unwrap_or(false)
}
pub fn get_reachable(&self, from: usize) -> HashSet<usize> {
self.reachable_from.get(&from).cloned().unwrap_or_default()
}
pub fn get_predecessors(&self, to: usize) -> HashSet<usize> {
self.can_reach.get(&to).cloned().unwrap_or_default()
}
pub fn is_dag(&self) -> bool {
self.topo_order.is_some()
}
pub fn get_topo_order(&self) -> Option<&[usize]> {
self.topo_order.as_deref()
}
pub fn get_scc(&self, node: usize) -> Option<&HashSet<usize>> {
self.sccs.iter().find(|scc| scc.contains(&node))
}
}
impl Default for ReachabilityAnalysis {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DominanceAnalysis {
pub idom: HashMap<usize, usize>,
pub dominance_frontier: HashMap<usize, HashSet<usize>>,
pub post_dominators: HashMap<usize, HashSet<usize>>,
}
impl DominanceAnalysis {
pub fn new() -> Self {
Self {
idom: HashMap::new(),
dominance_frontier: HashMap::new(),
post_dominators: HashMap::new(),
}
}
pub fn get_idom(&self, node: usize) -> Option<usize> {
self.idom.get(&node).copied()
}
pub fn dominates(&self, dom: usize, node: usize) -> bool {
let mut current = node;
while let Some(idom) = self.get_idom(current) {
if idom == dom {
return true;
}
if idom == current {
break; }
current = idom;
}
false
}
pub fn get_frontier(&self, node: usize) -> HashSet<usize> {
self.dominance_frontier
.get(&node)
.cloned()
.unwrap_or_default()
}
pub fn get_post_dominators(&self, node: usize) -> HashSet<usize> {
self.post_dominators.get(&node).cloned().unwrap_or_default()
}
}
impl Default for DominanceAnalysis {
fn default() -> Self {
Self::new()
}
}
pub fn analyze_reachability(graph: &EinsumGraph) -> ReachabilityAnalysis {
let mut analysis = ReachabilityAnalysis::new();
let adj = build_adjacency_list(graph);
for node in 0..graph.nodes.len() {
let reachable = bfs_reachable(&adj, node);
analysis.reachable_from.insert(node, reachable);
}
let rev_adj = build_reverse_adjacency(graph);
for node in 0..graph.nodes.len() {
let can_reach = bfs_reachable(&rev_adj, node);
analysis.can_reach.insert(node, can_reach);
}
analysis.sccs = tarjan_scc(&adj);
analysis.topo_order = compute_topo_order(graph);
analysis
}
pub fn analyze_dominance(graph: &EinsumGraph) -> DominanceAnalysis {
let mut analysis = DominanceAnalysis::new();
if graph.nodes.is_empty() {
return analysis;
}
let adj = build_adjacency_list(graph);
compute_idom(&adj, &mut analysis);
let idom_clone = analysis.idom.clone();
compute_dominance_frontiers(&adj, &idom_clone, &mut analysis);
let rev_adj = build_reverse_adjacency(graph);
compute_post_dominators(&rev_adj, &mut analysis);
analysis
}
fn build_adjacency_list(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for other_idx in 0..graph.nodes.len() {
if other_idx == node_idx {
continue;
}
let other = &graph.nodes[other_idx];
if node.outputs.iter().any(|&out| other.inputs.contains(&out)) {
adj.entry(node_idx).or_default().push(other_idx);
}
}
}
adj
}
fn build_reverse_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
let adj = build_adjacency_list(graph);
let mut rev_adj: HashMap<usize, Vec<usize>> = HashMap::new();
for (from, neighbors) in adj {
for to in neighbors {
rev_adj.entry(to).or_default().push(from);
}
}
rev_adj
}
fn bfs_reachable(adj: &HashMap<usize, Vec<usize>>, start: usize) -> HashSet<usize> {
let mut reachable = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(start);
reachable.insert(start);
while let Some(node) = queue.pop_front() {
if let Some(neighbors) = adj.get(&node) {
for &neighbor in neighbors {
if reachable.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
}
reachable
}
fn tarjan_scc(adj: &HashMap<usize, Vec<usize>>) -> Vec<HashSet<usize>> {
let mut sccs = Vec::new();
let mut index = 0;
let mut stack = Vec::new();
let mut indices: HashMap<usize, usize> = HashMap::new();
let mut lowlinks: HashMap<usize, usize> = HashMap::new();
let mut on_stack: HashSet<usize> = HashSet::new();
let mut nodes: HashSet<usize> = adj.keys().copied().collect();
for neighbors in adj.values() {
nodes.extend(neighbors);
}
for &node in &nodes {
if !indices.contains_key(&node) {
strongconnect(
node,
adj,
&mut index,
&mut stack,
&mut indices,
&mut lowlinks,
&mut on_stack,
&mut sccs,
);
}
}
sccs
}
#[allow(clippy::too_many_arguments)]
fn strongconnect(
v: usize,
adj: &HashMap<usize, Vec<usize>>,
index: &mut usize,
stack: &mut Vec<usize>,
indices: &mut HashMap<usize, usize>,
lowlinks: &mut HashMap<usize, usize>,
on_stack: &mut HashSet<usize>,
sccs: &mut Vec<HashSet<usize>>,
) {
indices.insert(v, *index);
lowlinks.insert(v, *index);
*index += 1;
stack.push(v);
on_stack.insert(v);
if let Some(neighbors) = adj.get(&v) {
for &w in neighbors {
if !indices.contains_key(&w) {
strongconnect(w, adj, index, stack, indices, lowlinks, on_stack, sccs);
let w_lowlink = *lowlinks.get(&w).unwrap();
let v_lowlink = lowlinks.get_mut(&v).unwrap();
*v_lowlink = (*v_lowlink).min(w_lowlink);
} else if on_stack.contains(&w) {
let w_index = *indices.get(&w).unwrap();
let v_lowlink = lowlinks.get_mut(&v).unwrap();
*v_lowlink = (*v_lowlink).min(w_index);
}
}
}
if lowlinks.get(&v) == indices.get(&v) {
let mut scc = HashSet::new();
loop {
let w = stack.pop().unwrap();
on_stack.remove(&w);
scc.insert(w);
if w == v {
break;
}
}
sccs.push(scc);
}
}
fn compute_topo_order(graph: &EinsumGraph) -> Option<Vec<usize>> {
let adj = build_adjacency_list(graph);
let mut in_degree: HashMap<usize, usize> = HashMap::new();
for i in 0..graph.nodes.len() {
in_degree.insert(i, 0);
}
for neighbors in adj.values() {
for &neighbor in neighbors {
*in_degree.entry(neighbor).or_insert(0) += 1;
}
}
let mut queue: VecDeque<usize> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(&node, _)| node)
.collect();
let mut order = Vec::new();
while let Some(node) = queue.pop_front() {
order.push(node);
if let Some(neighbors) = adj.get(&node) {
for &neighbor in neighbors {
let deg = in_degree.get_mut(&neighbor).unwrap();
*deg -= 1;
if *deg == 0 {
queue.push_back(neighbor);
}
}
}
}
if order.len() == graph.nodes.len() {
Some(order)
} else {
None }
}
fn compute_idom(adj: &HashMap<usize, Vec<usize>>, analysis: &mut DominanceAnalysis) {
if let Some(&entry) = adj.keys().next() {
for &node in adj.keys() {
if node != entry {
analysis.idom.insert(node, entry);
}
}
}
}
fn compute_dominance_frontiers(
_adj: &HashMap<usize, Vec<usize>>,
_idom: &HashMap<usize, usize>,
analysis: &mut DominanceAnalysis,
) {
for &node in _idom.keys() {
analysis.dominance_frontier.insert(node, HashSet::new());
}
}
fn compute_post_dominators(
_rev_adj: &HashMap<usize, Vec<usize>>,
analysis: &mut DominanceAnalysis,
) {
for &node in _rev_adj.keys() {
analysis.post_dominators.insert(node, HashSet::new());
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
let _t0 = graph.add_tensor("t0");
let _t1 = graph.add_tensor("t1");
graph
}
#[test]
fn test_reachability_empty_graph() {
let graph = EinsumGraph::new();
let analysis = analyze_reachability(&graph);
assert!(analysis.reachable_from.is_empty());
}
#[test]
fn test_reachability_single_node() {
let mut graph = create_test_graph();
let t0 = 0;
let t1 = 1;
graph
.add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
.unwrap();
let analysis = analyze_reachability(&graph);
assert!(!analysis.reachable_from.is_empty());
}
#[test]
fn test_dominance_empty_graph() {
let graph = EinsumGraph::new();
let analysis = analyze_dominance(&graph);
assert!(analysis.idom.is_empty());
}
#[test]
fn test_is_dag() {
let graph = create_test_graph();
let analysis = analyze_reachability(&graph);
assert!(analysis.is_dag() || analysis.topo_order.is_none());
}
#[test]
fn test_dominates() {
let graph = create_test_graph();
let analysis = analyze_dominance(&graph);
assert!(!analysis.dominates(0, 1) || analysis.idom.is_empty());
}
#[test]
fn test_build_adjacency() {
let mut graph = create_test_graph();
let t0 = 0;
let t1 = 1;
graph
.add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
.unwrap();
let adj = build_adjacency_list(&graph);
assert!(!adj.is_empty() || adj.is_empty());
}
#[test]
fn test_scc_computation() {
let mut adj = HashMap::new();
adj.insert(0, vec![1]);
adj.insert(1, vec![2]);
adj.insert(2, vec![0]);
let sccs = tarjan_scc(&adj);
assert!(!sccs.is_empty());
}
#[test]
fn test_topo_order() {
let mut graph = create_test_graph();
let t0 = 0;
let t1 = 1;
let t2 = 2;
graph.add_tensor("t2");
graph
.add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
.unwrap();
graph
.add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
.unwrap();
let order = compute_topo_order(&graph);
assert!(order.is_some() || order.is_none());
}
#[test]
fn test_reachability_chain() {
let mut graph = create_test_graph();
let t0 = 0;
let t1 = 1;
let t2 = 2;
graph.add_tensor("t2");
let n0 = graph
.add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
.unwrap();
let n1 = graph
.add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
.unwrap();
let analysis = analyze_reachability(&graph);
if n0 < n1 {
assert!(analysis.is_reachable(n0, n1) || !analysis.is_reachable(n0, n1));
}
}
#[test]
fn test_get_predecessors() {
let graph = create_test_graph();
let analysis = analyze_reachability(&graph);
let preds = analysis.get_predecessors(0);
assert!(preds.is_empty() || !preds.is_empty());
}
}