use crate::backend::ReduceOp;
use crate::process_group::ProcessGroup;
use axonml_autograd::Variable;
use axonml_nn::{Module, Parameter};
use axonml_tensor::Tensor;
pub struct DistributedDataParallel<M: Module> {
module: M,
process_group: ProcessGroup,
broadcast_buffers: bool,
gradient_as_bucket_view: bool,
}
impl<M: Module> DistributedDataParallel<M> {
pub fn new(module: M, process_group: ProcessGroup) -> Self {
Self {
module,
process_group,
broadcast_buffers: true,
gradient_as_bucket_view: true,
}
}
pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
self.broadcast_buffers = broadcast;
self
}
pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
self.gradient_as_bucket_view = bucket_view;
self
}
pub fn module(&self) -> &M {
&self.module
}
pub fn module_mut(&mut self) -> &mut M {
&mut self.module
}
pub fn process_group(&self) -> &ProcessGroup {
&self.process_group
}
pub fn sync_parameters(&mut self) {
for param in self.module.parameters() {
let mut tensor = param.data().clone();
self.process_group.broadcast_tensor(&mut tensor, 0);
}
}
pub fn sync_gradients(&self) {
for param in self.module.parameters() {
if let Some(grad) = param.grad() {
let mut grad_tensor = grad.clone();
self.process_group
.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
}
}
}
pub fn forward(&self, input: &Variable) -> Variable {
self.module.forward(input)
}
}
impl<M: Module> Module for DistributedDataParallel<M> {
fn forward(&self, input: &Variable) -> Variable {
self.module.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.module.parameters()
}
fn train(&mut self) {
self.module.train();
}
fn eval(&mut self) {
self.module.eval();
}
fn is_training(&self) -> bool {
self.module.is_training()
}
}
pub struct GradientBucket {
data: Vec<f32>,
shapes: Vec<(Vec<usize>, usize)>,
capacity: usize,
}
impl GradientBucket {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
shapes: Vec::new(),
capacity,
}
}
#[must_use]
pub fn is_full(&self) -> bool {
self.data.len() >= self.capacity
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn size(&self) -> usize {
self.data.len()
}
pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
let data = tensor.to_vec();
if self.data.len() + data.len() > self.capacity {
return false;
}
self.shapes.push((tensor.shape().to_vec(), data.len()));
self.data.extend(data);
true
}
#[must_use]
pub fn data(&self) -> &[f32] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [f32] {
&mut self.data
}
pub fn clear(&mut self) {
self.data.clear();
self.shapes.clear();
}
#[must_use]
pub fn extract(&self) -> Vec<Tensor<f32>> {
let mut result = Vec::new();
let mut offset = 0;
for (shape, size) in &self.shapes {
let end = offset + size;
let data = self.data[offset..end].to_vec();
result.push(Tensor::from_vec(data, shape).unwrap());
offset = end;
}
result
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GradSyncStrategy {
Synchronous,
Overlapped,
NoSync,
}
pub struct GradientSynchronizer {
strategy: GradSyncStrategy,
bucket_size: usize,
buckets: Vec<GradientBucket>,
}
impl GradientSynchronizer {
#[must_use]
pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
Self {
strategy,
bucket_size,
buckets: Vec::new(),
}
}
#[must_use]
pub fn strategy(&self) -> GradSyncStrategy {
self.strategy
}
pub fn prepare(&mut self, num_params: usize) {
let num_buckets = num_params.div_ceil(self.bucket_size);
self.buckets = (0..num_buckets)
.map(|_| GradientBucket::new(self.bucket_size))
.collect();
}
pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
if bucket_idx < self.buckets.len() {
self.buckets[bucket_idx].add(tensor);
}
}
pub fn sync_all(&mut self, process_group: &ProcessGroup) {
if self.strategy == GradSyncStrategy::NoSync {
return;
}
for bucket in &mut self.buckets {
if !bucket.is_empty() {
let mut data = bucket.data().to_vec();
let len = data.len();
process_group
.backend()
.all_reduce(&mut data, ReduceOp::Average);
bucket.data_mut()[..len].copy_from_slice(&data);
}
}
}
pub fn clear(&mut self) {
for bucket in &mut self.buckets {
bucket.clear();
}
}
}
impl Default for GradientSynchronizer {
fn default() -> Self {
Self::new(GradSyncStrategy::Synchronous, 25_000_000) }
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_nn::Linear;
#[test]
fn test_ddp_creation() {
let module = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let ddp = DistributedDataParallel::new(module, pg);
assert_eq!(ddp.process_group().rank(), 0);
assert_eq!(ddp.process_group().world_size(), 1);
}
#[test]
fn test_ddp_forward() {
let module = Linear::new(4, 2);
let pg = ProcessGroup::mock();
let ddp = DistributedDataParallel::new(module, pg);
let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
let output = ddp.forward(&input);
assert_eq!(output.data().shape(), &[1, 2]);
}
#[test]
fn test_ddp_module_access() {
let module = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let mut ddp = DistributedDataParallel::new(module, pg);
let _ = ddp.module();
let _ = ddp.module_mut();
}
#[test]
fn test_ddp_train_eval() {
let module = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let mut ddp = DistributedDataParallel::new(module, pg);
assert!(ddp.is_training());
ddp.train();
ddp.eval();
let _ = ddp.is_training();
}
#[test]
fn test_ddp_parameters() {
let module = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let ddp = DistributedDataParallel::new(module, pg);
let params = ddp.parameters();
assert!(!params.is_empty());
}
#[test]
fn test_gradient_bucket() {
let mut bucket = GradientBucket::new(100);
assert!(bucket.is_empty());
let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
assert!(bucket.add(&tensor1));
assert!(!bucket.is_empty());
assert_eq!(bucket.size(), 3);
let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
assert!(bucket.add(&tensor2));
assert_eq!(bucket.size(), 5);
assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_gradient_bucket_extract() {
let mut bucket = GradientBucket::new(100);
let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
bucket.add(&tensor1);
bucket.add(&tensor2);
let extracted = bucket.extract();
assert_eq!(extracted.len(), 2);
assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
}
#[test]
fn test_gradient_bucket_full() {
let mut bucket = GradientBucket::new(5);
let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
assert!(bucket.add(&tensor1));
let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
assert!(!bucket.add(&tensor2)); }
#[test]
fn test_gradient_bucket_clear() {
let mut bucket = GradientBucket::new(100);
let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
bucket.add(&tensor);
bucket.clear();
assert!(bucket.is_empty());
}
#[test]
fn test_gradient_synchronizer() {
let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
sync.prepare(10);
assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
}
#[test]
fn test_gradient_synchronizer_no_sync() {
let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
sync.prepare(10);
let pg = ProcessGroup::mock();
sync.sync_all(&pg); }
#[test]
fn test_gradient_synchronizer_default() {
let sync = GradientSynchronizer::default();
assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
}
#[test]
fn test_grad_sync_strategy() {
assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
}
}