use std::collections::HashMap;
use std::sync::Arc;
use crate::distributed::comm_utils::all_reduce_tensor;
use crate::error::{Error, Result};
use numr::autograd::{GradStore, Var};
use numr::dtype::DType;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Communicator, ReduceOp, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
struct Bucket<R: Runtime> {
param_ids: Vec<TensorId>,
param_numels: Vec<usize>,
param_shapes: Vec<Vec<usize>>,
dtype: DType,
received_grads: HashMap<TensorId, Tensor<R>>,
flat_buffer: Option<Tensor<R>>,
allreduce_launched: bool,
completion_event: Option<u64>,
}
pub struct GradientBucketManager<R: Runtime> {
buckets: Vec<Bucket<R>>,
param_to_bucket: HashMap<TensorId, usize>,
comm: Arc<dyn Communicator>,
compute_stream_handle: Option<u64>,
}
impl<R: Runtime<DType = DType>> GradientBucketManager<R> {
pub fn new(
param_info: &[(TensorId, usize, DType)],
comm: Arc<dyn Communicator>,
bucket_size_bytes: usize,
compute_stream_handle: Option<u64>,
) -> Self {
let mut buckets = Vec::new();
let mut param_to_bucket = HashMap::new();
let mut current_ids = Vec::new();
let mut current_numels = Vec::new();
let mut current_bytes = 0usize;
let mut current_dtype = DType::F32;
for &(id, numel, dtype) in param_info {
let elem_bytes = dtype.size_in_bytes();
let param_bytes = numel * elem_bytes;
if !current_ids.is_empty()
&& (current_bytes + param_bytes > bucket_size_bytes || dtype != current_dtype)
{
let n = current_ids.len();
for &cid in ¤t_ids {
param_to_bucket.insert(cid, buckets.len());
}
buckets.push(Bucket {
param_ids: std::mem::take(&mut current_ids),
param_numels: std::mem::take(&mut current_numels),
param_shapes: Vec::with_capacity(n),
dtype: current_dtype,
received_grads: HashMap::new(),
flat_buffer: None,
allreduce_launched: false,
completion_event: None,
});
current_bytes = 0;
}
current_ids.push(id);
current_numels.push(numel);
current_bytes += param_bytes;
current_dtype = dtype;
}
if !current_ids.is_empty() {
let n = current_ids.len();
for &cid in ¤t_ids {
param_to_bucket.insert(cid, buckets.len());
}
buckets.push(Bucket {
param_ids: current_ids,
param_numels: current_numels,
param_shapes: Vec::with_capacity(n),
dtype: current_dtype,
received_grads: HashMap::new(),
flat_buffer: None,
allreduce_launched: false,
completion_event: None,
});
}
let overlap_handle = if comm.as_stream_sync().is_some() {
compute_stream_handle
} else {
None
};
Self {
buckets,
param_to_bucket,
comm,
compute_stream_handle: overlap_handle,
}
}
pub fn mark_grad_ready<C>(&mut self, id: TensorId, grad: &Tensor<R>, client: &C) -> Result<()>
where
C: RuntimeClient<R> + TensorOps<R>,
{
let bucket_idx = match self.param_to_bucket.get(&id) {
Some(&idx) => idx,
None => return Ok(()), };
let bucket = &mut self.buckets[bucket_idx];
if bucket.allreduce_launched {
return Ok(()); }
bucket.received_grads.insert(id, grad.clone());
if bucket.received_grads.len() < bucket.param_ids.len() {
return Ok(());
}
self.flatten_and_allreduce(bucket_idx, client)
}
fn flatten_and_allreduce<C>(&mut self, bucket_idx: usize, client: &C) -> Result<()>
where
C: RuntimeClient<R> + TensorOps<R>,
{
let bucket = &mut self.buckets[bucket_idx];
for &pid in &bucket.param_ids {
if let Some(g) = bucket.received_grads.get(&pid) {
if g.dtype() != bucket.dtype {
return Err(Error::DistributedError {
reason: format!(
"dtype mismatch in bucket {bucket_idx}: expected {:?}, got {:?}",
bucket.dtype,
g.dtype()
),
});
}
}
}
bucket.param_shapes.clear();
let mut flat_grads: Vec<Tensor<R>> = Vec::with_capacity(bucket.param_ids.len());
for &pid in &bucket.param_ids {
let g = bucket
.received_grads
.get(&pid)
.ok_or_else(|| Error::DistributedError {
reason: format!("gradient missing for param in bucket {bucket_idx}"),
})?;
bucket.param_shapes.push(g.shape().to_vec());
let flat = g.flatten().map_err(|e| Error::DistributedError {
reason: format!("flatten gradient failed: {e}"),
})?;
flat_grads.push(flat);
}
let refs: Vec<&Tensor<R>> = flat_grads.iter().collect();
let flat_buffer = client.cat(&refs, 0).map_err(|e| Error::DistributedError {
reason: format!("cat gradients failed: {e}"),
})?;
if let Some(compute_stream) = self.compute_stream_handle {
let sync = self
.comm
.as_stream_sync()
.expect("compute_stream_handle is Some only when as_stream_sync() is Some");
let ready_event = sync.create_event().map_err(|e| Error::DistributedError {
reason: format!("create ready event failed: {e}"),
})?;
let overlap_result = (|| -> Result<u64> {
sync.record_on_stream(ready_event, compute_stream)
.map_err(|e| Error::DistributedError {
reason: format!("record ready event failed: {e}"),
})?;
sync.comm_stream_wait_event(ready_event)
.map_err(|e| Error::DistributedError {
reason: format!("comm stream wait for ready event failed: {e}"),
})?;
all_reduce_tensor(self.comm.as_ref(), &flat_buffer, ReduceOp::Sum)?;
let completion_event =
sync.create_event().map_err(|e| Error::DistributedError {
reason: format!("create completion event failed: {e}"),
})?;
if let Err(e) = sync.record_on_comm_stream(completion_event) {
let _ = sync.destroy_event(completion_event);
return Err(Error::DistributedError {
reason: format!("record completion event failed: {e}"),
});
}
Ok(completion_event)
})();
let _ = sync.destroy_event(ready_event);
bucket.completion_event = Some(overlap_result?);
} else {
all_reduce_tensor(self.comm.as_ref(), &flat_buffer, ReduceOp::Sum)?;
}
bucket.flat_buffer = Some(flat_buffer);
bucket.allreduce_launched = true;
Ok(())
}
pub fn wait_and_unflatten<C>(&mut self, client: &C, grads: &mut GradStore<R>) -> Result<()>
where
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R>,
{
let world_size = self.comm.world_size();
let scale = 1.0 / world_size as f64;
if let Some(compute_stream) = self.compute_stream_handle {
let sync = self
.comm
.as_stream_sync()
.expect("compute_stream_handle is Some only when as_stream_sync() is Some");
for bucket in &mut self.buckets {
if let Some(event) = bucket.completion_event.take() {
if let Err(e) = sync.stream_wait_event(compute_stream, event) {
let _ = sync.destroy_event(event);
return Err(Error::DistributedError {
reason: format!("compute stream wait for completion event failed: {e}"),
});
}
let _ = sync.destroy_event(event);
}
}
} else {
self.comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after allreduce failed: {e}"),
})?;
}
for bucket in &mut self.buckets {
let flat_buffer = match bucket.flat_buffer.take() {
Some(buf) => buf,
None => continue,
};
let mut offset = 0usize;
for (i, &pid) in bucket.param_ids.iter().enumerate() {
let numel = bucket.param_numels[i];
let shape = &bucket.param_shapes[i];
let flat_grad =
flat_buffer
.narrow(0, offset, numel)
.map_err(|e| Error::DistributedError {
reason: format!("narrow failed during unflatten: {e}"),
})?;
let reshaped = flat_grad
.reshape(shape)
.map_err(|e| Error::DistributedError {
reason: format!("reshape failed during unflatten: {e}"),
})?;
let averaged = if world_size > 1 {
client.mul_scalar(&reshaped, scale)?
} else {
reshaped
};
grads.insert(pid, averaged);
offset += numel;
}
}
Ok(())
}
pub fn reset(&mut self) {
let sync = self.comm.as_stream_sync();
for bucket in &mut self.buckets {
bucket.received_grads.clear();
bucket.allreduce_launched = false;
bucket.flat_buffer = None;
bucket.param_shapes.clear();
if let Some(event) = bucket.completion_event.take() {
if let Some(s) = sync {
let _ = s.destroy_event(event);
}
}
}
}
pub fn num_buckets(&self) -> usize {
self.buckets.len()
}
}
pub fn param_order_from_graph<R: Runtime>(loss: &Var<R>) -> Vec<TensorId> {
use std::collections::HashSet;
let mut topo = Vec::new();
let mut visited = HashSet::new();
fn dfs<R: Runtime>(
id: TensorId,
grad_fn: Option<Arc<dyn numr::autograd::GradFn<R>>>,
visited: &mut HashSet<TensorId>,
topo: &mut Vec<(TensorId, bool)>, ) {
if visited.contains(&id) {
return;
}
visited.insert(id);
let input_ids: Vec<TensorId> = grad_fn
.as_ref()
.map(|gf| gf.inputs().to_vec())
.unwrap_or_default();
if let Some(gf) = &grad_fn {
for (input_id, input_grad_fn) in input_ids.iter().zip(gf.input_grad_fns()) {
dfs(*input_id, input_grad_fn, visited, topo);
}
}
topo.push((id, grad_fn.is_none()));
}
dfs(loss.id(), loss.grad_fn().cloned(), &mut visited, &mut topo);
topo.into_iter()
.rev()
.filter(|(_, is_leaf)| *is_leaf)
.map(|(id, _)| id)
.collect()
}
#[cfg(test)]
#[path = "bucket_manager_tests.rs"]
mod tests;