axonml-distributed
Overview
axonml-distributed provides distributed training utilities for the AxonML machine learning framework. It includes backend abstractions for communication, process group management, DistributedDataParallel (DDP) wrappers for multi-GPU training, and high-level communication primitives like all-reduce, broadcast, and gather operations.
Features
- Backend Abstraction - Pluggable communication backend trait with mock implementation for testing
- Process Groups - Manage distributed processes with rank and world size information
- DistributedDataParallel (DDP) - Wrap models for automatic gradient synchronization across processes
- Collective Operations - All-reduce, broadcast, all-gather, reduce-scatter, and barrier primitives
- Gradient Bucketing - Efficient gradient accumulation and synchronization with configurable bucket sizes
- Multiple Reduce Operations - Sum, product, min, max, and average reduction strategies
- Model Parallel Utilities - Tensor scattering and gathering for model parallelism
Modules
| Module | Description |
|---|---|
backend |
Communication backend trait and MockBackend implementation for testing |
process_group |
ProcessGroup and World abstractions for managing distributed processes |
ddp |
DistributedDataParallel wrapper, GradientBucket, and GradientSynchronizer |
comm |
High-level communication utilities (all_reduce, broadcast, gather, etc.) |
Usage
Add the dependency to your Cargo.toml:
[]
= "0.1.0"
Basic DDP Training
use *;
use Linear;
// Initialize distributed world
let world = mock; // Use mock for testing
// Create model and wrap in DDP
let model = new;
let mut ddp = new;
// Synchronize parameters from rank 0 at start of training
ddp.sync_parameters;
// Training loop
ddp.train;
for batch in data_loader.iter
Communication Primitives
use *;
let pg = mock;
// All-reduce operations
let mut tensor = from_vec.unwrap;
all_reduce_sum;
all_reduce_mean;
// Broadcast from rank 0
broadcast;
// All-gather across processes
let gathered = all_gather;
// Barrier synchronization
barrier;
// Query process information
let my_rank = rank;
let total_processes = world_size;
let is_main = is_main_process;
Gradient Synchronization
use *;
let pg = mock;
// Synchronize multiple gradients
let mut gradients = vec!;
sync_gradients;
// Or synchronize a single gradient
let mut grad = from_vec.unwrap;
sync_gradient;
Gradient Bucketing
use *;
// Create gradient bucket for efficient all-reduce
let mut bucket = new; // 1000 element capacity
let grad1 = from_vec.unwrap;
let grad2 = from_vec.unwrap;
bucket.add;
bucket.add;
// All-reduce the flattened bucket data
let pg = mock;
pg.backend.all_reduce;
// Extract synchronized gradients
let synced_grads = bucket.extract;
Custom Synchronization Strategy
use *;
let mut sync = new;
sync.prepare; // 10 parameters
// Add gradients during backward pass
let grad = from_vec.unwrap;
sync.add_gradient;
// Synchronize all buckets
let pg = mock;
sync.sync_all;
sync.clear;
Multi-Backend Setup
use *;
use Arc;
// Create multiple mock backends (simulates multi-process)
let backends = create_world;
// Each process creates its ProcessGroup
for backend in backends
Process Subgroups
use *;
let world = mock;
// Create a subgroup with specific ranks
let subgroup = world.new_group;
assert!;
assert_eq!;
Tests
Run the test suite:
License
Licensed under either of:
- MIT License (LICENSE-MIT or http://opensource.org/licenses/MIT)
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
at your option.