axonml_distributed/
lib.rs

1//! Axonml Distributed - Distributed Training Utilities
2//!
3//! This crate provides distributed training functionality for the Axonml ML framework:
4//!
5//! - **Backend**: Communication backend abstraction with mock implementation for testing
6//! - **`ProcessGroup`**: Process group abstraction for managing distributed processes
7//! - **DDP**: `DistributedDataParallel` wrapper for synchronizing gradients
8//! - **Communication**: High-level communication utilities (all-reduce, broadcast, etc.)
9//!
10//! # Example
11//!
12//! ```ignore
13//! use axonml_distributed::prelude::*;
14//! use axonml_nn::Linear;
15//!
16//! // Create a mock process group for testing
17//! let world = World::mock();
18//!
19//! // Wrap a model in DDP
20//! let model = Linear::new(10, 5);
21//! let ddp_model = DistributedDataParallel::new(model, world.default_group().clone());
22//!
23//! // Synchronize gradients after backward pass
24//! ddp_model.sync_gradients();
25//! ```
26//!
27//! @version 0.1.0
28//! @author `AutomataNexus` Development Team
29
30#![warn(missing_docs)]
31#![warn(clippy::all)]
32#![warn(clippy::pedantic)]
33// ML/tensor-specific allowances
34#![allow(clippy::cast_possible_truncation)]
35#![allow(clippy::cast_sign_loss)]
36#![allow(clippy::cast_precision_loss)]
37#![allow(clippy::cast_possible_wrap)]
38#![allow(clippy::missing_errors_doc)]
39#![allow(clippy::missing_panics_doc)]
40#![allow(clippy::must_use_candidate)]
41#![allow(clippy::module_name_repetitions)]
42#![allow(clippy::similar_names)]
43#![allow(clippy::many_single_char_names)]
44#![allow(clippy::too_many_arguments)]
45#![allow(clippy::doc_markdown)]
46#![allow(clippy::cast_lossless)]
47#![allow(clippy::needless_pass_by_value)]
48#![allow(clippy::redundant_closure_for_method_calls)]
49#![allow(clippy::uninlined_format_args)]
50#![allow(clippy::ptr_arg)]
51#![allow(clippy::return_self_not_must_use)]
52#![allow(clippy::not_unsafe_ptr_arg_deref)]
53#![allow(clippy::items_after_statements)]
54#![allow(clippy::unreadable_literal)]
55#![allow(clippy::if_same_then_else)]
56#![allow(clippy::needless_range_loop)]
57#![allow(clippy::trivially_copy_pass_by_ref)]
58#![allow(clippy::unnecessary_wraps)]
59#![allow(clippy::match_same_arms)]
60#![allow(clippy::unused_self)]
61#![allow(clippy::too_many_lines)]
62#![allow(clippy::single_match_else)]
63#![allow(clippy::fn_params_excessive_bools)]
64#![allow(clippy::struct_excessive_bools)]
65#![allow(clippy::format_push_string)]
66#![allow(clippy::erasing_op)]
67#![allow(clippy::type_repetition_in_bounds)]
68#![allow(clippy::iter_without_into_iter)]
69#![allow(clippy::should_implement_trait)]
70#![allow(clippy::use_debug)]
71#![allow(clippy::case_sensitive_file_extension_comparisons)]
72#![allow(clippy::large_enum_variant)]
73#![allow(clippy::panic)]
74#![allow(clippy::struct_field_names)]
75#![allow(clippy::missing_fields_in_debug)]
76#![allow(clippy::upper_case_acronyms)]
77#![allow(clippy::assigning_clones)]
78#![allow(clippy::option_if_let_else)]
79#![allow(clippy::manual_let_else)]
80#![allow(clippy::explicit_iter_loop)]
81#![allow(clippy::default_trait_access)]
82#![allow(clippy::only_used_in_recursion)]
83#![allow(clippy::manual_clamp)]
84#![allow(clippy::ref_option)]
85#![allow(clippy::multiple_bound_locations)]
86#![allow(clippy::comparison_chain)]
87#![allow(clippy::manual_assert)]
88#![allow(clippy::unnecessary_debug_formatting)]
89
90pub mod backend;
91pub mod comm;
92pub mod ddp;
93pub mod process_group;
94
95// =============================================================================
96// Re-exports
97// =============================================================================
98
99pub use backend::{Backend, MockBackend, ReduceOp};
100pub use comm::{
101    all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
102    all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
103    reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
104    world_size,
105};
106pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
107pub use process_group::{ProcessGroup, World};
108
109// =============================================================================
110// Prelude
111// =============================================================================
112
113/// Common imports for distributed training.
114pub mod prelude {
115    pub use crate::{
116        all_gather,
117        all_reduce_max,
118        all_reduce_mean,
119        all_reduce_min,
120        all_reduce_product,
121        // Communication
122        all_reduce_sum,
123        barrier,
124        broadcast,
125        broadcast_from,
126        gather_tensor,
127        is_main_process,
128        rank,
129        reduce_scatter_mean,
130        reduce_scatter_sum,
131        scatter_tensor,
132        sync_gradient,
133        sync_gradients,
134        world_size,
135        // Backend
136        Backend,
137        // DDP
138        DistributedDataParallel,
139        GradSyncStrategy,
140        GradientBucket,
141        GradientSynchronizer,
142        MockBackend,
143        // Process groups
144        ProcessGroup,
145        ReduceOp,
146        World,
147    };
148
149    pub use axonml_autograd::Variable;
150    pub use axonml_nn::Module;
151    pub use axonml_tensor::Tensor;
152}
153
154// =============================================================================
155// Type Aliases
156// =============================================================================
157
158/// Type alias for `DistributedDataParallel`.
159pub type DDP<M> = DistributedDataParallel<M>;
160
161// =============================================================================
162// Tests
163// =============================================================================
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use axonml_autograd::Variable;
169    use axonml_nn::{Linear, Module};
170    use axonml_tensor::Tensor;
171    use std::sync::Arc;
172
173    #[test]
174    fn test_full_distributed_workflow() {
175        // Initialize world
176        let world = World::mock();
177        assert_eq!(world.rank(), 0);
178        assert!(world.is_main());
179
180        // Create model and wrap in DDP
181        let model = Linear::new(10, 5);
182        let mut ddp = DDP::new(model, world.default_group().clone());
183
184        // Forward pass
185        let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
186        let output = ddp.forward(&input);
187
188        assert_eq!(output.data().shape(), &[1, 5]);
189
190        // Train mode
191        ddp.train();
192        assert!(ddp.is_training());
193
194        // Sync parameters
195        ddp.sync_parameters();
196
197        // Sync gradients
198        ddp.sync_gradients();
199    }
200
201    #[test]
202    fn test_multiple_backends() {
203        let backends = MockBackend::create_world(4);
204
205        // All backends should have consistent world view
206        for (i, backend) in backends.iter().enumerate() {
207            assert_eq!(backend.rank(), i);
208            assert_eq!(backend.world_size(), 4);
209        }
210    }
211
212    #[test]
213    fn test_process_group_creation() {
214        let backends = MockBackend::create_world(2);
215        let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
216
217        assert_eq!(pg.size(), 2);
218        assert_eq!(pg.ranks().len(), 2);
219    }
220
221    #[test]
222    fn test_communication_functions() {
223        let pg = ProcessGroup::mock();
224
225        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
226
227        // All reduce
228        all_reduce_sum(&mut tensor, &pg);
229        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
230
231        // Broadcast
232        broadcast(&mut tensor, &pg);
233        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
234
235        // All gather
236        let gathered = all_gather(&tensor, &pg);
237        assert_eq!(gathered.shape(), &[1, 3]);
238
239        // Barrier
240        barrier(&pg);
241    }
242
243    #[test]
244    fn test_gradient_bucket_workflow() {
245        let mut bucket = GradientBucket::new(1000);
246
247        // Add gradients
248        let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
249        let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
250
251        bucket.add(&grad1);
252        bucket.add(&grad2);
253
254        assert_eq!(bucket.size(), 5);
255
256        // Extract
257        let tensors = bucket.extract();
258        assert_eq!(tensors.len(), 2);
259        assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
260        assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
261
262        // Clear
263        bucket.clear();
264        assert!(bucket.is_empty());
265    }
266
267    #[test]
268    fn test_gradient_synchronizer_workflow() {
269        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
270        sync.prepare(10);
271
272        // Add gradients
273        let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
274        sync.add_gradient(0, &grad);
275
276        // Sync
277        let pg = ProcessGroup::mock();
278        sync.sync_all(&pg);
279
280        // Clear
281        sync.clear();
282    }
283
284    #[test]
285    fn test_world_default_group() {
286        let world = World::mock();
287        let group = world.default_group();
288
289        assert_eq!(group.rank(), 0);
290        assert_eq!(group.world_size(), 1);
291    }
292
293    #[test]
294    fn test_world_new_subgroup() {
295        let world = World::mock();
296        let subgroup = world.new_group(vec![0]);
297
298        assert_eq!(subgroup.size(), 1);
299        assert!(subgroup.contains(0));
300    }
301
302    #[test]
303    fn test_ddp_builder_pattern() {
304        let model = Linear::new(10, 5);
305        let pg = ProcessGroup::mock();
306
307        let ddp = DDP::new(model, pg)
308            .broadcast_buffers(false)
309            .gradient_as_bucket_view(false);
310
311        // Linear defaults to training mode, DDP wraps it
312        assert!(ddp.is_training());
313    }
314
315    #[test]
316    fn test_reduce_op_all_variants() {
317        let op_sum = ReduceOp::Sum;
318        let op_prod = ReduceOp::Product;
319        let op_min = ReduceOp::Min;
320        let op_max = ReduceOp::Max;
321        let op_avg = ReduceOp::Average;
322
323        assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
324        assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
325        assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
326        assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
327        assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
328    }
329}