use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PpError {
ZeroStages,
StageRankOutOfRange { stage_rank: usize, num_stages: usize },
InsufficientLayers { total_layers: usize, num_stages: usize },
InvalidMicroBatch(usize),
}
impl fmt::Display for PpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ZeroStages => write!(f, "num_stages must be at least 1"),
Self::StageRankOutOfRange { stage_rank, num_stages } => write!(
f,
"stage_rank {stage_rank} is out of range for num_stages {num_stages}"
),
Self::InsufficientLayers { total_layers, num_stages } => write!(
f,
"cannot partition {total_layers} layers into {num_stages} stages"
),
Self::InvalidMicroBatch(id) => write!(f, "invalid micro-batch id: {id}"),
}
}
}
impl std::error::Error for PpError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PipelineSchedule {
GPipe,
OneFOneBubble,
Interleaved { num_virtual_stages: usize },
}
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub num_stages: usize,
pub num_micro_batches: usize,
pub stage_rank: usize,
pub schedule: PipelineSchedule,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
num_stages: 1,
num_micro_batches: 8,
stage_rank: 0,
schedule: PipelineSchedule::GPipe,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineStageInfo {
pub stage_id: usize,
pub num_layers: usize,
pub layer_start: usize,
pub layer_end: usize,
}
pub fn partition_layers_evenly(
total_layers: usize,
num_stages: usize,
) -> Vec<PipelineStageInfo> {
if num_stages == 0 || total_layers == 0 {
return Vec::new();
}
let base = total_layers / num_stages;
let remainder = total_layers % num_stages;
let mut stages = Vec::with_capacity(num_stages);
let mut layer_start = 0usize;
for stage_id in 0..num_stages {
let num_layers = base + if stage_id < remainder { 1 } else { 0 };
let layer_end = layer_start + num_layers;
stages.push(PipelineStageInfo { stage_id, num_layers, layer_start, layer_end });
layer_start = layer_end;
}
stages
}
pub fn partition_layers_by_flops(
layer_flops: &[u64],
num_stages: usize,
) -> Vec<PipelineStageInfo> {
let total_layers = layer_flops.len();
if num_stages == 0 || total_layers == 0 {
return Vec::new();
}
if num_stages == 1 {
return vec![PipelineStageInfo {
stage_id: 0,
num_layers: total_layers,
layer_start: 0,
layer_end: total_layers,
}];
}
let total_flops: u64 = layer_flops.iter().sum();
let target = (total_flops + num_stages as u64 - 1) / num_stages as u64;
let mut boundaries: Vec<usize> = vec![0]; let mut cumulative = 0u64;
let mut stages_opened = 1usize;
for (idx, &flops) in layer_flops.iter().enumerate() {
cumulative += flops;
let remaining_layers = total_layers - idx - 1;
let remaining_stages = num_stages - stages_opened;
if cumulative >= target
&& stages_opened < num_stages
&& remaining_layers >= remaining_stages
{
boundaries.push(idx + 1);
stages_opened += 1;
cumulative = 0;
}
}
let mut stages = Vec::with_capacity(num_stages);
for stage_id in 0..num_stages {
let layer_start = boundaries[stage_id];
let layer_end = if stage_id + 1 < boundaries.len() {
boundaries[stage_id + 1]
} else {
total_layers
};
let num_layers = layer_end - layer_start;
stages.push(PipelineStageInfo { stage_id, num_layers, layer_start, layer_end });
}
stages
}
#[derive(Debug, Clone)]
pub struct MicroBatch {
pub micro_batch_id: usize,
pub data: Vec<f32>,
pub shape: (usize, usize),
pub is_last: bool,
}
impl MicroBatch {
pub fn new(
micro_batch_id: usize,
data: Vec<f32>,
shape: (usize, usize),
is_last: bool,
) -> Self {
Self { micro_batch_id, data, shape, is_last }
}
}
#[derive(Debug, Clone)]
pub struct PipelineBubbleStats {
pub num_stages: usize,
pub num_micro_batches: usize,
pub schedule: PipelineSchedule,
}
impl PipelineBubbleStats {
pub fn bubble_fraction_gpipe(&self) -> f32 {
let p = self.num_stages as f32;
let m = self.num_micro_batches as f32;
(p - 1.0) / (m + p - 1.0)
}
pub fn bubble_fraction_1f1b(&self) -> f32 {
let p = self.num_stages as f32;
let m = self.num_micro_batches as f32;
(p - 1.0) / (2.0 * (m - 1.0) + p)
}
pub fn memory_footprint_ratio(&self) -> f32 {
let p = self.num_stages as f32;
let m = self.num_micro_batches as f32;
(p + m) / (p * m)
}
pub fn optimal_num_micro_batches(&self) -> usize {
4 * self.num_stages
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PipelineStep {
Forward { micro_batch_id: usize },
Backward { micro_batch_id: usize },
WaitForward { from_stage: usize },
WaitBackward { from_stage: usize },
SendActivations { to_stage: usize, micro_batch_id: usize },
RecvActivations { from_stage: usize, micro_batch_id: usize },
}
pub struct PipelineScheduler {
pub config: PipelineConfig,
pub stages: Vec<PipelineStageInfo>,
}
impl PipelineScheduler {
pub fn new(config: PipelineConfig, total_layers: usize) -> Self {
let stages = partition_layers_evenly(total_layers, config.num_stages);
Self { config, stages }
}
pub fn this_stage(&self) -> &PipelineStageInfo {
self.stages
.get(self.config.stage_rank)
.unwrap_or_else(|| self.stages.first().expect("pipeline must have at least one stage"))
}
pub fn is_first_stage(&self) -> bool {
self.config.stage_rank == 0
}
pub fn is_last_stage(&self) -> bool {
self.config.stage_rank + 1 == self.config.num_stages
}
pub fn schedule_steps(&self, num_micro_batches: usize) -> Vec<PipelineStep> {
match &self.config.schedule {
PipelineSchedule::GPipe => {
self.schedule_gpipe(num_micro_batches)
},
PipelineSchedule::OneFOneBubble => {
self.schedule_1f1b(num_micro_batches)
},
PipelineSchedule::Interleaved { num_virtual_stages } => {
self.schedule_interleaved(num_micro_batches, *num_virtual_stages)
},
}
}
fn schedule_gpipe(&self, num_micro_batches: usize) -> Vec<PipelineStep> {
let rank = self.config.stage_rank;
let mut steps = Vec::new();
for mb in 0..num_micro_batches {
if !self.is_first_stage() {
steps.push(PipelineStep::RecvActivations {
from_stage: rank - 1,
micro_batch_id: mb,
});
}
steps.push(PipelineStep::Forward { micro_batch_id: mb });
if !self.is_last_stage() {
steps.push(PipelineStep::SendActivations {
to_stage: rank + 1,
micro_batch_id: mb,
});
}
}
for mb in (0..num_micro_batches).rev() {
if !self.is_last_stage() {
steps.push(PipelineStep::WaitBackward { from_stage: rank + 1 });
}
steps.push(PipelineStep::Backward { micro_batch_id: mb });
}
steps
}
fn schedule_1f1b(&self, num_micro_batches: usize) -> Vec<PipelineStep> {
let rank = self.config.stage_rank;
let num_stages = self.config.num_stages;
let warmup_steps = (num_stages - rank).min(num_micro_batches);
let mut steps = Vec::new();
for mb in 0..warmup_steps {
if !self.is_first_stage() {
steps.push(PipelineStep::RecvActivations {
from_stage: rank - 1,
micro_batch_id: mb,
});
}
steps.push(PipelineStep::Forward { micro_batch_id: mb });
if !self.is_last_stage() {
steps.push(PipelineStep::SendActivations {
to_stage: rank + 1,
micro_batch_id: mb,
});
}
}
let steady_start_fwd = warmup_steps;
let steady_start_bwd = 0usize;
let steady_count = num_micro_batches - warmup_steps;
for i in 0..steady_count {
let fwd_mb = steady_start_fwd + i;
let bwd_mb = steady_start_bwd + i;
if !self.is_first_stage() {
steps.push(PipelineStep::RecvActivations {
from_stage: rank - 1,
micro_batch_id: fwd_mb,
});
}
steps.push(PipelineStep::Forward { micro_batch_id: fwd_mb });
if !self.is_last_stage() {
steps.push(PipelineStep::SendActivations {
to_stage: rank + 1,
micro_batch_id: fwd_mb,
});
}
if !self.is_last_stage() {
steps.push(PipelineStep::WaitBackward { from_stage: rank + 1 });
}
steps.push(PipelineStep::Backward { micro_batch_id: bwd_mb });
}
let cooldown_start = steady_count;
for i in 0..warmup_steps {
let bwd_mb = cooldown_start + i;
if !self.is_last_stage() {
steps.push(PipelineStep::WaitBackward { from_stage: rank + 1 });
}
steps.push(PipelineStep::Backward { micro_batch_id: bwd_mb });
}
steps
}
fn schedule_interleaved(
&self,
num_micro_batches: usize,
_num_virtual_stages: usize,
) -> Vec<PipelineStep> {
self.schedule_1f1b(num_micro_batches)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_layers_evenly_exact() {
let stages = partition_layers_evenly(12, 4);
assert_eq!(stages.len(), 4);
for s in &stages {
assert_eq!(s.num_layers, 3);
}
assert_eq!(stages[0].layer_start, 0);
assert_eq!(stages[0].layer_end, 3);
assert_eq!(stages[3].layer_end, 12);
}
#[test]
fn test_partition_layers_evenly_remainder() {
let stages = partition_layers_evenly(10, 3);
assert_eq!(stages.len(), 3);
assert_eq!(stages[0].num_layers, 4);
assert_eq!(stages[1].num_layers, 3);
assert_eq!(stages[2].num_layers, 3);
assert_eq!(stages[0].layer_start, 0);
assert_eq!(stages[2].layer_end, 10);
}
#[test]
fn test_partition_by_flops_balanced() {
let flops = vec![100u64; 8];
let stages = partition_layers_by_flops(&flops, 4);
assert_eq!(stages.len(), 4);
for s in &stages {
assert_eq!(s.num_layers, 2);
}
}
#[test]
fn test_partition_by_flops_uneven() {
let flops = vec![100u64, 100, 100, 100, 400];
let stages = partition_layers_by_flops(&flops, 2);
assert_eq!(stages.len(), 2);
assert_eq!(stages[0].num_layers, 4);
assert_eq!(stages[1].num_layers, 1);
let s0_flops: u64 = flops[stages[0].layer_start..stages[0].layer_end].iter().sum();
let s1_flops: u64 = flops[stages[1].layer_start..stages[1].layer_end].iter().sum();
assert_eq!(s0_flops, 400);
assert_eq!(s1_flops, 400);
}
#[test]
fn test_bubble_fraction_gpipe() {
let stats =
PipelineBubbleStats { num_stages: 4, num_micro_batches: 8, schedule: PipelineSchedule::GPipe };
let expected = 3.0f32 / 11.0;
assert!((stats.bubble_fraction_gpipe() - expected).abs() < 1e-6);
}
#[test]
fn test_bubble_fraction_1f1b() {
let stats = PipelineBubbleStats {
num_stages: 4,
num_micro_batches: 8,
schedule: PipelineSchedule::OneFOneBubble,
};
let expected = 3.0f32 / 18.0;
assert!((stats.bubble_fraction_1f1b() - expected).abs() < 1e-6);
}
#[test]
fn test_first_last_stage_detection() {
let cfg_first = PipelineConfig {
num_stages: 4,
num_micro_batches: 8,
stage_rank: 0,
schedule: PipelineSchedule::GPipe,
};
let sched_first = PipelineScheduler::new(cfg_first, 8);
assert!(sched_first.is_first_stage());
assert!(!sched_first.is_last_stage());
let cfg_last = PipelineConfig {
num_stages: 4,
num_micro_batches: 8,
stage_rank: 3,
schedule: PipelineSchedule::GPipe,
};
let sched_last = PipelineScheduler::new(cfg_last, 8);
assert!(!sched_last.is_first_stage());
assert!(sched_last.is_last_stage());
}
#[test]
fn test_micro_batch_count_in_schedule() {
let cfg = PipelineConfig {
num_stages: 2,
num_micro_batches: 4,
stage_rank: 0,
schedule: PipelineSchedule::GPipe,
};
let sched = PipelineScheduler::new(cfg, 4);
let steps = sched.schedule_steps(4);
let fwd_count = steps
.iter()
.filter(|s| matches!(s, PipelineStep::Forward { .. }))
.count();
let bwd_count = steps
.iter()
.filter(|s| matches!(s, PipelineStep::Backward { .. }))
.count();
assert_eq!(fwd_count, 4);
assert_eq!(bwd_count, 4);
}
#[test]
fn test_schedule_steps_stage_0_gpipe() {
let cfg = PipelineConfig {
num_stages: 3,
num_micro_batches: 3,
stage_rank: 0,
schedule: PipelineSchedule::GPipe,
};
let sched = PipelineScheduler::new(cfg, 6);
let steps = sched.schedule_steps(3);
assert!(!steps.iter().any(|s| matches!(s, PipelineStep::RecvActivations { .. })));
assert_eq!(
steps.iter().filter(|s| matches!(s, PipelineStep::Forward { .. })).count(),
3
);
assert_eq!(
steps
.iter()
.filter(|s| matches!(s, PipelineStep::SendActivations { .. }))
.count(),
3
);
}
#[test]
fn test_optimal_num_micro_batches() {
let stats = PipelineBubbleStats {
num_stages: 8,
num_micro_batches: 32,
schedule: PipelineSchedule::GPipe,
};
assert_eq!(stats.optimal_num_micro_batches(), 32); }
#[test]
fn test_memory_footprint_ratio() {
let stats = PipelineBubbleStats {
num_stages: 4,
num_micro_batches: 8,
schedule: PipelineSchedule::GPipe,
};
let expected = 12.0f32 / 32.0;
assert!((stats.memory_footprint_ratio() - expected).abs() < 1e-6);
}
#[test]
fn test_flops_partition_single_stage() {
let flops = vec![100u64, 200, 300, 400];
let stages = partition_layers_by_flops(&flops, 1);
assert_eq!(stages.len(), 1);
assert_eq!(stages[0].num_layers, 4);
assert_eq!(stages[0].layer_start, 0);
assert_eq!(stages[0].layer_end, 4);
}
#[test]
fn test_this_stage_metadata() {
let cfg = PipelineConfig {
num_stages: 4,
num_micro_batches: 8,
stage_rank: 2,
schedule: PipelineSchedule::GPipe,
};
let sched = PipelineScheduler::new(cfg, 8);
let stage = sched.this_stage();
assert_eq!(stage.stage_id, 2);
assert_eq!(stage.num_layers, 2); assert_eq!(stage.layer_start, 4);
assert_eq!(stage.layer_end, 6);
}
#[test]
fn test_1f1b_schedule_has_fwd_and_bwd() {
let cfg = PipelineConfig {
num_stages: 4,
num_micro_batches: 8,
stage_rank: 1,
schedule: PipelineSchedule::OneFOneBubble,
};
let sched = PipelineScheduler::new(cfg, 8);
let steps = sched.schedule_steps(8);
let fwd = steps.iter().filter(|s| matches!(s, PipelineStep::Forward { .. })).count();
let bwd = steps.iter().filter(|s| matches!(s, PipelineStep::Backward { .. })).count();
assert_eq!(fwd, 8);
assert_eq!(bwd, 8);
}
#[test]
fn test_micro_batch_creation() {
let mb = MicroBatch::new(3, vec![1.0f32, 2.0, 3.0], (1, 3), true);
assert_eq!(mb.micro_batch_id, 3);
assert_eq!(mb.shape, (1, 3));
assert!(mb.is_last);
assert_eq!(mb.data.len(), 3);
}
#[test]
fn test_pp_error_display() {
let e = PpError::ZeroStages;
assert!(e.to_string().contains("num_stages"));
let e2 = PpError::StageRankOutOfRange { stage_rank: 5, num_stages: 4 };
assert!(e2.to_string().contains("5"));
}
}