use crate::error::{NeuralError, Result as NeuralResult};
use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_core::random::rngs::SmallRng;
use scirs2_core::random::{Rng, RngExt, SeedableRng};
use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PipelineSchedule {
FThenB,
Interleaved1F1B,
}
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub n_stages: usize,
pub n_micro_batches: usize,
pub schedule: PipelineSchedule,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
n_stages: 2,
n_micro_batches: 4,
schedule: PipelineSchedule::FThenB,
}
}
}
#[derive(Debug, Clone)]
pub struct MicroBatch {
pub id: usize,
pub data: Array2<f64>,
pub labels: Option<Array1<f64>>,
}
pub struct PipelineStage {
pub stage_id: usize,
pub n_input_features: usize,
pub n_output_features: usize,
weight: Array2<f64>,
bias: Array1<f64>,
activation_stash: HashMap<usize, Array2<f64>>,
input_stash: HashMap<usize, Array2<f64>>,
}
impl PipelineStage {
pub fn new(stage_id: usize, n_in: usize, n_out: usize, seed: u64) -> Self {
let mut rng = SmallRng::seed_from_u64(seed);
let scale = (2.0_f64 / (n_in + n_out) as f64).sqrt();
let weight =
Array2::from_shape_fn((n_in, n_out), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale);
let bias = Array1::zeros(n_out);
Self {
stage_id,
n_input_features: n_in,
n_output_features: n_out,
weight,
bias,
activation_stash: HashMap::new(),
input_stash: HashMap::new(),
}
}
pub fn forward(&mut self, micro_batch_id: usize, input: &Array2<f64>) -> Array2<f64> {
let z = input.dot(&self.weight) + &self.bias; let activation = z.mapv(|v| v.max(0.0));
self.input_stash.insert(micro_batch_id, input.clone());
self.activation_stash
.insert(micro_batch_id, activation.clone());
activation
}
pub fn backward(&mut self, micro_batch_id: usize, grad_output: &Array2<f64>) -> Array2<f64> {
let activation = self
.activation_stash
.remove(µ_batch_id)
.unwrap_or_else(|| Array2::zeros(grad_output.raw_dim()));
let relu_mask = activation.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 });
let grad_pre_relu = grad_output * &relu_mask;
let grad_input = grad_pre_relu.dot(&self.weight.t());
self.input_stash.remove(µ_batch_id);
grad_input
}
pub fn clear_stash(&mut self) {
self.activation_stash.clear();
self.input_stash.clear();
}
}
pub struct PipelineParallel {
stages: Vec<PipelineStage>,
config: PipelineConfig,
}
impl PipelineParallel {
pub fn new(stages: Vec<PipelineStage>, config: PipelineConfig) -> NeuralResult<Self> {
if stages.len() != config.n_stages {
return Err(NeuralError::ConfigError(format!(
"PipelineConfig.n_stages={} but {} stages provided",
config.n_stages,
stages.len()
)));
}
Ok(Self { stages, config })
}
pub fn run_schedule(&mut self, input: &Array2<f64>, labels: &Array1<f64>) -> NeuralResult<f64> {
let micro_batches =
Self::split_into_micro_batches(input, labels, self.config.n_micro_batches);
let mut outputs: Vec<(usize, Array2<f64>)> = Vec::with_capacity(micro_batches.len());
#[allow(unreachable_patterns)]
match self.config.schedule {
PipelineSchedule::FThenB => {
for mb in µ_batches {
let out = self.forward_one_micro_batch(mb);
outputs.push((mb.id, out));
}
for (mb_id, out) in outputs.iter().rev() {
let mb_labels = micro_batches[*mb_id]
.labels
.as_ref()
.ok_or_else(|| NeuralError::ComputationError("missing labels".into()))?;
let grad = Self::mse_grad(out, mb_labels);
self.backward_one_micro_batch(*mb_id, &grad);
}
}
PipelineSchedule::Interleaved1F1B => {
let warmup = (self.config.n_stages - 1).min(micro_batches.len());
for mb in µ_batches[..warmup] {
let out = self.forward_one_micro_batch(mb);
outputs.push((mb.id, out));
}
let steady_start = warmup;
let steady_end = micro_batches.len();
let mut backward_idx = 0_usize;
for fwd_idx in steady_start..steady_end {
let mb = µ_batches[fwd_idx];
let out = self.forward_one_micro_batch(mb);
outputs.push((mb.id, out));
if backward_idx < outputs.len().saturating_sub(warmup) {
let (bwd_mb_id, bwd_out) = &outputs[backward_idx];
let mb_labels =
micro_batches[*bwd_mb_id].labels.as_ref().ok_or_else(|| {
NeuralError::ComputationError("missing labels".into())
})?;
let grad = Self::mse_grad(bwd_out, mb_labels);
self.backward_one_micro_batch(*bwd_mb_id, &grad);
backward_idx += 1;
}
}
while backward_idx < outputs.len() {
let (bwd_mb_id, bwd_out) = &outputs[backward_idx];
let mb_labels = micro_batches[*bwd_mb_id]
.labels
.as_ref()
.ok_or_else(|| NeuralError::ComputationError("missing labels".into()))?;
let grad = Self::mse_grad(bwd_out, mb_labels);
self.backward_one_micro_batch(*bwd_mb_id, &grad);
backward_idx += 1;
}
}
_ => {
return Err(NeuralError::NotImplemented(
"unsupported pipeline schedule variant".into(),
));
}
}
let total_loss: f64 = outputs
.iter()
.filter_map(|(mb_id, out)| {
let mb_labels = micro_batches[*mb_id].labels.as_ref()?;
Some(Self::mse_loss(out, mb_labels))
})
.sum();
let n = outputs.len().max(1) as f64;
for stage in self.stages.iter_mut() {
stage.clear_stash();
}
Ok(total_loss / n)
}
fn split_into_micro_batches(
input: &Array2<f64>,
labels: &Array1<f64>,
n: usize,
) -> Vec<MicroBatch> {
let batch_size = input.shape()[0];
let chunk = batch_size.div_ceil(n);
let mut result = Vec::with_capacity(n);
let mut start = 0_usize;
let mut id = 0_usize;
while start < batch_size && id < n {
let end = (start + chunk).min(batch_size);
let data = input
.slice(scirs2_core::ndarray::s![start..end, ..])
.to_owned();
let lbl = labels
.slice(scirs2_core::ndarray::s![start..end])
.to_owned();
result.push(MicroBatch {
id,
data,
labels: Some(lbl),
});
start = end;
id += 1;
}
result
}
fn forward_one_micro_batch(&mut self, mb: &MicroBatch) -> Array2<f64> {
let mut x = mb.data.clone();
for stage in self.stages.iter_mut() {
x = stage.forward(mb.id, &x);
}
x
}
fn backward_one_micro_batch(&mut self, mb_id: usize, grad: &Array2<f64>) {
let mut g = grad.clone();
for stage in self.stages.iter_mut().rev() {
g = stage.backward(mb_id, &g);
}
}
fn mse_loss(output: &Array2<f64>, labels: &Array1<f64>) -> f64 {
let batch = output.shape()[0];
if batch == 0 {
return 0.0;
}
let n_out = output.shape()[1];
let pred = output.index_axis(Axis(1), n_out.saturating_sub(1));
let diff = &pred - labels;
diff.mapv(|v| v * v).mean().unwrap_or(0.0)
}
fn mse_grad(output: &Array2<f64>, labels: &Array1<f64>) -> Array2<f64> {
let batch = output.shape()[0];
let n_out = output.shape()[1];
let mut grad = Array2::<f64>::zeros((batch, n_out));
if batch == 0 {
return grad;
}
let n_out_idx = n_out.saturating_sub(1);
let pred = output.index_axis(Axis(1), n_out_idx);
let diff = (&pred - labels).mapv(|v| 2.0 * v / batch as f64);
grad.index_axis_mut(Axis(1), n_out_idx).assign(&diff);
grad
}
pub fn n_stages(&self) -> usize {
self.stages.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
fn make_pp(n_stages: usize, n_micro: usize, schedule: PipelineSchedule) -> PipelineParallel {
let mut stages = Vec::new();
let mut n_in = 8_usize;
for s in 0..n_stages {
let n_out = if s < n_stages - 1 { 8 } else { 2 };
stages.push(PipelineStage::new(s, n_in, n_out, s as u64));
n_in = n_out;
}
let cfg = PipelineConfig {
n_stages,
n_micro_batches: n_micro,
schedule,
};
PipelineParallel::new(stages, cfg).expect("valid pipeline")
}
#[test]
fn test_split_into_micro_batches_count() {
let input = Array2::<f64>::ones((16, 8));
let labels = Array1::<f64>::ones(16);
let mbs = PipelineParallel::split_into_micro_batches(&input, &labels, 4);
assert_eq!(mbs.len(), 4);
}
#[test]
fn test_split_micro_batch_ids_sequential() {
let input = Array2::<f64>::ones((12, 4));
let labels = Array1::<f64>::zeros(12);
let mbs = PipelineParallel::split_into_micro_batches(&input, &labels, 3);
for (i, mb) in mbs.iter().enumerate() {
assert_eq!(mb.id, i, "micro-batch id should be sequential");
}
}
#[test]
fn test_split_micro_batches_cover_all_rows() {
let batch = 17;
let input = Array2::<f64>::ones((batch, 4));
let labels = Array1::<f64>::zeros(batch);
let mbs = PipelineParallel::split_into_micro_batches(&input, &labels, 4);
let total_rows: usize = mbs.iter().map(|mb| mb.data.shape()[0]).sum();
assert_eq!(total_rows, batch, "all rows must be covered");
}
#[test]
fn test_stage_forward_output_shape() {
let mut stage = PipelineStage::new(0, 4, 8, 0);
let input = Array2::<f64>::ones((5, 4));
let out = stage.forward(0, &input);
assert_eq!(out.shape(), [5, 8]);
}
#[test]
fn test_stage_backward_output_shape() {
let mut stage = PipelineStage::new(0, 4, 8, 0);
let input = Array2::<f64>::ones((5, 4));
let _ = stage.forward(0, &input);
let grad_out = Array2::<f64>::ones((5, 8));
let grad_in = stage.backward(0, &grad_out);
assert_eq!(
grad_in.shape(),
[5, 4],
"backward shape must match input shape"
);
}
#[test]
fn test_pipeline_fthenb_loss_finite() {
let mut pp = make_pp(2, 4, PipelineSchedule::FThenB);
let input = Array2::<f64>::ones((16, 8));
let labels = Array1::<f64>::ones(16);
let loss = pp.run_schedule(&input, &labels).expect("run ok");
assert!(loss.is_finite(), "loss should be finite, got {loss}");
}
#[test]
fn test_pipeline_interleaved_loss_finite() {
let mut pp = make_pp(2, 4, PipelineSchedule::Interleaved1F1B);
let input = Array2::<f64>::ones((16, 8));
let labels = Array1::<f64>::ones(16);
let loss = pp.run_schedule(&input, &labels).expect("run ok");
assert!(loss.is_finite(), "loss should be finite, got {loss}");
}
#[test]
fn test_pipeline_both_schedules_same_order_data() {
let input = Array2::<f64>::ones((8, 8));
let labels = Array1::<f64>::zeros(8);
let mut pp_fthenb = make_pp(2, 2, PipelineSchedule::FThenB);
let loss_ftb = pp_fthenb.run_schedule(&input, &labels).expect("fthenb ok");
let mut pp_1f1b = make_pp(2, 2, PipelineSchedule::Interleaved1F1B);
let loss_1f1b = pp_1f1b.run_schedule(&input, &labels).expect("1f1b ok");
assert!(loss_ftb.is_finite());
assert!(loss_1f1b.is_finite());
}
#[test]
fn test_mse_loss_zero_when_predictions_equal_labels() {
let n = 4;
let mut stage = PipelineStage::new(0, n, n, 42);
stage.weight = Array2::eye(n) * 10.0;
stage.bias = Array1::zeros(n);
let input = Array2::from_shape_fn((4, n), |(_i, j)| j as f64 * 0.1 + 0.5);
let out = stage.forward(0, &input);
let last_col = out.index_axis(Axis(1), n - 1).to_owned();
let loss = PipelineParallel::mse_loss(&out, &last_col);
assert!(loss.abs() < 1e-12, "loss should be 0; got {loss}");
}
#[test]
fn test_pipeline_n_stages() {
let pp = make_pp(3, 4, PipelineSchedule::FThenB);
assert_eq!(pp.n_stages(), 3);
}
#[test]
fn test_pipeline_stage_count_mismatch_error() {
let stages = vec![PipelineStage::new(0, 4, 4, 0)];
let cfg = PipelineConfig {
n_stages: 3,
n_micro_batches: 2,
schedule: PipelineSchedule::FThenB,
};
assert!(
PipelineParallel::new(stages, cfg).is_err(),
"mismatched stage count must return error"
);
}
#[test]
fn test_pipeline_1_stage_matches_single_forward() {
let n_in = 6;
let n_out = 3;
let seed = 77;
let input = Array2::from_shape_fn((4, n_in), |(i, j)| (i * n_in + j) as f64 * 0.01);
let labels = Array1::<f64>::zeros(4);
let stage_ref = PipelineStage::new(0, n_in, n_out, seed);
let cfg = PipelineConfig {
n_stages: 1,
n_micro_batches: 1,
schedule: PipelineSchedule::FThenB,
};
let mut pp = PipelineParallel::new(vec![stage_ref], cfg).expect("ok");
let loss = pp.run_schedule(&input, &labels).expect("run ok");
assert!(loss.is_finite());
}
}