Skip to main content

entrenar/train/transformer_trainer/
pipeline.rs

1//! Pipeline parallelism for transformer pretraining.
2//!
3//! Splits transformer blocks across multiple workers by layer range.
4//! Each worker runs forward/backward only for its assigned blocks.
5//! Inter-worker communication passes activations (forward) and
6//! gradients (backward) between adjacent pipeline stages.
7//!
8//! # Architecture
9//!
10//! ```text
11//! Worker 0 (blocks 0-11)    Worker 1 (blocks 12-23)
12//! ─────────────────────    ──────────────────────────
13//! embed → block[0..12]  →  block[12..24] → lm_head
14//!                       ←  grad_activations
15//! ```
16//!
17//! # Schedule: 1F1B (One Forward, One Backward)
18//!
19//! With M micro-batches:
20//! 1. Warmup: pipeline fills with forward passes (M-1 forwards)
21//! 2. Steady state: alternate 1 forward + 1 backward
22//! 3. Cooldown: drain remaining backward passes
23//!
24//! Pipeline bubble = (P-1) / M of total compute, where P = pipeline stages.
25//!
26//! # Contract (C-PIPE-001)
27//!
28//! - Each block is owned by exactly one pipeline stage
29//! - Activation tensor shapes are consistent at stage boundaries
30//! - Gradient flow is continuous across stage boundaries
31
32/// Pipeline stage assignment for a worker.
33#[derive(Debug, Clone)]
34pub struct PipelineStage {
35    /// Stage index (0 = first, closest to embedding)
36    pub stage_id: usize,
37    /// Total number of pipeline stages
38    pub num_stages: usize,
39    /// First block index (inclusive)
40    pub block_start: usize,
41    /// Last block index (exclusive)
42    pub block_end: usize,
43    /// Whether this stage owns the embedding layer
44    pub has_embedding: bool,
45    /// Whether this stage owns the LM head
46    pub has_lm_head: bool,
47    /// Number of micro-batches for 1F1B schedule
48    pub num_micro_batches: usize,
49}
50
51impl PipelineStage {
52    /// Create a pipeline stage assignment.
53    ///
54    /// # Arguments
55    /// * `stage_id` - This worker's stage (0-indexed)
56    /// * `num_stages` - Total pipeline stages (typically 2 or 4)
57    /// * `num_blocks` - Total transformer blocks
58    /// * `num_micro_batches` - Micro-batches for 1F1B (must be >= num_stages)
59    ///
60    /// # Panics
61    /// Panics if `num_micro_batches < num_stages` (can't fill pipeline).
62    pub fn new(
63        stage_id: usize,
64        num_stages: usize,
65        num_blocks: usize,
66        num_micro_batches: usize,
67    ) -> Self {
68        assert!(
69            num_micro_batches >= num_stages,
70            "need at least {num_stages} micro-batches to fill pipeline, got {num_micro_batches}"
71        );
72
73        let blocks_per_stage = num_blocks / num_stages;
74        let remainder = num_blocks % num_stages;
75
76        let block_start = if stage_id < remainder {
77            stage_id * (blocks_per_stage + 1)
78        } else {
79            remainder * (blocks_per_stage + 1) + (stage_id - remainder) * blocks_per_stage
80        };
81
82        let block_end = if stage_id < remainder {
83            block_start + blocks_per_stage + 1
84        } else {
85            block_start + blocks_per_stage
86        };
87
88        Self {
89            stage_id,
90            num_stages,
91            block_start,
92            block_end,
93            has_embedding: stage_id == 0,
94            has_lm_head: stage_id == num_stages - 1,
95            num_micro_batches,
96        }
97    }
98
99    /// Number of blocks in this stage.
100    pub fn num_blocks(&self) -> usize {
101        self.block_end - self.block_start
102    }
103
104    /// Whether this is the first pipeline stage.
105    pub fn is_first(&self) -> bool {
106        self.stage_id == 0
107    }
108
109    /// Whether this is the last pipeline stage.
110    pub fn is_last(&self) -> bool {
111        self.stage_id == self.num_stages - 1
112    }
113
114    /// Compute pipeline bubble fraction.
115    ///
116    /// Returns the fraction of time spent idle due to pipeline bubbles.
117    /// Bubble = (P - 1) / M where P = stages, M = micro-batches.
118    pub fn bubble_fraction(&self) -> f64 {
119        (self.num_stages as f64 - 1.0) / self.num_micro_batches as f64
120    }
121
122    /// Compute pipeline efficiency (1 - bubble fraction).
123    pub fn efficiency(&self) -> f64 {
124        1.0 - self.bubble_fraction()
125    }
126
127    /// Generate 1F1B schedule for this stage.
128    ///
129    /// Returns a sequence of (action, micro_batch_id) pairs.
130    /// Action: Forward or Backward.
131    pub fn schedule_1f1b(&self) -> Vec<PipelineAction> {
132        let m = self.num_micro_batches;
133        let p = self.num_stages;
134        let mut actions = Vec::new();
135
136        // Warmup phase: forward passes to fill pipeline
137        let warmup_forwards = p - self.stage_id - 1;
138        for mb in 0..warmup_forwards.min(m) {
139            actions.push(PipelineAction::Forward(mb));
140        }
141
142        // Steady state: 1F1B pairs
143        let steady_start = warmup_forwards.min(m);
144        let mut next_fwd = steady_start;
145        let mut next_bwd = 0;
146
147        while next_fwd < m || next_bwd < m {
148            // One forward (if remaining)
149            if next_fwd < m {
150                actions.push(PipelineAction::Forward(next_fwd));
151                next_fwd += 1;
152            }
153            // One backward (if remaining)
154            if next_bwd < m {
155                actions.push(PipelineAction::Backward(next_bwd));
156                next_bwd += 1;
157            }
158        }
159
160        actions
161    }
162}
163
164/// Action in a 1F1B pipeline schedule.
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum PipelineAction {
167    /// Run forward pass for micro-batch N
168    Forward(usize),
169    /// Run backward pass for micro-batch N
170    Backward(usize),
171}
172
173/// Activation buffer between pipeline stages.
174///
175/// Stores intermediate activations sent from stage N to stage N+1
176/// during forward, and gradients sent from stage N+1 to stage N
177/// during backward.
178#[derive(Debug, Clone)]
179pub struct PipelineActivationBuffer {
180    /// Stored activations per micro-batch: `[micro_batch][seq_len * hidden_size]`
181    pub forward_activations: Vec<Vec<f32>>,
182    /// Stored gradients per micro-batch (from downstream stage)
183    pub backward_gradients: Vec<Vec<f32>>,
184    /// Number of micro-batches
185    pub num_micro_batches: usize,
186    /// Elements per activation tensor (seq_len * hidden_size)
187    pub activation_size: usize,
188}
189
190impl PipelineActivationBuffer {
191    /// Create a new activation buffer.
192    ///
193    /// # Arguments
194    /// * `num_micro_batches` - Number of micro-batches
195    /// * `seq_len` - Sequence length
196    /// * `hidden_size` - Hidden dimension
197    pub fn new(num_micro_batches: usize, seq_len: usize, hidden_size: usize) -> Self {
198        let activation_size = seq_len * hidden_size;
199        Self {
200            forward_activations: vec![Vec::new(); num_micro_batches],
201            backward_gradients: vec![Vec::new(); num_micro_batches],
202            num_micro_batches,
203            activation_size,
204        }
205    }
206
207    /// Store forward activation for a micro-batch.
208    pub fn store_activation(&mut self, micro_batch: usize, activation: Vec<f32>) {
209        assert_eq!(
210            activation.len(),
211            self.activation_size,
212            "activation size mismatch: expected {}, got {}",
213            self.activation_size,
214            activation.len()
215        );
216        self.forward_activations[micro_batch] = activation;
217    }
218
219    /// Store backward gradient for a micro-batch.
220    pub fn store_gradient(&mut self, micro_batch: usize, gradient: Vec<f32>) {
221        assert_eq!(
222            gradient.len(),
223            self.activation_size,
224            "gradient size mismatch: expected {}, got {}",
225            self.activation_size,
226            gradient.len()
227        );
228        self.backward_gradients[micro_batch] = gradient;
229    }
230
231    /// Get forward activation for a micro-batch.
232    pub fn get_activation(&self, micro_batch: usize) -> &[f32] {
233        &self.forward_activations[micro_batch]
234    }
235
236    /// Get backward gradient for a micro-batch.
237    pub fn get_gradient(&self, micro_batch: usize) -> &[f32] {
238        &self.backward_gradients[micro_batch]
239    }
240
241    /// Total memory used by this buffer in bytes.
242    pub fn memory_bytes(&self) -> usize {
243        let fwd: usize = self.forward_activations.iter().map(|v| v.len() * 4).sum();
244        let bwd: usize = self.backward_gradients.iter().map(|v| v.len() * 4).sum();
245        fwd + bwd
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_pipeline_stage_basic() {
255        // 24 blocks, 2 stages, 4 micro-batches
256        let stage0 = PipelineStage::new(0, 2, 24, 4);
257        let stage1 = PipelineStage::new(1, 2, 24, 4);
258
259        assert_eq!(stage0.block_start, 0);
260        assert_eq!(stage0.block_end, 12);
261        assert_eq!(stage0.num_blocks(), 12);
262        assert!(stage0.has_embedding);
263        assert!(!stage0.has_lm_head);
264
265        assert_eq!(stage1.block_start, 12);
266        assert_eq!(stage1.block_end, 24);
267        assert!(stage1.has_lm_head);
268        assert!(!stage1.has_embedding);
269    }
270
271    #[test]
272    fn test_pipeline_stage_4way() {
273        // 24 blocks, 4 stages → 6 each
274        for i in 0..4 {
275            let stage = PipelineStage::new(i, 4, 24, 8);
276            assert_eq!(stage.num_blocks(), 6);
277            assert_eq!(stage.block_start, i * 6);
278            assert_eq!(stage.block_end, (i + 1) * 6);
279        }
280    }
281
282    #[test]
283    fn test_pipeline_stage_uneven() {
284        // 10 blocks, 3 stages → 4, 3, 3
285        let s0 = PipelineStage::new(0, 3, 10, 6);
286        let s1 = PipelineStage::new(1, 3, 10, 6);
287        let s2 = PipelineStage::new(2, 3, 10, 6);
288
289        assert_eq!(s0.num_blocks(), 4);
290        assert_eq!(s1.num_blocks(), 3);
291        assert_eq!(s2.num_blocks(), 3);
292
293        // Complete coverage
294        assert_eq!(s0.block_end, s1.block_start);
295        assert_eq!(s1.block_end, s2.block_start);
296        assert_eq!(s2.block_end, 10);
297    }
298
299    #[test]
300    fn test_pipeline_bubble_fraction() {
301        // 2 stages, 4 micro-batches → bubble = 1/4 = 25%
302        let stage = PipelineStage::new(0, 2, 24, 4);
303        assert!((stage.bubble_fraction() - 0.25).abs() < 1e-10);
304        assert!((stage.efficiency() - 0.75).abs() < 1e-10);
305
306        // 4 stages, 8 micro-batches → bubble = 3/8 = 37.5%
307        let stage = PipelineStage::new(0, 4, 24, 8);
308        assert!((stage.bubble_fraction() - 0.375).abs() < 1e-10);
309
310        // 2 stages, 16 micro-batches → bubble = 1/16 = 6.25%
311        let stage = PipelineStage::new(0, 2, 24, 16);
312        assert!((stage.bubble_fraction() - 0.0625).abs() < 1e-10);
313    }
314
315    #[test]
316    fn test_pipeline_1f1b_schedule() {
317        let stage = PipelineStage::new(0, 2, 24, 4);
318        let schedule = stage.schedule_1f1b();
319
320        // Count forwards and backwards
321        let fwd_count = schedule.iter().filter(|a| matches!(a, PipelineAction::Forward(_))).count();
322        let bwd_count =
323            schedule.iter().filter(|a| matches!(a, PipelineAction::Backward(_))).count();
324
325        assert_eq!(fwd_count, 4, "should have 4 forwards");
326        assert_eq!(bwd_count, 4, "should have 4 backwards");
327
328        // All micro-batches covered
329        let mut fwd_ids: Vec<_> = schedule
330            .iter()
331            .filter_map(|a| match a {
332                PipelineAction::Forward(id) => Some(*id),
333                _ => None,
334            })
335            .collect();
336        fwd_ids.sort_unstable();
337        assert_eq!(fwd_ids, vec![0, 1, 2, 3]);
338    }
339
340    #[test]
341    fn test_pipeline_activation_buffer() {
342        let mut buf = PipelineActivationBuffer::new(2, 512, 1024);
343        assert_eq!(buf.activation_size, 512 * 1024);
344
345        let act = vec![1.0f32; 512 * 1024];
346        buf.store_activation(0, act.clone());
347        assert_eq!(buf.get_activation(0).len(), 512 * 1024);
348        assert_eq!(buf.get_activation(0)[0], 1.0);
349
350        let grad = vec![0.5f32; 512 * 1024];
351        buf.store_gradient(1, grad);
352        assert_eq!(buf.get_gradient(1)[0], 0.5);
353    }
354
355    #[test]
356    fn test_pipeline_first_last_stage() {
357        let s0 = PipelineStage::new(0, 3, 12, 6);
358        let s1 = PipelineStage::new(1, 3, 12, 6);
359        let s2 = PipelineStage::new(2, 3, 12, 6);
360
361        assert!(s0.is_first());
362        assert!(!s0.is_last());
363        assert!(!s1.is_first());
364        assert!(!s1.is_last());
365        assert!(!s2.is_first());
366        assert!(s2.is_last());
367    }
368
369    #[test]
370    #[should_panic(expected = "need at least")]
371    fn test_pipeline_too_few_micro_batches() {
372        PipelineStage::new(0, 4, 24, 2); // 2 < 4 stages
373    }
374}