Skip to main content

axonml_distributed/
pipeline.rs

1//! Pipeline Parallelism
2//!
3//! # File
4//! `crates/axonml-distributed/src/pipeline.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::process_group::ProcessGroup;
18use axonml_autograd::Variable;
19use axonml_nn::{Module, Parameter};
20use axonml_tensor::Tensor;
21
22// =============================================================================
23// Pipeline Schedule
24// =============================================================================
25
26/// Pipeline execution schedule.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum PipelineSchedule {
29    /// GPipe: Fill-drain schedule with synchronized updates
30    GPipe,
31    /// 1F1B: One forward, one backward schedule for memory efficiency
32    #[default]
33    OneFOneBSchedule,
34    /// Interleaved 1F1B for better efficiency
35    InterleavedOneFOneB,
36}
37
38// =============================================================================
39// Pipeline Stage
40// =============================================================================
41
42/// A stage in the pipeline.
43pub struct PipelineStage<M: Module> {
44    /// The module for this stage
45    module: M,
46    /// Stage index (0 = first stage)
47    stage_id: usize,
48    /// Device/rank this stage runs on
49    device_rank: usize,
50}
51
52impl<M: Module> PipelineStage<M> {
53    /// Creates a new pipeline stage.
54    pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
55        Self {
56            module,
57            stage_id,
58            device_rank,
59        }
60    }
61
62    /// Returns the stage ID.
63    pub fn stage_id(&self) -> usize {
64        self.stage_id
65    }
66
67    /// Returns the device rank.
68    pub fn device_rank(&self) -> usize {
69        self.device_rank
70    }
71
72    /// Forward pass for this stage.
73    pub fn forward(&self, input: &Variable) -> Variable {
74        self.module.forward(input)
75    }
76}
77
78impl<M: Module> Module for PipelineStage<M> {
79    fn forward(&self, input: &Variable) -> Variable {
80        self.module.forward(input)
81    }
82
83    fn parameters(&self) -> Vec<Parameter> {
84        self.module.parameters()
85    }
86
87    fn train(&mut self) {
88        self.module.train();
89    }
90
91    fn eval(&mut self) {
92        self.module.eval();
93    }
94
95    fn is_training(&self) -> bool {
96        self.module.is_training()
97    }
98}
99
100// =============================================================================
101// Pipeline
102// =============================================================================
103
104/// Pipeline parallel wrapper for distributed training.
105///
106/// Splits model computation across multiple stages, with each stage
107/// potentially running on a different device/rank.
108pub struct Pipeline<M: Module> {
109    /// Pipeline stages
110    stages: Vec<PipelineStage<M>>,
111    /// Process group for communication
112    process_group: ProcessGroup,
113    /// Pipeline schedule
114    schedule: PipelineSchedule,
115    /// Number of microbatches
116    num_microbatches: usize,
117    /// Current rank's stage index.
118    pub local_stage: usize,
119}
120
121impl<M: Module + Clone> Pipeline<M> {
122    /// Creates a new pipeline from a list of modules.
123    pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
124        let world_size = process_group.world_size();
125        let rank = process_group.rank();
126
127        let stages: Vec<PipelineStage<M>> = modules
128            .into_iter()
129            .enumerate()
130            .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
131            .collect();
132
133        let local_stage = stages
134            .iter()
135            .position(|s| s.device_rank == rank)
136            .unwrap_or(0);
137
138        Self {
139            stages,
140            process_group,
141            schedule: PipelineSchedule::default(),
142            num_microbatches: 1,
143            local_stage,
144        }
145    }
146
147    /// Builder: set pipeline schedule.
148    pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
149        self.schedule = schedule;
150        self
151    }
152
153    /// Builder: set number of microbatches.
154    pub fn num_microbatches(mut self, num: usize) -> Self {
155        self.num_microbatches = num.max(1);
156        self
157    }
158
159    /// Returns number of stages.
160    pub fn num_stages(&self) -> usize {
161        self.stages.len()
162    }
163
164    /// Returns the current schedule.
165    pub fn get_schedule(&self) -> PipelineSchedule {
166        self.schedule
167    }
168
169    /// Forward pass through the pipeline.
170    ///
171    /// For the first stage, takes input and forwards to next stage.
172    /// For intermediate stages, receives from previous, forwards to next.
173    /// For the last stage, receives from previous, returns output.
174    pub fn forward(&self, input: &Variable) -> Variable {
175        match self.schedule {
176            PipelineSchedule::GPipe => self.forward_gpipe(input),
177            PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
178            PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
179        }
180    }
181
182    /// GPipe schedule: fill-drain with all forwards then all backwards.
183    fn forward_gpipe(&self, input: &Variable) -> Variable {
184        let rank = self.process_group.rank();
185        let num_stages = self.stages.len();
186
187        // Split input into microbatches
188        let microbatches = self.split_microbatches(input);
189
190        // Process all microbatches through pipeline
191        let mut outputs = Vec::new();
192
193        for microbatch in microbatches {
194            let mut activation = microbatch;
195
196            // Forward through all stages
197            for (stage_idx, stage) in self.stages.iter().enumerate() {
198                if stage.device_rank == rank {
199                    activation = stage.forward(&activation);
200                }
201
202                // Send to next stage if not last
203                if stage_idx < num_stages - 1 {
204                    let next_rank = self.stages[stage_idx + 1].device_rank;
205                    if stage.device_rank == rank {
206                        // Send activation to next stage
207                        self.send_activation(&activation, next_rank);
208                    } else if next_rank == rank {
209                        // Receive activation from previous stage
210                        activation = self.recv_activation(stage.device_rank, activation.shape());
211                    }
212                }
213            }
214
215            // Last stage collects output
216            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
217                outputs.push(activation);
218            }
219        }
220
221        // Combine outputs
222        self.combine_microbatches(&outputs)
223    }
224
225    /// 1F1B schedule: memory-efficient interleaved forward/backward.
226    ///
227    /// Instead of GPipe's fill-all-then-drain-all, 1F1B interleaves:
228    /// 1. Warmup: forward passes to fill the pipeline (num_stages microbatches)
229    /// 2. Steady state: alternate 1 forward + 1 backward per microbatch
230    /// 3. Cooldown: drain remaining backward passes
231    ///
232    /// This limits peak memory to ~(num_stages) activations instead of
233    /// ~(num_microbatches) in GPipe.
234    fn forward_1f1b(&self, input: &Variable) -> Variable {
235        let rank = self.process_group.rank();
236        let num_stages = self.stages.len();
237
238        let microbatches = self.split_microbatches(input);
239        let num_mb = microbatches.len();
240
241        // If only 1 microbatch or 1 stage, 1F1B degenerates to GPipe
242        if num_mb <= 1 || num_stages <= 1 {
243            return self.forward_gpipe(input);
244        }
245
246        // Activations buffer: we only keep up to num_stages in-flight
247        let mut activations: Vec<Option<Variable>> = Vec::with_capacity(num_mb);
248        let mut outputs: Vec<Option<Variable>> = vec![None; num_mb];
249
250        // Phase 1: Warmup — forward the first min(num_stages, num_mb) microbatches
251        let warmup_count = num_stages.min(num_mb);
252        for mb_idx in 0..warmup_count {
253            let mut activation = microbatches[mb_idx].clone();
254            for (stage_idx, stage) in self.stages.iter().enumerate() {
255                if stage.device_rank == rank {
256                    activation = stage.forward(&activation);
257                }
258                if stage_idx < num_stages - 1 {
259                    let next_rank = self.stages[stage_idx + 1].device_rank;
260                    if stage.device_rank == rank {
261                        self.send_activation(&activation, next_rank);
262                    } else if next_rank == rank {
263                        activation = self.recv_activation(stage.device_rank, activation.shape());
264                    }
265                }
266            }
267            activations.push(Some(activation.clone()));
268            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
269                outputs[mb_idx] = Some(activation);
270            }
271        }
272
273        // Phase 2: Steady state — for each remaining microbatch, do 1 forward + release 1 old
274        for mb_idx in warmup_count..num_mb {
275            // Release the oldest activation (simulating backward freeing memory)
276            let release_idx = mb_idx - warmup_count;
277            if release_idx < activations.len() {
278                activations[release_idx] = None;
279            }
280
281            // Forward the new microbatch
282            let mut activation = microbatches[mb_idx].clone();
283            for (stage_idx, stage) in self.stages.iter().enumerate() {
284                if stage.device_rank == rank {
285                    activation = stage.forward(&activation);
286                }
287                if stage_idx < num_stages - 1 {
288                    let next_rank = self.stages[stage_idx + 1].device_rank;
289                    if stage.device_rank == rank {
290                        self.send_activation(&activation, next_rank);
291                    } else if next_rank == rank {
292                        activation = self.recv_activation(stage.device_rank, activation.shape());
293                    }
294                }
295            }
296            activations.push(Some(activation.clone()));
297            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
298                outputs[mb_idx] = Some(activation);
299            }
300        }
301
302        // Combine outputs (filter out None for non-last-stage ranks)
303        let final_outputs: Vec<Variable> = outputs.into_iter().flatten().collect();
304        self.combine_microbatches(&final_outputs)
305    }
306
307    /// Interleaved 1F1B for virtual pipeline parallelism.
308    ///
309    /// Similar to 1F1B but processes multiple virtual stages per rank in a
310    /// round-robin fashion. Falls back to standard 1F1B when virtual stages
311    /// are not configured (single model chunk per rank).
312    fn forward_interleaved(&self, input: &Variable) -> Variable {
313        // Interleaved 1F1B requires multiple model chunks per rank (virtual stages).
314        // When each rank owns exactly one stage (the common case), this is
315        // equivalent to standard 1F1B.
316        self.forward_1f1b(input)
317    }
318
319    /// Splits input into microbatches.
320    fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
321        let data = input.data();
322        let batch_size = data.shape()[0];
323        let microbatch_size = batch_size.div_ceil(self.num_microbatches);
324
325        let mut microbatches = Vec::new();
326        let flat_data = data.to_vec();
327        let elements_per_sample: usize = data.shape()[1..].iter().product();
328
329        for i in 0..self.num_microbatches {
330            let start = i * microbatch_size;
331            let end = ((i + 1) * microbatch_size).min(batch_size);
332
333            if start >= batch_size {
334                break;
335            }
336
337            let mb_size = end - start;
338            let start_idx = start * elements_per_sample;
339            let end_idx = end * elements_per_sample;
340            let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
341
342            let mut shape = data.shape().to_vec();
343            shape[0] = mb_size;
344            let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
345            microbatches.push(Variable::new(tensor, input.requires_grad()));
346        }
347
348        microbatches
349    }
350
351    /// Combines microbatch outputs.
352    fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
353        if outputs.is_empty() {
354            return Variable::new(Tensor::zeros(&[0]), false);
355        }
356
357        if outputs.len() == 1 {
358            return outputs[0].clone();
359        }
360
361        // Concatenate along batch dimension
362        let mut all_data = Vec::new();
363        let mut total_batch = 0;
364        let shape = outputs[0].data().shape().to_vec();
365
366        for output in outputs {
367            all_data.extend(output.data().to_vec());
368            total_batch += output.data().shape()[0];
369        }
370
371        let mut new_shape = shape;
372        new_shape[0] = total_batch;
373        let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
374        Variable::new(tensor, outputs[0].requires_grad())
375    }
376
377    /// Sends activation to another rank.
378    fn send_activation(&self, activation: &Variable, dest_rank: usize) {
379        let mut tensor = activation.data().clone();
380        self.process_group.send_tensor(&mut tensor, dest_rank);
381    }
382
383    /// Receives activation from another rank.
384    fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
385        let tensor = self.process_group.recv_tensor(src_rank, &shape);
386        Variable::new(tensor, true)
387    }
388}
389
390impl<M: Module + Clone> Module for Pipeline<M> {
391    fn forward(&self, input: &Variable) -> Variable {
392        Pipeline::forward(self, input)
393    }
394
395    fn parameters(&self) -> Vec<Parameter> {
396        self.stages.iter().flat_map(|s| s.parameters()).collect()
397    }
398
399    fn train(&mut self) {
400        for stage in &mut self.stages {
401            stage.train();
402        }
403    }
404
405    fn eval(&mut self) {
406        for stage in &mut self.stages {
407            stage.eval();
408        }
409    }
410
411    fn is_training(&self) -> bool {
412        self.stages.first().is_some_and(|s| s.is_training())
413    }
414}
415
416// =============================================================================
417// Pipeline Memory Stats
418// =============================================================================
419
420/// Memory statistics for pipeline parallelism.
421#[derive(Debug, Clone)]
422pub struct PipelineMemoryStats {
423    /// Number of stages
424    pub num_stages: usize,
425    /// Number of microbatches
426    pub num_microbatches: usize,
427    /// Peak activations stored (per stage)
428    pub peak_activations_per_stage: usize,
429    /// Schedule used
430    pub schedule: PipelineSchedule,
431}
432
433impl PipelineMemoryStats {
434    /// Estimates peak activation memory for GPipe.
435    pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
436        // GPipe stores all microbatch activations
437        num_stages * num_microbatches
438    }
439
440    /// Estimates peak activation memory for 1F1B.
441    pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
442        // 1F1B stores at most num_stages activations in steady state
443        num_stages.min(num_microbatches)
444    }
445}
446
447// =============================================================================
448// Tests
449// =============================================================================
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use axonml_nn::Linear;
455
456    /// Simple identity module for testing pipelines (Linear doesn't impl Clone)
457    #[derive(Clone)]
458    struct IdentityModule {
459        size: usize,
460        training: bool,
461    }
462
463    impl IdentityModule {
464        fn new(size: usize) -> Self {
465            Self {
466                size,
467                training: true,
468            }
469        }
470    }
471
472    impl Module for IdentityModule {
473        fn forward(&self, input: &Variable) -> Variable {
474            input.clone()
475        }
476
477        fn parameters(&self) -> Vec<Parameter> {
478            Vec::new()
479        }
480
481        fn train(&mut self) {
482            self.training = true;
483        }
484
485        fn eval(&mut self) {
486            self.training = false;
487        }
488
489        fn is_training(&self) -> bool {
490            self.training
491        }
492    }
493
494    #[test]
495    fn test_pipeline_schedule_default() {
496        assert_eq!(
497            PipelineSchedule::default(),
498            PipelineSchedule::OneFOneBSchedule
499        );
500    }
501
502    #[test]
503    fn test_pipeline_stage_creation() {
504        let module = Linear::new(10, 5);
505        let stage = PipelineStage::new(module, 0, 0);
506
507        assert_eq!(stage.stage_id(), 0);
508        assert_eq!(stage.device_rank(), 0);
509    }
510
511    #[test]
512    fn test_pipeline_creation() {
513        let modules = vec![
514            IdentityModule::new(10),
515            IdentityModule::new(8),
516            IdentityModule::new(6),
517        ];
518        let pg = ProcessGroup::mock();
519        let pipeline = Pipeline::from_modules(modules, pg)
520            .schedule(PipelineSchedule::GPipe)
521            .num_microbatches(2);
522
523        assert_eq!(pipeline.num_stages(), 3);
524        assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
525    }
526
527    #[test]
528    fn test_pipeline_forward() {
529        let modules = vec![IdentityModule::new(4)];
530        let pg = ProcessGroup::mock();
531        let pipeline = Pipeline::from_modules(modules, pg);
532
533        let input = Variable::new(Tensor::randn(&[2, 4]), false);
534        let output = pipeline.forward(&input);
535
536        assert_eq!(output.data().shape(), &[2, 4]);
537    }
538
539    #[test]
540    fn test_pipeline_memory_stats() {
541        let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
542        let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
543
544        assert_eq!(gpipe, 32); // 4 * 8
545        assert_eq!(one_f_one_b, 4); // min(4, 8)
546    }
547
548    #[test]
549    fn test_split_microbatches() {
550        let modules = vec![IdentityModule::new(4)];
551        let pg = ProcessGroup::mock();
552        let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
553
554        let input = Variable::new(Tensor::randn(&[4, 4]), false);
555        let microbatches = pipeline.split_microbatches(&input);
556
557        assert_eq!(microbatches.len(), 2);
558        assert_eq!(microbatches[0].data().shape()[0], 2);
559        assert_eq!(microbatches[1].data().shape()[0], 2);
560    }
561}