Expand description
Axonml Distributed - Distributed Training Utilities
Comprehensive distributed training support for scaling ML workloads across multiple GPUs and machines. Provides PyTorch-equivalent functionality.
§Features
§Data Parallelism
- DDP -
DistributedDataParallelfor gradient synchronization across replicas - FSDP - Fully Sharded Data Parallel with ZeRO-2 and ZeRO-3 optimizations
§Model Parallelism
- Pipeline Parallelism - Split model across devices with microbatching (GPipe-style)
- Tensor Parallelism - Layer-wise model sharding for large models
§Communication
- Collective Operations: all-reduce, all-gather, broadcast, reduce-scatter, barrier
- Point-to-Point: send, recv for direct tensor communication
- Process Groups: Flexible grouping for hierarchical parallelism
§Backends
- Mock backend for testing without real hardware
- Extensible Backend trait for NCCL, Gloo, MPI integration
§DDP Example
ⓘ
use axonml_distributed::prelude::*;
use axonml_nn::Linear;
let world = World::mock();
let model = Linear::new(10, 5);
let ddp_model = DistributedDataParallel::new(model, world.default_group().clone());
// Forward pass
let output = ddp_model.forward(&input);
loss.backward();
// Gradient sync happens automatically or manually:
ddp_model.sync_gradients();§FSDP Example (ZeRO-3)
ⓘ
use axonml_distributed::{FSDP, FSDPConfig, ShardingStrategy};
let config = FSDPConfig {
sharding_strategy: ShardingStrategy::FullShard, // ZeRO-3
cpu_offload: true,
..Default::default()
};
let fsdp_model = FSDP::new(model, process_group, config);
let output = fsdp_model.forward(&input);§Pipeline Parallelism Example
ⓘ
use axonml_distributed::{PipelineParallel, PipelineConfig, PipelineSchedule};
let config = PipelineConfig {
num_stages: 4,
num_microbatches: 8,
schedule: PipelineSchedule::GPipe,
..Default::default()
};
let pipeline = PipelineParallel::new(stages, process_group, config);
let output = pipeline.forward(&input);@version 0.2.6
@author AutomataNexus Development Team
Re-exports§
pub use backend::Backend;pub use backend::MockBackend;pub use backend::ReduceOp;pub use comm::all_gather;pub use comm::all_reduce_max;pub use comm::all_reduce_mean;pub use comm::all_reduce_min;pub use comm::all_reduce_product;pub use comm::all_reduce_sum;pub use comm::barrier;pub use comm::broadcast;pub use comm::broadcast_from;pub use comm::gather_tensor;pub use comm::is_main_process;pub use comm::rank;pub use comm::reduce_scatter_mean;pub use comm::reduce_scatter_sum;pub use comm::scatter_tensor;pub use comm::sync_gradient;pub use comm::sync_gradients;pub use comm::world_size;pub use ddp::DistributedDataParallel;pub use ddp::GradSyncStrategy;pub use ddp::GradientBucket;pub use ddp::GradientSynchronizer;pub use fsdp::ColumnParallelLinear;pub use fsdp::CPUOffload;pub use fsdp::FSDPMemoryStats;pub use fsdp::FullyShardedDataParallel;pub use fsdp::RowParallelLinear;pub use fsdp::ShardingStrategy;pub use pipeline::Pipeline;pub use pipeline::PipelineMemoryStats;pub use pipeline::PipelineSchedule;pub use pipeline::PipelineStage;pub use process_group::ProcessGroup;pub use process_group::World;
Modules§
- backend
- Backend - Communication Backend Abstractions
- comm
- Communication - High-level Communication Utilities
- ddp
- DDP - Distributed Data Parallel
- fsdp
- FSDP - Fully Sharded Data Parallel
- pipeline
- Pipeline Parallelism
- prelude
- Common imports for distributed training.
- process_
group ProcessGroup- Process Group Abstraction