1#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76#![allow(clippy::ptr_as_ptr)]
77#![allow(clippy::ptr_cast_constness)]
78#![allow(clippy::manual_slice_size_calculation)]
79#![allow(clippy::needless_lifetimes)]
80
81pub mod backend;
82pub mod comm;
83pub mod ddp;
84pub mod fsdp;
85#[cfg(feature = "nccl")]
86pub mod nccl_backend;
87pub mod pipeline;
88pub mod process_group;
89
90pub use backend::{Backend, MockBackend, ReduceOp};
95pub use comm::{
96 all_gather, all_reduce_max, all_reduce_mean, all_reduce_min, all_reduce_product,
97 all_reduce_sum, barrier, broadcast, broadcast_from, gather_tensor, is_main_process, rank,
98 reduce_scatter_mean, reduce_scatter_sum, scatter_tensor, sync_gradient, sync_gradients,
99 world_size,
100};
101pub use ddp::{DistributedDataParallel, GradSyncStrategy, GradientBucket, GradientSynchronizer};
102pub use fsdp::{
103 CPUOffload, ColumnParallelLinear, FSDPMemoryStats, FullyShardedDataParallel, RowParallelLinear,
104 ShardingStrategy,
105};
106#[cfg(feature = "nccl")]
107pub use nccl_backend::{NcclBackend, NcclError, NcclUniqueId};
108pub use pipeline::{Pipeline, PipelineMemoryStats, PipelineSchedule, PipelineStage};
109pub use process_group::{ProcessGroup, World};
110
111pub mod prelude {
117 pub use crate::{
118 Backend,
120 CPUOffload,
121 ColumnParallelLinear,
123 DistributedDataParallel,
125 FullyShardedDataParallel,
126 GradSyncStrategy,
127 GradientBucket,
128 GradientSynchronizer,
129 MockBackend,
130 ProcessGroup,
132 ReduceOp,
133 RowParallelLinear,
134 ShardingStrategy,
135 World,
136 all_gather,
137 all_reduce_max,
138 all_reduce_mean,
139 all_reduce_min,
140 all_reduce_product,
141 all_reduce_sum,
143 barrier,
144 broadcast,
145 broadcast_from,
146 gather_tensor,
147 is_main_process,
148 rank,
149 reduce_scatter_mean,
150 reduce_scatter_sum,
151 scatter_tensor,
152 sync_gradient,
153 sync_gradients,
154 world_size,
155 };
156
157 pub use axonml_autograd::Variable;
158 pub use axonml_nn::Module;
159 pub use axonml_tensor::Tensor;
160}
161
162pub type DDP<M> = DistributedDataParallel<M>;
168
169pub type FSDP<M> = FullyShardedDataParallel<M>;
171
172#[cfg(test)]
177mod tests {
178 use super::*;
179 use axonml_autograd::Variable;
180 use axonml_nn::{Linear, Module};
181 use axonml_tensor::Tensor;
182 use std::sync::Arc;
183
184 #[test]
185 fn test_full_distributed_workflow() {
186 let world = World::mock();
188 assert_eq!(world.rank(), 0);
189 assert!(world.is_main());
190
191 let model = Linear::new(10, 5);
193 let mut ddp = DDP::new(model, world.default_group().clone());
194
195 let input = Variable::new(Tensor::from_vec(vec![1.0; 10], &[1, 10]).unwrap(), false);
197 let output = ddp.forward(&input);
198
199 assert_eq!(output.data().shape(), &[1, 5]);
200
201 ddp.train();
203 assert!(ddp.is_training());
204
205 ddp.sync_parameters();
207
208 ddp.sync_gradients();
210 }
211
212 #[test]
213 fn test_multiple_backends() {
214 let backends = MockBackend::create_world(4);
215
216 for (i, backend) in backends.iter().enumerate() {
218 assert_eq!(backend.rank(), i);
219 assert_eq!(backend.world_size(), 4);
220 }
221 }
222
223 #[test]
224 fn test_process_group_creation() {
225 let backends = MockBackend::create_world(2);
226 let pg = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
227
228 assert_eq!(pg.size(), 2);
229 assert_eq!(pg.ranks().len(), 2);
230 }
231
232 #[test]
233 fn test_communication_functions() {
234 let pg = ProcessGroup::mock();
235
236 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
237
238 all_reduce_sum(&mut tensor, &pg);
240 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
241
242 broadcast(&mut tensor, &pg);
244 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
245
246 let gathered = all_gather(&tensor, &pg);
248 assert_eq!(gathered.shape(), &[1, 3]);
249
250 barrier(&pg);
252 }
253
254 #[test]
255 fn test_gradient_bucket_workflow() {
256 let mut bucket = GradientBucket::new(1000);
257
258 let grad1 = Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap();
260 let grad2 = Tensor::from_vec(vec![0.3, 0.4, 0.5], &[3]).unwrap();
261
262 bucket.add(&grad1);
263 bucket.add(&grad2);
264
265 assert_eq!(bucket.size(), 5);
266
267 let tensors = bucket.extract();
269 assert_eq!(tensors.len(), 2);
270 assert_eq!(tensors[0].to_vec(), vec![0.1, 0.2]);
271 assert_eq!(tensors[1].to_vec(), vec![0.3, 0.4, 0.5]);
272
273 bucket.clear();
275 assert!(bucket.is_empty());
276 }
277
278 #[test]
279 fn test_gradient_synchronizer_workflow() {
280 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 1000);
281 sync.prepare(10);
282
283 let grad = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
285 sync.add_gradient(0, &grad);
286
287 let pg = ProcessGroup::mock();
289 sync.sync_all(&pg);
290
291 sync.clear();
293 }
294
295 #[test]
296 fn test_world_default_group() {
297 let world = World::mock();
298 let group = world.default_group();
299
300 assert_eq!(group.rank(), 0);
301 assert_eq!(group.world_size(), 1);
302 }
303
304 #[test]
305 fn test_world_new_subgroup() {
306 let world = World::mock();
307 let subgroup = world.new_group(vec![0]);
308
309 assert_eq!(subgroup.size(), 1);
310 assert!(subgroup.contains(0));
311 }
312
313 #[test]
314 fn test_ddp_builder_pattern() {
315 let model = Linear::new(10, 5);
316 let pg = ProcessGroup::mock();
317
318 let ddp = DDP::new(model, pg)
319 .broadcast_buffers(false)
320 .gradient_as_bucket_view(false);
321
322 assert!(ddp.is_training());
324 }
325
326 #[test]
327 fn test_reduce_op_all_variants() {
328 let op_sum = ReduceOp::Sum;
329 let op_prod = ReduceOp::Product;
330 let op_min = ReduceOp::Min;
331 let op_max = ReduceOp::Max;
332 let op_avg = ReduceOp::Average;
333
334 assert_eq!(op_sum.apply_f32(1.0, 2.0), 3.0);
335 assert_eq!(op_prod.apply_f32(2.0, 3.0), 6.0);
336 assert_eq!(op_min.apply_f32(2.0, 3.0), 2.0);
337 assert_eq!(op_max.apply_f32(2.0, 3.0), 3.0);
338 assert_eq!(op_avg.apply_f32(2.0, 4.0), 3.0);
339 }
340}