torsh-distributed
Distributed training support for ToRSh with PyTorch-compatible API.
Overview
This crate provides distributed and parallel training capabilities including:
- Data Parallel Training: DistributedDataParallel (DDP)
- Model Parallel Training: Pipeline and tensor parallelism
- Communication Backends: NCCL, Gloo, MPI support
- RPC Framework: Remote procedure calls for distributed computing
- Collective Operations: All-reduce, broadcast, gather, scatter
Usage
Basic Distributed Training
use *;
use *;
// Initialize process group
init_process_group?;
// Get rank and world size
let rank = get_rank;
let world_size = get_world_size;
// Create model and wrap with DDP
let model = create_model;
let ddp_model = new?;
// Distributed optimizer
let optimizer = new?;
// Training loop
for epoch in 0..num_epochs
// Cleanup
destroy_process_group?;
Collective Operations
use *;
// All-reduce: sum tensors across all processes
let tensor = create_tensor;
all_reduce?;
// Broadcast: send tensor from rank 0 to all others
broadcast?;
// Gather: collect tensors from all ranks
let gathered = all_gather?;
// Scatter: distribute chunks to different ranks
let chunks = scatter?;
// Reduce: aggregate to specific rank
reduce?;
RPC Framework
use *;
// Initialize RPC
init_rpc?;
// Remote procedure call
let future = rpc_async?;
// Get result
let result = future.wait?;
// Remote reference
let rref = remote?;
let local_value = rref.to_here?;
// Shutdown RPC
shutdown_rpc?;
Pipeline Parallelism
use *;
// Split model into stages
let stages = vec!;
// Create pipeline
let pipeline = new?;
// Forward with micro-batching
let output = pipeline.forward?;
Model Parallel
use *;
// Tensor parallel linear layer
let tp_linear = new?;
// Attention with tensor parallelism
let tp_attention = new?;
Gradient Compression
use *;
// Configure gradient compression
let compressor = new
.algorithm
.memory;
// Apply to DDP
let ddp_model = new
.with_compression?;
Fault Tolerance
use *;
// Elastic training with dynamic workers
let elastic_agent = new
.min_workers
.max_workers
.checkpoint_dir;
elastic_agent.run?;
Monitoring
use *;
// Track distributed metrics
let monitor = new;
// Log communication time
monitor.log_comm_time;
// Get statistics
let stats = monitor.get_stats;
println!;
Backends
NCCL (NVIDIA GPUs)
- Optimized for NVIDIA GPU communication
- Supports GPUDirect and NVLink
- Best for single-node multi-GPU
Gloo (CPU and GPU)
- Cross-platform communication
- Supports both TCP and InfiniBand
- Good for CPU training
MPI (HPC environments)
- Integration with MPI implementations
- Optimized for HPC clusters
- Supports various interconnects
Environment Variables
# Basic setup
# NCCL specific
# Gloo specific
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.