use crate::graph::{Graph, TensorID};
use crate::Float;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct ParallelLevel {
pub level: usize,
pub nodes: Vec<TensorID>,
pub width: usize,
}
#[derive(Debug, Clone)]
pub struct CriticalPath {
pub path: Vec<TensorID>,
pub length: usize,
pub total_cost: f64,
}
#[derive(Debug, Clone)]
pub struct ReadyTask {
pub node_id: TensorID,
pub priority: usize,
pub op_name: String,
pub inputs: Vec<TensorID>,
pub num_dependents: usize,
}
#[derive(Debug, Clone)]
pub struct WorkStealingSchedule {
pub num_workers: usize,
pub worker_queues: Vec<Vec<ReadyTask>>,
pub total_tasks: usize,
pub max_parallelism: usize,
pub critical_path_length: usize,
pub estimated_speedup: f64,
}
#[derive(Debug, Clone)]
pub struct ParallelAnalysis {
pub levels: Vec<ParallelLevel>,
pub critical_path: CriticalPath,
pub max_parallelism: usize,
pub average_parallelism: f64,
pub total_work: usize,
pub span: usize,
}
fn compute_node_depths<F: Float>(graph: &Graph<F>) -> Vec<usize> {
let nodes = graph.node_set.borrow();
let n = nodes.len();
let mut depth = vec![0usize; n];
let mut order: Vec<TensorID> = (0..n).collect();
order.sort_by_key(|&id| nodes[id].topo_rank);
for &id in &order {
for inc in &nodes[id].incoming_nodes {
let pid = inc.id;
if pid < n {
let candidate = depth[pid] + 1;
if candidate > depth[id] {
depth[id] = candidate;
}
}
}
}
depth
}
pub fn level_decomposition<F: Float>(graph: &Graph<F>) -> Vec<ParallelLevel> {
let depths = compute_node_depths(graph);
let n = depths.len();
if n == 0 {
return Vec::new();
}
let max_depth = depths.iter().copied().max().unwrap_or(0);
let mut levels: Vec<Vec<TensorID>> = vec![Vec::new(); max_depth + 1];
for id in 0..n {
levels[depths[id]].push(id);
}
levels
.into_iter()
.enumerate()
.filter(|(_, nodes)| !nodes.is_empty())
.map(|(level, nodes)| {
let width = nodes.len();
ParallelLevel {
level,
nodes,
width,
}
})
.collect()
}
pub fn critical_path<F: Float>(graph: &Graph<F>) -> CriticalPath {
let nodes = graph.node_set.borrow();
let n = nodes.len();
if n == 0 {
return CriticalPath {
path: Vec::new(),
length: 0,
total_cost: 0.0,
};
}
let mut dist = vec![0usize; n];
let mut predecessor = vec![None::<TensorID>; n];
let mut order: Vec<TensorID> = (0..n).collect();
order.sort_by_key(|&id| nodes[id].topo_rank);
for &id in &order {
for inc in &nodes[id].incoming_nodes {
let pid = inc.id;
if pid < n {
let candidate = dist[pid] + 1;
if candidate > dist[id] {
dist[id] = candidate;
predecessor[id] = Some(pid);
}
}
}
}
let end_node = (0..n).max_by_key(|&id| dist[id]).unwrap_or(0);
let length = dist[end_node];
let mut path = Vec::new();
let mut current = Some(end_node);
while let Some(nid) = current {
path.push(nid);
current = predecessor[nid];
}
path.reverse();
CriticalPath {
path,
length,
total_cost: length as f64,
}
}
fn compute_priorities<F: Float>(graph: &Graph<F>) -> Vec<usize> {
let nodes = graph.node_set.borrow();
let n = nodes.len();
let mut priority = vec![0usize; n];
let mut children: Vec<Vec<TensorID>> = vec![Vec::new(); n];
for node in nodes.iter() {
for inc in &node.incoming_nodes {
if inc.id < n {
children[inc.id].push(node.id);
}
}
}
let mut order: Vec<TensorID> = (0..n).collect();
order.sort_by_key(|&id| std::cmp::Reverse(nodes[id].topo_rank));
for &id in &order {
for &child in &children[id] {
let candidate = priority[child] + 1;
if candidate > priority[id] {
priority[id] = candidate;
}
}
}
priority
}
pub fn work_stealing_schedule<F: Float>(
graph: &Graph<F>,
num_workers: usize,
) -> WorkStealingSchedule {
let num_workers = num_workers.max(1);
let nodes_ref = graph.node_set.borrow();
let n = nodes_ref.len();
if n == 0 {
return WorkStealingSchedule {
num_workers,
worker_queues: vec![Vec::new(); num_workers],
total_tasks: 0,
max_parallelism: 0,
critical_path_length: 0,
estimated_speedup: 1.0,
};
}
let node_data: Vec<(String, Vec<TensorID>)> = nodes_ref
.iter()
.map(|nd| {
let op_name = nd
.op
.as_ref()
.map(|o| o.name().to_owned())
.unwrap_or_else(|| "source".to_owned());
let inputs: Vec<TensorID> = nd.incoming_nodes.iter().map(|inc| inc.id).collect();
(op_name, inputs)
})
.collect();
let mut children_count: Vec<usize> = vec![0; n];
for nd in nodes_ref.iter() {
for inc in &nd.incoming_nodes {
if inc.id < n {
children_count[inc.id] += 1;
}
}
}
drop(nodes_ref);
let priorities = compute_priorities(graph);
let levels = level_decomposition(graph);
let cp = critical_path(graph);
let mut all_tasks: Vec<ReadyTask> = (0..n)
.map(|id| ReadyTask {
node_id: id,
priority: priorities[id],
op_name: node_data[id].0.clone(),
inputs: node_data[id].1.clone(),
num_dependents: children_count[id],
})
.collect();
all_tasks.sort_by(|a, b| b.priority.cmp(&a.priority));
let mut worker_queues: Vec<Vec<ReadyTask>> = vec![Vec::new(); num_workers];
for (i, task) in all_tasks.into_iter().enumerate() {
worker_queues[i % num_workers].push(task);
}
let max_parallelism = levels.iter().map(|l| l.width).max().unwrap_or(1);
let cp_len = cp.length;
let speedup = if cp_len > 0 {
n as f64 / cp_len as f64
} else {
1.0
};
WorkStealingSchedule {
num_workers,
worker_queues,
total_tasks: n,
max_parallelism,
critical_path_length: cp_len,
estimated_speedup: speedup,
}
}
pub fn parallel_analysis<F: Float>(graph: &Graph<F>) -> ParallelAnalysis {
let levels = level_decomposition(graph);
let cp = critical_path(graph);
let total_work: usize = levels.iter().map(|l| l.width).sum();
let max_par = levels.iter().map(|l| l.width).max().unwrap_or(0);
let span = cp.length;
let avg_par = if span > 0 {
total_work as f64 / span as f64
} else if total_work > 0 {
total_work as f64
} else {
0.0
};
ParallelAnalysis {
levels,
critical_path: cp,
max_parallelism: max_par,
average_parallelism: avg_par,
total_work,
span,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::AsGraph;
use crate::tensor_ops as T;
use crate::VariableEnvironment;
#[test]
fn test_level_decomposition_linear() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let b = a + T::ones(&[2], ctx);
let _ = b * T::ones(&[2], ctx);
let levels = level_decomposition(ctx.as_graph());
assert!(!levels.is_empty());
assert_eq!(levels[0].level, 0);
});
}
#[test]
fn test_level_decomposition_wide() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let b = T::ones(&[2], ctx);
let c = T::zeros(&[2], ctx);
let d = a + b;
let _ = d + c;
let levels = level_decomposition(ctx.as_graph());
assert!(levels[0].width >= 2, "Expected wide first level");
});
}
#[test]
fn test_critical_path() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let b = a + T::ones(&[2], ctx);
let c = b * T::ones(&[2], ctx);
let _ = c + T::ones(&[2], ctx);
let cp = critical_path(ctx.as_graph());
assert!(cp.length >= 1, "Critical path should have length >= 1");
assert!(!cp.path.is_empty());
});
}
#[test]
fn test_critical_path_empty_graph() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let cp = critical_path(ctx.as_graph());
assert_eq!(cp.length, 0);
assert!(cp.path.is_empty());
});
}
#[test]
fn test_work_stealing_schedule() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[3], ctx);
let b = T::ones(&[3], ctx);
let c = a + b;
let d = a * b;
let _ = c + d;
let ws = work_stealing_schedule(ctx.as_graph(), 4);
assert_eq!(ws.num_workers, 4);
assert!(ws.total_tasks > 0);
assert!(ws.max_parallelism >= 1);
let total_distributed: usize = ws.worker_queues.iter().map(|q| q.len()).sum();
assert_eq!(total_distributed, ws.total_tasks);
});
}
#[test]
fn test_work_stealing_single_worker() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let _ = a + T::ones(&[2], ctx);
let ws = work_stealing_schedule(ctx.as_graph(), 1);
assert_eq!(ws.num_workers, 1);
assert_eq!(ws.worker_queues.len(), 1);
assert_eq!(ws.worker_queues[0].len(), ws.total_tasks);
});
}
#[test]
fn test_parallel_analysis() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[4], ctx);
let b = T::ones(&[4], ctx);
let c = a + b;
let d = a * b;
let _ = c + d;
let analysis = parallel_analysis(ctx.as_graph());
assert!(analysis.total_work > 0);
assert!(analysis.max_parallelism >= 1);
assert!(analysis.average_parallelism > 0.0);
});
}
#[test]
fn test_parallel_analysis_empty() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let analysis = parallel_analysis(ctx.as_graph());
assert_eq!(analysis.total_work, 0);
assert_eq!(analysis.max_parallelism, 0);
});
}
#[test]
fn test_task_priorities() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let b = a + T::ones(&[2], ctx);
let _ = b * T::ones(&[2], ctx);
let priorities = compute_priorities(ctx.as_graph());
assert!(!priorities.is_empty());
});
}
}