use std::collections::{HashMap, HashSet, VecDeque};
use super::EinsumGraph;
use crate::error::IrError;
#[derive(Debug, Clone, PartialEq)]
pub struct ParallelGroup {
pub nodes: Vec<usize>,
pub estimated_cost: f64,
pub level: usize,
}
#[derive(Debug, Clone)]
pub struct ParallelizationAnalysis {
pub parallel_groups: Vec<ParallelGroup>,
pub max_parallelism: usize,
pub avg_parallelism: f64,
pub critical_path_length: usize,
pub critical_path: Vec<usize>,
pub estimated_speedup: f64,
}
impl ParallelizationAnalysis {
pub fn new() -> Self {
Self {
parallel_groups: Vec::new(),
max_parallelism: 0,
avg_parallelism: 0.0,
critical_path_length: 0,
critical_path: Vec::new(),
estimated_speedup: 1.0,
}
}
pub fn has_parallelism(&self) -> bool {
self.max_parallelism > 1
}
pub fn total_nodes(&self) -> usize {
self.parallel_groups.iter().map(|g| g.nodes.len()).sum()
}
}
impl Default for ParallelizationAnalysis {
fn default() -> Self {
Self::new()
}
}
pub fn analyze_parallelization(graph: &EinsumGraph) -> Result<ParallelizationAnalysis, IrError> {
if graph.nodes.is_empty() {
return Ok(ParallelizationAnalysis::new());
}
let (dependencies, dependents) = build_dependency_graph(graph);
let node_levels = compute_node_levels(graph, &dependencies);
let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for (node_idx, &level) in node_levels.iter().enumerate() {
level_groups.entry(level).or_default().push(node_idx);
}
let mut parallel_groups = Vec::new();
let max_level = node_levels.iter().max().copied().unwrap_or(0);
for level in 0..=max_level {
if let Some(nodes) = level_groups.get(&level) {
let estimated_cost = estimate_group_cost(graph, nodes);
parallel_groups.push(ParallelGroup {
nodes: nodes.clone(),
estimated_cost,
level,
});
}
}
let max_parallelism = parallel_groups
.iter()
.map(|g| g.nodes.len())
.max()
.unwrap_or(0);
let total_nodes: usize = parallel_groups.iter().map(|g| g.nodes.len()).sum();
let avg_parallelism = if !parallel_groups.is_empty() {
total_nodes as f64 / parallel_groups.len() as f64
} else {
0.0
};
let (critical_path, critical_path_length) =
find_critical_path(graph, &node_levels, &dependents);
let sequential_cost: f64 = (0..graph.nodes.len())
.map(|i| estimate_node_cost(graph, i))
.sum();
let parallel_cost: f64 = parallel_groups.iter().map(|g| g.estimated_cost).sum();
let estimated_speedup = if parallel_cost > 0.0 {
sequential_cost / parallel_cost
} else {
1.0
};
Ok(ParallelizationAnalysis {
parallel_groups,
max_parallelism,
avg_parallelism,
critical_path_length,
critical_path,
estimated_speedup,
})
}
fn build_dependency_graph(
graph: &EinsumGraph,
) -> (HashMap<usize, Vec<usize>>, HashMap<usize, Vec<usize>>) {
let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
let mut dependents: HashMap<usize, Vec<usize>> = HashMap::new();
let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &output_idx in &node.outputs {
tensor_producer.insert(output_idx, node_idx);
}
}
for (node_idx, node) in graph.nodes.iter().enumerate() {
let mut node_deps = Vec::new();
for &input_idx in &node.inputs {
if let Some(&producer_idx) = tensor_producer.get(&input_idx) {
if producer_idx != node_idx {
node_deps.push(producer_idx);
dependents.entry(producer_idx).or_default().push(node_idx);
}
}
}
dependencies.insert(node_idx, node_deps);
}
(dependencies, dependents)
}
fn compute_node_levels(
graph: &EinsumGraph,
dependencies: &HashMap<usize, Vec<usize>>,
) -> Vec<usize> {
let mut levels = vec![0; graph.nodes.len()];
let mut in_degree = vec![0; graph.nodes.len()];
for (node_idx, deps) in dependencies.iter() {
in_degree[*node_idx] = deps.len();
}
let mut queue: VecDeque<usize> = VecDeque::new();
for (node_idx, °ree) in in_degree.iter().enumerate() {
if degree == 0 && node_idx < graph.nodes.len() {
queue.push_back(node_idx);
levels[node_idx] = 0;
}
}
let mut dependents: HashMap<usize, Vec<usize>> = HashMap::new();
for (node_idx, deps) in dependencies.iter() {
for &dep in deps {
dependents.entry(dep).or_default().push(*node_idx);
}
}
let mut visited = HashSet::new();
while let Some(node_idx) = queue.pop_front() {
if visited.contains(&node_idx) {
continue;
}
visited.insert(node_idx);
let current_level = levels[node_idx];
if let Some(deps) = dependents.get(&node_idx) {
for &dep_idx in deps {
if dep_idx < graph.nodes.len() {
levels[dep_idx] = levels[dep_idx].max(current_level + 1);
queue.push_back(dep_idx);
}
}
}
}
levels
}
fn estimate_group_cost(graph: &EinsumGraph, nodes: &[usize]) -> f64 {
nodes
.iter()
.map(|&idx| estimate_node_cost(graph, idx))
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0)
}
fn estimate_node_cost(_graph: &EinsumGraph, _node_idx: usize) -> f64 {
1.0
}
fn find_critical_path(
graph: &EinsumGraph,
node_levels: &[usize],
_dependents: &HashMap<usize, Vec<usize>>,
) -> (Vec<usize>, usize) {
let max_level = node_levels.iter().max().copied().unwrap_or(0);
let end_nodes: Vec<usize> = node_levels
.iter()
.enumerate()
.filter(|(_, &level)| level == max_level)
.map(|(idx, _)| idx)
.collect();
if end_nodes.is_empty() {
return (Vec::new(), 0);
}
let mut path = Vec::new();
let mut current = end_nodes[0];
path.push(current);
while node_levels[current] > 0 {
let predecessors = get_predecessors(graph, current);
if let Some(&pred) = predecessors
.iter()
.max_by_key(|&&idx| node_levels.get(idx).copied().unwrap_or(0))
{
path.push(pred);
current = pred;
} else {
break;
}
}
path.reverse();
let length = path.len();
(path, length)
}
fn get_predecessors(graph: &EinsumGraph, node_idx: usize) -> Vec<usize> {
let mut predecessors = Vec::new();
let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
for (idx, node) in graph.nodes.iter().enumerate() {
for &output in &node.outputs {
tensor_producer.insert(output, idx);
}
}
if let Some(node) = graph.nodes.get(node_idx) {
for &input in &node.inputs {
if let Some(&producer) = tensor_producer.get(&input) {
predecessors.push(producer);
}
}
}
predecessors
}
pub fn partition_independent_subgraphs(graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
if graph.nodes.is_empty() {
return Ok(Vec::new());
}
let (dependencies, dependents) = build_dependency_graph(graph);
let mut visited = HashSet::new();
let mut subgraphs = Vec::new();
for node_idx in 0..graph.nodes.len() {
if visited.contains(&node_idx) {
continue;
}
let mut subgraph = Vec::new();
let mut stack = vec![node_idx];
while let Some(current) = stack.pop() {
if visited.contains(¤t) {
continue;
}
visited.insert(current);
subgraph.push(current);
if let Some(deps) = dependencies.get(¤t) {
stack.extend(deps.iter().copied());
}
if let Some(deps) = dependents.get(¤t) {
stack.extend(deps.iter().copied());
}
}
subgraphs.push(subgraph);
}
Ok(subgraphs)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::EinsumNode;
#[test]
fn test_parallelization_analysis_default() {
let analysis = ParallelizationAnalysis::default();
assert_eq!(analysis.max_parallelism, 0);
assert!(!analysis.has_parallelism());
}
#[test]
fn test_analyze_empty_graph() {
let graph = EinsumGraph::new();
let analysis = analyze_parallelization(&graph).expect("unwrap");
assert_eq!(analysis.max_parallelism, 0);
assert_eq!(analysis.total_nodes(), 0);
}
#[test]
fn test_analyze_single_node() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let analysis = analyze_parallelization(&graph).expect("unwrap");
assert_eq!(analysis.max_parallelism, 1);
assert_eq!(analysis.total_nodes(), 1);
}
#[test]
fn test_analyze_parallel_nodes() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
let d = graph.add_tensor("D");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("tanh", c, d))
.expect("unwrap");
let analysis = analyze_parallelization(&graph).expect("unwrap");
assert_eq!(analysis.max_parallelism, 2);
assert!(analysis.has_parallelism());
}
#[test]
fn test_analyze_sequential_nodes() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("tanh", b, c))
.expect("unwrap");
let analysis = analyze_parallelization(&graph).expect("unwrap");
assert_eq!(analysis.critical_path_length, 2);
}
#[test]
fn test_partition_empty_graph() {
let graph = EinsumGraph::new();
let subgraphs = partition_independent_subgraphs(&graph).expect("unwrap");
assert!(subgraphs.is_empty());
}
#[test]
fn test_partition_single_node() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let subgraphs = partition_independent_subgraphs(&graph).expect("unwrap");
assert_eq!(subgraphs.len(), 1);
assert_eq!(subgraphs[0].len(), 1);
}
#[test]
fn test_partition_independent_nodes() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
let d = graph.add_tensor("D");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("tanh", c, d))
.expect("unwrap");
let subgraphs = partition_independent_subgraphs(&graph).expect("unwrap");
assert_eq!(subgraphs.len(), 2);
}
#[test]
fn test_estimate_node_cost() {
let graph = EinsumGraph::new();
let cost = estimate_node_cost(&graph, 0);
assert_eq!(cost, 1.0);
}
#[test]
fn test_estimate_group_cost() {
let graph = EinsumGraph::new();
let cost = estimate_group_cost(&graph, &[0, 1, 2]);
assert_eq!(cost, 1.0); }
#[test]
fn test_parallel_group_creation() {
let group = ParallelGroup {
nodes: vec![0, 1, 2],
estimated_cost: 3.5,
level: 1,
};
assert_eq!(group.nodes.len(), 3);
assert_eq!(group.estimated_cost, 3.5);
assert_eq!(group.level, 1);
}
}