use std::sync::Arc;
use super::stage::PipelineStage;
use crate::distributed::comm_utils::{recv_tensor_with_metadata, send_tensor_with_metadata};
use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::ops::ShapeOps;
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct GpipeSchedule<R: Runtime> {
stage: Box<dyn PipelineStage<R>>,
num_micro_batches: usize,
comm: Arc<dyn Communicator>,
device: R::Device,
}
impl<R: Runtime<DType = DType>> GpipeSchedule<R> {
pub fn new(
stage: Box<dyn PipelineStage<R>>,
num_micro_batches: usize,
comm: Arc<dyn Communicator>,
device: R::Device,
) -> Result<Self> {
if num_micro_batches == 0 {
return Err(Error::DistributedError {
reason: "num_micro_batches must be > 0".to_string(),
});
}
Ok(Self {
stage,
num_micro_batches,
comm,
device,
})
}
pub fn run<C>(&mut self, client: &C, input: Option<Tensor<R>>) -> Result<Vec<Tensor<R>>>
where
C: RuntimeClient<R> + ShapeOps<R>,
{
let rank = self.comm.rank();
let world_size = self.comm.world_size();
let num_stages = world_size;
let is_first = rank == 0;
let is_last = rank == num_stages - 1;
if world_size <= 1 {
return self.run_single_device(client, input);
}
let mut outputs = Vec::new();
let micro_batches: Vec<Tensor<R>> = if is_first {
let inp = input.ok_or_else(|| Error::DistributedError {
reason: "rank 0 must provide input".to_string(),
})?;
client.chunk(&inp, self.num_micro_batches, 0)?
} else {
Vec::new()
};
let mut mb_iter = micro_batches.into_iter();
for mb_idx in 0..self.num_micro_batches {
let tag = u32::try_from(mb_idx * 2).map_err(|_| Error::DistributedError {
reason: format!("micro-batch index {mb_idx} exceeds u32 tag range"),
})?;
let mb_input = if is_first {
mb_iter.next().ok_or_else(|| Error::DistributedError {
reason: "fewer micro-batches than expected from chunk".to_string(),
})?
} else {
recv_tensor_with_metadata::<R>(self.comm.as_ref(), rank - 1, tag, &self.device)?
};
let mb_output = self.stage.forward(mb_input)?;
if is_last {
outputs.push(mb_output);
} else {
send_tensor_with_metadata(self.comm.as_ref(), &mb_output, rank + 1, tag)?;
}
}
Ok(outputs)
}
fn run_single_device<C>(
&mut self,
client: &C,
input: Option<Tensor<R>>,
) -> Result<Vec<Tensor<R>>>
where
C: RuntimeClient<R> + ShapeOps<R>,
{
let inp = input.ok_or_else(|| Error::DistributedError {
reason: "input required for single-device pipeline".to_string(),
})?;
let micro_batches = client.chunk(&inp, self.num_micro_batches, 0)?;
let mut outputs = Vec::with_capacity(self.num_micro_batches);
for mb in micro_batches {
let out = self.stage.forward(mb)?;
outputs.push(out);
}
Ok(outputs)
}
pub fn recv_into(&self, buffer: &Tensor<R>, src: usize, tag: u32) -> Result<()> {
crate::distributed::comm_utils::recv_into_tensor(self.comm.as_ref(), buffer, src, tag)
}
pub fn num_micro_batches(&self) -> usize {
self.num_micro_batches
}
pub fn communicator(&self) -> &dyn Communicator {
self.comm.as_ref()
}
}
pub type PipelineSchedule<R> = GpipeSchedule<R>;
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
struct DoubleStage;
impl PipelineStage<CpuRuntime> for DoubleStage {
fn forward(&mut self, input: Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
let data = input.to_vec::<f32>();
let doubled: Vec<f32> = data.iter().map(|x| x * 2.0).collect();
Ok(Tensor::from_slice(&doubled, input.shape(), input.device()))
}
}
struct AddOneStage;
impl PipelineStage<CpuRuntime> for AddOneStage {
fn forward(&mut self, input: Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
let data = input.to_vec::<f32>();
let result: Vec<f32> = data.iter().map(|x| x + 1.0).collect();
Ok(Tensor::from_slice(&result, input.shape(), input.device()))
}
}
#[test]
fn test_gpipe_single_device() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(DoubleStage);
let mut pipeline = GpipeSchedule::new(stage, 2, comm, device.clone()).unwrap();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4, 1], &device);
let outputs = pipeline.run(&client, Some(input)).unwrap();
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].to_vec::<f32>(), vec![2.0, 4.0]);
assert_eq!(outputs[1].to_vec::<f32>(), vec![6.0, 8.0]);
}
#[test]
fn test_gpipe_single_micro_batch() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(AddOneStage);
let mut pipeline = GpipeSchedule::new(stage, 1, comm, device.clone()).unwrap();
let input = Tensor::<CpuRuntime>::from_slice(&[10.0f32, 20.0], &[2, 1], &device);
let outputs = pipeline.run(&client, Some(input)).unwrap();
assert_eq!(outputs.len(), 1);
assert_eq!(outputs[0].to_vec::<f32>(), vec![11.0, 21.0]);
}
#[test]
fn test_gpipe_zero_micro_batches_error() {
let (_client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(DoubleStage);
let result = GpipeSchedule::<CpuRuntime>::new(stage, 0, comm, device);
assert!(result.is_err());
}
#[test]
fn test_gpipe_no_input_error() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(DoubleStage);
let mut pipeline = GpipeSchedule::new(stage, 1, comm, device.clone()).unwrap();
let result = pipeline.run(&client, None);
assert!(result.is_err());
}
#[test]
fn test_gpipe_recv_into() {
let (_client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(DoubleStage);
let pipeline = GpipeSchedule::new(stage, 1, comm, device.clone()).unwrap();
let buffer = Tensor::<CpuRuntime>::zeros(&[3], DType::F32, &device);
pipeline.recv_into(&buffer, 0, 0).unwrap();
}
}