use std::sync::Arc;
use super::clock::{PipelineAction, PipelineClock};
use super::comm::{compute_loss_grad, recv_activation, send_activation};
use super::stage::TrainablePipelineStage;
use crate::error::{Error, Result};
use numr::autograd::Var;
use numr::dtype::DType;
use numr::ops::ShapeOps;
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub type LossFn<'a, R> = dyn Fn(&Var<R>) -> Result<Var<R>> + 'a;
pub struct PipelineOutput<R: Runtime> {
pub losses: Vec<f64>,
pub input_grads: Vec<Tensor<R>>,
}
pub struct Schedule1F1B<R: Runtime> {
stage: Box<dyn TrainablePipelineStage<R>>,
num_micro_batches: usize,
pp_comm: Arc<dyn Communicator>,
device: R::Device,
}
impl<R: Runtime<DType = DType>> Schedule1F1B<R> {
pub fn new(
stage: Box<dyn TrainablePipelineStage<R>>,
num_micro_batches: usize,
pp_comm: Arc<dyn Communicator>,
device: R::Device,
) -> Result<Self> {
if num_micro_batches == 0 {
return Err(Error::DistributedError {
reason: "num_micro_batches must be > 0".to_string(),
});
}
Ok(Self {
stage,
num_micro_batches,
pp_comm,
device,
})
}
pub fn run<C>(
&mut self,
_client: &C,
micro_batches: Option<Vec<Tensor<R>>>,
loss_fn: Option<&LossFn<'_, R>>,
) -> Result<PipelineOutput<R>>
where
C: RuntimeClient<R> + ShapeOps<R>,
{
let rank = self.pp_comm.rank();
let world_size = self.pp_comm.world_size();
let num_stages = world_size.max(1);
let is_first = rank == 0;
let is_last = rank == num_stages - 1;
let clock = PipelineClock::new(num_stages, self.num_micro_batches, rank);
let actions = clock.schedule_1f1b();
let mut forward_outputs: Vec<Option<Var<R>>> = vec![None; self.num_micro_batches];
let mut losses = Vec::new();
let mut input_grads: Vec<Option<Tensor<R>>> = vec![None; self.num_micro_batches];
let mb_inputs: Vec<Option<Tensor<R>>> = if is_first {
let mbs = micro_batches.ok_or_else(|| Error::DistributedError {
reason: "stage 0 must provide micro_batches".to_string(),
})?;
if mbs.len() != self.num_micro_batches {
return Err(Error::DistributedError {
reason: format!(
"expected {} micro-batches, got {}",
self.num_micro_batches,
mbs.len()
),
});
}
mbs.into_iter().map(Some).collect()
} else {
vec![None; self.num_micro_batches]
};
let mut mb_inputs = mb_inputs;
for action in &actions {
match *action {
PipelineAction::Forward(mb_id) => {
let input_tensor = if is_first {
mb_inputs[mb_id]
.take()
.ok_or_else(|| Error::DistributedError {
reason: format!("micro-batch {mb_id} already consumed"),
})?
} else {
recv_activation::<R>(
self.pp_comm.as_ref(),
rank - 1,
mb_id,
false,
&self.device,
)?
};
let input_var = Var::new(input_tensor, is_first);
let output_var = self.stage.forward(input_var)?;
if is_last {
forward_outputs[mb_id] = Some(output_var);
} else {
send_activation::<R>(
self.pp_comm.as_ref(),
output_var.tensor(),
rank + 1,
mb_id,
false,
)?;
}
}
PipelineAction::Backward(mb_id) => {
let output_grad = if is_last {
compute_loss_grad(
&mut forward_outputs[mb_id],
mb_id,
loss_fn,
&mut losses,
&self.device,
)?
} else {
recv_activation::<R>(
self.pp_comm.as_ref(),
rank + 1,
mb_id,
true,
&self.device,
)?
};
let input_grad = self.stage.backward(mb_id, output_grad)?;
if is_first {
input_grads[mb_id] = Some(input_grad);
} else {
send_activation::<R>(
self.pp_comm.as_ref(),
&input_grad,
rank - 1,
mb_id,
true,
)?;
}
}
_ => {
}
}
}
Ok(PipelineOutput {
losses,
input_grads: input_grads.into_iter().flatten().collect(),
})
}
pub fn num_micro_batches(&self) -> usize {
self.num_micro_batches
}
pub fn communicator(&self) -> &dyn Communicator {
self.pp_comm.as_ref()
}
pub fn param_grads(&self) -> Result<Vec<(numr::tensor::TensorId, Tensor<R>)>> {
self.stage.param_grads()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::Var;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::TensorId;
struct DoubleTrainStage {
saved_count: usize,
}
impl DoubleTrainStage {
fn new() -> Self {
Self { saved_count: 0 }
}
}
impl TrainablePipelineStage<CpuRuntime> for DoubleTrainStage {
fn forward(&mut self, input: Var<CpuRuntime>) -> Result<Var<CpuRuntime>> {
let data = input.tensor().to_vec::<f32>();
let doubled: Vec<f32> = data.iter().map(|x| x * 2.0).collect();
let out = Tensor::from_slice(&doubled, input.tensor().shape(), input.tensor().device());
self.saved_count += 1;
Ok(Var::new(out, false))
}
fn backward(
&mut self,
_micro_batch_id: usize,
output_grad: Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
self.saved_count = self.saved_count.saturating_sub(1);
let data = output_grad.to_vec::<f32>();
let grad: Vec<f32> = data.iter().map(|x| x * 2.0).collect();
Ok(Tensor::from_slice(
&grad,
output_grad.shape(),
output_grad.device(),
))
}
fn param_grads(&self) -> Result<Vec<(TensorId, Tensor<CpuRuntime>)>> {
Ok(Vec::new())
}
fn num_saved(&self) -> usize {
self.saved_count
}
}
#[test]
fn test_schedule_1f1b_single_device() {
let (_client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(DoubleTrainStage::new());
let mut schedule = Schedule1F1B::<CpuRuntime>::new(stage, 2, comm, device.clone()).unwrap();
let mb0 = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let mb1 = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let loss_fn = |output: &Var<CpuRuntime>| -> Result<Var<CpuRuntime>> {
let data = output.tensor().to_vec::<f32>();
let sum: f32 = data.iter().sum();
let loss_t = Tensor::from_slice(&[sum], &[1], output.tensor().device());
Ok(Var::new(loss_t, false))
};
let result = schedule
.run(&_client, Some(vec![mb0, mb1]), Some(&loss_fn))
.unwrap();
assert_eq!(result.losses.len(), 2);
assert!((result.losses[0] - 6.0).abs() < 1e-5);
assert!((result.losses[1] - 14.0).abs() < 1e-5);
}
#[test]
fn test_schedule_1f1b_zero_mb_error() {
let (_client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stage = Box::new(DoubleTrainStage::new());
let result = Schedule1F1B::<CpuRuntime>::new(stage, 0, comm, device);
assert!(result.is_err());
}
}