use crate::autograd::Variable;
use crate::error::RusTorchResult;
use crate::gpu::DeviceType;
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub trait DistributedScalar:
Float + Send + Sync + 'static + std::fmt::Debug + ndarray::ScalarOperand + num_traits::FromPrimitive
{
}
impl DistributedScalar for f32 {}
impl DistributedScalar for f64 {}
pub trait DistributedDataParallelTrait<T: DistributedScalar> {
fn device_ids(&self) -> &[usize];
fn distributed_forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>>;
fn sync_gradients(&self) -> RusTorchResult<()>;
}
pub use multi_gpu_validation::{
BenchmarkResults, GpuDeviceInfo, MemoryUsage, MultiGpuValidator, ValidationMetrics,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistributedBackend {
NCCL,
Gloo,
MPI,
TCP,
}
#[derive(Debug, Clone)]
pub struct ProcessGroup {
pub rank: usize,
pub world_size: usize,
pub backend: DistributedBackend,
pub master_addr: String,
pub master_port: u16,
}
impl ProcessGroup {
pub fn new(
rank: usize,
world_size: usize,
backend: DistributedBackend,
master_addr: String,
master_port: u16,
) -> Self {
Self {
rank,
world_size,
backend,
master_addr,
master_port,
}
}
pub fn init(&self) -> crate::error::RusTorchResult<()> {
match self.backend {
DistributedBackend::NCCL => self.init_nccl(),
DistributedBackend::Gloo => self.init_gloo(),
DistributedBackend::MPI => self.init_mpi(),
DistributedBackend::TCP => self.init_tcp(),
}
}
fn init_nccl(&self) -> crate::error::RusTorchResult<()> {
#[cfg(feature = "nccl")]
{
Ok(())
}
#[cfg(not(feature = "nccl"))]
{
Err(crate::error::RusTorchError::distributed(
"NCCL not compiled",
))
}
}
fn init_gloo(&self) -> crate::error::RusTorchResult<()> {
Ok(())
}
fn init_mpi(&self) -> crate::error::RusTorchResult<()> {
Err(crate::error::RusTorchError::distributed(
"MPI not supported - use TCP or NCCL",
))
}
fn init_tcp(&self) -> crate::error::RusTorchResult<()> {
Ok(())
}
}
pub trait DistributedOps<T: Float> {
fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()>;
fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>>;
fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()>;
fn reduce(&self, tensor: &mut Tensor<T>, root: usize, op: ReduceOp) -> RusTorchResult<()>;
fn scatter(&self, tensors: &[Tensor<T>], root: usize) -> RusTorchResult<Tensor<T>>;
fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Product,
Min,
Max,
Average,
}
static mut DISTRIBUTED_STATE: Option<Arc<Mutex<DistributedState>>> = None;
static DISTRIBUTED_INIT: std::sync::Once = std::sync::Once::new();
#[derive(Debug)]
pub struct DistributedState {
pub process_group: Option<ProcessGroup>,
pub devices: Vec<DeviceType>,
pub device_map: HashMap<usize, Vec<DeviceType>>,
}
impl Default for DistributedState {
fn default() -> Self {
Self::new()
}
}
impl DistributedState {
pub fn new() -> Self {
Self {
process_group: None,
devices: Vec::new(),
device_map: HashMap::new(),
}
}
pub fn set_process_group(&mut self, pg: ProcessGroup) {
self.process_group = Some(pg);
}
pub fn rank(&self) -> Option<usize> {
self.process_group.as_ref().map(|pg| pg.rank)
}
pub fn world_size(&self) -> Option<usize> {
self.process_group.as_ref().map(|pg| pg.world_size)
}
pub fn is_initialized(&self) -> bool {
self.process_group.is_some()
}
}
pub fn get_distributed_state() -> &'static Arc<Mutex<DistributedState>> {
unsafe {
DISTRIBUTED_INIT.call_once(|| {
DISTRIBUTED_STATE = Some(Arc::new(Mutex::new(DistributedState::new())));
});
#[allow(static_mut_refs)]
DISTRIBUTED_STATE.as_ref().unwrap()
}
}
pub fn init_distributed(
backend: DistributedBackend,
rank: usize,
world_size: usize,
master_addr: String,
master_port: u16,
) -> RusTorchResult<()> {
let process_group = ProcessGroup::new(rank, world_size, backend, master_addr, master_port);
process_group.init()?;
let state = get_distributed_state();
let mut state_guard = state.lock().unwrap();
state_guard.set_process_group(process_group);
Ok(())
}
pub fn is_available() -> bool {
let state = get_distributed_state();
let state_guard = state.lock().unwrap();
state_guard.is_initialized()
}
pub fn get_rank() -> Option<usize> {
let state = get_distributed_state();
let state_guard = state.lock().unwrap();
state_guard.rank()
}
pub fn get_world_size() -> Option<usize> {
let state = get_distributed_state();
let state_guard = state.lock().unwrap();
state_guard.world_size()
}
pub fn finalize() -> RusTorchResult<()> {
let state = get_distributed_state();
let mut state_guard = state.lock().unwrap();
state_guard.process_group = None;
state_guard.devices.clear();
state_guard.device_map.clear();
Ok(())
}
pub mod api;
pub mod async_gradient;
pub mod backends;
pub mod cluster;
pub mod common;
pub mod data_parallel;
pub mod ddp;
pub mod model_parallel;
pub mod multi_gpu_validation;
pub mod nccl_integration;
pub mod optimizer;
pub mod performance;
pub mod simple_ddp;
pub use api::*;
pub use ddp::{wrap_module, DistributedDataParallel};
pub use simple_ddp::{wrap_simple, SimpleDistributedDataParallel};
pub use async_gradient::{AsyncConfig, AsyncGradientSynchronizer, Priority};
#[cfg(feature = "nccl")]
pub use nccl_integration::{NCCLBackendOptimized, NCCLOps, NCCLOptimizations, NCCLProfiler};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_group_creation() {
let pg = ProcessGroup::new(
0,
4,
DistributedBackend::TCP,
"localhost".to_string(),
12345,
);
assert_eq!(pg.rank, 0);
assert_eq!(pg.world_size, 4);
assert_eq!(pg.backend, DistributedBackend::TCP);
assert_eq!(pg.master_addr, "localhost");
assert_eq!(pg.master_port, 12345);
}
#[test]
fn test_distributed_state() {
let mut state = DistributedState::new();
assert!(!state.is_initialized());
let pg = ProcessGroup::new(
1,
2,
DistributedBackend::Gloo,
"127.0.0.1".to_string(),
29500,
);
state.set_process_group(pg);
assert!(state.is_initialized());
assert_eq!(state.rank(), Some(1));
assert_eq!(state.world_size(), Some(2));
}
#[test]
fn test_reduce_op_variants() {
let ops = [
ReduceOp::Sum,
ReduceOp::Product,
ReduceOp::Min,
ReduceOp::Max,
ReduceOp::Average,
];
for op in &ops {
assert!(matches!(
op,
ReduceOp::Sum
| ReduceOp::Product
| ReduceOp::Min
| ReduceOp::Max
| ReduceOp::Average
));
}
}
}