use super::*;
use crate::tensor::{
cuda_device_count, cuda_synchronize, test_device, DType, TensorOptions,
};
use super::NCCL_LOCK;
fn require_multi_gpu() -> bool {
if !test_device().is_cuda() || cuda_device_count() < 2 {
return false;
}
for i in 0..2 {
let opts = TensorOptions {
dtype: DType::Float32,
device: Device::CUDA(i),
};
if Tensor::zeros(&[1], opts).is_err() {
eprintln!(
"Device CUDA({i}) cannot run compute kernels, skipping multi-GPU test"
);
return false;
}
}
true
}
#[test]
fn test_ddp_requires_two_models() {
let result = Ddp::wrap(&[], &[]);
assert!(result.is_err());
}
#[test]
fn test_ddp_model_device_mismatch() {
let result = Ddp::wrap(
&[],
&[Device::CUDA(0), Device::CUDA(1)],
);
assert!(result.is_err());
}
#[test]
fn test_shard_sizes_equal() {
let ratios = vec![0.5, 0.5];
let state = mock_state(&ratios);
assert_eq!(state.compute_shard_sizes(10), vec![5, 5]);
assert_eq!(state.compute_shard_sizes(11), vec![6, 5]);
assert_eq!(state.compute_shard_sizes(3), vec![2, 1]);
}
#[test]
fn test_shard_sizes_unequal() {
let ratios = vec![0.7, 0.3];
let state = mock_state(&ratios);
assert_eq!(state.compute_shard_sizes(10), vec![7, 3]);
assert_eq!(state.compute_shard_sizes(100), vec![70, 30]);
}
#[test]
fn test_shard_sizes_three_devices() {
let ratios = vec![0.5, 0.3, 0.2];
let state = mock_state(&ratios);
let sizes = state.compute_shard_sizes(10);
assert_eq!(sizes.iter().sum::<i64>(), 10);
assert_eq!(sizes, vec![5, 3, 2]);
}
fn mock_state(ratios: &[f64]) -> DistributedState {
let n = ratios.len();
DistributedState {
replicas: Vec::new(),
comms: unsafe { mock_nccl_comms(n) },
devices: (0..n as u8)
.map(Device::CUDA)
.collect(),
optimizers: Vec::new(),
chunk_ratios: ratios.to_vec(),
param_groups: Vec::new(),
buffer_groups: Vec::new(),
last_timing: None,
last_shard_sizes: vec![0; n],
ema_throughput: vec![0.0; n],
step_count: 0,
calibration_steps: DEFAULT_CALIBRATION_STEPS,
rebalance_interval: DEFAULT_REBALANCE_INTERVAL,
el_che: None,
last_el_che_counts: Vec::new(),
last_el_che_sync: None,
max_grad_norm: None,
timeline: None,
}
}
unsafe fn mock_nccl_comms(n: usize) -> NcclComms {
let devices: Vec<Device> = (0..n as u8).map(Device::CUDA).collect();
unsafe { NcclComms::from_raw(std::ptr::null_mut(), devices) }
}
#[test]
fn test_is_balanced_equal() {
let state = mock_state(&[0.5, 0.5]);
assert!(state.is_balanced());
}
#[test]
fn test_is_balanced_unequal() {
let state = mock_state(&[0.7, 0.3]);
assert!(!state.is_balanced());
}
#[test]
fn test_rebalance_proportional() {
let mut state = mock_state(&[0.5, 0.5]);
state.ema_throughput = vec![30.0, 10.0];
state.rebalance();
assert!((state.chunk_ratios[0] - 0.75).abs() < 0.01,
"fast GPU should get ~75%, got {}", state.chunk_ratios[0]);
assert!((state.chunk_ratios[1] - 0.25).abs() < 0.01,
"slow GPU should get ~25%, got {}", state.chunk_ratios[1]);
let sum: f64 = state.chunk_ratios.iter().sum();
assert!((sum - 1.0).abs() < 1e-9, "ratios must sum to 1.0, got {sum}");
}
#[test]
fn test_rebalance_three_devices() {
let mut state = mock_state(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]);
state.ema_throughput = vec![50.0, 30.0, 20.0];
state.rebalance();
assert!((state.chunk_ratios[0] - 0.50).abs() < 0.01);
assert!((state.chunk_ratios[1] - 0.30).abs() < 0.01);
assert!((state.chunk_ratios[2] - 0.20).abs() < 0.01);
let sum: f64 = state.chunk_ratios.iter().sum();
assert!((sum - 1.0).abs() < 1e-9);
}
#[test]
fn test_rebalance_respects_min_ratio() {
let mut state = mock_state(&[0.5, 0.5]);
state.ema_throughput = vec![100.0, 1.0];
state.rebalance();
assert!(state.chunk_ratios[1] >= MIN_CHUNK_RATIO,
"slow GPU should get at least MIN_CHUNK_RATIO, got {}", state.chunk_ratios[1]);
let sum: f64 = state.chunk_ratios.iter().sum();
assert!((sum - 1.0).abs() < 1e-9);
}
#[test]
fn test_rebalance_no_data() {
let mut state = mock_state(&[0.5, 0.5]);
state.ema_throughput = vec![0.0, 0.0];
state.rebalance();
assert_eq!(state.chunk_ratios, vec![0.5, 0.5]);
}
#[test]
fn test_update_balance_calibration_timing() {
let mut state = mock_state(&[0.5, 0.5]);
for _ in 0..DEFAULT_CALIBRATION_STEPS - 1 {
let rebalanced = state.update_balance().unwrap();
assert!(!rebalanced, "should not rebalance during calibration");
}
let rebalanced = state.update_balance().unwrap();
assert!(rebalanced, "should rebalance at calibration boundary");
}
#[test]
fn test_update_balance_interval() {
let mut state = mock_state(&[0.5, 0.5]);
state.step_count = DEFAULT_CALIBRATION_STEPS;
for _ in 0..DEFAULT_REBALANCE_INTERVAL - 1 {
let rebalanced = state.update_balance().unwrap();
assert!(!rebalanced);
}
let rebalanced = state.update_balance().unwrap();
assert!(rebalanced);
}
#[test]
fn test_ema_throughput_init() {
let mut state = mock_state(&[0.5, 0.5]);
state.ema_throughput = vec![0.0, 0.0];
let throughput_0 = 10.0;
state.ema_throughput[0] = throughput_0; assert_eq!(state.ema_throughput[0], 10.0);
}
#[test]
fn test_ema_throughput_smoothing() {
let mut state = mock_state(&[0.5, 0.5]);
state.ema_throughput = vec![10.0, 5.0];
let new_measurement = 20.0;
state.ema_throughput[0] =
EMA_ALPHA * new_measurement + (1.0 - EMA_ALPHA) * state.ema_throughput[0];
assert!((state.ema_throughput[0] - 13.0).abs() < 1e-9);
}
#[test]
fn test_shard_sizes_after_rebalance() {
let mut state = mock_state(&[0.5, 0.5]);
state.ema_throughput = vec![70.0, 30.0];
state.rebalance();
let sizes = state.compute_shard_sizes(100);
assert_eq!(sizes.iter().sum::<i64>(), 100);
assert_eq!(sizes[0], 70);
assert_eq!(sizes[1], 30);
}
#[test]
fn test_cross_device_autograd_gradient_flow() {
if !require_multi_gpu() {
return;
}
let opts0 = TensorOptions {
dtype: DType::Float32,
device: Device::CUDA(0),
};
let opts1 = TensorOptions {
dtype: DType::Float32,
device: Device::CUDA(1),
};
let w0 = Variable::new(Tensor::ones(&[4, 3], opts0).unwrap(), true);
let w1 = Variable::new(Tensor::ones(&[4, 3], opts1).unwrap(), true);
let input = Variable::new(
Tensor::ones(&[4, 4], opts0).unwrap(),
false,
);
let chunks = input.chunk(2, 0).unwrap();
assert_eq!(chunks.len(), 2);
let out0 = chunks[0].matmul(&w0).unwrap();
let shard1_dev1 = chunks[1].to_device(Device::CUDA(1)).unwrap();
let out1_dev1 = shard1_dev1.matmul(&w1).unwrap(); let out1_dev0 = out1_dev1.to_device(Device::CUDA(0)).unwrap();
let gathered = Variable::cat_many(&[&out0, &out1_dev0], 0).unwrap();
let loss = gathered.sum().unwrap();
loss.backward().unwrap();
let grad0 = w0.grad();
let grad1 = w1.grad();
assert!(
grad0.is_some(),
"w0 on device 0 should have gradient after backward"
);
assert!(
grad1.is_some(),
"w1 on device 1 should have gradient after backward"
);
let g0 = grad0.unwrap();
let g1 = grad1.unwrap();
assert_eq!(g0.device(), Device::CUDA(0), "w0 gradient should be on device 0");
assert_eq!(g1.device(), Device::CUDA(1), "w1 gradient should be on device 1");
let g0_sum = g0.sum().unwrap().item().unwrap();
let g1_sum = g1.sum().unwrap().item().unwrap();
assert!(
g0_sum.abs() > 1e-6,
"w0 gradient should be non-zero, got {g0_sum}"
);
assert!(
g1_sum.abs() > 1e-6,
"w1 gradient should be non-zero, got {g1_sum}"
);
cuda_synchronize(0);
cuda_synchronize(1);
}
#[test]
fn test_cross_device_autograd_values() {
if !require_multi_gpu() {
return;
}
let w_data = Tensor::from_f32(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[4, 2],
Device::CUDA(0),
)
.unwrap();
let w_ref = Variable::new(w_data.clone(), true);
let x = Tensor::from_f32(
&[1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
&[4, 4],
Device::CUDA(0),
)
.unwrap();
let x_var = Variable::new(x.clone(), false);
let out_ref = x_var.matmul(&w_ref).unwrap();
let loss_ref = out_ref.sum().unwrap();
loss_ref.backward().unwrap();
let grad_ref = w_ref.grad().unwrap();
let grad_ref_vals = grad_ref.to_f32_vec().unwrap();
let w0 = Variable::new(
Tensor::from_f32(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[4, 2],
Device::CUDA(0),
)
.unwrap(),
true,
);
let w1 = Variable::new(
Tensor::from_f32(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[4, 2],
Device::CUDA(1),
)
.unwrap(),
true,
);
let x_var2 = Variable::new(x, false);
let chunks = x_var2.chunk(2, 0).unwrap();
let out0 = chunks[0].matmul(&w0).unwrap();
let shard1 = chunks[1].to_device(Device::CUDA(1)).unwrap();
let out1_dev1 = shard1.matmul(&w1).unwrap();
let out1_dev0 = out1_dev1.to_device(Device::CUDA(0)).unwrap();
let gathered = Variable::cat_many(&[&out0, &out1_dev0], 0).unwrap();
let loss = gathered.sum().unwrap();
loss.backward().unwrap();
let g0 = w0.grad().unwrap().to_f32_vec().unwrap();
let g1 = w1.grad().unwrap().to_f32_vec().unwrap();
for i in 0..g0.len() {
let cross_sum = g0[i] + g1[i];
let diff = (cross_sum - grad_ref_vals[i]).abs();
assert!(
diff < 1e-5,
"gradient mismatch at index {i}: cross-device sum {cross_sum} vs reference {}",
grad_ref_vals[i]
);
}
cuda_synchronize(0);
cuda_synchronize(1);
}
#[test]
fn test_graph_set_optimizer_and_step() {
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear, ReLU, mse_loss};
let model = FlowBuilder::from(Linear::new(4, 8).unwrap())
.through(ReLU::new())
.through(Linear::new(8, 2).unwrap())
.build()
.unwrap();
model.set_optimizer(|p| Adam::new(p, 0.01));
model.set_training(true);
let params_before: Vec<f32> = model
.parameters()
.iter()
.flat_map(|p| p.variable.data().to_f32_vec().unwrap())
.collect();
let x = Variable::new(
Tensor::randn(&[4, 4], Default::default()).unwrap(),
false,
);
let target = Variable::new(
Tensor::randn(&[4, 2], Default::default()).unwrap(),
false,
);
let out = model.forward(&x).unwrap();
let loss = mse_loss(&out, &target).unwrap();
loss.backward().unwrap();
model.step().unwrap();
let params_after: Vec<f32> = model
.parameters()
.iter()
.flat_map(|p| p.variable.data().to_f32_vec().unwrap())
.collect();
let changed = params_before
.iter()
.zip(¶ms_after)
.any(|(a, b)| (a - b).abs() > 1e-8);
assert!(changed, "parameters should change after step()");
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-all"]
fn test_graph_distribute_adapts_to_hardware() {
use crate::graph::FlowBuilder;
use crate::nn::Linear;
use crate::tensor::usable_cuda_devices;
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let model = FlowBuilder::from(Linear::new(4, 2).unwrap())
.build()
.unwrap();
let result = model.distribute(|dev| {
FlowBuilder::from(Linear::on_device(4, 2, dev)?).build()
});
assert!(result.is_ok());
let usable = usable_cuda_devices();
if usable.len() >= 2 {
assert!(model.is_distributed());
assert_eq!(model.world_size(), usable.len());
} else {
assert!(!model.is_distributed());
assert_eq!(model.world_size(), 1);
}
}
#[test]
fn test_ddp_auto_single_gpu() {
if cuda_device_count() >= 2 {
return;
}
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear, ReLU, mse_loss};
let model = FlowBuilder::from(Linear::new(4, 8).unwrap())
.through(ReLU::new())
.through(Linear::new(8, 2).unwrap())
.build()
.unwrap();
Ddp::setup(
&model,
|dev| {
FlowBuilder::from(Linear::on_device(4, 8, dev)?)
.through(ReLU::new())
.through(Linear::on_device(8, 2, dev)?)
.build()
},
|p| Adam::new(p, 0.001),
)
.unwrap();
let x = Variable::new(
Tensor::randn(&[4, 4], Default::default()).unwrap(),
false,
);
let target = Variable::new(
Tensor::randn(&[4, 2], Default::default()).unwrap(),
false,
);
let out = model.forward(&x).unwrap();
let loss = mse_loss(&out, &target).unwrap();
loss.backward().unwrap();
model.step().unwrap();
assert!(!model.is_distributed());
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_ddp_auto_multi_gpu() {
if !require_multi_gpu() {
return;
}
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear, ReLU, mse_loss};
let model = FlowBuilder::from(
Linear::on_device(4, 8, Device::CUDA(0)).unwrap(),
)
.through(ReLU::new())
.through(Linear::on_device(8, 2, Device::CUDA(0)).unwrap())
.build()
.unwrap();
Ddp::setup(
&model,
|dev| {
FlowBuilder::from(Linear::on_device(4, 8, dev)?)
.through(ReLU::new())
.through(Linear::on_device(8, 2, dev)?)
.build()
},
|p| Adam::new(p, 0.001),
)
.unwrap();
assert!(model.is_distributed());
assert_eq!(model.world_size(), 2);
let opts = TensorOptions {
dtype: DType::Float32,
device: Device::CUDA(0),
};
let x = Variable::new(
Tensor::randn(&[8, 4], opts).unwrap(),
false,
);
let target = Variable::new(
Tensor::randn(&[8, 2], opts).unwrap(),
false,
);
let out = model.forward(&x).unwrap();
let loss = mse_loss(&out, &target).unwrap();
loss.backward().unwrap();
model.step().unwrap();
cuda_synchronize(0);
cuda_synchronize(1);
}
#[test]
fn test_graph_step_without_optimizer() {
use crate::graph::FlowBuilder;
use crate::nn::Linear;
let model = FlowBuilder::from(Linear::new(4, 2).unwrap())
.build()
.unwrap();
let result = model.step();
assert!(result.is_ok());
}
#[test]
fn test_graph_set_lr() {
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear};
let model = FlowBuilder::from(Linear::new(4, 2).unwrap())
.build()
.unwrap();
model.set_optimizer(|p| Adam::new(p, 0.01));
model.set_lr(0.001);
}
#[test]
fn test_cadence_initial_equal() {
let c = ElChe::new(2, 10);
assert_eq!(c.batches(0), 10);
assert_eq!(c.batches(1), 10);
assert_eq!(c.total_batches(), 20);
assert_eq!(c.anchor(), 10);
assert!(!c.is_calibrated());
}
#[test]
fn test_cadence_initial_three_devices() {
let c = ElChe::new(3, 15);
assert_eq!(c.batches(0), 15);
assert_eq!(c.batches(1), 15);
assert_eq!(c.batches(2), 15);
assert_eq!(c.total_batches(), 45);
}
#[test]
fn test_cadence_ratio_discovery_2x() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.50); let bc = c.batch_counts().to_vec(); c.report_timing(&[500.0, 1000.0], &bc, 10.0);
assert!(c.is_calibrated());
assert_eq!(c.batches(1), 10);
assert_eq!(c.batches(0), 20);
}
#[test]
fn test_cadence_ratio_discovery_fbrl_like() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.50);
let bc = c.batch_counts().to_vec(); c.report_timing(&[730.0, 1640.0], &bc, 50.0);
assert!(c.is_calibrated());
assert_eq!(c.batches(1), 10); let fast = c.batches(0);
assert!(
(22..=23).contains(&fast),
"expected ~22-23, got {fast}"
);
}
#[test]
fn test_cadence_anchor_auto_tune() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.10);
let bc = c.batch_counts().to_vec(); c.report_timing(&[1000.0, 1000.0], &bc, 500.0);
assert_eq!(c.anchor(), 50);
assert_eq!(c.batches(0), 50);
assert_eq!(c.batches(1), 50);
}
#[test]
fn test_cadence_anchor_auto_tune_with_speed_ratio() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.10);
let bc = c.batch_counts().to_vec(); c.report_timing(&[500.0, 1000.0], &bc, 400.0);
assert_eq!(c.anchor(), 40);
assert_eq!(c.batches(1), 40); assert_eq!(c.batches(0), 80);
}
#[test]
fn test_cadence_anchor_capped_at_max() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.01)
.with_max_anchor(30);
let bc = c.batch_counts().to_vec(); c.report_timing(&[100.0, 100.0], &bc, 500.0);
assert_eq!(c.anchor(), 30);
assert_eq!(c.batches(0), 30);
}
#[test]
fn test_cadence_stable_when_overhead_low() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.10);
let bc = c.batch_counts().to_vec(); c.report_timing(&[1000.0, 1000.0], &bc, 5.0);
assert_eq!(c.anchor(), 10); }
#[test]
fn test_cadence_three_devices_mixed_speed() {
let mut c = ElChe::new(3, 10)
.with_overhead_target(0.50);
let bc = c.batch_counts().to_vec(); c.report_timing(&[333.0, 500.0, 1000.0], &bc, 10.0);
assert_eq!(c.batches(2), 10); assert_eq!(c.batches(0), 30);
assert_eq!(c.batches(1), 20);
}
#[test]
fn test_cadence_successive_reports_refine() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.50);
let bc = c.batch_counts().to_vec(); c.report_timing(&[500.0, 1000.0], &bc, 10.0);
assert_eq!(c.batches(0), 20);
assert_eq!(c.batches(1), 10);
let bc = c.batch_counts().to_vec(); c.report_timing(&[1000.0, 1000.0], &bc, 10.0);
assert_eq!(c.batches(0), 20);
assert_eq!(c.batches(1), 10);
}
#[test]
fn test_cadence_clamp_total() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.50);
let bc = c.batch_counts().to_vec(); c.report_timing(&[500.0, 1000.0], &bc, 10.0);
let clamped = c.clamp_total(15);
assert_eq!(clamped.iter().sum::<usize>(), 15);
assert!(clamped[0] >= clamped[1], "fast device should still get more");
}
#[test]
fn test_cadence_clamp_total_no_op_when_within() {
let c = ElChe::new(2, 10);
let clamped = c.clamp_total(30);
assert_eq!(clamped, vec![10, 10]);
}
#[test]
fn test_cadence_builders() {
let c = ElChe::new(2, 10)
.with_overhead_target(0.20)
.with_max_anchor(100);
assert_eq!(c.anchor(), 10);
assert!(!c.is_calibrated());
let c2 = ElChe::new(2, 5)
.with_overhead_target(0.001); let _ = c2;
}
#[test]
fn test_cadence_max_batch_diff() {
let c = ElChe::new(2, 10).with_max_batch_diff(5);
assert_eq!(c.max_batch_diff(), Some(5));
let c2 = ElChe::new(2, 10);
assert_eq!(c2.max_batch_diff(), None);
}
#[test]
fn test_batch_count_clamped_to_max_diff() {
let mut c = ElChe::new(2, 10).with_max_batch_diff(3);
let bc = c.batch_counts().to_vec(); c.report_timing(&[100.0, 20.0], &bc, 0.0);
assert!(c.is_calibrated());
let counts_after_cal = c.batch_counts().to_vec();
assert_eq!(counts_after_cal[0], 10);
assert_eq!(counts_after_cal[1], 50);
let bc = c.batch_counts().to_vec(); c.report_timing(&[100.0, 450.0], &bc, 0.0);
let counts = c.batch_counts();
assert!(counts[1] >= counts_after_cal[1] - 3,
"batch count drop should be clamped to 3, was {} now {}",
counts_after_cal[1], counts[1]);
}
#[test]
fn test_cadence_weighted_allreduce_validation() {
let c = ElChe::new(2, 10);
assert_eq!(c.batch_counts().len(), 2);
}
#[test]
#[should_panic(expected = "El Che requires at least 2 devices")]
fn test_cadence_requires_two_devices() {
ElChe::new(1, 10);
}
#[test]
#[should_panic(expected = "anchor must be >= 1")]
fn test_cadence_requires_positive_anchor() {
ElChe::new(2, 0);
}
#[test]
fn test_cadence_speed_ratio_2x() {
let c = ElChe::new(2, 10).with_speed_ratio(1, 2.0);
assert_eq!(c.batches(0), 20);
assert_eq!(c.batches(1), 10);
}
#[test]
fn test_cadence_speed_ratio_fbrl() {
let c = ElChe::new(2, 10).with_speed_ratio(1, 2.3);
assert_eq!(c.batches(0), 23);
assert_eq!(c.batches(1), 10);
}
#[test]
fn test_cadence_speed_ratio_slow_rank_0() {
let c = ElChe::new(2, 10).with_speed_ratio(0, 3.0);
assert_eq!(c.batches(0), 10);
assert_eq!(c.batches(1), 30);
}
#[test]
fn test_cadence_speed_ratio_equal() {
let c = ElChe::new(2, 10).with_speed_ratio(1, 1.0);
assert_eq!(c.batches(0), 10);
assert_eq!(c.batches(1), 10);
}
#[test]
fn test_cadence_speed_ratio_three_devices() {
let c = ElChe::new(3, 10).with_speed_ratio(2, 3.0);
assert_eq!(c.batches(0), 30);
assert_eq!(c.batches(1), 30);
assert_eq!(c.batches(2), 10);
}
#[test]
fn test_cadence_speed_ratio_three_devices_mid_slow() {
let c = ElChe::new(3, 10).with_speed_ratio(1, 2.0);
assert_eq!(c.batches(0), 20);
assert_eq!(c.batches(1), 10);
assert_eq!(c.batches(2), 20);
}
#[test]
fn test_cadence_max_anchor_one() {
let mut c = ElChe::new(2, 1)
.with_max_anchor(1)
.with_speed_ratio(1, 2.0);
assert_eq!(c.batches(0), 2);
assert_eq!(c.batches(1), 1);
let bc = c.batch_counts().to_vec(); c.report_timing(&[100.0, 200.0], &bc, 500.0);
assert_eq!(c.anchor(), 1);
}
#[test]
fn test_nudge_anchor_down() {
let mut c = ElChe::new(2, 20)
.with_overhead_target(0.50); let bc = c.batch_counts().to_vec();
c.report_timing(&[50.0, 100.0], &bc, 0.0);
assert!(c.is_calibrated());
assert_eq!(c.anchor(), 20);
assert_eq!(c.batches(0), 40); assert_eq!(c.batches(1), 20);
c.nudge_anchor_down(0.5);
assert_eq!(c.anchor(), 10);
assert_eq!(c.batches(0), 20);
assert_eq!(c.batches(1), 10);
}
#[test]
fn test_nudge_anchor_down_clamped_to_one() {
let mut c = ElChe::new(2, 5);
assert_eq!(c.anchor(), 5);
c.nudge_anchor_down(0.1);
assert_eq!(c.anchor(), 1, "should clamp to 1");
}
#[test]
fn test_nudge_anchor_down_never_increases() {
let mut c = ElChe::new(2, 10);
c.nudge_anchor_down(2.0);
assert_eq!(c.anchor(), 10, "should never increase");
}
#[test]
fn test_cadence_speed_ratio_self_corrects() {
let mut c = ElChe::new(2, 10)
.with_overhead_target(0.50)
.with_speed_ratio(0, 2.0);
assert_eq!(c.batches(0), 10);
assert_eq!(c.batches(1), 20);
let bc = c.batch_counts().to_vec(); c.report_timing(&[500.0, 2000.0], &bc, 10.0);
assert_eq!(c.batches(1), c.anchor());
assert!(c.batches(0) > c.batches(1), "fast device should get more batches");
}
#[test]
fn test_ddp_config_defaults() {
let c = DdpConfig::new();
assert!(c.speed_hint.is_none());
assert!(c.overhead_target.is_none());
assert!(c.max_anchor.is_none());
}
#[test]
fn test_ddp_config_builder() {
let c = DdpConfig::new()
.speed_hint(1, 2.5)
.overhead_target(0.05)
.max_anchor(Some(20));
assert_eq!(c.speed_hint, Some((1, 2.5)));
assert_eq!(c.overhead_target, Some(0.05));
assert_eq!(c.max_anchor, Some(20));
}
#[test]
fn test_ddp_config_disable_el_che() {
let c = DdpConfig::new().max_anchor(Some(0));
assert_eq!(c.max_anchor, Some(0));
}
#[test]
fn test_configure_el_che_creates_from_config() {
let mut state = mock_state(&[0.5, 0.5]);
let config = DdpConfig::new().speed_hint(1, 2.0).overhead_target(0.15);
state.configure_el_che(&config);
assert!(state.el_che.is_some());
let el = state.el_che.as_ref().unwrap();
assert_eq!(el.batches(1), el.anchor());
assert!(el.batches(0) > el.batches(1));
}
#[test]
fn test_configure_el_che_disabled() {
let mut state = mock_state(&[0.5, 0.5]);
let config = DdpConfig::new().max_anchor(Some(0));
state.configure_el_che(&config);
assert!(state.el_che.is_none());
}
#[test]
fn test_configure_el_che_single_device_noop() {
let mut state = mock_state(&[1.0]);
let config = DdpConfig::new();
state.configure_el_che(&config);
assert!(state.el_che.is_none());
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_el_che_full_training_loop() {
if !require_multi_gpu() {
return;
}
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear, ReLU, mse_loss};
use crate::data::{DataLoader, DataSet};
struct TinyData;
impl DataSet for TinyData {
fn len(&self) -> usize { 200 }
fn get(&self, index: usize) -> crate::tensor::Result<Vec<Tensor>> {
let x = Tensor::from_f32(
&[index as f32; 4], &[4], Device::CPU,
)?;
let y = Tensor::from_f32(
&[(index as f32) * 0.1; 2], &[2], Device::CPU,
)?;
Ok(vec![x, y])
}
}
let model = FlowBuilder::from(
Linear::on_device(4, 8, Device::CUDA(0)).unwrap(),
)
.through(ReLU::new())
.through(Linear::on_device(8, 2, Device::CUDA(0)).unwrap())
.build()
.unwrap();
Ddp::setup_with(
&model,
|dev| {
FlowBuilder::from(Linear::on_device(4, 8, dev)?)
.through(ReLU::new())
.through(Linear::on_device(8, 2, dev)?)
.build()
},
|p| Adam::new(p, 0.001),
DdpConfig::new().speed_hint(1, 2.0).max_anchor(Some(3)),
)
.unwrap();
assert!(model.is_distributed());
assert!(model.has_el_che());
assert_eq!(model.world_size(), 2);
let loader = DataLoader::from_dataset(TinyData)
.batch_size(10)
.names(&["input", "target"])
.build()
.unwrap();
model.set_data_loader(loader, "input").unwrap();
let mut step_count = 0;
for batch in model.epoch(0).activate() {
let b = batch.unwrap();
let out = model.forward_batch(&b).unwrap();
let target = Variable::new(b["target"].clone(), false);
let loss = mse_loss(&out, &target).unwrap();
loss.backward().unwrap();
model.step().unwrap();
step_count += 1;
}
assert!(step_count > 0, "should have trained at least one step");
assert!(step_count <= 20, "should not have more steps than batches");
cuda_synchronize(0);
cuda_synchronize(1);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_el_che_tagged_outputs_gathered() {
if !require_multi_gpu() {
return;
}
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
use crate::graph::FlowBuilder;
use crate::nn::{Adam, Linear, ReLU, mse_loss};
use crate::data::{DataLoader, DataSet};
struct TinyData;
impl DataSet for TinyData {
fn len(&self) -> usize { 100 }
fn get(&self, index: usize) -> crate::tensor::Result<Vec<Tensor>> {
let x = Tensor::from_f32(
&[index as f32; 4], &[4], Device::CPU,
)?;
let y = Tensor::from_f32(
&[(index as f32) * 0.1; 2], &[2], Device::CPU,
)?;
Ok(vec![x, y])
}
}
let model = FlowBuilder::from(
Linear::on_device(4, 8, Device::CUDA(0)).unwrap(),
)
.through(ReLU::new())
.tag("hidden")
.through(Linear::on_device(8, 2, Device::CUDA(0)).unwrap())
.build()
.unwrap();
Ddp::setup_with(
&model,
|dev| {
FlowBuilder::from(Linear::on_device(4, 8, dev)?)
.through(ReLU::new())
.tag("hidden")
.through(Linear::on_device(8, 2, dev)?)
.build()
},
|p| Adam::new(p, 0.001),
DdpConfig::new().max_anchor(Some(2)),
)
.unwrap();
let loader = DataLoader::from_dataset(TinyData)
.batch_size(10)
.names(&["input", "target"])
.build()
.unwrap();
model.set_data_loader(loader, "input").unwrap();
let mut iter = model.epoch(0).activate();
if let Some(batch) = iter.next() {
let b = batch.unwrap();
let out = model.forward_batch(&b).unwrap();
let hidden = model.tagged("hidden");
assert!(hidden.is_some(), "tagged output should be gathered");
let h = hidden.unwrap();
assert_eq!(h.shape()[1], 8);
assert!(h.shape()[0] >= 10, "gathered hidden should span multiple batches");
let target = Variable::new(b["target"].clone(), false);
let loss = mse_loss(&out, &target).unwrap();
loss.backward().unwrap();
model.step().unwrap();
}
cuda_synchronize(0);
cuda_synchronize(1);
}