1use crate::process_group::ProcessGroup;
18use axonml_autograd::Variable;
19use axonml_nn::{Module, Parameter};
20use axonml_tensor::Tensor;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum PipelineSchedule {
29 GPipe,
31 #[default]
33 OneFOneBSchedule,
34 InterleavedOneFOneB,
36}
37
38pub struct PipelineStage<M: Module> {
44 module: M,
46 stage_id: usize,
48 device_rank: usize,
50}
51
52impl<M: Module> PipelineStage<M> {
53 pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
55 Self {
56 module,
57 stage_id,
58 device_rank,
59 }
60 }
61
62 pub fn stage_id(&self) -> usize {
64 self.stage_id
65 }
66
67 pub fn device_rank(&self) -> usize {
69 self.device_rank
70 }
71
72 pub fn forward(&self, input: &Variable) -> Variable {
74 self.module.forward(input)
75 }
76}
77
78impl<M: Module> Module for PipelineStage<M> {
79 fn forward(&self, input: &Variable) -> Variable {
80 self.module.forward(input)
81 }
82
83 fn parameters(&self) -> Vec<Parameter> {
84 self.module.parameters()
85 }
86
87 fn train(&mut self) {
88 self.module.train();
89 }
90
91 fn eval(&mut self) {
92 self.module.eval();
93 }
94
95 fn is_training(&self) -> bool {
96 self.module.is_training()
97 }
98}
99
100pub struct Pipeline<M: Module> {
109 stages: Vec<PipelineStage<M>>,
111 process_group: ProcessGroup,
113 schedule: PipelineSchedule,
115 num_microbatches: usize,
117 pub local_stage: usize,
119}
120
121impl<M: Module + Clone> Pipeline<M> {
122 pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
124 let world_size = process_group.world_size();
125 let rank = process_group.rank();
126
127 let stages: Vec<PipelineStage<M>> = modules
128 .into_iter()
129 .enumerate()
130 .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
131 .collect();
132
133 let local_stage = stages
134 .iter()
135 .position(|s| s.device_rank == rank)
136 .unwrap_or(0);
137
138 Self {
139 stages,
140 process_group,
141 schedule: PipelineSchedule::default(),
142 num_microbatches: 1,
143 local_stage,
144 }
145 }
146
147 pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
149 self.schedule = schedule;
150 self
151 }
152
153 pub fn num_microbatches(mut self, num: usize) -> Self {
155 self.num_microbatches = num.max(1);
156 self
157 }
158
159 pub fn num_stages(&self) -> usize {
161 self.stages.len()
162 }
163
164 pub fn get_schedule(&self) -> PipelineSchedule {
166 self.schedule
167 }
168
169 pub fn forward(&self, input: &Variable) -> Variable {
175 match self.schedule {
176 PipelineSchedule::GPipe => self.forward_gpipe(input),
177 PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
178 PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
179 }
180 }
181
182 fn forward_gpipe(&self, input: &Variable) -> Variable {
184 let rank = self.process_group.rank();
185 let num_stages = self.stages.len();
186
187 let microbatches = self.split_microbatches(input);
189
190 let mut outputs = Vec::new();
192
193 for microbatch in microbatches {
194 let mut activation = microbatch;
195
196 for (stage_idx, stage) in self.stages.iter().enumerate() {
198 if stage.device_rank == rank {
199 activation = stage.forward(&activation);
200 }
201
202 if stage_idx < num_stages - 1 {
204 let next_rank = self.stages[stage_idx + 1].device_rank;
205 if stage.device_rank == rank {
206 self.send_activation(&activation, next_rank);
208 } else if next_rank == rank {
209 activation = self.recv_activation(stage.device_rank, activation.shape());
211 }
212 }
213 }
214
215 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
217 outputs.push(activation);
218 }
219 }
220
221 self.combine_microbatches(&outputs)
223 }
224
225 fn forward_1f1b(&self, input: &Variable) -> Variable {
235 let rank = self.process_group.rank();
236 let num_stages = self.stages.len();
237
238 let microbatches = self.split_microbatches(input);
239 let num_mb = microbatches.len();
240
241 if num_mb <= 1 || num_stages <= 1 {
243 return self.forward_gpipe(input);
244 }
245
246 let mut activations: Vec<Option<Variable>> = Vec::with_capacity(num_mb);
248 let mut outputs: Vec<Option<Variable>> = vec![None; num_mb];
249
250 let warmup_count = num_stages.min(num_mb);
252 for mb_idx in 0..warmup_count {
253 let mut activation = microbatches[mb_idx].clone();
254 for (stage_idx, stage) in self.stages.iter().enumerate() {
255 if stage.device_rank == rank {
256 activation = stage.forward(&activation);
257 }
258 if stage_idx < num_stages - 1 {
259 let next_rank = self.stages[stage_idx + 1].device_rank;
260 if stage.device_rank == rank {
261 self.send_activation(&activation, next_rank);
262 } else if next_rank == rank {
263 activation = self.recv_activation(stage.device_rank, activation.shape());
264 }
265 }
266 }
267 activations.push(Some(activation.clone()));
268 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
269 outputs[mb_idx] = Some(activation);
270 }
271 }
272
273 for mb_idx in warmup_count..num_mb {
275 let release_idx = mb_idx - warmup_count;
277 if release_idx < activations.len() {
278 activations[release_idx] = None;
279 }
280
281 let mut activation = microbatches[mb_idx].clone();
283 for (stage_idx, stage) in self.stages.iter().enumerate() {
284 if stage.device_rank == rank {
285 activation = stage.forward(&activation);
286 }
287 if stage_idx < num_stages - 1 {
288 let next_rank = self.stages[stage_idx + 1].device_rank;
289 if stage.device_rank == rank {
290 self.send_activation(&activation, next_rank);
291 } else if next_rank == rank {
292 activation = self.recv_activation(stage.device_rank, activation.shape());
293 }
294 }
295 }
296 activations.push(Some(activation.clone()));
297 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
298 outputs[mb_idx] = Some(activation);
299 }
300 }
301
302 let final_outputs: Vec<Variable> = outputs.into_iter().flatten().collect();
304 self.combine_microbatches(&final_outputs)
305 }
306
307 fn forward_interleaved(&self, input: &Variable) -> Variable {
313 self.forward_1f1b(input)
317 }
318
319 fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
321 let data = input.data();
322 let batch_size = data.shape()[0];
323 let microbatch_size = batch_size.div_ceil(self.num_microbatches);
324
325 let mut microbatches = Vec::new();
326 let flat_data = data.to_vec();
327 let elements_per_sample: usize = data.shape()[1..].iter().product();
328
329 for i in 0..self.num_microbatches {
330 let start = i * microbatch_size;
331 let end = ((i + 1) * microbatch_size).min(batch_size);
332
333 if start >= batch_size {
334 break;
335 }
336
337 let mb_size = end - start;
338 let start_idx = start * elements_per_sample;
339 let end_idx = end * elements_per_sample;
340 let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
341
342 let mut shape = data.shape().to_vec();
343 shape[0] = mb_size;
344 let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
345 microbatches.push(Variable::new(tensor, input.requires_grad()));
346 }
347
348 microbatches
349 }
350
351 fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
353 if outputs.is_empty() {
354 return Variable::new(Tensor::zeros(&[0]), false);
355 }
356
357 if outputs.len() == 1 {
358 return outputs[0].clone();
359 }
360
361 let mut all_data = Vec::new();
363 let mut total_batch = 0;
364 let shape = outputs[0].data().shape().to_vec();
365
366 for output in outputs {
367 all_data.extend(output.data().to_vec());
368 total_batch += output.data().shape()[0];
369 }
370
371 let mut new_shape = shape;
372 new_shape[0] = total_batch;
373 let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
374 Variable::new(tensor, outputs[0].requires_grad())
375 }
376
377 fn send_activation(&self, activation: &Variable, dest_rank: usize) {
379 let mut tensor = activation.data().clone();
380 self.process_group.send_tensor(&mut tensor, dest_rank);
381 }
382
383 fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
385 let tensor = self.process_group.recv_tensor(src_rank, &shape);
386 Variable::new(tensor, true)
387 }
388}
389
390impl<M: Module + Clone> Module for Pipeline<M> {
391 fn forward(&self, input: &Variable) -> Variable {
392 Pipeline::forward(self, input)
393 }
394
395 fn parameters(&self) -> Vec<Parameter> {
396 self.stages.iter().flat_map(|s| s.parameters()).collect()
397 }
398
399 fn train(&mut self) {
400 for stage in &mut self.stages {
401 stage.train();
402 }
403 }
404
405 fn eval(&mut self) {
406 for stage in &mut self.stages {
407 stage.eval();
408 }
409 }
410
411 fn is_training(&self) -> bool {
412 self.stages.first().is_some_and(|s| s.is_training())
413 }
414}
415
416#[derive(Debug, Clone)]
422pub struct PipelineMemoryStats {
423 pub num_stages: usize,
425 pub num_microbatches: usize,
427 pub peak_activations_per_stage: usize,
429 pub schedule: PipelineSchedule,
431}
432
433impl PipelineMemoryStats {
434 pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
436 num_stages * num_microbatches
438 }
439
440 pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
442 num_stages.min(num_microbatches)
444 }
445}
446
447#[cfg(test)]
452mod tests {
453 use super::*;
454 use axonml_nn::Linear;
455
456 #[derive(Clone)]
458 struct IdentityModule {
459 size: usize,
460 training: bool,
461 }
462
463 impl IdentityModule {
464 fn new(size: usize) -> Self {
465 Self {
466 size,
467 training: true,
468 }
469 }
470 }
471
472 impl Module for IdentityModule {
473 fn forward(&self, input: &Variable) -> Variable {
474 input.clone()
475 }
476
477 fn parameters(&self) -> Vec<Parameter> {
478 Vec::new()
479 }
480
481 fn train(&mut self) {
482 self.training = true;
483 }
484
485 fn eval(&mut self) {
486 self.training = false;
487 }
488
489 fn is_training(&self) -> bool {
490 self.training
491 }
492 }
493
494 #[test]
495 fn test_pipeline_schedule_default() {
496 assert_eq!(
497 PipelineSchedule::default(),
498 PipelineSchedule::OneFOneBSchedule
499 );
500 }
501
502 #[test]
503 fn test_pipeline_stage_creation() {
504 let module = Linear::new(10, 5);
505 let stage = PipelineStage::new(module, 0, 0);
506
507 assert_eq!(stage.stage_id(), 0);
508 assert_eq!(stage.device_rank(), 0);
509 }
510
511 #[test]
512 fn test_pipeline_creation() {
513 let modules = vec![
514 IdentityModule::new(10),
515 IdentityModule::new(8),
516 IdentityModule::new(6),
517 ];
518 let pg = ProcessGroup::mock();
519 let pipeline = Pipeline::from_modules(modules, pg)
520 .schedule(PipelineSchedule::GPipe)
521 .num_microbatches(2);
522
523 assert_eq!(pipeline.num_stages(), 3);
524 assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
525 }
526
527 #[test]
528 fn test_pipeline_forward() {
529 let modules = vec![IdentityModule::new(4)];
530 let pg = ProcessGroup::mock();
531 let pipeline = Pipeline::from_modules(modules, pg);
532
533 let input = Variable::new(Tensor::randn(&[2, 4]), false);
534 let output = pipeline.forward(&input);
535
536 assert_eq!(output.data().shape(), &[2, 4]);
537 }
538
539 #[test]
540 fn test_pipeline_memory_stats() {
541 let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
542 let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
543
544 assert_eq!(gpipe, 32); assert_eq!(one_f_one_b, 4); }
547
548 #[test]
549 fn test_split_microbatches() {
550 let modules = vec![IdentityModule::new(4)];
551 let pg = ProcessGroup::mock();
552 let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
553
554 let input = Variable::new(Tensor::randn(&[4, 4]), false);
555 let microbatches = pipeline.split_microbatches(&input);
556
557 assert_eq!(microbatches.len(), 2);
558 assert_eq!(microbatches[0].data().shape()[0], 2);
559 assert_eq!(microbatches[1].data().shape()[0], 2);
560 }
561}