Skip to main content

Crate axonml_distributed

Crate axonml_distributed 

Source
Expand description

Axonml Distributed - Distributed Training Utilities

Comprehensive distributed training support for scaling ML workloads across multiple GPUs and machines. Provides PyTorch-equivalent functionality.

§Features

§Data Parallelism

  • DDP - DistributedDataParallel for gradient synchronization across replicas
  • FSDP - Fully Sharded Data Parallel with ZeRO-2 and ZeRO-3 optimizations

§Model Parallelism

  • Pipeline Parallelism - Split model across devices with microbatching (GPipe-style)
  • Tensor Parallelism - Layer-wise model sharding for large models

§Communication

  • Collective Operations: all-reduce, all-gather, broadcast, reduce-scatter, barrier
  • Point-to-Point: send, recv for direct tensor communication
  • Process Groups: Flexible grouping for hierarchical parallelism

§Backends

  • Mock backend for testing without real hardware
  • Extensible Backend trait for NCCL, Gloo, MPI integration

§DDP Example

use axonml_distributed::prelude::*;
use axonml_nn::Linear;

let world = World::mock();
let model = Linear::new(10, 5);
let ddp_model = DistributedDataParallel::new(model, world.default_group().clone());

// Forward pass
let output = ddp_model.forward(&input);
loss.backward();

// Gradient sync happens automatically or manually:
ddp_model.sync_gradients();

§FSDP Example (ZeRO-3)

use axonml_distributed::{FSDP, FSDPConfig, ShardingStrategy};

let config = FSDPConfig {
    sharding_strategy: ShardingStrategy::FullShard, // ZeRO-3
    cpu_offload: true,
    ..Default::default()
};

let fsdp_model = FSDP::new(model, process_group, config);
let output = fsdp_model.forward(&input);

§Pipeline Parallelism Example

use axonml_distributed::{PipelineParallel, PipelineConfig, PipelineSchedule};

let config = PipelineConfig {
    num_stages: 4,
    num_microbatches: 8,
    schedule: PipelineSchedule::GPipe,
    ..Default::default()
};

let pipeline = PipelineParallel::new(stages, process_group, config);
let output = pipeline.forward(&input);

@version 0.2.6 @author AutomataNexus Development Team

Re-exports§

pub use backend::Backend;
pub use backend::MockBackend;
pub use backend::ReduceOp;
pub use comm::all_gather;
pub use comm::all_reduce_max;
pub use comm::all_reduce_mean;
pub use comm::all_reduce_min;
pub use comm::all_reduce_product;
pub use comm::all_reduce_sum;
pub use comm::barrier;
pub use comm::broadcast;
pub use comm::broadcast_from;
pub use comm::gather_tensor;
pub use comm::is_main_process;
pub use comm::rank;
pub use comm::reduce_scatter_mean;
pub use comm::reduce_scatter_sum;
pub use comm::scatter_tensor;
pub use comm::sync_gradient;
pub use comm::sync_gradients;
pub use comm::world_size;
pub use ddp::DistributedDataParallel;
pub use ddp::GradSyncStrategy;
pub use ddp::GradientBucket;
pub use ddp::GradientSynchronizer;
pub use fsdp::ColumnParallelLinear;
pub use fsdp::CPUOffload;
pub use fsdp::FSDPMemoryStats;
pub use fsdp::FullyShardedDataParallel;
pub use fsdp::RowParallelLinear;
pub use fsdp::ShardingStrategy;
pub use pipeline::Pipeline;
pub use pipeline::PipelineMemoryStats;
pub use pipeline::PipelineSchedule;
pub use pipeline::PipelineStage;
pub use process_group::ProcessGroup;
pub use process_group::World;

Modules§

backend
Backend - Communication Backend Abstractions
comm
Communication - High-level Communication Utilities
ddp
DDP - Distributed Data Parallel
fsdp
FSDP - Fully Sharded Data Parallel
pipeline
Pipeline Parallelism
prelude
Common imports for distributed training.
process_group
ProcessGroup - Process Group Abstraction

Type Aliases§

DDP
Type alias for DistributedDataParallel.
FSDP
Type alias for FullyShardedDataParallel.