1#![warn(missing_docs)]
28#![warn(clippy::all)]
29#![warn(clippy::pedantic)]
30#![allow(clippy::cast_possible_truncation)]
32#![allow(clippy::cast_sign_loss)]
33#![allow(clippy::cast_precision_loss)]
34#![allow(clippy::cast_possible_wrap)]
35#![allow(clippy::missing_errors_doc)]
36#![allow(clippy::missing_panics_doc)]
37#![allow(clippy::must_use_candidate)]
38#![allow(clippy::module_name_repetitions)]
39#![allow(clippy::similar_names)]
40#![allow(clippy::many_single_char_names)]
41#![allow(clippy::too_many_arguments)]
42#![allow(clippy::doc_markdown)]
43#![allow(clippy::cast_lossless)]
44#![allow(clippy::needless_pass_by_value)]
45#![allow(clippy::redundant_closure_for_method_calls)]
46#![allow(clippy::uninlined_format_args)]
47#![allow(clippy::ptr_arg)]
48#![allow(clippy::return_self_not_must_use)]
49#![allow(clippy::not_unsafe_ptr_arg_deref)]
50#![allow(clippy::items_after_statements)]
51#![allow(clippy::unreadable_literal)]
52#![allow(clippy::if_same_then_else)]
53#![allow(clippy::needless_range_loop)]
54#![allow(clippy::trivially_copy_pass_by_ref)]
55#![allow(clippy::unnecessary_wraps)]
56#![allow(clippy::match_same_arms)]
57#![allow(clippy::unused_self)]
58#![allow(clippy::too_many_lines)]
59#![allow(clippy::single_match_else)]
60#![allow(clippy::fn_params_excessive_bools)]
61#![allow(clippy::struct_excessive_bools)]
62#![allow(clippy::format_push_string)]
63#![allow(clippy::erasing_op)]
64#![allow(clippy::type_repetition_in_bounds)]
65#![allow(clippy::iter_without_into_iter)]
66#![allow(clippy::should_implement_trait)]
67#![allow(clippy::use_debug)]
68#![allow(clippy::case_sensitive_file_extension_comparisons)]
69#![allow(clippy::large_enum_variant)]
70#![allow(clippy::panic)]
71#![allow(clippy::struct_field_names)]
72#![allow(clippy::missing_fields_in_debug)]
73#![allow(clippy::upper_case_acronyms)]
74#![allow(clippy::assigning_clones)]
75#![allow(clippy::option_if_let_else)]
76#![allow(clippy::manual_let_else)]
77#![allow(clippy::explicit_iter_loop)]
78#![allow(clippy::default_trait_access)]
79#![allow(clippy::only_used_in_recursion)]
80#![allow(clippy::manual_clamp)]
81#![allow(clippy::ref_option)]
82#![allow(clippy::multiple_bound_locations)]
83#![allow(clippy::comparison_chain)]
84#![allow(clippy::manual_assert)]
85#![allow(clippy::unnecessary_debug_formatting)]
86#![allow(clippy::ptr_as_ptr)]
87#![allow(clippy::ptr_cast_constness)]
88#![allow(clippy::manual_slice_size_calculation)]
89#![allow(clippy::needless_lifetimes)]
90
91pub mod backend;
92pub mod comm;
93pub mod ddp;
94pub mod fsdp;
95#[cfg(feature = "nccl")]
96pub mod nccl_backend;
97pub mod pipeline;
98pub mod process_group;
99
100pub use backend::{Backend, MockBackend, ReduceOp};
105pub use comm::{
106 all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
107 all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
108 reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
109 world_size,
110};
111pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
112pub use fsdp::{
113 CPUOffload, ColumnParallelLinear, FSDPMemoryStats, FullyShardedDataParallel, RowParallelLinear,
114 ShardingStrategy,
115};
116#[cfg(feature = "nccl")]
117pub use nccl_backend::{NcclBackend, NcclError, NcclUniqueId};
118pub use pipeline::{Pipeline, PipelineMemoryStats, PipelineSchedule, PipelineStage};
119pub use process_group::{ProcessGroup, World};
120
121pub mod prelude {
127 pub use crate::{
128 Backend,
130 CPUOffload,
131 ColumnParallelLinear,
133 DistributedDataParallel,
135 FullyShardedDataParallel,
136 GradSyncStrategy,
137 GradientBucket,
138 GradientSynchronizer,
139 MockBackend,
140 ProcessGroup,
142 ReduceOp,
143 RowParallelLinear,
144 ShardingStrategy,
145 World,
146 all_gather,
147 all_reduce_max,
148 all_reduce_mean,
149 all_reduce_min,
150 all_reduce_product,
151 all_reduce_sum,
153 barrier,
154 broadcast,
155 broadcast_from,
156 gather_tensor,
157 is_main_process,
158 rank,
159 reduce_scatter_mean,
160 reduce_scatter_sum,
161 scatter_tensor,
162 sync_gradient,
163 sync_gradients,
164 world_size,
165 };
166
167 pub use axonml_autograd::Variable;
168 pub use axonml_nn::Module;
169 pub use axonml_tensor::Tensor;
170}
171
172pub type DDP<M> = DistributedDataParallel<M>;
178
179pub type FSDP<M> = FullyShardedDataParallel<M>;
181
182#[cfg(test)]
187mod tests {
188 use super::*;
189 use axonml_autograd::Variable;
190 use axonml_nn::{Linear, Module};
191 use axonml_tensor::Tensor;
192 use std::sync::Arc;
193
194 #[test]
195 fn test_full_distributed_workflow() {
196 let world = World::mock();
198 assert_eq!(world.rank(), 0);
199 assert!(world.is_main());
200
201 let model = Linear::new(10, 5);
203 let mut ddp = DDP::new(model, world.default_group().clone());
204
205 let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
207 let output = ddp.forward(&input);
208
209 assert_eq!(output.data().shape(), &[1, 5]);
210
211 ddp.train();
213 assert!(ddp.is_training());
214
215 ddp.sync_parameters();
217
218 ddp.sync_gradients();
220 }
221
222 #[test]
223 fn test_multiple_backends() {
224 let backends = MockBackend::create_world(4);
225
226 for (i, backend) in backends.iter().enumerate() {
228 assert_eq!(backend.rank(), i);
229 assert_eq!(backend.world_size(), 4);
230 }
231 }
232
233 #[test]
234 fn test_process_group_creation() {
235 let backends = MockBackend::create_world(2);
236 let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
237
238 assert_eq!(pg.size(), 2);
239 assert_eq!(pg.ranks().len(), 2);
240 }
241
242 #[test]
243 fn test_communication_functions() {
244 let pg = ProcessGroup::mock();
245
246 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
247
248 all_reduce_sum(&mut tensor, &pg);
250 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
251
252 broadcast(&mut tensor, &pg);
254 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
255
256 let gathered = all_gather(&tensor, &pg);
258 assert_eq!(gathered.shape(), &[1, 3]);
259
260 barrier(&pg);
262 }
263
264 #[test]
265 fn test_gradient_bucket_workflow() {
266 let mut bucket = GradientBucket::new(1000);
267
268 let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
270 let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
271
272 bucket.add(&grad1);
273 bucket.add(&grad2);
274
275 assert_eq!(bucket.size(), 5);
276
277 let tensors = bucket.extract();
279 assert_eq!(tensors.len(), 2);
280 assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
281 assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
282
283 bucket.clear();
285 assert!(bucket.is_empty());
286 }
287
288 #[test]
289 fn test_gradient_synchronizer_workflow() {
290 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
291 sync.prepare(10);
292
293 let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
295 sync.add_gradient(0, &grad);
296
297 let pg = ProcessGroup::mock();
299 sync.sync_all(&pg);
300
301 sync.clear();
303 }
304
305 #[test]
306 fn test_world_default_group() {
307 let world = World::mock();
308 let group = world.default_group();
309
310 assert_eq!(group.rank(), 0);
311 assert_eq!(group.world_size(), 1);
312 }
313
314 #[test]
315 fn test_world_new_subgroup() {
316 let world = World::mock();
317 let subgroup = world.new_group(vec![0]);
318
319 assert_eq!(subgroup.size(), 1);
320 assert!(subgroup.contains(0));
321 }
322
323 #[test]
324 fn test_ddp_builder_pattern() {
325 let model = Linear::new(10, 5);
326 let pg = ProcessGroup::mock();
327
328 let ddp = DDP::new(model, pg)
329 .broadcast_buffers(false)
330 .gradient_as_bucket_view(false);
331
332 assert!(ddp.is_training());
334 }
335
336 #[test]
337 fn test_reduce_op_all_variants() {
338 let op_sum = ReduceOp::Sum;
339 let op_prod = ReduceOp::Product;
340 let op_min = ReduceOp::Min;
341 let op_max = ReduceOp::Max;
342 let op_avg = ReduceOp::Average;
343
344 assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
345 assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
346 assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
347 assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
348 assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
349 }
350}