use crate::process_group::ProcessGroup;
use axonml_autograd::Variable;
use axonml_nn::{Module, Parameter};
use axonml_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PipelineSchedule {
GPipe,
#[default]
OneFOneBSchedule,
InterleavedOneFOneB,
}
pub struct PipelineStage<M: Module> {
module: M,
stage_id: usize,
device_rank: usize,
}
impl<M: Module> PipelineStage<M> {
pub fn new(module: M, stage_id: usize, device_rank: usize) -> Self {
Self {
module,
stage_id,
device_rank,
}
}
pub fn stage_id(&self) -> usize {
self.stage_id
}
pub fn device_rank(&self) -> usize {
self.device_rank
}
pub fn forward(&self, input: &Variable) -> Variable {
self.module.forward(input)
}
}
impl<M: Module> Module for PipelineStage<M> {
fn forward(&self, input: &Variable) -> Variable {
self.module.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.module.parameters()
}
fn train(&mut self) {
self.module.train();
}
fn eval(&mut self) {
self.module.eval();
}
fn is_training(&self) -> bool {
self.module.is_training()
}
}
pub struct Pipeline<M: Module> {
stages: Vec<PipelineStage<M>>,
process_group: ProcessGroup,
schedule: PipelineSchedule,
num_microbatches: usize,
#[allow(dead_code)]
local_stage: usize,
}
impl<M: Module + Clone> Pipeline<M> {
pub fn from_modules(modules: Vec<M>, process_group: ProcessGroup) -> Self {
let world_size = process_group.world_size();
let rank = process_group.rank();
let stages: Vec<PipelineStage<M>> = modules
.into_iter()
.enumerate()
.map(|(i, m)| PipelineStage::new(m, i, i % world_size))
.collect();
let local_stage = stages
.iter()
.position(|s| s.device_rank == rank)
.unwrap_or(0);
Self {
stages,
process_group,
schedule: PipelineSchedule::default(),
num_microbatches: 1,
local_stage,
}
}
pub fn schedule(mut self, schedule: PipelineSchedule) -> Self {
self.schedule = schedule;
self
}
pub fn num_microbatches(mut self, num: usize) -> Self {
self.num_microbatches = num.max(1);
self
}
pub fn num_stages(&self) -> usize {
self.stages.len()
}
pub fn get_schedule(&self) -> PipelineSchedule {
self.schedule
}
pub fn forward(&self, input: &Variable) -> Variable {
match self.schedule {
PipelineSchedule::GPipe => self.forward_gpipe(input),
PipelineSchedule::OneFOneBSchedule => self.forward_1f1b(input),
PipelineSchedule::InterleavedOneFOneB => self.forward_interleaved(input),
}
}
fn forward_gpipe(&self, input: &Variable) -> Variable {
let rank = self.process_group.rank();
let num_stages = self.stages.len();
let microbatches = self.split_microbatches(input);
let mut outputs = Vec::new();
for microbatch in microbatches {
let mut activation = microbatch;
for (stage_idx, stage) in self.stages.iter().enumerate() {
if stage.device_rank == rank {
activation = stage.forward(&activation);
}
if stage_idx < num_stages - 1 {
let next_rank = self.stages[stage_idx + 1].device_rank;
if stage.device_rank == rank {
self.send_activation(&activation, next_rank);
} else if next_rank == rank {
activation = self.recv_activation(stage.device_rank, activation.shape());
}
}
}
if self.stages.last().map(|s| s.device_rank) == Some(rank) {
outputs.push(activation);
}
}
self.combine_microbatches(&outputs)
}
fn forward_1f1b(&self, input: &Variable) -> Variable {
self.forward_gpipe(input)
}
fn forward_interleaved(&self, input: &Variable) -> Variable {
self.forward_gpipe(input)
}
fn split_microbatches(&self, input: &Variable) -> Vec<Variable> {
let data = input.data();
let batch_size = data.shape()[0];
let microbatch_size = batch_size.div_ceil(self.num_microbatches);
let mut microbatches = Vec::new();
let flat_data = data.to_vec();
let elements_per_sample: usize = data.shape()[1..].iter().product();
for i in 0..self.num_microbatches {
let start = i * microbatch_size;
let end = ((i + 1) * microbatch_size).min(batch_size);
if start >= batch_size {
break;
}
let mb_size = end - start;
let start_idx = start * elements_per_sample;
let end_idx = end * elements_per_sample;
let mb_data: Vec<f32> = flat_data[start_idx..end_idx].to_vec();
let mut shape = data.shape().to_vec();
shape[0] = mb_size;
let tensor = Tensor::from_vec(mb_data, &shape).unwrap();
microbatches.push(Variable::new(tensor, input.requires_grad()));
}
microbatches
}
fn combine_microbatches(&self, outputs: &[Variable]) -> Variable {
if outputs.is_empty() {
return Variable::new(Tensor::zeros(&[0]), false);
}
if outputs.len() == 1 {
return outputs[0].clone();
}
let mut all_data = Vec::new();
let mut total_batch = 0;
let shape = outputs[0].data().shape().to_vec();
for output in outputs {
all_data.extend(output.data().to_vec());
total_batch += output.data().shape()[0];
}
let mut new_shape = shape;
new_shape[0] = total_batch;
let tensor = Tensor::from_vec(all_data, &new_shape).unwrap();
Variable::new(tensor, outputs[0].requires_grad())
}
fn send_activation(&self, activation: &Variable, dest_rank: usize) {
let mut tensor = activation.data().clone();
self.process_group.send_tensor(&mut tensor, dest_rank);
}
fn recv_activation(&self, src_rank: usize, shape: Vec<usize>) -> Variable {
let tensor = self.process_group.recv_tensor(src_rank, &shape);
Variable::new(tensor, true)
}
}
impl<M: Module + Clone> Module for Pipeline<M> {
fn forward(&self, input: &Variable) -> Variable {
Pipeline::forward(self, input)
}
fn parameters(&self) -> Vec<Parameter> {
self.stages.iter().flat_map(|s| s.parameters()).collect()
}
fn train(&mut self) {
for stage in &mut self.stages {
stage.train();
}
}
fn eval(&mut self) {
for stage in &mut self.stages {
stage.eval();
}
}
fn is_training(&self) -> bool {
self.stages.first().is_some_and(|s| s.is_training())
}
}
#[derive(Debug, Clone)]
pub struct PipelineMemoryStats {
pub num_stages: usize,
pub num_microbatches: usize,
pub peak_activations_per_stage: usize,
pub schedule: PipelineSchedule,
}
impl PipelineMemoryStats {
pub fn gpipe_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
num_stages * num_microbatches
}
pub fn one_f_one_b_peak_activations(num_stages: usize, num_microbatches: usize) -> usize {
num_stages.min(num_microbatches)
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_nn::Linear;
#[derive(Clone)]
struct IdentityModule {
size: usize,
training: bool,
}
impl IdentityModule {
fn new(size: usize) -> Self {
Self {
size,
training: true,
}
}
}
impl Module for IdentityModule {
fn forward(&self, input: &Variable) -> Variable {
input.clone()
}
fn parameters(&self) -> Vec<Parameter> {
Vec::new()
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[test]
fn test_pipeline_schedule_default() {
assert_eq!(
PipelineSchedule::default(),
PipelineSchedule::OneFOneBSchedule
);
}
#[test]
fn test_pipeline_stage_creation() {
let module = Linear::new(10, 5);
let stage = PipelineStage::new(module, 0, 0);
assert_eq!(stage.stage_id(), 0);
assert_eq!(stage.device_rank(), 0);
}
#[test]
fn test_pipeline_creation() {
let modules = vec![
IdentityModule::new(10),
IdentityModule::new(8),
IdentityModule::new(6),
];
let pg = ProcessGroup::mock();
let pipeline = Pipeline::from_modules(modules, pg)
.schedule(PipelineSchedule::GPipe)
.num_microbatches(2);
assert_eq!(pipeline.num_stages(), 3);
assert_eq!(pipeline.get_schedule(), PipelineSchedule::GPipe);
}
#[test]
fn test_pipeline_forward() {
let modules = vec![IdentityModule::new(4)];
let pg = ProcessGroup::mock();
let pipeline = Pipeline::from_modules(modules, pg);
let input = Variable::new(Tensor::randn(&[2, 4]), false);
let output = pipeline.forward(&input);
assert_eq!(output.data().shape(), &[2, 4]);
}
#[test]
fn test_pipeline_memory_stats() {
let gpipe = PipelineMemoryStats::gpipe_peak_activations(4, 8);
let one_f_one_b = PipelineMemoryStats::one_f_one_b_peak_activations(4, 8);
assert_eq!(gpipe, 32); assert_eq!(one_f_one_b, 4); }
#[test]
fn test_split_microbatches() {
let modules = vec![IdentityModule::new(4)];
let pg = ProcessGroup::mock();
let pipeline = Pipeline::from_modules(modules, pg).num_microbatches(2);
let input = Variable::new(Tensor::randn(&[4, 4]), false);
let microbatches = pipeline.split_microbatches(&input);
assert_eq!(microbatches.len(), 2);
assert_eq!(microbatches[0].data().shape()[0], 2);
assert_eq!(microbatches[1].data().shape()[0], 2);
}
}