use oxicuda_driver::{CudaError, CudaResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PipelineSchedule {
GPipe,
PipeDream1F1B,
InterleavedStages,
ZeroBubble,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MicrobatchStatus {
Pending,
InForward(usize),
InBackward(usize),
Complete,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EventType {
ForwardStart,
ForwardEnd,
BackwardStart,
BackwardEnd,
SendActivation,
RecvActivation,
WeightGradStart,
WeightGradEnd,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineEvent {
pub timestamp: usize,
pub event_type: EventType,
pub stage_id: usize,
pub microbatch_id: usize,
}
#[derive(Debug, Clone)]
pub struct PipelineStage {
pub stage_id: usize,
pub device_id: i32,
pub name: String,
pub compute_cost_estimate: f64,
}
impl PipelineStage {
pub fn new(stage_id: usize, device_id: i32, name: impl Into<String>) -> Self {
Self {
stage_id,
device_id,
name: name.into(),
compute_cost_estimate: 1.0,
}
}
pub fn with_compute_cost(mut self, cost: f64) -> Self {
self.compute_cost_estimate = cost;
self
}
}
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub stages: Vec<PipelineStage>,
pub num_microbatches: usize,
pub schedule_type: PipelineSchedule,
pub interleave_factor: usize,
}
impl PipelineConfig {
pub fn validate(&self) -> CudaResult<()> {
if self.stages.is_empty() {
return Err(CudaError::InvalidValue);
}
if self.num_microbatches == 0 {
return Err(CudaError::InvalidValue);
}
if self.interleave_factor == 0 {
return Err(CudaError::InvalidValue);
}
if self.schedule_type == PipelineSchedule::InterleavedStages
&& self.stages.len() % self.interleave_factor != 0
{
return Err(CudaError::InvalidValue);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BubbleAnalysis {
pub total_time: usize,
pub compute_time: usize,
pub bubble_time: usize,
pub bubble_ratio: f64,
pub per_stage_idle: Vec<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CheckpointDecision {
Store,
Recompute,
Offload,
}
pub struct GpipeScheduler;
impl GpipeScheduler {
pub fn schedule(num_stages: usize, num_microbatches: usize) -> Vec<PipelineEvent> {
let mut events = Vec::new();
let s = num_stages;
let m = num_microbatches;
for mb in 0..m {
for st in 0..s {
let t = st + mb;
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: st,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: st,
microbatch_id: mb,
});
if st + 1 < s {
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::SendActivation,
stage_id: st,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::RecvActivation,
stage_id: st + 1,
microbatch_id: mb,
});
}
}
}
let bwd_offset = s + m - 1;
for mb in 0..m {
for st in (0..s).rev() {
let t = bwd_offset + (s - 1 - st) + mb;
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: st,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: st,
microbatch_id: mb,
});
if st > 0 {
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::SendActivation,
stage_id: st,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::RecvActivation,
stage_id: st - 1,
microbatch_id: mb,
});
}
}
}
events.sort_by_key(|e| (e.timestamp, e.stage_id, e.microbatch_id));
events
}
}
pub struct OneFOneBScheduler;
impl OneFOneBScheduler {
pub fn schedule(num_stages: usize, num_microbatches: usize) -> Vec<PipelineEvent> {
let mut events = Vec::new();
let s = num_stages;
let m = num_microbatches;
if s == 0 || m == 0 {
return events;
}
let mut stage_time: Vec<usize> = (0..s).collect();
let mut fwd_mb = vec![0usize; s]; let mut bwd_mb = vec![0usize; s];
let warmup_count: Vec<usize> = (0..s).map(|st| (s - 1 - st).min(m)).collect();
for st in 0..s {
for _ in 0..warmup_count[st] {
let mb = fwd_mb[st];
if mb >= m {
break;
}
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: st,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: st,
microbatch_id: mb,
});
fwd_mb[st] = mb + 1;
stage_time[st] = t + 1;
}
}
for st in 0..s {
while fwd_mb[st] < m {
let mb_f = fwd_mb[st];
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: st,
microbatch_id: mb_f,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: st,
microbatch_id: mb_f,
});
fwd_mb[st] = mb_f + 1;
stage_time[st] = t + 1;
let mb_b = bwd_mb[st];
if mb_b < m {
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: st,
microbatch_id: mb_b,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: st,
microbatch_id: mb_b,
});
bwd_mb[st] = mb_b + 1;
stage_time[st] = t + 1;
}
}
}
for st in 0..s {
while bwd_mb[st] < m {
let mb_b = bwd_mb[st];
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: st,
microbatch_id: mb_b,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: st,
microbatch_id: mb_b,
});
bwd_mb[st] = mb_b + 1;
stage_time[st] = t + 1;
}
}
events.sort_by_key(|e| (e.timestamp, e.stage_id, e.microbatch_id));
events
}
}
pub struct InterleavedScheduler;
impl InterleavedScheduler {
pub fn schedule(
num_stages: usize,
num_microbatches: usize,
interleave_factor: usize,
) -> Vec<PipelineEvent> {
let mut events = Vec::new();
let s = num_stages;
let m = num_microbatches;
if s == 0 || m == 0 || interleave_factor == 0 {
return events;
}
let num_physical = s / interleave_factor;
if num_physical == 0 {
return events;
}
let mut fwd_mb = vec![0usize; s];
let mut bwd_mb = vec![0usize; s];
let mut stage_time: Vec<usize> = (0..s)
.map(|v| {
let phys = v % num_physical;
let group_idx = v / num_physical;
phys + group_idx * num_physical
})
.collect();
let warmup_count: Vec<usize> = (0..s)
.map(|v| {
let effective = (s - 1 - v) / interleave_factor;
effective.min(m)
})
.collect();
for v in 0..s {
for _ in 0..warmup_count[v] {
let mb = fwd_mb[v];
if mb >= m {
break;
}
let t = stage_time[v];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: v,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: v,
microbatch_id: mb,
});
fwd_mb[v] = mb + 1;
stage_time[v] = t + 1;
}
}
for v in 0..s {
while fwd_mb[v] < m {
let mb_f = fwd_mb[v];
let t = stage_time[v];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: v,
microbatch_id: mb_f,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: v,
microbatch_id: mb_f,
});
fwd_mb[v] = mb_f + 1;
stage_time[v] = t + 1;
let mb_b = bwd_mb[v];
if mb_b < m {
let t = stage_time[v];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: v,
microbatch_id: mb_b,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: v,
microbatch_id: mb_b,
});
bwd_mb[v] = mb_b + 1;
stage_time[v] = t + 1;
}
}
}
for v in 0..s {
while bwd_mb[v] < m {
let mb_b = bwd_mb[v];
let t = stage_time[v];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: v,
microbatch_id: mb_b,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: v,
microbatch_id: mb_b,
});
bwd_mb[v] = mb_b + 1;
stage_time[v] = t + 1;
}
}
events.sort_by_key(|e| (e.timestamp, e.stage_id, e.microbatch_id));
events
}
}
pub struct ZeroBubbleScheduler;
impl ZeroBubbleScheduler {
pub fn schedule(num_stages: usize, num_microbatches: usize) -> Vec<PipelineEvent> {
let mut events = Vec::new();
let s = num_stages;
let m = num_microbatches;
if s == 0 || m == 0 {
return events;
}
let mut stage_time: Vec<usize> = (0..s).collect();
let mut fwd_mb = vec![0usize; s];
let mut bwd_mb = vec![0usize; s]; let mut wgt_mb = vec![0usize; s];
let warmup_count: Vec<usize> = (0..s).map(|st| (s - 1 - st).min(m)).collect();
for st in 0..s {
for _ in 0..warmup_count[st] {
let mb = fwd_mb[st];
if mb >= m {
break;
}
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: st,
microbatch_id: mb,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: st,
microbatch_id: mb,
});
fwd_mb[st] = mb + 1;
stage_time[st] = t + 1;
}
}
for st in 0..s {
while fwd_mb[st] < m {
let mb_f = fwd_mb[st];
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardStart,
stage_id: st,
microbatch_id: mb_f,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::ForwardEnd,
stage_id: st,
microbatch_id: mb_f,
});
fwd_mb[st] = mb_f + 1;
stage_time[st] = t + 1;
let mb_b = bwd_mb[st];
if mb_b < m {
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: st,
microbatch_id: mb_b,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: st,
microbatch_id: mb_b,
});
bwd_mb[st] = mb_b + 1;
stage_time[st] = t + 1;
}
let mb_w = wgt_mb[st];
if mb_w < m {
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::WeightGradStart,
stage_id: st,
microbatch_id: mb_w,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::WeightGradEnd,
stage_id: st,
microbatch_id: mb_w,
});
wgt_mb[st] = mb_w + 1;
stage_time[st] = t + 1;
}
}
}
for st in 0..s {
while bwd_mb[st] < m || wgt_mb[st] < m {
if bwd_mb[st] < m {
let mb_b = bwd_mb[st];
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardStart,
stage_id: st,
microbatch_id: mb_b,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::BackwardEnd,
stage_id: st,
microbatch_id: mb_b,
});
bwd_mb[st] = mb_b + 1;
stage_time[st] = t + 1;
}
if wgt_mb[st] < m {
let mb_w = wgt_mb[st];
let t = stage_time[st];
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::WeightGradStart,
stage_id: st,
microbatch_id: mb_w,
});
events.push(PipelineEvent {
timestamp: t,
event_type: EventType::WeightGradEnd,
stage_id: st,
microbatch_id: mb_w,
});
wgt_mb[st] = mb_w + 1;
stage_time[st] = t + 1;
}
}
}
events.sort_by_key(|e| (e.timestamp, e.stage_id, e.microbatch_id));
events
}
}
pub struct PipelineEngine {
config: PipelineConfig,
schedule_cache: Option<Vec<PipelineEvent>>,
}
impl PipelineEngine {
pub fn new(config: PipelineConfig) -> CudaResult<Self> {
config.validate()?;
Ok(Self {
config,
schedule_cache: None,
})
}
pub fn config(&self) -> &PipelineConfig {
&self.config
}
pub fn num_stages(&self) -> usize {
self.config.stages.len()
}
pub fn num_microbatches(&self) -> usize {
self.config.num_microbatches
}
pub fn generate_schedule(&mut self) -> Vec<PipelineEvent> {
let s = self.config.stages.len();
let m = self.config.num_microbatches;
let schedule = match self.config.schedule_type {
PipelineSchedule::GPipe => GpipeScheduler::schedule(s, m),
PipelineSchedule::PipeDream1F1B => OneFOneBScheduler::schedule(s, m),
PipelineSchedule::InterleavedStages => {
InterleavedScheduler::schedule(s, m, self.config.interleave_factor)
}
PipelineSchedule::ZeroBubble => ZeroBubbleScheduler::schedule(s, m),
};
self.schedule_cache = Some(schedule.clone());
schedule
}
pub fn bubble_ratio(&mut self) -> f64 {
self.analyze().bubble_ratio
}
pub fn analyze(&mut self) -> BubbleAnalysis {
let events = if let Some(ref cached) = self.schedule_cache {
cached.clone()
} else {
self.generate_schedule()
};
let num_stages = self.config.stages.len();
Self::analyze_events(&events, num_stages)
}
fn analyze_events(events: &[PipelineEvent], num_stages: usize) -> BubbleAnalysis {
if events.is_empty() || num_stages == 0 {
return BubbleAnalysis {
total_time: 0,
compute_time: 0,
bubble_time: 0,
bubble_ratio: 0.0,
per_stage_idle: vec![0.0; num_stages],
};
}
let max_time = events.iter().map(|e| e.timestamp).max().unwrap_or(0) + 1;
let mut stage_compute = vec![0usize; num_stages];
for ev in events {
match ev.event_type {
EventType::ForwardStart | EventType::BackwardStart | EventType::WeightGradStart
if ev.stage_id < num_stages =>
{
stage_compute[ev.stage_id] += 1;
}
_ => {}
}
}
let compute_time: usize = stage_compute.iter().sum();
let total_available = max_time * num_stages;
let bubble_time = total_available.saturating_sub(compute_time);
let bubble_ratio = if total_available > 0 {
bubble_time as f64 / total_available as f64
} else {
0.0
};
let per_stage_idle: Vec<f64> = stage_compute
.iter()
.map(|&c| {
if max_time > 0 {
(max_time.saturating_sub(c)) as f64 / max_time as f64
} else {
0.0
}
})
.collect();
BubbleAnalysis {
total_time: max_time,
compute_time,
bubble_time,
bubble_ratio,
per_stage_idle,
}
}
pub fn steady_state_throughput(&mut self) -> f64 {
let analysis = self.analyze();
if analysis.total_time == 0 {
return 0.0;
}
self.config.num_microbatches as f64 / analysis.total_time as f64
}
pub fn microbatch_status_at(&self, microbatch_id: usize, time: usize) -> MicrobatchStatus {
let events = match &self.schedule_cache {
Some(cached) => cached,
None => return MicrobatchStatus::Pending,
};
let mut status = MicrobatchStatus::Pending;
for ev in events {
if ev.microbatch_id != microbatch_id || ev.timestamp > time {
continue;
}
match ev.event_type {
EventType::ForwardStart => {
status = MicrobatchStatus::InForward(ev.stage_id);
}
EventType::BackwardStart => {
status = MicrobatchStatus::InBackward(ev.stage_id);
}
EventType::BackwardEnd if ev.stage_id == 0 => {
status = MicrobatchStatus::Complete;
}
_ => {}
}
}
status
}
}
pub struct ActivationCheckpointing;
impl ActivationCheckpointing {
const STORE_THRESHOLD: usize = 1024;
const RECOMPUTE_THRESHOLD: usize = 256;
pub fn plan(num_stages: usize, memory_budget_per_stage: usize) -> Vec<CheckpointDecision> {
(0..num_stages)
.map(|_| {
if memory_budget_per_stage >= Self::STORE_THRESHOLD {
CheckpointDecision::Store
} else if memory_budget_per_stage >= Self::RECOMPUTE_THRESHOLD {
CheckpointDecision::Recompute
} else {
CheckpointDecision::Offload
}
})
.collect()
}
pub fn plan_variable(budgets: &[usize]) -> Vec<CheckpointDecision> {
budgets
.iter()
.map(|&budget| {
if budget >= Self::STORE_THRESHOLD {
CheckpointDecision::Store
} else if budget >= Self::RECOMPUTE_THRESHOLD {
CheckpointDecision::Recompute
} else {
CheckpointDecision::Offload
}
})
.collect()
}
}
pub struct PipelineVisualizer;
impl PipelineVisualizer {
pub fn render_ascii(events: &[PipelineEvent], num_stages: usize) -> String {
if events.is_empty() || num_stages == 0 {
return String::new();
}
let max_time = events.iter().map(|e| e.timestamp).max().unwrap_or(0) + 1;
let mut grid: Vec<Vec<String>> = vec![vec![String::from(".."); max_time]; num_stages];
for ev in events {
if ev.stage_id >= num_stages || ev.timestamp >= max_time {
continue;
}
let cell = &mut grid[ev.stage_id][ev.timestamp];
match ev.event_type {
EventType::ForwardStart => {
*cell = format!("F{}", ev.microbatch_id);
}
EventType::BackwardStart => {
*cell = format!("B{}", ev.microbatch_id);
}
EventType::WeightGradStart => {
*cell = format!("W{}", ev.microbatch_id);
}
_ => {}
}
}
let col_width = grid
.iter()
.flat_map(|row| row.iter())
.map(|s| s.len())
.max()
.unwrap_or(2)
.max(2);
let mut output = String::new();
for (st, row) in grid.iter().enumerate() {
output.push_str(&format!("Stage {st}: "));
for (t, cell) in row.iter().enumerate() {
if t > 0 {
output.push(' ');
}
output.push_str(&format!("{cell:>width$}", width = col_width));
}
output.push('\n');
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_stages(n: usize) -> Vec<PipelineStage> {
(0..n)
.map(|i| PipelineStage::new(i, i as i32, format!("stage_{i}")))
.collect()
}
fn make_config(
num_stages: usize,
num_microbatches: usize,
schedule: PipelineSchedule,
) -> PipelineConfig {
PipelineConfig {
stages: make_stages(num_stages),
num_microbatches,
schedule_type: schedule,
interleave_factor: 1,
}
}
#[test]
fn gpipe_schedule_4_stages_8_microbatches() {
let events = GpipeScheduler::schedule(4, 8);
let fwd_count = events
.iter()
.filter(|e| e.event_type == EventType::ForwardStart)
.count();
let bwd_count = events
.iter()
.filter(|e| e.event_type == EventType::BackwardStart)
.count();
assert_eq!(fwd_count, 32);
assert_eq!(bwd_count, 32);
let first_fwd = events
.iter()
.find(|e| e.event_type == EventType::ForwardStart)
.expect("should have a forward event");
assert_eq!(first_fwd.timestamp, 0);
assert_eq!(first_fwd.stage_id, 0);
assert_eq!(first_fwd.microbatch_id, 0);
let last_fwd_time = events
.iter()
.filter(|e| e.event_type == EventType::ForwardEnd)
.map(|e| e.timestamp)
.max()
.unwrap_or(0);
let first_bwd_time = events
.iter()
.filter(|e| e.event_type == EventType::BackwardStart)
.map(|e| e.timestamp)
.min()
.unwrap_or(0);
assert!(first_bwd_time > last_fwd_time);
}
#[test]
fn one_f_one_b_schedule_correctness() {
let events = OneFOneBScheduler::schedule(4, 8);
let fwd_count = events
.iter()
.filter(|e| e.event_type == EventType::ForwardStart)
.count();
let bwd_count = events
.iter()
.filter(|e| e.event_type == EventType::BackwardStart)
.count();
assert_eq!(fwd_count, 32); assert_eq!(bwd_count, 32);
let stage3_events: Vec<&PipelineEvent> = events
.iter()
.filter(|e| e.stage_id == 3)
.filter(|e| {
matches!(
e.event_type,
EventType::ForwardStart | EventType::BackwardStart
)
})
.collect();
assert_eq!(stage3_events[0].event_type, EventType::ForwardStart);
assert_eq!(stage3_events[1].event_type, EventType::BackwardStart);
}
#[test]
fn interleaved_schedule_correctness() {
let events = InterleavedScheduler::schedule(4, 8, 2);
let fwd_count = events
.iter()
.filter(|e| e.event_type == EventType::ForwardStart)
.count();
let bwd_count = events
.iter()
.filter(|e| e.event_type == EventType::BackwardStart)
.count();
assert_eq!(fwd_count, 32); assert_eq!(bwd_count, 32);
let analysis_interleaved = PipelineEngine::analyze_events(&events, 4);
let gpipe_events = GpipeScheduler::schedule(4, 8);
let analysis_gpipe = PipelineEngine::analyze_events(&gpipe_events, 4);
assert!(
analysis_interleaved.bubble_ratio <= analysis_gpipe.bubble_ratio,
"interleaved ({}) should have <= bubble ratio than GPipe ({})",
analysis_interleaved.bubble_ratio,
analysis_gpipe.bubble_ratio
);
}
#[test]
fn zero_bubble_schedule() {
let events = ZeroBubbleScheduler::schedule(4, 8);
let fwd_count = events
.iter()
.filter(|e| e.event_type == EventType::ForwardStart)
.count();
let bwd_count = events
.iter()
.filter(|e| e.event_type == EventType::BackwardStart)
.count();
let wgt_count = events
.iter()
.filter(|e| e.event_type == EventType::WeightGradStart)
.count();
assert_eq!(fwd_count, 32);
assert_eq!(bwd_count, 32);
assert_eq!(wgt_count, 32);
let analysis = PipelineEngine::analyze_events(&events, 4);
assert_eq!(analysis.compute_time, 96); }
#[test]
fn bubble_analysis_gpipe() {
let mut engine =
PipelineEngine::new(make_config(4, 8, PipelineSchedule::GPipe)).expect("create engine");
let analysis = engine.analyze();
assert_eq!(analysis.total_time, 22);
assert_eq!(analysis.compute_time, 64);
assert_eq!(analysis.bubble_time, 24);
assert!(
(analysis.bubble_ratio - 24.0 / 88.0).abs() < 1e-6,
"bubble_ratio = {}",
analysis.bubble_ratio
);
assert_eq!(analysis.per_stage_idle.len(), 4);
}
#[test]
fn steady_state_throughput() {
let mut engine =
PipelineEngine::new(make_config(4, 8, PipelineSchedule::GPipe)).expect("create engine");
let throughput = engine.steady_state_throughput();
assert!(throughput > 0.0);
assert!((throughput - 8.0 / 22.0).abs() < 1e-6);
}
#[test]
fn pipeline_single_stage() {
let events = GpipeScheduler::schedule(1, 4);
let fwd_count = events
.iter()
.filter(|e| e.event_type == EventType::ForwardStart)
.count();
let bwd_count = events
.iter()
.filter(|e| e.event_type == EventType::BackwardStart)
.count();
assert_eq!(fwd_count, 4);
assert_eq!(bwd_count, 4);
let analysis = PipelineEngine::analyze_events(&events, 1);
assert_eq!(analysis.total_time, 8);
assert_eq!(analysis.compute_time, 8);
assert_eq!(analysis.bubble_time, 0);
assert!((analysis.bubble_ratio).abs() < 1e-6);
}
#[test]
fn pipeline_many_microbatches() {
let events = OneFOneBScheduler::schedule(4, 64);
let fwd_count = events
.iter()
.filter(|e| e.event_type == EventType::ForwardStart)
.count();
assert_eq!(fwd_count, 256);
let analysis_64 = PipelineEngine::analyze_events(&events, 4);
let events_8 = OneFOneBScheduler::schedule(4, 8);
let analysis_8 = PipelineEngine::analyze_events(&events_8, 4);
assert!(
analysis_64.bubble_ratio <= analysis_8.bubble_ratio,
"more microbatches should reduce bubble: {} vs {}",
analysis_64.bubble_ratio,
analysis_8.bubble_ratio
);
}
#[test]
fn activation_checkpointing_plan() {
let plan = ActivationCheckpointing::plan(4, 2048);
assert_eq!(plan.len(), 4);
assert!(plan.iter().all(|d| *d == CheckpointDecision::Store));
let plan = ActivationCheckpointing::plan(4, 512);
assert!(plan.iter().all(|d| *d == CheckpointDecision::Recompute));
let plan = ActivationCheckpointing::plan(4, 128);
assert!(plan.iter().all(|d| *d == CheckpointDecision::Offload));
let plan = ActivationCheckpointing::plan_variable(&[2048, 512, 128, 1024]);
assert_eq!(plan[0], CheckpointDecision::Store);
assert_eq!(plan[1], CheckpointDecision::Recompute);
assert_eq!(plan[2], CheckpointDecision::Offload);
assert_eq!(plan[3], CheckpointDecision::Store);
}
#[test]
fn ascii_visualization() {
let events = GpipeScheduler::schedule(2, 3);
let output = PipelineVisualizer::render_ascii(&events, 2);
assert!(output.contains("Stage 0:"));
assert!(output.contains("Stage 1:"));
assert!(output.contains("F0"));
assert!(output.contains("F1"));
assert!(output.contains("F2"));
assert!(output.contains("B0"));
assert!(output.contains("B1"));
assert!(output.contains("B2"));
assert!(output.contains(".."));
let line_count = output.lines().count();
assert_eq!(line_count, 2);
}
#[test]
fn event_ordering() {
let events = GpipeScheduler::schedule(4, 4);
for pair in events.windows(2) {
assert!(
pair[0].timestamp <= pair[1].timestamp,
"events not sorted: t={} before t={}",
pair[0].timestamp,
pair[1].timestamp
);
}
}
#[test]
fn schedule_completeness() {
for schedule_type in [PipelineSchedule::GPipe, PipelineSchedule::PipeDream1F1B] {
let num_stages = 4;
let num_mb = 6;
let config = make_config(num_stages, num_mb, schedule_type);
let mut engine = PipelineEngine::new(config).expect("engine");
let events = engine.generate_schedule();
for mb in 0..num_mb {
for st in 0..num_stages {
let has_fwd = events.iter().any(|e| {
e.microbatch_id == mb
&& e.stage_id == st
&& e.event_type == EventType::ForwardStart
});
let has_bwd = events.iter().any(|e| {
e.microbatch_id == mb
&& e.stage_id == st
&& e.event_type == EventType::BackwardStart
});
assert!(
has_fwd,
"{schedule_type:?}: microbatch {mb} missing forward at stage {st}"
);
assert!(
has_bwd,
"{schedule_type:?}: microbatch {mb} missing backward at stage {st}"
);
}
}
}
}
#[test]
fn microbatch_status_tracking() {
let config = make_config(2, 2, PipelineSchedule::GPipe);
let mut engine = PipelineEngine::new(config).expect("engine");
engine.generate_schedule();
let status = engine.microbatch_status_at(0, 0);
assert_eq!(status, MicrobatchStatus::InForward(0));
let status = engine.microbatch_status_at(0, 1);
assert_eq!(status, MicrobatchStatus::InForward(1));
let status = engine.microbatch_status_at(1, 0);
assert_eq!(status, MicrobatchStatus::Pending);
let status = engine.microbatch_status_at(0, 4);
assert_eq!(status, MicrobatchStatus::Complete);
}
#[test]
fn config_validation() {
let config = PipelineConfig {
stages: vec![],
num_microbatches: 4,
schedule_type: PipelineSchedule::GPipe,
interleave_factor: 1,
};
assert!(config.validate().is_err());
let config = make_config(4, 0, PipelineSchedule::GPipe);
assert!(config.validate().is_err());
let config = PipelineConfig {
stages: make_stages(4),
num_microbatches: 4,
schedule_type: PipelineSchedule::InterleavedStages,
interleave_factor: 0,
};
assert!(config.validate().is_err());
let config = PipelineConfig {
stages: make_stages(3),
num_microbatches: 4,
schedule_type: PipelineSchedule::InterleavedStages,
interleave_factor: 2,
};
assert!(config.validate().is_err());
let config = PipelineConfig {
stages: make_stages(4),
num_microbatches: 8,
schedule_type: PipelineSchedule::InterleavedStages,
interleave_factor: 2,
};
assert!(config.validate().is_ok());
}
#[test]
fn gpipe_bubble_ratio_theoretical() {
let mut engine =
PipelineEngine::new(make_config(4, 8, PipelineSchedule::GPipe)).expect("engine");
let ratio = engine.bubble_ratio();
assert!((ratio - 24.0 / 88.0).abs() < 1e-6);
}
#[test]
fn pipeline_stage_builder() {
let stage = PipelineStage::new(0, 0, "encoder").with_compute_cost(2.5);
assert_eq!(stage.stage_id, 0);
assert_eq!(stage.device_id, 0);
assert_eq!(stage.name, "encoder");
assert!((stage.compute_cost_estimate - 2.5).abs() < f64::EPSILON);
}
}