use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TpError {
WorldSizeNotDivisible { dim: usize, world_size: usize },
RankOutOfRange { rank: usize, world_size: usize },
DimensionMismatch,
}
impl fmt::Display for TpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::WorldSizeNotDivisible { dim, world_size } => write!(
f,
"dimension {dim} is not divisible by world_size {world_size}"
),
Self::RankOutOfRange { rank, world_size } => write!(
f,
"rank {rank} is out of range for world_size {world_size}"
),
Self::DimensionMismatch => write!(f, "tensor dimension mismatch"),
}
}
}
impl std::error::Error for TpError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorParallelConfig {
pub world_size: usize,
pub rank: usize,
pub scatter_gather: bool,
pub sequence_parallel: bool,
}
impl Default for TensorParallelConfig {
fn default() -> Self {
Self {
world_size: 1,
rank: 0,
scatter_gather: false,
sequence_parallel: false,
}
}
}
impl TensorParallelConfig {
pub fn new(world_size: usize, rank: usize) -> Result<Self, TpError> {
if rank >= world_size {
return Err(TpError::RankOutOfRange { rank, world_size });
}
Ok(Self { world_size, rank, scatter_gather: false, sequence_parallel: false })
}
fn check_divisible(&self, dim: usize) -> Result<(), TpError> {
if dim % self.world_size != 0 {
Err(TpError::WorldSizeNotDivisible { dim, world_size: self.world_size })
} else {
Ok(())
}
}
}
pub struct ColumnParallelLinear {
pub weight: Vec<f32>,
pub in_features: usize,
pub out_features: usize,
pub local_out: usize,
pub config: TensorParallelConfig,
}
impl ColumnParallelLinear {
pub fn new(
in_features: usize,
out_features: usize,
config: TensorParallelConfig,
) -> Result<Self, TpError> {
if config.rank >= config.world_size {
return Err(TpError::RankOutOfRange {
rank: config.rank,
world_size: config.world_size,
});
}
config.check_divisible(out_features)?;
let local_out = out_features / config.world_size;
let weight = vec![0.0f32; in_features * local_out];
Ok(Self { weight, in_features, out_features, local_out, config })
}
pub fn forward(&self, x: &[f32], batch_size: usize) -> Vec<f32> {
let mut out = vec![0.0f32; batch_size * self.local_out];
for b in 0..batch_size {
for j in 0..self.local_out {
let mut acc = 0.0f32;
for k in 0..self.in_features {
acc += x[b * self.in_features + k] * self.weight[k * self.local_out + j];
}
out[b * self.local_out + j] = acc;
}
}
out
}
pub fn all_gather_output(local_outputs: &[Vec<f32>], _world_size: usize) -> Vec<f32> {
let total = local_outputs.iter().map(|v| v.len()).sum();
let mut result = Vec::with_capacity(total);
for shard in local_outputs {
result.extend_from_slice(shard);
}
result
}
pub fn all_gather_output_batched(
local_outputs: &[Vec<f32>],
batch_size: usize,
local_out: usize,
) -> Vec<f32> {
let world_size = local_outputs.len();
let out_features = world_size * local_out;
let mut result = vec![0.0f32; batch_size * out_features];
for (r, shard) in local_outputs.iter().enumerate() {
for b in 0..batch_size {
for j in 0..local_out {
result[b * out_features + r * local_out + j] =
shard[b * local_out + j];
}
}
}
result
}
}
pub struct RowParallelLinear {
pub weight: Vec<f32>,
pub in_features: usize,
pub local_in: usize,
pub out_features: usize,
pub config: TensorParallelConfig,
}
impl RowParallelLinear {
pub fn new(
in_features: usize,
out_features: usize,
config: TensorParallelConfig,
) -> Result<Self, TpError> {
if config.rank >= config.world_size {
return Err(TpError::RankOutOfRange {
rank: config.rank,
world_size: config.world_size,
});
}
config.check_divisible(in_features)?;
let local_in = in_features / config.world_size;
let weight = vec![0.0f32; local_in * out_features];
Ok(Self { weight, in_features, local_in, out_features, config })
}
pub fn forward(&self, x: &[f32], batch_size: usize) -> Vec<f32> {
let rank = self.config.rank;
let start = rank * self.local_in;
let mut out = vec![0.0f32; batch_size * self.out_features];
for b in 0..batch_size {
for j in 0..self.out_features {
let mut acc = 0.0f32;
for k in 0..self.local_in {
acc += x[b * self.in_features + start + k]
* self.weight[k * self.out_features + j];
}
out[b * self.out_features + j] = acc;
}
}
out
}
pub fn all_reduce_output(partial_outputs: &[Vec<f32>]) -> Vec<f32> {
if partial_outputs.is_empty() {
return Vec::new();
}
let len = partial_outputs[0].len();
let mut result = vec![0.0f32; len];
for partial in partial_outputs {
for (r, p) in result.iter_mut().zip(partial.iter()) {
*r += p;
}
}
result
}
}
pub struct VocabParallelEmbedding {
pub embedding: Vec<f32>,
pub local_vocab_size: usize,
pub hidden_size: usize,
pub vocab_start_idx: usize,
pub vocab_end_idx: usize,
pub config: TensorParallelConfig,
}
impl VocabParallelEmbedding {
pub fn new(
vocab_size: usize,
hidden_size: usize,
config: TensorParallelConfig,
) -> Result<Self, TpError> {
if config.rank >= config.world_size {
return Err(TpError::RankOutOfRange {
rank: config.rank,
world_size: config.world_size,
});
}
config.check_divisible(vocab_size)?;
let local_vocab_size = vocab_size / config.world_size;
let vocab_start_idx = config.rank * local_vocab_size;
let vocab_end_idx = vocab_start_idx + local_vocab_size;
let embedding = vec![0.0f32; local_vocab_size * hidden_size];
Ok(Self {
embedding,
local_vocab_size,
hidden_size,
vocab_start_idx,
vocab_end_idx,
config,
})
}
pub fn forward(&self, token_ids: &[u32]) -> Vec<f32> {
let n = token_ids.len();
let mut out = vec![0.0f32; n * self.hidden_size];
for (i, &tok) in token_ids.iter().enumerate() {
let tok_usize = tok as usize;
if tok_usize >= self.vocab_start_idx && tok_usize < self.vocab_end_idx {
let local_idx = tok_usize - self.vocab_start_idx;
let src = local_idx * self.hidden_size;
let dst = i * self.hidden_size;
out[dst..dst + self.hidden_size]
.copy_from_slice(&self.embedding[src..src + self.hidden_size]);
}
}
out
}
pub fn all_reduce_embeddings(partial: &[Vec<f32>]) -> Vec<f32> {
if partial.is_empty() {
return Vec::new();
}
let len = partial[0].len();
let mut result = vec![0.0f32; len];
for p in partial {
for (r, v) in result.iter_mut().zip(p.iter()) {
*r += v;
}
}
result
}
}
pub enum TensorParallelLinear {
Column(ColumnParallelLinear),
Row(RowParallelLinear),
}
impl TensorParallelLinear {
pub fn forward(&self, x: &[f32], batch_size: usize) -> Vec<f32> {
match self {
Self::Column(layer) => layer.forward(x, batch_size),
Self::Row(layer) => layer.forward(x, batch_size),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn identity_weight(rows: usize, cols: usize) -> Vec<f32> {
let mut w = vec![0.0f32; rows * cols];
for i in 0..rows.min(cols) {
w[i * cols + i] = 1.0;
}
w
}
#[test]
fn test_column_parallel_output_shape() {
let cfg = TensorParallelConfig::new(4, 0).expect("valid config");
let layer = ColumnParallelLinear::new(8, 16, cfg).expect("valid layer");
let x = vec![1.0f32; 2 * 8]; let out = layer.forward(&x, 2);
assert_eq!(out.len(), 2 * 4); }
#[test]
fn test_column_parallel_correct_output() {
let cfg = TensorParallelConfig::new(2, 0).expect("valid config");
let mut layer = ColumnParallelLinear::new(4, 4, cfg).expect("valid layer");
layer.weight = identity_weight(4, 2);
let x: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0]; let out = layer.forward(&x, 1);
assert!((out[0] - 1.0).abs() < 1e-6);
assert!((out[1] - 2.0).abs() < 1e-6);
}
#[test]
fn test_row_parallel_correct_output() {
let cfg = TensorParallelConfig::new(2, 0).expect("valid config");
let mut layer = RowParallelLinear::new(4, 4, cfg).expect("valid layer");
layer.weight = identity_weight(2, 4);
let x: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0]; let partial = layer.forward(&x, 1);
assert!((partial[0] - 1.0).abs() < 1e-6);
assert!((partial[1] - 2.0).abs() < 1e-6);
assert!((partial[2]).abs() < 1e-6);
assert!((partial[3]).abs() < 1e-6);
}
#[test]
fn test_vocab_parallel_in_range() {
let cfg = TensorParallelConfig::new(4, 1).expect("valid config");
let mut emb = VocabParallelEmbedding::new(8, 3, cfg).expect("valid embedding");
emb.embedding[0] = 1.0;
emb.embedding[1] = 2.0;
emb.embedding[2] = 3.0;
let out = emb.forward(&[2u32]);
assert!((out[0] - 1.0).abs() < 1e-6);
assert!((out[1] - 2.0).abs() < 1e-6);
assert!((out[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_vocab_parallel_out_of_range() {
let cfg = TensorParallelConfig::new(4, 0).expect("valid config");
let emb = VocabParallelEmbedding::new(8, 3, cfg).expect("valid embedding");
let out = emb.forward(&[5u32]);
assert!(out.iter().all(|&v| v == 0.0));
}
#[test]
fn test_all_gather_output_concatenation() {
let shard0 = vec![1.0f32, 2.0, 3.0]; let shard1 = vec![4.0f32, 5.0, 6.0];
let gathered =
ColumnParallelLinear::all_gather_output(&[shard0, shard1], 2);
assert_eq!(gathered, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_all_gather_output_batched() {
let shard0 = vec![1.0f32, 2.0, 3.0, 4.0]; let shard1 = vec![5.0f32, 6.0, 7.0, 8.0];
let gathered = ColumnParallelLinear::all_gather_output_batched(
&[shard0, shard1],
2,
2,
);
assert_eq!(gathered, vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]);
}
#[test]
fn test_all_reduce_output_sum() {
let p0 = vec![1.0f32, 0.0, 0.0, 0.0];
let p1 = vec![0.0f32, 2.0, 0.0, 0.0];
let p2 = vec![0.0f32, 0.0, 3.0, 0.0];
let reduced = RowParallelLinear::all_reduce_output(&[p0, p1, p2]);
assert!((reduced[0] - 1.0).abs() < 1e-6);
assert!((reduced[1] - 2.0).abs() < 1e-6);
assert!((reduced[2] - 3.0).abs() < 1e-6);
assert!((reduced[3]).abs() < 1e-6);
}
#[test]
fn test_single_rank_column_linear() {
let cfg = TensorParallelConfig::default(); let mut layer = ColumnParallelLinear::new(3, 3, cfg).expect("valid layer");
layer.weight = identity_weight(3, 3);
let x = vec![1.0f32, 2.0, 3.0];
let out = layer.forward(&x, 1);
assert!((out[0] - 1.0).abs() < 1e-6);
assert!((out[1] - 2.0).abs() < 1e-6);
assert!((out[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_world_size_not_divisible_error() {
let cfg = TensorParallelConfig::new(4, 0).expect("valid config");
let result = ColumnParallelLinear::new(8, 7, cfg); assert!(matches!(
result,
Err(TpError::WorldSizeNotDivisible { dim: 7, world_size: 4 })
));
}
#[test]
fn test_rank_out_of_range_error() {
let result = TensorParallelConfig::new(4, 4); assert!(matches!(
result,
Err(TpError::RankOutOfRange { rank: 4, world_size: 4 })
));
}
#[test]
fn test_vocab_parallel_each_rank() {
let hidden = 2usize;
let mut all_results: Vec<Vec<f32>> = Vec::new();
for r in 0..4usize {
let cfg = TensorParallelConfig::new(4, r).expect("valid config");
let mut emb = VocabParallelEmbedding::new(8, hidden, cfg)
.expect("valid embedding");
for local_i in 0..2usize {
let global_tok = r * 2 + local_i;
emb.embedding[local_i * hidden] = global_tok as f32 * 10.0;
emb.embedding[local_i * hidden + 1] = global_tok as f32 * 10.0 + 1.0;
}
let out = emb.forward(&[0u32, 1, 2, 3, 4, 5, 6, 7]);
all_results.push(out);
}
let final_emb = VocabParallelEmbedding::all_reduce_embeddings(&all_results);
assert!((final_emb[0] - 0.0).abs() < 1e-6);
assert!((final_emb[1] - 1.0).abs() < 1e-6);
assert!((final_emb[3 * hidden] - 30.0).abs() < 1e-6);
assert!((final_emb[3 * hidden + 1] - 31.0).abs() < 1e-6);
assert!((final_emb[7 * hidden] - 70.0).abs() < 1e-6);
assert!((final_emb[7 * hidden + 1] - 71.0).abs() < 1e-6);
}
#[test]
fn test_tensor_parallel_linear_dispatch() {
let cfg = TensorParallelConfig::new(2, 0).expect("valid config");
let layer = TensorParallelLinear::Column(
ColumnParallelLinear::new(4, 4, cfg).expect("valid layer"),
);
let x = vec![0.0f32; 1 * 4];
let out = layer.forward(&x, 1);
assert_eq!(out.len(), 2); }
#[test]
fn test_attention_head_split() {
let world_size = 4usize;
let in_features = 64usize;
let out_features = 128usize; let mut shards: Vec<Vec<f32>> = Vec::new();
let mut layers: Vec<ColumnParallelLinear> = Vec::new();
for r in 0..world_size {
let cfg = TensorParallelConfig::new(world_size, r).expect("valid config");
let layer = ColumnParallelLinear::new(in_features, out_features, cfg)
.expect("valid layer");
layers.push(layer);
}
let x = vec![1.0f32; in_features]; for layer in &layers {
shards.push(layer.forward(&x, 1));
}
assert!(shards.iter().all(|s| s.len() == 32));
let gathered = ColumnParallelLinear::all_gather_output_batched(&shards, 1, 32);
assert_eq!(gathered.len(), out_features);
}
#[test]
fn test_row_parallel_multi_rank_simulation() {
let world_size = 2usize;
let in_features = 4usize;
let out_features = 4usize;
let mut partials: Vec<Vec<f32>> = Vec::new();
let x: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
for r in 0..world_size {
let cfg = TensorParallelConfig::new(world_size, r).expect("valid config");
let mut layer =
RowParallelLinear::new(in_features, out_features, cfg).expect("valid layer");
if r == 0 {
layer.weight = identity_weight(2, 4);
} else {
let mut w = vec![0.0f32; 2 * 4];
w[0 * 4 + 2] = 1.0; w[1 * 4 + 3] = 1.0; layer.weight = w;
}
partials.push(layer.forward(&x, 1));
}
let result = RowParallelLinear::all_reduce_output(&partials);
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[1] - 2.0).abs() < 1e-6);
assert!((result[2] - 3.0).abs() < 1e-6);
assert!((result[3] - 4.0).abs() < 1e-6);
}
}