entrenar/train/transformer_trainer/
tensor_parallel.rs1#[derive(Debug, Clone)]
44pub struct TensorParallelConfig {
45 pub tp_rank: usize,
47 pub tp_size: usize,
49 pub hidden_size: usize,
51 pub intermediate_size: usize,
53 pub num_heads: usize,
55 pub num_kv_heads: usize,
57 pub head_dim: usize,
59}
60
61impl TensorParallelConfig {
62 pub fn new(
67 tp_rank: usize,
68 tp_size: usize,
69 hidden_size: usize,
70 intermediate_size: usize,
71 num_heads: usize,
72 num_kv_heads: usize,
73 ) -> Self {
74 assert!(
75 num_heads.is_multiple_of(tp_size),
76 "num_heads ({num_heads}) must be divisible by tp_size ({tp_size})"
77 );
78 assert!(
79 num_kv_heads.is_multiple_of(tp_size),
80 "num_kv_heads ({num_kv_heads}) must be divisible by tp_size ({tp_size})"
81 );
82 assert!(
83 intermediate_size.is_multiple_of(tp_size),
84 "intermediate_size ({intermediate_size}) must be divisible by tp_size ({tp_size})"
85 );
86
87 let head_dim = hidden_size / num_heads;
88
89 Self { tp_rank, tp_size, hidden_size, intermediate_size, num_heads, num_kv_heads, head_dim }
90 }
91
92 pub fn local_num_heads(&self) -> usize {
94 self.num_heads / self.tp_size
95 }
96
97 pub fn local_num_kv_heads(&self) -> usize {
99 self.num_kv_heads / self.tp_size
100 }
101
102 pub fn local_q_size(&self) -> usize {
104 self.local_num_heads() * self.head_dim
105 }
106
107 pub fn local_kv_size(&self) -> usize {
109 self.local_num_kv_heads() * self.head_dim
110 }
111
112 pub fn local_intermediate_size(&self) -> usize {
114 self.intermediate_size / self.tp_size
115 }
116
117 pub fn weight_memory_fraction(&self) -> f64 {
121 1.0 / self.tp_size as f64
122 }
123}
124
125#[derive(Debug, Clone)]
130pub struct ColumnParallelShard {
131 pub input_dim: usize,
133 pub local_output_dim: usize,
135 pub col_start: usize,
137 pub col_end: usize,
139}
140
141impl ColumnParallelShard {
142 pub fn new(input_dim: usize, full_output_dim: usize, tp_rank: usize, tp_size: usize) -> Self {
144 let local_output_dim = full_output_dim / tp_size;
145 let col_start = tp_rank * local_output_dim;
146 let col_end = col_start + local_output_dim;
147
148 Self { input_dim, local_output_dim, col_start, col_end }
149 }
150
151 pub fn num_elements(&self) -> usize {
153 self.input_dim * self.local_output_dim
154 }
155
156 pub fn extract_shard(&self, full_weights: &[f32], full_output_dim: usize) -> Vec<f32> {
161 let mut shard = Vec::with_capacity(self.num_elements());
162 for row in 0..self.input_dim {
163 let row_start = row * full_output_dim;
164 shard.extend_from_slice(
165 &full_weights[row_start + self.col_start..row_start + self.col_end],
166 );
167 }
168 shard
169 }
170}
171
172#[derive(Debug, Clone)]
177pub struct RowParallelShard {
178 pub local_input_dim: usize,
180 pub output_dim: usize,
182 pub row_start: usize,
184 pub row_end: usize,
186}
187
188impl RowParallelShard {
189 pub fn new(full_input_dim: usize, output_dim: usize, tp_rank: usize, tp_size: usize) -> Self {
191 let local_input_dim = full_input_dim / tp_size;
192 let row_start = tp_rank * local_input_dim;
193 let row_end = row_start + local_input_dim;
194
195 Self { local_input_dim, output_dim, row_start, row_end }
196 }
197
198 pub fn num_elements(&self) -> usize {
200 self.local_input_dim * self.output_dim
201 }
202
203 pub fn extract_shard(&self, full_weights: &[f32], _full_input_dim: usize) -> Vec<f32> {
208 let start = self.row_start * self.output_dim;
209 let end = self.row_end * self.output_dim;
210 full_weights[start..end].to_vec()
211 }
212}
213
214#[derive(Debug, Clone)]
216pub struct TpCommCost {
217 pub bytes_per_allreduce: usize,
219 pub allreduces_per_block: usize,
221 pub num_blocks: usize,
223}
224
225impl TpCommCost {
226 pub fn estimate(seq_len: usize, hidden_size: usize, num_blocks: usize) -> Self {
228 Self {
229 bytes_per_allreduce: seq_len * hidden_size * std::mem::size_of::<f32>(),
230 allreduces_per_block: 2,
231 num_blocks,
232 }
233 }
234
235 pub fn total_bytes_per_step(&self) -> usize {
237 self.bytes_per_allreduce * self.allreduces_per_block * self.num_blocks
238 }
239
240 pub fn estimated_overhead_ms(&self, bandwidth_gbps: f64) -> f64 {
242 let total_bytes = self.total_bytes_per_step() as f64;
243 let bandwidth_bytes_per_ms = bandwidth_gbps * 1e9 / 8.0 / 1000.0;
244 total_bytes / bandwidth_bytes_per_ms
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_tp_config_basic() {
254 let tp = TensorParallelConfig::new(0, 2, 1024, 4096, 16, 4);
256 assert_eq!(tp.local_num_heads(), 8);
257 assert_eq!(tp.local_num_kv_heads(), 2);
258 assert_eq!(tp.local_q_size(), 8 * 64); assert_eq!(tp.local_kv_size(), 2 * 64); assert_eq!(tp.local_intermediate_size(), 2048);
261 assert!((tp.weight_memory_fraction() - 0.5).abs() < 1e-10);
262 }
263
264 #[test]
265 #[should_panic(expected = "num_heads")]
266 fn test_tp_config_indivisible_heads() {
267 TensorParallelConfig::new(0, 3, 1024, 4096, 16, 4); }
269
270 #[test]
271 fn test_column_parallel_shard() {
272 let shard0 = ColumnParallelShard::new(1024, 1024, 0, 2);
274 let shard1 = ColumnParallelShard::new(1024, 1024, 1, 2);
275
276 assert_eq!(shard0.col_start, 0);
277 assert_eq!(shard0.col_end, 512);
278 assert_eq!(shard0.local_output_dim, 512);
279 assert_eq!(shard0.num_elements(), 1024 * 512);
280
281 assert_eq!(shard1.col_start, 512);
282 assert_eq!(shard1.col_end, 1024);
283 }
284
285 #[test]
286 fn test_column_parallel_extract() {
287 let full = vec![
289 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
292 let shard0 = ColumnParallelShard::new(2, 4, 0, 2);
293 let shard1 = ColumnParallelShard::new(2, 4, 1, 2);
294
295 let s0 = shard0.extract_shard(&full, 4);
296 assert_eq!(s0, vec![1.0, 2.0, 5.0, 6.0]);
297
298 let s1 = shard1.extract_shard(&full, 4);
299 assert_eq!(s1, vec![3.0, 4.0, 7.0, 8.0]);
300 }
301
302 #[test]
303 fn test_row_parallel_shard() {
304 let shard0 = RowParallelShard::new(1024, 1024, 0, 2);
306 let shard1 = RowParallelShard::new(1024, 1024, 1, 2);
307
308 assert_eq!(shard0.row_start, 0);
309 assert_eq!(shard0.row_end, 512);
310 assert_eq!(shard0.num_elements(), 512 * 1024);
311
312 assert_eq!(shard1.row_start, 512);
313 assert_eq!(shard1.row_end, 1024);
314 }
315
316 #[test]
317 fn test_row_parallel_extract() {
318 let full = vec![
320 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
325 let shard0 = RowParallelShard::new(4, 2, 0, 2);
326 let shard1 = RowParallelShard::new(4, 2, 1, 2);
327
328 let s0 = shard0.extract_shard(&full, 4);
329 assert_eq!(s0, vec![1.0, 2.0, 3.0, 4.0]);
330
331 let s1 = shard1.extract_shard(&full, 4);
332 assert_eq!(s1, vec![5.0, 6.0, 7.0, 8.0]);
333 }
334
335 #[test]
336 fn test_tp_comm_cost() {
337 let cost = TpCommCost::estimate(1024, 1024, 24);
339 assert_eq!(cost.bytes_per_allreduce, 1024 * 1024 * 4); assert_eq!(cost.allreduces_per_block, 2);
341 assert_eq!(cost.total_bytes_per_step(), 4 * 1024 * 1024 * 2 * 24); let overhead = cost.estimated_overhead_ms(100.0);
345 assert!(overhead > 0.0);
346 assert!(overhead < 100.0); }
348
349 #[test]
350 fn test_tp_config_4way() {
351 let tp = TensorParallelConfig::new(2, 4, 1024, 4096, 16, 4);
352 assert_eq!(tp.local_num_heads(), 4);
353 assert_eq!(tp.local_num_kv_heads(), 1);
354 assert_eq!(tp.local_q_size(), 4 * 64);
355 assert_eq!(tp.local_intermediate_size(), 1024);
356 }
357}