use std::collections::HashMap;
use std::sync::Arc;
use crate::distributed::grad_sync::{all_reduce_grads, broadcast_params};
use crate::error::Result;
use crate::ops::FusedOptimizerOps;
use crate::trainer::SimpleTrainer;
use crate::trainer::config::{TrainingConfig, TrainingMetrics};
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use numr::runtime::Communicator;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
pub struct DistributedTrainer<R: Runtime<DType = DType>> {
inner: SimpleTrainer<R>,
comm: Arc<dyn Communicator>,
}
impl<R: Runtime<DType = DType>> DistributedTrainer<R> {
pub fn new(config: TrainingConfig, comm: Arc<dyn Communicator>) -> Result<Self> {
let inner = SimpleTrainer::new(config)?;
Ok(Self { inner, comm })
}
pub fn broadcast_params(&self, params: &HashMap<TensorId, Tensor<R>>) -> Result<()> {
broadcast_params(self.comm.as_ref(), params, 0)
}
pub fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: GradStore<R>,
loss_value: f64,
) -> Result<Option<TrainingMetrics>>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
let mut synced_grads = grads;
all_reduce_grads(self.comm.as_ref(), client, &mut synced_grads)?;
self.inner.step(client, params, synced_grads, loss_value)
}
pub fn global_step(&self) -> u64 {
self.inner.global_step()
}
pub fn communicator(&self) -> &dyn Communicator {
self.comm.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use crate::trainer::config::TrainingConfig;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_distributed_trainer_creation() {
let comm = Arc::new(NoOpCommunicator);
let config = TrainingConfig {
learning_rate: 1e-3,
weight_decay: 0.01,
grad_accum_steps: 1,
max_grad_norm: Some(1.0),
};
let trainer = DistributedTrainer::<CpuRuntime>::new(config, comm).unwrap();
assert_eq!(trainer.global_step(), 0);
}
#[test]
fn test_distributed_trainer_broadcast_noop() {
let comm = Arc::new(NoOpCommunicator);
let config = TrainingConfig {
learning_rate: 1e-3,
weight_decay: 0.01,
grad_accum_steps: 1,
max_grad_norm: None,
};
let trainer = DistributedTrainer::<CpuRuntime>::new(config, comm).unwrap();
let (_client, device) = cpu_setup();
let id = TensorId::new();
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let mut params = HashMap::new();
params.insert(id, t);
trainer.broadcast_params(¶ms).unwrap();
}
#[test]
fn test_distributed_trainer_step() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let config = TrainingConfig {
learning_rate: 1e-4,
weight_decay: 0.0,
grad_accum_steps: 1,
max_grad_norm: None,
};
let mut trainer = DistributedTrainer::<CpuRuntime>::new(config, comm).unwrap();
let param_id = TensorId::new();
let param = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let mut params = HashMap::new();
params.insert(param_id, param);
let grad = Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2], &[2], &device);
let mut grads = GradStore::new();
grads.insert(param_id, grad);
let result = trainer.step(&client, &mut params, grads, 0.5).unwrap();
assert!(result.is_some());
let metrics = result.unwrap();
assert_eq!(metrics.step, 1);
}
#[test]
fn test_distributed_trainer_grad_accumulation() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let config = TrainingConfig {
learning_rate: 1e-4,
weight_decay: 0.0,
grad_accum_steps: 2,
max_grad_norm: None,
};
let mut trainer = DistributedTrainer::<CpuRuntime>::new(config, comm).unwrap();
let param_id = TensorId::new();
let param = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let mut params = HashMap::new();
params.insert(param_id, param);
let grad1 = Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2], &[2], &device);
let mut grads1 = GradStore::new();
grads1.insert(param_id, grad1);
let result = trainer.step(&client, &mut params, grads1, 0.5).unwrap();
assert!(result.is_none());
let grad2 = Tensor::<CpuRuntime>::from_slice(&[0.3f32, 0.4], &[2], &device);
let mut grads2 = GradStore::new();
grads2.insert(param_id, grad2);
let result = trainer.step(&client, &mut params, grads2, 0.6).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().step, 1);
}
#[test]
fn test_distributed_trainer_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DistributedTrainer<CpuRuntime>>();
}
}