#![allow(unused_variables)]
use super::model_parallel::{
ModelParallelContext, PipelineOp, PipelineSchedule, PipelineScheduleType,
};
use crate::errors::{runtime_error, Result};
use crate::Tensor;
use parking_lot::{Mutex, RwLock};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
pub trait PipelineLayer: Send + Sync {
fn forward(&self, input: &Tensor) -> Result<Tensor>;
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor>;
}
pub struct PipelineStage {
pub stage_id: usize,
pub layers: Vec<Box<dyn PipelineLayer>>,
pub device_id: usize,
pub requires_grad: bool,
}
impl PipelineStage {
pub fn new(stage_id: usize, device_id: usize) -> Self {
Self {
stage_id,
layers: Vec::new(),
device_id,
requires_grad: true,
}
}
pub fn add_layer(&mut self, layer: Box<dyn PipelineLayer>) {
self.layers.push(layer);
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output)?;
}
Ok(output)
}
pub fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor> {
let mut grad = grad_output.clone();
for layer in self.layers.iter_mut().rev() {
grad = layer.backward(&grad)?;
}
Ok(grad)
}
}
pub struct PipelineModel {
pub stages: Vec<PipelineStage>,
pub mp_context: Arc<ModelParallelContext>,
pub local_stage_id: Option<usize>,
}
impl PipelineModel {
pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
Self {
stages: Vec::new(),
mp_context,
local_stage_id: None,
}
}
pub fn add_stage(&mut self, stage: PipelineStage) {
if stage.device_id == self.mp_context.rank() {
self.local_stage_id = Some(stage.stage_id);
}
self.stages.push(stage);
}
pub fn local_stage(&self) -> Result<&PipelineStage> {
let stage_id =
self.local_stage_id.ok_or_else(|| runtime_error("No local stage assigned"))?;
self.stages.get(stage_id).ok_or_else(|| runtime_error("Invalid stage ID"))
}
pub fn local_stage_mut(&mut self) -> Result<&mut PipelineStage> {
let stage_id =
self.local_stage_id.ok_or_else(|| runtime_error("No local stage assigned"))?;
self.stages.get_mut(stage_id).ok_or_else(|| runtime_error("Invalid stage ID"))
}
pub fn num_stages(&self) -> usize {
self.stages.len()
}
}
#[derive(Clone)]
pub struct Microbatch {
pub id: usize,
pub input: Option<Tensor>,
pub output: Option<Tensor>,
pub grad_output: Option<Tensor>,
pub grad_input: Option<Tensor>,
pub labels: Option<Tensor>,
}
impl Microbatch {
pub fn new(id: usize) -> Self {
Self {
id,
input: None,
output: None,
grad_output: None,
grad_input: None,
labels: None,
}
}
}
pub struct MicrobatchManager {
microbatches: Vec<Microbatch>,
checkpoint_activations: bool,
forward_queue: VecDeque<usize>,
backward_queue: VecDeque<usize>,
}
impl MicrobatchManager {
pub fn new(num_microbatches: usize, checkpoint_activations: bool) -> Self {
let microbatches = (0..num_microbatches).map(Microbatch::new).collect();
Self {
microbatches,
checkpoint_activations,
forward_queue: VecDeque::new(),
backward_queue: VecDeque::new(),
}
}
pub fn get(&self, id: usize) -> Result<&Microbatch> {
self.microbatches
.get(id)
.ok_or_else(|| runtime_error(format!("Invalid microbatch ID: {}", id)))
}
pub fn get_mut(&mut self, id: usize) -> Result<&mut Microbatch> {
self.microbatches
.get_mut(id)
.ok_or_else(|| runtime_error(format!("Invalid microbatch ID: {}", id)))
}
pub fn enqueue_forward(&mut self, mb_id: usize) {
self.forward_queue.push_back(mb_id);
}
pub fn enqueue_backward(&mut self, mb_id: usize) {
self.backward_queue.push_back(mb_id);
}
pub fn dequeue_forward(&mut self) -> Option<usize> {
self.forward_queue.pop_front()
}
pub fn dequeue_backward(&mut self) -> Option<usize> {
self.backward_queue.pop_front()
}
pub fn maybe_clear_activation(&mut self, mb_id: usize) -> Result<()> {
if self.checkpoint_activations {
let mb = self.get_mut(mb_id)?;
mb.output = None; }
Ok(())
}
pub fn maybe_recompute_activation(
&mut self,
mb_id: usize,
stage: &PipelineStage,
) -> Result<()> {
let should_recompute = self.checkpoint_activations;
let mb = self.get_mut(mb_id)?;
if should_recompute && mb.output.is_none() {
if let Some(input) = &mb.input {
mb.output = Some(stage.forward(input)?);
}
}
Ok(())
}
}
pub struct PipelineExecutor {
model: Arc<RwLock<PipelineModel>>,
schedule: PipelineSchedule,
mb_manager: Arc<Mutex<MicrobatchManager>>,
#[allow(dead_code)]
send_buffers: HashMap<usize, Tensor>,
_recv_buffers: HashMap<usize, Tensor>,
}
impl PipelineExecutor {
pub fn new(
model: Arc<RwLock<PipelineModel>>,
num_microbatches: usize,
checkpoint_activations: bool,
) -> Result<Self> {
let num_stages = {
let model_read = model.read();
model_read.num_stages()
};
let schedule = PipelineSchedule::new(
num_stages,
num_microbatches,
PipelineScheduleType::OneForwardOneBackward,
);
let mb_manager = Arc::new(Mutex::new(MicrobatchManager::new(
num_microbatches,
checkpoint_activations,
)));
Ok(Self {
model,
schedule,
mb_manager,
send_buffers: HashMap::new(),
_recv_buffers: HashMap::new(),
})
}
pub fn execute_step(&mut self, inputs: Vec<Tensor>, labels: Vec<Tensor>) -> Result<f32> {
let num_inputs = inputs.len();
self.prepare_microbatches(inputs, labels)?;
let stage_id = {
let model = self.model.read();
model.local_stage_id.ok_or_else(|| runtime_error("No local stage"))?
};
let ops = self.schedule.get_stage_schedule(stage_id);
let mut total_loss = 0.0;
for op in ops {
match op {
PipelineOp::Forward { microbatch_id } => {
self.execute_forward(microbatch_id)?;
},
PipelineOp::Backward { microbatch_id } => {
let loss = self.execute_backward(microbatch_id)?;
total_loss += loss;
},
PipelineOp::SendActivation { to_stage } => {
self.send_activation(to_stage)?;
},
PipelineOp::RecvActivation { from_stage } => {
self.recv_activation(from_stage)?;
},
PipelineOp::SendGradient { to_stage } => {
self.send_gradient(to_stage)?;
},
PipelineOp::RecvGradient { from_stage } => {
self.recv_gradient(from_stage)?;
},
}
}
Ok(total_loss / num_inputs as f32)
}
fn prepare_microbatches(&mut self, inputs: Vec<Tensor>, labels: Vec<Tensor>) -> Result<()> {
let mut mb_manager = self.mb_manager.lock();
for (i, (input, label)) in inputs.into_iter().zip(labels).enumerate() {
let mb = mb_manager.get_mut(i)?;
mb.input = Some(input);
mb.labels = Some(label);
mb_manager.enqueue_forward(i);
}
Ok(())
}
fn execute_forward(&mut self, mb_id: usize) -> Result<()> {
let mut model = self.model.write();
let stage = model.local_stage_mut()?;
let mut mb_manager = self.mb_manager.lock();
let mb = mb_manager.get_mut(mb_id)?;
let input = if stage.stage_id == 0 {
mb.input.as_ref().ok_or_else(|| runtime_error("Missing input"))?
} else {
mb.output.as_ref().ok_or_else(|| runtime_error("Missing activation"))?
};
let output = stage.forward(input)?;
mb.output = Some(output);
mb_manager.maybe_clear_activation(mb_id)?;
Ok(())
}
fn execute_backward(&mut self, mb_id: usize) -> Result<f32> {
let (is_last_stage, stage_id) = {
let model = self.model.read();
let stage = model.local_stage()?;
(stage.stage_id == model.num_stages() - 1, stage.stage_id)
};
let mut model = self.model.write();
let stage = model.local_stage_mut()?;
let mut mb_manager = self.mb_manager.lock();
mb_manager.maybe_recompute_activation(mb_id, stage)?;
let mb = mb_manager.get_mut(mb_id)?;
let loss = if is_last_stage {
1.0
} else {
0.0
};
let grad_output = if is_last_stage {
mb.output.as_ref().ok_or_else(|| runtime_error("Missing output"))?.clone()
} else {
mb.grad_output
.as_ref()
.ok_or_else(|| runtime_error("Missing grad_output"))?
.clone()
};
let grad_input = stage.backward(&grad_output)?;
mb.grad_input = Some(grad_input);
Ok(loss)
}
fn send_activation(&mut self, to_stage: usize) -> Result<()> {
Ok(())
}
fn recv_activation(&mut self, from_stage: usize) -> Result<()> {
Ok(())
}
fn send_gradient(&mut self, to_stage: usize) -> Result<()> {
Ok(())
}
fn recv_gradient(&mut self, from_stage: usize) -> Result<()> {
Ok(())
}
}
pub struct PipelineOptimizer {
#[allow(dead_code)]
lr: f32,
_weight_decay: f32,
accumulation_steps: usize,
current_step: usize,
accumulated_grads: HashMap<String, Tensor>,
}
impl PipelineOptimizer {
pub fn new(lr: f32, weight_decay: f32, accumulation_steps: usize) -> Self {
Self {
lr,
_weight_decay: weight_decay,
accumulation_steps,
current_step: 0,
accumulated_grads: HashMap::new(),
}
}
pub fn accumulate_gradients(&mut self, grads: HashMap<String, Tensor>) -> Result<()> {
for (name, grad) in grads {
if let Some(acc_grad) = self.accumulated_grads.get_mut(&name) {
*acc_grad = acc_grad.add(&grad)?;
} else {
self.accumulated_grads.insert(name, grad);
}
}
self.current_step += 1;
Ok(())
}
pub fn step(&mut self, model: &mut PipelineModel) -> Result<bool> {
if self.current_step < self.accumulation_steps {
return Ok(false);
}
let scale = 1.0 / self.accumulation_steps as f32;
self.accumulated_grads.clear();
self.current_step = 0;
Ok(true)
}
}
pub struct PipelineModelBuilder {
mp_context: Arc<ModelParallelContext>,
stages: Vec<PipelineStage>,
layers_per_stage: Option<usize>,
}
impl PipelineModelBuilder {
pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
Self {
mp_context,
stages: Vec::new(),
layers_per_stage: None,
}
}
pub fn layers_per_stage(mut self, layers_per_stage: usize) -> Self {
self.layers_per_stage = Some(layers_per_stage);
self
}
pub fn add_stage(mut self, stage: PipelineStage) -> Self {
self.stages.push(stage);
self
}
pub fn build(self) -> Result<PipelineModel> {
let mut model = PipelineModel::new(self.mp_context);
for stage in self.stages {
model.add_stage(stage);
}
Ok(model)
}
}
#[cfg(test)]
mod tests {
use super::super::model_parallel::{
CommunicationBackend, ModelParallelConfig, ModelParallelStrategy,
};
use super::*;
#[test]
fn test_pipeline_stage() {
let stage = PipelineStage::new(0, 0);
assert_eq!(stage.stage_id, 0);
assert_eq!(stage.device_id, 0);
assert!(stage.requires_grad);
}
#[test]
fn test_microbatch_manager() {
let mut manager = MicrobatchManager::new(4, true);
manager.enqueue_forward(0);
manager.enqueue_forward(1);
assert_eq!(manager.dequeue_forward(), Some(0));
assert_eq!(manager.dequeue_forward(), Some(1));
assert_eq!(manager.dequeue_forward(), None);
}
#[test]
fn test_pipeline_model_builder() {
let config = ModelParallelConfig {
num_devices: 4,
device_ids: vec![0, 1, 2, 3],
strategy: ModelParallelStrategy::Pipeline,
comm_backend: CommunicationBackend::Custom,
..Default::default()
};
let mp_context =
Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
let model = PipelineModelBuilder::new(mp_context)
.add_stage(PipelineStage::new(0, 0))
.add_stage(PipelineStage::new(1, 1))
.build()
.expect("operation failed in test");
assert_eq!(model.num_stages(), 2);
}
}