1use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum PipelineSchedule {
30 GPipe,
32 #[default]
34 OneFOneBSchedule,
35 InterleavedOneFOneB,
37}
38
39pub struct PipelineStage<M: Module> {
45 module: M,
47 stage_id: usize,
49 device_rank: usize,
51}
52
53impl<M: Module> PipelineStage<M> {
54 pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
56 Self {
57 module,
58 stage_id,
59 device_rank,
60 }
61 }
62
63 pub fn stage_id(&self) -> usize {
65 self.stage_id
66 }
67
68 pub fn device_rank(&self) -> usize {
70 self.device_rank
71 }
72
73 pub fn forward(&self, input: &Variable) -> Variable {
75 self.module.forward(input)
76 }
77}
78
79impl<M: Module> Module for PipelineStage<M> {
80 fn forward(&self, input: &Variable) -> Variable {
81 self.module.forward(input)
82 }
83
84 fn parameters(&self) -> Vec<Parameter> {
85 self.module.parameters()
86 }
87
88 fn train(&mut self) {
89 self.module.train();
90 }
91
92 fn eval(&mut self) {
93 self.module.eval();
94 }
95
96 fn is_training(&self) -> bool {
97 self.module.is_training()
98 }
99}
100
101pub struct Pipeline<M: Module> {
110 stages: Vec<PipelineStage<M>>,
112 process_group: ProcessGroup,
114 schedule: PipelineSchedule,
116 num_microbatches: usize,
118 pub local_stage: usize,
120}
121
122impl<M: Module + Clone> Pipeline<M> {
123 pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
125 let world_size = process_group.world_size();
126 let rank = process_group.rank();
127
128 let stages: Vec<PipelineStage<M>> = modules
129 .into_iter()
130 .enumerate()
131 .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
132 .collect();
133
134 let local_stage = stages
135 .iter()
136 .position(|s| s.device_rank == rank)
137 .unwrap_or(0);
138
139 Self {
140 stages,
141 process_group,
142 schedule: PipelineSchedule::default(),
143 num_microbatches: 1,
144 local_stage,
145 }
146 }
147
148 pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
150 self.schedule = schedule;
151 self
152 }
153
154 pub fn num_microbatches(mut self, num: usize) -> Self {
156 self.num_microbatches = num.max(1);
157 self
158 }
159
160 pub fn num_stages(&self) -> usize {
162 self.stages.len()
163 }
164
165 pub fn get_schedule(&self) -> PipelineSchedule {
167 self.schedule
168 }
169
170 pub fn forward(&self, input: &Variable) -> Variable {
176 match self.schedule {
177 PipelineSchedule::GPipe => self.forward_gpipe(input),
178 PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
179 PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
180 }
181 }
182
183 fn forward_gpipe(&self, input: &Variable) -> Variable {
185 let rank = self.process_group.rank();
186 let num_stages = self.stages.len();
187
188 let microbatches = self.split_microbatches(input);
190
191 let mut outputs = Vec::new();
193
194 for microbatch in microbatches {
195 let mut activation = microbatch;
196
197 for (stage_idx, stage) in self.stages.iter().enumerate() {
199 if stage.device_rank == rank {
200 activation = stage.forward(&activation);
201 }
202
203 if stage_idx < num_stages - 1 {
205 let next_rank = self.stages[stage_idx + 1].device_rank;
206 if stage.device_rank == rank {
207 self.send_activation(&activation, next_rank);
209 } else if next_rank == rank {
210 activation = self.recv_activation(stage.device_rank, activation.shape());
212 }
213 }
214 }
215
216 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
218 outputs.push(activation);
219 }
220 }
221
222 self.combine_microbatches(&outputs)
224 }
225
226 fn forward_1f1b(&self, input: &Variable) -> Variable {
236 let rank = self.process_group.rank();
237 let num_stages = self.stages.len();
238
239 let microbatches = self.split_microbatches(input);
240 let num_mb = microbatches.len();
241
242 if num_mb <= 1 || num_stages <= 1 {
244 return self.forward_gpipe(input);
245 }
246
247 let mut activations: Vec<Option<Variable>> = Vec::with_capacity(num_mb);
249 let mut outputs: Vec<Option<Variable>> = vec![None; num_mb];
250
251 let warmup_count = num_stages.min(num_mb);
253 for mb_idx in 0..warmup_count {
254 let mut activation = microbatches[mb_idx].clone();
255 for (stage_idx, stage) in self.stages.iter().enumerate() {
256 if stage.device_rank == rank {
257 activation = stage.forward(&activation);
258 }
259 if stage_idx < num_stages - 1 {
260 let next_rank = self.stages[stage_idx + 1].device_rank;
261 if stage.device_rank == rank {
262 self.send_activation(&activation, next_rank);
263 } else if next_rank == rank {
264 activation = self.recv_activation(stage.device_rank, activation.shape());
265 }
266 }
267 }
268 activations.push(Some(activation.clone()));
269 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
270 outputs[mb_idx] = Some(activation);
271 }
272 }
273
274 for mb_idx in warmup_count..num_mb {
276 let release_idx = mb_idx - warmup_count;
278 if release_idx < activations.len() {
279 activations[release_idx] = None;
280 }
281
282 let mut activation = microbatches[mb_idx].clone();
284 for (stage_idx, stage) in self.stages.iter().enumerate() {
285 if stage.device_rank == rank {
286 activation = stage.forward(&activation);
287 }
288 if stage_idx < num_stages - 1 {
289 let next_rank = self.stages[stage_idx + 1].device_rank;
290 if stage.device_rank == rank {
291 self.send_activation(&activation, next_rank);
292 } else if next_rank == rank {
293 activation = self.recv_activation(stage.device_rank, activation.shape());
294 }
295 }
296 }
297 activations.push(Some(activation.clone()));
298 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
299 outputs[mb_idx] = Some(activation);
300 }
301 }
302
303 let final_outputs: Vec<Variable> = outputs.into_iter().flatten().collect();
305 self.combine_microbatches(&final_outputs)
306 }
307
308 fn forward_interleaved(&self, input: &Variable) -> Variable {
314 self.forward_1f1b(input)
318 }
319
320 fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
322 let data = input.data();
323 let batch_size = data.shape()[0];
324 let microbatch_size = batch_size.div_ceil(self.num_microbatches);
325
326 let mut microbatches = Vec::new();
327 let flat_data = data.to_vec();
328 let elements_per_sample: usize = data.shape()[1..].iter().product();
329
330 for i in 0..self.num_microbatches {
331 let start = i * microbatch_size;
332 let end = ((i + 1) * microbatch_size).min(batch_size);
333
334 if start >= batch_size {
335 break;
336 }
337
338 let mb_size = end - start;
339 let start_idx = start * elements_per_sample;
340 let end_idx = end * elements_per_sample;
341 let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
342
343 let mut shape = data.shape().to_vec();
344 shape[0] = mb_size;
345 let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
346 microbatches.push(Variable::new(tensor, input.requires_grad()));
347 }
348
349 microbatches
350 }
351
352 fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
354 if outputs.is_empty() {
355 return Variable::new(Tensor::zeros(&[0]), false);
356 }
357
358 if outputs.len() == 1 {
359 return outputs[0].clone();
360 }
361
362 let mut all_data = Vec::new();
364 let mut total_batch = 0;
365 let shape = outputs[0].data().shape().to_vec();
366
367 for output in outputs {
368 all_data.extend(output.data().to_vec());
369 total_batch += output.data().shape()[0];
370 }
371
372 let mut new_shape = shape;
373 new_shape[0] = total_batch;
374 let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
375 Variable::new(tensor, outputs[0].requires_grad())
376 }
377
378 fn send_activation(&self, activation: &Variable, dest_rank: usize) {
380 let mut tensor = activation.data().clone();
381 self.process_group.send_tensor(&mut tensor, dest_rank);
382 }
383
384 fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
386 let tensor = self.process_group.recv_tensor(src_rank, &shape);
387 Variable::new(tensor, true)
388 }
389}
390
391impl<M: Module + Clone> Module for Pipeline<M> {
392 fn forward(&self, input: &Variable) -> Variable {
393 Pipeline::forward(self, input)
394 }
395
396 fn parameters(&self) -> Vec<Parameter> {
397 self.stages.iter().flat_map(|s| s.parameters()).collect()
398 }
399
400 fn train(&mut self) {
401 for stage in &mut self.stages {
402 stage.train();
403 }
404 }
405
406 fn eval(&mut self) {
407 for stage in &mut self.stages {
408 stage.eval();
409 }
410 }
411
412 fn is_training(&self) -> bool {
413 self.stages.first().is_some_and(|s| s.is_training())
414 }
415}
416
417#[derive(Debug, Clone)]
423pub struct PipelineMemoryStats {
424 pub num_stages: usize,
426 pub num_microbatches: usize,
428 pub peak_activations_per_stage: usize,
430 pub schedule: PipelineSchedule,
432}
433
434impl PipelineMemoryStats {
435 pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
437 num_stages * num_microbatches
439 }
440
441 pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
443 num_stages.min(num_microbatches)
445 }
446}
447
448#[cfg(test)]
453mod tests {
454 use super::*;
455 use axonml_nn::Linear;
456
457 #[derive(Clone)]
459 struct IdentityModule {
460 size: usize,
461 training: bool,
462 }
463
464 impl IdentityModule {
465 fn new(size: usize) -> Self {
466 Self {
467 size,
468 training: true,
469 }
470 }
471 }
472
473 impl Module for IdentityModule {
474 fn forward(&self, input: &Variable) -> Variable {
475 input.clone()
476 }
477
478 fn parameters(&self) -> Vec<Parameter> {
479 Vec::new()
480 }
481
482 fn train(&mut self) {
483 self.training = true;
484 }
485
486 fn eval(&mut self) {
487 self.training = false;
488 }
489
490 fn is_training(&self) -> bool {
491 self.training
492 }
493 }
494
495 #[test]
496 fn test_pipeline_schedule_default() {
497 assert_eq!(
498 PipelineSchedule::default(),
499 PipelineSchedule::OneFOneBSchedule
500 );
501 }
502
503 #[test]
504 fn test_pipeline_stage_creation() {
505 let module = Linear::new(10, 5);
506 let stage = PipelineStage::new(module, 0, 0);
507
508 assert_eq!(stage.stage_id(), 0);
509 assert_eq!(stage.device_rank(), 0);
510 }
511
512 #[test]
513 fn test_pipeline_creation() {
514 let modules = vec![
515 IdentityModule::new(10),
516 IdentityModule::new(8),
517 IdentityModule::new(6),
518 ];
519 let pg = ProcessGroup::mock();
520 let pipeline = Pipeline::from_modules(modules, pg)
521 .schedule(PipelineSchedule::GPipe)
522 .num_microbatches(2);
523
524 assert_eq!(pipeline.num_stages(), 3);
525 assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
526 }
527
528 #[test]
529 fn test_pipeline_forward() {
530 let modules = vec![IdentityModule::new(4)];
531 let pg = ProcessGroup::mock();
532 let pipeline = Pipeline::from_modules(modules, pg);
533
534 let input = Variable::new(Tensor::randn(&[2, 4]), false);
535 let output = pipeline.forward(&input);
536
537 assert_eq!(output.data().shape(), &[2, 4]);
538 }
539
540 #[test]
541 fn test_pipeline_memory_stats() {
542 let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
543 let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
544
545 assert_eq!(gpipe, 32); assert_eq!(one_f_one_b, 4); }
548
549 #[test]
550 fn test_split_microbatches() {
551 let modules = vec![IdentityModule::new(4)];
552 let pg = ProcessGroup::mock();
553 let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
554
555 let input = Variable::new(Tensor::randn(&[4, 4]), false);
556 let microbatches = pipeline.split_microbatches(&input);
557
558 assert_eq!(microbatches.len(), 2);
559 assert_eq!(microbatches[0].data().shape()[0], 2);
560 assert_eq!(microbatches[1].data().shape()[0], 2);
561 }
562}