use crate::autograd::Variable;
use crate::graph::Graph;
use crate::nn::{Buffer, Module, Optimizer, Parameter};
use super::cuda_event::CudaEvent;
use super::nccl::{NcclComms, ReduceOp};
use super::ddp_run::{DdpBuilder, DdpHandle};
pub use super::el_che::ElChe;
use crate::tensor::{Device, Result, Tensor, TensorError};
#[cfg(test)]
pub(crate) static NCCL_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub(crate) const DEFAULT_CALIBRATION_STEPS: usize = 10;
pub(crate) const DEFAULT_REBALANCE_INTERVAL: usize = 50;
const EMA_ALPHA: f64 = 0.3;
const MIN_CHUNK_RATIO: f64 = 0.05;
pub(crate) struct DistributedState {
pub replicas: Vec<Box<dyn Module>>,
pub comms: NcclComms,
pub devices: Vec<Device>,
pub optimizers: Vec<Box<dyn Optimizer>>,
pub chunk_ratios: Vec<f64>,
pub param_groups: Vec<Vec<Variable>>,
pub buffer_groups: Vec<Vec<Buffer>>,
pub last_timing: Option<Vec<(CudaEvent, CudaEvent)>>,
pub last_shard_sizes: Vec<i64>,
pub ema_throughput: Vec<f64>,
pub step_count: usize,
pub calibration_steps: usize,
pub rebalance_interval: usize,
pub el_che: Option<ElChe>,
pub last_el_che_counts: Vec<usize>,
pub last_el_che_sync: Option<std::time::Instant>,
pub max_grad_norm: Option<f64>,
}
impl DistributedState {
pub fn all_reduce_gradients(&self) -> Result<()> {
for group in &self.param_groups {
if group[0].grad().is_none() {
continue;
}
let grads: Vec<Tensor> = group
.iter()
.map(|v| v.grad().expect("gradient missing on replica"))
.collect();
let refs: Vec<&Tensor> = grads.iter().collect();
self.comms.all_reduce(&refs, ReduceOp::Avg)?;
}
Ok(())
}
pub fn sync_buffers(&self) -> Result<()> {
for group in &self.buffer_groups {
let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
self.comms.broadcast(&refs, 0)?;
}
Ok(())
}
pub fn sync_params(&self) -> Result<()> {
for group in &self.param_groups {
let tensors: Vec<Tensor> = group.iter().map(|v| v.data()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
self.comms.broadcast(&refs, 0)?;
}
self.sync_buffers()
}
pub fn compute_shard_sizes(&self, batch_size: i64) -> Vec<i64> {
let n = self.devices.len();
let mut sizes = Vec::with_capacity(n);
let mut remaining = batch_size;
for i in 0..n {
if i == n - 1 {
sizes.push(remaining);
} else {
let s = (batch_size as f64 * self.chunk_ratios[i]).round() as i64;
let s = s.max(1).min(remaining - (n - i - 1) as i64); sizes.push(s);
remaining -= s;
}
}
sizes
}
pub fn world_size(&self) -> usize {
self.devices.len()
}
pub fn is_balanced(&self) -> bool {
let first = self.chunk_ratios[0];
self.chunk_ratios.iter().all(|r| (r - first).abs() < 1e-6)
}
pub fn weighted_all_reduce_gradients(&self, batch_size: i64) -> Result<()> {
for group in &self.param_groups {
if group[0].grad().is_none() {
continue;
}
let grads: Vec<Tensor> = group
.iter()
.enumerate()
.map(|(rank, v)| {
let g = v.grad().expect("gradient missing on replica");
let weight = self.last_shard_sizes[rank] as f64 / batch_size as f64;
g.mul_scalar_(weight).ok();
g
})
.collect();
let refs: Vec<&Tensor> = grads.iter().collect();
self.comms.all_reduce(&refs, ReduceOp::Sum)?;
}
Ok(())
}
pub fn update_balance(&mut self) -> Result<bool> {
self.step_count += 1;
if let Some(timing) = self.last_timing.take() {
for (rank, (start, end)) in timing.iter().enumerate() {
let ms = CudaEvent::elapsed_time(start, end)?;
if ms > 0.0 && self.last_shard_sizes[rank] > 0 {
let throughput = self.last_shard_sizes[rank] as f64 / ms as f64;
if self.ema_throughput[rank] == 0.0 {
self.ema_throughput[rank] = throughput;
} else {
self.ema_throughput[rank] =
EMA_ALPHA * throughput + (1.0 - EMA_ALPHA) * self.ema_throughput[rank];
}
}
}
}
let should_rebalance = if self.step_count == self.calibration_steps {
true
} else if self.step_count > self.calibration_steps {
(self.step_count - self.calibration_steps) % self.rebalance_interval == 0
} else {
false
};
if should_rebalance {
self.rebalance();
return Ok(true);
}
Ok(false)
}
fn rebalance(&mut self) {
let total: f64 = self.ema_throughput.iter().sum();
if total <= 0.0 {
return; }
let n = self.devices.len();
let min_total = MIN_CHUNK_RATIO * n as f64;
let mut ratios: Vec<f64> = self.ema_throughput.iter().map(|t| t / total).collect();
let mut deficit = 0.0;
let mut unclamped = 0;
for r in &mut ratios {
if *r < MIN_CHUNK_RATIO {
deficit += MIN_CHUNK_RATIO - *r;
*r = MIN_CHUNK_RATIO;
} else {
unclamped += 1;
}
}
if deficit > 0.0 && unclamped > 0 {
let unclamped_total: f64 = ratios
.iter()
.filter(|&&r| r > MIN_CHUNK_RATIO + 1e-9)
.sum();
if unclamped_total > min_total {
for r in &mut ratios {
if *r > MIN_CHUNK_RATIO + 1e-9 {
*r -= deficit * (*r / unclamped_total);
*r = r.max(MIN_CHUNK_RATIO);
}
}
}
}
let sum: f64 = ratios.iter().sum();
if sum > 0.0 {
for r in &mut ratios {
*r /= sum;
}
}
self.chunk_ratios = ratios;
}
pub(crate) fn configure_el_che(&mut self, config: &DdpConfig) {
let n = self.devices.len();
if n < 2 {
return;
}
if config.max_anchor == Some(0) {
self.el_che = None;
return;
}
let anchor = 10; let mut el_che = ElChe::new(n, anchor);
if let Some(target) = config.overhead_target {
el_che = el_che.with_overhead_target(target);
}
if let Some(max) = config.max_anchor {
el_che = el_che.with_max_anchor(max);
}
if let Some((slow_rank, ratio)) = config.speed_hint {
el_che = el_che.with_speed_ratio(slow_rank, ratio);
self.apply_speed_hint(slow_rank, ratio);
}
self.el_che = Some(el_che);
self.max_grad_norm = config.max_grad_norm;
}
fn apply_speed_hint(&mut self, slow_rank: usize, ratio: f64) {
let n = self.devices.len();
if slow_rank >= n {
return;
}
let ratio = ratio.max(1.0);
let mut weights = vec![ratio; n];
weights[slow_rank] = 1.0;
let total: f64 = weights.iter().sum();
self.chunk_ratios = weights.iter().map(|w| w / total).collect();
}
}
pub struct Ddp {
comms: NcclComms,
devices: Vec<Device>,
param_groups: Vec<Vec<Variable>>,
buffer_groups: Vec<Vec<Buffer>>,
}
impl Ddp {
pub fn wrap(models: &[&dyn Module], devices: &[Device]) -> Result<Self> {
if models.len() < 2 {
return Err(TensorError::new("Ddp::wrap requires at least 2 models"));
}
if models.len() != devices.len() {
return Err(TensorError::new(
"Ddp::wrap: model count must match device count",
));
}
let comms = NcclComms::new(devices)?;
let all_params: Vec<Vec<Parameter>> =
models.iter().map(|m| m.parameters()).collect();
let n_params = all_params[0].len();
for (rank, params) in all_params.iter().enumerate().skip(1) {
if params.len() != n_params {
return Err(TensorError::new(&format!(
"Ddp: replica {} has {} parameters, expected {}",
rank,
params.len(),
n_params
)));
}
}
let mut param_groups = Vec::with_capacity(n_params);
for pi in 0..n_params {
let group: Vec<Variable> =
all_params.iter().map(|p| p[pi].variable.clone()).collect();
param_groups.push(group);
}
let all_buffers: Vec<Vec<Buffer>> =
models.iter().map(|m| m.buffers()).collect();
let n_buffers = all_buffers[0].len();
let mut buffer_groups = Vec::with_capacity(n_buffers);
for bi in 0..n_buffers {
let group: Vec<Buffer> =
all_buffers.iter().map(|b| b[bi].clone()).collect();
buffer_groups.push(group);
}
Ok(Ddp {
comms,
devices: devices.to_vec(),
param_groups,
buffer_groups,
})
}
pub fn sync_params(&self) -> Result<()> {
for group in &self.param_groups {
let tensors: Vec<Tensor> = group.iter().map(|v| v.data()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
self.comms.broadcast(&refs, 0)?;
}
for group in &self.buffer_groups {
let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
self.comms.broadcast(&refs, 0)?;
}
Ok(())
}
pub fn all_reduce_gradients(&self) -> Result<()> {
for group in &self.param_groups {
if group[0].grad().is_none() {
continue;
}
let grads: Vec<Tensor> = group
.iter()
.map(|v| v.grad().expect("gradient missing on replica"))
.collect();
let refs: Vec<&Tensor> = grads.iter().collect();
self.comms.all_reduce(&refs, ReduceOp::Avg)?;
}
Ok(())
}
pub fn sync_buffers(&self) -> Result<()> {
for group in &self.buffer_groups {
let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
self.comms.broadcast(&refs, 0)?;
}
Ok(())
}
pub fn weighted_all_reduce_gradients(&self, batch_counts: &[usize]) -> Result<()> {
if batch_counts.len() != self.devices.len() {
return Err(TensorError::new(&format!(
"weighted_all_reduce: batch_counts len ({}) != device count ({})",
batch_counts.len(),
self.devices.len(),
)));
}
let total: usize = batch_counts.iter().sum();
if total == 0 {
return Err(TensorError::new("weighted_all_reduce: total batch count is 0"));
}
for group in &self.param_groups {
if group[0].grad().is_none() {
continue;
}
let grads: Vec<Tensor> = group
.iter()
.enumerate()
.map(|(rank, v)| {
let g = v.grad().expect("gradient missing on replica");
let weight = batch_counts[rank] as f64 / total as f64;
g.mul_scalar_(weight).ok();
g
})
.collect();
let refs: Vec<&Tensor> = grads.iter().collect();
self.comms.all_reduce(&refs, ReduceOp::Sum)?;
}
Ok(())
}
pub fn world_size(&self) -> usize {
self.devices.len()
}
pub fn devices(&self) -> &[Device] {
&self.devices
}
pub fn setup<F, M, G, O>(
model: &Graph,
builder: F,
optimizer: G,
) -> Result<()>
where
F: Fn(Device) -> Result<M>,
M: Module + 'static,
G: Fn(&[Parameter]) -> O,
O: Optimizer + 'static,
{
Self::print_device_summary();
model.distribute(builder)?;
model.set_optimizer(optimizer);
model.set_training(true);
if Self::is_heterogeneous() {
model.configure_el_che(&DdpConfig::new());
}
Ok(())
}
pub fn setup_with<F, M, G, O>(
model: &Graph,
builder: F,
optimizer: G,
config: DdpConfig,
) -> Result<()>
where
F: Fn(Device) -> Result<M>,
M: Module + 'static,
G: Fn(&[Parameter]) -> O,
O: Optimizer + 'static,
{
Self::print_device_summary();
model.distribute(builder)?;
model.set_optimizer(optimizer);
model.set_training(true);
model.configure_el_che(&config);
Ok(())
}
#[deprecated(since = "0.3.0", note = "Renamed to Ddp::setup()")]
pub fn auto<F, M, G, O>(
model: &Graph,
builder: F,
optimizer: G,
) -> Result<()>
where
F: Fn(Device) -> Result<M>,
M: Module + 'static,
G: Fn(&[Parameter]) -> O,
O: Optimizer + 'static,
{
Self::setup(model, builder, optimizer)
}
#[deprecated(since = "0.3.0", note = "Renamed to Ddp::setup_with()")]
pub fn auto_with<F, M, G, O>(
model: &Graph,
builder: F,
optimizer: G,
config: DdpConfig,
) -> Result<()>
where
F: Fn(Device) -> Result<M>,
M: Module + 'static,
G: Fn(&[Parameter]) -> O,
O: Optimizer + 'static,
{
Self::setup_with(model, builder, optimizer, config)
}
pub fn builder<F, M, G, O, T>(
model_factory: F,
optim_factory: G,
train_fn: T,
) -> DdpBuilder<F, M, G, O, T>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
DdpHandle::new_builder(model_factory, optim_factory, train_fn)
}
fn is_heterogeneous() -> bool {
use crate::tensor::{cuda_available, cuda_device_count, cuda_device_name_idx};
if !cuda_available() || cuda_device_count() < 2 {
return false;
}
let n = cuda_device_count();
let names: Vec<Option<String>> = (0..n)
.map(cuda_device_name_idx)
.collect();
names.windows(2).any(|w| w[0] != w[1])
}
fn print_device_summary() {
use crate::tensor::{
cuda_available, cuda_device_count,
cuda_device_name_idx, cuda_memory_info_idx,
};
use crate::monitor::format_bytes;
if !cuda_available() || cuda_device_count() == 0 {
eprintln!(" ddp: no CUDA available | CPU mode");
return;
}
let n = cuda_device_count();
let mut names = Vec::with_capacity(n as usize);
let mut parts = Vec::with_capacity(n as usize);
for i in 0..n {
let raw_name = cuda_device_name_idx(i)
.unwrap_or_else(|| format!("CUDA({})", i));
let short = raw_name
.strip_prefix("NVIDIA ")
.unwrap_or(&raw_name)
.to_string();
let vram = cuda_memory_info_idx(i)
.map(|(_, total)| format!(" ({})", format_bytes(total)))
.unwrap_or_default();
parts.push(format!("{}{}", short, vram));
names.push(raw_name);
}
let heterogeneous = names.windows(2).any(|w| w[0] != w[1]);
if n == 1 {
eprintln!(" ddp: 1 GPU | {} | single-device mode", parts[0]);
} else if heterogeneous {
eprintln!(
" ddp: {} GPUs (heterogeneous) | {}",
n,
parts.join(" | "),
);
} else {
eprintln!(" ddp: {} GPUs | {}", n, parts.join(" | "));
}
}
}
#[derive(Debug, Clone)]
pub struct DdpConfig {
pub speed_hint: Option<(usize, f64)>,
pub overhead_target: Option<f64>,
pub max_anchor: Option<usize>,
pub max_grad_norm: Option<f64>,
}
impl DdpConfig {
pub fn new() -> Self {
DdpConfig {
speed_hint: None,
overhead_target: None,
max_anchor: None,
max_grad_norm: None,
}
}
pub fn speed_hint(mut self, slow_rank: usize, ratio: f64) -> Self {
self.speed_hint = Some((slow_rank, ratio));
self
}
pub fn overhead_target(mut self, target: f64) -> Self {
self.overhead_target = Some(target.clamp(0.01, 0.50));
self
}
pub fn max_anchor(mut self, max: Option<usize>) -> Self {
self.max_anchor = max;
self
}
pub fn max_grad_norm(mut self, max_norm: f64) -> Self {
self.max_grad_norm = Some(max_norm);
self
}
}
impl Default for DdpConfig {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[path = "ddp_tests.rs"]
mod tests;