1use crate::process_group::ProcessGroup;
18use axonml_autograd::Variable;
19use axonml_nn::{Module, Parameter};
20use axonml_tensor::Tensor;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum PipelineSchedule {
29 GPipe,
31 #[default]
33 OneFOneBSchedule,
34 InterleavedOneFOneB,
36}
37
38pub struct PipelineStage<M: Module> {
44 module: M,
46 stage_id: usize,
48 device_rank: usize,
50}
51
52impl<M: Module> PipelineStage<M> {
53 pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
55 Self {
56 module,
57 stage_id,
58 device_rank,
59 }
60 }
61
62 pub fn stage_id(&self) -> usize {
64 self.stage_id
65 }
66
67 pub fn device_rank(&self) -> usize {
69 self.device_rank
70 }
71
72 pub fn forward(&self, input: &Variable) -> Variable {
74 self.module.forward(input)
75 }
76}
77
78impl<M: Module> Module for PipelineStage<M> {
79 fn forward(&self, input: &Variable) -> Variable {
80 self.module.forward(input)
81 }
82
83 fn parameters(&self) -> Vec<Parameter> {
84 self.module.parameters()
85 }
86
87 fn train(&mut self) {
88 self.module.train();
89 }
90
91 fn eval(&mut self) {
92 self.module.eval();
93 }
94
95 fn is_training(&self) -> bool {
96 self.module.is_training()
97 }
98}
99
100pub struct Pipeline<M: Module> {
109 stages: Vec<PipelineStage<M>>,
111 process_group: ProcessGroup,
113 schedule: PipelineSchedule,
115 num_microbatches: usize,
117 #[allow(dead_code)]
119 local_stage: usize,
120}
121
122impl<M: Module + Clone> Pipeline<M> {
123 pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
125 let world_size = process_group.world_size();
126 let rank = process_group.rank();
127
128 let stages: Vec<PipelineStage<M>> = modules
129 .into_iter()
130 .enumerate()
131 .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
132 .collect();
133
134 let local_stage = stages
135 .iter()
136 .position(|s| s.device_rank == rank)
137 .unwrap_or(0);
138
139 Self {
140 stages,
141 process_group,
142 schedule: PipelineSchedule::default(),
143 num_microbatches: 1,
144 local_stage,
145 }
146 }
147
148 pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
150 self.schedule = schedule;
151 self
152 }
153
154 pub fn num_microbatches(mut self, num: usize) -> Self {
156 self.num_microbatches = num.max(1);
157 self
158 }
159
160 pub fn num_stages(&self) -> usize {
162 self.stages.len()
163 }
164
165 pub fn get_schedule(&self) -> PipelineSchedule {
167 self.schedule
168 }
169
170 pub fn forward(&self, input: &Variable) -> Variable {
176 match self.schedule {
177 PipelineSchedule::GPipe => self.forward_gpipe(input),
178 PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
179 PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
180 }
181 }
182
183 fn forward_gpipe(&self, input: &Variable) -> Variable {
185 let rank = self.process_group.rank();
186 let num_stages = self.stages.len();
187
188 let microbatches = self.split_microbatches(input);
190
191 let mut outputs = Vec::new();
193
194 for microbatch in microbatches {
195 let mut activation = microbatch;
196
197 for (stage_idx, stage) in self.stages.iter().enumerate() {
199 if stage.device_rank == rank {
200 activation = stage.forward(&activation);
201 }
202
203 if stage_idx < num_stages - 1 {
205 let next_rank = self.stages[stage_idx + 1].device_rank;
206 if stage.device_rank == rank {
207 self.send_activation(&activation, next_rank);
209 } else if next_rank == rank {
210 activation = self.recv_activation(stage.device_rank, activation.shape());
212 }
213 }
214 }
215
216 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
218 outputs.push(activation);
219 }
220 }
221
222 self.combine_microbatches(&outputs)
224 }
225
226 fn forward_1f1b(&self, input: &Variable) -> Variable {
228 self.forward_gpipe(input)
231 }
232
233 fn forward_interleaved(&self, input: &Variable) -> Variable {
235 self.forward_gpipe(input)
237 }
238
239 fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
241 let data = input.data();
242 let batch_size = data.shape()[0];
243 let microbatch_size = batch_size.div_ceil(self.num_microbatches);
244
245 let mut microbatches = Vec::new();
246 let flat_data = data.to_vec();
247 let elements_per_sample: usize = data.shape()[1..].iter().product();
248
249 for i in 0..self.num_microbatches {
250 let start = i * microbatch_size;
251 let end = ((i + 1) * microbatch_size).min(batch_size);
252
253 if start >= batch_size {
254 break;
255 }
256
257 let mb_size = end - start;
258 let start_idx = start * elements_per_sample;
259 let end_idx = end * elements_per_sample;
260 let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
261
262 let mut shape = data.shape().to_vec();
263 shape[0] = mb_size;
264 let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
265 microbatches.push(Variable::new(tensor, input.requires_grad()));
266 }
267
268 microbatches
269 }
270
271 fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
273 if outputs.is_empty() {
274 return Variable::new(Tensor::zeros(&[0]), false);
275 }
276
277 if outputs.len() == 1 {
278 return outputs[0].clone();
279 }
280
281 let mut all_data = Vec::new();
283 let mut total_batch = 0;
284 let shape = outputs[0].data().shape().to_vec();
285
286 for output in outputs {
287 all_data.extend(output.data().to_vec());
288 total_batch += output.data().shape()[0];
289 }
290
291 let mut new_shape = shape;
292 new_shape[0] = total_batch;
293 let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
294 Variable::new(tensor, outputs[0].requires_grad())
295 }
296
297 fn send_activation(&self, activation: &Variable, dest_rank: usize) {
299 let mut tensor = activation.data().clone();
300 self.process_group.send_tensor(&mut tensor, dest_rank);
301 }
302
303 fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
305 let tensor = self.process_group.recv_tensor(src_rank, &shape);
306 Variable::new(tensor, true)
307 }
308}
309
310impl<M: Module + Clone> Module for Pipeline<M> {
311 fn forward(&self, input: &Variable) -> Variable {
312 Pipeline::forward(self, input)
313 }
314
315 fn parameters(&self) -> Vec<Parameter> {
316 self.stages.iter().flat_map(|s| s.parameters()).collect()
317 }
318
319 fn train(&mut self) {
320 for stage in &mut self.stages {
321 stage.train();
322 }
323 }
324
325 fn eval(&mut self) {
326 for stage in &mut self.stages {
327 stage.eval();
328 }
329 }
330
331 fn is_training(&self) -> bool {
332 self.stages.first().is_some_and(|s| s.is_training())
333 }
334}
335
336#[derive(Debug, Clone)]
342pub struct PipelineMemoryStats {
343 pub num_stages: usize,
345 pub num_microbatches: usize,
347 pub peak_activations_per_stage: usize,
349 pub schedule: PipelineSchedule,
351}
352
353impl PipelineMemoryStats {
354 pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
356 num_stages * num_microbatches
358 }
359
360 pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
362 num_stages.min(num_microbatches)
364 }
365}
366
367#[cfg(test)]
372mod tests {
373 use super::*;
374 use axonml_nn::Linear;
375
376 #[derive(Clone)]
378 struct IdentityModule {
379 size: usize,
380 training: bool,
381 }
382
383 impl IdentityModule {
384 fn new(size: usize) -> Self {
385 Self {
386 size,
387 training: true,
388 }
389 }
390 }
391
392 impl Module for IdentityModule {
393 fn forward(&self, input: &Variable) -> Variable {
394 input.clone()
395 }
396
397 fn parameters(&self) -> Vec<Parameter> {
398 Vec::new()
399 }
400
401 fn train(&mut self) {
402 self.training = true;
403 }
404
405 fn eval(&mut self) {
406 self.training = false;
407 }
408
409 fn is_training(&self) -> bool {
410 self.training
411 }
412 }
413
414 #[test]
415 fn test_pipeline_schedule_default() {
416 assert_eq!(
417 PipelineSchedule::default(),
418 PipelineSchedule::OneFOneBSchedule
419 );
420 }
421
422 #[test]
423 fn test_pipeline_stage_creation() {
424 let module = Linear::new(10, 5);
425 let stage = PipelineStage::new(module, 0, 0);
426
427 assert_eq!(stage.stage_id(), 0);
428 assert_eq!(stage.device_rank(), 0);
429 }
430
431 #[test]
432 fn test_pipeline_creation() {
433 let modules = vec![
434 IdentityModule::new(10),
435 IdentityModule::new(8),
436 IdentityModule::new(6),
437 ];
438 let pg = ProcessGroup::mock();
439 let pipeline = Pipeline::from_modules(modules, pg)
440 .schedule(PipelineSchedule::GPipe)
441 .num_microbatches(2);
442
443 assert_eq!(pipeline.num_stages(), 3);
444 assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
445 }
446
447 #[test]
448 fn test_pipeline_forward() {
449 let modules = vec![IdentityModule::new(4)];
450 let pg = ProcessGroup::mock();
451 let pipeline = Pipeline::from_modules(modules, pg);
452
453 let input = Variable::new(Tensor::randn(&[2, 4]), false);
454 let output = pipeline.forward(&input);
455
456 assert_eq!(output.data().shape(), &[2, 4]);
457 }
458
459 #[test]
460 fn test_pipeline_memory_stats() {
461 let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
462 let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
463
464 assert_eq!(gpipe, 32); assert_eq!(one_f_one_b, 4); }
467
468 #[test]
469 fn test_split_microbatches() {
470 let modules = vec![IdentityModule::new(4)];
471 let pg = ProcessGroup::mock();
472 let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
473
474 let input = Variable::new(Tensor::randn(&[4, 4]), false);
475 let microbatches = pipeline.split_microbatches(&input);
476
477 assert_eq!(microbatches.len(), 2);
478 assert_eq!(microbatches[0].data().shape()[0], 2);
479 assert_eq!(microbatches[1].data().shape()[0], 2);
480 }
481}