Skip to main content

rumus_distributed/
tensor_parallel.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Tensor Parallelism: ColumnParallelLinear and RowParallelLinear.
3
4use std::sync::Arc;
5
6use rumus::nn::Parameter;
7use rumus::tensor::Tensor;
8
9use crate::collective::CollectiveBarrier;
10
11// ---------------------------------------------------------------------------
12// ColumnParallelLinear
13// ---------------------------------------------------------------------------
14
15/// Linear layer with weight sharded along columns (N dimension).
16///
17/// Forward: `Y_t = X @ W_t` (no collective).
18/// Backward: `grad_X = AllReduce(grad_Y_t @ W_t^T)` — handled by normal
19///           autograd + external AllReduce call.
20pub struct ColumnParallelLinear {
21    pub weight: Parameter,
22    pub bias: Option<Parameter>,
23    pub rank: usize,
24    pub world_size: usize,
25    pub barrier: Arc<CollectiveBarrier>,
26}
27
28impl ColumnParallelLinear {
29    /// Forward: Y_t = X @ W_t (+ bias).  No collective in forward.
30    pub fn forward(&self, x: &Tensor) -> Tensor {
31        let y = x.matmul(&self.weight.tensor);
32        match &self.bias {
33            Some(b) => y.add_bias(&b.tensor),
34            None => y,
35        }
36    }
37
38    /// AllReduce grad_X after backward (called explicitly by the user/executor).
39    pub fn allreduce_grad_x(&self, grad_x: &Tensor) -> Tensor {
40        let data = {
41            let g = grad_x.data();
42            g.to_vec()
43        };
44        let reduced = self.barrier.reduce(data);
45        let t = Tensor::new(reduced, grad_x.shape().to_vec());
46        t.to_gpu();
47        t
48    }
49}
50
51// ---------------------------------------------------------------------------
52// RowParallelLinear
53// ---------------------------------------------------------------------------
54
55/// Linear layer with weight sharded along rows (K dimension).
56///
57/// Forward: `Y_t = X_t @ W_t` (partial sum), then `Y = AllReduce(Y_t)`.
58/// Backward: `grad_X_t = grad_Y @ W_t^T` (no collective).
59pub struct RowParallelLinear {
60    pub weight: Parameter,
61    pub bias: Option<Parameter>,
62    pub rank: usize,
63    pub world_size: usize,
64    pub barrier: Arc<CollectiveBarrier>,
65}
66
67impl RowParallelLinear {
68    /// Forward: Y_t = X_t @ W_t, then AllReduce → Y.
69    pub fn forward(&self, x_t: &Tensor) -> Tensor {
70        let y_partial = x_t.matmul(&self.weight.tensor);
71
72        // AllReduce the partial sums via CPU staging.
73        let data = {
74            let g = y_partial.data();
75            g.to_vec()
76        };
77        let reduced = self.barrier.reduce(data);
78        let y = Tensor::new(reduced, y_partial.shape().to_vec());
79        y.to_gpu();
80
81        match &self.bias {
82            Some(b) if self.rank == 0 => y.add_bias(&b.tensor),
83            _ => y,
84        }
85    }
86}