Skip to main content

entrenar/train/transformer_trainer/
tensor_parallel.rs

1//! Tensor parallelism for transformer pretraining.
2//!
3//! Splits weight matrices across GPUs along the hidden/head dimension.
4//! Each GPU holds a shard of Q/K/V/O projections and FFN layers.
5//!
6//! # Architecture (Megatron-LM style)
7//!
8//! ## Attention (Column + Row parallel):
9//! ```text
10//! Input X [S, H]
11//!   ├── GPU 0: Q₀ = X × W_q[:, :H/2]  →  heads 0..N/2
12//!   └── GPU 1: Q₁ = X × W_q[:, H/2:]  →  heads N/2..N
13//!                    ↓ attention ↓
14//!   ├── GPU 0: O₀ = attn₀ × W_o[:H/2, :]
15//!   └── GPU 1: O₁ = attn₁ × W_o[H/2:, :]
16//!              AllReduce(O₀ + O₁)
17//! ```
18//!
19//! ## FFN (Column + Row parallel):
20//! ```text
21//! Input X [S, H]
22//!   ├── GPU 0: gate₀ = X × W_gate[:, :I/2], up₀ = X × W_up[:, :I/2]
23//!   └── GPU 1: gate₁ = X × W_gate[:, I/2:], up₁ = X × W_up[:, I/2:]
24//!                     ↓ SiLU(gate) * up ↓
25//!   ├── GPU 0: down₀ = act₀ × W_down[:I/2, :]
26//!   └── GPU 1: down₁ = act₁ × W_down[I/2:, :]
27//!              AllReduce(down₀ + down₁)
28//! ```
29//!
30//! # Communication
31//!
32//! 2 AllReduce per block (attention output + FFN output).
33//! Each AllReduce is [S, H] = seq_len × hidden_size × 4 bytes.
34//! For S=1024, H=1024: 4 MB per AllReduce, 8 MB per block.
35//!
36//! # Contract (C-TP-001)
37//!
38//! - Column parallel: output[i, j] = input[i, :] · weight[:, shard_start + j]
39//! - Row parallel: after AllReduce, output = Σ(partial_outputs) across shards
40//! - Weight shards are contiguous slices along the parallelism dimension
41
42/// Tensor parallel configuration for a single GPU.
43#[derive(Debug, Clone)]
44pub struct TensorParallelConfig {
45    /// This GPU's rank within the TP group
46    pub tp_rank: usize,
47    /// Total GPUs in the TP group
48    pub tp_size: usize,
49    /// Full hidden size (before sharding)
50    pub hidden_size: usize,
51    /// Full intermediate size (before sharding)
52    pub intermediate_size: usize,
53    /// Number of attention heads (must be divisible by tp_size)
54    pub num_heads: usize,
55    /// Number of KV heads (must be divisible by tp_size)
56    pub num_kv_heads: usize,
57    /// Head dimension
58    pub head_dim: usize,
59}
60
61impl TensorParallelConfig {
62    /// Create a new TP config.
63    ///
64    /// # Panics
65    /// Panics if heads or intermediate size not divisible by tp_size.
66    pub fn new(
67        tp_rank: usize,
68        tp_size: usize,
69        hidden_size: usize,
70        intermediate_size: usize,
71        num_heads: usize,
72        num_kv_heads: usize,
73    ) -> Self {
74        assert!(
75            num_heads.is_multiple_of(tp_size),
76            "num_heads ({num_heads}) must be divisible by tp_size ({tp_size})"
77        );
78        assert!(
79            num_kv_heads.is_multiple_of(tp_size),
80            "num_kv_heads ({num_kv_heads}) must be divisible by tp_size ({tp_size})"
81        );
82        assert!(
83            intermediate_size.is_multiple_of(tp_size),
84            "intermediate_size ({intermediate_size}) must be divisible by tp_size ({tp_size})"
85        );
86
87        let head_dim = hidden_size / num_heads;
88
89        Self { tp_rank, tp_size, hidden_size, intermediate_size, num_heads, num_kv_heads, head_dim }
90    }
91
92    /// Number of Q heads on this GPU.
93    pub fn local_num_heads(&self) -> usize {
94        self.num_heads / self.tp_size
95    }
96
97    /// Number of KV heads on this GPU.
98    pub fn local_num_kv_heads(&self) -> usize {
99        self.num_kv_heads / self.tp_size
100    }
101
102    /// Local Q projection size: local_heads × head_dim.
103    pub fn local_q_size(&self) -> usize {
104        self.local_num_heads() * self.head_dim
105    }
106
107    /// Local KV projection size: local_kv_heads × head_dim.
108    pub fn local_kv_size(&self) -> usize {
109        self.local_num_kv_heads() * self.head_dim
110    }
111
112    /// Local intermediate (FFN) size.
113    pub fn local_intermediate_size(&self) -> usize {
114        self.intermediate_size / self.tp_size
115    }
116
117    /// Memory savings from tensor parallelism (ratio).
118    ///
119    /// TP shards attention + FFN weights. Embedding and norms are replicated.
120    pub fn weight_memory_fraction(&self) -> f64 {
121        1.0 / self.tp_size as f64
122    }
123}
124
125/// Weight shard specification for a column-parallel layer.
126///
127/// Column parallel: input is replicated, output is sharded.
128/// Each GPU holds weight[:, shard_start:shard_end].
129#[derive(Debug, Clone)]
130pub struct ColumnParallelShard {
131    /// Input dimension (full, not sharded)
132    pub input_dim: usize,
133    /// Output dimension per GPU (sharded)
134    pub local_output_dim: usize,
135    /// Start column index in the full weight matrix
136    pub col_start: usize,
137    /// End column index (exclusive)
138    pub col_end: usize,
139}
140
141impl ColumnParallelShard {
142    /// Create a column-parallel shard for Q/K/V projection or FFN gate/up.
143    pub fn new(input_dim: usize, full_output_dim: usize, tp_rank: usize, tp_size: usize) -> Self {
144        let local_output_dim = full_output_dim / tp_size;
145        let col_start = tp_rank * local_output_dim;
146        let col_end = col_start + local_output_dim;
147
148        Self { input_dim, local_output_dim, col_start, col_end }
149    }
150
151    /// Number of elements in the local weight shard.
152    pub fn num_elements(&self) -> usize {
153        self.input_dim * self.local_output_dim
154    }
155
156    /// Extract this shard from a full weight matrix (row-major).
157    ///
158    /// Full weight shape: [input_dim, full_output_dim]
159    /// Returns: [input_dim, local_output_dim]
160    pub fn extract_shard(&self, full_weights: &[f32], full_output_dim: usize) -> Vec<f32> {
161        let mut shard = Vec::with_capacity(self.num_elements());
162        for row in 0..self.input_dim {
163            let row_start = row * full_output_dim;
164            shard.extend_from_slice(
165                &full_weights[row_start + self.col_start..row_start + self.col_end],
166            );
167        }
168        shard
169    }
170}
171
172/// Weight shard specification for a row-parallel layer.
173///
174/// Row parallel: input is sharded, output is partial (needs AllReduce).
175/// Each GPU holds weight[shard_start:shard_end, :].
176#[derive(Debug, Clone)]
177pub struct RowParallelShard {
178    /// Input dimension per GPU (sharded)
179    pub local_input_dim: usize,
180    /// Output dimension (full, not sharded)
181    pub output_dim: usize,
182    /// Start row index in the full weight matrix
183    pub row_start: usize,
184    /// End row index (exclusive)
185    pub row_end: usize,
186}
187
188impl RowParallelShard {
189    /// Create a row-parallel shard for O projection or FFN down.
190    pub fn new(full_input_dim: usize, output_dim: usize, tp_rank: usize, tp_size: usize) -> Self {
191        let local_input_dim = full_input_dim / tp_size;
192        let row_start = tp_rank * local_input_dim;
193        let row_end = row_start + local_input_dim;
194
195        Self { local_input_dim, output_dim, row_start, row_end }
196    }
197
198    /// Number of elements in the local weight shard.
199    pub fn num_elements(&self) -> usize {
200        self.local_input_dim * self.output_dim
201    }
202
203    /// Extract this shard from a full weight matrix (row-major).
204    ///
205    /// Full weight shape: [full_input_dim, output_dim]
206    /// Returns: [local_input_dim, output_dim]
207    pub fn extract_shard(&self, full_weights: &[f32], _full_input_dim: usize) -> Vec<f32> {
208        let start = self.row_start * self.output_dim;
209        let end = self.row_end * self.output_dim;
210        full_weights[start..end].to_vec()
211    }
212}
213
214/// Communication cost estimate for tensor parallelism.
215#[derive(Debug, Clone)]
216pub struct TpCommCost {
217    /// Bytes per AllReduce call
218    pub bytes_per_allreduce: usize,
219    /// Number of AllReduce calls per block (2: attention + FFN)
220    pub allreduces_per_block: usize,
221    /// Total blocks
222    pub num_blocks: usize,
223}
224
225impl TpCommCost {
226    /// Estimate TP communication cost.
227    pub fn estimate(seq_len: usize, hidden_size: usize, num_blocks: usize) -> Self {
228        Self {
229            bytes_per_allreduce: seq_len * hidden_size * std::mem::size_of::<f32>(),
230            allreduces_per_block: 2,
231            num_blocks,
232        }
233    }
234
235    /// Total bytes communicated per training step.
236    pub fn total_bytes_per_step(&self) -> usize {
237        self.bytes_per_allreduce * self.allreduces_per_block * self.num_blocks
238    }
239
240    /// Estimated overhead in milliseconds (assumes 10 GB/s intra-node bandwidth).
241    pub fn estimated_overhead_ms(&self, bandwidth_gbps: f64) -> f64 {
242        let total_bytes = self.total_bytes_per_step() as f64;
243        let bandwidth_bytes_per_ms = bandwidth_gbps * 1e9 / 8.0 / 1000.0;
244        total_bytes / bandwidth_bytes_per_ms
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_tp_config_basic() {
254        // 350M: H=1024, I=4096, 16 heads, 4 KV heads, TP=2
255        let tp = TensorParallelConfig::new(0, 2, 1024, 4096, 16, 4);
256        assert_eq!(tp.local_num_heads(), 8);
257        assert_eq!(tp.local_num_kv_heads(), 2);
258        assert_eq!(tp.local_q_size(), 8 * 64); // 512
259        assert_eq!(tp.local_kv_size(), 2 * 64); // 128
260        assert_eq!(tp.local_intermediate_size(), 2048);
261        assert!((tp.weight_memory_fraction() - 0.5).abs() < 1e-10);
262    }
263
264    #[test]
265    #[should_panic(expected = "num_heads")]
266    fn test_tp_config_indivisible_heads() {
267        TensorParallelConfig::new(0, 3, 1024, 4096, 16, 4); // 16 % 3 != 0
268    }
269
270    #[test]
271    fn test_column_parallel_shard() {
272        // Q projection: [1024, 1024] split across 2 GPUs
273        let shard0 = ColumnParallelShard::new(1024, 1024, 0, 2);
274        let shard1 = ColumnParallelShard::new(1024, 1024, 1, 2);
275
276        assert_eq!(shard0.col_start, 0);
277        assert_eq!(shard0.col_end, 512);
278        assert_eq!(shard0.local_output_dim, 512);
279        assert_eq!(shard0.num_elements(), 1024 * 512);
280
281        assert_eq!(shard1.col_start, 512);
282        assert_eq!(shard1.col_end, 1024);
283    }
284
285    #[test]
286    fn test_column_parallel_extract() {
287        // Small example: [2, 4] split into [2, 2] each
288        let full = vec![
289            1.0, 2.0, 3.0, 4.0, // row 0
290            5.0, 6.0, 7.0, 8.0, // row 1
291        ];
292        let shard0 = ColumnParallelShard::new(2, 4, 0, 2);
293        let shard1 = ColumnParallelShard::new(2, 4, 1, 2);
294
295        let s0 = shard0.extract_shard(&full, 4);
296        assert_eq!(s0, vec![1.0, 2.0, 5.0, 6.0]);
297
298        let s1 = shard1.extract_shard(&full, 4);
299        assert_eq!(s1, vec![3.0, 4.0, 7.0, 8.0]);
300    }
301
302    #[test]
303    fn test_row_parallel_shard() {
304        // O projection: [1024, 1024] split across 2 GPUs
305        let shard0 = RowParallelShard::new(1024, 1024, 0, 2);
306        let shard1 = RowParallelShard::new(1024, 1024, 1, 2);
307
308        assert_eq!(shard0.row_start, 0);
309        assert_eq!(shard0.row_end, 512);
310        assert_eq!(shard0.num_elements(), 512 * 1024);
311
312        assert_eq!(shard1.row_start, 512);
313        assert_eq!(shard1.row_end, 1024);
314    }
315
316    #[test]
317    fn test_row_parallel_extract() {
318        // Small example: [4, 2] split into [2, 2] each
319        let full = vec![
320            1.0, 2.0, // row 0
321            3.0, 4.0, // row 1
322            5.0, 6.0, // row 2
323            7.0, 8.0, // row 3
324        ];
325        let shard0 = RowParallelShard::new(4, 2, 0, 2);
326        let shard1 = RowParallelShard::new(4, 2, 1, 2);
327
328        let s0 = shard0.extract_shard(&full, 4);
329        assert_eq!(s0, vec![1.0, 2.0, 3.0, 4.0]);
330
331        let s1 = shard1.extract_shard(&full, 4);
332        assert_eq!(s1, vec![5.0, 6.0, 7.0, 8.0]);
333    }
334
335    #[test]
336    fn test_tp_comm_cost() {
337        // 350M: S=1024, H=1024, L=24
338        let cost = TpCommCost::estimate(1024, 1024, 24);
339        assert_eq!(cost.bytes_per_allreduce, 1024 * 1024 * 4); // 4 MB
340        assert_eq!(cost.allreduces_per_block, 2);
341        assert_eq!(cost.total_bytes_per_step(), 4 * 1024 * 1024 * 2 * 24); // 192 MB
342
343        // At 100 Gbps NVLink: ~15 ms overhead
344        let overhead = cost.estimated_overhead_ms(100.0);
345        assert!(overhead > 0.0);
346        assert!(overhead < 100.0); // sanity check
347    }
348
349    #[test]
350    fn test_tp_config_4way() {
351        let tp = TensorParallelConfig::new(2, 4, 1024, 4096, 16, 4);
352        assert_eq!(tp.local_num_heads(), 4);
353        assert_eq!(tp.local_num_kv_heads(), 1);
354        assert_eq!(tp.local_q_size(), 4 * 64);
355        assert_eq!(tp.local_intermediate_size(), 1024);
356    }
357}