entrenar/train/transformer_trainer/
pipeline.rs1#[derive(Debug, Clone)]
34pub struct PipelineStage {
35 pub stage_id: usize,
37 pub num_stages: usize,
39 pub block_start: usize,
41 pub block_end: usize,
43 pub has_embedding: bool,
45 pub has_lm_head: bool,
47 pub num_micro_batches: usize,
49}
50
51impl PipelineStage {
52 pub fn new(
63 stage_id: usize,
64 num_stages: usize,
65 num_blocks: usize,
66 num_micro_batches: usize,
67 ) -> Self {
68 assert!(
69 num_micro_batches >= num_stages,
70 "need at least {num_stages} micro-batches to fill pipeline, got {num_micro_batches}"
71 );
72
73 let blocks_per_stage = num_blocks / num_stages;
74 let remainder = num_blocks % num_stages;
75
76 let block_start = if stage_id < remainder {
77 stage_id * (blocks_per_stage + 1)
78 } else {
79 remainder * (blocks_per_stage + 1) + (stage_id - remainder) * blocks_per_stage
80 };
81
82 let block_end = if stage_id < remainder {
83 block_start + blocks_per_stage + 1
84 } else {
85 block_start + blocks_per_stage
86 };
87
88 Self {
89 stage_id,
90 num_stages,
91 block_start,
92 block_end,
93 has_embedding: stage_id == 0,
94 has_lm_head: stage_id == num_stages - 1,
95 num_micro_batches,
96 }
97 }
98
99 pub fn num_blocks(&self) -> usize {
101 self.block_end - self.block_start
102 }
103
104 pub fn is_first(&self) -> bool {
106 self.stage_id == 0
107 }
108
109 pub fn is_last(&self) -> bool {
111 self.stage_id == self.num_stages - 1
112 }
113
114 pub fn bubble_fraction(&self) -> f64 {
119 (self.num_stages as f64 - 1.0) / self.num_micro_batches as f64
120 }
121
122 pub fn efficiency(&self) -> f64 {
124 1.0 - self.bubble_fraction()
125 }
126
127 pub fn schedule_1f1b(&self) -> Vec<PipelineAction> {
132 let m = self.num_micro_batches;
133 let p = self.num_stages;
134 let mut actions = Vec::new();
135
136 let warmup_forwards = p - self.stage_id - 1;
138 for mb in 0..warmup_forwards.min(m) {
139 actions.push(PipelineAction::Forward(mb));
140 }
141
142 let steady_start = warmup_forwards.min(m);
144 let mut next_fwd = steady_start;
145 let mut next_bwd = 0;
146
147 while next_fwd < m || next_bwd < m {
148 if next_fwd < m {
150 actions.push(PipelineAction::Forward(next_fwd));
151 next_fwd += 1;
152 }
153 if next_bwd < m {
155 actions.push(PipelineAction::Backward(next_bwd));
156 next_bwd += 1;
157 }
158 }
159
160 actions
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum PipelineAction {
167 Forward(usize),
169 Backward(usize),
171}
172
173#[derive(Debug, Clone)]
179pub struct PipelineActivationBuffer {
180 pub forward_activations: Vec<Vec<f32>>,
182 pub backward_gradients: Vec<Vec<f32>>,
184 pub num_micro_batches: usize,
186 pub activation_size: usize,
188}
189
190impl PipelineActivationBuffer {
191 pub fn new(num_micro_batches: usize, seq_len: usize, hidden_size: usize) -> Self {
198 let activation_size = seq_len * hidden_size;
199 Self {
200 forward_activations: vec![Vec::new(); num_micro_batches],
201 backward_gradients: vec![Vec::new(); num_micro_batches],
202 num_micro_batches,
203 activation_size,
204 }
205 }
206
207 pub fn store_activation(&mut self, micro_batch: usize, activation: Vec<f32>) {
209 assert_eq!(
210 activation.len(),
211 self.activation_size,
212 "activation size mismatch: expected {}, got {}",
213 self.activation_size,
214 activation.len()
215 );
216 self.forward_activations[micro_batch] = activation;
217 }
218
219 pub fn store_gradient(&mut self, micro_batch: usize, gradient: Vec<f32>) {
221 assert_eq!(
222 gradient.len(),
223 self.activation_size,
224 "gradient size mismatch: expected {}, got {}",
225 self.activation_size,
226 gradient.len()
227 );
228 self.backward_gradients[micro_batch] = gradient;
229 }
230
231 pub fn get_activation(&self, micro_batch: usize) -> &[f32] {
233 &self.forward_activations[micro_batch]
234 }
235
236 pub fn get_gradient(&self, micro_batch: usize) -> &[f32] {
238 &self.backward_gradients[micro_batch]
239 }
240
241 pub fn memory_bytes(&self) -> usize {
243 let fwd: usize = self.forward_activations.iter().map(|v| v.len() * 4).sum();
244 let bwd: usize = self.backward_gradients.iter().map(|v| v.len() * 4).sum();
245 fwd + bwd
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_pipeline_stage_basic() {
255 let stage0 = PipelineStage::new(0, 2, 24, 4);
257 let stage1 = PipelineStage::new(1, 2, 24, 4);
258
259 assert_eq!(stage0.block_start, 0);
260 assert_eq!(stage0.block_end, 12);
261 assert_eq!(stage0.num_blocks(), 12);
262 assert!(stage0.has_embedding);
263 assert!(!stage0.has_lm_head);
264
265 assert_eq!(stage1.block_start, 12);
266 assert_eq!(stage1.block_end, 24);
267 assert!(stage1.has_lm_head);
268 assert!(!stage1.has_embedding);
269 }
270
271 #[test]
272 fn test_pipeline_stage_4way() {
273 for i in 0..4 {
275 let stage = PipelineStage::new(i, 4, 24, 8);
276 assert_eq!(stage.num_blocks(), 6);
277 assert_eq!(stage.block_start, i * 6);
278 assert_eq!(stage.block_end, (i + 1) * 6);
279 }
280 }
281
282 #[test]
283 fn test_pipeline_stage_uneven() {
284 let s0 = PipelineStage::new(0, 3, 10, 6);
286 let s1 = PipelineStage::new(1, 3, 10, 6);
287 let s2 = PipelineStage::new(2, 3, 10, 6);
288
289 assert_eq!(s0.num_blocks(), 4);
290 assert_eq!(s1.num_blocks(), 3);
291 assert_eq!(s2.num_blocks(), 3);
292
293 assert_eq!(s0.block_end, s1.block_start);
295 assert_eq!(s1.block_end, s2.block_start);
296 assert_eq!(s2.block_end, 10);
297 }
298
299 #[test]
300 fn test_pipeline_bubble_fraction() {
301 let stage = PipelineStage::new(0, 2, 24, 4);
303 assert!((stage.bubble_fraction() - 0.25).abs() < 1e-10);
304 assert!((stage.efficiency() - 0.75).abs() < 1e-10);
305
306 let stage = PipelineStage::new(0, 4, 24, 8);
308 assert!((stage.bubble_fraction() - 0.375).abs() < 1e-10);
309
310 let stage = PipelineStage::new(0, 2, 24, 16);
312 assert!((stage.bubble_fraction() - 0.0625).abs() < 1e-10);
313 }
314
315 #[test]
316 fn test_pipeline_1f1b_schedule() {
317 let stage = PipelineStage::new(0, 2, 24, 4);
318 let schedule = stage.schedule_1f1b();
319
320 let fwd_count = schedule.iter().filter(|a| matches!(a, PipelineAction::Forward(_))).count();
322 let bwd_count =
323 schedule.iter().filter(|a| matches!(a, PipelineAction::Backward(_))).count();
324
325 assert_eq!(fwd_count, 4, "should have 4 forwards");
326 assert_eq!(bwd_count, 4, "should have 4 backwards");
327
328 let mut fwd_ids: Vec<_> = schedule
330 .iter()
331 .filter_map(|a| match a {
332 PipelineAction::Forward(id) => Some(*id),
333 _ => None,
334 })
335 .collect();
336 fwd_ids.sort_unstable();
337 assert_eq!(fwd_ids, vec![0, 1, 2, 3]);
338 }
339
340 #[test]
341 fn test_pipeline_activation_buffer() {
342 let mut buf = PipelineActivationBuffer::new(2, 512, 1024);
343 assert_eq!(buf.activation_size, 512 * 1024);
344
345 let act = vec![1.0f32; 512 * 1024];
346 buf.store_activation(0, act.clone());
347 assert_eq!(buf.get_activation(0).len(), 512 * 1024);
348 assert_eq!(buf.get_activation(0)[0], 1.0);
349
350 let grad = vec![0.5f32; 512 * 1024];
351 buf.store_gradient(1, grad);
352 assert_eq!(buf.get_gradient(1)[0], 0.5);
353 }
354
355 #[test]
356 fn test_pipeline_first_last_stage() {
357 let s0 = PipelineStage::new(0, 3, 12, 6);
358 let s1 = PipelineStage::new(1, 3, 12, 6);
359 let s2 = PipelineStage::new(2, 3, 12, 6);
360
361 assert!(s0.is_first());
362 assert!(!s0.is_last());
363 assert!(!s1.is_first());
364 assert!(!s1.is_last());
365 assert!(!s2.is_first());
366 assert!(s2.is_last());
367 }
368
369 #[test]
370 #[should_panic(expected = "need at least")]
371 fn test_pipeline_too_few_micro_batches() {
372 PipelineStage::new(0, 4, 24, 2); }
374}