rumus_distributed/tensor_parallel.rs
1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Tensor Parallelism: ColumnParallelLinear and RowParallelLinear.
3
4use std::sync::Arc;
5
6use rumus::nn::Parameter;
7use rumus::tensor::Tensor;
8
9use crate::collective::CollectiveBarrier;
10
11// ---------------------------------------------------------------------------
12// ColumnParallelLinear
13// ---------------------------------------------------------------------------
14
15/// Linear layer with weight sharded along columns (N dimension).
16///
17/// Forward: `Y_t = X @ W_t` (no collective).
18/// Backward: `grad_X = AllReduce(grad_Y_t @ W_t^T)` — handled by normal
19/// autograd + external AllReduce call.
20pub struct ColumnParallelLinear {
21 pub weight: Parameter,
22 pub bias: Option<Parameter>,
23 pub rank: usize,
24 pub world_size: usize,
25 pub barrier: Arc<CollectiveBarrier>,
26}
27
28impl ColumnParallelLinear {
29 /// Forward: Y_t = X @ W_t (+ bias). No collective in forward.
30 pub fn forward(&self, x: &Tensor) -> Tensor {
31 let y = x.matmul(&self.weight.tensor);
32 match &self.bias {
33 Some(b) => y.add_bias(&b.tensor),
34 None => y,
35 }
36 }
37
38 /// AllReduce grad_X after backward (called explicitly by the user/executor).
39 pub fn allreduce_grad_x(&self, grad_x: &Tensor) -> Tensor {
40 let data = {
41 let g = grad_x.data();
42 g.to_vec()
43 };
44 let reduced = self.barrier.reduce(data);
45 let t = Tensor::new(reduced, grad_x.shape().to_vec());
46 t.to_gpu();
47 t
48 }
49}
50
51// ---------------------------------------------------------------------------
52// RowParallelLinear
53// ---------------------------------------------------------------------------
54
55/// Linear layer with weight sharded along rows (K dimension).
56///
57/// Forward: `Y_t = X_t @ W_t` (partial sum), then `Y = AllReduce(Y_t)`.
58/// Backward: `grad_X_t = grad_Y @ W_t^T` (no collective).
59pub struct RowParallelLinear {
60 pub weight: Parameter,
61 pub bias: Option<Parameter>,
62 pub rank: usize,
63 pub world_size: usize,
64 pub barrier: Arc<CollectiveBarrier>,
65}
66
67impl RowParallelLinear {
68 /// Forward: Y_t = X_t @ W_t, then AllReduce → Y.
69 pub fn forward(&self, x_t: &Tensor) -> Tensor {
70 let y_partial = x_t.matmul(&self.weight.tensor);
71
72 // AllReduce the partial sums via CPU staging.
73 let data = {
74 let g = y_partial.data();
75 g.to_vec()
76 };
77 let reduced = self.barrier.reduce(data);
78 let y = Tensor::new(reduced, y_partial.shape().to_vec());
79 y.to_gpu();
80
81 match &self.bias {
82 Some(b) if self.rank == 0 => y.add_bias(&b.tensor),
83 _ => y,
84 }
85 }
86}