Crate axonml_distributed

Crate axonml_distributed 

Source
Expand description

Axonml Distributed - Distributed Training Utilities

This crate provides distributed training functionality for the Axonml ML framework:

  • Backend: Communication backend abstraction with mock implementation for testing
  • ProcessGroup: Process group abstraction for managing distributed processes
  • DDP: DistributedDataParallel wrapper for synchronizing gradients
  • Communication: High-level communication utilities (all-reduce, broadcast, etc.)

§Example

use axonml_distributed::prelude::*;
use axonml_nn::Linear;

// Create a mock process group for testing
let world = World::mock();

// Wrap a model in DDP
let model = Linear::new(10, 5);
let ddp_model = DistributedDataParallel::new(model, world.default_group().clone());

// Synchronize gradients after backward pass
ddp_model.sync_gradients();

@version 0.1.0 @author AutomataNexus Development Team

Re-exports§

pub use backend::Backend;
pub use backend::MockBackend;
pub use backend::ReduceOp;
pub use comm::all_gather;
pub use comm::all_reduce_max;
pub use comm::all_reduce_mean;
pub use comm::all_reduce_min;
pub use comm::all_reduce_product;
pub use comm::all_reduce_sum;
pub use comm::barrier;
pub use comm::broadcast;
pub use comm::broadcast_from;
pub use comm::gather_tensor;
pub use comm::is_main_process;
pub use comm::rank;
pub use comm::reduce_scatter_mean;
pub use comm::reduce_scatter_sum;
pub use comm::scatter_tensor;
pub use comm::sync_gradient;
pub use comm::sync_gradients;
pub use comm::world_size;
pub use ddp::DistributedDataParallel;
pub use ddp::GradSyncStrategy;
pub use ddp::GradientBucket;
pub use ddp::GradientSynchronizer;
pub use process_group::ProcessGroup;
pub use process_group::World;

Modules§

backend
Backend - Communication Backend Abstractions
comm
Communication - High-level Communication Utilities
ddp
DDP - Distributed Data Parallel
prelude
Common imports for distributed training.
process_group
ProcessGroup - Process Group Abstraction

Type Aliases§

DDP
Type alias for DistributedDataParallel.