use std::collections::HashMap;
use std::sync::Arc;
use super::clock::{PipelineAction, PipelineClock};
use super::comm::{compute_loss_grad, recv_activation_tagged, send_activation_tagged};
use super::schedule_1f1b::PipelineOutput;
use super::stage::TrainablePipelineStage;
use crate::error::{Error, Result};
use numr::autograd::Var;
use numr::dtype::DType;
use numr::ops::ShapeOps;
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct ScheduleInterleaved1F1B<R: Runtime> {
stages: Vec<Box<dyn TrainablePipelineStage<R>>>,
num_micro_batches: usize,
pp_comm: Arc<dyn Communicator>,
device: R::Device,
}
impl<R: Runtime<DType = DType>> ScheduleInterleaved1F1B<R> {
pub fn new(
stages: Vec<Box<dyn TrainablePipelineStage<R>>>,
num_micro_batches: usize,
pp_comm: Arc<dyn Communicator>,
device: R::Device,
) -> Result<Self> {
if stages.is_empty() {
return Err(Error::DistributedError {
reason: "must have at least one virtual stage".to_string(),
});
}
if num_micro_batches == 0 {
return Err(Error::DistributedError {
reason: "num_micro_batches must be > 0".to_string(),
});
}
Ok(Self {
stages,
num_micro_batches,
pp_comm,
device,
})
}
pub fn run<C>(
&mut self,
_client: &C,
micro_batches: Option<Vec<Tensor<R>>>,
loss_fn: Option<&super::schedule_1f1b::LossFn<'_, R>>,
) -> Result<PipelineOutput<R>>
where
C: RuntimeClient<R> + ShapeOps<R>,
{
let rank = self.pp_comm.rank();
let world_size = self.pp_comm.world_size().max(1);
let num_virtual = self.stages.len();
let total_stages = world_size * num_virtual;
let clock = PipelineClock::new(world_size, self.num_micro_batches, rank);
let schedule = clock.schedule_interleaved(num_virtual);
let mut forward_outputs: Vec<Vec<Option<Var<R>>>> = (0..num_virtual)
.map(|_| (0..self.num_micro_batches).map(|_| None).collect())
.collect();
let mut losses = Vec::new();
let mut input_grads: Vec<Option<Tensor<R>>> = vec![None; self.num_micro_batches];
let mut mb_inputs: Vec<Option<Tensor<R>>> = if rank == 0 {
let mbs = micro_batches.ok_or_else(|| Error::DistributedError {
reason: "first rank must provide micro_batches".to_string(),
})?;
mbs.into_iter().map(Some).collect()
} else {
vec![None; self.num_micro_batches]
};
let mut local_buffers: HashMap<(usize, usize, bool), Tensor<R>> = HashMap::new();
for &(v_idx, ref action) in &schedule {
let logical_stage = rank + v_idx * world_size;
let is_first_logical = logical_stage == 0;
let is_last_logical = logical_stage == total_stages - 1;
let prev_logical = if logical_stage > 0 {
Some(logical_stage - 1)
} else {
None
};
let next_logical = if logical_stage < total_stages - 1 {
Some(logical_stage + 1)
} else {
None
};
let prev_is_local = prev_logical
.map(|l| l % world_size == rank)
.unwrap_or(false);
let next_is_local = next_logical
.map(|l| l % world_size == rank)
.unwrap_or(false);
let prev_rank_id = prev_logical.map(|l| l % world_size);
let next_rank_id = next_logical.map(|l| l % world_size);
let base_tag = u32::try_from(v_idx * self.num_micro_batches * 4).map_err(|_| {
Error::DistributedError {
reason: "tag overflow".to_string(),
}
})?;
match *action {
PipelineAction::Forward(mb_id) => {
let input_tensor = if is_first_logical {
mb_inputs[mb_id]
.take()
.ok_or_else(|| Error::DistributedError {
reason: format!("micro-batch {mb_id} already consumed"),
})?
} else if prev_is_local {
let prev_l = prev_logical
.expect("prev_logical guaranteed Some when prev_is_local is true");
local_buffers
.remove(&(prev_l, mb_id, false))
.ok_or_else(|| Error::DistributedError {
reason: format!(
"no local activation for logical_stage={prev_l} mb={mb_id}"
),
})?
} else {
let src = prev_rank_id.ok_or_else(|| Error::DistributedError {
reason: "no previous stage for recv".to_string(),
})?;
recv_activation_tagged::<R>(
self.pp_comm.as_ref(),
src,
mb_id,
false,
base_tag,
&self.device,
)?
};
let input_var = Var::new(input_tensor, is_first_logical);
let output_var = self.stages[v_idx].forward(input_var)?;
if is_last_logical {
forward_outputs[v_idx][mb_id] = Some(output_var);
} else if next_is_local {
local_buffers
.insert((logical_stage, mb_id, false), output_var.tensor().clone());
} else {
let dest = next_rank_id.ok_or_else(|| Error::DistributedError {
reason: "no next stage for send".to_string(),
})?;
send_activation_tagged::<R>(
self.pp_comm.as_ref(),
output_var.tensor(),
dest,
mb_id,
false,
base_tag,
)?;
}
}
PipelineAction::Backward(mb_id) => {
let output_grad = if is_last_logical {
compute_loss_grad(
&mut forward_outputs[v_idx][mb_id],
mb_id,
loss_fn,
&mut losses,
&self.device,
)?
} else if next_is_local {
let next_l = next_logical
.expect("next_logical guaranteed Some when next_is_local is true");
local_buffers
.remove(&(next_l, mb_id, true))
.ok_or_else(|| Error::DistributedError {
reason: format!(
"no local grad for logical_stage={next_l} mb={mb_id}"
),
})?
} else {
let src = next_rank_id.ok_or_else(|| Error::DistributedError {
reason: "no next stage for backward recv".to_string(),
})?;
recv_activation_tagged::<R>(
self.pp_comm.as_ref(),
src,
mb_id,
true,
base_tag,
&self.device,
)?
};
let input_grad = self.stages[v_idx].backward(mb_id, output_grad)?;
if is_first_logical {
input_grads[mb_id] = Some(input_grad);
} else if prev_is_local {
local_buffers.insert((logical_stage, mb_id, true), input_grad);
} else {
let dest = prev_rank_id.ok_or_else(|| Error::DistributedError {
reason: "no prev stage for backward send".to_string(),
})?;
send_activation_tagged::<R>(
self.pp_comm.as_ref(),
&input_grad,
dest,
mb_id,
true,
base_tag,
)?;
}
}
_ => {}
}
}
Ok(PipelineOutput {
losses,
input_grads: input_grads.into_iter().flatten().collect(),
})
}
pub fn num_micro_batches(&self) -> usize {
self.num_micro_batches
}
pub fn num_virtual_stages(&self) -> usize {
self.stages.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::Var;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::TensorId;
struct IdentityTrainStage;
impl TrainablePipelineStage<CpuRuntime> for IdentityTrainStage {
fn forward(&mut self, input: Var<CpuRuntime>) -> Result<Var<CpuRuntime>> {
Ok(Var::new(input.tensor().clone(), false))
}
fn backward(
&mut self,
_mb_id: usize,
output_grad: Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
Ok(output_grad)
}
fn param_grads(&self) -> Result<Vec<(TensorId, Tensor<CpuRuntime>)>> {
Ok(Vec::new())
}
fn num_saved(&self) -> usize {
0
}
}
#[test]
fn test_interleaved_single_device() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stages: Vec<Box<dyn TrainablePipelineStage<CpuRuntime>>> =
vec![Box::new(IdentityTrainStage), Box::new(IdentityTrainStage)];
let mut schedule = ScheduleInterleaved1F1B::new(stages, 2, comm, device.clone()).unwrap();
let mb0 = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let mb1 = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let loss_fn = |output: &Var<CpuRuntime>| -> Result<Var<CpuRuntime>> {
Ok(Var::new(output.tensor().clone(), false))
};
let result = schedule
.run(&client, Some(vec![mb0, mb1]), Some(&loss_fn))
.unwrap();
assert_eq!(result.losses.len(), 2);
}
#[test]
fn test_interleaved_empty_stages_error() {
let (_client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let stages: Vec<Box<dyn TrainablePipelineStage<CpuRuntime>>> = vec![];
let result = ScheduleInterleaved1F1B::<CpuRuntime>::new(stages, 2, comm, device);
assert!(result.is_err());
}
}