#[derive(Debug, Clone)]
pub struct TensorParallelConfig {
pub tp_rank: usize,
pub tp_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
}
impl TensorParallelConfig {
pub fn new(
tp_rank: usize,
tp_size: usize,
hidden_size: usize,
intermediate_size: usize,
num_heads: usize,
num_kv_heads: usize,
) -> Self {
assert!(
num_heads.is_multiple_of(tp_size),
"num_heads ({num_heads}) must be divisible by tp_size ({tp_size})"
);
assert!(
num_kv_heads.is_multiple_of(tp_size),
"num_kv_heads ({num_kv_heads}) must be divisible by tp_size ({tp_size})"
);
assert!(
intermediate_size.is_multiple_of(tp_size),
"intermediate_size ({intermediate_size}) must be divisible by tp_size ({tp_size})"
);
let head_dim = hidden_size / num_heads;
Self { tp_rank, tp_size, hidden_size, intermediate_size, num_heads, num_kv_heads, head_dim }
}
pub fn local_num_heads(&self) -> usize {
self.num_heads / self.tp_size
}
pub fn local_num_kv_heads(&self) -> usize {
self.num_kv_heads / self.tp_size
}
pub fn local_q_size(&self) -> usize {
self.local_num_heads() * self.head_dim
}
pub fn local_kv_size(&self) -> usize {
self.local_num_kv_heads() * self.head_dim
}
pub fn local_intermediate_size(&self) -> usize {
self.intermediate_size / self.tp_size
}
pub fn weight_memory_fraction(&self) -> f64 {
1.0 / self.tp_size as f64
}
}
#[derive(Debug, Clone)]
pub struct ColumnParallelShard {
pub input_dim: usize,
pub local_output_dim: usize,
pub col_start: usize,
pub col_end: usize,
}
impl ColumnParallelShard {
pub fn new(input_dim: usize, full_output_dim: usize, tp_rank: usize, tp_size: usize) -> Self {
let local_output_dim = full_output_dim / tp_size;
let col_start = tp_rank * local_output_dim;
let col_end = col_start + local_output_dim;
Self { input_dim, local_output_dim, col_start, col_end }
}
pub fn num_elements(&self) -> usize {
self.input_dim * self.local_output_dim
}
pub fn extract_shard(&self, full_weights: &[f32], full_output_dim: usize) -> Vec<f32> {
let mut shard = Vec::with_capacity(self.num_elements());
for row in 0..self.input_dim {
let row_start = row * full_output_dim;
shard.extend_from_slice(
&full_weights[row_start + self.col_start..row_start + self.col_end],
);
}
shard
}
}
#[derive(Debug, Clone)]
pub struct RowParallelShard {
pub local_input_dim: usize,
pub output_dim: usize,
pub row_start: usize,
pub row_end: usize,
}
impl RowParallelShard {
pub fn new(full_input_dim: usize, output_dim: usize, tp_rank: usize, tp_size: usize) -> Self {
let local_input_dim = full_input_dim / tp_size;
let row_start = tp_rank * local_input_dim;
let row_end = row_start + local_input_dim;
Self { local_input_dim, output_dim, row_start, row_end }
}
pub fn num_elements(&self) -> usize {
self.local_input_dim * self.output_dim
}
pub fn extract_shard(&self, full_weights: &[f32], _full_input_dim: usize) -> Vec<f32> {
let start = self.row_start * self.output_dim;
let end = self.row_end * self.output_dim;
full_weights[start..end].to_vec()
}
}
#[derive(Debug, Clone)]
pub struct TpCommCost {
pub bytes_per_allreduce: usize,
pub allreduces_per_block: usize,
pub num_blocks: usize,
}
impl TpCommCost {
pub fn estimate(seq_len: usize, hidden_size: usize, num_blocks: usize) -> Self {
Self {
bytes_per_allreduce: seq_len * hidden_size * std::mem::size_of::<f32>(),
allreduces_per_block: 2,
num_blocks,
}
}
pub fn total_bytes_per_step(&self) -> usize {
self.bytes_per_allreduce * self.allreduces_per_block * self.num_blocks
}
pub fn estimated_overhead_ms(&self, bandwidth_gbps: f64) -> f64 {
let total_bytes = self.total_bytes_per_step() as f64;
let bandwidth_bytes_per_ms = bandwidth_gbps * 1e9 / 8.0 / 1000.0;
total_bytes / bandwidth_bytes_per_ms
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tp_config_basic() {
let tp = TensorParallelConfig::new(0, 2, 1024, 4096, 16, 4);
assert_eq!(tp.local_num_heads(), 8);
assert_eq!(tp.local_num_kv_heads(), 2);
assert_eq!(tp.local_q_size(), 8 * 64); assert_eq!(tp.local_kv_size(), 2 * 64); assert_eq!(tp.local_intermediate_size(), 2048);
assert!((tp.weight_memory_fraction() - 0.5).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "num_heads")]
fn test_tp_config_indivisible_heads() {
TensorParallelConfig::new(0, 3, 1024, 4096, 16, 4); }
#[test]
fn test_column_parallel_shard() {
let shard0 = ColumnParallelShard::new(1024, 1024, 0, 2);
let shard1 = ColumnParallelShard::new(1024, 1024, 1, 2);
assert_eq!(shard0.col_start, 0);
assert_eq!(shard0.col_end, 512);
assert_eq!(shard0.local_output_dim, 512);
assert_eq!(shard0.num_elements(), 1024 * 512);
assert_eq!(shard1.col_start, 512);
assert_eq!(shard1.col_end, 1024);
}
#[test]
fn test_column_parallel_extract() {
let full = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let shard0 = ColumnParallelShard::new(2, 4, 0, 2);
let shard1 = ColumnParallelShard::new(2, 4, 1, 2);
let s0 = shard0.extract_shard(&full, 4);
assert_eq!(s0, vec![1.0, 2.0, 5.0, 6.0]);
let s1 = shard1.extract_shard(&full, 4);
assert_eq!(s1, vec![3.0, 4.0, 7.0, 8.0]);
}
#[test]
fn test_row_parallel_shard() {
let shard0 = RowParallelShard::new(1024, 1024, 0, 2);
let shard1 = RowParallelShard::new(1024, 1024, 1, 2);
assert_eq!(shard0.row_start, 0);
assert_eq!(shard0.row_end, 512);
assert_eq!(shard0.num_elements(), 512 * 1024);
assert_eq!(shard1.row_start, 512);
assert_eq!(shard1.row_end, 1024);
}
#[test]
fn test_row_parallel_extract() {
let full = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let shard0 = RowParallelShard::new(4, 2, 0, 2);
let shard1 = RowParallelShard::new(4, 2, 1, 2);
let s0 = shard0.extract_shard(&full, 4);
assert_eq!(s0, vec![1.0, 2.0, 3.0, 4.0]);
let s1 = shard1.extract_shard(&full, 4);
assert_eq!(s1, vec![5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_tp_comm_cost() {
let cost = TpCommCost::estimate(1024, 1024, 24);
assert_eq!(cost.bytes_per_allreduce, 1024 * 1024 * 4); assert_eq!(cost.allreduces_per_block, 2);
assert_eq!(cost.total_bytes_per_step(), 4 * 1024 * 1024 * 2 * 24);
let overhead = cost.estimated_overhead_ms(100.0);
assert!(overhead > 0.0);
assert!(overhead < 100.0); }
#[test]
fn test_tp_config_4way() {
let tp = TensorParallelConfig::new(2, 4, 1024, 4096, 16, 4);
assert_eq!(tp.local_num_heads(), 4);
assert_eq!(tp.local_num_kv_heads(), 1);
assert_eq!(tp.local_q_size(), 4 * 64);
assert_eq!(tp.local_intermediate_size(), 1024);
}
}