#[derive(Debug, Clone)]
pub struct PipelineStage {
pub stage_id: usize,
pub num_stages: usize,
pub block_start: usize,
pub block_end: usize,
pub has_embedding: bool,
pub has_lm_head: bool,
pub num_micro_batches: usize,
}
impl PipelineStage {
pub fn new(
stage_id: usize,
num_stages: usize,
num_blocks: usize,
num_micro_batches: usize,
) -> Self {
assert!(
num_micro_batches >= num_stages,
"need at least {num_stages} micro-batches to fill pipeline, got {num_micro_batches}"
);
let blocks_per_stage = num_blocks / num_stages;
let remainder = num_blocks % num_stages;
let block_start = if stage_id < remainder {
stage_id * (blocks_per_stage + 1)
} else {
remainder * (blocks_per_stage + 1) + (stage_id - remainder) * blocks_per_stage
};
let block_end = if stage_id < remainder {
block_start + blocks_per_stage + 1
} else {
block_start + blocks_per_stage
};
Self {
stage_id,
num_stages,
block_start,
block_end,
has_embedding: stage_id == 0,
has_lm_head: stage_id == num_stages - 1,
num_micro_batches,
}
}
pub fn num_blocks(&self) -> usize {
self.block_end - self.block_start
}
pub fn is_first(&self) -> bool {
self.stage_id == 0
}
pub fn is_last(&self) -> bool {
self.stage_id == self.num_stages - 1
}
pub fn bubble_fraction(&self) -> f64 {
(self.num_stages as f64 - 1.0) / self.num_micro_batches as f64
}
pub fn efficiency(&self) -> f64 {
1.0 - self.bubble_fraction()
}
pub fn schedule_1f1b(&self) -> Vec<PipelineAction> {
let m = self.num_micro_batches;
let p = self.num_stages;
let mut actions = Vec::new();
let warmup_forwards = p - self.stage_id - 1;
for mb in 0..warmup_forwards.min(m) {
actions.push(PipelineAction::Forward(mb));
}
let steady_start = warmup_forwards.min(m);
let mut next_fwd = steady_start;
let mut next_bwd = 0;
while next_fwd < m || next_bwd < m {
if next_fwd < m {
actions.push(PipelineAction::Forward(next_fwd));
next_fwd += 1;
}
if next_bwd < m {
actions.push(PipelineAction::Backward(next_bwd));
next_bwd += 1;
}
}
actions
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PipelineAction {
Forward(usize),
Backward(usize),
}
#[derive(Debug, Clone)]
pub struct PipelineActivationBuffer {
pub forward_activations: Vec<Vec<f32>>,
pub backward_gradients: Vec<Vec<f32>>,
pub num_micro_batches: usize,
pub activation_size: usize,
}
impl PipelineActivationBuffer {
pub fn new(num_micro_batches: usize, seq_len: usize, hidden_size: usize) -> Self {
let activation_size = seq_len * hidden_size;
Self {
forward_activations: vec![Vec::new(); num_micro_batches],
backward_gradients: vec![Vec::new(); num_micro_batches],
num_micro_batches,
activation_size,
}
}
pub fn store_activation(&mut self, micro_batch: usize, activation: Vec<f32>) {
assert_eq!(
activation.len(),
self.activation_size,
"activation size mismatch: expected {}, got {}",
self.activation_size,
activation.len()
);
self.forward_activations[micro_batch] = activation;
}
pub fn store_gradient(&mut self, micro_batch: usize, gradient: Vec<f32>) {
assert_eq!(
gradient.len(),
self.activation_size,
"gradient size mismatch: expected {}, got {}",
self.activation_size,
gradient.len()
);
self.backward_gradients[micro_batch] = gradient;
}
pub fn get_activation(&self, micro_batch: usize) -> &[f32] {
&self.forward_activations[micro_batch]
}
pub fn get_gradient(&self, micro_batch: usize) -> &[f32] {
&self.backward_gradients[micro_batch]
}
pub fn memory_bytes(&self) -> usize {
let fwd: usize = self.forward_activations.iter().map(|v| v.len() * 4).sum();
let bwd: usize = self.backward_gradients.iter().map(|v| v.len() * 4).sum();
fwd + bwd
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_stage_basic() {
let stage0 = PipelineStage::new(0, 2, 24, 4);
let stage1 = PipelineStage::new(1, 2, 24, 4);
assert_eq!(stage0.block_start, 0);
assert_eq!(stage0.block_end, 12);
assert_eq!(stage0.num_blocks(), 12);
assert!(stage0.has_embedding);
assert!(!stage0.has_lm_head);
assert_eq!(stage1.block_start, 12);
assert_eq!(stage1.block_end, 24);
assert!(stage1.has_lm_head);
assert!(!stage1.has_embedding);
}
#[test]
fn test_pipeline_stage_4way() {
for i in 0..4 {
let stage = PipelineStage::new(i, 4, 24, 8);
assert_eq!(stage.num_blocks(), 6);
assert_eq!(stage.block_start, i * 6);
assert_eq!(stage.block_end, (i + 1) * 6);
}
}
#[test]
fn test_pipeline_stage_uneven() {
let s0 = PipelineStage::new(0, 3, 10, 6);
let s1 = PipelineStage::new(1, 3, 10, 6);
let s2 = PipelineStage::new(2, 3, 10, 6);
assert_eq!(s0.num_blocks(), 4);
assert_eq!(s1.num_blocks(), 3);
assert_eq!(s2.num_blocks(), 3);
assert_eq!(s0.block_end, s1.block_start);
assert_eq!(s1.block_end, s2.block_start);
assert_eq!(s2.block_end, 10);
}
#[test]
fn test_pipeline_bubble_fraction() {
let stage = PipelineStage::new(0, 2, 24, 4);
assert!((stage.bubble_fraction() - 0.25).abs() < 1e-10);
assert!((stage.efficiency() - 0.75).abs() < 1e-10);
let stage = PipelineStage::new(0, 4, 24, 8);
assert!((stage.bubble_fraction() - 0.375).abs() < 1e-10);
let stage = PipelineStage::new(0, 2, 24, 16);
assert!((stage.bubble_fraction() - 0.0625).abs() < 1e-10);
}
#[test]
fn test_pipeline_1f1b_schedule() {
let stage = PipelineStage::new(0, 2, 24, 4);
let schedule = stage.schedule_1f1b();
let fwd_count = schedule.iter().filter(|a| matches!(a, PipelineAction::Forward(_))).count();
let bwd_count =
schedule.iter().filter(|a| matches!(a, PipelineAction::Backward(_))).count();
assert_eq!(fwd_count, 4, "should have 4 forwards");
assert_eq!(bwd_count, 4, "should have 4 backwards");
let mut fwd_ids: Vec<_> = schedule
.iter()
.filter_map(|a| match a {
PipelineAction::Forward(id) => Some(*id),
_ => None,
})
.collect();
fwd_ids.sort_unstable();
assert_eq!(fwd_ids, vec![0, 1, 2, 3]);
}
#[test]
fn test_pipeline_activation_buffer() {
let mut buf = PipelineActivationBuffer::new(2, 512, 1024);
assert_eq!(buf.activation_size, 512 * 1024);
let act = vec![1.0f32; 512 * 1024];
buf.store_activation(0, act.clone());
assert_eq!(buf.get_activation(0).len(), 512 * 1024);
assert_eq!(buf.get_activation(0)[0], 1.0);
let grad = vec![0.5f32; 512 * 1024];
buf.store_gradient(1, grad);
assert_eq!(buf.get_gradient(1)[0], 0.5);
}
#[test]
fn test_pipeline_first_last_stage() {
let s0 = PipelineStage::new(0, 3, 12, 6);
let s1 = PipelineStage::new(1, 3, 12, 6);
let s2 = PipelineStage::new(2, 3, 12, 6);
assert!(s0.is_first());
assert!(!s0.is_last());
assert!(!s1.is_first());
assert!(!s1.is_last());
assert!(!s2.is_first());
assert!(s2.is_last());
}
#[test]
#[should_panic(expected = "need at least")]
fn test_pipeline_too_few_micro_batches() {
PipelineStage::new(0, 4, 24, 2); }
}