use std::sync::Arc;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
use ferrotorch_nn::Module;
use crate::backend::Backend;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PipelineSchedule {
GPipe,
OneFOnEB,
}
pub struct Pipeline<M: Module<T>, T: Float> {
module: M,
backend: Arc<dyn Backend>,
num_microbatches: usize,
schedule: PipelineSchedule,
_marker: std::marker::PhantomData<T>,
}
impl<M: Module<T>, T: Float> std::fmt::Debug for Pipeline<M, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pipeline")
.field("num_microbatches", &self.num_microbatches)
.field("schedule", &self.schedule)
.finish_non_exhaustive()
}
}
impl<M: Module<T>, T: Float> Pipeline<M, T> {
pub fn new(
module: M,
backend: Arc<dyn Backend>,
num_microbatches: usize,
schedule: PipelineSchedule,
) -> FerrotorchResult<Self> {
if num_microbatches == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "Pipeline: num_microbatches must be > 0".into(),
});
}
if backend.world_size() < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Pipeline: world_size must be >= 2 for pipeline parallelism, got {}",
backend.world_size(),
),
});
}
Ok(Self {
module,
backend,
num_microbatches,
schedule,
_marker: std::marker::PhantomData,
})
}
pub fn forward(&self, input: Option<&Tensor<T>>) -> FerrotorchResult<Vec<Tensor<T>>> {
let rank = self.backend.rank();
let world_size = self.backend.world_size();
let mut outputs = Vec::with_capacity(self.num_microbatches);
for mb in 0..self.num_microbatches {
let mb_input = if rank == 0 {
let input = input.ok_or_else(|| FerrotorchError::InvalidArgument {
message: "Pipeline: rank 0 must provide input".into(),
})?;
self.get_microbatch(input, mb)?
} else {
self.recv_activation(rank - 1)?
};
let output = self.module.forward(&mb_input)?;
if rank < world_size - 1 {
self.send_activation(&output, rank + 1)?;
}
outputs.push(output);
}
Ok(outputs)
}
pub fn backward(
&self,
outputs: &[Tensor<T>],
grad_outputs: Option<&[Tensor<T>]>,
) -> FerrotorchResult<()> {
let rank = self.backend.rank();
let world_size = self.backend.world_size();
for mb in (0..self.num_microbatches).rev() {
if rank == world_size - 1 {
if let Some(grads) = grad_outputs {
if mb < grads.len() {
outputs[mb].set_grad(Some(grads[mb].clone()))?;
}
}
} else {
let grad = self.recv_activation(rank + 1)?;
outputs[mb].set_grad(Some(grad))?;
}
ferrotorch_core::backward(&outputs[mb])?;
if rank > 0 {
let numel = outputs[mb].numel();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
outputs[mb].shape().to_vec(),
false,
)?;
self.send_activation(&grad_input, rank - 1)?;
}
}
Ok(())
}
fn get_microbatch(&self, input: &Tensor<T>, mb_idx: usize) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "Pipeline: input tensor must have at least 1 dimension".into(),
});
}
let batch_size = shape[0];
let mb_size = batch_size / self.num_microbatches;
let start = mb_idx * mb_size;
let end = if mb_idx == self.num_microbatches - 1 {
batch_size
} else {
start + mb_size
};
let data = input.data_vec()?;
let stride: usize = shape[1..].iter().product();
let mb_data = data[start * stride..end * stride].to_vec();
let mut mb_shape = shape.to_vec();
mb_shape[0] = end - start;
Tensor::from_storage(TensorStorage::cpu(mb_data), mb_shape, input.requires_grad())
}
fn send_activation(&self, tensor: &Tensor<T>, dst_rank: usize) -> FerrotorchResult<()> {
let data = tensor.data_vec()?;
let elem_size = std::mem::size_of::<T>();
let byte_slice: Vec<u8> = data
.iter()
.flat_map(|v| {
let bytes =
unsafe { std::slice::from_raw_parts(v as *const T as *const u8, elem_size) };
bytes.to_vec()
})
.collect();
let ndim = tensor.shape().len() as u64;
let mut header = ndim.to_le_bytes().to_vec();
for &d in tensor.shape() {
header.extend_from_slice(&(d as u64).to_le_bytes());
}
self.backend.send(&header, dst_rank)?;
self.backend.send(&byte_slice, dst_rank)?;
Ok(())
}
fn recv_activation(&self, src_rank: usize) -> FerrotorchResult<Tensor<T>> {
let mut ndim_buf = [0u8; 8];
self.backend.recv(&mut ndim_buf, src_rank)?;
let ndim = u64::from_le_bytes(ndim_buf) as usize;
let mut shape = Vec::with_capacity(ndim);
for _ in 0..ndim {
let mut dim_buf = [0u8; 8];
self.backend.recv(&mut dim_buf, src_rank)?;
shape.push(u64::from_le_bytes(dim_buf) as usize);
}
let numel: usize = shape.iter().product();
let elem_size = std::mem::size_of::<T>();
let mut byte_buf = vec![0u8; numel * elem_size];
self.backend.recv(&mut byte_buf, src_rank)?;
let data: Vec<T> = byte_buf
.chunks_exact(elem_size)
.map(|chunk| match elem_size {
4 => {
let val = f32::from_le_bytes(chunk.try_into().unwrap());
T::from(val).unwrap()
}
8 => {
let val = f64::from_le_bytes(chunk.try_into().unwrap());
T::from(val).unwrap()
}
_ => unreachable!("unsupported element size {}", elem_size),
})
.collect();
Tensor::from_storage(TensorStorage::cpu(data), shape, false)
}
pub fn schedule(&self) -> PipelineSchedule {
self.schedule
}
pub fn num_microbatches(&self) -> usize {
self.num_microbatches
}
pub fn module(&self) -> &M {
&self.module
}
pub fn module_mut(&mut self) -> &mut M {
&mut self.module
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SimulatedBackend;
#[test]
fn test_pipeline_new_validates_microbatches() {
use ferrotorch_core::TensorStorage;
use ferrotorch_nn::Parameter;
struct DummyModule;
impl Module<f32> for DummyModule {
fn forward(&self, input: &Tensor<f32>) -> FerrotorchResult<Tensor<f32>> {
Ok(input.clone())
}
fn parameters(&self) -> Vec<&Parameter<f32>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<f32>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<f32>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
true
}
}
let group = SimulatedBackend::create_group(2).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let result = Pipeline::new(DummyModule, b.clone(), 0, PipelineSchedule::GPipe);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("num_microbatches must be > 0"));
}
#[test]
fn test_pipeline_new_validates_world_size() {
use ferrotorch_nn::Parameter;
struct DummyModule;
impl Module<f32> for DummyModule {
fn forward(&self, input: &Tensor<f32>) -> FerrotorchResult<Tensor<f32>> {
Ok(input.clone())
}
fn parameters(&self) -> Vec<&Parameter<f32>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<f32>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<f32>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
true
}
}
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let result = Pipeline::new(DummyModule, b, 2, PipelineSchedule::OneFOnEB);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("world_size must be >= 2"));
}
#[test]
fn test_pipeline_schedule_accessors() {
use ferrotorch_nn::Parameter;
struct DummyModule;
impl Module<f32> for DummyModule {
fn forward(&self, input: &Tensor<f32>) -> FerrotorchResult<Tensor<f32>> {
Ok(input.clone())
}
fn parameters(&self) -> Vec<&Parameter<f32>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<f32>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<f32>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
true
}
}
let group = SimulatedBackend::create_group(2).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let pipeline = Pipeline::new(DummyModule, b, 4, PipelineSchedule::GPipe).unwrap();
assert_eq!(pipeline.schedule(), PipelineSchedule::GPipe);
assert_eq!(pipeline.num_microbatches(), 4);
}
}