use crate::capabilities::DeviceType;
use crate::error::ExecutorError;
use crate::placement::Device;
use crate::shape::TensorShape;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tensorlogic_ir::EinsumGraph;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum ParallelismStrategy {
#[default]
DataParallel,
ModelParallel,
PipelineParallel,
Hybrid { data_parallel_groups: usize },
}
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub parallelism: ParallelismStrategy,
pub num_devices: usize,
pub backend: String,
pub master_addr: Option<String>,
pub master_port: Option<u16>,
pub rank: usize,
pub world_size: usize,
pub enable_gradient_compression: bool,
pub enable_mixed_precision: bool,
pub bucket_size: usize,
pub enable_async_communication: bool,
}
impl Default for DistributedConfig {
fn default() -> Self {
DistributedConfig {
parallelism: ParallelismStrategy::default(),
num_devices: 1,
backend: "gloo".to_string(),
master_addr: None,
master_port: None,
rank: 0,
world_size: 1,
enable_gradient_compression: false,
enable_mixed_precision: false,
bucket_size: 25 * 1024 * 1024, enable_async_communication: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShardingSpec {
pub node_id: usize,
pub shard_dim: usize,
pub num_shards: usize,
pub shard_to_device: Vec<Device>,
}
impl ShardingSpec {
pub fn new(node_id: usize, shard_dim: usize, devices: Vec<Device>) -> Self {
let num_shards = devices.len();
ShardingSpec {
node_id,
shard_dim,
num_shards,
shard_to_device: devices,
}
}
pub fn device_for_shard(&self, shard_id: usize) -> Option<&Device> {
self.shard_to_device.get(shard_id)
}
pub fn is_valid_shard(&self, shard_id: usize) -> bool {
shard_id < self.num_shards
}
}
#[derive(Debug, Clone)]
pub struct DistributedPlacementPlan {
pub node_placement: HashMap<usize, Device>,
pub sharding_specs: Vec<ShardingSpec>,
pub communication_deps: HashMap<usize, Vec<usize>>,
}
impl DistributedPlacementPlan {
pub fn new() -> Self {
DistributedPlacementPlan {
node_placement: HashMap::new(),
sharding_specs: Vec::new(),
communication_deps: HashMap::new(),
}
}
pub fn place_node(&mut self, node_id: usize, device: Device) {
self.node_placement.insert(node_id, device);
}
pub fn add_sharding(&mut self, spec: ShardingSpec) {
self.sharding_specs.push(spec);
}
pub fn get_device(&self, node_id: usize) -> Option<&Device> {
self.node_placement.get(&node_id)
}
pub fn get_sharding(&self, node_id: usize) -> Option<&ShardingSpec> {
self.sharding_specs.iter().find(|s| s.node_id == node_id)
}
pub fn is_sharded(&self, node_id: usize) -> bool {
self.get_sharding(node_id).is_some()
}
}
impl Default for DistributedPlacementPlan {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CommunicationOp {
AllReduce {
reduction: ReductionOp,
},
Broadcast {
src_rank: usize,
},
Scatter {
src_rank: usize,
},
Gather {
dst_rank: usize,
},
AllGather,
ReduceScatter {
reduction: ReductionOp,
},
Send {
dst_rank: usize,
},
Recv {
src_rank: usize,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReductionOp {
Sum,
Mean,
Max,
Min,
Product,
}
pub trait CommunicationBackend: Send + Sync {
fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError>;
fn finalize(&mut self) -> Result<(), ExecutorError>;
fn rank(&self) -> usize;
fn world_size(&self) -> usize;
fn all_reduce(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
fn broadcast(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
fn scatter(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
fn gather(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
fn all_gather(&self, tensor_id: &str) -> Result<(), ExecutorError>;
fn reduce_scatter(&self, tensor_id: &str, reduction: ReductionOp) -> Result<(), ExecutorError>;
fn send(&self, tensor_id: &str, dst_rank: usize) -> Result<(), ExecutorError>;
fn recv(&self, tensor_id: &str, src_rank: usize) -> Result<(), ExecutorError>;
fn barrier(&self) -> Result<(), ExecutorError>;
}
pub struct DummyCommunicationBackend {
rank: usize,
world_size: usize,
}
impl DummyCommunicationBackend {
pub fn new() -> Self {
DummyCommunicationBackend {
rank: 0,
world_size: 1,
}
}
}
impl Default for DummyCommunicationBackend {
fn default() -> Self {
Self::new()
}
}
impl CommunicationBackend for DummyCommunicationBackend {
fn initialize(&mut self, config: &DistributedConfig) -> Result<(), ExecutorError> {
self.rank = config.rank;
self.world_size = config.world_size;
Ok(())
}
fn finalize(&mut self) -> Result<(), ExecutorError> {
Ok(())
}
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn all_reduce(&self, _tensor_id: &str, _reduction: ReductionOp) -> Result<(), ExecutorError> {
Ok(())
}
fn broadcast(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
Ok(())
}
fn scatter(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
Ok(())
}
fn gather(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
Ok(())
}
fn all_gather(&self, _tensor_id: &str) -> Result<(), ExecutorError> {
Ok(())
}
fn reduce_scatter(
&self,
_tensor_id: &str,
_reduction: ReductionOp,
) -> Result<(), ExecutorError> {
Ok(())
}
fn send(&self, _tensor_id: &str, _dst_rank: usize) -> Result<(), ExecutorError> {
Ok(())
}
fn recv(&self, _tensor_id: &str, _src_rank: usize) -> Result<(), ExecutorError> {
Ok(())
}
fn barrier(&self) -> Result<(), ExecutorError> {
Ok(())
}
}
pub struct DataParallelCoordinator {
config: DistributedConfig,
backend: Arc<RwLock<dyn CommunicationBackend>>,
devices: Vec<Device>,
}
impl DataParallelCoordinator {
pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
let devices = (0..config.num_devices)
.map(|i| Device::new(DeviceType::CPU, i))
.collect();
DataParallelCoordinator {
config,
backend,
devices,
}
}
pub fn distribute_batch(&self, batch_size: usize) -> Vec<(usize, usize)> {
let per_device = batch_size / self.config.num_devices;
let remainder = batch_size % self.config.num_devices;
let mut distribution = Vec::new();
let mut offset = 0;
for i in 0..self.config.num_devices {
let size = per_device + if i < remainder { 1 } else { 0 };
distribution.push((offset, size));
offset += size;
}
distribution
}
pub fn synchronize_gradients(&self) -> Result<(), ExecutorError> {
let backend = self.backend.read().expect("lock should not be poisoned");
backend.all_reduce("gradients", ReductionOp::Mean)?;
Ok(())
}
pub fn devices(&self) -> &[Device] {
&self.devices
}
}
pub struct ModelParallelCoordinator {
config: DistributedConfig,
backend: Arc<RwLock<dyn CommunicationBackend>>,
placement_plan: DistributedPlacementPlan,
}
impl ModelParallelCoordinator {
pub fn new(config: DistributedConfig, backend: Arc<RwLock<dyn CommunicationBackend>>) -> Self {
ModelParallelCoordinator {
config,
backend,
placement_plan: DistributedPlacementPlan::new(),
}
}
pub fn create_sharding_plan(&mut self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
let num_devices = self.config.num_devices;
let nodes_per_device = graph.nodes.len().div_ceil(num_devices);
for (node_id, _node) in graph.nodes.iter().enumerate() {
let device_idx = node_id / nodes_per_device;
let device = Device::new(DeviceType::CPU, device_idx);
self.placement_plan.place_node(node_id, device);
}
Ok(())
}
pub fn placement_plan(&self) -> &DistributedPlacementPlan {
&self.placement_plan
}
pub fn shard_tensor(
&self,
_node_id: usize,
shape: &TensorShape,
shard_dim: usize,
) -> Result<Vec<TensorShape>, ExecutorError> {
let num_shards = self.config.num_devices;
if shard_dim >= shape.rank() {
return Err(ExecutorError::InvalidInput(format!(
"Shard dimension {} exceeds tensor rank {}",
shard_dim,
shape.rank()
)));
}
let total_size = shape.dims[shard_dim].as_static().ok_or_else(|| {
ExecutorError::InvalidInput("Cannot shard dynamic dimension".to_string())
})?;
let per_shard = total_size / num_shards;
let remainder = total_size % num_shards;
let mut shard_shapes = Vec::new();
for i in 0..num_shards {
let shard_size = per_shard + if i < remainder { 1 } else { 0 };
let mut shard_shape = shape.clone();
shard_shape.dims[shard_dim] = crate::shape::DimSize::Static(shard_size);
shard_shapes.push(shard_shape);
}
Ok(shard_shapes)
}
pub fn gather_shards(&self, _shard_dim: usize) -> Result<(), ExecutorError> {
let backend = self.backend.read().expect("lock should not be poisoned");
backend.all_gather("sharded_tensor")?;
Ok(())
}
}
pub struct PipelineParallelCoordinator {
config: DistributedConfig,
backend: Arc<RwLock<dyn CommunicationBackend>>,
num_stages: usize,
micro_batch_size: usize,
}
impl PipelineParallelCoordinator {
pub fn new(
config: DistributedConfig,
backend: Arc<RwLock<dyn CommunicationBackend>>,
num_stages: usize,
) -> Self {
PipelineParallelCoordinator {
config,
backend,
num_stages,
micro_batch_size: 1,
}
}
pub fn set_micro_batch_size(&mut self, size: usize) {
self.micro_batch_size = size;
}
pub fn stage_for_rank(&self, rank: usize) -> usize {
rank % self.num_stages
}
pub fn send_activations(&self, stage: usize) -> Result<(), ExecutorError> {
if stage < self.num_stages - 1 {
let next_rank = stage + 1;
let backend = self.backend.read().expect("lock should not be poisoned");
backend.send("activations", next_rank)?;
}
Ok(())
}
pub fn recv_activations(&self, stage: usize) -> Result<(), ExecutorError> {
if stage > 0 {
let prev_rank = stage - 1;
let backend = self.backend.read().expect("lock should not be poisoned");
backend.recv("activations", prev_rank)?;
}
Ok(())
}
pub fn send_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
if stage > 0 {
let prev_rank = stage - 1;
let backend = self.backend.read().expect("lock should not be poisoned");
backend.send("gradients", prev_rank)?;
}
Ok(())
}
pub fn recv_gradients(&self, stage: usize) -> Result<(), ExecutorError> {
if stage < self.num_stages - 1 {
let next_rank = stage + 1;
let backend = self.backend.read().expect("lock should not be poisoned");
backend.recv("gradients", next_rank)?;
}
Ok(())
}
pub fn num_stages(&self) -> usize {
self.num_stages
}
pub fn micro_batch_size(&self) -> usize {
self.micro_batch_size
}
pub fn config(&self) -> &DistributedConfig {
&self.config
}
}
pub struct DistributedExecutor {
config: DistributedConfig,
backend: Arc<RwLock<dyn CommunicationBackend>>,
data_parallel: Option<DataParallelCoordinator>,
model_parallel: Option<ModelParallelCoordinator>,
pipeline_parallel: Option<PipelineParallelCoordinator>,
}
impl DistributedExecutor {
pub fn new(
config: DistributedConfig,
backend: Arc<RwLock<dyn CommunicationBackend>>,
) -> Result<Self, ExecutorError> {
backend
.write()
.expect("lock should not be poisoned")
.initialize(&config)?;
let mut executor = DistributedExecutor {
config: config.clone(),
backend: backend.clone(),
data_parallel: None,
model_parallel: None,
pipeline_parallel: None,
};
executor.setup_coordinators()?;
Ok(executor)
}
fn setup_coordinators(&mut self) -> Result<(), ExecutorError> {
match self.config.parallelism {
ParallelismStrategy::DataParallel => {
self.data_parallel = Some(DataParallelCoordinator::new(
self.config.clone(),
self.backend.clone(),
));
}
ParallelismStrategy::ModelParallel => {
self.model_parallel = Some(ModelParallelCoordinator::new(
self.config.clone(),
self.backend.clone(),
));
}
ParallelismStrategy::PipelineParallel => {
let num_stages = self.config.num_devices;
self.pipeline_parallel = Some(PipelineParallelCoordinator::new(
self.config.clone(),
self.backend.clone(),
num_stages,
));
}
ParallelismStrategy::Hybrid {
data_parallel_groups: _,
} => {
self.data_parallel = Some(DataParallelCoordinator::new(
self.config.clone(),
self.backend.clone(),
));
self.model_parallel = Some(ModelParallelCoordinator::new(
self.config.clone(),
self.backend.clone(),
));
}
}
Ok(())
}
pub fn strategy(&self) -> ParallelismStrategy {
self.config.parallelism
}
pub fn rank(&self) -> usize {
self.backend
.read()
.expect("lock should not be poisoned")
.rank()
}
pub fn world_size(&self) -> usize {
self.backend
.read()
.expect("lock should not be poisoned")
.world_size()
}
pub fn barrier(&self) -> Result<(), ExecutorError> {
self.backend
.read()
.expect("lock should not be poisoned")
.barrier()
}
pub fn data_parallel(&self) -> Option<&DataParallelCoordinator> {
self.data_parallel.as_ref()
}
pub fn model_parallel(&self) -> Option<&ModelParallelCoordinator> {
self.model_parallel.as_ref()
}
pub fn pipeline_parallel(&self) -> Option<&PipelineParallelCoordinator> {
self.pipeline_parallel.as_ref()
}
}
impl Drop for DistributedExecutor {
fn drop(&mut self) {
let _ = self
.backend
.write()
.expect("lock should not be poisoned")
.finalize();
}
}
pub trait TlDistributedExecutor {
fn distributed_executor(&self) -> Option<&DistributedExecutor>;
fn enable_distributed(&mut self, config: DistributedConfig) -> Result<(), ExecutorError>;
fn disable_distributed(&mut self);
fn is_distributed(&self) -> bool;
fn rank(&self) -> usize {
self.distributed_executor().map(|d| d.rank()).unwrap_or(0)
}
fn world_size(&self) -> usize {
self.distributed_executor()
.map(|d| d.world_size())
.unwrap_or(1)
}
}
#[derive(Debug, Clone, Default)]
pub struct DistributedStats {
pub total_communications: usize,
pub total_bytes_communicated: u64,
pub gradient_syncs: usize,
pub avg_communication_time_ms: f64,
pub load_imbalance: f64,
}
impl DistributedStats {
pub fn summary(&self) -> String {
format!(
"Distributed Stats: {} communications, {:.2} MB transferred, {} gradient syncs, {:.2}ms avg comm time, {:.2}% load imbalance",
self.total_communications,
self.total_bytes_communicated as f64 / 1_000_000.0,
self.gradient_syncs,
self.avg_communication_time_ms,
self.load_imbalance * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distributed_config_default() {
let config = DistributedConfig::default();
assert_eq!(config.parallelism, ParallelismStrategy::DataParallel);
assert_eq!(config.num_devices, 1);
assert_eq!(config.rank, 0);
assert_eq!(config.world_size, 1);
}
#[test]
fn test_sharding_spec() {
let devices = vec![
Device::new(DeviceType::CPU, 0),
Device::new(DeviceType::CPU, 1),
Device::new(DeviceType::CPU, 2),
];
let spec = ShardingSpec::new(0, 1, devices);
assert_eq!(spec.num_shards, 3);
assert_eq!(spec.shard_dim, 1);
assert!(spec.is_valid_shard(0));
assert!(spec.is_valid_shard(2));
assert!(!spec.is_valid_shard(3));
}
#[test]
fn test_distributed_placement_plan() {
let mut plan = DistributedPlacementPlan::new();
plan.place_node(0, Device::new(DeviceType::CPU, 0));
plan.place_node(1, Device::new(DeviceType::CPU, 1));
assert!(plan.get_device(0).is_some());
assert!(plan.get_device(1).is_some());
assert!(plan.get_device(2).is_none());
}
#[test]
fn test_data_parallel_batch_distribution() {
let config = DistributedConfig {
num_devices: 4,
..Default::default()
};
let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
let coordinator = DataParallelCoordinator::new(config, backend);
let distribution = coordinator.distribute_batch(10);
assert_eq!(distribution.len(), 4);
let total: usize = distribution.iter().map(|(_, size)| size).sum();
assert_eq!(total, 10);
}
#[test]
fn test_model_parallel_sharding() {
let config = DistributedConfig {
num_devices: 4,
parallelism: ParallelismStrategy::ModelParallel,
..Default::default()
};
let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
let coordinator = ModelParallelCoordinator::new(config, backend);
let shape = TensorShape::static_shape(vec![8, 16]);
let shards = coordinator.shard_tensor(0, &shape, 0).expect("unwrap");
assert_eq!(shards.len(), 4);
assert_eq!(shards[0].dims[0].as_static().expect("unwrap"), 2);
}
#[test]
fn test_pipeline_parallel_stage_assignment() {
let config = DistributedConfig {
num_devices: 4,
parallelism: ParallelismStrategy::PipelineParallel,
..Default::default()
};
let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
let coordinator = PipelineParallelCoordinator::new(config, backend, 4);
assert_eq!(coordinator.stage_for_rank(0), 0);
assert_eq!(coordinator.stage_for_rank(1), 1);
assert_eq!(coordinator.stage_for_rank(2), 2);
assert_eq!(coordinator.stage_for_rank(3), 3);
}
#[test]
fn test_distributed_executor_creation() {
let config = DistributedConfig::default();
let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
let executor = DistributedExecutor::new(config, backend);
assert!(executor.is_ok());
let executor = executor.expect("unwrap");
assert_eq!(executor.rank(), 0);
assert_eq!(executor.world_size(), 1);
}
#[test]
fn test_communication_ops() {
let op1 = CommunicationOp::AllReduce {
reduction: ReductionOp::Sum,
};
let op2 = CommunicationOp::Broadcast { src_rank: 0 };
assert_ne!(op1, op2);
}
#[test]
fn test_reduction_ops() {
let ops = [
ReductionOp::Sum,
ReductionOp::Mean,
ReductionOp::Max,
ReductionOp::Min,
ReductionOp::Product,
];
assert_eq!(ops.len(), 5);
}
#[test]
fn test_dummy_backend() {
let mut backend = DummyCommunicationBackend::new();
let config = DistributedConfig::default();
assert!(backend.initialize(&config).is_ok());
assert_eq!(backend.rank(), 0);
assert_eq!(backend.world_size(), 1);
assert!(backend.all_reduce("test", ReductionOp::Sum).is_ok());
assert!(backend.barrier().is_ok());
assert!(backend.finalize().is_ok());
}
#[test]
fn test_distributed_stats() {
let stats = DistributedStats {
total_communications: 100,
total_bytes_communicated: 1_000_000,
gradient_syncs: 50,
avg_communication_time_ms: 10.5,
load_imbalance: 0.15,
};
let summary = stats.summary();
assert!(summary.contains("100 communications"));
assert!(summary.contains("50 gradient syncs"));
}
#[test]
fn test_hybrid_parallelism() {
let config = DistributedConfig {
parallelism: ParallelismStrategy::Hybrid {
data_parallel_groups: 2,
},
num_devices: 8,
..Default::default()
};
let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
let executor = DistributedExecutor::new(config, backend).expect("unwrap");
assert!(executor.data_parallel().is_some());
assert!(executor.model_parallel().is_some());
}
#[test]
fn test_sharding_invalid_dimension() {
let config = DistributedConfig {
num_devices: 4,
..Default::default()
};
let backend = Arc::new(RwLock::new(DummyCommunicationBackend::new()));
let coordinator = ModelParallelCoordinator::new(config, backend);
let shape = TensorShape::static_shape(vec![8, 16]);
let result = coordinator.shard_tensor(0, &shape, 5);
assert!(result.is_err());
}
}