Skip to main content

axonml_distributed/
pipeline.rs

1//! Pipeline Parallelism
2//!
3//! Implements pipeline parallelism for training large models across multiple devices.
4//! Splits the model into stages, with each stage running on a different device.
5//!
6//! # Example
7//! ```rust,ignore
8//! use axonml_distributed::pipeline::{Pipeline, PipelineSchedule};
9//!
10//! // Split model into 4 stages across 4 GPUs
11//! let pipeline = Pipeline::new(stages, process_group)
12//!     .schedule(PipelineSchedule::GPipe)
13//!     .num_microbatches(4);
14//!
15//! let output = pipeline.forward(&input);
16//! ```
17//!
18//! @version 0.1.0
19
20use crate::process_group::ProcessGroup;
21use axonml_autograd::Variable;
22use axonml_nn::{Module, Parameter};
23use axonml_tensor::Tensor;
24
25// =============================================================================
26// Pipeline Schedule
27// =============================================================================
28
29/// Pipeline execution schedule.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum PipelineSchedule {
32    /// GPipe: Fill-drain schedule with synchronized updates
33    GPipe,
34    /// 1F1B: One forward, one backward schedule for memory efficiency
35    OneFOneBSchedule,
36    /// Interleaved 1F1B for better efficiency
37    InterleavedOneFOneB,
38}
39
40impl Default for PipelineSchedule {
41    fn default() -> Self {
42        Self::OneFOneBSchedule
43    }
44}
45
46// =============================================================================
47// Pipeline Stage
48// =============================================================================
49
50/// A stage in the pipeline.
51pub struct PipelineStage<M: Module> {
52    /// The module for this stage
53    module: M,
54    /// Stage index (0 = first stage)
55    stage_id: usize,
56    /// Device/rank this stage runs on
57    device_rank: usize,
58}
59
60impl<M: Module> PipelineStage<M> {
61    /// Creates a new pipeline stage.
62    pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
63        Self {
64            module,
65            stage_id,
66            device_rank,
67        }
68    }
69
70    /// Returns the stage ID.
71    pub fn stage_id(&self) -> usize {
72        self.stage_id
73    }
74
75    /// Returns the device rank.
76    pub fn device_rank(&self) -> usize {
77        self.device_rank
78    }
79
80    /// Forward pass for this stage.
81    pub fn forward(&self, input: &Variable) -> Variable {
82        self.module.forward(input)
83    }
84}
85
86impl<M: Module> Module for PipelineStage<M> {
87    fn forward(&self, input: &Variable) -> Variable {
88        self.module.forward(input)
89    }
90
91    fn parameters(&self) -> Vec<Parameter> {
92        self.module.parameters()
93    }
94
95    fn train(&mut self) {
96        self.module.train();
97    }
98
99    fn eval(&mut self) {
100        self.module.eval();
101    }
102
103    fn is_training(&self) -> bool {
104        self.module.is_training()
105    }
106}
107
108// =============================================================================
109// Pipeline
110// =============================================================================
111
112/// Pipeline parallel wrapper for distributed training.
113///
114/// Splits model computation across multiple stages, with each stage
115/// potentially running on a different device/rank.
116pub struct Pipeline<M: Module> {
117    /// Pipeline stages
118    stages: Vec<PipelineStage<M>>,
119    /// Process group for communication
120    process_group: ProcessGroup,
121    /// Pipeline schedule
122    schedule: PipelineSchedule,
123    /// Number of microbatches
124    num_microbatches: usize,
125    /// Current rank's stage index (for future use with multi-GPU)
126    #[allow(dead_code)]
127    local_stage: usize,
128}
129
130impl<M: Module + Clone> Pipeline<M> {
131    /// Creates a new pipeline from a list of modules.
132    pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
133        let world_size = process_group.world_size();
134        let rank = process_group.rank();
135
136        let stages: Vec<PipelineStage<M>> = modules
137            .into_iter()
138            .enumerate()
139            .map(|(i, m)| PipelineStage::new(m, i, i % world_size))
140            .collect();
141
142        let local_stage = stages
143            .iter()
144            .position(|s| s.device_rank == rank)
145            .unwrap_or(0);
146
147        Self {
148            stages,
149            process_group,
150            schedule: PipelineSchedule::default(),
151            num_microbatches: 1,
152            local_stage,
153        }
154    }
155
156    /// Builder: set pipeline schedule.
157    pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
158        self.schedule = schedule;
159        self
160    }
161
162    /// Builder: set number of microbatches.
163    pub fn num_microbatches(mut self, num: usize) -> Self {
164        self.num_microbatches = num.max(1);
165        self
166    }
167
168    /// Returns number of stages.
169    pub fn num_stages(&self) -> usize {
170        self.stages.len()
171    }
172
173    /// Returns the current schedule.
174    pub fn get_schedule(&self) -> PipelineSchedule {
175        self.schedule
176    }
177
178    /// Forward pass through the pipeline.
179    ///
180    /// For the first stage, takes input and forwards to next stage.
181    /// For intermediate stages, receives from previous, forwards to next.
182    /// For the last stage, receives from previous, returns output.
183    pub fn forward(&self, input: &Variable) -> Variable {
184        match self.schedule {
185            PipelineSchedule::GPipe => self.forward_gpipe(input),
186            PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
187            PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
188        }
189    }
190
191    /// GPipe schedule: fill-drain with all forwards then all backwards.
192    fn forward_gpipe(&self, input: &Variable) -> Variable {
193        let rank = self.process_group.rank();
194        let num_stages = self.stages.len();
195
196        // Split input into microbatches
197        let microbatches = self.split_microbatches(input);
198
199        // Process all microbatches through pipeline
200        let mut outputs = Vec::new();
201
202        for microbatch in microbatches {
203            let mut activation = microbatch;
204
205            // Forward through all stages
206            for (stage_idx, stage) in self.stages.iter().enumerate() {
207                if stage.device_rank == rank {
208                    activation = stage.forward(&activation);
209                }
210
211                // Send to next stage if not last
212                if stage_idx < num_stages - 1 {
213                    let next_rank = self.stages[stage_idx + 1].device_rank;
214                    if stage.device_rank == rank {
215                        // Send activation to next stage
216                        self.send_activation(&activation, next_rank);
217                    } else if next_rank == rank {
218                        // Receive activation from previous stage
219                        activation = self.recv_activation(stage.device_rank, activation.shape());
220                    }
221                }
222            }
223
224            // Last stage collects output
225            if self.stages.last().map(|s| s.device_rank) == Some(rank) {
226                outputs.push(activation);
227            }
228        }
229
230        // Combine outputs
231        self.combine_microbatches(&outputs)
232    }
233
234    /// 1F1B schedule: memory-efficient interleaved forward/backward.
235    fn forward_1f1b(&self, input: &Variable) -> Variable {
236        // For simplicity, fall back to GPipe in this implementation
237        // Full 1F1B requires careful scheduling of forward/backward passes
238        self.forward_gpipe(input)
239    }
240
241    /// Interleaved 1F1B for virtual pipeline parallelism.
242    fn forward_interleaved(&self, input: &Variable) -> Variable {
243        // For simplicity, fall back to GPipe
244        self.forward_gpipe(input)
245    }
246
247    /// Splits input into microbatches.
248    fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
249        let data = input.data();
250        let batch_size = data.shape()[0];
251        let microbatch_size = (batch_size + self.num_microbatches - 1) / self.num_microbatches;
252
253        let mut microbatches = Vec::new();
254        let flat_data = data.to_vec();
255        let elements_per_sample: usize = data.shape()[1..].iter().product();
256
257        for i in 0..self.num_microbatches {
258            let start = i * microbatch_size;
259            let end = ((i + 1) * microbatch_size).min(batch_size);
260
261            if start >= batch_size {
262                break;
263            }
264
265            let mb_size = end - start;
266            let start_idx = start * elements_per_sample;
267            let end_idx = end * elements_per_sample;
268            let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
269
270            let mut shape = data.shape().to_vec();
271            shape[0] = mb_size;
272            let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
273            microbatches.push(Variable::new(tensor, input.requires_grad()));
274        }
275
276        microbatches
277    }
278
279    /// Combines microbatch outputs.
280    fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
281        if outputs.is_empty() {
282            return Variable::new(Tensor::zeros(&[0]), false);
283        }
284
285        if outputs.len() == 1 {
286            return outputs[0].clone();
287        }
288
289        // Concatenate along batch dimension
290        let mut all_data = Vec::new();
291        let mut total_batch = 0;
292        let shape = outputs[0].data().shape().to_vec();
293
294        for output in outputs {
295            all_data.extend(output.data().to_vec());
296            total_batch += output.data().shape()[0];
297        }
298
299        let mut new_shape = shape;
300        new_shape[0] = total_batch;
301        let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
302        Variable::new(tensor, outputs[0].requires_grad())
303    }
304
305    /// Sends activation to another rank.
306    fn send_activation(&self, activation: &Variable, dest_rank: usize) {
307        let mut tensor = activation.data().clone();
308        self.process_group.send_tensor(&mut tensor, dest_rank);
309    }
310
311    /// Receives activation from another rank.
312    fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
313        let tensor = self.process_group.recv_tensor(src_rank, &shape);
314        Variable::new(tensor, true)
315    }
316}
317
318impl<M: Module + Clone> Module for Pipeline<M> {
319    fn forward(&self, input: &Variable) -> Variable {
320        Pipeline::forward(self, input)
321    }
322
323    fn parameters(&self) -> Vec<Parameter> {
324        self.stages
325            .iter()
326            .flat_map(|s| s.parameters())
327            .collect()
328    }
329
330    fn train(&mut self) {
331        for stage in &mut self.stages {
332            stage.train();
333        }
334    }
335
336    fn eval(&mut self) {
337        for stage in &mut self.stages {
338            stage.eval();
339        }
340    }
341
342    fn is_training(&self) -> bool {
343        self.stages.first().map(|s| s.is_training()).unwrap_or(false)
344    }
345}
346
347// =============================================================================
348// Pipeline Memory Stats
349// =============================================================================
350
351/// Memory statistics for pipeline parallelism.
352#[derive(Debug, Clone)]
353pub struct PipelineMemoryStats {
354    /// Number of stages
355    pub num_stages: usize,
356    /// Number of microbatches
357    pub num_microbatches: usize,
358    /// Peak activations stored (per stage)
359    pub peak_activations_per_stage: usize,
360    /// Schedule used
361    pub schedule: PipelineSchedule,
362}
363
364impl PipelineMemoryStats {
365    /// Estimates peak activation memory for GPipe.
366    pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
367        // GPipe stores all microbatch activations
368        num_stages * num_microbatches
369    }
370
371    /// Estimates peak activation memory for 1F1B.
372    pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
373        // 1F1B stores at most num_stages activations in steady state
374        num_stages.min(num_microbatches)
375    }
376}
377
378// =============================================================================
379// Tests
380// =============================================================================
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use axonml_nn::Linear;
386
387    /// Simple identity module for testing pipelines (Linear doesn't impl Clone)
388    #[derive(Clone)]
389    struct IdentityModule {
390        size: usize,
391        training: bool,
392    }
393
394    impl IdentityModule {
395        fn new(size: usize) -> Self {
396            Self { size, training: true }
397        }
398    }
399
400    impl Module for IdentityModule {
401        fn forward(&self, input: &Variable) -> Variable {
402            input.clone()
403        }
404
405        fn parameters(&self) -> Vec<Parameter> {
406            Vec::new()
407        }
408
409        fn train(&mut self) {
410            self.training = true;
411        }
412
413        fn eval(&mut self) {
414            self.training = false;
415        }
416
417        fn is_training(&self) -> bool {
418            self.training
419        }
420    }
421
422    #[test]
423    fn test_pipeline_schedule_default() {
424        assert_eq!(PipelineSchedule::default(), PipelineSchedule::OneFOneBSchedule);
425    }
426
427    #[test]
428    fn test_pipeline_stage_creation() {
429        let module = Linear::new(10, 5);
430        let stage = PipelineStage::new(module, 0, 0);
431
432        assert_eq!(stage.stage_id(), 0);
433        assert_eq!(stage.device_rank(), 0);
434    }
435
436    #[test]
437    fn test_pipeline_creation() {
438        let modules = vec![
439            IdentityModule::new(10),
440            IdentityModule::new(8),
441            IdentityModule::new(6),
442        ];
443        let pg = ProcessGroup::mock();
444        let pipeline = Pipeline::from_modules(modules, pg)
445            .schedule(PipelineSchedule::GPipe)
446            .num_microbatches(2);
447
448        assert_eq!(pipeline.num_stages(), 3);
449        assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
450    }
451
452    #[test]
453    fn test_pipeline_forward() {
454        let modules = vec![
455            IdentityModule::new(4),
456        ];
457        let pg = ProcessGroup::mock();
458        let pipeline = Pipeline::from_modules(modules, pg);
459
460        let input = Variable::new(Tensor::randn(&[2, 4]), false);
461        let output = pipeline.forward(&input);
462
463        assert_eq!(output.data().shape(), &[2, 4]);
464    }
465
466    #[test]
467    fn test_pipeline_memory_stats() {
468        let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
469        let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
470
471        assert_eq!(gpipe, 32); // 4 * 8
472        assert_eq!(one_f_one_b, 4); // min(4, 8)
473    }
474
475    #[test]
476    fn test_split_microbatches() {
477        let modules = vec![IdentityModule::new(4)];
478        let pg = ProcessGroup::mock();
479        let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
480
481        let input = Variable::new(Tensor::randn(&[4, 4]), false);
482        let microbatches = pipeline.split_microbatches(&input);
483
484        assert_eq!(microbatches.len(), 2);
485        assert_eq!(microbatches[0].data().shape()[0], 2);
486        assert_eq!(microbatches[1].data().shape()[0], 2);
487    }
488}