1#![warn(missing_docs)]
78#![warn(clippy::all)]
79#![warn(clippy::pedantic)]
80#![allow(clippy::cast_possible_truncation)]
82#![allow(clippy::cast_sign_loss)]
83#![allow(clippy::cast_precision_loss)]
84#![allow(clippy::cast_possible_wrap)]
85#![allow(clippy::missing_errors_doc)]
86#![allow(clippy::missing_panics_doc)]
87#![allow(clippy::must_use_candidate)]
88#![allow(clippy::module_name_repetitions)]
89#![allow(clippy::similar_names)]
90#![allow(clippy::many_single_char_names)]
91#![allow(clippy::too_many_arguments)]
92#![allow(clippy::doc_markdown)]
93#![allow(clippy::cast_lossless)]
94#![allow(clippy::needless_pass_by_value)]
95#![allow(clippy::redundant_closure_for_method_calls)]
96#![allow(clippy::uninlined_format_args)]
97#![allow(clippy::ptr_arg)]
98#![allow(clippy::return_self_not_must_use)]
99#![allow(clippy::not_unsafe_ptr_arg_deref)]
100#![allow(clippy::items_after_statements)]
101#![allow(clippy::unreadable_literal)]
102#![allow(clippy::if_same_then_else)]
103#![allow(clippy::needless_range_loop)]
104#![allow(clippy::trivially_copy_pass_by_ref)]
105#![allow(clippy::unnecessary_wraps)]
106#![allow(clippy::match_same_arms)]
107#![allow(clippy::unused_self)]
108#![allow(clippy::too_many_lines)]
109#![allow(clippy::single_match_else)]
110#![allow(clippy::fn_params_excessive_bools)]
111#![allow(clippy::struct_excessive_bools)]
112#![allow(clippy::format_push_string)]
113#![allow(clippy::erasing_op)]
114#![allow(clippy::type_repetition_in_bounds)]
115#![allow(clippy::iter_without_into_iter)]
116#![allow(clippy::should_implement_trait)]
117#![allow(clippy::use_debug)]
118#![allow(clippy::case_sensitive_file_extension_comparisons)]
119#![allow(clippy::large_enum_variant)]
120#![allow(clippy::panic)]
121#![allow(clippy::struct_field_names)]
122#![allow(clippy::missing_fields_in_debug)]
123#![allow(clippy::upper_case_acronyms)]
124#![allow(clippy::assigning_clones)]
125#![allow(clippy::option_if_let_else)]
126#![allow(clippy::manual_let_else)]
127#![allow(clippy::explicit_iter_loop)]
128#![allow(clippy::default_trait_access)]
129#![allow(clippy::only_used_in_recursion)]
130#![allow(clippy::manual_clamp)]
131#![allow(clippy::ref_option)]
132#![allow(clippy::multiple_bound_locations)]
133#![allow(clippy::comparison_chain)]
134#![allow(clippy::manual_assert)]
135#![allow(clippy::unnecessary_debug_formatting)]
136
137pub mod backend;
138pub mod comm;
139pub mod ddp;
140pub mod fsdp;
141pub mod pipeline;
142pub mod process_group;
143
144pub use backend::{Backend, MockBackend, ReduceOp};
149pub use comm::{
150 all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
151 all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
152 reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
153 world_size,
154};
155pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
156pub use fsdp::{
157 ColumnParallelLinear, CPUOffload, FSDPMemoryStats, FullyShardedDataParallel,
158 RowParallelLinear, ShardingStrategy,
159};
160pub use pipeline::{Pipeline, PipelineMemoryStats, PipelineSchedule, PipelineStage};
161pub use process_group::{ProcessGroup, World};
162
163pub mod prelude {
169 pub use crate::{
170 all_gather,
171 all_reduce_max,
172 all_reduce_mean,
173 all_reduce_min,
174 all_reduce_product,
175 all_reduce_sum,
177 barrier,
178 broadcast,
179 broadcast_from,
180 gather_tensor,
181 is_main_process,
182 rank,
183 reduce_scatter_mean,
184 reduce_scatter_sum,
185 scatter_tensor,
186 sync_gradient,
187 sync_gradients,
188 world_size,
189 Backend,
191 DistributedDataParallel,
193 GradSyncStrategy,
194 GradientBucket,
195 GradientSynchronizer,
196 MockBackend,
197 ColumnParallelLinear,
199 CPUOffload,
200 FullyShardedDataParallel,
201 RowParallelLinear,
202 ShardingStrategy,
203 ProcessGroup,
205 ReduceOp,
206 World,
207 };
208
209 pub use axonml_autograd::Variable;
210 pub use axonml_nn::Module;
211 pub use axonml_tensor::Tensor;
212}
213
214pub type DDP<M> = DistributedDataParallel<M>;
220
221pub type FSDP<M> = FullyShardedDataParallel<M>;
223
224#[cfg(test)]
229mod tests {
230 use super::*;
231 use axonml_autograd::Variable;
232 use axonml_nn::{Linear, Module};
233 use axonml_tensor::Tensor;
234 use std::sync::Arc;
235
236 #[test]
237 fn test_full_distributed_workflow() {
238 let world = World::mock();
240 assert_eq!(world.rank(), 0);
241 assert!(world.is_main());
242
243 let model = Linear::new(10, 5);
245 let mut ddp = DDP::new(model, world.default_group().clone());
246
247 let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
249 let output = ddp.forward(&input);
250
251 assert_eq!(output.data().shape(), &[1, 5]);
252
253 ddp.train();
255 assert!(ddp.is_training());
256
257 ddp.sync_parameters();
259
260 ddp.sync_gradients();
262 }
263
264 #[test]
265 fn test_multiple_backends() {
266 let backends = MockBackend::create_world(4);
267
268 for (i, backend) in backends.iter().enumerate() {
270 assert_eq!(backend.rank(), i);
271 assert_eq!(backend.world_size(), 4);
272 }
273 }
274
275 #[test]
276 fn test_process_group_creation() {
277 let backends = MockBackend::create_world(2);
278 let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
279
280 assert_eq!(pg.size(), 2);
281 assert_eq!(pg.ranks().len(), 2);
282 }
283
284 #[test]
285 fn test_communication_functions() {
286 let pg = ProcessGroup::mock();
287
288 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
289
290 all_reduce_sum(&mut tensor, &pg);
292 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
293
294 broadcast(&mut tensor, &pg);
296 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
297
298 let gathered = all_gather(&tensor, &pg);
300 assert_eq!(gathered.shape(), &[1, 3]);
301
302 barrier(&pg);
304 }
305
306 #[test]
307 fn test_gradient_bucket_workflow() {
308 let mut bucket = GradientBucket::new(1000);
309
310 let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
312 let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
313
314 bucket.add(&grad1);
315 bucket.add(&grad2);
316
317 assert_eq!(bucket.size(), 5);
318
319 let tensors = bucket.extract();
321 assert_eq!(tensors.len(), 2);
322 assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
323 assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
324
325 bucket.clear();
327 assert!(bucket.is_empty());
328 }
329
330 #[test]
331 fn test_gradient_synchronizer_workflow() {
332 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
333 sync.prepare(10);
334
335 let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
337 sync.add_gradient(0, &grad);
338
339 let pg = ProcessGroup::mock();
341 sync.sync_all(&pg);
342
343 sync.clear();
345 }
346
347 #[test]
348 fn test_world_default_group() {
349 let world = World::mock();
350 let group = world.default_group();
351
352 assert_eq!(group.rank(), 0);
353 assert_eq!(group.world_size(), 1);
354 }
355
356 #[test]
357 fn test_world_new_subgroup() {
358 let world = World::mock();
359 let subgroup = world.new_group(vec![0]);
360
361 assert_eq!(subgroup.size(), 1);
362 assert!(subgroup.contains(0));
363 }
364
365 #[test]
366 fn test_ddp_builder_pattern() {
367 let model = Linear::new(10, 5);
368 let pg = ProcessGroup::mock();
369
370 let ddp = DDP::new(model, pg)
371 .broadcast_buffers(false)
372 .gradient_as_bucket_view(false);
373
374 assert!(ddp.is_training());
376 }
377
378 #[test]
379 fn test_reduce_op_all_variants() {
380 let op_sum = ReduceOp::Sum;
381 let op_prod = ReduceOp::Product;
382 let op_min = ReduceOp::Min;
383 let op_max = ReduceOp::Max;
384 let op_avg = ReduceOp::Average;
385
386 assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
387 assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
388 assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
389 assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
390 assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
391 }
392}