rumus_distributed/
pipeline.rs1use std::sync::mpsc;
5
6use rumus::autograd::{backward_with_grad, context, GradientStore, Tape};
7use rumus::tensor::{GradId, Tensor};
8
9pub struct PipelineStage {
15 pub device_index: usize,
16 pub forward_fn: Box<dyn Fn(&Tensor) -> Tensor + Send + Sync>,
17}
18
19pub struct PipelineExecutor {
25 pub stages: Vec<PipelineStage>,
26 pub num_micro_batches: usize,
27}
28
29impl PipelineExecutor {
30 pub fn new(stages: Vec<PipelineStage>, num_micro_batches: usize) -> Self {
31 Self { stages, num_micro_batches }
32 }
33
34 pub fn run(
36 &self,
37 input: &Tensor,
38 loss_fn: &(dyn Fn(&Tensor) -> Tensor + Send + Sync),
39 ) -> Vec<GradientStore> {
40 let p = self.stages.len();
41 let m = self.num_micro_batches;
42 let batch = input.shape()[0];
43 assert!(batch % m == 0);
44 let micro_size = batch / m;
45 let micros: Vec<Tensor> = (0..m)
46 .map(|i| input.slice_range(0, i * micro_size, (i + 1) * micro_size))
47 .collect();
48
49 let mut fwd_tx_opts: Vec<Option<mpsc::SyncSender<Tensor>>> = Vec::new();
52 let mut fwd_rx_opts: Vec<Option<mpsc::Receiver<Tensor>>> = Vec::new();
53 let mut bwd_tx_opts: Vec<Option<mpsc::SyncSender<Tensor>>> = Vec::new();
54 let mut bwd_rx_opts: Vec<Option<mpsc::Receiver<Tensor>>> = Vec::new();
55
56 fwd_rx_opts.push(None);
58 bwd_tx_opts.push(None);
59
60 for _ in 0..p.saturating_sub(1) {
61 let (ftx, frx) = mpsc::sync_channel(m);
62 let (btx, brx) = mpsc::sync_channel(m);
63 fwd_tx_opts.push(None); fwd_rx_opts.push(Some(frx));
65 bwd_tx_opts.push(Some(btx));
66 bwd_rx_opts.push(None); let last_fwd_tx = fwd_tx_opts.len() - 1;
69 fwd_tx_opts[last_fwd_tx - 1] = Some(ftx);
70 let last_bwd_rx = bwd_rx_opts.len() - 1;
71 bwd_rx_opts[last_bwd_rx - 1] = Some(brx);
72 }
73 let grad_stores: Vec<std::sync::Mutex<GradientStore>> =
77 (0..p).map(|_| std::sync::Mutex::new(GradientStore::new())).collect();
78
79 std::thread::scope(|scope| {
80 let micros_ref = µs;
81 let stages_ref = &self.stages;
82 let gs_ref = &grad_stores;
83
84 let mut handles = Vec::with_capacity(p);
85 for s in 0..p {
86 let my_fwd_rx = fwd_rx_opts[s].take();
87 let my_fwd_tx = fwd_tx_opts[s].take();
88 let my_bwd_rx = bwd_rx_opts[s].take();
89 let my_bwd_tx = bwd_tx_opts[s].take();
90
91 handles.push(scope.spawn(move || {
92 run_stage(
93 s, p, m, stages_ref, micros_ref,
94 my_fwd_rx, my_fwd_tx, my_bwd_rx, my_bwd_tx,
95 gs_ref, loss_fn,
96 );
97 }));
98 }
99
100 for h in handles { h.join().expect("pipeline thread panicked"); }
101 });
102
103 grad_stores.into_iter().map(|m| m.into_inner().unwrap()).collect()
104 }
105}
106
107fn run_stage(
108 stage: usize,
109 num_stages: usize,
110 num_micro: usize,
111 stages: &[PipelineStage],
112 micros: &[Tensor],
113 fwd_rx: Option<mpsc::Receiver<Tensor>>,
114 fwd_tx: Option<mpsc::SyncSender<Tensor>>,
115 bwd_rx: Option<mpsc::Receiver<Tensor>>,
116 bwd_tx: Option<mpsc::SyncSender<Tensor>>,
117 grad_stores: &[std::sync::Mutex<GradientStore>],
118 loss_fn: &(dyn Fn(&Tensor) -> Tensor + Send + Sync),
119) {
120 let is_first = stage == 0;
121 let is_last = stage == num_stages - 1;
122
123 let mut saved_tapes: Vec<Option<Tape>> = (0..num_micro).map(|_| None).collect();
124 let mut saved_outputs: Vec<Option<Tensor>> = (0..num_micro).map(|_| None).collect();
125 let mut saved_input_gids: Vec<Option<GradId>> = (0..num_micro).map(|_| None).collect();
126
127 for mb in 0..num_micro {
129 let mut input_t = if is_first {
130 micros[mb].clone()
131 } else {
132 fwd_rx.as_ref().unwrap().recv().expect("fwd recv failed")
133 };
134
135 if !is_first {
137 input_t.set_requires_grad(true);
138 saved_input_gids[mb] = input_t.grad_id();
139 }
140
141 context::install_tape(Tape::new());
143 let output = (stages[stage].forward_fn)(&input_t);
144 saved_tapes[mb] = context::take_tape();
145 saved_outputs[mb] = Some(output.clone());
146
147 if !is_last {
148 fwd_tx.as_ref().unwrap().send(output).expect("fwd send failed");
149 }
150 }
151
152 for mb in (0..num_micro).rev() {
154 let output = saved_outputs[mb].take().unwrap();
155
156 if is_last {
157 context::install_tape(saved_tapes[mb].take().unwrap());
159 let loss = loss_fn(&output);
160 let mut grads = rumus::autograd::backward(&loss).expect("backward failed");
161
162 if !is_first {
164 if let Some(gid) = saved_input_gids[mb] {
165 if let Some(gi) = grads.remove(gid) {
166 bwd_tx.as_ref().unwrap().send(gi).expect("bwd send failed");
167 }
168 }
169 }
170 grad_stores[stage].lock().unwrap().merge_from(&mut grads);
171 } else {
172 let grad_output = bwd_rx.as_ref().unwrap().recv().expect("bwd recv failed");
174 context::install_tape(saved_tapes[mb].take().unwrap());
175 let mut grads = backward_with_grad(&output, grad_output).expect("bwd_with_grad failed");
176
177 if !is_first {
178 if let Some(gid) = saved_input_gids[mb] {
179 if let Some(gi) = grads.remove(gid) {
180 bwd_tx.as_ref().unwrap().send(gi).expect("bwd send failed");
181 }
182 }
183 }
184 grad_stores[stage].lock().unwrap().merge_from(&mut grads);
185 }
186 }
187}