Skip to main content

axonml_distributed/
lib.rs

1//! Axonml Distributed - Distributed Training Utilities
2//!
3//! # File
4//! `crates/axonml-distributed/src/lib.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20// ML/tensor-specific allowances
21#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76#![allow(clippy::ptr_as_ptr)]
77#![allow(clippy::ptr_cast_constness)]
78#![allow(clippy::manual_slice_size_calculation)]
79#![allow(clippy::needless_lifetimes)]
80
81pub mod backend;
82pub mod comm;
83pub mod ddp;
84pub mod fsdp;
85#[cfg(feature = "nccl")]
86pub mod nccl_backend;
87pub mod pipeline;
88pub mod process_group;
89
90// =============================================================================
91// Re-exports
92// =============================================================================
93
94pub use backend::{Backend, MockBackend, ReduceOp};
95pub use comm::{
96    all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
97    all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
98    reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
99    world_size,
100};
101pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
102pub use fsdp::{
103    CPUOffload, ColumnParallelLinear, FSDPMemoryStats, FullyShardedDataParallel, RowParallelLinear,
104    ShardingStrategy,
105};
106#[cfg(feature = "nccl")]
107pub use nccl_backend::{NcclBackend, NcclError, NcclUniqueId};
108pub use pipeline::{Pipeline, PipelineMemoryStats, PipelineSchedule, PipelineStage};
109pub use process_group::{ProcessGroup, World};
110
111// =============================================================================
112// Prelude
113// =============================================================================
114
115/// Common imports for distributed training.
116pub mod prelude {
117    pub use crate::{
118        // Backend
119        Backend,
120        CPUOffload,
121        // FSDP
122        ColumnParallelLinear,
123        // DDP
124        DistributedDataParallel,
125        FullyShardedDataParallel,
126        GradSyncStrategy,
127        GradientBucket,
128        GradientSynchronizer,
129        MockBackend,
130        // Process groups
131        ProcessGroup,
132        ReduceOp,
133        RowParallelLinear,
134        ShardingStrategy,
135        World,
136        all_gather,
137        all_reduce_max,
138        all_reduce_mean,
139        all_reduce_min,
140        all_reduce_product,
141        // Communication
142        all_reduce_sum,
143        barrier,
144        broadcast,
145        broadcast_from,
146        gather_tensor,
147        is_main_process,
148        rank,
149        reduce_scatter_mean,
150        reduce_scatter_sum,
151        scatter_tensor,
152        sync_gradient,
153        sync_gradients,
154        world_size,
155    };
156
157    pub use axonml_autograd::Variable;
158    pub use axonml_nn::Module;
159    pub use axonml_tensor::Tensor;
160}
161
162// =============================================================================
163// Type Aliases
164// =============================================================================
165
166/// Type alias for `DistributedDataParallel`.
167pub type DDP<M> = DistributedDataParallel<M>;
168
169/// Type alias for `FullyShardedDataParallel`.
170pub type FSDP<M> = FullyShardedDataParallel<M>;
171
172// =============================================================================
173// Tests
174// =============================================================================
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use axonml_autograd::Variable;
180    use axonml_nn::{Linear, Module};
181    use axonml_tensor::Tensor;
182    use std::sync::Arc;
183
184    #[test]
185    fn test_full_distributed_workflow() {
186        // Initialize world
187        let world = World::mock();
188        assert_eq!(world.rank(), 0);
189        assert!(world.is_main());
190
191        // Create model and wrap in DDP
192        let model = Linear::new(10, 5);
193        let mut ddp = DDP::new(model, world.default_group().clone());
194
195        // Forward pass
196        let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
197        let output = ddp.forward(&input);
198
199        assert_eq!(output.data().shape(), &[1, 5]);
200
201        // Train mode
202        ddp.train();
203        assert!(ddp.is_training());
204
205        // Sync parameters
206        ddp.sync_parameters();
207
208        // Sync gradients
209        ddp.sync_gradients();
210    }
211
212    #[test]
213    fn test_multiple_backends() {
214        let backends = MockBackend::create_world(4);
215
216        // All backends should have consistent world view
217        for (i, backend) in backends.iter().enumerate() {
218            assert_eq!(backend.rank(), i);
219            assert_eq!(backend.world_size(), 4);
220        }
221    }
222
223    #[test]
224    fn test_process_group_creation() {
225        let backends = MockBackend::create_world(2);
226        let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
227
228        assert_eq!(pg.size(), 2);
229        assert_eq!(pg.ranks().len(), 2);
230    }
231
232    #[test]
233    fn test_communication_functions() {
234        let pg = ProcessGroup::mock();
235
236        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
237
238        // All reduce
239        all_reduce_sum(&mut tensor, &pg);
240        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
241
242        // Broadcast
243        broadcast(&mut tensor, &pg);
244        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
245
246        // All gather
247        let gathered = all_gather(&tensor, &pg);
248        assert_eq!(gathered.shape(), &[1, 3]);
249
250        // Barrier
251        barrier(&pg);
252    }
253
254    #[test]
255    fn test_gradient_bucket_workflow() {
256        let mut bucket = GradientBucket::new(1000);
257
258        // Add gradients
259        let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
260        let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
261
262        bucket.add(&grad1);
263        bucket.add(&grad2);
264
265        assert_eq!(bucket.size(), 5);
266
267        // Extract
268        let tensors = bucket.extract();
269        assert_eq!(tensors.len(), 2);
270        assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
271        assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
272
273        // Clear
274        bucket.clear();
275        assert!(bucket.is_empty());
276    }
277
278    #[test]
279    fn test_gradient_synchronizer_workflow() {
280        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
281        sync.prepare(10);
282
283        // Add gradients
284        let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
285        sync.add_gradient(0, &grad);
286
287        // Sync
288        let pg = ProcessGroup::mock();
289        sync.sync_all(&pg);
290
291        // Clear
292        sync.clear();
293    }
294
295    #[test]
296    fn test_world_default_group() {
297        let world = World::mock();
298        let group = world.default_group();
299
300        assert_eq!(group.rank(), 0);
301        assert_eq!(group.world_size(), 1);
302    }
303
304    #[test]
305    fn test_world_new_subgroup() {
306        let world = World::mock();
307        let subgroup = world.new_group(vec![0]);
308
309        assert_eq!(subgroup.size(), 1);
310        assert!(subgroup.contains(0));
311    }
312
313    #[test]
314    fn test_ddp_builder_pattern() {
315        let model = Linear::new(10, 5);
316        let pg = ProcessGroup::mock();
317
318        let ddp = DDP::new(model, pg)
319            .broadcast_buffers(false)
320            .gradient_as_bucket_view(false);
321
322        // Linear defaults to training mode, DDP wraps it
323        assert!(ddp.is_training());
324    }
325
326    #[test]
327    fn test_reduce_op_all_variants() {
328        let op_sum = ReduceOp::Sum;
329        let op_prod = ReduceOp::Product;
330        let op_min = ReduceOp::Min;
331        let op_max = ReduceOp::Max;
332        let op_avg = ReduceOp::Average;
333
334        assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
335        assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
336        assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
337        assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
338        assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
339    }
340}