use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum AutoParallelError {
#[error("Dependency cycle detected: {0}")]
DependencyCycle(String),
#[error("Invalid graph: {0}")]
InvalidGraph(String),
#[error("Cost model error: {0}")]
CostModelError(String),
#[error("Partitioning failed: {0}")]
PartitioningFailed(String),
}
pub type NodeId = String;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ParallelizationStrategy {
Conservative,
Balanced,
Aggressive,
CostBased,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CostModel {
Heuristic,
ProfileBased,
Analytical,
Hybrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DependencyType {
Data,
Control,
Memory,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub id: NodeId,
pub op_type: String,
pub estimated_cost: f64, pub memory_size: usize, pub dependencies: Vec<(NodeId, DependencyType)>,
pub can_parallelize: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelStage {
pub stage_id: usize,
pub nodes: Vec<NodeId>,
pub estimated_time: f64,
pub memory_requirement: usize,
pub predecessors: Vec<usize>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkPartition {
pub worker_id: usize,
pub nodes: Vec<NodeId>,
pub estimated_load: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelizationAnalysis {
pub num_stages: usize,
pub stages: Vec<ParallelStage>,
pub critical_path_length: f64,
pub total_work: f64,
pub parallelism_factor: f64, pub communication_overhead: f64,
pub recommended_workers: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelExecutionPlan {
pub stages: Vec<ParallelStage>,
pub partitions: Vec<WorkPartition>,
pub estimated_speedup: f64,
pub load_balance_ratio: f64,
}
pub struct AutoParallelizer {
strategy: ParallelizationStrategy,
cost_model: CostModel,
max_workers: usize,
overhead_per_task: f64, communication_bandwidth: f64, profile_data: HashMap<String, f64>, }
impl AutoParallelizer {
pub fn new() -> Self {
Self {
strategy: ParallelizationStrategy::Balanced,
cost_model: CostModel::Heuristic,
max_workers: num_cpus::get(),
overhead_per_task: 10.0, communication_bandwidth: 100.0, profile_data: HashMap::new(),
}
}
pub fn with_strategy(mut self, strategy: ParallelizationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_cost_model(mut self, model: CostModel) -> Self {
self.cost_model = model;
self
}
pub fn with_max_workers(mut self, workers: usize) -> Self {
self.max_workers = workers;
self
}
pub fn update_profile(&mut self, op_type: String, time_us: f64) {
let entry = self.profile_data.entry(op_type).or_insert(0.0);
*entry = 0.9 * *entry + 0.1 * time_us; }
pub fn analyze(
&self,
nodes: &[NodeInfo],
) -> Result<ParallelizationAnalysis, AutoParallelError> {
let dep_graph = self.build_dependency_graph(nodes)?;
let stages = self.compute_stages(nodes, &dep_graph)?;
let critical_path_length = self.calculate_critical_path(&stages);
let total_work: f64 = nodes.iter().map(|n| n.estimated_cost).sum();
let communication_overhead = self.estimate_communication_overhead(&stages, nodes);
let parallelism_factor = if critical_path_length > 0.0 {
total_work / critical_path_length
} else {
1.0
};
let recommended_workers = self.recommend_worker_count(parallelism_factor);
Ok(ParallelizationAnalysis {
num_stages: stages.len(),
stages,
critical_path_length,
total_work,
parallelism_factor,
communication_overhead,
recommended_workers,
})
}
pub fn generate_plan(
&self,
nodes: &[NodeInfo],
) -> Result<ParallelExecutionPlan, AutoParallelError> {
let analysis = self.analyze(nodes)?;
let partitions = self.partition_work(&analysis)?;
let sequential_time = analysis.total_work;
let parallel_time = analysis.critical_path_length + analysis.communication_overhead;
let estimated_speedup = if parallel_time > 0.0 {
sequential_time / parallel_time
} else {
1.0
};
let load_balance_ratio = self.calculate_load_balance(&partitions);
Ok(ParallelExecutionPlan {
stages: analysis.stages,
partitions,
estimated_speedup,
load_balance_ratio,
})
}
fn build_dependency_graph(
&self,
nodes: &[NodeInfo],
) -> Result<HashMap<NodeId, HashSet<NodeId>>, AutoParallelError> {
let mut graph: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
for node in nodes {
graph.entry(node.id.clone()).or_insert_with(HashSet::new);
}
for node in nodes {
for (dep_id, _dep_type) in &node.dependencies {
if !graph.contains_key(dep_id) {
return Err(AutoParallelError::InvalidGraph(format!(
"Unknown dependency: {}",
dep_id
)));
}
graph
.entry(node.id.clone())
.or_insert_with(HashSet::new)
.insert(dep_id.clone());
}
}
self.check_cycles(&graph)?;
Ok(graph)
}
fn check_cycles(
&self,
graph: &HashMap<NodeId, HashSet<NodeId>>,
) -> Result<(), AutoParallelError> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for node in graph.keys() {
if !visited.contains(node) {
if self.has_cycle_util(node, graph, &mut visited, &mut rec_stack)? {
return Err(AutoParallelError::DependencyCycle(format!(
"Cycle detected involving node: {}",
node
)));
}
}
}
Ok(())
}
fn has_cycle_util(
&self,
node: &NodeId,
graph: &HashMap<NodeId, HashSet<NodeId>>,
visited: &mut HashSet<NodeId>,
rec_stack: &mut HashSet<NodeId>,
) -> Result<bool, AutoParallelError> {
visited.insert(node.clone());
rec_stack.insert(node.clone());
if let Some(neighbors) = graph.get(node) {
for neighbor in neighbors {
if !visited.contains(neighbor) {
if self.has_cycle_util(neighbor, graph, visited, rec_stack)? {
return Ok(true);
}
} else if rec_stack.contains(neighbor) {
return Ok(true);
}
}
}
rec_stack.remove(node);
Ok(false)
}
fn compute_stages(
&self,
nodes: &[NodeInfo],
dep_graph: &HashMap<NodeId, HashSet<NodeId>>,
) -> Result<Vec<ParallelStage>, AutoParallelError> {
let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
let mut node_map: HashMap<NodeId, &NodeInfo> = HashMap::new();
for node in nodes {
node_map.insert(node.id.clone(), node);
let deps = dep_graph
.get(&node.id)
.expect("dep_graph built from same nodes");
in_degree.insert(node.id.clone(), deps.len());
}
let mut stages = Vec::new();
let mut current_level: VecDeque<NodeId> = VecDeque::new();
for (node_id, °ree) in &in_degree {
if degree == 0 {
current_level.push_back(node_id.clone());
}
}
let mut stage_id = 0;
while !current_level.is_empty() {
let mut stage_nodes = Vec::new();
let mut estimated_time: f64 = 0.0;
let mut memory_requirement = 0;
for _ in 0..current_level.len() {
if let Some(node_id) = current_level.pop_front() {
let node = node_map[&node_id];
stage_nodes.push(node_id.clone());
estimated_time = estimated_time.max(node.estimated_cost);
memory_requirement += node.memory_size;
for other_id in node_map.keys() {
if dep_graph[other_id].contains(&node_id) {
if let Some(degree) = in_degree.get_mut(other_id) {
*degree -= 1;
if *degree == 0 {
current_level.push_back(other_id.clone());
}
}
}
}
}
}
if !stage_nodes.is_empty() {
stages.push(ParallelStage {
stage_id,
nodes: stage_nodes,
estimated_time,
memory_requirement,
predecessors: if stage_id > 0 {
vec![stage_id - 1]
} else {
vec![]
},
});
stage_id += 1;
}
}
if stages.iter().map(|s| s.nodes.len()).sum::<usize>() != nodes.len() {
return Err(AutoParallelError::DependencyCycle(
"Not all nodes were processed - cycle detected".to_string(),
));
}
Ok(stages)
}
fn calculate_critical_path(&self, stages: &[ParallelStage]) -> f64 {
stages.iter().map(|s| s.estimated_time).sum()
}
fn estimate_communication_overhead(
&self,
stages: &[ParallelStage],
_nodes: &[NodeInfo],
) -> f64 {
let mut overhead = 0.0;
for stage in stages {
if stage.nodes.len() > 1 {
overhead += self.overhead_per_task * stage.nodes.len() as f64;
let transfer_time =
stage.memory_requirement as f64 / (self.communication_bandwidth * 1e9) * 1e6;
overhead += transfer_time;
}
}
overhead
}
fn recommend_worker_count(&self, parallelism_factor: f64) -> usize {
let ideal = parallelism_factor.ceil() as usize;
match self.strategy {
ParallelizationStrategy::Conservative => ideal.min(self.max_workers / 2).max(1),
ParallelizationStrategy::Balanced => ideal.min(self.max_workers),
ParallelizationStrategy::Aggressive => self.max_workers,
ParallelizationStrategy::CostBased => {
if parallelism_factor > 2.0 {
ideal.min(self.max_workers)
} else {
(ideal / 2).max(1)
}
}
}
}
fn partition_work(
&self,
analysis: &ParallelizationAnalysis,
) -> Result<Vec<WorkPartition>, AutoParallelError> {
let num_workers = analysis.recommended_workers;
let mut partitions: Vec<WorkPartition> = (0..num_workers)
.map(|i| WorkPartition {
worker_id: i,
nodes: Vec::new(),
estimated_load: 0.0,
})
.collect();
for stage in &analysis.stages {
let mut stage_nodes: Vec<(NodeId, f64)> = stage
.nodes
.iter()
.map(|id| (id.clone(), 1.0)) .collect();
stage_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (node_id, cost) in stage_nodes {
let min_partition = partitions
.iter_mut()
.min_by(|a, b| {
a.estimated_load
.partial_cmp(&b.estimated_load)
.unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| {
AutoParallelError::PartitioningFailed("No partitions available".to_string())
})?;
min_partition.nodes.push(node_id);
min_partition.estimated_load += cost;
}
}
Ok(partitions)
}
fn calculate_load_balance(&self, partitions: &[WorkPartition]) -> f64 {
if partitions.is_empty() {
return 1.0;
}
let loads: Vec<f64> = partitions.iter().map(|p| p.estimated_load).collect();
let max_load = loads.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let avg_load = loads.iter().sum::<f64>() / loads.len() as f64;
if max_load > 0.0 {
avg_load / max_load
} else {
1.0
}
}
}
impl Default for AutoParallelizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_nodes() -> Vec<NodeInfo> {
vec![
NodeInfo {
id: "a".to_string(),
op_type: "input".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![],
can_parallelize: true,
},
NodeInfo {
id: "b".to_string(),
op_type: "compute".to_string(),
estimated_cost: 20.0,
memory_size: 2000,
dependencies: vec![("a".to_string(), DependencyType::Data)],
can_parallelize: true,
},
NodeInfo {
id: "c".to_string(),
op_type: "compute".to_string(),
estimated_cost: 15.0,
memory_size: 1500,
dependencies: vec![("a".to_string(), DependencyType::Data)],
can_parallelize: true,
},
NodeInfo {
id: "d".to_string(),
op_type: "output".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![
("b".to_string(), DependencyType::Data),
("c".to_string(), DependencyType::Data),
],
can_parallelize: false,
},
]
}
#[test]
fn test_auto_parallelizer_creation() {
let parallelizer = AutoParallelizer::new();
assert_eq!(parallelizer.strategy, ParallelizationStrategy::Balanced);
assert_eq!(parallelizer.cost_model, CostModel::Heuristic);
}
#[test]
fn test_builder_pattern() {
let parallelizer = AutoParallelizer::new()
.with_strategy(ParallelizationStrategy::Aggressive)
.with_cost_model(CostModel::ProfileBased)
.with_max_workers(8);
assert_eq!(parallelizer.strategy, ParallelizationStrategy::Aggressive);
assert_eq!(parallelizer.cost_model, CostModel::ProfileBased);
assert_eq!(parallelizer.max_workers, 8);
}
#[test]
fn test_dependency_graph_building() {
let parallelizer = AutoParallelizer::new();
let nodes = create_test_nodes();
let graph = parallelizer.build_dependency_graph(&nodes).expect("unwrap");
assert_eq!(graph.len(), 4);
assert!(graph["b"].contains("a"));
assert!(graph["c"].contains("a"));
assert!(graph["d"].contains("b"));
assert!(graph["d"].contains("c"));
}
#[test]
fn test_cycle_detection() {
let parallelizer = AutoParallelizer::new();
let nodes = vec![
NodeInfo {
id: "a".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![("b".to_string(), DependencyType::Data)],
can_parallelize: true,
},
NodeInfo {
id: "b".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![("a".to_string(), DependencyType::Data)],
can_parallelize: true,
},
];
let result = parallelizer.build_dependency_graph(&nodes);
assert!(result.is_err());
}
#[test]
fn test_stage_computation() {
let parallelizer = AutoParallelizer::new();
let nodes = create_test_nodes();
let analysis = parallelizer.analyze(&nodes).expect("unwrap");
assert_eq!(analysis.num_stages, 3);
assert_eq!(analysis.stages[0].nodes, vec!["a"]);
assert_eq!(analysis.stages[1].nodes.len(), 2); assert!(analysis.stages[1].nodes.contains(&"b".to_string()));
assert!(analysis.stages[1].nodes.contains(&"c".to_string()));
assert_eq!(analysis.stages[2].nodes, vec!["d"]);
}
#[test]
fn test_critical_path_calculation() {
let parallelizer = AutoParallelizer::new();
let nodes = create_test_nodes();
let analysis = parallelizer.analyze(&nodes).expect("unwrap");
assert_eq!(analysis.critical_path_length, 40.0);
}
#[test]
fn test_parallelism_factor() {
let parallelizer = AutoParallelizer::new();
let nodes = create_test_nodes();
let analysis = parallelizer.analyze(&nodes).expect("unwrap");
assert!((analysis.parallelism_factor - 1.375).abs() < 0.01);
}
#[test]
fn test_execution_plan_generation() {
let parallelizer = AutoParallelizer::new();
let nodes = create_test_nodes();
let plan = parallelizer.generate_plan(&nodes).expect("unwrap");
assert_eq!(plan.stages.len(), 3);
assert!(!plan.partitions.is_empty());
assert!(plan.estimated_speedup > 0.0);
assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
}
#[test]
fn test_profile_update() {
let mut parallelizer = AutoParallelizer::new();
parallelizer.update_profile("compute".to_string(), 100.0);
parallelizer.update_profile("compute".to_string(), 200.0);
assert!(parallelizer.profile_data.contains_key("compute"));
let avg = parallelizer.profile_data["compute"];
assert!(avg >= 0.0);
}
#[test]
fn test_strategy_variations() {
let nodes = create_test_nodes();
let conservative = AutoParallelizer::new()
.with_strategy(ParallelizationStrategy::Conservative)
.analyze(&nodes)
.expect("unwrap");
let aggressive = AutoParallelizer::new()
.with_strategy(ParallelizationStrategy::Aggressive)
.analyze(&nodes)
.expect("unwrap");
assert!(aggressive.recommended_workers >= conservative.recommended_workers);
}
#[test]
fn test_sequential_graph() {
let parallelizer = AutoParallelizer::new();
let nodes = vec![
NodeInfo {
id: "a".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![],
can_parallelize: true,
},
NodeInfo {
id: "b".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![("a".to_string(), DependencyType::Data)],
can_parallelize: true,
},
];
let analysis = parallelizer.analyze(&nodes).expect("unwrap");
assert_eq!(analysis.num_stages, 2);
assert_eq!(analysis.parallelism_factor, 1.0); }
#[test]
fn test_fully_parallel_graph() {
let parallelizer = AutoParallelizer::new();
let nodes = vec![
NodeInfo {
id: "a".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![],
can_parallelize: true,
},
NodeInfo {
id: "b".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![],
can_parallelize: true,
},
NodeInfo {
id: "c".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![],
can_parallelize: true,
},
];
let analysis = parallelizer.analyze(&nodes).expect("unwrap");
assert_eq!(analysis.num_stages, 1);
assert_eq!(analysis.parallelism_factor, 3.0); }
#[test]
fn test_load_balancing() {
let parallelizer = AutoParallelizer::new().with_max_workers(2);
let nodes = create_test_nodes();
let plan = parallelizer.generate_plan(&nodes).expect("unwrap");
assert!(plan.partitions.len() > 0);
assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
}
#[test]
fn test_invalid_graph() {
let parallelizer = AutoParallelizer::new();
let nodes = vec![NodeInfo {
id: "a".to_string(),
op_type: "compute".to_string(),
estimated_cost: 10.0,
memory_size: 1000,
dependencies: vec![("unknown".to_string(), DependencyType::Data)],
can_parallelize: true,
}];
let result = parallelizer.build_dependency_graph(&nodes);
assert!(result.is_err());
}
}