use super::*;
use crate::nn::Module;
use crate::tensor::TensorError;
#[test]
fn test_apply_policy_variants() {
let policies = [ApplyPolicy::Sync, ApplyPolicy::Cadence, ApplyPolicy::Async];
assert_eq!(policies.len(), 3);
assert_eq!(ApplyPolicy::Sync, ApplyPolicy::Sync);
assert_ne!(ApplyPolicy::Sync, ApplyPolicy::Async);
}
#[test]
fn test_average_backend_variants() {
let backends = [AverageBackend::Nccl, AverageBackend::Cpu];
assert_eq!(backends.len(), 2);
assert_eq!(AverageBackend::Nccl, AverageBackend::Nccl);
assert_ne!(AverageBackend::Nccl, AverageBackend::Cpu);
}
#[test]
fn test_control_msg_variants() {
let _req = ControlMsg::RequestParams;
let _sync = ControlMsg::SyncNow;
let _throttle = ControlMsg::Throttle;
let _start = ControlMsg::StartEpoch(EpochPlan {
epoch: 0, partition_offset: 0, partition_size: 1000,
});
let _ckpt = ControlMsg::Checkpoint { version: 42 };
let _shutdown = ControlMsg::Shutdown;
let _update = ControlMsg::Update(AveragedParams {
params: vec![],
buffers: vec![],
version: 0,
});
}
#[test]
fn test_timing_msg_send() {
fn assert_send<T: Send>() {}
assert_send::<TimingMsg>();
}
#[test]
fn test_metrics_msg_send() {
fn assert_send<T: Send>() {}
assert_send::<MetricsMsg>();
}
#[test]
fn test_param_snapshot_send() {
fn assert_send<T: Send>() {}
assert_send::<ParamSnapshot>();
}
#[test]
fn test_averaged_params_send() {
fn assert_send<T: Send>() {}
assert_send::<AveragedParams>();
}
#[test]
fn test_control_msg_send() {
fn assert_send<T: Send>() {}
assert_send::<ControlMsg>();
}
#[test]
fn test_worker_config_send() {
fn assert_send<T: Send>() {}
assert_send::<WorkerConfig>();
}
#[test]
fn test_worker_config_clone() {
let cfg = WorkerConfig {
rank: 0,
world_size: 2,
device: Device::CPU,
initial_params: vec![],
initial_buffers: vec![],
total_samples: 10000,
batch_size: 32,
seed: 42,
max_grad_norm: None,
timeline: None,
policy: ApplyPolicy::Sync,
};
let cfg2 = cfg.clone();
assert_eq!(cfg2.rank, 0);
assert_eq!(cfg2.world_size, 2);
assert_eq!(cfg2.total_samples, 10000);
}
use std::sync::mpsc;
use crate::autograd::Variable;
use crate::nn::Linear;
use crate::tensor::{test_device, test_opts, Tensor, TensorOptions, DType};
struct TestDataset {
n: usize,
}
impl crate::data::BatchDataSet for TestDataset {
fn len(&self) -> usize { self.n }
fn get_batch(&self, indices: &[usize]) -> crate::tensor::Result<Vec<Tensor>> {
let n = indices.len() as i64;
let opts = TensorOptions { dtype: DType::Float32, device: Device::CPU };
Ok(vec![
Tensor::randn(&[n, 4], opts)?,
Tensor::randn(&[n, 2], opts)?,
])
}
}
fn mse_train(model: &Linear, batch: &[Tensor]) -> Result<Variable> {
let input = Variable::new(batch[0].clone(), false);
let target = Variable::new(batch[1].clone(), false);
let output = model.forward(&input)?;
let diff = output.sub(&target)?;
diff.mul(&diff)?.mean()
}
fn make_test_worker() -> (GpuWorker<Linear>, WorkerChannels) {
make_test_worker_with(0, 1, 4)
}
fn make_test_worker_with(
rank: usize,
world_size: usize,
dataset_size: usize,
) -> (GpuWorker<Linear>, WorkerChannels) {
let dev = test_device();
let tmp_model = Linear::on_device(4, 2, dev).unwrap();
let tmp_params: Vec<Tensor> = tmp_model.parameters().iter()
.map(|p| p.variable.data())
.collect();
let tmp_buffers: Vec<Tensor> = tmp_model.buffers().iter()
.map(|b| b.get())
.collect();
drop(tmp_model);
let config = WorkerConfig {
rank,
world_size,
device: dev,
initial_params: tmp_params,
initial_buffers: tmp_buffers,
total_samples: dataset_size,
batch_size: 4,
seed: 42,
max_grad_norm: None,
timeline: None,
policy: ApplyPolicy::Sync,
};
let ((timing_tx, metrics_tx, param_tx, final_param_tx, control_rx), channels) =
GpuWorker::<Linear>::channels();
let dataset: Arc<dyn crate::data::BatchDataSet> =
Arc::new(TestDataset { n: dataset_size });
let worker = GpuWorker::new(
&config,
|d| Linear::on_device(4, 2, d),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
dataset,
None, None, timing_tx,
metrics_tx,
param_tx,
final_param_tx,
control_rx,
).unwrap();
(worker, channels)
}
#[test]
fn test_worker_new_and_accessors() {
let (worker, _ch) = make_test_worker();
assert_eq!(worker.rank(), 0);
assert_eq!(worker.local_step(), 0);
assert_eq!(worker.current_version(), 0);
assert_eq!(worker.param_vars.len(), 2); }
#[test]
fn test_worker_snapshot_params() {
let (worker, _ch) = make_test_worker();
let snap = worker.snapshot_params();
assert_eq!(snap.rank, 0);
assert_eq!(snap.params.len(), 2); assert_eq!(snap.buffers.len(), 0); assert_eq!(snap.batch_count, 1);
assert_eq!(snap.params[0].shape(), &[2, 4]); assert_eq!(snap.params[1].shape(), &[2]); }
#[test]
fn test_worker_snapshot_is_send() {
let (worker, _ch) = make_test_worker();
let snap = worker.snapshot_params();
let (tx, rx) = mpsc::channel::<ParamSnapshot>();
tx.send(snap).unwrap();
let received = rx.recv().unwrap();
assert_eq!(received.rank, 0);
assert_eq!(received.params.len(), 2);
}
#[test]
fn test_worker_load_averaged() {
let (mut worker, _ch) = make_test_worker();
let cpu = TensorOptions { dtype: DType::Float32, device: Device::CPU };
let new_weight = Tensor::ones(&[2, 4], cpu).unwrap();
let new_bias = Tensor::ones(&[2], cpu).unwrap();
let update = AveragedParams {
params: vec![new_weight, new_bias],
buffers: vec![],
version: 42,
};
worker.load_averaged(&update).unwrap();
let dev = test_device();
if let Device::CUDA(idx) = dev {
crate::tensor::cuda_synchronize(idx);
}
assert_eq!(worker.current_version(), 42);
let snap = worker.snapshot_params();
let w_sum: f64 = snap.params[0].sum().unwrap().item().unwrap();
assert!((w_sum - 8.0).abs() < 1e-5, "weight should be all ones (sum=8), got {w_sum}");
let b_sum: f64 = snap.params[1].sum().unwrap().item().unwrap();
assert!((b_sum - 2.0).abs() < 1e-5, "bias should be all ones (sum=2), got {b_sum}");
}
#[test]
fn test_worker_load_averaged_wrong_count() {
let (mut worker, _ch) = make_test_worker();
let update = AveragedParams {
params: vec![], buffers: vec![],
version: 1,
};
assert!(worker.load_averaged(&update).is_err());
}
#[test]
fn test_worker_train_step() {
let (mut worker, ch) = make_test_worker();
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let (loss, ms) = worker.train_step(&batch, &mse_train).unwrap();
assert!(ms > 0.0);
assert!(loss > 0.0);
assert_eq!(worker.local_step(), 1);
assert!(ch.timing_rx.try_recv().is_err());
}
#[test]
fn test_worker_report_timing() {
let (worker, ch) = make_test_worker();
worker.report_timing(12.5, None, 0.5, None).unwrap();
let msg = ch.timing_rx.recv().unwrap();
match msg {
TimingMsg::Batch { rank, batch_ms, step_count, .. } => {
assert_eq!(rank, 0);
assert!((batch_ms - 12.5).abs() < 1e-10);
assert_eq!(step_count, 0);
}
_ => panic!("expected Batch"),
}
}
#[test]
fn test_worker_report_epoch() {
let (worker, ch) = make_test_worker();
worker.report_epoch(0.5, 100, 5000.0).unwrap();
let msg = ch.metrics_rx.recv().unwrap();
assert_eq!(msg.rank, 0);
assert_eq!(msg.epoch, 0);
assert!((msg.avg_loss - 0.5).abs() < 1e-10);
assert_eq!(msg.batches_processed, 100);
}
#[test]
fn test_worker_handle_control_request_params() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::RequestParams).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(!shutdown);
let snap = ch.param_rx.recv().unwrap();
assert_eq!(snap.rank, 0);
assert_eq!(snap.params.len(), 2);
}
#[test]
fn test_worker_handle_control_update() {
let (mut worker, ch) = make_test_worker();
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: dev };
let update = AveragedParams {
params: vec![
Tensor::zeros(&[2, 4], opts).unwrap(),
Tensor::zeros(&[2], opts).unwrap(),
],
buffers: vec![],
version: 7,
};
ch.control_tx.send(ControlMsg::Update(update)).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(!shutdown);
assert_eq!(worker.current_version(), 7);
}
#[test]
fn test_worker_handle_control_start_epoch() {
let (mut worker, ch) = make_test_worker();
assert!(worker.pending_plan.is_none());
ch.control_tx.send(ControlMsg::StartEpoch(EpochPlan {
epoch: 1, partition_offset: 0, partition_size: 750,
})).unwrap();
worker.handle_control().unwrap();
let plan = worker.pending_plan.take();
assert!(plan.is_some());
assert_eq!(plan.unwrap().partition_size, 750);
assert!(worker.pending_plan.is_none()); }
#[test]
fn test_worker_handle_control_shutdown() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(shutdown);
}
#[test]
fn test_worker_handle_control_sync_now_noop() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(!shutdown);
}
#[test]
fn test_worker_full_roundtrip() {
let (mut worker, ch) = make_test_worker();
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
worker.train_step(&batch, &mse_train).unwrap();
assert_eq!(worker.local_step(), 1);
ch.control_tx.send(ControlMsg::RequestParams).unwrap();
worker.handle_control().unwrap();
let snap = ch.param_rx.recv().unwrap();
assert_eq!(snap.batch_count, 1);
let update = AveragedParams {
params: snap.params,
buffers: snap.buffers,
version: 1,
};
ch.control_tx.send(ControlMsg::Update(update)).unwrap();
worker.handle_control().unwrap();
assert_eq!(worker.current_version(), 1);
let batch2 = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
worker.train_step(&batch2, &mse_train).unwrap();
assert_eq!(worker.local_step(), 2);
}
#[test]
fn test_worker_epoch_from_plan() {
let (mut worker, _ch) = make_test_worker();
assert_eq!(worker.current_epoch, 0);
worker.current_epoch = 3;
assert_eq!(worker.current_epoch, 3);
}
#[test]
fn test_worker_channels_create() {
let ((timing_tx, metrics_tx, param_tx, _final_param_tx, _control_rx), ch) =
GpuWorker::<Linear>::channels();
timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 1.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
let msg = ch.timing_rx.recv().unwrap();
assert!(matches!(msg, TimingMsg::Batch { rank: 0, .. }));
metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.5, batches_processed: 10, epoch_ms: 100.0,
samples_processed: 320, scalars: HashMap::new(),
}).unwrap();
let msg = ch.metrics_rx.recv().unwrap();
assert_eq!(msg.batches_processed, 10);
param_tx.send(ParamSnapshot {
rank: 0, params: vec![], buffers: vec![], batch_count: 0,
}).unwrap();
let snap = ch.param_rx.recv().unwrap();
assert_eq!(snap.rank, 0);
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
}
use crate::distributed::ddp::ElChe;
struct CoordTestHarness {
coord: Coordinator,
timing_tx: mpsc::Sender<TimingMsg>,
metrics_tx: mpsc::Sender<MetricsMsg>,
param_tx: mpsc::Sender<ParamSnapshot>,
control_rxs: Vec<mpsc::Receiver<ControlMsg>>,
}
fn make_coord_harness(
n: usize,
policy: ApplyPolicy,
backend: AverageBackend,
) -> CoordTestHarness {
make_coord_harness_with_timeout(n, policy, backend, 5)
}
fn make_coord_harness_with_timeout(
n: usize,
policy: ApplyPolicy,
backend: AverageBackend,
snapshot_timeout_secs: u64,
) -> CoordTestHarness {
let (timing_tx, timing_rx) = mpsc::channel();
let (metrics_tx, metrics_rx) = mpsc::channel();
let (param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut control_rxs = Vec::new();
let mut final_param_rxs = Vec::new();
for _ in 0..n {
let (tx, rx) = mpsc::channel();
control_txs.push(tx);
control_rxs.push(rx);
let (_ftx, frx) = mpsc::channel();
final_param_rxs.push(frx);
}
let el_che = ElChe::new(n, 10);
let coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
policy, backend,
n, 10000, el_che,
)
.snapshot_timeout_secs(snapshot_timeout_secs)
.build();
CoordTestHarness { coord, timing_tx, metrics_tx, param_tx, control_rxs }
}
#[test]
fn test_coordinator_initial_state() {
let h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
assert_eq!(h.coord.version(), 0);
assert!(!h.coord.is_calibrated());
assert_eq!(h.coord.steps_since_avg(), &[0, 0]);
}
#[test]
fn test_coordinator_drain_timing() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert_eq!(h.coord.steps_since_avg(), &[1, 1]);
}
#[test]
fn test_coordinator_should_average_sync() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
assert!(!h.coord.should_average());
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!(!h.coord.should_average());
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!(h.coord.should_average());
}
#[test]
fn test_coordinator_should_average_async() {
let mut h = make_coord_harness(2, ApplyPolicy::Async, AverageBackend::Nccl);
for _ in 0..9 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
assert!(!h.coord.should_average());
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!(h.coord.should_average());
}
#[test]
fn test_coordinator_should_average_wall_time() {
let mut h = make_coord_harness(2, ApplyPolicy::Cadence, AverageBackend::Nccl);
for i in 0..10 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: i + 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: i + 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
assert!(h.coord.should_average()); h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs { while rx.try_recv().is_ok() {} }
assert!(h.coord.is_calibrated());
let target = h.coord.el_che.anchor_wall_ms();
assert!(target > 0.0, "anchor_wall_ms should be positive after calibration");
for i in 0..10 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 11 + i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: 11 + i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
assert!(!h.coord.should_average(), "fast rank wall time < target");
for i in 0..10 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 21 + i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
assert!(h.coord.should_average(), "both ranks at target wall time");
}
#[test]
fn test_async_uses_batch_count_not_wall_time() {
let mut h = make_coord_harness(2, ApplyPolicy::Async, AverageBackend::Nccl);
for i in 0..10 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: i + 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: i + 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
assert!(h.coord.should_average());
h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs { while rx.try_recv().is_ok() {} }
assert!(h.coord.is_calibrated());
let counts = h.coord.el_che.batch_counts();
let mut step0 = 11usize;
let mut step1 = 11usize;
for _ in 0..counts[0] {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: step0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
step0 += 1;
}
for _ in 0..counts[1] {
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: step1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
step1 += 1;
}
h.coord.drain_timing();
assert!(h.coord.should_average(), "async triggers on batch counts, not wall time");
}
#[test]
fn test_coordinator_trigger_nccl() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs {
match rx.recv().unwrap() {
ControlMsg::SyncNow => {}
other => panic!("expected SyncNow, got {:?}", std::mem::discriminant(&other)),
}
}
assert_eq!(h.coord.version(), 1);
assert_eq!(h.coord.steps_since_avg(), &[0, 0]);
}
#[test]
fn test_coordinator_trigger_cpu_averaging() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Cpu);
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: dev };
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs {
match rx.recv().unwrap() {
ControlMsg::RequestParams => {}
other => panic!("expected RequestParams, got {:?}", std::mem::discriminant(&other)),
}
match rx.recv().unwrap() {
ControlMsg::Throttle => {}
other => panic!("expected Throttle, got {:?}", std::mem::discriminant(&other)),
}
}
h.param_tx.send(ParamSnapshot {
rank: 0,
params: vec![Tensor::ones(&[2, 3], opts).unwrap()],
buffers: vec![],
batch_count: 10,
}).unwrap();
h.param_tx.send(ParamSnapshot {
rank: 1,
params: vec![Tensor::full(&[2, 3], 3.0, opts).unwrap()],
buffers: vec![],
batch_count: 10,
}).unwrap();
for _ in 0..100 {
h.coord.poll_cpu_averaging().unwrap();
if h.coord.version() > 0 {
break;
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
assert_eq!(h.coord.version(), 1);
for rx in &h.control_rxs {
match rx.recv().unwrap() {
ControlMsg::Update(avg) => {
let sum: f64 = avg.params[0].sum().unwrap().item().unwrap();
let expected = 2.0 * 6.0; assert!((sum - expected).abs() < 1e-4,
"expected sum={expected}, got {sum}");
assert_eq!(avg.version, 1);
}
other => panic!("expected Update, got {:?}", std::mem::discriminant(&other)),
}
}
}
#[test]
fn test_coordinator_average_params_weighted() {
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: dev };
let snapshots = vec![
ParamSnapshot {
rank: 0,
params: vec![Tensor::ones(&[4], opts).unwrap()],
buffers: vec![],
batch_count: 1,
},
ParamSnapshot {
rank: 1,
params: vec![Tensor::full(&[4], 5.0, opts).unwrap()],
buffers: vec![],
batch_count: 3,
},
];
let avg = Coordinator::average_params(&snapshots, 42).unwrap();
assert_eq!(avg.version, 42);
assert_eq!(avg.params.len(), 1);
let sum: f64 = avg.params[0].sum().unwrap().item().unwrap();
let expected = 4.0 * 4.0; assert!((sum - expected).abs() < 1e-4, "expected sum={expected}, got {sum}");
}
#[test]
fn test_coordinator_tick_sync_flow() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
let metrics = h.coord.tick().unwrap();
assert!(metrics.is_empty());
assert_eq!(h.coord.version(), 0);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
let metrics = h.coord.tick().unwrap();
assert!(metrics.is_empty());
assert_eq!(h.coord.version(), 1);
for rx in &h.control_rxs {
assert!(matches!(rx.recv().unwrap(), ControlMsg::SyncNow));
}
}
#[test]
fn test_coordinator_drain_metrics() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 1, avg_loss: 0.3, batches_processed: 50, epoch_ms: 2000.0,
samples_processed: 1600, scalars: HashMap::new(),
}).unwrap();
let metrics = h.coord.drain_metrics();
assert_eq!(metrics.len(), 1);
assert_eq!(metrics[0].rank, 0);
assert_eq!(metrics[0].epoch, 1);
}
#[test]
fn test_coordinator_compute_partition_sizes() {
let h = make_coord_harness(2, ApplyPolicy::Cadence, AverageBackend::Nccl);
let sizes = h.coord.compute_partition_sizes();
assert_eq!(sizes.len(), 2);
assert_eq!(sizes[0], 5000); assert_eq!(sizes[1], 5000);
}
#[test]
fn test_divergence_correction_nudges_anchor_down() {
let mut h = make_coord_harness(2, ApplyPolicy::Async, AverageBackend::Cpu);
let steps = vec![10; 2];
let wall_ms = vec![100.0; 2];
h.coord.finish_averaging_cpu(0.0, &steps, &wall_ms, None);
let overshoot_before = h.coord.max_overshoot;
for i in 0..3 {
let div = 0.10 + i as f64 * 0.05; h.coord.finish_averaging_cpu(0.0, &[10, 10], &[100.0, 100.0],
Some(super::convergence::DivergenceReport {
deltas: vec![div, div],
pre_norms: None,
post_norm: None,
}));
}
assert!(h.coord.max_overshoot <= overshoot_before + 2,
"3rd interval should suppress overshoot growth, got {}", h.coord.max_overshoot);
}
#[test]
fn test_divergence_below_threshold_no_correction() {
let mut h = make_coord_harness(2, ApplyPolicy::Async, AverageBackend::Cpu);
let steps = vec![10; 2];
let wall_ms = vec![100.0; 2];
h.coord.finish_averaging_cpu(0.0, &steps, &wall_ms, None);
let anchor_after_calibration = h.coord.el_che.anchor();
let steps2 = vec![10; 2];
let wall_ms2 = vec![100.0; 2];
h.coord.finish_averaging_cpu(0.0, &steps2, &wall_ms2, Some(super::convergence::DivergenceReport {
deltas: vec![0.01, 0.01],
pre_norms: None,
post_norm: None,
}));
assert_eq!(h.coord.el_che.anchor(), anchor_after_calibration);
}
fn make_throttle_harness(
n: usize,
max_batch_diff: usize,
) -> CoordTestHarness {
let (timing_tx, timing_rx) = mpsc::channel();
let (metrics_tx, metrics_rx) = mpsc::channel();
let (param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut control_rxs = Vec::new();
let mut final_param_rxs = Vec::new();
for _ in 0..n {
let (tx, rx) = mpsc::channel();
control_txs.push(tx);
control_rxs.push(rx);
let (_ftx, frx) = mpsc::channel();
final_param_rxs.push(frx);
}
let el_che = ElChe::new(n, 10).with_max_batch_diff(max_batch_diff);
let coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Async, AverageBackend::Cpu,
n, 10000, el_che,
).build();
CoordTestHarness { coord, timing_tx, metrics_tx, param_tx, control_rxs }
}
#[test]
fn test_throttle_sends_when_diff_exceeded() {
let mut h = make_throttle_harness(2, 3);
for i in 0..5 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
h.coord.check_throttle();
match h.control_rxs[0].try_recv() {
Ok(ControlMsg::Throttle) => {}
_ => panic!("expected Throttle for rank 0"),
}
assert!(h.control_rxs[1].try_recv().is_err(), "rank 1 should not be throttled");
}
#[test]
fn test_throttle_no_send_within_limit() {
let mut h = make_throttle_harness(2, 5);
for i in 0..3 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
h.coord.check_throttle();
assert!(h.control_rxs[0].try_recv().is_err());
assert!(h.control_rxs[1].try_recv().is_err());
}
#[test]
fn test_throttle_zero_is_strict_lockstep() {
let mut h = make_throttle_harness(2, 0);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
h.coord.check_throttle();
match h.control_rxs[0].try_recv() {
Ok(ControlMsg::Throttle) => {}
_ => panic!("expected Throttle for rank 0"),
}
}
#[test]
fn test_throttle_skipped_for_nccl() {
let (timing_tx, timing_rx) = mpsc::channel();
let (_metrics_tx, metrics_rx) = mpsc::channel();
let (_param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut control_rxs = Vec::new();
let mut final_param_rxs = Vec::new();
for _ in 0..2 {
let (tx, rx) = mpsc::channel();
control_txs.push(tx);
control_rxs.push(rx);
let (_ftx, frx) = mpsc::channel();
final_param_rxs.push(frx);
}
let el_che = ElChe::new(2, 10).with_max_batch_diff(3);
let mut coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Cadence, AverageBackend::Nccl,
2, 10000, el_che,
).build();
for i in 0..10 {
timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
coord.drain_timing();
coord.check_throttle();
assert!(control_rxs[0].try_recv().is_err(),
"NCCL backend must not throttle (AllReduce is the coordination mechanism)");
}
#[test]
fn test_throttle_disabled_when_none() {
let mut h = make_coord_harness(2, ApplyPolicy::Async, AverageBackend::Nccl);
for i in 0..50 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
h.coord.check_throttle();
assert!(h.control_rxs[0].try_recv().is_err());
}
#[test]
fn test_throttle_worker_unblocks_on_sync_now() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::Throttle).unwrap();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(!shutdown, "should not shutdown");
}
#[test]
fn test_throttle_worker_unblocks_on_shutdown() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::Throttle).unwrap();
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(shutdown, "should signal shutdown");
}
#[test]
fn test_async_ddp_config_max_batch_diff() {
let config = DdpRunConfig::new().with_max_batch_diff(5);
assert_eq!(config.max_batch_diff, Some(5));
let config2 = DdpRunConfig::new();
assert_eq!(config2.max_batch_diff, None);
}
#[test]
fn test_async_ddp_single_gpu_fallback() {
let ddp = DdpHandle::auto(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
Arc::new(TestDataset { n: 100 }),
4,
2, ApplyPolicy::Sync,
AverageBackend::Cpu, ).unwrap();
assert!(ddp.world_size() >= 1);
let state = ddp.join().unwrap();
assert_eq!(state.params.len(), 2);
assert_eq!(state.buffers.len(), 0);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_async_ddp_multi_gpu_nccl() {
if crate::tensor::usable_cuda_devices().len() < 2 {
return;
}
let ddp = DdpHandle::auto(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
Arc::new(TestDataset { n: 256 }),
32,
2, ApplyPolicy::Sync,
AverageBackend::Nccl,
).unwrap();
assert!(ddp.world_size() >= 2);
let state = ddp.join().unwrap();
assert_eq!(state.params.len(), 2);
}
#[test]
fn test_async_ddp_send_sync() {
fn assert_send<T: Send>() {}
assert_send::<DdpHandle>();
assert_send::<TrainedState>();
}
#[test]
fn test_builder_with_defaults() {
let ddp = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.dataset(Arc::new(TestDataset { n: 100 }))
.batch_size(4)
.num_epochs(2)
.backend(AverageBackend::Cpu)
.run()
.unwrap();
assert!(ddp.world_size() >= 1);
let state = ddp.join().unwrap();
assert_eq!(state.params.len(), 2);
}
#[test]
fn test_builder_with_all_options() {
let ddp = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.dataset(Arc::new(TestDataset { n: 100 }))
.batch_size(4)
.num_epochs(2)
.policy(ApplyPolicy::Sync)
.backend(AverageBackend::Cpu)
.overhead_target(0.15)
.max_anchor(100)
.anchor(5)
.divergence_threshold(0.1)
.max_batch_diff(10)
.run()
.unwrap();
let state = ddp.join().unwrap();
assert_eq!(state.params.len(), 2);
}
#[test]
#[should_panic(expected = "dataset is required")]
fn test_builder_missing_dataset_panics() {
let _ = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.batch_size(4)
.num_epochs(2)
.run();
}
#[test]
#[should_panic(expected = "batch_size is required")]
fn test_builder_missing_batch_size_panics() {
let _ = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.dataset(Arc::new(TestDataset { n: 100 }))
.num_epochs(2)
.run();
}
#[test]
#[should_panic(expected = "num_epochs is required")]
fn test_builder_missing_num_epochs_panics() {
let _ = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.dataset(Arc::new(TestDataset { n: 100 }))
.batch_size(4)
.run();
}
#[test]
fn test_worker_current_epoch_accessor() {
let (mut worker, _ch) = make_test_worker();
assert_eq!(worker.current_epoch(), 0);
worker.current_epoch = 1;
assert_eq!(worker.current_epoch(), 1);
}
#[test]
fn test_worker_set_lr() {
let (mut worker, _ch) = make_test_worker();
worker.set_lr(0.1);
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let (loss, _) = worker.train_step(&batch, &mse_train).unwrap();
assert!(loss > 0.0);
}
#[test]
fn test_epoch_fn_called_per_epoch() {
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = Arc::new(AtomicUsize::new(0));
let epochs_seen = Arc::new(std::sync::Mutex::new(Vec::new()));
let counter_c = counter.clone();
let epochs_c = epochs_seen.clone();
let num_epochs = 3;
let ddp = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.dataset(Arc::new(TestDataset { n: 100 }))
.batch_size(4)
.num_epochs(num_epochs)
.backend(AverageBackend::Cpu)
.epoch_fn(move |epoch, worker| {
counter_c.fetch_add(1, Ordering::Relaxed);
epochs_c.lock().unwrap().push(epoch);
assert_eq!(worker.current_epoch(), epoch);
})
.run()
.unwrap();
let world = ddp.world_size();
let _state = ddp.join().unwrap();
assert_eq!(counter.load(Ordering::Relaxed), num_epochs * world);
let mut seen = epochs_seen.lock().unwrap().clone();
seen.sort();
let mut expected: Vec<usize> = (0..num_epochs).cycle().take(num_epochs * world).collect();
expected.sort();
assert_eq!(seen, expected);
}
#[test]
fn test_epoch_fn_set_lr() {
use std::sync::atomic::{AtomicUsize, Ordering};
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_c = call_count.clone();
let ddp = DdpHandle::builder(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
mse_train,
)
.dataset(Arc::new(TestDataset { n: 100 }))
.batch_size(4)
.num_epochs(3)
.backend(AverageBackend::Cpu)
.epoch_fn(move |epoch, worker| {
let lr = 0.01 * (1.0 - epoch as f64 * 0.3);
worker.set_lr(lr);
call_count_c.fetch_add(1, Ordering::Relaxed);
})
.run()
.unwrap();
let world = ddp.world_size();
let _state = ddp.join().unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 3 * world);
}
#[test]
fn test_worker_send_final_snapshot() {
let (worker, ch) = make_test_worker();
worker.send_final_snapshot();
let snap = ch.final_param_rx.recv().unwrap();
assert_eq!(snap.params.len(), 2); assert_eq!(snap.rank, 0);
}
#[test]
fn test_collect_final_state_averages() {
let (timing_tx, timing_rx) = mpsc::channel();
let (_metrics_tx, metrics_rx) = mpsc::channel();
let (_param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut final_param_rxs = Vec::new();
let mut final_param_txs = Vec::new();
for _ in 0..2 {
let (ctx, _crx) = mpsc::channel();
control_txs.push(ctx);
let (ftx, frx) = mpsc::channel();
final_param_txs.push(ftx);
final_param_rxs.push(frx);
}
let el_che = ElChe::new(2, 10);
let coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Sync, AverageBackend::Cpu,
2, 1000, el_che,
).build();
let opts = crate::tensor::test_opts();
let t1 = Tensor::full(&[3], 2.0, opts).unwrap();
let t2 = Tensor::full(&[3], 4.0, opts).unwrap();
final_param_txs[0].send(ParamSnapshot {
rank: 0, params: vec![t1], buffers: vec![], batch_count: 1,
}).unwrap();
final_param_txs[1].send(ParamSnapshot {
rank: 1, params: vec![t2], buffers: vec![], batch_count: 1,
}).unwrap();
let state = coord.collect_final_state().unwrap();
assert_eq!(state.params.len(), 1);
let vals: Vec<f64> = state.params[0].to_f64_vec().unwrap();
assert!(vals.iter().all(|v| (v - 3.0).abs() < 1e-5), "expected all ~3.0, got {vals:?}");
drop(timing_tx);
}
#[test]
fn test_collect_final_state_single_survivor() {
let (_timing_tx, timing_rx) = mpsc::channel();
let (_metrics_tx, metrics_rx) = mpsc::channel();
let (_param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut final_param_rxs = Vec::new();
let mut final_param_txs = Vec::new();
for _ in 0..2 {
let (ctx, _crx) = mpsc::channel();
control_txs.push(ctx);
let (ftx, frx) = mpsc::channel();
final_param_txs.push(ftx);
final_param_rxs.push(frx);
}
let el_che = ElChe::new(2, 10);
let coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Sync, AverageBackend::Cpu,
2, 1000, el_che,
).build();
let opts = crate::tensor::test_opts();
let t = Tensor::full(&[3], 7.0, opts).unwrap();
final_param_txs[0].send(ParamSnapshot {
rank: 0, params: vec![t], buffers: vec![], batch_count: 5,
}).unwrap();
let state = coord.collect_final_state().unwrap();
assert_eq!(state.params.len(), 1);
let vals: Vec<f64> = state.params[0].to_f64_vec().unwrap();
assert!(vals.iter().all(|v| (v - 7.0).abs() < 1e-5), "single survivor should return its own params");
}
#[test]
fn test_checkpoint_msg_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ControlMsg>();
}
#[test]
fn test_checkpoint_fn_called_on_dispatch() {
use std::sync::atomic::{AtomicU64, Ordering};
let (mut worker, ch) = make_test_worker();
let called_version = Arc::new(AtomicU64::new(0));
let cv = called_version.clone();
worker.checkpoint_fn = Some(Arc::new(move |ver, _model| {
cv.store(ver, Ordering::Relaxed);
Ok(())
}));
ch.control_tx.send(ControlMsg::Checkpoint { version: 7 }).unwrap();
worker.handle_control().unwrap();
assert_eq!(called_version.load(Ordering::Relaxed), 7);
}
#[test]
fn test_checkpoint_error_logged_not_propagated() {
let (mut worker, ch) = make_test_worker();
worker.checkpoint_fn = Some(Arc::new(|_ver, _model| {
Err(TensorError::new("disk full"))
}));
ch.control_tx.send(ControlMsg::Checkpoint { version: 1 }).unwrap();
let shutdown = worker.handle_control().unwrap();
assert!(!shutdown);
}
#[test]
fn test_coordinator_sends_checkpoint_every_n_epochs() {
use crate::distributed::ddp::ElChe;
let n = 2;
let (_timing_tx, timing_rx) = mpsc::channel();
let (_metrics_tx, metrics_rx) = mpsc::channel();
let (_param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut control_rxs = Vec::new();
let mut final_param_rxs = Vec::new();
for _ in 0..n {
let (tx, rx) = mpsc::channel();
control_txs.push(tx);
control_rxs.push(rx);
let (_ftx, frx) = mpsc::channel();
final_param_rxs.push(frx);
}
let el_che = ElChe::new(n, 10);
let mut coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Sync, AverageBackend::Nccl,
n, 10000, el_che,
)
.num_epochs(10)
.checkpoint_every(2)
.build();
for epoch in 0..3 {
coord.on_epoch_aggregated(epoch);
}
let mut checkpoint_versions = Vec::new();
for rx in &control_rxs {
while let Ok(msg) = rx.try_recv() {
if let ControlMsg::Checkpoint { version } = msg {
checkpoint_versions.push(version);
}
}
}
assert_eq!(checkpoint_versions, vec![2], "should checkpoint once (at epoch 2) after 3 epochs with every=2");
}
type LossLog = Arc<std::sync::Mutex<Vec<(usize, usize, f64)>>>;
fn make_loss_tracker() -> LossLog {
Arc::new(std::sync::Mutex::new(Vec::new()))
}
fn run_2gpu_training(
backend: AverageBackend,
policy: ApplyPolicy,
num_epochs: usize,
) -> (Vec<f64>, Vec<f64>) {
let log = make_loss_tracker();
let log_clone = log.clone();
let ddp = DdpHandle::auto(
|dev| Linear::on_device(4, 2, dev),
|params| crate::nn::SGD::new(params, 0.01, 0.0),
move |model: &Linear, batch: &[Tensor]| {
let input = Variable::new(batch[0].clone(), false);
let target = Variable::new(batch[1].clone(), false);
let output = model.forward(&input)?;
let diff = output.sub(&target)?;
let loss = diff.mul(&diff)?.mean()?;
let loss_val: f64 = loss.data().item()?;
let rank = match batch[0].device() {
Device::CUDA(idx) => idx as usize,
Device::CPU => 0,
};
let step = {
let mut lg = log_clone.lock().unwrap();
let step = lg.iter().filter(|(r, _, _)| *r == rank).count();
lg.push((rank, step, loss_val));
step
};
let _ = step;
Ok(loss)
},
Arc::new(TestDataset { n: 512 }),
32,
num_epochs,
policy,
backend,
).unwrap();
let _state = ddp.join().unwrap();
let entries = log.lock().unwrap();
let r0: Vec<f64> = entries.iter().filter(|(r, _, _)| *r == 0).map(|(_, _, l)| *l).collect();
let r1: Vec<f64> = entries.iter().filter(|(r, _, _)| *r == 1).map(|(_, _, l)| *l).collect();
(r0, r1)
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_async_ddp_2gpu_cpu_backend_loss_decreases() {
if crate::tensor::usable_cuda_devices().len() < 2 {
return;
}
let (r0, r1) = run_2gpu_training(AverageBackend::Cpu, ApplyPolicy::Sync, 5);
assert!(!r0.is_empty(), "rank 0 should have loss entries");
assert!(!r1.is_empty(), "rank 1 should have loss entries");
let check_converged = |losses: &[f64], rank: usize| {
let n = losses.len();
let quarter = (n / 4).max(1);
let last_avg: f64 = losses[n - quarter..].iter().sum::<f64>() / quarter as f64;
assert!(last_avg.is_finite() && last_avg < 2.0,
"rank {rank} should converge: last_avg={last_avg:.4}");
};
check_converged(&r0, 0);
check_converged(&r1, 1);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_async_ddp_2gpu_nccl_backend_loss_decreases() {
if crate::tensor::usable_cuda_devices().len() < 2 {
return;
}
let (r0, r1) = run_2gpu_training(AverageBackend::Nccl, ApplyPolicy::Sync, 5);
assert!(!r0.is_empty(), "rank 0 should have loss entries");
assert!(!r1.is_empty(), "rank 1 should have loss entries");
let check_converged = |losses: &[f64], rank: usize| {
let n = losses.len();
let quarter = (n / 4).max(1);
let last_avg: f64 = losses[n - quarter..].iter().sum::<f64>() / quarter as f64;
assert!(last_avg.is_finite() && last_avg < 2.0,
"rank {rank} should converge: last_avg={last_avg:.4}");
};
check_converged(&r0, 0);
check_converged(&r1, 1);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: make cuda-test-nccl"]
fn test_async_ddp_ab_cpu_vs_nccl() {
if crate::tensor::usable_cuda_devices().len() < 2 {
return;
}
let epochs = 5;
let (cpu_r0, cpu_r1) = run_2gpu_training(AverageBackend::Cpu, ApplyPolicy::Sync, epochs);
let (nccl_r0, nccl_r1) = run_2gpu_training(AverageBackend::Nccl, ApplyPolicy::Sync, epochs);
let final_avg = |losses: &[f64]| -> f64 {
let n = losses.len();
let quarter = n / 4;
if quarter == 0 { return f64::MAX; }
losses[n - quarter..].iter().sum::<f64>() / quarter as f64
};
let cpu_final = (final_avg(&cpu_r0) + final_avg(&cpu_r1)) / 2.0;
let nccl_final = (final_avg(&nccl_r0) + final_avg(&nccl_r1)) / 2.0;
assert!(cpu_final < 2.0,
"CPU backend final loss too high: {cpu_final:.4}");
assert!(nccl_final < 2.0,
"NCCL backend final loss too high: {nccl_final:.4}");
let ratio = cpu_final.max(nccl_final) / cpu_final.min(nccl_final);
eprintln!(" A/B: CPU final={cpu_final:.4} NCCL final={nccl_final:.4} ratio={ratio:.2}");
assert!(ratio < 3.0,
"CPU vs NCCL final loss ratio too large: {ratio:.2} (CPU={cpu_final:.4} NCCL={nccl_final:.4})");
}
#[test]
fn test_cadence_heterogeneous_timing() {
let mut h = make_coord_harness(2, ApplyPolicy::Cadence, AverageBackend::Nccl);
for _ in 0..10 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
if h.coord.should_average() {
h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs {
while rx.try_recv().is_ok() {}
}
}
}
if h.coord.is_calibrated() {
let counts = h.coord.el_che.batch_counts();
assert!(counts[0] >= counts[1],
"fast rank should get more batches: {:?}", counts);
}
}
#[test]
fn test_cpu_averaging_divergence_correction() {
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: dev };
let mut h = make_coord_harness(2, ApplyPolicy::Async, AverageBackend::Cpu);
assert_eq!(h.coord.el_che.anchor(), 10);
for _ in 0..10 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
assert!(h.coord.should_average());
h.coord.trigger_averaging().unwrap();
h.param_tx.send(ParamSnapshot {
rank: 0,
params: vec![Tensor::ones(&[100], opts).unwrap()],
buffers: vec![],
batch_count: 1,
}).unwrap();
h.param_tx.send(ParamSnapshot {
rank: 1,
params: vec![Tensor::full(&[100], 100.0, opts).unwrap()],
buffers: vec![],
batch_count: 1,
}).unwrap();
let v_before = h.coord.version();
for _ in 0..100 {
h.coord.poll_cpu_averaging().unwrap();
if h.coord.version() > v_before {
break;
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
assert!(h.coord.version() > v_before, "averaging should have completed");
for rx in &h.control_rxs {
while rx.try_recv().is_ok() {}
}
let anchor = h.coord.el_che.anchor();
assert!(anchor < 200,
"divergence correction should have kept anchor below max, got {}", anchor);
assert!(h.coord.is_calibrated());
}
#[test]
fn test_throttle_during_cpu_averaging() {
let mut h = make_coord_harness(2, ApplyPolicy::Cadence, AverageBackend::Cpu);
let el_che = ElChe::new(2, 1).with_max_batch_diff(2);
h.coord.el_che = el_che;
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 5.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!(h.coord.should_average());
h.coord.trigger_averaging().unwrap();
assert!(h.coord.is_cpu_averaging());
assert!(!h.coord.should_average());
for rx in &h.control_rxs {
match rx.try_recv() {
Ok(ControlMsg::RequestParams) => {}
other => panic!("expected RequestParams, got {:?}", other.map(|m| std::mem::discriminant(&m))),
}
}
for i in 0..5 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 1.0, step_count: 2 + i, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
}
h.coord.drain_timing();
h.coord.check_throttle();
match h.control_rxs[0].try_recv() {
Ok(ControlMsg::Throttle) => {}
other => panic!("expected Throttle for rank 0, got {:?}", other.map(|m| std::mem::discriminant(&m))),
}
assert!(h.control_rxs[1].try_recv().is_err(), "rank 1 should not be throttled");
}
#[test]
fn test_cpu_avg_state_machine_full_cycle() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Cpu);
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: dev };
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert_eq!(h.coord.version(), 0);
assert!(!h.coord.is_cpu_averaging());
h.coord.trigger_averaging().unwrap();
assert!(h.coord.is_cpu_averaging());
h.coord.poll_cpu_averaging().unwrap();
assert!(h.coord.is_cpu_averaging());
h.param_tx.send(ParamSnapshot {
rank: 0,
params: vec![Tensor::ones(&[4], opts).unwrap()],
buffers: vec![],
batch_count: 5,
}).unwrap();
h.param_tx.send(ParamSnapshot {
rank: 1,
params: vec![Tensor::full(&[4], 3.0, opts).unwrap()],
buffers: vec![],
batch_count: 5,
}).unwrap();
h.coord.poll_cpu_averaging().unwrap();
for _ in 0..100 {
h.coord.poll_cpu_averaging().unwrap();
if !h.coord.is_cpu_averaging() {
break;
}
std::thread::sleep(std::time::Duration::from_millis(5));
}
assert!(!h.coord.is_cpu_averaging());
assert_eq!(h.coord.version(), 1);
for rx in &h.control_rxs {
let mut got_request = false;
let mut got_update = false;
while let Ok(msg) = rx.try_recv() {
match msg {
ControlMsg::RequestParams => got_request = true,
ControlMsg::Update(avg) => {
got_update = true;
assert_eq!(avg.version, 1);
}
_ => {}
}
}
assert!(got_request, "worker should have received RequestParams");
assert!(got_update, "worker should have received Update");
}
}
#[test]
fn test_cpu_avg_collection_timeout() {
let mut h = make_coord_harness_with_timeout(
2, ApplyPolicy::Sync, AverageBackend::Cpu, 1,
);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 5.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
h.coord.trigger_averaging().unwrap();
assert!(h.coord.is_cpu_averaging());
std::thread::sleep(std::time::Duration::from_secs(2));
h.coord.poll_cpu_averaging().unwrap(); assert!(!h.coord.is_cpu_averaging());
assert_eq!(h.coord.version(), 0);
assert!(h.coord.should_average());
}
#[test]
fn test_stale_snapshot_after_timeout() {
let mut h = make_coord_harness_with_timeout(
2, ApplyPolicy::Sync, AverageBackend::Cpu, 1,
);
let dev = test_device();
let opts = TensorOptions { dtype: DType::Float32, device: dev };
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 5.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
h.coord.trigger_averaging().unwrap();
h.param_tx.send(ParamSnapshot {
rank: 0,
params: vec![Tensor::full(&[4], 999.0, opts).unwrap()],
buffers: vec![],
batch_count: 1,
}).unwrap();
std::thread::sleep(std::time::Duration::from_secs(2));
h.coord.poll_cpu_averaging().unwrap();
assert!(!h.coord.is_cpu_averaging()); assert_eq!(h.coord.version(), 0);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 2, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 5.0, step_count: 2, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
h.coord.trigger_averaging().unwrap();
h.param_tx.send(ParamSnapshot {
rank: 0,
params: vec![Tensor::ones(&[4], opts).unwrap()],
buffers: vec![],
batch_count: 1,
}).unwrap();
h.param_tx.send(ParamSnapshot {
rank: 1,
params: vec![Tensor::full(&[4], 3.0, opts).unwrap()],
buffers: vec![],
batch_count: 1,
}).unwrap();
for _ in 0..100 {
h.coord.poll_cpu_averaging().unwrap();
if h.coord.version() > 0 {
break;
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
assert_eq!(h.coord.version(), 1);
for rx in &h.control_rxs {
let mut found_update = false;
while let Ok(msg) = rx.try_recv() {
if let ControlMsg::Update(avg) = msg {
let sum: f64 = avg.params[0].sum().unwrap().item().unwrap();
let expected = 2.0 * 4.0; assert!(
(sum - expected).abs() < 1e-4,
"expected sum={expected}, got {sum} (stale data leaked?)"
);
found_update = true;
}
}
assert!(found_update, "worker should have received Update");
}
}
#[test]
fn test_elche_calibration_produces_proportional_sizes() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
for _ in 0..5 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
if h.coord.should_average() {
h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs {
while rx.try_recv().is_ok() {}
}
}
}
assert!(h.coord.is_calibrated(), "ElChe should have calibrated");
let sizes = h.coord.compute_partition_sizes();
assert_eq!(sizes.len(), 2);
let total: usize = sizes.iter().sum();
assert!(total <= 10000, "partitions should not exceed total: {total}");
}
#[test]
fn test_wall_ms_accumulation() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 7.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 12.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!((h.coord.wall_ms_accum[0] - 12.0).abs() < 1e-10, "rank 0 should be 5+7=12");
assert!((h.coord.wall_ms_accum[1] - 22.0).abs() < 1e-10, "rank 1 should be 10+12=22");
}
#[test]
fn test_config_defaults() {
let cfg = DdpRunConfig::new();
assert!(cfg.overhead_target.is_none());
assert!(cfg.max_anchor.is_none());
assert!(cfg.anchor.is_none());
assert!(cfg.divergence_threshold.is_none());
}
#[test]
fn test_config_builder() {
let cfg = DdpRunConfig::new()
.with_overhead_target(0.05)
.with_max_anchor(100)
.with_anchor(20)
.with_divergence_threshold(0.01);
assert_eq!(cfg.overhead_target, Some(0.05));
assert_eq!(cfg.max_anchor, Some(100));
assert_eq!(cfg.anchor, Some(20));
assert_eq!(cfg.divergence_threshold, Some(0.01));
}
#[test]
fn test_make_partition_basic() {
let p0 = make_partition(0, 50, 100, 0, 42);
let p1 = make_partition(50, 50, 100, 0, 42);
assert_eq!(p0.len(), 50);
assert_eq!(p1.len(), 50);
let mut all: Vec<usize> = p0.iter().chain(p1.iter()).copied().collect();
all.sort();
all.dedup();
assert_eq!(all.len(), 100, "partitions should be non-overlapping");
}
#[test]
fn test_make_partition_different_epochs() {
let p_e0 = make_partition(0, 50, 100, 0, 42);
let p_e1 = make_partition(0, 50, 100, 1, 42);
assert_ne!(p_e0, p_e1);
}
#[test]
fn test_make_partition_deterministic() {
let p1 = make_partition(0, 50, 100, 5, 42);
let p2 = make_partition(0, 50, 100, 5, 42);
assert_eq!(p1, p2, "same params should produce same partition");
}
#[test]
fn test_worker_partition_changes_with_epoch() {
let (mut worker, _ch) = make_test_worker();
let plan0 = EpochPlan { epoch: 0, partition_offset: 0, partition_size: 1000 };
worker.run_epoch_plan(&plan0, &mse_train).unwrap();
let partition0 = worker.partition.clone();
let plan1 = EpochPlan { epoch: 1, partition_offset: 0, partition_size: 1000 };
worker.run_epoch_plan(&plan1, &mse_train).unwrap();
assert_ne!(worker.partition, partition0);
}
#[test]
fn test_worker_epoch_plan_applies_partition_size() {
let (mut worker, _ch) = make_test_worker_with(0, 1, 1000);
let plan = EpochPlan { epoch: 0, partition_offset: 0, partition_size: 200 };
worker.run_epoch_plan(&plan, &mse_train).unwrap();
assert_eq!(worker.partition.len(), 200);
}
#[test]
fn test_worker_run_epoch_plan() {
let (mut worker, ch) = make_test_worker_with(0, 1, 40);
let plan = EpochPlan { epoch: 0, partition_offset: 0, partition_size: 40 };
let shutdown = worker.run_epoch_plan(&plan, &mse_train).unwrap();
assert!(!shutdown);
assert_eq!(worker.current_epoch, 0);
let mut count = 0;
while ch.timing_rx.try_recv().is_ok() {
count += 1;
}
assert!(count > 0, "should have sent timing messages");
let metrics = ch.metrics_rx.recv().unwrap();
assert_eq!(metrics.epoch, 0); assert!(metrics.avg_loss > 0.0);
assert!(metrics.batches_processed > 0);
}
#[test]
fn test_worker_run_epoch_plan_loss_decreases() {
let (mut worker, _ch) = make_test_worker_with(0, 1, 80);
for epoch in 0..5 {
let plan = EpochPlan { epoch, partition_offset: 0, partition_size: 80 };
worker.run_epoch_plan(&plan, &mse_train).unwrap();
}
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let loss_after: f64 = mse_train(worker.model(), &batch).unwrap().data().item().unwrap();
assert!(loss_after.is_finite());
}
#[test]
fn test_worker_run_epoch_plan_shutdown_mid_epoch() {
let (mut worker, ch) = make_test_worker_with(0, 1, 400);
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
let plan = EpochPlan { epoch: 0, partition_offset: 0, partition_size: 400 };
let shutdown = worker.run_epoch_plan(&plan, &mse_train).unwrap();
assert!(shutdown, "should detect shutdown during epoch");
}
#[test]
fn test_cpu_averaging_end_to_end() {
let (mut w0, _ch0) = make_test_worker_with(0, 2, 40);
let (mut w1, _ch1) = make_test_worker_with(1, 2, 40);
let plan0 = EpochPlan { epoch: 0, partition_offset: 0, partition_size: 20 };
let plan1 = EpochPlan { epoch: 0, partition_offset: 20, partition_size: 20 };
w0.run_epoch_plan(&plan0, &mse_train).unwrap();
w1.run_epoch_plan(&plan1, &mse_train).unwrap();
let snap0 = w0.snapshot_params();
let snap1 = w1.snapshot_params();
let averaged = Coordinator::average_params(&[snap0, snap1], 1).unwrap();
w0.load_averaged(&averaged).unwrap();
w1.load_averaged(&averaged).unwrap();
assert_eq!(w0.current_version(), 1);
assert_eq!(w1.current_version(), 1);
let s0 = w0.snapshot_params();
let s1 = w1.snapshot_params();
for (p0, p1) in s0.params.iter().zip(&s1.params) {
let diff: f64 = p0.sub(p1).unwrap().abs().unwrap().sum().unwrap().item().unwrap();
assert!(diff < 1e-5, "params should be identical after averaging, diff={diff}");
}
}
#[test]
fn test_proportional_sharding() {
let mut h = make_coord_harness(2, ApplyPolicy::Cadence, AverageBackend::Nccl);
for _ in 0..3 {
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 5.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 10.0, step_count: 0, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
if h.coord.should_average() {
h.coord.trigger_averaging().unwrap();
for rx in &h.control_rxs {
while rx.try_recv().is_ok() {}
}
}
}
if h.coord.is_calibrated() {
let sizes = h.coord.compute_partition_sizes();
assert_eq!(sizes.len(), 2);
assert!(sizes[0] > sizes[1],
"fast rank should get more samples: {:?}", sizes);
let total: usize = sizes.iter().sum();
assert!(total <= 10000, "partitions should not exceed total: {total}");
}
}
#[test]
fn test_partition_non_overlapping_equal_sizes() {
let total = 300;
let per_rank = total / 3; let p0 = make_partition(0, per_rank, total, 5, 42);
let p1 = make_partition(100, per_rank, total, 5, 42);
let p2 = make_partition(200, per_rank, total, 5, 42);
assert_eq!(p0.len(), 100);
assert_eq!(p1.len(), 100);
assert_eq!(p2.len(), 100);
let set0: std::collections::HashSet<usize> = p0.iter().copied().collect();
let set1: std::collections::HashSet<usize> = p1.iter().copied().collect();
let set2: std::collections::HashSet<usize> = p2.iter().copied().collect();
assert_eq!(set0.intersection(&set1).count(), 0, "rank 0/1 should not overlap");
assert_eq!(set0.intersection(&set2).count(), 0, "rank 0/2 should not overlap");
assert_eq!(set1.intersection(&set2).count(), 0, "rank 1/2 should not overlap");
}
#[test]
fn test_partition_non_overlapping_smaller_sizes() {
let total = 300;
let p0 = make_partition(0, 50, total, 5, 42); let p1 = make_partition(50, 80, total, 5, 42); let p2 = make_partition(130, 60, total, 5, 42);
let set0: std::collections::HashSet<usize> = p0.iter().copied().collect();
let set1: std::collections::HashSet<usize> = p1.iter().copied().collect();
let set2: std::collections::HashSet<usize> = p2.iter().copied().collect();
assert_eq!(set0.intersection(&set1).count(), 0, "rank 0/1 should not overlap");
assert_eq!(set0.intersection(&set2).count(), 0, "rank 0/2 should not overlap");
assert_eq!(set1.intersection(&set2).count(), 0, "rank 1/2 should not overlap");
}
#[test]
fn test_partition_benign_overlap_different_epochs() {
let p0_e5 = make_partition(0, 50, 100, 5, 42);
let p1_e6 = make_partition(50, 50, 100, 6, 42);
let set0: std::collections::HashSet<usize> = p0_e5.iter().copied().collect();
let set1: std::collections::HashSet<usize> = p1_e6.iter().copied().collect();
assert!(set0.iter().all(|&i| i < 100));
assert!(set1.iter().all(|&i| i < 100));
}
#[test]
fn test_self_managed_epochs() {
let (mut worker, ch) = make_test_worker_with(0, 1, 40);
for epoch in 0..3 {
let plan = EpochPlan { epoch, partition_offset: 0, partition_size: 40 };
let shutdown = worker.run_epoch_plan(&plan, &mse_train).unwrap();
assert!(!shutdown);
}
assert_eq!(worker.current_epoch, 2);
let mut epoch_msgs = Vec::new();
while let Ok(msg) = ch.metrics_rx.try_recv() {
epoch_msgs.push(msg);
}
assert_eq!(epoch_msgs.len(), 3);
assert_eq!(epoch_msgs[0].epoch, 0);
assert_eq!(epoch_msgs[1].epoch, 1);
assert_eq!(epoch_msgs[2].epoch, 2);
}
#[test]
fn test_epoch_plan_partition_size_at_epoch_boundary() {
let (mut worker, _ch) = make_test_worker_with(0, 1, 80);
let plan0 = EpochPlan { epoch: 0, partition_offset: 0, partition_size: 80 };
worker.run_epoch_plan(&plan0, &mse_train).unwrap();
assert_eq!(worker.partition.len(), 80);
let plan1 = EpochPlan { epoch: 1, partition_offset: 0, partition_size: 20 };
worker.run_epoch_plan(&plan1, &mse_train).unwrap();
assert_eq!(worker.partition.len(), 20);
}
#[test]
fn test_record_scalar_accumulates() {
drain_scalars();
record_scalar("loss", 1.0);
record_scalar("loss", 2.0);
record_scalar("loss", 3.0);
let map = drain_scalars();
assert_eq!(map.len(), 1);
let (sum, count) = map["loss"];
assert_eq!(sum, 6.0);
assert_eq!(count, 3);
}
#[test]
fn test_record_scalar_multiple_tags() {
drain_scalars();
record_scalar("a", 1.0);
record_scalar("b", 2.0);
record_scalar("a", 3.0);
let map = drain_scalars();
assert_eq!(map.len(), 2);
assert_eq!(map["a"], (4.0, 2));
assert_eq!(map["b"], (2.0, 1));
}
#[test]
fn test_drain_scalars_clears() {
drain_scalars();
record_scalar("x", 1.0);
let first = drain_scalars();
assert_eq!(first.len(), 1);
let second = drain_scalars();
assert!(second.is_empty());
record_scalar("y", 5.0);
let third = drain_scalars();
assert_eq!(third.len(), 1);
assert!(!third.contains_key("x"));
assert_eq!(third["y"], (5.0, 1));
}
#[test]
fn test_record_scalar_thread_isolation() {
drain_scalars();
record_scalar("main", 1.0);
let child_result = std::thread::spawn(|| {
let empty = drain_scalars();
assert!(empty.is_empty());
record_scalar("child", 42.0);
drain_scalars()
}).join().unwrap();
assert_eq!(child_result.len(), 1);
assert_eq!(child_result["child"], (42.0, 1));
let main_result = drain_scalars();
assert_eq!(main_result.len(), 1);
assert_eq!(main_result["main"], (1.0, 1));
}
#[test]
fn test_aggregate_epoch_metrics() {
use super::coordinator::aggregate_epoch_metrics;
let mut scalars_r0 = HashMap::new();
scalars_r0.insert("loss".to_string(), (3.0, 3_usize)); scalars_r0.insert("acc".to_string(), (1.8, 3));
let mut scalars_r1 = HashMap::new();
scalars_r1.insert("loss".to_string(), (4.0, 2_usize)); scalars_r1.insert("acc".to_string(), (0.8, 2));
let msgs = vec![
MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.5, batches_processed: 60,
epoch_ms: 1000.0, samples_processed: 1920, scalars: scalars_r0,
},
MetricsMsg {
rank: 1, epoch: 0, avg_loss: 0.7, batches_processed: 40,
epoch_ms: 1200.0, samples_processed: 1280, scalars: scalars_r1,
},
];
let dev_indices = vec![0_u8, 1];
let m = aggregate_epoch_metrics(0, &msgs, &dev_indices);
assert_eq!(m.epoch, 0);
assert!((m.avg_loss - 0.58).abs() < 1e-9);
assert_eq!(m.epoch_ms, 1200.0);
assert!((m.scalars["loss"] - 1.4).abs() < 1e-9);
assert!((m.scalars["acc"] - 0.52).abs() < 1e-9);
assert_eq!(m.per_rank.len(), 2);
assert!((m.per_rank[0]["loss"] - 1.0).abs() < 1e-9);
assert!((m.per_rank[1]["loss"] - 2.0).abs() < 1e-9);
assert!((m.per_rank_throughput[0] - 1.92).abs() < 1e-9);
assert!((m.per_rank_throughput[1] - 1280.0 / 1200.0).abs() < 1e-9);
assert!((m.per_rank_batch_share[0] - 0.6).abs() < 1e-9);
assert!((m.per_rank_batch_share[1] - 0.4).abs() < 1e-9);
assert_eq!(m.device_indices, vec![0, 1]);
}
#[test]
fn test_aggregate_epoch_metrics_progressive() {
use super::coordinator::aggregate_epoch_metrics;
let msgs = vec![
MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.5, batches_processed: 20,
epoch_ms: 300.0, samples_processed: 640,
scalars: [("loss".to_string(), (2.0, 2_usize))].into(),
},
MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.4, batches_processed: 20,
epoch_ms: 600.0, samples_processed: 640,
scalars: [("loss".to_string(), (1.6, 2_usize))].into(),
},
MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.6, batches_processed: 20,
epoch_ms: 900.0, samples_processed: 640,
scalars: [("loss".to_string(), (1.8, 2_usize))].into(),
},
MetricsMsg {
rank: 1, epoch: 0, avg_loss: 0.7, batches_processed: 20,
epoch_ms: 500.0, samples_processed: 640,
scalars: [("loss".to_string(), (2.8, 2_usize))].into(),
},
MetricsMsg {
rank: 1, epoch: 0, avg_loss: 0.8, batches_processed: 20,
epoch_ms: 1000.0, samples_processed: 640,
scalars: [("loss".to_string(), (3.2, 2_usize))].into(),
},
];
let dev_indices = vec![0_u8, 1];
let m = aggregate_epoch_metrics(0, &msgs, &dev_indices);
assert_eq!(m.per_rank_throughput.len(), 2, "should have world_size entries");
assert_eq!(m.per_rank_batch_share.len(), 2);
assert_eq!(m.per_rank.len(), 2);
assert_eq!(m.device_indices, vec![0, 1]);
assert!((m.per_rank_throughput[0] - 1920.0 / 900.0).abs() < 1e-6);
assert!((m.per_rank_throughput[1] - 1280.0 / 1000.0).abs() < 1e-6);
assert!((m.per_rank_batch_share[0] - 0.6).abs() < 1e-9);
assert!((m.per_rank_batch_share[1] - 0.4).abs() < 1e-9);
assert_eq!(m.epoch_ms, 1000.0);
assert!((m.per_rank[0]["loss"] - 0.9).abs() < 1e-9);
assert!((m.per_rank[1]["loss"] - 1.5).abs() < 1e-9);
assert!((m.scalars["loss"] - 1.14).abs() < 1e-9);
}
#[test]
fn test_drain_until_shutdown_skips_sync_now() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
worker.drain_until_shutdown();
}
#[test]
fn test_drain_until_shutdown_handles_multiple_sync_now() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
worker.drain_until_shutdown();
}
#[test]
fn test_drain_until_shutdown_handles_interleaved_messages() {
let (mut worker, ch) = make_test_worker();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
ch.control_tx.send(ControlMsg::Checkpoint { version: 99 }).unwrap();
ch.control_tx.send(ControlMsg::StartEpoch(EpochPlan {
epoch: 5, partition_offset: 0, partition_size: 100,
})).unwrap();
ch.control_tx.send(ControlMsg::SyncNow).unwrap();
ch.control_tx.send(ControlMsg::Shutdown).unwrap();
worker.drain_until_shutdown();
assert!(worker.pending_plan.is_some());
}
#[test]
fn test_abort_nccl_no_panic_without_comm() {
let (mut worker, _ch) = make_test_worker();
worker.abort_nccl();
worker.abort_nccl();
}
#[test]
fn test_collect_final_state_disconnected_worker() {
let (_timing_tx, timing_rx) = mpsc::channel();
let (_metrics_tx, metrics_rx) = mpsc::channel();
let (_param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut final_param_rxs = Vec::new();
let mut final_param_txs = Vec::new();
for _ in 0..2 {
let (ctx, _crx) = mpsc::channel();
control_txs.push(ctx);
let (ftx, frx) = mpsc::channel();
final_param_txs.push(ftx);
final_param_rxs.push(frx);
}
let el_che = ElChe::new(2, 10);
let coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Sync, AverageBackend::Cpu,
2, 1000, el_che,
).build();
let opts = crate::tensor::test_opts();
let t = Tensor::full(&[3], 5.0, opts).unwrap();
final_param_txs[0].send(ParamSnapshot {
rank: 0, params: vec![t], buffers: vec![], batch_count: 1,
}).unwrap();
drop(final_param_txs.remove(1));
let start = std::time::Instant::now();
let state = coord.collect_final_state();
let elapsed = start.elapsed();
assert!(state.is_some(), "should get state from surviving worker");
assert!(elapsed.as_secs() < 2, "disconnect should be fast, not 10s timeout");
assert_eq!(state.unwrap().params.len(), 1);
}
#[test]
fn test_worker_error_triggers_shutdown_flag() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_check = shutdown.clone();
shutdown.store(true, Ordering::Relaxed);
assert!(shutdown_check.load(Ordering::Relaxed));
}
#[test]
fn test_coordinator_active_count_prevents_averaging_after_exit() {
let mut h = make_coord_harness(2, ApplyPolicy::Sync, AverageBackend::Nccl);
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 1, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!(h.coord.should_average(), "both ranks reported, should average");
h.coord.trigger_averaging().unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 0, batch_ms: 10.0, step_count: 2, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.timing_tx.send(TimingMsg::Batch { rank: 1, batch_ms: 20.0, step_count: 2, param_norm: None, batch_loss: 0.1, sync_divergence: None }).unwrap();
h.coord.drain_timing();
assert!(h.coord.should_average());
h.timing_tx.send(TimingMsg::Exiting { rank: 1 }).unwrap();
h.coord.drain_timing();
assert_eq!(h.coord.active_count, 1);
assert!(!h.coord.should_average(),
"should NOT average when active_count < world_size");
}
struct StreamingTestHarness {
inner: CoordTestHarness,
epoch_metrics_rx: mpsc::Receiver<EpochMetrics>,
}
fn make_streaming_harness(
n: usize,
num_epochs: usize,
total_samples: usize,
batch_size: usize,
max_overshoot: Option<usize>,
) -> CoordTestHarness {
make_streaming_harness_with_metrics(n, num_epochs, total_samples, batch_size, max_overshoot).inner
}
fn make_streaming_harness_with_metrics(
n: usize,
num_epochs: usize,
total_samples: usize,
batch_size: usize,
max_overshoot: Option<usize>,
) -> StreamingTestHarness {
let (timing_tx, timing_rx) = mpsc::channel();
let (metrics_tx, metrics_rx) = mpsc::channel();
let (param_tx, param_rx) = mpsc::channel();
let (epoch_metrics_tx, epoch_metrics_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut control_rxs = Vec::new();
let mut final_param_rxs = Vec::new();
for _ in 0..n {
let (tx, rx) = mpsc::channel();
control_txs.push(tx);
control_rxs.push(rx);
let (_ftx, frx) = mpsc::channel();
final_param_rxs.push(frx);
}
let el_che = ElChe::new(n, 10);
let coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Async, AverageBackend::Cpu,
n, total_samples, el_che,
)
.progressive(true)
.batch_size(batch_size)
.num_epochs(num_epochs)
.max_overshoot(max_overshoot)
.epoch_metrics_tx(epoch_metrics_tx)
.build();
StreamingTestHarness {
inner: CoordTestHarness { coord, timing_tx, metrics_tx, param_tx, control_rxs },
epoch_metrics_rx,
}
}
#[test]
fn test_streaming_cross_epoch_dispatch() {
let mut h = make_streaming_harness(2, 3, 20, 10, Some(5));
h.coord.send_all_plans(0);
let mut rank0_plan = None;
while let Ok(msg) = h.control_rxs[0].try_recv() {
if let ControlMsg::StartEpoch(p) = msg { rank0_plan = Some(p); }
}
let plan = rank0_plan.expect("rank 0 should get initial chunk");
assert_eq!(plan.epoch, 0);
let dispatched = plan.partition_size;
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1,
batches_processed: dispatched / 10,
epoch_ms: 50.0, samples_processed: dispatched,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
let mut epochs_dispatched = Vec::new();
while let Ok(msg) = h.control_rxs[0].try_recv() {
if let ControlMsg::StartEpoch(p) = msg { epochs_dispatched.push(p.epoch); }
}
}
#[test]
fn test_streaming_global_epoch_event_fires_when_all_complete() {
let mut h = make_streaming_harness(2, 2, 20, 10, Some(5));
let pool = super::coordinator::ChunkPool::new(0, 20, 2);
h.coord.chunk_pools.insert(0, pool);
h.coord.chunk_pools.get_mut(&0).unwrap().take_chunk(10, 0);
h.coord.chunk_pools.get_mut(&0).unwrap().take_chunk(10, 1);
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1, batches_processed: 1,
epoch_ms: 10.0, samples_processed: 10,
scalars: Default::default(),
}).unwrap();
h.metrics_tx.send(MetricsMsg {
rank: 1, epoch: 0, avg_loss: 0.2, batches_processed: 1,
epoch_ms: 20.0, samples_processed: 10,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
assert_eq!(h.coord.last_aggregated_epoch, Some(0),
"epoch 0 should be aggregated when both ranks complete");
}
#[test]
fn test_overshoot_gate_blocks_runaway() {
let mut h = make_streaming_harness(2, 3, 100, 10, Some(0));
let pool = super::coordinator::ChunkPool::new(0, 100, 2);
h.coord.chunk_pools.insert(0, pool);
h.coord.chunk_pools.get_mut(&0).unwrap().take_chunk(50, 0);
h.coord.chunk_pools.get_mut(&0).unwrap().take_chunk(50, 1);
h.coord.steps_since_avg[0] = 10;
h.coord.steps_since_avg[1] = 3;
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1, batches_processed: 5,
epoch_ms: 50.0, samples_processed: 50,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
let mut got_epoch_1 = false;
while let Ok(msg) = h.control_rxs[0].try_recv() {
if let ControlMsg::StartEpoch(p) = msg {
if p.epoch == 1 { got_epoch_1 = true; }
}
}
assert!(!got_epoch_1,
"overshoot gate should prevent cross-epoch dispatch when at limit");
}
#[test]
fn test_overshoot_gate_skipped_for_nccl() {
let (_timing_tx, timing_rx) = mpsc::channel();
let (metrics_tx, metrics_rx) = mpsc::channel();
let (_param_tx, param_rx) = mpsc::channel();
let mut control_txs = Vec::new();
let mut control_rxs = Vec::new();
let mut final_param_rxs = Vec::new();
for _ in 0..2 {
let (tx, rx) = mpsc::channel();
control_txs.push(tx);
control_rxs.push(rx);
let (_ftx, frx) = mpsc::channel();
final_param_rxs.push(frx);
}
let el_che = ElChe::new(2, 10);
let mut coord = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
final_param_rxs,
control_txs,
ApplyPolicy::Cadence, AverageBackend::Nccl,
2, 100, el_che,
)
.progressive(true)
.batch_size(10)
.num_epochs(3)
.max_overshoot(Some(0)) .build();
let pool = super::coordinator::ChunkPool::new(0, 100, 2);
coord.chunk_pools.insert(0, pool);
coord.chunk_pools.get_mut(&0).unwrap().take_chunk(50, 0);
coord.chunk_pools.get_mut(&0).unwrap().take_chunk(50, 1);
coord.steps_since_avg[0] = 10;
coord.steps_since_avg[1] = 3;
metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1, batches_processed: 5,
epoch_ms: 50.0, samples_processed: 50,
scalars: Default::default(),
}).unwrap();
coord.drain_metrics();
let mut got_start_epoch = false;
while let Ok(msg) = control_rxs[0].try_recv() {
if let ControlMsg::StartEpoch(_) = msg {
got_start_epoch = true;
}
}
assert!(got_start_epoch,
"NCCL backend must skip overshoot gate (AllReduce handles coordination)");
}
#[test]
fn test_overshoot_auto_tune_grows() {
let mut h = make_streaming_harness(2, 3, 1000, 10, None);
let initial = h.coord.max_overshoot;
assert!(initial >= 2, "initial overshoot should be at least 2");
h.coord.steps_since_avg = vec![10, 10];
h.coord.wall_ms_accum = vec![100.0, 200.0];
h.coord.finish_averaging_nccl();
assert_eq!(h.coord.max_overshoot, initial + 1,
"overshoot should grow by 1 after successful averaging");
}
#[test]
fn test_overshoot_auto_tune_suppressed_on_divergence_trend() {
let mut h = make_streaming_harness(2, 3, 1000, 10, None);
for _ in 0..3 {
h.coord.steps_since_avg = vec![10, 10];
h.coord.wall_ms_accum = vec![100.0, 200.0];
h.coord.finish_averaging_nccl();
}
let overshoot_after_growth = h.coord.max_overshoot;
for i in 0..3 {
let div = 0.10 + i as f64 * 0.05;
h.coord.finish_averaging_cpu(
10.0,
&[5_usize, 5],
&[50.0, 100.0],
Some(super::convergence::DivergenceReport {
deltas: vec![div, div],
pre_norms: None,
post_norm: None,
}),
);
}
assert!(h.coord.max_overshoot <= overshoot_after_growth + 2,
"3rd CPU round should suppress overshoot growth, got {}", h.coord.max_overshoot);
}
#[test]
fn test_overshoot_user_override_no_autotune() {
let mut h = make_streaming_harness(2, 3, 1000, 10, Some(7));
assert_eq!(h.coord.max_overshoot, 7);
assert!(!h.coord.overshoot_auto);
h.coord.steps_since_avg = vec![10, 10];
h.coord.wall_ms_accum = vec![100.0, 200.0];
h.coord.finish_averaging_nccl();
assert_eq!(h.coord.max_overshoot, 7,
"user-set overshoot should not auto-tune");
}
#[test]
fn test_multi_pool_completion_tracking() {
let mut h = make_streaming_harness(2, 3, 100, 10, Some(10));
let mut pool0 = super::coordinator::ChunkPool::new(0, 100, 2);
pool0.take_chunk(50, 0); pool0.take_chunk(50, 1); h.coord.chunk_pools.insert(0, pool0);
let mut pool1 = super::coordinator::ChunkPool::new(1, 100, 2);
pool1.take_chunk(30, 0); h.coord.chunk_pools.insert(1, pool1);
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1, batches_processed: 5,
epoch_ms: 50.0, samples_processed: 50,
scalars: Default::default(),
}).unwrap();
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 1, avg_loss: 0.2, batches_processed: 3,
epoch_ms: 30.0, samples_processed: 30,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
if let Some(pool) = h.coord.chunk_pools.get(&0) {
assert_eq!(pool.completed[0], 50, "epoch 0 pool should track rank 0 completion");
}
if let Some(pool) = h.coord.chunk_pools.get(&1) {
assert_eq!(pool.completed[0], 30, "epoch 1 pool should track rank 0 completion");
}
}
#[test]
fn test_shutdown_with_streaming_pools() {
let mut h = make_streaming_harness(2, 2, 20, 10, Some(5));
let mut pool0 = super::coordinator::ChunkPool::new(0, 20, 2);
pool0.take_chunk(10, 0);
pool0.take_chunk(10, 1);
h.coord.chunk_pools.insert(0, pool0);
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1, batches_processed: 1,
epoch_ms: 10.0, samples_processed: 10,
scalars: Default::default(),
}).unwrap();
h.metrics_tx.send(MetricsMsg {
rank: 1, epoch: 0, avg_loss: 0.2, batches_processed: 1,
epoch_ms: 20.0, samples_processed: 10,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
assert_eq!(h.coord.last_aggregated_epoch, Some(0));
for rx in &h.control_rxs {
while rx.try_recv().is_ok() {}
}
h.coord.chunk_pools.remove(&1);
let mut pool1 = super::coordinator::ChunkPool::new(1, 20, 2);
pool1.take_chunk(10, 0);
pool1.take_chunk(10, 1);
h.coord.chunk_pools.insert(1, pool1);
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 1, avg_loss: 0.05, batches_processed: 1,
epoch_ms: 10.0, samples_processed: 10,
scalars: Default::default(),
}).unwrap();
h.metrics_tx.send(MetricsMsg {
rank: 1, epoch: 1, avg_loss: 0.06, batches_processed: 1,
epoch_ms: 20.0, samples_processed: 10,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
assert_eq!(h.coord.last_aggregated_epoch, Some(1));
let mut shutdowns = 0;
for rx in &h.control_rxs {
while let Ok(msg) = rx.try_recv() {
if matches!(msg, ControlMsg::Shutdown) {
shutdowns += 1;
}
}
}
assert_eq!(shutdowns, 2, "both ranks should receive Shutdown after final epoch");
}
#[test]
fn test_ddp_run_config_max_overshoot() {
let config = DdpRunConfig::new().with_max_overshoot(5);
assert_eq!(config.max_overshoot, Some(5));
let config2 = DdpRunConfig::new();
assert_eq!(config2.max_overshoot, None);
}
#[test]
fn test_epoch_event_fires_with_mixed_epoch_ranks() {
let mut sh = make_streaming_harness_with_metrics(2, 3, 60, 10, Some(10));
let mut pool1 = super::coordinator::ChunkPool::new(1, 60, 2);
pool1.take_chunk(30, 0); pool1.take_chunk(30, 1); sh.inner.coord.chunk_pools.insert(1, pool1);
let mut pool2 = super::coordinator::ChunkPool::new(2, 60, 2);
pool2.take_chunk(20, 0); sh.inner.coord.chunk_pools.insert(2, pool2);
sh.inner.coord.rank_epoch[0] = 2; sh.inner.coord.rank_epoch[1] = 1;
sh.inner.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 1, avg_loss: 0.10, batches_processed: 3,
epoch_ms: 30.0, samples_processed: 30,
scalars: [("loss".to_string(), (0.30, 3_usize))].into(),
}).unwrap();
sh.inner.coord.drain_metrics();
assert!(sh.inner.coord.last_aggregated_epoch.is_none()
|| sh.inner.coord.last_aggregated_epoch == Some(0),
"epoch 1 should not aggregate with only rank 0 complete");
assert!(sh.inner.coord.chunk_pools.contains_key(&1),
"epoch 1 pool should still exist");
sh.inner.metrics_tx.send(MetricsMsg {
rank: 1, epoch: 1, avg_loss: 0.20, batches_processed: 3,
epoch_ms: 60.0, samples_processed: 30,
scalars: [("loss".to_string(), (0.60, 3_usize))].into(),
}).unwrap();
sh.inner.coord.drain_metrics();
assert_eq!(sh.inner.coord.last_aggregated_epoch, Some(1),
"epoch 1 should aggregate when both ranks complete");
assert!(!sh.inner.coord.chunk_pools.contains_key(&1),
"epoch 1 pool should be removed after aggregation");
assert!(sh.inner.coord.chunk_pools.contains_key(&2),
"epoch 2 pool should survive epoch 1 aggregation");
let em = sh.epoch_metrics_rx.try_recv()
.expect("epoch metrics should have been sent for epoch 1");
assert_eq!(em.epoch, 1);
assert!((em.avg_loss - 0.15).abs() < 1e-9,
"avg_loss should be batch-weighted mean: got {}", em.avg_loss);
assert_eq!(em.per_rank_batch_share.len(), 2);
assert!((em.per_rank_batch_share[0] - 0.5).abs() < 1e-9);
assert!((em.per_rank_batch_share[1] - 0.5).abs() < 1e-9);
assert!((em.scalars["loss"] - 0.15).abs() < 1e-9,
"loss scalar should be batch-weighted: got {}", em.scalars["loss"]);
assert!(em.epoch_ms > 0.0, "epoch_ms should be positive");
}
#[test]
fn test_dispatch_skips_aggregated_epochs() {
let mut h = make_streaming_harness(2, 5, 100, 10, None);
let mut pool0 = super::coordinator::ChunkPool::new(0, 100, 2);
pool0.take_chunk(70, 0);
pool0.take_chunk(30, 1);
h.coord.chunk_pools.insert(0, pool0);
let mut pool1 = super::coordinator::ChunkPool::new(1, 100, 2);
pool1.take_chunk(100, 0);
h.coord.chunk_pools.insert(1, pool1);
h.coord.rank_epoch[0] = 1;
h.coord.rank_epoch[1] = 0;
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 0, avg_loss: 0.1, batches_processed: 7,
epoch_ms: 50.0, samples_processed: 70,
scalars: Default::default(),
}).unwrap();
h.metrics_tx.send(MetricsMsg {
rank: 1, epoch: 0, avg_loss: 0.2, batches_processed: 3,
epoch_ms: 80.0, samples_processed: 30,
scalars: Default::default(),
}).unwrap();
h.metrics_tx.send(MetricsMsg {
rank: 0, epoch: 1, avg_loss: 0.1, batches_processed: 10,
epoch_ms: 100.0, samples_processed: 100,
scalars: Default::default(),
}).unwrap();
h.coord.drain_metrics();
assert_eq!(h.coord.last_aggregated_epoch, Some(1),
"both epoch 0 and 1 should be aggregated");
for &epoch in h.coord.chunk_pools.keys() {
assert!(epoch >= 2,
"found orphan pool for already-aggregated epoch {epoch}");
}
assert!(h.coord.rank_epoch[1] >= 2,
"slow GPU should be on epoch 2+, got epoch {}",
h.coord.rank_epoch[1]);
}
struct ConstLr(f64);
impl crate::nn::Scheduler for ConstLr {
fn lr(&self, _step: usize) -> f64 { self.0 }
}
struct LinearLr { slope: f64 }
impl crate::nn::Scheduler for LinearLr {
fn lr(&self, step: usize) -> f64 { step as f64 * self.slope }
}
#[test]
fn test_worker_scheduler_drives_optimizer_lr() {
let (mut worker, _ch) = make_test_worker();
worker.set_lr(0.0);
worker.set_scheduler(Arc::new(ConstLr(0.05)));
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
worker.train_step(&batch, &mse_train).unwrap();
assert!((worker.current_lr() - 0.05).abs() < 1e-9,
"expected optimizer LR 0.05, got {}", worker.current_lr());
}
#[test]
fn test_worker_lr_scale_multiplies_scheduler_output() {
let (mut worker, _ch) = make_test_worker();
worker.set_scheduler(Arc::new(ConstLr(0.05)));
worker.set_lr_scale(2.0);
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
worker.train_step(&batch, &mse_train).unwrap();
assert!((worker.current_lr() - 0.10).abs() < 1e-9,
"expected optimizer LR 0.10 (sched 0.05 * scale 2.0), got {}",
worker.current_lr());
}
#[test]
fn test_worker_scheduler_step_advances_with_global_progress() {
let (mut worker, _ch) = make_test_worker();
worker.set_scheduler(Arc::new(LinearLr { slope: 0.01 }));
let opts = test_opts();
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
worker.train_step(&batch, &mse_train).unwrap();
assert!((worker.current_lr() - 0.0).abs() < 1e-9);
worker.train_step(&batch, &mse_train).unwrap();
assert!((worker.current_lr() - 0.01).abs() < 1e-9,
"step 1: got {}", worker.current_lr());
worker.train_step(&batch, &mse_train).unwrap();
assert!((worker.current_lr() - 0.02).abs() < 1e-9,
"step 2: got {}", worker.current_lr());
}
struct RecordingSched {
inner: crate::nn::MultiStepLR,
queries: std::sync::Mutex<Vec<(usize, f64)>>,
}
impl RecordingSched {
fn new(base_lr: f64, milestones: &[usize], gamma: f64) -> Self {
Self {
inner: crate::nn::MultiStepLR::new(base_lr, milestones, gamma),
queries: std::sync::Mutex::new(Vec::new()),
}
}
}
impl crate::nn::Scheduler for RecordingSched {
fn lr(&self, step: usize) -> f64 {
let lr = self.inner.lr(step);
self.queries.lock().unwrap().push((step, lr));
lr
}
}
#[test]
fn test_cross_mode_lr_parity_solo_vs_worker_vs_graph() {
use crate::graph::FlowBuilder;
use crate::nn::{Module, Optimizer, SGD};
let dev = test_device();
let opts = test_opts();
let n_steps: usize = 12;
let base_lr = 0.1f64;
let milestones = vec![4usize, 8];
let gamma = 0.1f64;
let manual_lrs: Vec<f64> = {
let model = Linear::on_device(4, 2, dev).unwrap();
let mut opt = SGD::new(&model.parameters(), 0.0, 0.0);
let sched = crate::nn::MultiStepLR::new(base_lr, &milestones, gamma);
let batch = [
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let mut lrs = Vec::with_capacity(n_steps);
for step in 0..n_steps {
opt.set_lr(sched.lr(step));
let v = Variable::new(batch[0].clone(), false);
let t = Variable::new(batch[1].clone(), false);
let pred = model.forward(&v).unwrap();
let loss = pred.sub(&t).unwrap();
loss.mul(&loss).unwrap().mean().unwrap().backward().unwrap();
opt.step().unwrap();
opt.zero_grad();
lrs.push(opt.lr());
}
lrs
};
let worker_lrs: Vec<f64> = {
let (mut worker, _ch) = make_test_worker();
worker.set_scheduler(Arc::new(RecordingSched::new(base_lr, &milestones, gamma)));
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let mut lrs = Vec::with_capacity(n_steps);
for _ in 0..n_steps {
worker.train_step(&batch, &mse_train).unwrap();
lrs.push(worker.current_lr());
}
lrs
};
let graph_lrs: Vec<f64> = {
let graph = FlowBuilder::from(Linear::on_device(4, 2, dev).unwrap())
.build()
.unwrap();
graph.set_optimizer(|p| SGD::new(p, 0.0, 0.0));
graph.set_scheduler(Arc::new(RecordingSched::new(base_lr, &milestones, gamma)));
let x = Variable::new(Tensor::randn(&[4, 4], opts).unwrap(), false);
let t = Variable::new(Tensor::randn(&[4, 2], opts).unwrap(), false);
let mut lrs = Vec::with_capacity(n_steps);
for _ in 0..n_steps {
let pred = graph.forward(&x).unwrap();
let loss = pred.sub(&t).unwrap();
loss.mul(&loss).unwrap().mean().unwrap().backward().unwrap();
graph.step().unwrap();
let lr = graph.optimizer.borrow().as_ref().map(|o| o.lr()).unwrap();
lrs.push(lr);
}
lrs
};
assert_eq!(manual_lrs.len(), n_steps);
assert_eq!(worker_lrs.len(), n_steps);
assert_eq!(graph_lrs.len(), n_steps);
for step in 0..n_steps {
assert!((manual_lrs[step] - worker_lrs[step]).abs() < 1e-9,
"step {step}: solo={} vs worker={}", manual_lrs[step], worker_lrs[step]);
assert!((manual_lrs[step] - graph_lrs[step]).abs() < 1e-9,
"step {step}: solo={} vs graph={}", manual_lrs[step], graph_lrs[step]);
}
let mut transitions = 0;
for w in manual_lrs.windows(2) {
if (w[0] - w[1]).abs() > 1e-9 { transitions += 1; }
}
assert_eq!(transitions, 2,
"expected 2 LR drops over 12 steps with milestones [4, 8]; got {transitions}. \
trajectory: {manual_lrs:?}");
}
#[test]
fn test_cross_mode_lr_parity_with_lr_scale() {
use crate::graph::FlowBuilder;
use crate::nn::{Module, Optimizer, SGD};
let dev = test_device();
let opts = test_opts();
let n_steps: usize = 8;
let scale = 2.5;
let manual_lrs: Vec<f64> = {
let model = Linear::on_device(4, 2, dev).unwrap();
let mut opt = SGD::new(&model.parameters(), 0.0, 0.0);
let sched = crate::nn::MultiStepLR::new(0.1, &[4], 0.1);
let batch = [
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let mut lrs = Vec::with_capacity(n_steps);
for step in 0..n_steps {
opt.set_lr(sched.lr(step) * scale);
let v = Variable::new(batch[0].clone(), false);
let t = Variable::new(batch[1].clone(), false);
let pred = model.forward(&v).unwrap();
let loss = pred.sub(&t).unwrap();
loss.mul(&loss).unwrap().mean().unwrap().backward().unwrap();
opt.step().unwrap();
opt.zero_grad();
lrs.push(opt.lr());
}
lrs
};
let worker_lrs: Vec<f64> = {
let (mut worker, _ch) = make_test_worker();
worker.set_scheduler(Arc::new(crate::nn::MultiStepLR::new(0.1, &[4], 0.1)));
worker.set_lr_scale(scale);
let batch = vec![
Tensor::randn(&[4, 4], opts).unwrap(),
Tensor::randn(&[4, 2], opts).unwrap(),
];
let mut lrs = Vec::with_capacity(n_steps);
for _ in 0..n_steps {
worker.train_step(&batch, &mse_train).unwrap();
lrs.push(worker.current_lr());
}
lrs
};
let graph_lrs: Vec<f64> = {
let graph = FlowBuilder::from(Linear::on_device(4, 2, dev).unwrap())
.build()
.unwrap();
graph.set_optimizer(|p| SGD::new(p, 0.0, 0.0));
graph.set_scheduler(Arc::new(crate::nn::MultiStepLR::new(0.1, &[4], 0.1)));
graph.set_lr_scale(scale);
let x = Variable::new(Tensor::randn(&[4, 4], opts).unwrap(), false);
let t = Variable::new(Tensor::randn(&[4, 2], opts).unwrap(), false);
let mut lrs = Vec::with_capacity(n_steps);
for _ in 0..n_steps {
let pred = graph.forward(&x).unwrap();
let loss = pred.sub(&t).unwrap();
loss.mul(&loss).unwrap().mean().unwrap().backward().unwrap();
graph.step().unwrap();
let lr = graph.optimizer.borrow().as_ref().map(|o| o.lr()).unwrap();
lrs.push(lr);
}
lrs
};
for step in 0..n_steps {
assert!((manual_lrs[step] - worker_lrs[step]).abs() < 1e-9,
"step {step}: solo*scale={} vs worker={}", manual_lrs[step], worker_lrs[step]);
assert!((manual_lrs[step] - graph_lrs[step]).abs() < 1e-9,
"step {step}: solo*scale={} vs graph={}", manual_lrs[step], graph_lrs[step]);
}
}