use crate::graph::EinsumGraph;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Cycle {
pub tensors: Vec<usize>,
pub nodes: Vec<usize>,
}
pub fn find_cycles(graph: &EinsumGraph) -> Vec<Cycle> {
let mut cycles = Vec::new();
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut path = Vec::new();
let adjacency = build_tensor_adjacency(graph);
for tensor_idx in 0..graph.tensors.len() {
if !visited.contains(&tensor_idx) {
dfs_find_cycles(
tensor_idx,
&adjacency,
&mut visited,
&mut rec_stack,
&mut path,
&mut cycles,
);
}
}
cycles
}
fn dfs_find_cycles(
tensor: usize,
adjacency: &HashMap<usize, Vec<usize>>,
visited: &mut HashSet<usize>,
rec_stack: &mut HashSet<usize>,
path: &mut Vec<usize>,
cycles: &mut Vec<Cycle>,
) {
visited.insert(tensor);
rec_stack.insert(tensor);
path.push(tensor);
if let Some(neighbors) = adjacency.get(&tensor) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
dfs_find_cycles(neighbor, adjacency, visited, rec_stack, path, cycles);
} else if rec_stack.contains(&neighbor) {
if let Some(cycle_start) = path.iter().position(|&t| t == neighbor) {
let cycle_tensors = path[cycle_start..].to_vec();
cycles.push(Cycle {
tensors: cycle_tensors,
nodes: Vec::new(), });
}
}
}
}
path.pop();
rec_stack.remove(&tensor);
}
fn build_tensor_adjacency(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
for node in &graph.nodes {
for &input_tensor in &node.inputs {
for &output_tensor in &node.outputs {
adjacency
.entry(input_tensor)
.or_default()
.push(output_tensor);
}
}
}
adjacency
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct StronglyConnectedComponent {
pub tensors: Vec<usize>,
pub nodes: Vec<usize>,
}
pub fn strongly_connected_components(graph: &EinsumGraph) -> Vec<StronglyConnectedComponent> {
let mut tarjan = TarjanSCC::new(graph);
tarjan.find_sccs();
tarjan.sccs
}
struct TarjanSCC<'a> {
graph: &'a EinsumGraph,
adjacency: HashMap<usize, Vec<usize>>,
index: usize,
indices: HashMap<usize, usize>,
lowlinks: HashMap<usize, usize>,
on_stack: HashSet<usize>,
stack: Vec<usize>,
sccs: Vec<StronglyConnectedComponent>,
}
impl<'a> TarjanSCC<'a> {
fn new(graph: &'a EinsumGraph) -> Self {
TarjanSCC {
graph,
adjacency: build_tensor_adjacency(graph),
index: 0,
indices: HashMap::new(),
lowlinks: HashMap::new(),
on_stack: HashSet::new(),
stack: Vec::new(),
sccs: Vec::new(),
}
}
fn find_sccs(&mut self) {
for tensor_idx in 0..self.graph.tensors.len() {
if !self.indices.contains_key(&tensor_idx) {
self.strong_connect(tensor_idx);
}
}
}
fn strong_connect(&mut self, v: usize) {
self.indices.insert(v, self.index);
self.lowlinks.insert(v, self.index);
self.index += 1;
self.stack.push(v);
self.on_stack.insert(v);
if let Some(neighbors) = self.adjacency.get(&v).cloned() {
for w in neighbors {
if !self.indices.contains_key(&w) {
self.strong_connect(w);
let w_lowlink = *self
.lowlinks
.get(&w)
.expect("lowlink must exist for visited node");
let v_lowlink = *self
.lowlinks
.get(&v)
.expect("lowlink must exist for visited node");
self.lowlinks.insert(v, v_lowlink.min(w_lowlink));
} else if self.on_stack.contains(&w) {
let w_index = *self
.indices
.get(&w)
.expect("index must exist for visited node");
let v_lowlink = *self
.lowlinks
.get(&v)
.expect("lowlink must exist for visited node");
self.lowlinks.insert(v, v_lowlink.min(w_index));
}
}
}
if self.lowlinks[&v] == self.indices[&v] {
let mut scc_tensors = Vec::new();
loop {
let w = self
.stack
.pop()
.expect("stack must be non-empty when processing SCC");
self.on_stack.remove(&w);
scc_tensors.push(w);
if w == v {
break;
}
}
self.sccs.push(StronglyConnectedComponent {
tensors: scc_tensors,
nodes: Vec::new(),
});
}
}
}
pub fn topological_sort(graph: &EinsumGraph) -> Option<Vec<usize>> {
let adjacency = build_tensor_adjacency(graph);
let mut in_degree = vec![0; graph.tensors.len()];
for neighbors in adjacency.values() {
for &neighbor in neighbors {
in_degree[neighbor] += 1;
}
}
let mut queue: VecDeque<usize> = in_degree
.iter()
.enumerate()
.filter(|(_, °)| deg == 0)
.map(|(idx, _)| idx)
.collect();
let mut result = Vec::new();
while let Some(tensor) = queue.pop_front() {
result.push(tensor);
if let Some(neighbors) = adjacency.get(&tensor) {
for &neighbor in neighbors {
in_degree[neighbor] -= 1;
if in_degree[neighbor] == 0 {
queue.push_back(neighbor);
}
}
}
}
if result.len() == graph.tensors.len() {
Some(result)
} else {
None
}
}
pub fn is_dag(graph: &EinsumGraph) -> bool {
topological_sort(graph).is_some()
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum IsomorphismResult {
Isomorphic { mapping: HashMap<usize, usize> },
NotIsomorphic,
}
pub fn are_isomorphic(g1: &EinsumGraph, g2: &EinsumGraph) -> IsomorphismResult {
if g1.tensors.len() != g2.tensors.len() || g1.nodes.len() != g2.nodes.len() {
return IsomorphismResult::NotIsomorphic;
}
let deg1 = compute_degree_sequence(g1);
let deg2 = compute_degree_sequence(g2);
if deg1 != deg2 {
return IsomorphismResult::NotIsomorphic;
}
let mut mapping = HashMap::new();
if backtrack_isomorphism(g1, g2, &mut mapping, 0) {
IsomorphismResult::Isomorphic { mapping }
} else {
IsomorphismResult::NotIsomorphic
}
}
fn compute_degree_sequence(graph: &EinsumGraph) -> Vec<(usize, usize)> {
let mut in_degrees = vec![0; graph.tensors.len()];
let mut out_degrees = vec![0; graph.tensors.len()];
for node in &graph.nodes {
for &input in &node.inputs {
out_degrees[input] += 1;
}
for &output in &node.outputs {
in_degrees[output] += 1;
}
}
let mut degrees: Vec<(usize, usize)> = in_degrees.into_iter().zip(out_degrees).collect();
degrees.sort_unstable();
degrees
}
fn backtrack_isomorphism(
g1: &EinsumGraph,
g2: &EinsumGraph,
mapping: &mut HashMap<usize, usize>,
tensor_idx: usize,
) -> bool {
if tensor_idx >= g1.tensors.len() {
return verify_isomorphism(g1, g2, mapping);
}
let mapped_values: HashSet<usize> = mapping.values().copied().collect();
for candidate in 0..g2.tensors.len() {
if !mapped_values.contains(&candidate) {
mapping.insert(tensor_idx, candidate);
if backtrack_isomorphism(g1, g2, mapping, tensor_idx + 1) {
return true;
}
mapping.remove(&tensor_idx);
}
}
false
}
fn verify_isomorphism(g1: &EinsumGraph, g2: &EinsumGraph, mapping: &HashMap<usize, usize>) -> bool {
let adj1 = build_tensor_adjacency(g1);
let adj2 = build_tensor_adjacency(g2);
for (u, neighbors) in &adj1 {
let u_mapped = mapping[u];
for &v in neighbors {
let v_mapped = mapping[&v];
if let Some(adj2_neighbors) = adj2.get(&u_mapped) {
if !adj2_neighbors.contains(&v_mapped) {
return false;
}
} else {
return false;
}
}
}
true
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CriticalPath {
pub tensors: Vec<usize>,
pub nodes: Vec<usize>,
pub length: f64,
}
pub fn critical_path_analysis(
graph: &EinsumGraph,
weights: &HashMap<usize, f64>,
) -> Option<CriticalPath> {
if !is_dag(graph) {
return None; }
let topo_order = topological_sort(graph)?;
let adjacency = build_tensor_adjacency(graph);
let mut distances: HashMap<usize, f64> = HashMap::new();
let mut predecessors: HashMap<usize, usize> = HashMap::new();
for &tensor in &topo_order {
distances.insert(tensor, 0.0);
}
for &u in &topo_order {
if let Some(neighbors) = adjacency.get(&u) {
let u_dist = distances[&u];
for &v in neighbors {
let weight = weights.get(&v).copied().unwrap_or(1.0);
let new_dist = u_dist + weight;
if new_dist > *distances.get(&v).unwrap_or(&0.0) {
distances.insert(v, new_dist);
predecessors.insert(v, u);
}
}
}
}
let (&end_tensor, &max_dist) = distances
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))?;
let mut path = Vec::new();
let mut current = end_tensor;
loop {
path.push(current);
if let Some(&pred) = predecessors.get(¤t) {
current = pred;
} else {
break;
}
}
path.reverse();
Some(CriticalPath {
tensors: path,
nodes: Vec::new(),
length: max_dist,
})
}
pub fn graph_diameter(graph: &EinsumGraph) -> Option<usize> {
let adjacency = build_tensor_adjacency(graph);
let mut max_distance = 0;
for start in 0..graph.tensors.len() {
let distances = bfs_distances(&adjacency, start);
if let Some(&max) = distances.values().max() {
max_distance = max_distance.max(max);
}
}
Some(max_distance)
}
fn bfs_distances(adjacency: &HashMap<usize, Vec<usize>>, source: usize) -> HashMap<usize, usize> {
let mut distances = HashMap::new();
let mut queue = VecDeque::new();
distances.insert(source, 0);
queue.push_back(source);
while let Some(u) = queue.pop_front() {
let dist_u = distances[&u];
if let Some(neighbors) = adjacency.get(&u) {
for &v in neighbors {
if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(v) {
e.insert(dist_u + 1);
queue.push_back(v);
}
}
}
}
distances
}
pub fn find_all_paths(graph: &EinsumGraph, from: usize, to: usize) -> Vec<Vec<usize>> {
let adjacency = build_tensor_adjacency(graph);
let mut paths = Vec::new();
let mut current_path = Vec::new();
let mut visited = HashSet::new();
dfs_all_paths(
from,
to,
&adjacency,
&mut current_path,
&mut visited,
&mut paths,
);
paths
}
fn dfs_all_paths(
current: usize,
target: usize,
adjacency: &HashMap<usize, Vec<usize>>,
path: &mut Vec<usize>,
visited: &mut HashSet<usize>,
paths: &mut Vec<Vec<usize>>,
) {
path.push(current);
visited.insert(current);
if current == target {
paths.push(path.clone());
} else if let Some(neighbors) = adjacency.get(¤t) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
dfs_all_paths(neighbor, target, adjacency, path, visited, paths);
}
}
}
path.pop();
visited.remove(¤t);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::{EinsumNode, OpType};
fn create_simple_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
let node = EinsumNode {
op: OpType::Einsum {
spec: "ij,jk->ik".to_string(),
},
inputs: vec![a, b],
outputs: vec![c],
metadata: Default::default(),
};
graph.add_node(node).expect("unwrap");
graph
}
#[test]
fn test_acyclic_graph_no_cycles() {
let graph = create_simple_graph();
let cycles = find_cycles(&graph);
assert!(cycles.is_empty());
}
#[test]
fn test_is_dag() {
let graph = create_simple_graph();
assert!(is_dag(&graph));
}
#[test]
fn test_topological_sort() {
let graph = create_simple_graph();
let topo = topological_sort(&graph);
assert!(topo.is_some());
let order = topo.expect("unwrap");
assert_eq!(order.len(), 3);
}
#[test]
fn test_strongly_connected_components() {
let graph = create_simple_graph();
let sccs = strongly_connected_components(&graph);
assert_eq!(sccs.len(), 3);
}
#[test]
fn test_graph_diameter() {
let graph = create_simple_graph();
let diameter = graph_diameter(&graph);
assert!(diameter.is_some());
assert!(diameter.expect("unwrap") >= 1);
}
#[test]
fn test_critical_path() {
let graph = create_simple_graph();
let weights = HashMap::new(); let critical = critical_path_analysis(&graph, &weights);
assert!(critical.is_some());
}
#[test]
fn test_find_all_paths() {
let graph = create_simple_graph();
let paths = find_all_paths(&graph, 0, 2);
assert!(!paths.is_empty());
}
#[test]
fn test_isomorphism_identical_graphs() {
let g1 = create_simple_graph();
let g2 = create_simple_graph();
let result = are_isomorphic(&g1, &g2);
assert!(matches!(result, IsomorphismResult::Isomorphic { .. }));
}
#[test]
fn test_isomorphism_different_sizes() {
let g1 = create_simple_graph();
let mut g2 = EinsumGraph::new();
g2.add_tensor("A");
let result = are_isomorphic(&g1, &g2);
assert_eq!(result, IsomorphismResult::NotIsomorphic);
}
#[test]
fn test_tensor_adjacency() {
let graph = create_simple_graph();
let adj = build_tensor_adjacency(&graph);
assert!(adj.contains_key(&0));
assert!(adj.contains_key(&1));
}
#[test]
fn test_degree_sequence() {
let graph = create_simple_graph();
let deg_seq = compute_degree_sequence(&graph);
assert_eq!(deg_seq.len(), 3);
}
#[test]
fn test_bfs_distances() {
let mut adj = HashMap::new();
adj.insert(0, vec![1, 2]);
adj.insert(1, vec![3]);
adj.insert(2, vec![3]);
let distances = bfs_distances(&adj, 0);
assert_eq!(distances[&0], 0);
assert_eq!(distances[&1], 1);
assert_eq!(distances[&3], 2);
}
}