axonml-distributed
Overview
axonml-distributed provides distributed training primitives for AxonML:
data parallelism (DDP), fully sharded data parallelism (FSDP, ZeRO-2 /
ZeRO-3 with HybridShard and CPU offload), pipeline parallelism (Pipeline
with GPipe / 1F1B / Interleaved 1F1B schedules), tensor parallelism
(ColumnParallelLinear, RowParallelLinear), collective ops (all-reduce,
broadcast, gather/scatter, reduce-scatter, barrier), a pluggable Backend
trait, a deterministic MockBackend for tests, and an optional NCCL backend
behind the nccl feature.
Features
- Backend Abstraction —
Backendtrait withMockBackend(in-process shared-state simulation) and optionalNcclBackend(dynamiclibcudart/libncclloading vialibloading) - Process Groups —
ProcessGroup/Worldabstractions with rank, world size, subgroups, default group - DistributedDataParallel (DDP) — model wrapper with gradient bucketing, sync strategies (
Synchronous,Overlapped,NoSync), parameter broadcast, buffer sync toggle,DDP<M>type alias - FullyShardedDataParallel (FSDP) — parameter sharding with
FullShard(ZeRO-3),ShardGradOp(ZeRO-2),NoShard,HybridShard; optionalCPUOffload::{None, Params, Full}; mixed precision toggle;gather_parameters/reshard_parameters;clip_grad_norm;FSDPMemoryStatsdiagnostics;FSDP<M>type alias - Pipeline Parallelism —
PipelinewithGPipe,OneFOneBSchedule(default),InterleavedOneFOneB;PipelineStage,PipelineMemoryStatswithgpipe_peak_activationsandone_f_one_b_peak_activations - Tensor Parallelism —
ColumnParallelLinear,RowParallelLinear - Collective Operations —
all_reduce_{sum,mean,min,max,product},broadcast,broadcast_from,all_gather,reduce_scatter_sum,reduce_scatter_mean,gather_tensor,scatter_tensor,barrier,sync_gradient,sync_gradients,rank,world_size,is_main_process - Reduce Operations —
ReduceOp::{Sum, Product, Min, Max, Average} - Gradient Bucketing —
GradientBucket,GradientSynchronizer,GradSyncStrategy
Modules
| Module | Description |
|---|---|
backend |
Backend trait, MockBackend, ReduceOp |
process_group |
ProcessGroup, World (with new_group subgroups, default_group, mock constructor) |
comm |
Collective ops (all_reduce_*, broadcast*, all_gather, reduce_scatter_*, gather_tensor, scatter_tensor, barrier, sync_gradient(s), query helpers) |
ddp |
DistributedDataParallel, GradientBucket, GradientSynchronizer, GradSyncStrategy |
fsdp |
FullyShardedDataParallel, ShardingStrategy, CPUOffload, FSDPMemoryStats, ColumnParallelLinear, RowParallelLinear |
pipeline |
Pipeline, PipelineStage, PipelineSchedule, PipelineMemoryStats |
nccl_backend (feature: nccl) |
NcclBackend, NcclUniqueId, NcclError, version/device query, multi-node init |
Features Flags
| Flag | Effect |
|---|---|
nccl |
Enables the NcclBackend module and pulls in libloading for runtime NCCL discovery |
Usage
Add the dependency to your Cargo.toml:
[]
= "0.6.1"
# Or with NCCL support:
= { = "0.6.1", = ["nccl"] }
Basic DDP Training
use *;
use Linear;
// Initialize distributed world
let world = mock; // Use mock for testing
// Create model and wrap in DDP
let model = new;
let mut ddp = new;
// Synchronize parameters from rank 0 at start of training
ddp.sync_parameters;
// Training loop
ddp.train;
for batch in data_loader.iter
DDP Builder
use *;
let ddp = DDPnew
.broadcast_buffers
.gradient_as_bucket_view;
FSDP (ZeRO-3 / ZeRO-2 / Hybrid)
use *;
let fsdp = new
.sharding_strategy // ZeRO-3
.cpu_offload
.mixed_precision;
// Gather full parameters for a forward pass, then reshard
fsdp.gather_parameters;
// ... forward / backward ...
fsdp.reshard_parameters;
fsdp.sync_gradients;
// Gradient clipping
let grad_norm = fsdp.clip_grad_norm;
// Memory accounting
let stats = fsdp.memory_estimate;
println!;
Pipeline Parallelism
use *;
let pipe = from_modules
.schedule
.num_microbatches;
let output = pipe.forward;
// Memory accounting
let peak = one_f_one_b_peak_activations;
Tensor Parallelism
use *;
// Shard along output dimension
let col = new;
// Shard along input dimension
let row = new;
Communication Primitives
use *;
let pg = mock;
// All-reduce operations
let mut tensor = from_vec.unwrap;
all_reduce_sum;
all_reduce_mean;
all_reduce_max;
all_reduce_min;
all_reduce_product;
// Broadcast from rank 0 (or from any source rank)
broadcast;
broadcast_from;
// All-gather / reduce-scatter / gather / scatter
let gathered = all_gather;
let scattered_sum = reduce_scatter_sum;
let scattered_mean = reduce_scatter_mean;
let _ = gather_tensor;
let _ = scatter_tensor;
// Barrier synchronization
barrier;
// Query process information
let my_rank = rank;
let total_processes = world_size;
let is_main = is_main_process;
Gradient Synchronization
use *;
// Synchronize multiple gradients
let mut gradients = vec!;
sync_gradients;
// Or synchronize a single gradient
let mut grad = from_vec.unwrap;
sync_gradient;
Gradient Bucketing
use *;
// Create gradient bucket for efficient all-reduce
let mut bucket = new; // 1000 element capacity
let grad1 = from_vec.unwrap;
let grad2 = from_vec.unwrap;
bucket.add;
bucket.add;
// Extract synchronized gradients
let synced_grads = bucket.extract;
Custom Synchronization Strategy
use *;
let mut sync = new;
sync.prepare; // 10 parameters
// Add gradients during backward pass
let grad = from_vec.unwrap;
sync.add_gradient;
// Synchronize all buckets
sync.sync_all;
sync.clear;
Multi-Backend Setup
use *;
use Arc;
// Create multiple mock backends (simulates multi-process)
let backends = create_world;
// Each process creates its ProcessGroup
for backend in backends
Process Subgroups
use *;
let world = mock;
// Create a subgroup with specific ranks
let subgroup = world.new_group;
assert!;
assert_eq!;
NCCL Backend (feature-gated)
use ;
#
#
Tests
License
Licensed under either of:
- MIT License (LICENSE-MIT or http://opensource.org/licenses/MIT)
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
at your option.