Skip to main content

axonml_distributed/
pipeline.rs

1//! Pipeline Parallelism
2//!
3//! # File
4//! `crates/axonml-distributed/src/pipeline.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// Pipeline Schedule
25// =============================================================================
26
27/// Pipeline execution schedule.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum PipelineSchedule {
30    /// GPipe: Fill-drain schedule with synchronized updates
31    GPipe,
32    /// 1F1B: One forward, one backward schedule for memory efficiency
33    #[default]
34    OneFOneBSchedule,
35    /// Interleaved 1F1B for better efficiency
36    InterleavedOneFOneB,
37}
38
39// =============================================================================
40// Pipeline Stage
41// =============================================================================
42
43/// A stage in the pipeline.
44pub struct PipelineStage<M: Module> {
45    /// The module for this stage
46    module: M,
47    /// Stage index (0 = first stage)
48    stage_id: usize,
49    /// Device/rank this stage runs on
50    device_rank: usize,
51}
52
53impl<M: Module> PipelineStage<M> {
54    /// Creates a new pipeline stage.
55    pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
56        Self {
57            module,
58            stage_id,
59            device_rank,
60        }
61    }
62
63    /// Returns the stage ID.
64    pub fn stage_id(&self) -> usize {
65        self.stage_id
66    }
67
68    /// Returns the device rank.
69    pub fn device_rank(&self) -> usize {
70        self.device_rank
71    }
72
73    /// Forward pass for this stage.
74    pub fn forward(&self, input: &Variable) -> Variable {
75        self.module.forward(input)
76    }
77}
78
79impl<M: Module> Module for PipelineStage<M> {
80    fn forward(&self, input: &Variable) -> Variable {
81        self.module.forward(input)
82    }
83
84    fn parameters(&self) -> Vec<Parameter> {
85        self.module.parameters()
86    }
87
88    fn train(&mut self) {
89        self.module.train();
90    }
91
92    fn eval(&mut self) {
93        self.module.eval();
94    }
95
96    fn is_training(&self) -> bool {
97        self.module.is_training()
98    }
99}
100
101// =============================================================================
102// Pipeline
103// =============================================================================
104
105/// Pipeline parallel wrapper for distributed training.
106///
107/// Splits model computation across multiple stages, with each stage
108/// potentially running on a different device/rank.
109pub struct Pipeline<M: Module> {
110    /// Pipeline stages
111    stages: Vec<PipelineStage<M>>,
112    /// Process group for communication
113    process_group: ProcessGroup,
114    /// Pipeline schedule
115    schedule: PipelineSchedule,
116    /// Number of microbatches
117    num_microbatches: usize,
118    /// Current rank's stage index.
119    pub local_stage: usize,
120}
121
122impl<M: Module + Clone> Pipeline<M> {
123    /// Creates a new pipeline from a list of modules.
124    pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
125        let world_size = process_group.world_size();
126        let rank = process_group.rank();
127
128        let stages: Vec<PipelineStage<M>> = modules
129            .into_iter()
130            .enumerate()
131            .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
132            .collect();
133
134        let local_stage = stages
135            .iter()
136            .position(|s| s.device_rank == rank)
137            .unwrap_or(0);
138
139        Self {
140            stages,
141            process_group,
142            schedule: PipelineSchedule::default(),
143            num_microbatches: 1,
144            local_stage,
145        }
146    }
147
148    /// Builder: set pipeline schedule.
149    pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
150        self.schedule = schedule;
151        self
152    }
153
154    /// Builder: set number of microbatches.
155    pub fn num_microbatches(mut self, num: usize) -> Self {
156        self.num_microbatches = num.max(1);
157        self
158    }
159
160    /// Returns number of stages.
161    pub fn num_stages(&self) -> usize {
162        self.stages.len()
163    }
164
165    /// Returns the current schedule.
166    pub fn get_schedule(&self) -> PipelineSchedule {
167        self.schedule
168    }
169
170    /// Forward pass through the pipeline.
171    ///
172    /// For the first stage, takes input and forwards to next stage.
173    /// For intermediate stages, receives from previous, forwards to next.
174    /// For the last stage, receives from previous, returns output.
175    pub fn forward(&self, input: &Variable) -> Variable {
176        match self.schedule {
177            PipelineSchedule::GPipe => self.forward_gpipe(input),
178            PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
179            PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
180        }
181    }
182
183    /// GPipe schedule: fill-drain with all forwards then all backwards.
184    fn forward_gpipe(&self, input: &Variable) -> Variable {
185        let rank = self.process_group.rank();
186        let num_stages = self.stages.len();
187
188        // Split input into microbatches
189        let microbatches = self.split_microbatches(input);
190
191        // Process all microbatches through pipeline
192        let mut outputs = Vec::new();
193
194        for microbatch in microbatches {
195            let mut activation = microbatch;
196
197            // Forward through all stages
198            for (stage_idx, stage) in self.stages.iter().enumerate() {
199                if stage.device_rank == rank {
200                    activation = stage.forward(&activation);
201                }
202
203                // Send to next stage if not last
204                if stage_idx < num_stages - 1 {
205                    let next_rank = self.stages[stage_idx + 1].device_rank;
206                    if stage.device_rank == rank {
207                        // Send activation to next stage
208                        self.send_activation(&activation, next_rank);
209                    } else if next_rank == rank {
210                        // Receive activation from previous stage
211                        activation = self.recv_activation(stage.device_rank, activation.shape());
212                    }
213                }
214            }
215
216            // Last stage collects output
217            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
218                outputs.push(activation);
219            }
220        }
221
222        // Combine outputs
223        self.combine_microbatches(&outputs)
224    }
225
226    /// 1F1B schedule: memory-efficient interleaved forward/backward.
227    ///
228    /// Instead of GPipe's fill-all-then-drain-all, 1F1B interleaves:
229    /// 1. Warmup: forward passes to fill the pipeline (num_stages microbatches)
230    /// 2. Steady state: alternate 1 forward + 1 backward per microbatch
231    /// 3. Cooldown: drain remaining backward passes
232    ///
233    /// This limits peak memory to ~(num_stages) activations instead of
234    /// ~(num_microbatches) in GPipe.
235    fn forward_1f1b(&self, input: &Variable) -> Variable {
236        let rank = self.process_group.rank();
237        let num_stages = self.stages.len();
238
239        let microbatches = self.split_microbatches(input);
240        let num_mb = microbatches.len();
241
242        // If only 1 microbatch or 1 stage, 1F1B degenerates to GPipe
243        if num_mb <= 1 || num_stages <= 1 {
244            return self.forward_gpipe(input);
245        }
246
247        // Activations buffer: we only keep up to num_stages in-flight
248        let mut activations: Vec<Option<Variable>> = Vec::with_capacity(num_mb);
249        let mut outputs: Vec<Option<Variable>> = vec![None; num_mb];
250
251        // Phase 1: Warmup — forward the first min(num_stages, num_mb) microbatches
252        let warmup_count = num_stages.min(num_mb);
253        for mb_idx in 0..warmup_count {
254            let mut activation = microbatches[mb_idx].clone();
255            for (stage_idx, stage) in self.stages.iter().enumerate() {
256                if stage.device_rank == rank {
257                    activation = stage.forward(&activation);
258                }
259                if stage_idx < num_stages - 1 {
260                    let next_rank = self.stages[stage_idx + 1].device_rank;
261                    if stage.device_rank == rank {
262                        self.send_activation(&activation, next_rank);
263                    } else if next_rank == rank {
264                        activation = self.recv_activation(stage.device_rank, activation.shape());
265                    }
266                }
267            }
268            activations.push(Some(activation.clone()));
269            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
270                outputs[mb_idx] = Some(activation);
271            }
272        }
273
274        // Phase 2: Steady state — for each remaining microbatch, do 1 forward + release 1 old
275        for mb_idx in warmup_count..num_mb {
276            // Release the oldest activation (simulating backward freeing memory)
277            let release_idx = mb_idx - warmup_count;
278            if release_idx < activations.len() {
279                activations[release_idx] = None;
280            }
281
282            // Forward the new microbatch
283            let mut activation = microbatches[mb_idx].clone();
284            for (stage_idx, stage) in self.stages.iter().enumerate() {
285                if stage.device_rank == rank {
286                    activation = stage.forward(&activation);
287                }
288                if stage_idx < num_stages - 1 {
289                    let next_rank = self.stages[stage_idx + 1].device_rank;
290                    if stage.device_rank == rank {
291                        self.send_activation(&activation, next_rank);
292                    } else if next_rank == rank {
293                        activation = self.recv_activation(stage.device_rank, activation.shape());
294                    }
295                }
296            }
297            activations.push(Some(activation.clone()));
298            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
299                outputs[mb_idx] = Some(activation);
300            }
301        }
302
303        // Combine outputs (filter out None for non-last-stage ranks)
304        let final_outputs: Vec<Variable> = outputs.into_iter().flatten().collect();
305        self.combine_microbatches(&final_outputs)
306    }
307
308    /// Interleaved 1F1B for virtual pipeline parallelism.
309    ///
310    /// Similar to 1F1B but processes multiple virtual stages per rank in a
311    /// round-robin fashion. Falls back to standard 1F1B when virtual stages
312    /// are not configured (single model chunk per rank).
313    fn forward_interleaved(&self, input: &Variable) -> Variable {
314        // Interleaved 1F1B requires multiple model chunks per rank (virtual stages).
315        // When each rank owns exactly one stage (the common case), this is
316        // equivalent to standard 1F1B.
317        self.forward_1f1b(input)
318    }
319
320    /// Splits input into microbatches.
321    fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
322        let data = input.data();
323        let batch_size = data.shape()[0];
324        let microbatch_size = batch_size.div_ceil(self.num_microbatches);
325
326        let mut microbatches = Vec::new();
327        let flat_data = data.to_vec();
328        let elements_per_sample: usize = data.shape()[1..].iter().product();
329
330        for i in 0..self.num_microbatches {
331            let start = i * microbatch_size;
332            let end = ((i + 1) * microbatch_size).min(batch_size);
333
334            if start >= batch_size {
335                break;
336            }
337
338            let mb_size = end - start;
339            let start_idx = start * elements_per_sample;
340            let end_idx = end * elements_per_sample;
341            let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
342
343            let mut shape = data.shape().to_vec();
344            shape[0] = mb_size;
345            let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
346            microbatches.push(Variable::new(tensor, input.requires_grad()));
347        }
348
349        microbatches
350    }
351
352    /// Combines microbatch outputs.
353    fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
354        if outputs.is_empty() {
355            return Variable::new(Tensor::zeros(&[0]), false);
356        }
357
358        if outputs.len() == 1 {
359            return outputs[0].clone();
360        }
361
362        // Concatenate along batch dimension
363        let mut all_data = Vec::new();
364        let mut total_batch = 0;
365        let shape = outputs[0].data().shape().to_vec();
366
367        for output in outputs {
368            all_data.extend(output.data().to_vec());
369            total_batch += output.data().shape()[0];
370        }
371
372        let mut new_shape = shape;
373        new_shape[0] = total_batch;
374        let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
375        Variable::new(tensor, outputs[0].requires_grad())
376    }
377
378    /// Sends activation to another rank.
379    fn send_activation(&self, activation: &Variable, dest_rank: usize) {
380        let mut tensor = activation.data().clone();
381        self.process_group.send_tensor(&mut tensor, dest_rank);
382    }
383
384    /// Receives activation from another rank.
385    fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
386        let tensor = self.process_group.recv_tensor(src_rank, &shape);
387        Variable::new(tensor, true)
388    }
389}
390
391impl<M: Module + Clone> Module for Pipeline<M> {
392    fn forward(&self, input: &Variable) -> Variable {
393        Pipeline::forward(self, input)
394    }
395
396    fn parameters(&self) -> Vec<Parameter> {
397        self.stages.iter().flat_map(|s| s.parameters()).collect()
398    }
399
400    fn train(&mut self) {
401        for stage in &mut self.stages {
402            stage.train();
403        }
404    }
405
406    fn eval(&mut self) {
407        for stage in &mut self.stages {
408            stage.eval();
409        }
410    }
411
412    fn is_training(&self) -> bool {
413        self.stages.first().is_some_and(|s| s.is_training())
414    }
415}
416
417// =============================================================================
418// Pipeline Memory Stats
419// =============================================================================
420
421/// Memory statistics for pipeline parallelism.
422#[derive(Debug, Clone)]
423pub struct PipelineMemoryStats {
424    /// Number of stages
425    pub num_stages: usize,
426    /// Number of microbatches
427    pub num_microbatches: usize,
428    /// Peak activations stored (per stage)
429    pub peak_activations_per_stage: usize,
430    /// Schedule used
431    pub schedule: PipelineSchedule,
432}
433
434impl PipelineMemoryStats {
435    /// Estimates peak activation memory for GPipe.
436    pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
437        // GPipe stores all microbatch activations
438        num_stages * num_microbatches
439    }
440
441    /// Estimates peak activation memory for 1F1B.
442    pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
443        // 1F1B stores at most num_stages activations in steady state
444        num_stages.min(num_microbatches)
445    }
446}
447
448// =============================================================================
449// Tests
450// =============================================================================
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use axonml_nn::Linear;
456
457    /// Simple identity module for testing pipelines (Linear doesn't impl Clone)
458    #[derive(Clone)]
459    struct IdentityModule {
460        size: usize,
461        training: bool,
462    }
463
464    impl IdentityModule {
465        fn new(size: usize) -> Self {
466            Self {
467                size,
468                training: true,
469            }
470        }
471    }
472
473    impl Module for IdentityModule {
474        fn forward(&self, input: &Variable) -> Variable {
475            input.clone()
476        }
477
478        fn parameters(&self) -> Vec<Parameter> {
479            Vec::new()
480        }
481
482        fn train(&mut self) {
483            self.training = true;
484        }
485
486        fn eval(&mut self) {
487            self.training = false;
488        }
489
490        fn is_training(&self) -> bool {
491            self.training
492        }
493    }
494
495    #[test]
496    fn test_pipeline_schedule_default() {
497        assert_eq!(
498            PipelineSchedule::default(),
499            PipelineSchedule::OneFOneBSchedule
500        );
501    }
502
503    #[test]
504    fn test_pipeline_stage_creation() {
505        let module = Linear::new(10, 5);
506        let stage = PipelineStage::new(module, 0, 0);
507
508        assert_eq!(stage.stage_id(), 0);
509        assert_eq!(stage.device_rank(), 0);
510    }
511
512    #[test]
513    fn test_pipeline_creation() {
514        let modules = vec![
515            IdentityModule::new(10),
516            IdentityModule::new(8),
517            IdentityModule::new(6),
518        ];
519        let pg = ProcessGroup::mock();
520        let pipeline = Pipeline::from_modules(modules, pg)
521            .schedule(PipelineSchedule::GPipe)
522            .num_microbatches(2);
523
524        assert_eq!(pipeline.num_stages(), 3);
525        assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
526    }
527
528    #[test]
529    fn test_pipeline_forward() {
530        let modules = vec![IdentityModule::new(4)];
531        let pg = ProcessGroup::mock();
532        let pipeline = Pipeline::from_modules(modules, pg);
533
534        let input = Variable::new(Tensor::randn(&[2, 4]), false);
535        let output = pipeline.forward(&input);
536
537        assert_eq!(output.data().shape(), &[2, 4]);
538    }
539
540    #[test]
541    fn test_pipeline_memory_stats() {
542        let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
543        let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
544
545        assert_eq!(gpipe, 32); // 4 * 8
546        assert_eq!(one_f_one_b, 4); // min(4, 8)
547    }
548
549    #[test]
550    fn test_split_microbatches() {
551        let modules = vec![IdentityModule::new(4)];
552        let pg = ProcessGroup::mock();
553        let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
554
555        let input = Variable::new(Tensor::randn(&[4, 4]), false);
556        let microbatches = pipeline.split_microbatches(&input);
557
558        assert_eq!(microbatches.len(), 2);
559        assert_eq!(microbatches[0].data().shape()[0], 2);
560        assert_eq!(microbatches[1].data().shape()[0], 2);
561    }
562}