#[derive(Debug, Clone)]
pub struct PipelineStage {
pub stage_id: usize,
pub num_stages: usize,
pub layer_start: usize,
pub layer_end: usize,
pub is_first: bool,
pub is_last: bool,
}
impl PipelineStage {
pub fn new(stage_id: usize, num_stages: usize, num_layers: usize) -> Self {
assert!(num_stages > 0, "num_stages must be > 0");
let base = num_layers / num_stages;
let remainder = num_layers % num_stages;
let layer_start = stage_id * base;
let layer_end = if stage_id + 1 == num_stages {
stage_id * base + base + remainder
} else {
(stage_id + 1) * base
};
Self {
stage_id,
num_stages,
layer_start,
layer_end,
is_first: stage_id == 0,
is_last: stage_id + 1 == num_stages,
}
}
pub fn layer_count(&self) -> usize {
self.layer_end - self.layer_start
}
pub fn contains_layer(&self, layer_idx: usize) -> bool {
layer_idx >= self.layer_start && layer_idx < self.layer_end
}
pub fn layer_range(&self) -> std::ops::Range<usize> {
self.layer_start..self.layer_end
}
}
pub fn partition_layers(num_layers: usize, num_stages: usize) -> Vec<PipelineStage> {
(0..num_stages)
.map(|id| PipelineStage::new(id, num_stages, num_layers))
.collect()
}
pub struct MicroBatch {
pub batch_id: usize,
pub tokens: Vec<u32>,
pub hidden_states: Option<Vec<f32>>,
pub stage_id: usize,
}
impl MicroBatch {
pub fn new(batch_id: usize, tokens: Vec<u32>) -> Self {
Self {
batch_id,
tokens,
hidden_states: None,
stage_id: 0,
}
}
pub fn with_hidden_states(batch_id: usize, hidden_states: Vec<f32>, stage_id: usize) -> Self {
Self {
batch_id,
tokens: Vec::new(),
hidden_states: Some(hidden_states),
stage_id,
}
}
#[inline]
pub fn is_first_stage(&self) -> bool {
self.stage_id == 0
}
}
pub struct PipelineSchedule {
pub num_stages: usize,
pub num_micro_batches: usize,
}
impl PipelineSchedule {
pub fn new(num_stages: usize, num_micro_batches: usize) -> Self {
Self {
num_stages,
num_micro_batches,
}
}
pub fn gpipe_schedule(&self) -> Vec<(usize, usize, bool)> {
let fwd_steps = self.num_stages * self.num_micro_batches;
let mut schedule = Vec::with_capacity(fwd_steps * 2);
for stage in 0..self.num_stages {
for mb in 0..self.num_micro_batches {
schedule.push((stage, mb, true));
}
}
for stage in 0..self.num_stages {
for mb in 0..self.num_micro_batches {
schedule.push((stage, mb, false));
}
}
schedule
}
pub fn total_steps(&self) -> usize {
2 * self.num_stages * self.num_micro_batches
}
pub fn bubble_fraction(&self) -> f32 {
if self.num_stages == 0 || self.num_micro_batches == 0 {
return 0.0;
}
let numerator = (self.num_stages as f32) - 1.0;
let denominator = numerator + (self.num_micro_batches as f32);
if denominator <= 0.0 {
0.0
} else {
numerator / denominator
}
}
}
pub fn pipeline_memory_per_stage(
total_params: u64,
num_stages: usize,
activation_bytes_per_token: usize,
micro_batch_tokens: usize,
) -> usize {
let weight_bytes = ((total_params as usize) * std::mem::size_of::<f32>())
.checked_div(num_stages)
.unwrap_or(0);
let activation_bytes = activation_bytes_per_token * micro_batch_tokens;
weight_bytes + activation_bytes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_stage_partition_even() {
let stages = partition_layers(8, 4);
assert_eq!(stages.len(), 4);
assert_eq!(stages[0].layer_start, 0);
assert_eq!(stages[0].layer_end, 2);
assert_eq!(stages[1].layer_start, 2);
assert_eq!(stages[1].layer_end, 4);
assert_eq!(stages[2].layer_start, 4);
assert_eq!(stages[2].layer_end, 6);
assert_eq!(stages[3].layer_start, 6);
assert_eq!(stages[3].layer_end, 8);
}
#[test]
fn test_pipeline_stage_partition_uneven() {
let stages = partition_layers(10, 3);
assert_eq!(stages.len(), 3);
assert_eq!(stages[0].layer_count(), 3);
assert_eq!(stages[1].layer_count(), 3);
assert_eq!(stages[2].layer_count(), 4); assert_eq!(stages[0].layer_end, stages[1].layer_start);
assert_eq!(stages[1].layer_end, stages[2].layer_start);
assert_eq!(stages[2].layer_end, 10);
}
#[test]
fn test_pipeline_stage_contains_layer() {
let stage = PipelineStage::new(1, 4, 8);
assert!(!stage.contains_layer(1));
assert!(stage.contains_layer(2));
assert!(stage.contains_layer(3));
assert!(!stage.contains_layer(4));
}
#[test]
fn test_pipeline_stage_is_first_last() {
let stages = partition_layers(6, 3);
assert!(stages[0].is_first);
assert!(!stages[0].is_last);
assert!(!stages[1].is_first);
assert!(!stages[1].is_last);
assert!(!stages[2].is_first);
assert!(stages[2].is_last);
}
#[test]
fn test_micro_batch_new() {
let mb = MicroBatch::new(42, vec![1, 2, 3]);
assert_eq!(mb.batch_id, 42);
assert_eq!(mb.tokens, vec![1u32, 2, 3]);
assert!(mb.hidden_states.is_none());
assert_eq!(mb.stage_id, 0);
assert!(mb.is_first_stage());
}
#[test]
fn test_micro_batch_with_hidden_states() {
let hs = vec![0.1f32, 0.2, 0.3];
let mb = MicroBatch::with_hidden_states(7, hs.clone(), 2);
assert_eq!(mb.batch_id, 7);
assert_eq!(mb.stage_id, 2);
assert!(!mb.is_first_stage());
assert_eq!(mb.hidden_states.as_deref(), Some(hs.as_slice()));
}
#[test]
fn test_pipeline_schedule_gpipe() {
let sched = PipelineSchedule::new(2, 3);
let steps = sched.gpipe_schedule();
assert_eq!(steps.len(), 12);
for &(_, _, is_fwd) in &steps[..6] {
assert!(is_fwd);
}
for &(_, _, is_fwd) in &steps[6..] {
assert!(!is_fwd);
}
assert_eq!(steps[0], (0, 0, true));
assert_eq!(steps[0], (0, 0, true));
assert_eq!(steps[1], (0, 1, true));
assert_eq!(steps[2], (0, 2, true));
assert_eq!(steps[3], (1, 0, true));
}
#[test]
fn test_pipeline_schedule_bubble_fraction() {
let sched = PipelineSchedule::new(4, 8);
let bubble = sched.bubble_fraction();
let expected = 3.0f32 / 11.0;
assert!((bubble - expected).abs() < 1e-5, "bubble={bubble}");
let sched1 = PipelineSchedule::new(1, 4);
assert!((sched1.bubble_fraction()).abs() < 1e-6);
}
#[test]
fn test_pipeline_memory_per_stage() {
let total_params: u64 = 8_000_000_000;
let num_stages = 4;
let activation_bytes_per_token = 512;
let micro_batch_tokens = 128;
let mem = pipeline_memory_per_stage(
total_params,
num_stages,
activation_bytes_per_token,
micro_batch_tokens,
);
let expected_weights = (total_params as usize) * 4 / num_stages;
let expected_act = activation_bytes_per_token * micro_batch_tokens;
assert_eq!(mem, expected_weights + expected_act);
}
}