Skip to main content

axonml_distributed/
lib.rs

1//! Axonml Distributed - Distributed Training Utilities
2//!
3//! Comprehensive distributed training support for scaling ML workloads across
4//! multiple GPUs and machines. Provides PyTorch-equivalent functionality.
5//!
6//! # Features
7//!
8//! ## Data Parallelism
9//! - **DDP** - `DistributedDataParallel` for gradient synchronization across replicas
10//! - **FSDP** - Fully Sharded Data Parallel with ZeRO-2 and ZeRO-3 optimizations
11//!
12//! ## Model Parallelism
13//! - **Pipeline Parallelism** - Split model across devices with microbatching (GPipe-style)
14//! - **Tensor Parallelism** - Layer-wise model sharding for large models
15//!
16//! ## Communication
17//! - **Collective Operations**: all-reduce, all-gather, broadcast, reduce-scatter, barrier
18//! - **Point-to-Point**: send, recv for direct tensor communication
19//! - **Process Groups**: Flexible grouping for hierarchical parallelism
20//!
21//! ## Backends
22//! - Mock backend for testing without real hardware
23//! - Extensible Backend trait for NCCL, Gloo, MPI integration
24//!
25//! # DDP Example
26//!
27//! ```ignore
28//! use axonml_distributed::prelude::*;
29//! use axonml_nn::Linear;
30//!
31//! let world = World::mock();
32//! let model = Linear::new(10, 5);
33//! let ddp_model = DistributedDataParallel::new(model, world.default_group().clone());
34//!
35//! // Forward pass
36//! let output = ddp_model.forward(&input);
37//! loss.backward();
38//!
39//! // Gradient sync happens automatically or manually:
40//! ddp_model.sync_gradients();
41//! ```
42//!
43//! # FSDP Example (ZeRO-3)
44//!
45//! ```ignore
46//! use axonml_distributed::{FSDP, FSDPConfig, ShardingStrategy};
47//!
48//! let config = FSDPConfig {
49//!     sharding_strategy: ShardingStrategy::FullShard, // ZeRO-3
50//!     cpu_offload: true,
51//!     ..Default::default()
52//! };
53//!
54//! let fsdp_model = FSDP::new(model, process_group, config);
55//! let output = fsdp_model.forward(&input);
56//! ```
57//!
58//! # Pipeline Parallelism Example
59//!
60//! ```ignore
61//! use axonml_distributed::{PipelineParallel, PipelineConfig, PipelineSchedule};
62//!
63//! let config = PipelineConfig {
64//!     num_stages: 4,
65//!     num_microbatches: 8,
66//!     schedule: PipelineSchedule::GPipe,
67//!     ..Default::default()
68//! };
69//!
70//! let pipeline = PipelineParallel::new(stages, process_group, config);
71//! let output = pipeline.forward(&input);
72//! ```
73//!
74//! @version 0.2.6
75//! @author `AutomataNexus` Development Team
76
77#![warn(missing_docs)]
78#![warn(clippy::all)]
79#![warn(clippy::pedantic)]
80// ML/tensor-specific allowances
81#![allow(clippy::cast_possible_truncation)]
82#![allow(clippy::cast_sign_loss)]
83#![allow(clippy::cast_precision_loss)]
84#![allow(clippy::cast_possible_wrap)]
85#![allow(clippy::missing_errors_doc)]
86#![allow(clippy::missing_panics_doc)]
87#![allow(clippy::must_use_candidate)]
88#![allow(clippy::module_name_repetitions)]
89#![allow(clippy::similar_names)]
90#![allow(clippy::many_single_char_names)]
91#![allow(clippy::too_many_arguments)]
92#![allow(clippy::doc_markdown)]
93#![allow(clippy::cast_lossless)]
94#![allow(clippy::needless_pass_by_value)]
95#![allow(clippy::redundant_closure_for_method_calls)]
96#![allow(clippy::uninlined_format_args)]
97#![allow(clippy::ptr_arg)]
98#![allow(clippy::return_self_not_must_use)]
99#![allow(clippy::not_unsafe_ptr_arg_deref)]
100#![allow(clippy::items_after_statements)]
101#![allow(clippy::unreadable_literal)]
102#![allow(clippy::if_same_then_else)]
103#![allow(clippy::needless_range_loop)]
104#![allow(clippy::trivially_copy_pass_by_ref)]
105#![allow(clippy::unnecessary_wraps)]
106#![allow(clippy::match_same_arms)]
107#![allow(clippy::unused_self)]
108#![allow(clippy::too_many_lines)]
109#![allow(clippy::single_match_else)]
110#![allow(clippy::fn_params_excessive_bools)]
111#![allow(clippy::struct_excessive_bools)]
112#![allow(clippy::format_push_string)]
113#![allow(clippy::erasing_op)]
114#![allow(clippy::type_repetition_in_bounds)]
115#![allow(clippy::iter_without_into_iter)]
116#![allow(clippy::should_implement_trait)]
117#![allow(clippy::use_debug)]
118#![allow(clippy::case_sensitive_file_extension_comparisons)]
119#![allow(clippy::large_enum_variant)]
120#![allow(clippy::panic)]
121#![allow(clippy::struct_field_names)]
122#![allow(clippy::missing_fields_in_debug)]
123#![allow(clippy::upper_case_acronyms)]
124#![allow(clippy::assigning_clones)]
125#![allow(clippy::option_if_let_else)]
126#![allow(clippy::manual_let_else)]
127#![allow(clippy::explicit_iter_loop)]
128#![allow(clippy::default_trait_access)]
129#![allow(clippy::only_used_in_recursion)]
130#![allow(clippy::manual_clamp)]
131#![allow(clippy::ref_option)]
132#![allow(clippy::multiple_bound_locations)]
133#![allow(clippy::comparison_chain)]
134#![allow(clippy::manual_assert)]
135#![allow(clippy::unnecessary_debug_formatting)]
136
137pub mod backend;
138pub mod comm;
139pub mod ddp;
140pub mod fsdp;
141pub mod pipeline;
142pub mod process_group;
143
144// =============================================================================
145// Re-exports
146// =============================================================================
147
148pub use backend::{Backend, MockBackend, ReduceOp};
149pub use comm::{
150    all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
151    all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
152    reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
153    world_size,
154};
155pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
156pub use fsdp::{
157    ColumnParallelLinear, CPUOffload, FSDPMemoryStats, FullyShardedDataParallel,
158    RowParallelLinear, ShardingStrategy,
159};
160pub use pipeline::{Pipeline, PipelineMemoryStats, PipelineSchedule, PipelineStage};
161pub use process_group::{ProcessGroup, World};
162
163// =============================================================================
164// Prelude
165// =============================================================================
166
167/// Common imports for distributed training.
168pub mod prelude {
169    pub use crate::{
170        all_gather,
171        all_reduce_max,
172        all_reduce_mean,
173        all_reduce_min,
174        all_reduce_product,
175        // Communication
176        all_reduce_sum,
177        barrier,
178        broadcast,
179        broadcast_from,
180        gather_tensor,
181        is_main_process,
182        rank,
183        reduce_scatter_mean,
184        reduce_scatter_sum,
185        scatter_tensor,
186        sync_gradient,
187        sync_gradients,
188        world_size,
189        // Backend
190        Backend,
191        // DDP
192        DistributedDataParallel,
193        GradSyncStrategy,
194        GradientBucket,
195        GradientSynchronizer,
196        MockBackend,
197        // FSDP
198        ColumnParallelLinear,
199        CPUOffload,
200        FullyShardedDataParallel,
201        RowParallelLinear,
202        ShardingStrategy,
203        // Process groups
204        ProcessGroup,
205        ReduceOp,
206        World,
207    };
208
209    pub use axonml_autograd::Variable;
210    pub use axonml_nn::Module;
211    pub use axonml_tensor::Tensor;
212}
213
214// =============================================================================
215// Type Aliases
216// =============================================================================
217
218/// Type alias for `DistributedDataParallel`.
219pub type DDP<M> = DistributedDataParallel<M>;
220
221/// Type alias for `FullyShardedDataParallel`.
222pub type FSDP<M> = FullyShardedDataParallel<M>;
223
224// =============================================================================
225// Tests
226// =============================================================================
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use axonml_autograd::Variable;
232    use axonml_nn::{Linear, Module};
233    use axonml_tensor::Tensor;
234    use std::sync::Arc;
235
236    #[test]
237    fn test_full_distributed_workflow() {
238        // Initialize world
239        let world = World::mock();
240        assert_eq!(world.rank(), 0);
241        assert!(world.is_main());
242
243        // Create model and wrap in DDP
244        let model = Linear::new(10, 5);
245        let mut ddp = DDP::new(model, world.default_group().clone());
246
247        // Forward pass
248        let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
249        let output = ddp.forward(&input);
250
251        assert_eq!(output.data().shape(), &[1, 5]);
252
253        // Train mode
254        ddp.train();
255        assert!(ddp.is_training());
256
257        // Sync parameters
258        ddp.sync_parameters();
259
260        // Sync gradients
261        ddp.sync_gradients();
262    }
263
264    #[test]
265    fn test_multiple_backends() {
266        let backends = MockBackend::create_world(4);
267
268        // All backends should have consistent world view
269        for (i, backend) in backends.iter().enumerate() {
270            assert_eq!(backend.rank(), i);
271            assert_eq!(backend.world_size(), 4);
272        }
273    }
274
275    #[test]
276    fn test_process_group_creation() {
277        let backends = MockBackend::create_world(2);
278        let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
279
280        assert_eq!(pg.size(), 2);
281        assert_eq!(pg.ranks().len(), 2);
282    }
283
284    #[test]
285    fn test_communication_functions() {
286        let pg = ProcessGroup::mock();
287
288        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
289
290        // All reduce
291        all_reduce_sum(&mut tensor, &pg);
292        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
293
294        // Broadcast
295        broadcast(&mut tensor, &pg);
296        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
297
298        // All gather
299        let gathered = all_gather(&tensor, &pg);
300        assert_eq!(gathered.shape(), &[1, 3]);
301
302        // Barrier
303        barrier(&pg);
304    }
305
306    #[test]
307    fn test_gradient_bucket_workflow() {
308        let mut bucket = GradientBucket::new(1000);
309
310        // Add gradients
311        let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
312        let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
313
314        bucket.add(&grad1);
315        bucket.add(&grad2);
316
317        assert_eq!(bucket.size(), 5);
318
319        // Extract
320        let tensors = bucket.extract();
321        assert_eq!(tensors.len(), 2);
322        assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
323        assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
324
325        // Clear
326        bucket.clear();
327        assert!(bucket.is_empty());
328    }
329
330    #[test]
331    fn test_gradient_synchronizer_workflow() {
332        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
333        sync.prepare(10);
334
335        // Add gradients
336        let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
337        sync.add_gradient(0, &grad);
338
339        // Sync
340        let pg = ProcessGroup::mock();
341        sync.sync_all(&pg);
342
343        // Clear
344        sync.clear();
345    }
346
347    #[test]
348    fn test_world_default_group() {
349        let world = World::mock();
350        let group = world.default_group();
351
352        assert_eq!(group.rank(), 0);
353        assert_eq!(group.world_size(), 1);
354    }
355
356    #[test]
357    fn test_world_new_subgroup() {
358        let world = World::mock();
359        let subgroup = world.new_group(vec![0]);
360
361        assert_eq!(subgroup.size(), 1);
362        assert!(subgroup.contains(0));
363    }
364
365    #[test]
366    fn test_ddp_builder_pattern() {
367        let model = Linear::new(10, 5);
368        let pg = ProcessGroup::mock();
369
370        let ddp = DDP::new(model, pg)
371            .broadcast_buffers(false)
372            .gradient_as_bucket_view(false);
373
374        // Linear defaults to training mode, DDP wraps it
375        assert!(ddp.is_training());
376    }
377
378    #[test]
379    fn test_reduce_op_all_variants() {
380        let op_sum = ReduceOp::Sum;
381        let op_prod = ReduceOp::Product;
382        let op_min = ReduceOp::Min;
383        let op_max = ReduceOp::Max;
384        let op_avg = ReduceOp::Average;
385
386        assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
387        assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
388        assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
389        assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
390        assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
391    }
392}