1#![warn(missing_docs)]
31#![warn(clippy::all)]
32#![warn(clippy::pedantic)]
33#![allow(clippy::cast_possible_truncation)]
35#![allow(clippy::cast_sign_loss)]
36#![allow(clippy::cast_precision_loss)]
37#![allow(clippy::cast_possible_wrap)]
38#![allow(clippy::missing_errors_doc)]
39#![allow(clippy::missing_panics_doc)]
40#![allow(clippy::must_use_candidate)]
41#![allow(clippy::module_name_repetitions)]
42#![allow(clippy::similar_names)]
43#![allow(clippy::many_single_char_names)]
44#![allow(clippy::too_many_arguments)]
45#![allow(clippy::doc_markdown)]
46#![allow(clippy::cast_lossless)]
47#![allow(clippy::needless_pass_by_value)]
48#![allow(clippy::redundant_closure_for_method_calls)]
49#![allow(clippy::uninlined_format_args)]
50#![allow(clippy::ptr_arg)]
51#![allow(clippy::return_self_not_must_use)]
52#![allow(clippy::not_unsafe_ptr_arg_deref)]
53#![allow(clippy::items_after_statements)]
54#![allow(clippy::unreadable_literal)]
55#![allow(clippy::if_same_then_else)]
56#![allow(clippy::needless_range_loop)]
57#![allow(clippy::trivially_copy_pass_by_ref)]
58#![allow(clippy::unnecessary_wraps)]
59#![allow(clippy::match_same_arms)]
60#![allow(clippy::unused_self)]
61#![allow(clippy::too_many_lines)]
62#![allow(clippy::single_match_else)]
63#![allow(clippy::fn_params_excessive_bools)]
64#![allow(clippy::struct_excessive_bools)]
65#![allow(clippy::format_push_string)]
66#![allow(clippy::erasing_op)]
67#![allow(clippy::type_repetition_in_bounds)]
68#![allow(clippy::iter_without_into_iter)]
69#![allow(clippy::should_implement_trait)]
70#![allow(clippy::use_debug)]
71#![allow(clippy::case_sensitive_file_extension_comparisons)]
72#![allow(clippy::large_enum_variant)]
73#![allow(clippy::panic)]
74#![allow(clippy::struct_field_names)]
75#![allow(clippy::missing_fields_in_debug)]
76#![allow(clippy::upper_case_acronyms)]
77#![allow(clippy::assigning_clones)]
78#![allow(clippy::option_if_let_else)]
79#![allow(clippy::manual_let_else)]
80#![allow(clippy::explicit_iter_loop)]
81#![allow(clippy::default_trait_access)]
82#![allow(clippy::only_used_in_recursion)]
83#![allow(clippy::manual_clamp)]
84#![allow(clippy::ref_option)]
85#![allow(clippy::multiple_bound_locations)]
86#![allow(clippy::comparison_chain)]
87#![allow(clippy::manual_assert)]
88#![allow(clippy::unnecessary_debug_formatting)]
89
90pub mod backend;
91pub mod comm;
92pub mod ddp;
93pub mod process_group;
94
95pub use backend::{Backend, MockBackend, ReduceOp};
100pub use comm::{
101 all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
102 all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
103 reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
104 world_size,
105};
106pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
107pub use process_group::{ProcessGroup, World};
108
109pub mod prelude {
115 pub use crate::{
116 all_gather,
117 all_reduce_max,
118 all_reduce_mean,
119 all_reduce_min,
120 all_reduce_product,
121 all_reduce_sum,
123 barrier,
124 broadcast,
125 broadcast_from,
126 gather_tensor,
127 is_main_process,
128 rank,
129 reduce_scatter_mean,
130 reduce_scatter_sum,
131 scatter_tensor,
132 sync_gradient,
133 sync_gradients,
134 world_size,
135 Backend,
137 DistributedDataParallel,
139 GradSyncStrategy,
140 GradientBucket,
141 GradientSynchronizer,
142 MockBackend,
143 ProcessGroup,
145 ReduceOp,
146 World,
147 };
148
149 pub use axonml_autograd::Variable;
150 pub use axonml_nn::Module;
151 pub use axonml_tensor::Tensor;
152}
153
154pub type DDP<M> = DistributedDataParallel<M>;
160
161#[cfg(test)]
166mod tests {
167 use super::*;
168 use axonml_autograd::Variable;
169 use axonml_nn::{Linear, Module};
170 use axonml_tensor::Tensor;
171 use std::sync::Arc;
172
173 #[test]
174 fn test_full_distributed_workflow() {
175 let world = World::mock();
177 assert_eq!(world.rank(), 0);
178 assert!(world.is_main());
179
180 let model = Linear::new(10, 5);
182 let mut ddp = DDP::new(model, world.default_group().clone());
183
184 let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
186 let output = ddp.forward(&input);
187
188 assert_eq!(output.data().shape(), &[1, 5]);
189
190 ddp.train();
192 assert!(ddp.is_training());
193
194 ddp.sync_parameters();
196
197 ddp.sync_gradients();
199 }
200
201 #[test]
202 fn test_multiple_backends() {
203 let backends = MockBackend::create_world(4);
204
205 for (i, backend) in backends.iter().enumerate() {
207 assert_eq!(backend.rank(), i);
208 assert_eq!(backend.world_size(), 4);
209 }
210 }
211
212 #[test]
213 fn test_process_group_creation() {
214 let backends = MockBackend::create_world(2);
215 let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
216
217 assert_eq!(pg.size(), 2);
218 assert_eq!(pg.ranks().len(), 2);
219 }
220
221 #[test]
222 fn test_communication_functions() {
223 let pg = ProcessGroup::mock();
224
225 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
226
227 all_reduce_sum(&mut tensor, &pg);
229 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
230
231 broadcast(&mut tensor, &pg);
233 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
234
235 let gathered = all_gather(&tensor, &pg);
237 assert_eq!(gathered.shape(), &[1, 3]);
238
239 barrier(&pg);
241 }
242
243 #[test]
244 fn test_gradient_bucket_workflow() {
245 let mut bucket = GradientBucket::new(1000);
246
247 let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
249 let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
250
251 bucket.add(&grad1);
252 bucket.add(&grad2);
253
254 assert_eq!(bucket.size(), 5);
255
256 let tensors = bucket.extract();
258 assert_eq!(tensors.len(), 2);
259 assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
260 assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
261
262 bucket.clear();
264 assert!(bucket.is_empty());
265 }
266
267 #[test]
268 fn test_gradient_synchronizer_workflow() {
269 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
270 sync.prepare(10);
271
272 let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
274 sync.add_gradient(0, &grad);
275
276 let pg = ProcessGroup::mock();
278 sync.sync_all(&pg);
279
280 sync.clear();
282 }
283
284 #[test]
285 fn test_world_default_group() {
286 let world = World::mock();
287 let group = world.default_group();
288
289 assert_eq!(group.rank(), 0);
290 assert_eq!(group.world_size(), 1);
291 }
292
293 #[test]
294 fn test_world_new_subgroup() {
295 let world = World::mock();
296 let subgroup = world.new_group(vec![0]);
297
298 assert_eq!(subgroup.size(), 1);
299 assert!(subgroup.contains(0));
300 }
301
302 #[test]
303 fn test_ddp_builder_pattern() {
304 let model = Linear::new(10, 5);
305 let pg = ProcessGroup::mock();
306
307 let ddp = DDP::new(model, pg)
308 .broadcast_buffers(false)
309 .gradient_as_bucket_view(false);
310
311 assert!(ddp.is_training());
313 }
314
315 #[test]
316 fn test_reduce_op_all_variants() {
317 let op_sum = ReduceOp::Sum;
318 let op_prod = ReduceOp::Product;
319 let op_min = ReduceOp::Min;
320 let op_max = ReduceOp::Max;
321 let op_avg = ReduceOp::Average;
322
323 assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
324 assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
325 assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
326 assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
327 assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
328 }
329}