Skip to main content

axonml_distributed/
pipeline.rs

1//! Pipeline Parallelism
2//!
3//! # File
4//! `crates/axonml-distributed/src/pipeline.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::process_group::ProcessGroup;
18use axonml_autograd::Variable;
19use axonml_nn::{Module, Parameter};
20use axonml_tensor::Tensor;
21
22// =============================================================================
23// Pipeline Schedule
24// =============================================================================
25
26/// Pipeline execution schedule.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum PipelineSchedule {
29    /// GPipe: Fill-drain schedule with synchronized updates
30    GPipe,
31    /// 1F1B: One forward, one backward schedule for memory efficiency
32    #[default]
33    OneFOneBSchedule,
34    /// Interleaved 1F1B for better efficiency
35    InterleavedOneFOneB,
36}
37
38// =============================================================================
39// Pipeline Stage
40// =============================================================================
41
42/// A stage in the pipeline.
43pub struct PipelineStage<M: Module> {
44    /// The module for this stage
45    module: M,
46    /// Stage index (0 = first stage)
47    stage_id: usize,
48    /// Device/rank this stage runs on
49    device_rank: usize,
50}
51
52impl<M: Module> PipelineStage<M> {
53    /// Creates a new pipeline stage.
54    pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
55        Self {
56            module,
57            stage_id,
58            device_rank,
59        }
60    }
61
62    /// Returns the stage ID.
63    pub fn stage_id(&self) -> usize {
64        self.stage_id
65    }
66
67    /// Returns the device rank.
68    pub fn device_rank(&self) -> usize {
69        self.device_rank
70    }
71
72    /// Forward pass for this stage.
73    pub fn forward(&self, input: &Variable) -> Variable {
74        self.module.forward(input)
75    }
76}
77
78impl<M: Module> Module for PipelineStage<M> {
79    fn forward(&self, input: &Variable) -> Variable {
80        self.module.forward(input)
81    }
82
83    fn parameters(&self) -> Vec<Parameter> {
84        self.module.parameters()
85    }
86
87    fn train(&mut self) {
88        self.module.train();
89    }
90
91    fn eval(&mut self) {
92        self.module.eval();
93    }
94
95    fn is_training(&self) -> bool {
96        self.module.is_training()
97    }
98}
99
100// =============================================================================
101// Pipeline
102// =============================================================================
103
104/// Pipeline parallel wrapper for distributed training.
105///
106/// Splits model computation across multiple stages, with each stage
107/// potentially running on a different device/rank.
108pub struct Pipeline<M: Module> {
109    /// Pipeline stages
110    stages: Vec<PipelineStage<M>>,
111    /// Process group for communication
112    process_group: ProcessGroup,
113    /// Pipeline schedule
114    schedule: PipelineSchedule,
115    /// Number of microbatches
116    num_microbatches: usize,
117    /// Current rank's stage index (for future use with multi-GPU)
118    #[allow(dead_code)]
119    local_stage: usize,
120}
121
122impl<M: Module + Clone> Pipeline<M> {
123    /// Creates a new pipeline from a list of modules.
124    pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
125        let world_size = process_group.world_size();
126        let rank = process_group.rank();
127
128        let stages: Vec<PipelineStage<M>> = modules
129            .into_iter()
130            .enumerate()
131            .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
132            .collect();
133
134        let local_stage = stages
135            .iter()
136            .position(|s| s.device_rank == rank)
137            .unwrap_or(0);
138
139        Self {
140            stages,
141            process_group,
142            schedule: PipelineSchedule::default(),
143            num_microbatches: 1,
144            local_stage,
145        }
146    }
147
148    /// Builder: set pipeline schedule.
149    pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
150        self.schedule = schedule;
151        self
152    }
153
154    /// Builder: set number of microbatches.
155    pub fn num_microbatches(mut self, num: usize) -> Self {
156        self.num_microbatches = num.max(1);
157        self
158    }
159
160    /// Returns number of stages.
161    pub fn num_stages(&self) -> usize {
162        self.stages.len()
163    }
164
165    /// Returns the current schedule.
166    pub fn get_schedule(&self) -> PipelineSchedule {
167        self.schedule
168    }
169
170    /// Forward pass through the pipeline.
171    ///
172    /// For the first stage, takes input and forwards to next stage.
173    /// For intermediate stages, receives from previous, forwards to next.
174    /// For the last stage, receives from previous, returns output.
175    pub fn forward(&self, input: &Variable) -> Variable {
176        match self.schedule {
177            PipelineSchedule::GPipe => self.forward_gpipe(input),
178            PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
179            PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
180        }
181    }
182
183    /// GPipe schedule: fill-drain with all forwards then all backwards.
184    fn forward_gpipe(&self, input: &Variable) -> Variable {
185        let rank = self.process_group.rank();
186        let num_stages = self.stages.len();
187
188        // Split input into microbatches
189        let microbatches = self.split_microbatches(input);
190
191        // Process all microbatches through pipeline
192        let mut outputs = Vec::new();
193
194        for microbatch in microbatches {
195            let mut activation = microbatch;
196
197            // Forward through all stages
198            for (stage_idx, stage) in self.stages.iter().enumerate() {
199                if stage.device_rank == rank {
200                    activation = stage.forward(&activation);
201                }
202
203                // Send to next stage if not last
204                if stage_idx < num_stages - 1 {
205                    let next_rank = self.stages[stage_idx + 1].device_rank;
206                    if stage.device_rank == rank {
207                        // Send activation to next stage
208                        self.send_activation(&activation, next_rank);
209                    } else if next_rank == rank {
210                        // Receive activation from previous stage
211                        activation = self.recv_activation(stage.device_rank, activation.shape());
212                    }
213                }
214            }
215
216            // Last stage collects output
217            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
218                outputs.push(activation);
219            }
220        }
221
222        // Combine outputs
223        self.combine_microbatches(&outputs)
224    }
225
226    /// 1F1B schedule: memory-efficient interleaved forward/backward.
227    fn forward_1f1b(&self, input: &Variable) -> Variable {
228        // For simplicity, fall back to GPipe in this implementation
229        // Full 1F1B requires careful scheduling of forward/backward passes
230        self.forward_gpipe(input)
231    }
232
233    /// Interleaved 1F1B for virtual pipeline parallelism.
234    fn forward_interleaved(&self, input: &Variable) -> Variable {
235        // For simplicity, fall back to GPipe
236        self.forward_gpipe(input)
237    }
238
239    /// Splits input into microbatches.
240    fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
241        let data = input.data();
242        let batch_size = data.shape()[0];
243        let microbatch_size = batch_size.div_ceil(self.num_microbatches);
244
245        let mut microbatches = Vec::new();
246        let flat_data = data.to_vec();
247        let elements_per_sample: usize = data.shape()[1..].iter().product();
248
249        for i in 0..self.num_microbatches {
250            let start = i * microbatch_size;
251            let end = ((i + 1) * microbatch_size).min(batch_size);
252
253            if start >= batch_size {
254                break;
255            }
256
257            let mb_size = end - start;
258            let start_idx = start * elements_per_sample;
259            let end_idx = end * elements_per_sample;
260            let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
261
262            let mut shape = data.shape().to_vec();
263            shape[0] = mb_size;
264            let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
265            microbatches.push(Variable::new(tensor, input.requires_grad()));
266        }
267
268        microbatches
269    }
270
271    /// Combines microbatch outputs.
272    fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
273        if outputs.is_empty() {
274            return Variable::new(Tensor::zeros(&[0]), false);
275        }
276
277        if outputs.len() == 1 {
278            return outputs[0].clone();
279        }
280
281        // Concatenate along batch dimension
282        let mut all_data = Vec::new();
283        let mut total_batch = 0;
284        let shape = outputs[0].data().shape().to_vec();
285
286        for output in outputs {
287            all_data.extend(output.data().to_vec());
288            total_batch += output.data().shape()[0];
289        }
290
291        let mut new_shape = shape;
292        new_shape[0] = total_batch;
293        let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
294        Variable::new(tensor, outputs[0].requires_grad())
295    }
296
297    /// Sends activation to another rank.
298    fn send_activation(&self, activation: &Variable, dest_rank: usize) {
299        let mut tensor = activation.data().clone();
300        self.process_group.send_tensor(&mut tensor, dest_rank);
301    }
302
303    /// Receives activation from another rank.
304    fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
305        let tensor = self.process_group.recv_tensor(src_rank, &shape);
306        Variable::new(tensor, true)
307    }
308}
309
310impl<M: Module + Clone> Module for Pipeline<M> {
311    fn forward(&self, input: &Variable) -> Variable {
312        Pipeline::forward(self, input)
313    }
314
315    fn parameters(&self) -> Vec<Parameter> {
316        self.stages.iter().flat_map(|s| s.parameters()).collect()
317    }
318
319    fn train(&mut self) {
320        for stage in &mut self.stages {
321            stage.train();
322        }
323    }
324
325    fn eval(&mut self) {
326        for stage in &mut self.stages {
327            stage.eval();
328        }
329    }
330
331    fn is_training(&self) -> bool {
332        self.stages.first().is_some_and(|s| s.is_training())
333    }
334}
335
336// =============================================================================
337// Pipeline Memory Stats
338// =============================================================================
339
340/// Memory statistics for pipeline parallelism.
341#[derive(Debug, Clone)]
342pub struct PipelineMemoryStats {
343    /// Number of stages
344    pub num_stages: usize,
345    /// Number of microbatches
346    pub num_microbatches: usize,
347    /// Peak activations stored (per stage)
348    pub peak_activations_per_stage: usize,
349    /// Schedule used
350    pub schedule: PipelineSchedule,
351}
352
353impl PipelineMemoryStats {
354    /// Estimates peak activation memory for GPipe.
355    pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
356        // GPipe stores all microbatch activations
357        num_stages * num_microbatches
358    }
359
360    /// Estimates peak activation memory for 1F1B.
361    pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
362        // 1F1B stores at most num_stages activations in steady state
363        num_stages.min(num_microbatches)
364    }
365}
366
367// =============================================================================
368// Tests
369// =============================================================================
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use axonml_nn::Linear;
375
376    /// Simple identity module for testing pipelines (Linear doesn't impl Clone)
377    #[derive(Clone)]
378    struct IdentityModule {
379        size: usize,
380        training: bool,
381    }
382
383    impl IdentityModule {
384        fn new(size: usize) -> Self {
385            Self {
386                size,
387                training: true,
388            }
389        }
390    }
391
392    impl Module for IdentityModule {
393        fn forward(&self, input: &Variable) -> Variable {
394            input.clone()
395        }
396
397        fn parameters(&self) -> Vec<Parameter> {
398            Vec::new()
399        }
400
401        fn train(&mut self) {
402            self.training = true;
403        }
404
405        fn eval(&mut self) {
406            self.training = false;
407        }
408
409        fn is_training(&self) -> bool {
410            self.training
411        }
412    }
413
414    #[test]
415    fn test_pipeline_schedule_default() {
416        assert_eq!(
417            PipelineSchedule::default(),
418            PipelineSchedule::OneFOneBSchedule
419        );
420    }
421
422    #[test]
423    fn test_pipeline_stage_creation() {
424        let module = Linear::new(10, 5);
425        let stage = PipelineStage::new(module, 0, 0);
426
427        assert_eq!(stage.stage_id(), 0);
428        assert_eq!(stage.device_rank(), 0);
429    }
430
431    #[test]
432    fn test_pipeline_creation() {
433        let modules = vec![
434            IdentityModule::new(10),
435            IdentityModule::new(8),
436            IdentityModule::new(6),
437        ];
438        let pg = ProcessGroup::mock();
439        let pipeline = Pipeline::from_modules(modules, pg)
440            .schedule(PipelineSchedule::GPipe)
441            .num_microbatches(2);
442
443        assert_eq!(pipeline.num_stages(), 3);
444        assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
445    }
446
447    #[test]
448    fn test_pipeline_forward() {
449        let modules = vec![IdentityModule::new(4)];
450        let pg = ProcessGroup::mock();
451        let pipeline = Pipeline::from_modules(modules, pg);
452
453        let input = Variable::new(Tensor::randn(&[2, 4]), false);
454        let output = pipeline.forward(&input);
455
456        assert_eq!(output.data().shape(), &[2, 4]);
457    }
458
459    #[test]
460    fn test_pipeline_memory_stats() {
461        let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
462        let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
463
464        assert_eq!(gpipe, 32); // 4 * 8
465        assert_eq!(one_f_one_b, 4); // min(4, 8)
466    }
467
468    #[test]
469    fn test_split_microbatches() {
470        let modules = vec![IdentityModule::new(4)];
471        let pg = ProcessGroup::mock();
472        let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
473
474        let input = Variable::new(Tensor::randn(&[4, 4]), false);
475        let microbatches = pipeline.split_microbatches(&input);
476
477        assert_eq!(microbatches.len(), 2);
478        assert_eq!(microbatches[0].data().shape()[0], 2);
479        assert_eq!(microbatches[1].data().shape()[0], 2);
480    }
481}