use std::collections::HashMap;
use std::sync::Arc;
use crate::distributed::bucket_manager::GradientBucketManager;
use crate::distributed::grad_sync::broadcast_params;
use crate::error::Result;
use crate::ops::FusedOptimizerOps;
use crate::trainer::SimpleTrainer;
use crate::trainer::config::{TrainingConfig, TrainingMetrics};
use numr::autograd::{BackwardHook, Var, backward_with_hooks};
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps, UnaryOps};
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
pub struct BucketedTrainer<R: Runtime<DType = DType>> {
inner: SimpleTrainer<R>,
bucket_manager: GradientBucketManager<R>,
comm: Arc<dyn Communicator>,
}
impl<R: Runtime<DType = DType>> BucketedTrainer<R> {
pub fn new(
config: TrainingConfig,
comm: Arc<dyn Communicator>,
param_info: &[(TensorId, usize, DType)],
bucket_size_bytes: usize,
compute_stream_handle: Option<u64>,
) -> Result<Self> {
let inner = SimpleTrainer::new(config)?;
let bucket_manager = GradientBucketManager::new(
param_info,
comm.clone(),
bucket_size_bytes,
compute_stream_handle,
);
Ok(Self {
inner,
bucket_manager,
comm,
})
}
pub fn broadcast_params(&self, params: &HashMap<TensorId, Tensor<R>>) -> Result<()> {
broadcast_params(self.comm.as_ref(), params, 0)
}
pub fn backward_and_step<C>(
&mut self,
loss: &Var<R>,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
) -> Result<Option<TrainingMetrics>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
let loss_value =
loss.tensor()
.item::<f32>()
.map_err(|e| crate::error::Error::DistributedError {
reason: format!("failed to extract scalar loss: {e}"),
})? as f64;
self.bucket_manager.reset();
let manager_ptr = &mut self.bucket_manager as *mut GradientBucketManager<R>;
let mut hook = BucketHook::<R, C> {
manager: manager_ptr,
client,
};
let mut grads = backward_with_hooks(loss, client, &mut hook)?;
self.bucket_manager.wait_and_unflatten(client, &mut grads)?;
self.inner.step(client, params, grads, loss_value)
}
pub fn global_step(&self) -> u64 {
self.inner.global_step()
}
pub fn communicator(&self) -> &dyn Communicator {
self.comm.as_ref()
}
pub fn num_buckets(&self) -> usize {
self.bucket_manager.num_buckets()
}
}
struct BucketHook<'a, R: Runtime, C: RuntimeClient<R> + TensorOps<R>> {
manager: *mut GradientBucketManager<R>,
client: &'a C,
}
unsafe impl<R: Runtime + Send, C: RuntimeClient<R> + TensorOps<R>> Send for BucketHook<'_, R, C> {}
impl<R: Runtime<DType = DType>, C: RuntimeClient<R> + TensorOps<R>> BackwardHook<R>
for BucketHook<'_, R, C>
{
fn on_leaf_grad_ready(&mut self, id: TensorId, grad: &Tensor<R>) {
let manager = unsafe { &mut *self.manager };
if let Err(_e) = manager.mark_grad_ready(id, grad, self.client) {
debug_assert!(false, "mark_grad_ready failed for {id:?}: {_e}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::bucket_manager::GradientBucketManager;
use crate::test_utils::cpu_setup;
use crate::trainer::config::TrainingConfig;
use numr::autograd::{backward, var_mul, var_sum};
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_bucketed_trainer_creation() {
let comm = Arc::new(NoOpCommunicator);
let config = TrainingConfig {
learning_rate: 1e-3,
weight_decay: 0.01,
grad_accum_steps: 1,
max_grad_norm: Some(1.0),
};
let id = TensorId::new();
let params = vec![(id, 100, DType::F32)];
let trainer =
BucketedTrainer::<CpuRuntime>::new(config, comm, ¶ms, 25 * 1024 * 1024, None)
.unwrap();
assert_eq!(trainer.global_step(), 0);
assert_eq!(trainer.num_buckets(), 1);
}
#[test]
fn test_bucketed_backward_produces_same_grads_as_regular() {
let (client, device) = cpu_setup();
let w_id = TensorId::new();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let w_var = Var::with_id(w_tensor.clone(), w_id, true);
let sq = var_mul(&w_var, &w_var, &client).unwrap();
let loss = var_sum(&sq, &[0], false, &client).unwrap();
let regular_grads = backward(&loss, &client).unwrap();
let regular_grad: Vec<f32> = regular_grads
.get(w_id)
.expect("regular grad should exist")
.to_vec();
let w_var2 = Var::with_id(w_tensor.clone(), w_id, true);
let sq2 = var_mul(&w_var2, &w_var2, &client).unwrap();
let loss2 = var_sum(&sq2, &[0], false, &client).unwrap();
let comm = Arc::new(NoOpCommunicator);
let param_info = vec![(w_id, 3, DType::F32)];
let mut mgr =
GradientBucketManager::<CpuRuntime>::new(¶m_info, comm, 25 * 1024 * 1024, None);
let manager_ptr = &mut mgr as *mut GradientBucketManager<CpuRuntime>;
let mut hook = BucketHook {
manager: manager_ptr,
client: &client,
};
let mut grads = backward_with_hooks(&loss2, &client, &mut hook).unwrap();
mgr.wait_and_unflatten(&client, &mut grads).unwrap();
let bucketed_grad: Vec<f32> = grads
.get(w_id)
.expect("bucketed grad should exist")
.to_vec();
for (a, b) in regular_grad.iter().zip(bucketed_grad.iter()) {
assert!(
(a - b).abs() < 1e-6,
"Grad mismatch: regular={a}, bucketed={b}"
);
}
}
#[test]
fn test_bucketed_trainer_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BucketedTrainer<CpuRuntime>>();
}
}