Skip to main content

rumus_distributed/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Pipeline parallelism: 1F1B micro-batch schedule with per-micro-batch tapes.
3
4use std::sync::mpsc;
5
6use rumus::autograd::{backward_with_grad, context, GradientStore, Tape};
7use rumus::tensor::{GradId, Tensor};
8
9// ---------------------------------------------------------------------------
10// PipelineStage
11// ---------------------------------------------------------------------------
12
13/// A single pipeline stage on a specific device.
14pub struct PipelineStage {
15    pub device_index: usize,
16    pub forward_fn: Box<dyn Fn(&Tensor) -> Tensor + Send + Sync>,
17}
18
19// ---------------------------------------------------------------------------
20// PipelineExecutor
21// ---------------------------------------------------------------------------
22
23/// 1F1B pipeline executor with per-micro-batch isolated tapes.
24pub struct PipelineExecutor {
25    pub stages: Vec<PipelineStage>,
26    pub num_micro_batches: usize,
27}
28
29impl PipelineExecutor {
30    pub fn new(stages: Vec<PipelineStage>, num_micro_batches: usize) -> Self {
31        Self { stages, num_micro_batches }
32    }
33
34    /// Run the pipeline.  Returns per-stage gradient stores.
35    pub fn run(
36        &self,
37        input: &Tensor,
38        loss_fn: &(dyn Fn(&Tensor) -> Tensor + Send + Sync),
39    ) -> Vec<GradientStore> {
40        let p = self.stages.len();
41        let m = self.num_micro_batches;
42        let batch = input.shape()[0];
43        assert!(batch % m == 0);
44        let micro_size = batch / m;
45        let micros: Vec<Tensor> = (0..m)
46            .map(|i| input.slice_range(0, i * micro_size, (i + 1) * micro_size))
47            .collect();
48
49        // Channels: stage s sends activations to stage s+1 (fwd) and
50        // stage s+1 sends gradients back to stage s (bwd).
51        let mut fwd_tx_opts: Vec<Option<mpsc::SyncSender<Tensor>>> = Vec::new();
52        let mut fwd_rx_opts: Vec<Option<mpsc::Receiver<Tensor>>> = Vec::new();
53        let mut bwd_tx_opts: Vec<Option<mpsc::SyncSender<Tensor>>> = Vec::new();
54        let mut bwd_rx_opts: Vec<Option<mpsc::Receiver<Tensor>>> = Vec::new();
55
56        // Stage 0: no fwd_rx, no bwd_tx.
57        fwd_rx_opts.push(None);
58        bwd_tx_opts.push(None);
59
60        for _ in 0..p.saturating_sub(1) {
61            let (ftx, frx) = mpsc::sync_channel(m);
62            let (btx, brx) = mpsc::sync_channel(m);
63            fwd_tx_opts.push(None); // placeholder, will be set below
64            fwd_rx_opts.push(Some(frx));
65            bwd_tx_opts.push(Some(btx));
66            bwd_rx_opts.push(None); // placeholder
67            // Fix: assign properly.
68            let last_fwd_tx = fwd_tx_opts.len() - 1;
69            fwd_tx_opts[last_fwd_tx - 1] = Some(ftx);
70            let last_bwd_rx = bwd_rx_opts.len() - 1;
71            bwd_rx_opts[last_bwd_rx - 1] = Some(brx);
72        }
73        // Last stage: no fwd_tx, no bwd_rx.
74        // Already None by construction.
75
76        let grad_stores: Vec<std::sync::Mutex<GradientStore>> =
77            (0..p).map(|_| std::sync::Mutex::new(GradientStore::new())).collect();
78
79        std::thread::scope(|scope| {
80            let micros_ref = &micros;
81            let stages_ref = &self.stages;
82            let gs_ref = &grad_stores;
83
84            let mut handles = Vec::with_capacity(p);
85            for s in 0..p {
86                let my_fwd_rx = fwd_rx_opts[s].take();
87                let my_fwd_tx = fwd_tx_opts[s].take();
88                let my_bwd_rx = bwd_rx_opts[s].take();
89                let my_bwd_tx = bwd_tx_opts[s].take();
90
91                handles.push(scope.spawn(move || {
92                    run_stage(
93                        s, p, m, stages_ref, micros_ref,
94                        my_fwd_rx, my_fwd_tx, my_bwd_rx, my_bwd_tx,
95                        gs_ref, loss_fn,
96                    );
97                }));
98            }
99
100            for h in handles { h.join().expect("pipeline thread panicked"); }
101        });
102
103        grad_stores.into_iter().map(|m| m.into_inner().unwrap()).collect()
104    }
105}
106
107fn run_stage(
108    stage: usize,
109    num_stages: usize,
110    num_micro: usize,
111    stages: &[PipelineStage],
112    micros: &[Tensor],
113    fwd_rx: Option<mpsc::Receiver<Tensor>>,
114    fwd_tx: Option<mpsc::SyncSender<Tensor>>,
115    bwd_rx: Option<mpsc::Receiver<Tensor>>,
116    bwd_tx: Option<mpsc::SyncSender<Tensor>>,
117    grad_stores: &[std::sync::Mutex<GradientStore>],
118    loss_fn: &(dyn Fn(&Tensor) -> Tensor + Send + Sync),
119) {
120    let is_first = stage == 0;
121    let is_last = stage == num_stages - 1;
122
123    let mut saved_tapes: Vec<Option<Tape>> = (0..num_micro).map(|_| None).collect();
124    let mut saved_outputs: Vec<Option<Tensor>> = (0..num_micro).map(|_| None).collect();
125    let mut saved_input_gids: Vec<Option<GradId>> = (0..num_micro).map(|_| None).collect();
126
127    // === Forward all micro-batches ===
128    for mb in 0..num_micro {
129        let mut input_t = if is_first {
130            micros[mb].clone()
131        } else {
132            fwd_rx.as_ref().unwrap().recv().expect("fwd recv failed")
133        };
134
135        // Track incoming tensor for gradient extraction.
136        if !is_first {
137            input_t.set_requires_grad(true);
138            saved_input_gids[mb] = input_t.grad_id();
139        }
140
141        // Fresh isolated tape.
142        context::install_tape(Tape::new());
143        let output = (stages[stage].forward_fn)(&input_t);
144        saved_tapes[mb] = context::take_tape();
145        saved_outputs[mb] = Some(output.clone());
146
147        if !is_last {
148            fwd_tx.as_ref().unwrap().send(output).expect("fwd send failed");
149        }
150    }
151
152    // === Backward all micro-batches (reverse) ===
153    for mb in (0..num_micro).rev() {
154        let output = saved_outputs[mb].take().unwrap();
155
156        if is_last {
157            // Last stage: loss + standard backward.
158            context::install_tape(saved_tapes[mb].take().unwrap());
159            let loss = loss_fn(&output);
160            let mut grads = rumus::autograd::backward(&loss).expect("backward failed");
161
162            // Send grad_input to prev stage.
163            if !is_first {
164                if let Some(gid) = saved_input_gids[mb] {
165                    if let Some(gi) = grads.remove(gid) {
166                        bwd_tx.as_ref().unwrap().send(gi).expect("bwd send failed");
167                    }
168                }
169            }
170            grad_stores[stage].lock().unwrap().merge_from(&mut grads);
171        } else {
172            // Middle/first: receive grad from next, inject into local tape.
173            let grad_output = bwd_rx.as_ref().unwrap().recv().expect("bwd recv failed");
174            context::install_tape(saved_tapes[mb].take().unwrap());
175            let mut grads = backward_with_grad(&output, grad_output).expect("bwd_with_grad failed");
176
177            if !is_first {
178                if let Some(gid) = saved_input_gids[mb] {
179                    if let Some(gi) = grads.remove(gid) {
180                        bwd_tx.as_ref().unwrap().send(gi).expect("bwd send failed");
181                    }
182                }
183            }
184            grad_stores[stage].lock().unwrap().merge_from(&mut grads);
185        }
186    }
187}