Skip to main content

axonml_distributed/
lib.rs

1//! Distributed training for AxonML — data, model, pipeline, and tensor parallelism.
2//!
3//! `DDP` (DistributedDataParallel with gradient bucketing), `FSDP` (Fully
4//! Sharded Data Parallel — ZeRO-2/ZeRO-3 + HybridShard + CPU offload),
5//! `Pipeline` (GPipe/1F1B/Interleaved microbatch scheduling), collective ops
6//! (all-reduce with 5 strategies, broadcast, all-gather, reduce-scatter,
7//! gather, scatter, reduce, send/recv, barrier), `ProcessGroup` / `World`
8//! abstraction, `NcclBackend` (dynamic libcudart/libnccl loading, multi-node
9//! init via NcclUniqueId), and `MockBackend` (shared-state in-process
10//! simulation for deterministic testing).
11//!
12//! # File
13//! `crates/axonml-distributed/src/lib.rs`
14//!
15//! # Author
16//! Andrew Jewell Sr. — AutomataNexus LLC
17//! ORCID: 0009-0005-2158-7060
18//!
19//! # Updated
20//! April 14, 2026 11:15 PM EST
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27#![warn(missing_docs)]
28#![warn(clippy::all)]
29#![warn(clippy::pedantic)]
30// ML/tensor-specific allowances
31#![allow(clippy::cast_possible_truncation)]
32#![allow(clippy::cast_sign_loss)]
33#![allow(clippy::cast_precision_loss)]
34#![allow(clippy::cast_possible_wrap)]
35#![allow(clippy::missing_errors_doc)]
36#![allow(clippy::missing_panics_doc)]
37#![allow(clippy::must_use_candidate)]
38#![allow(clippy::module_name_repetitions)]
39#![allow(clippy::similar_names)]
40#![allow(clippy::many_single_char_names)]
41#![allow(clippy::too_many_arguments)]
42#![allow(clippy::doc_markdown)]
43#![allow(clippy::cast_lossless)]
44#![allow(clippy::needless_pass_by_value)]
45#![allow(clippy::redundant_closure_for_method_calls)]
46#![allow(clippy::uninlined_format_args)]
47#![allow(clippy::ptr_arg)]
48#![allow(clippy::return_self_not_must_use)]
49#![allow(clippy::not_unsafe_ptr_arg_deref)]
50#![allow(clippy::items_after_statements)]
51#![allow(clippy::unreadable_literal)]
52#![allow(clippy::if_same_then_else)]
53#![allow(clippy::needless_range_loop)]
54#![allow(clippy::trivially_copy_pass_by_ref)]
55#![allow(clippy::unnecessary_wraps)]
56#![allow(clippy::match_same_arms)]
57#![allow(clippy::unused_self)]
58#![allow(clippy::too_many_lines)]
59#![allow(clippy::single_match_else)]
60#![allow(clippy::fn_params_excessive_bools)]
61#![allow(clippy::struct_excessive_bools)]
62#![allow(clippy::format_push_string)]
63#![allow(clippy::erasing_op)]
64#![allow(clippy::type_repetition_in_bounds)]
65#![allow(clippy::iter_without_into_iter)]
66#![allow(clippy::should_implement_trait)]
67#![allow(clippy::use_debug)]
68#![allow(clippy::case_sensitive_file_extension_comparisons)]
69#![allow(clippy::large_enum_variant)]
70#![allow(clippy::panic)]
71#![allow(clippy::struct_field_names)]
72#![allow(clippy::missing_fields_in_debug)]
73#![allow(clippy::upper_case_acronyms)]
74#![allow(clippy::assigning_clones)]
75#![allow(clippy::option_if_let_else)]
76#![allow(clippy::manual_let_else)]
77#![allow(clippy::explicit_iter_loop)]
78#![allow(clippy::default_trait_access)]
79#![allow(clippy::only_used_in_recursion)]
80#![allow(clippy::manual_clamp)]
81#![allow(clippy::ref_option)]
82#![allow(clippy::multiple_bound_locations)]
83#![allow(clippy::comparison_chain)]
84#![allow(clippy::manual_assert)]
85#![allow(clippy::unnecessary_debug_formatting)]
86#![allow(clippy::ptr_as_ptr)]
87#![allow(clippy::ptr_cast_constness)]
88#![allow(clippy::manual_slice_size_calculation)]
89#![allow(clippy::needless_lifetimes)]
90
91pub mod backend;
92pub mod comm;
93pub mod ddp;
94pub mod fsdp;
95#[cfg(feature = "nccl")]
96pub mod nccl_backend;
97pub mod pipeline;
98pub mod process_group;
99
100// =============================================================================
101// Re-exports
102// =============================================================================
103
104pub use backend::{Backend, MockBackend, ReduceOp};
105pub use comm::{
106    all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
107    all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
108    reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
109    world_size,
110};
111pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
112pub use fsdp::{
113    CPUOffload, ColumnParallelLinear, FSDPMemoryStats, FullyShardedDataParallel, RowParallelLinear,
114    ShardingStrategy,
115};
116#[cfg(feature = "nccl")]
117pub use nccl_backend::{NcclBackend, NcclError, NcclUniqueId};
118pub use pipeline::{Pipeline, PipelineMemoryStats, PipelineSchedule, PipelineStage};
119pub use process_group::{ProcessGroup, World};
120
121// =============================================================================
122// Prelude
123// =============================================================================
124
125/// Common imports for distributed training.
126pub mod prelude {
127    pub use crate::{
128        // Backend
129        Backend,
130        CPUOffload,
131        // FSDP
132        ColumnParallelLinear,
133        // DDP
134        DistributedDataParallel,
135        FullyShardedDataParallel,
136        GradSyncStrategy,
137        GradientBucket,
138        GradientSynchronizer,
139        MockBackend,
140        // Process groups
141        ProcessGroup,
142        ReduceOp,
143        RowParallelLinear,
144        ShardingStrategy,
145        World,
146        all_gather,
147        all_reduce_max,
148        all_reduce_mean,
149        all_reduce_min,
150        all_reduce_product,
151        // Communication
152        all_reduce_sum,
153        barrier,
154        broadcast,
155        broadcast_from,
156        gather_tensor,
157        is_main_process,
158        rank,
159        reduce_scatter_mean,
160        reduce_scatter_sum,
161        scatter_tensor,
162        sync_gradient,
163        sync_gradients,
164        world_size,
165    };
166
167    pub use axonml_autograd::Variable;
168    pub use axonml_nn::Module;
169    pub use axonml_tensor::Tensor;
170}
171
172// =============================================================================
173// Type Aliases
174// =============================================================================
175
176/// Type alias for `DistributedDataParallel`.
177pub type DDP<M> = DistributedDataParallel<M>;
178
179/// Type alias for `FullyShardedDataParallel`.
180pub type FSDP<M> = FullyShardedDataParallel<M>;
181
182// =============================================================================
183// Tests
184// =============================================================================
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use axonml_autograd::Variable;
190    use axonml_nn::{Linear, Module};
191    use axonml_tensor::Tensor;
192    use std::sync::Arc;
193
194    #[test]
195    fn test_full_distributed_workflow() {
196        // Initialize world
197        let world = World::mock();
198        assert_eq!(world.rank(), 0);
199        assert!(world.is_main());
200
201        // Create model and wrap in DDP
202        let model = Linear::new(10, 5);
203        let mut ddp = DDP::new(model, world.default_group().clone());
204
205        // Forward pass
206        let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
207        let output = ddp.forward(&input);
208
209        assert_eq!(output.data().shape(), &[1, 5]);
210
211        // Train mode
212        ddp.train();
213        assert!(ddp.is_training());
214
215        // Sync parameters
216        ddp.sync_parameters();
217
218        // Sync gradients
219        ddp.sync_gradients();
220    }
221
222    #[test]
223    fn test_multiple_backends() {
224        let backends = MockBackend::create_world(4);
225
226        // All backends should have consistent world view
227        for (i, backend) in backends.iter().enumerate() {
228            assert_eq!(backend.rank(), i);
229            assert_eq!(backend.world_size(), 4);
230        }
231    }
232
233    #[test]
234    fn test_process_group_creation() {
235        let backends = MockBackend::create_world(2);
236        let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
237
238        assert_eq!(pg.size(), 2);
239        assert_eq!(pg.ranks().len(), 2);
240    }
241
242    #[test]
243    fn test_communication_functions() {
244        let pg = ProcessGroup::mock();
245
246        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
247
248        // All reduce
249        all_reduce_sum(&mut tensor, &pg);
250        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
251
252        // Broadcast
253        broadcast(&mut tensor, &pg);
254        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
255
256        // All gather
257        let gathered = all_gather(&tensor, &pg);
258        assert_eq!(gathered.shape(), &[1, 3]);
259
260        // Barrier
261        barrier(&pg);
262    }
263
264    #[test]
265    fn test_gradient_bucket_workflow() {
266        let mut bucket = GradientBucket::new(1000);
267
268        // Add gradients
269        let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
270        let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
271
272        bucket.add(&grad1);
273        bucket.add(&grad2);
274
275        assert_eq!(bucket.size(), 5);
276
277        // Extract
278        let tensors = bucket.extract();
279        assert_eq!(tensors.len(), 2);
280        assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
281        assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
282
283        // Clear
284        bucket.clear();
285        assert!(bucket.is_empty());
286    }
287
288    #[test]
289    fn test_gradient_synchronizer_workflow() {
290        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
291        sync.prepare(10);
292
293        // Add gradients
294        let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
295        sync.add_gradient(0, &grad);
296
297        // Sync
298        let pg = ProcessGroup::mock();
299        sync.sync_all(&pg);
300
301        // Clear
302        sync.clear();
303    }
304
305    #[test]
306    fn test_world_default_group() {
307        let world = World::mock();
308        let group = world.default_group();
309
310        assert_eq!(group.rank(), 0);
311        assert_eq!(group.world_size(), 1);
312    }
313
314    #[test]
315    fn test_world_new_subgroup() {
316        let world = World::mock();
317        let subgroup = world.new_group(vec![0]);
318
319        assert_eq!(subgroup.size(), 1);
320        assert!(subgroup.contains(0));
321    }
322
323    #[test]
324    fn test_ddp_builder_pattern() {
325        let model = Linear::new(10, 5);
326        let pg = ProcessGroup::mock();
327
328        let ddp = DDP::new(model, pg)
329            .broadcast_buffers(false)
330            .gradient_as_bucket_view(false);
331
332        // Linear defaults to training mode, DDP wraps it
333        assert!(ddp.is_training());
334    }
335
336    #[test]
337    fn test_reduce_op_all_variants() {
338        let op_sum = ReduceOp::Sum;
339        let op_prod = ReduceOp::Product;
340        let op_min = ReduceOp::Min;
341        let op_max = ReduceOp::Max;
342        let op_avg = ReduceOp::Average;
343
344        assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
345        assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
346        assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
347        assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
348        assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
349    }
350}