use rayon::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardDim {
Output,
Input,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ShardInfo {
pub shard_id: usize,
pub num_shards: usize,
pub dim: ShardDim,
pub offset: usize,
pub size: usize,
}
impl ShardInfo {
pub fn new(shard_id: usize, num_shards: usize, total_size: usize, dim: ShardDim) -> Self {
assert!(num_shards > 0, "num_shards must be > 0");
let base = total_size / num_shards;
let remainder = total_size % num_shards;
let offset = shard_id * base;
let size = if shard_id + 1 == num_shards {
base + remainder
} else {
base
};
Self {
shard_id,
num_shards,
dim,
offset,
size,
}
}
pub fn slice_weights<'a>(&self, weights: &'a [f32], rows: usize, cols: usize) -> &'a [f32] {
match self.dim {
ShardDim::Output => {
let start = self.offset * cols;
let end = (self.offset + self.size) * cols;
&weights[start..end.min(rows * cols)]
}
ShardDim::Input => {
let _ = rows;
weights
}
}
}
#[inline]
pub fn is_last_shard(&self) -> bool {
self.shard_id + 1 == self.num_shards
}
#[inline]
pub fn covers_index(&self, idx: usize) -> bool {
idx >= self.offset && idx < self.offset + self.size
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorParallelMode {
ColumnParallel,
RowParallel,
}
pub struct ShardedLinear {
pub weights: Vec<f32>,
pub bias: Option<Vec<f32>>,
pub shard: ShardInfo,
pub in_features: usize,
pub out_features: usize,
}
impl ShardedLinear {
pub fn new(
weights: Vec<f32>,
shard: ShardInfo,
in_features: usize,
out_features: usize,
) -> Self {
Self {
weights,
bias: None,
shard,
in_features,
out_features,
}
}
pub fn with_bias(mut self, bias: Vec<f32>) -> Self {
self.bias = Some(bias);
self
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
match self.shard.dim {
ShardDim::Output => {
let shard_out = self.shard.size;
let in_f = self.in_features;
let mut out = vec![0.0f32; shard_out];
for (row, o) in out.iter_mut().enumerate() {
let row_start = row * in_f;
let mut acc = 0.0f32;
for (col, &inp_col) in input.iter().enumerate().take(in_f) {
acc += self.weights[row_start + col] * inp_col;
}
if let Some(ref b) = self.bias {
acc += b[row];
}
*o = acc;
}
out
}
ShardDim::Input => {
let shard_in = self.shard.size;
let in_offset = self.shard.offset;
let out_f = self.out_features;
let mut out = vec![0.0f32; out_f];
for row in 0..out_f {
let row_start = row * shard_in;
let mut acc = 0.0f32;
for col in 0..shard_in {
acc += self.weights[row_start + col] * input[in_offset + col];
}
if self.shard.is_last_shard() {
if let Some(ref b) = self.bias {
acc += b[row];
}
}
out[row] = acc;
}
out
}
}
}
pub fn shard_output_size(&self) -> usize {
match self.shard.dim {
ShardDim::Output => self.shard.size,
ShardDim::Input => self.out_features,
}
}
pub fn memory_bytes(&self) -> usize {
let w = self.weights.len() * std::mem::size_of::<f32>();
let b = self
.bias
.as_ref()
.map(|b| b.len() * std::mem::size_of::<f32>())
.unwrap_or(0);
w + b
}
}
pub fn partition_column_parallel(
weights: &[f32],
bias: Option<&[f32]>,
in_features: usize,
out_features: usize,
num_shards: usize,
) -> Vec<ShardedLinear> {
(0..num_shards)
.map(|shard_id| {
let info = ShardInfo::new(shard_id, num_shards, out_features, ShardDim::Output);
let row_start = info.offset * in_features;
let row_end = (info.offset + info.size) * in_features;
let shard_weights = weights[row_start..row_end].to_vec();
let shard_bias = bias.map(|b| b[info.offset..info.offset + info.size].to_vec());
let mut sl = ShardedLinear::new(shard_weights, info, in_features, out_features);
if let Some(b) = shard_bias {
sl = sl.with_bias(b);
}
sl
})
.collect()
}
pub fn partition_row_parallel(
weights: &[f32],
bias: Option<&[f32]>,
in_features: usize,
out_features: usize,
num_shards: usize,
) -> Vec<ShardedLinear> {
(0..num_shards)
.map(|shard_id| {
let info = ShardInfo::new(shard_id, num_shards, in_features, ShardDim::Input);
let mut shard_weights = Vec::with_capacity(out_features * info.size);
for row in 0..out_features {
let row_base = row * in_features;
shard_weights.extend_from_slice(
&weights[row_base + info.offset..row_base + info.offset + info.size],
);
}
let shard_bias = if info.is_last_shard() {
bias.map(|b| b.to_vec())
} else {
None
};
let mut sl = ShardedLinear::new(shard_weights, info, in_features, out_features);
if let Some(b) = shard_bias {
sl = sl.with_bias(b);
}
sl
})
.collect()
}
pub fn all_reduce(partials: &[Vec<f32>]) -> Vec<f32> {
if partials.is_empty() {
return Vec::new();
}
let len = partials[0].len();
let mut result = vec![0.0f32; len];
for partial in partials {
for (r, &p) in result.iter_mut().zip(partial.iter()) {
*r += p;
}
}
result
}
pub fn all_gather(partials: &[Vec<f32>]) -> Vec<f32> {
let total: usize = partials.iter().map(|v| v.len()).sum();
let mut result = Vec::with_capacity(total);
for partial in partials {
result.extend_from_slice(partial);
}
result
}
pub fn tensor_parallel_forward(
shards: &[ShardedLinear],
input: &[f32],
parallel_mode: TensorParallelMode,
) -> Vec<f32> {
let partials: Vec<Vec<f32>> = shards
.par_iter()
.map(|shard| shard.forward(input))
.collect();
match parallel_mode {
TensorParallelMode::ColumnParallel => all_gather(&partials),
TensorParallelMode::RowParallel => all_reduce(&partials),
}
}
#[derive(Debug, Clone)]
pub struct LayerSharding {
pub layer_name: String,
pub mode: TensorParallelMode,
pub num_shards: usize,
}
pub struct ShardingPlan {
pub num_shards: usize,
pub layer_assignments: Vec<LayerSharding>,
}
impl ShardingPlan {
pub fn new(num_shards: usize) -> Self {
Self {
num_shards,
layer_assignments: Vec::new(),
}
}
pub fn add_layer(&mut self, name: &str, mode: TensorParallelMode) {
self.layer_assignments.push(LayerSharding {
layer_name: name.to_owned(),
mode,
num_shards: self.num_shards,
});
}
pub fn standard_transformer_plan(num_shards: usize, num_layers: usize) -> Self {
let mut plan = Self::new(num_shards);
for layer in 0..num_layers {
let prefix = format!("blk.{layer}");
for suffix in &["attn_q", "attn_k", "attn_v"] {
plan.add_layer(
&format!("{prefix}.{suffix}"),
TensorParallelMode::ColumnParallel,
);
}
plan.add_layer(
&format!("{prefix}.attn_output"),
TensorParallelMode::RowParallel,
);
for suffix in &["ffn_gate", "ffn_up"] {
plan.add_layer(
&format!("{prefix}.{suffix}"),
TensorParallelMode::ColumnParallel,
);
}
plan.add_layer(
&format!("{prefix}.ffn_down"),
TensorParallelMode::RowParallel,
);
}
plan
}
pub fn get(&self, layer_name: &str) -> Option<&LayerSharding> {
self.layer_assignments
.iter()
.find(|a| a.layer_name == layer_name)
}
pub fn total_weight_memory_estimate(
&self,
hidden: usize,
intermediate: usize,
num_layers: usize,
) -> usize {
let attn_params = 4 * hidden * hidden;
let ffn_params = 2 * intermediate * hidden + hidden * intermediate;
let total_params = num_layers * (attn_params + ffn_params);
let per_device = total_params / self.num_shards.max(1);
per_device * std::mem::size_of::<f32>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shard_info_even_split() {
let info = ShardInfo::new(0, 4, 16, ShardDim::Output);
assert_eq!(info.offset, 0);
assert_eq!(info.size, 4);
let info2 = ShardInfo::new(3, 4, 16, ShardDim::Output);
assert_eq!(info2.offset, 12);
assert_eq!(info2.size, 4);
}
#[test]
fn test_shard_info_uneven_split_last_gets_remainder() {
let s0 = ShardInfo::new(0, 3, 10, ShardDim::Output);
let s1 = ShardInfo::new(1, 3, 10, ShardDim::Output);
let s2 = ShardInfo::new(2, 3, 10, ShardDim::Output);
assert_eq!(s0.size, 3);
assert_eq!(s1.size, 3);
assert_eq!(s2.size, 4); assert_eq!(s0.offset + s0.size, s1.offset);
assert_eq!(s1.offset + s1.size, s2.offset);
assert_eq!(s2.offset + s2.size, 10);
}
#[test]
fn test_shard_info_covers_index() {
let info = ShardInfo::new(1, 4, 16, ShardDim::Output);
assert!(!info.covers_index(3));
assert!(info.covers_index(4));
assert!(info.covers_index(7));
assert!(!info.covers_index(8));
}
#[test]
fn test_partition_column_parallel_count() {
let weights = vec![1.0f32; 8 * 4]; let shards = partition_column_parallel(&weights, None, 4, 8, 4);
assert_eq!(shards.len(), 4);
}
#[test]
fn test_partition_column_parallel_output_sizes() {
let weights = vec![1.0f32; 8 * 4];
let shards = partition_column_parallel(&weights, None, 4, 8, 4);
for shard in &shards {
assert_eq!(shard.weights.len(), 2 * 4);
assert_eq!(shard.shard_output_size(), 2);
}
}
#[test]
fn test_partition_row_parallel_count() {
let weights = vec![1.0f32; 4 * 8]; let shards = partition_row_parallel(&weights, None, 8, 4, 4);
assert_eq!(shards.len(), 4);
}
#[test]
fn test_sharded_linear_forward_column() {
let weights = vec![1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0];
let info = ShardInfo::new(0, 1, 2, ShardDim::Output);
let sl = ShardedLinear::new(weights, info, 3, 2);
let input = vec![5.0f32, 7.0, 9.0];
let out = sl.forward(&input);
assert_eq!(out.len(), 2);
assert!((out[0] - 5.0).abs() < 1e-6);
assert!((out[1] - 7.0).abs() < 1e-6);
}
#[test]
fn test_all_reduce_sums_correctly() {
let p1 = vec![1.0f32, 2.0, 3.0];
let p2 = vec![4.0f32, 5.0, 6.0];
let p3 = vec![7.0f32, 8.0, 9.0];
let result = all_reduce(&[p1, p2, p3]);
assert_eq!(result, vec![12.0f32, 15.0, 18.0]);
}
#[test]
fn test_all_gather_concatenates() {
let p1 = vec![1.0f32, 2.0];
let p2 = vec![3.0f32, 4.0];
let p3 = vec![5.0f32, 6.0];
let result = all_gather(&[p1, p2, p3]);
assert_eq!(result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_tensor_parallel_forward_column() {
let weights = vec![1.0f32, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
let shards = partition_column_parallel(&weights, None, 2, 4, 2);
let input = vec![3.0f32, 7.0];
let out = tensor_parallel_forward(&shards, &input, TensorParallelMode::ColumnParallel);
assert_eq!(out.len(), 4);
assert!((out[0] - 3.0).abs() < 1e-6);
assert!((out[1] - 7.0).abs() < 1e-6);
assert!((out[2] - 3.0).abs() < 1e-6);
assert!((out[3] - 7.0).abs() < 1e-6);
}
#[test]
fn test_tensor_parallel_forward_row() {
let weights = vec![1.0f32, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0];
let shards = partition_row_parallel(&weights, None, 4, 2, 2);
let input = vec![1.0f32, 2.0, 3.0, 4.0]; let out = tensor_parallel_forward(&shards, &input, TensorParallelMode::RowParallel);
assert_eq!(out.len(), 2);
assert!((out[0] - 10.0).abs() < 1e-5, "out[0]={}", out[0]);
assert!((out[1] - 20.0).abs() < 1e-5, "out[1]={}", out[1]);
}
#[test]
fn test_sharding_plan_standard_transformer() {
let plan = ShardingPlan::standard_transformer_plan(4, 2);
assert_eq!(plan.layer_assignments.len(), 14);
}
#[test]
fn test_sharding_plan_get_layer() {
let plan = ShardingPlan::standard_transformer_plan(4, 3);
let q = plan.get("blk.0.attn_q").expect("layer should exist");
assert_eq!(q.mode, TensorParallelMode::ColumnParallel);
assert_eq!(q.num_shards, 4);
let down = plan.get("blk.2.ffn_down").expect("layer should exist");
assert_eq!(down.mode, TensorParallelMode::RowParallel);
assert!(plan.get("blk.99.ffn_up").is_none());
}
}