1use crate::process_group::ProcessGroup;
21use axonml_autograd::Variable;
22use axonml_nn::{Module, Parameter};
23use axonml_tensor::Tensor;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum PipelineSchedule {
32 GPipe,
34 OneFOneBSchedule,
36 InterleavedOneFOneB,
38}
39
40impl Default for PipelineSchedule {
41 fn default() -> Self {
42 Self::OneFOneBSchedule
43 }
44}
45
46pub struct PipelineStage<M: Module> {
52 module: M,
54 stage_id: usize,
56 device_rank: usize,
58}
59
60impl<M: Module> PipelineStage<M> {
61 pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
63 Self {
64 module,
65 stage_id,
66 device_rank,
67 }
68 }
69
70 pub fn stage_id(&self) -> usize {
72 self.stage_id
73 }
74
75 pub fn device_rank(&self) -> usize {
77 self.device_rank
78 }
79
80 pub fn forward(&self, input: &Variable) -> Variable {
82 self.module.forward(input)
83 }
84}
85
86impl<M: Module> Module for PipelineStage<M> {
87 fn forward(&self, input: &Variable) -> Variable {
88 self.module.forward(input)
89 }
90
91 fn parameters(&self) -> Vec<Parameter> {
92 self.module.parameters()
93 }
94
95 fn train(&mut self) {
96 self.module.train();
97 }
98
99 fn eval(&mut self) {
100 self.module.eval();
101 }
102
103 fn is_training(&self) -> bool {
104 self.module.is_training()
105 }
106}
107
108pub struct Pipeline<M: Module> {
117 stages: Vec<PipelineStage<M>>,
119 process_group: ProcessGroup,
121 schedule: PipelineSchedule,
123 num_microbatches: usize,
125 #[allow(dead_code)]
127 local_stage: usize,
128}
129
130impl<M: Module + Clone> Pipeline<M> {
131 pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
133 let world_size = process_group.world_size();
134 let rank = process_group.rank();
135
136 let stages: Vec<PipelineStage<M>> = modules
137 .into_iter()
138 .enumerate()
139 .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
140 .collect();
141
142 let local_stage = stages
143 .iter()
144 .position(|s| s.device_rank == rank)
145 .unwrap_or(0);
146
147 Self {
148 stages,
149 process_group,
150 schedule: PipelineSchedule::default(),
151 num_microbatches: 1,
152 local_stage,
153 }
154 }
155
156 pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
158 self.schedule = schedule;
159 self
160 }
161
162 pub fn num_microbatches(mut self, num: usize) -> Self {
164 self.num_microbatches = num.max(1);
165 self
166 }
167
168 pub fn num_stages(&self) -> usize {
170 self.stages.len()
171 }
172
173 pub fn get_schedule(&self) -> PipelineSchedule {
175 self.schedule
176 }
177
178 pub fn forward(&self, input: &Variable) -> Variable {
184 match self.schedule {
185 PipelineSchedule::GPipe => self.forward_gpipe(input),
186 PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
187 PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
188 }
189 }
190
191 fn forward_gpipe(&self, input: &Variable) -> Variable {
193 let rank = self.process_group.rank();
194 let num_stages = self.stages.len();
195
196 let microbatches = self.split_microbatches(input);
198
199 let mut outputs = Vec::new();
201
202 for microbatch in microbatches {
203 let mut activation = microbatch;
204
205 for (stage_idx, stage) in self.stages.iter().enumerate() {
207 if stage.device_rank == rank {
208 activation = stage.forward(&activation);
209 }
210
211 if stage_idx < num_stages - 1 {
213 let next_rank = self.stages[stage_idx + 1].device_rank;
214 if stage.device_rank == rank {
215 self.send_activation(&activation, next_rank);
217 } else if next_rank == rank {
218 activation = self.recv_activation(stage.device_rank, activation.shape());
220 }
221 }
222 }
223
224 if self.stages.last().map(|s| s.device_rank) == Some(rank) {
226 outputs.push(activation);
227 }
228 }
229
230 self.combine_microbatches(&outputs)
232 }
233
234 fn forward_1f1b(&self, input: &Variable) -> Variable {
236 self.forward_gpipe(input)
239 }
240
241 fn forward_interleaved(&self, input: &Variable) -> Variable {
243 self.forward_gpipe(input)
245 }
246
247 fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
249 let data = input.data();
250 let batch_size = data.shape()[0];
251 let microbatch_size = (batch_size + self.num_microbatches - 1) / self.num_microbatches;
252
253 let mut microbatches = Vec::new();
254 let flat_data = data.to_vec();
255 let elements_per_sample: usize = data.shape()[1..].iter().product();
256
257 for i in 0..self.num_microbatches {
258 let start = i * microbatch_size;
259 let end = ((i + 1) * microbatch_size).min(batch_size);
260
261 if start >= batch_size {
262 break;
263 }
264
265 let mb_size = end - start;
266 let start_idx = start * elements_per_sample;
267 let end_idx = end * elements_per_sample;
268 let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
269
270 let mut shape = data.shape().to_vec();
271 shape[0] = mb_size;
272 let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
273 microbatches.push(Variable::new(tensor, input.requires_grad()));
274 }
275
276 microbatches
277 }
278
279 fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
281 if outputs.is_empty() {
282 return Variable::new(Tensor::zeros(&[0]), false);
283 }
284
285 if outputs.len() == 1 {
286 return outputs[0].clone();
287 }
288
289 let mut all_data = Vec::new();
291 let mut total_batch = 0;
292 let shape = outputs[0].data().shape().to_vec();
293
294 for output in outputs {
295 all_data.extend(output.data().to_vec());
296 total_batch += output.data().shape()[0];
297 }
298
299 let mut new_shape = shape;
300 new_shape[0] = total_batch;
301 let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
302 Variable::new(tensor, outputs[0].requires_grad())
303 }
304
305 fn send_activation(&self, activation: &Variable, dest_rank: usize) {
307 let mut tensor = activation.data().clone();
308 self.process_group.send_tensor(&mut tensor, dest_rank);
309 }
310
311 fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
313 let tensor = self.process_group.recv_tensor(src_rank, &shape);
314 Variable::new(tensor, true)
315 }
316}
317
318impl<M: Module + Clone> Module for Pipeline<M> {
319 fn forward(&self, input: &Variable) -> Variable {
320 Pipeline::forward(self, input)
321 }
322
323 fn parameters(&self) -> Vec<Parameter> {
324 self.stages
325 .iter()
326 .flat_map(|s| s.parameters())
327 .collect()
328 }
329
330 fn train(&mut self) {
331 for stage in &mut self.stages {
332 stage.train();
333 }
334 }
335
336 fn eval(&mut self) {
337 for stage in &mut self.stages {
338 stage.eval();
339 }
340 }
341
342 fn is_training(&self) -> bool {
343 self.stages.first().map(|s| s.is_training()).unwrap_or(false)
344 }
345}
346
347#[derive(Debug, Clone)]
353pub struct PipelineMemoryStats {
354 pub num_stages: usize,
356 pub num_microbatches: usize,
358 pub peak_activations_per_stage: usize,
360 pub schedule: PipelineSchedule,
362}
363
364impl PipelineMemoryStats {
365 pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
367 num_stages * num_microbatches
369 }
370
371 pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
373 num_stages.min(num_microbatches)
375 }
376}
377
378#[cfg(test)]
383mod tests {
384 use super::*;
385 use axonml_nn::Linear;
386
387 #[derive(Clone)]
389 struct IdentityModule {
390 size: usize,
391 training: bool,
392 }
393
394 impl IdentityModule {
395 fn new(size: usize) -> Self {
396 Self { size, training: true }
397 }
398 }
399
400 impl Module for IdentityModule {
401 fn forward(&self, input: &Variable) -> Variable {
402 input.clone()
403 }
404
405 fn parameters(&self) -> Vec<Parameter> {
406 Vec::new()
407 }
408
409 fn train(&mut self) {
410 self.training = true;
411 }
412
413 fn eval(&mut self) {
414 self.training = false;
415 }
416
417 fn is_training(&self) -> bool {
418 self.training
419 }
420 }
421
422 #[test]
423 fn test_pipeline_schedule_default() {
424 assert_eq!(PipelineSchedule::default(), PipelineSchedule::OneFOneBSchedule);
425 }
426
427 #[test]
428 fn test_pipeline_stage_creation() {
429 let module = Linear::new(10, 5);
430 let stage = PipelineStage::new(module, 0, 0);
431
432 assert_eq!(stage.stage_id(), 0);
433 assert_eq!(stage.device_rank(), 0);
434 }
435
436 #[test]
437 fn test_pipeline_creation() {
438 let modules = vec![
439 IdentityModule::new(10),
440 IdentityModule::new(8),
441 IdentityModule::new(6),
442 ];
443 let pg = ProcessGroup::mock();
444 let pipeline = Pipeline::from_modules(modules, pg)
445 .schedule(PipelineSchedule::GPipe)
446 .num_microbatches(2);
447
448 assert_eq!(pipeline.num_stages(), 3);
449 assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
450 }
451
452 #[test]
453 fn test_pipeline_forward() {
454 let modules = vec![
455 IdentityModule::new(4),
456 ];
457 let pg = ProcessGroup::mock();
458 let pipeline = Pipeline::from_modules(modules, pg);
459
460 let input = Variable::new(Tensor::randn(&[2, 4]), false);
461 let output = pipeline.forward(&input);
462
463 assert_eq!(output.data().shape(), &[2, 4]);
464 }
465
466 #[test]
467 fn test_pipeline_memory_stats() {
468 let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
469 let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
470
471 assert_eq!(gpipe, 32); assert_eq!(one_f_one_b, 4); }
474
475 #[test]
476 fn test_split_microbatches() {
477 let modules = vec![IdentityModule::new(4)];
478 let pg = ProcessGroup::mock();
479 let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
480
481 let input = Variable::new(Tensor::randn(&[4, 4]), false);
482 let microbatches = pipeline.split_microbatches(&input);
483
484 assert_eq!(microbatches.len(), 2);
485 assert_eq!(microbatches[0].data().shape()[0], 2);
486 assert_eq!(microbatches[1].data().shape()[0], 2);
487 }
488}