#![allow(dead_code)] #![allow(deprecated)]
use super::model_parallel::{DistributedTensor, ModelParallelContext, TensorPartition};
use crate::errors::{invalid_input, Result};
use crate::Tensor;
use std::sync::Arc;
pub struct TensorParallelOps {
mp_context: Arc<ModelParallelContext>,
}
impl TensorParallelOps {
pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
Self { mp_context }
}
pub fn broadcast_tensor(&self, tensor: &Tensor, root_rank: usize) -> Result<Tensor> {
let mut tensor_copy = tensor.clone();
self.mp_context.communicator.broadcast(&mut tensor_copy, root_rank)?;
Ok(tensor_copy)
}
pub fn scatter_tensor(&self, tensor: &Tensor, scatter_dim: usize) -> Result<DistributedTensor> {
self.mp_context.partition_tensor(tensor, scatter_dim)
}
pub fn gather_tensor(&self, distributed: &DistributedTensor) -> Result<Tensor> {
self.mp_context.all_gather(distributed)
}
pub fn all_to_all(
&self,
tensor: &Tensor,
split_dim: usize,
concat_dim: usize,
) -> Result<Tensor> {
let world_size = self.mp_context.world_size();
let rank = self.mp_context.rank();
let shape = tensor.shape();
if split_dim >= shape.len() || concat_dim >= shape.len() {
return Err(invalid_input(format!(
"Invalid dimensions: split_dim={}, concat_dim={}, tensor_dims={}",
split_dim,
concat_dim,
shape.len()
)));
}
if world_size == 1 {
return Ok(tensor.clone());
}
let split_size = shape[split_dim];
if split_size < world_size {
return Err(invalid_input(format!(
"Split dimension size ({}) must be at least world_size ({})",
split_size, world_size
)));
}
let chunk_size = split_size / world_size;
let remainder = split_size % world_size;
let mut send_chunks = Vec::with_capacity(world_size);
let mut current_offset = 0;
for i in 0..world_size {
let current_chunk_size = if i < remainder { chunk_size + 1 } else { chunk_size };
if current_chunk_size > 0 {
let chunk = tensor.slice(
split_dim,
current_offset,
current_offset + current_chunk_size,
)?;
send_chunks.push(chunk);
} else {
let mut empty_shape = shape.to_vec();
empty_shape[split_dim] = 0;
send_chunks.push(Tensor::zeros(&empty_shape)?);
}
current_offset += current_chunk_size;
}
let mut receive_chunks = Vec::with_capacity(world_size);
for i in 0..world_size {
let sender_chunk_size = if i < remainder { chunk_size + 1 } else { chunk_size };
if sender_chunk_size > 0 {
let mut recv_shape = shape.to_vec();
recv_shape[split_dim] = sender_chunk_size;
receive_chunks.push(Tensor::zeros(&recv_shape)?);
} else {
let mut empty_shape = shape.to_vec();
empty_shape[split_dim] = 0;
receive_chunks.push(Tensor::zeros(&empty_shape)?);
}
}
for phase in 0..world_size {
let send_to = (rank + phase) % world_size;
let recv_from = (rank + world_size - phase) % world_size;
if phase == 0 {
receive_chunks[rank] = send_chunks[rank].clone();
} else {
let send_data = send_chunks[send_to].clone();
receive_chunks[recv_from] =
self.simulate_point_to_point_exchange(&send_data, recv_from)?;
}
}
let non_empty_chunks: Vec<_> = receive_chunks
.into_iter()
.filter(|chunk| {
let chunk_shape = chunk.shape();
chunk_shape.iter().all(|&dim| dim > 0)
})
.collect();
if non_empty_chunks.is_empty() {
let mut result_shape = shape.to_vec();
result_shape[concat_dim] = 0;
return Tensor::zeros(&result_shape);
}
let result = if non_empty_chunks.len() == 1 {
non_empty_chunks
.into_iter()
.next()
.expect("non_empty_chunks validated to have exactly 1 element")
} else {
self.concatenate_tensors(&non_empty_chunks, concat_dim)?
};
Ok(result)
}
fn simulate_point_to_point_exchange(
&self,
data: &Tensor,
_target_rank: usize,
) -> Result<Tensor> {
Ok(data.clone())
}
fn concatenate_tensors(&self, tensors: &[Tensor], concat_dim: usize) -> Result<Tensor> {
if tensors.is_empty() {
return Err(invalid_input(
"Cannot concatenate empty tensor list".to_string(),
));
}
if tensors.len() == 1 {
return Ok(tensors[0].clone());
}
let first_shape = tensors[0].shape();
for (i, tensor) in tensors.iter().enumerate().skip(1) {
let shape = tensor.shape();
if shape.len() != first_shape.len() {
return Err(invalid_input(format!(
"Tensor {} has {} dimensions, expected {}",
i,
shape.len(),
first_shape.len()
)));
}
for (dim_idx, (&dim_size, &expected_size)) in
shape.iter().zip(first_shape.iter()).enumerate()
{
if dim_idx != concat_dim && dim_size != expected_size {
return Err(invalid_input(format!(
"Tensor {} has size {} in dimension {}, expected {}",
i, dim_size, dim_idx, expected_size
)));
}
}
}
let result = crate::tensor::Tensor::concat(tensors, concat_dim)?;
Ok(result)
}
pub fn split_column_wise(&self, tensor: &Tensor) -> Result<Tensor> {
let world_size = self.mp_context.world_size();
let rank = self.mp_context.rank();
let shape = tensor.shape();
let last_dim = shape.len() - 1;
let columns = shape[last_dim];
let columns_per_rank = columns.div_ceil(world_size);
let start_col = rank * columns_per_rank;
let end_col = ((rank + 1) * columns_per_rank).min(columns);
tensor.slice(last_dim, start_col, end_col)
}
pub fn split_row_wise(&self, tensor: &Tensor) -> Result<Tensor> {
let world_size = self.mp_context.world_size();
let rank = self.mp_context.rank();
let shape = tensor.shape();
let second_last_dim = shape.len() - 2;
let rows = shape[second_last_dim];
let rows_per_rank = rows.div_ceil(world_size);
let start_row = rank * rows_per_rank;
let end_row = ((rank + 1) * rows_per_rank).min(rows);
tensor.slice(second_last_dim, start_row, end_row)
}
pub fn reduce_scatter_sum(&self, tensor: &Tensor, scatter_dim: usize) -> Result<Tensor> {
self.mp_context.reduce_scatter(tensor, scatter_dim)
}
pub fn all_reduce_sum(&self, tensor: &mut Tensor) -> Result<()> {
self.mp_context.all_reduce(tensor)
}
pub fn create_distributed_tensor(
&self,
global_shape: &[usize],
sharding_strategy: ShardingStrategy,
) -> Result<DistributedTensor> {
let world_size = self.mp_context.world_size();
let rank = self.mp_context.rank();
match sharding_strategy {
ShardingStrategy::RowWise(dim) => {
if dim >= global_shape.len() {
return Err(invalid_input(format!(
"Sharding dimension {} out of bounds for shape {:?}",
dim, global_shape
)));
}
let dim_size = global_shape[dim];
let chunk_size = dim_size.div_ceil(world_size);
let start_idx = rank * chunk_size;
let end_idx = (start_idx + chunk_size).min(dim_size);
let mut local_shape = global_shape.to_vec();
local_shape[dim] = end_idx - start_idx;
let local_shard = Tensor::zeros(&local_shape)?;
Ok(DistributedTensor {
local_shard,
global_shape: global_shape.to_vec(),
partition: TensorPartition {
split_dim: dim,
start_idx,
end_idx,
num_partitions: world_size,
partition_rank: rank,
},
device_id: rank,
})
},
ShardingStrategy::ColumnWise(dim) => {
if dim >= global_shape.len() {
return Err(invalid_input(format!(
"Sharding dimension {} out of bounds for shape {:?}",
dim, global_shape
)));
}
let dim_size = global_shape[dim];
let chunk_size = dim_size.div_ceil(world_size);
let start_idx = rank * chunk_size;
let end_idx = (start_idx + chunk_size).min(dim_size);
let mut local_shape = global_shape.to_vec();
local_shape[dim] = end_idx - start_idx;
let local_shard = Tensor::zeros(&local_shape)?;
Ok(DistributedTensor {
local_shard,
global_shape: global_shape.to_vec(),
partition: TensorPartition {
split_dim: dim,
start_idx,
end_idx,
num_partitions: world_size,
partition_rank: rank,
},
device_id: rank,
})
},
ShardingStrategy::Replicated => {
let local_shard = Tensor::zeros(global_shape)?;
Ok(DistributedTensor {
local_shard,
global_shape: global_shape.to_vec(),
partition: TensorPartition {
split_dim: 0,
start_idx: 0,
end_idx: global_shape[0],
num_partitions: 1, partition_rank: rank,
},
device_id: rank,
})
},
ShardingStrategy::Block2D { row_dim, col_dim } => {
if row_dim >= global_shape.len() || col_dim >= global_shape.len() {
return Err(invalid_input(format!(
"Sharding dimensions ({}, {}) out of bounds for shape {:?}",
row_dim, col_dim, global_shape
)));
}
let grid_size = (world_size as f32).sqrt() as usize;
let row_rank = rank / grid_size;
let col_rank = rank % grid_size;
let row_size = global_shape[row_dim];
let col_size = global_shape[col_dim];
let row_chunk = row_size.div_ceil(grid_size);
let col_chunk = col_size.div_ceil(grid_size);
let row_start = row_rank * row_chunk;
let row_end = (row_start + row_chunk).min(row_size);
let col_start = col_rank * col_chunk;
let col_end = (col_start + col_chunk).min(col_size);
let mut local_shape = global_shape.to_vec();
local_shape[row_dim] = row_end - row_start;
local_shape[col_dim] = col_end - col_start;
let local_shard = Tensor::zeros(&local_shape)?;
Ok(DistributedTensor {
local_shard,
global_shape: global_shape.to_vec(),
partition: TensorPartition {
split_dim: row_dim,
start_idx: row_start,
end_idx: row_end,
num_partitions: world_size,
partition_rank: rank,
},
device_id: rank,
})
},
}
}
pub fn sync_gradients(&self, tensors: &mut [Tensor], clip_norm: Option<f32>) -> Result<()> {
if let Some(max_norm) = clip_norm {
let total_norm = Self::compute_gradient_norm(tensors)?;
if total_norm > max_norm {
let clip_coef = max_norm / total_norm;
for tensor in tensors.iter_mut() {
let data = tensor.data_mut()?;
for value in data.iter_mut() {
*value *= clip_coef;
}
}
}
}
for tensor in tensors.iter_mut() {
self.all_reduce_sum(tensor)?;
let world_size = self.mp_context.world_size() as f32;
let data = tensor.data_mut()?;
for value in data.iter_mut() {
*value /= world_size;
}
}
Ok(())
}
fn compute_gradient_norm(tensors: &[Tensor]) -> Result<f32> {
let mut total_norm_sq = 0.0f32;
for tensor in tensors {
let data = tensor.data()?;
for value in data.iter() {
total_norm_sq += value * value;
}
}
Ok(total_norm_sq.sqrt())
}
}
pub struct TensorParallelInit;
impl TensorParallelInit {
pub fn column_parallel_weight(
in_features: usize,
out_features: usize,
world_size: usize,
rank: usize,
init_method: InitMethod,
) -> Result<Tensor> {
let out_features_per_rank = out_features.div_ceil(world_size);
let local_out_features = if rank == world_size - 1 {
out_features - rank * out_features_per_rank
} else {
out_features_per_rank
};
Self::init_weight(&[in_features, local_out_features], init_method)
}
pub fn row_parallel_weight(
in_features: usize,
out_features: usize,
world_size: usize,
rank: usize,
init_method: InitMethod,
) -> Result<Tensor> {
let in_features_per_rank = in_features.div_ceil(world_size);
let local_in_features = if rank == world_size - 1 {
in_features - rank * in_features_per_rank
} else {
in_features_per_rank
};
Self::init_weight(&[local_in_features, out_features], init_method)
}
fn init_weight(shape: &[usize], method: InitMethod) -> Result<Tensor> {
use scirs2_core::random::*;
let mut rng = thread_rng();
let size = shape.iter().product();
match method {
InitMethod::Normal { mean, std } => {
let normal = Normal::new(mean as f64, std as f64)
.map_err(|e| invalid_input(format!("Normal distribution parameters: {}", e)))?;
let data: Vec<f32> = (0..size).map(|_| normal.sample(&mut rng) as f32).collect();
Ok(Tensor::from_data(data, shape)?)
},
InitMethod::Uniform { low, high } => {
if low >= high {
return Err(invalid_input(format!(
"Uniform distribution bounds: low ({}) must be less than high ({})",
low, high
)));
}
let uniform = UniformDist::new(low as f64, high as f64).map_err(|e| {
invalid_input(format!("Uniform distribution parameters: {}", e))
})?;
let data: Vec<f32> = (0..size).map(|_| uniform.sample(&mut rng) as f32).collect();
Ok(Tensor::from_data(data, shape)?)
},
InitMethod::Xavier => {
let fan_in = shape[0] as f32;
let fan_out = shape[1] as f32;
let std = (2.0 / (fan_in + fan_out)).sqrt();
let normal = Normal::new(0.0, std as f64).map_err(|e| {
invalid_input(format!("Xavier initialization parameters: {}", e))
})?;
let data: Vec<f32> = (0..size).map(|_| normal.sample(&mut rng) as f32).collect();
Ok(Tensor::from_data(data, shape)?)
},
InitMethod::Kaiming => {
let fan_in = shape[0] as f32;
let std = (2.0 / fan_in).sqrt();
let normal = Normal::new(0.0, std as f64).map_err(|e| {
invalid_input(format!("Kaiming initialization parameters: {}", e))
})?;
let data: Vec<f32> = (0..size).map(|_| normal.sample(&mut rng) as f32).collect();
Ok(Tensor::from_data(data, shape)?)
},
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum ShardingStrategy {
RowWise(usize),
ColumnWise(usize),
Replicated,
Block2D { row_dim: usize, col_dim: usize },
}
#[derive(Debug, Clone, Copy)]
pub enum InitMethod {
Normal { mean: f32, std: f32 },
Uniform { low: f32, high: f32 },
Xavier,
Kaiming,
}
pub struct TensorParallelShapes;
impl TensorParallelShapes {
pub fn column_parallel_shape(
global_shape: &[usize],
world_size: usize,
rank: usize,
) -> Vec<usize> {
let mut local_shape = global_shape.to_vec();
let last_dim = local_shape.len() - 1;
let global_columns = global_shape[last_dim];
let columns_per_rank = global_columns.div_ceil(world_size);
local_shape[last_dim] = if rank == world_size - 1 {
global_columns - rank * columns_per_rank
} else {
columns_per_rank
};
local_shape
}
pub fn row_parallel_shape(
global_shape: &[usize],
world_size: usize,
rank: usize,
) -> Vec<usize> {
let mut local_shape = global_shape.to_vec();
let second_last_dim = local_shape.len() - 2;
let global_rows = global_shape[second_last_dim];
let rows_per_rank = global_rows.div_ceil(world_size);
local_shape[second_last_dim] = if rank == world_size - 1 {
global_rows - rank * rows_per_rank
} else {
rows_per_rank
};
local_shape
}
pub fn split_sizes(total_size: usize, world_size: usize) -> Vec<usize> {
let base_size = total_size / world_size;
let remainder = total_size % world_size;
(0..world_size).map(|i| base_size + if i < remainder { 1 } else { 0 }).collect()
}
}
pub struct AsyncTensorParallel {
mp_context: Arc<ModelParallelContext>,
}
impl AsyncTensorParallel {
pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
Self { mp_context }
}
pub async fn all_reduce_async(&self, tensor: Tensor) -> Result<AllReduceHandle> {
Ok(AllReduceHandle {
tensor,
completed: false,
})
}
pub async fn all_gather_async(
&self,
distributed: DistributedTensor,
) -> Result<AllGatherHandle> {
Ok(AllGatherHandle {
distributed,
completed: false,
})
}
}
pub struct AllReduceHandle {
tensor: Tensor,
completed: bool,
}
impl AllReduceHandle {
pub async fn wait(mut self) -> Result<Tensor> {
self.completed = true;
Ok(self.tensor)
}
pub fn is_complete(&self) -> bool {
self.completed
}
}
pub struct AllGatherHandle {
distributed: DistributedTensor,
completed: bool,
}
impl AllGatherHandle {
pub async fn wait(mut self) -> Result<Tensor> {
self.completed = true;
Ok(self.distributed.local_shard.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_parallel_shapes() {
let global_shape = vec![32, 64, 768];
let world_size = 4;
let col_shape = TensorParallelShapes::column_parallel_shape(&global_shape, world_size, 0);
assert_eq!(col_shape, vec![32, 64, 192]);
let row_shape = TensorParallelShapes::row_parallel_shape(&global_shape, world_size, 0);
assert_eq!(row_shape, vec![32, 16, 768]); }
#[test]
fn test_split_sizes() {
assert_eq!(
TensorParallelShapes::split_sizes(100, 4),
vec![25, 25, 25, 25]
);
assert_eq!(
TensorParallelShapes::split_sizes(101, 4),
vec![26, 25, 25, 25]
);
}
#[test]
fn test_weight_initialization() {
let weight =
TensorParallelInit::column_parallel_weight(512, 2048, 4, 0, InitMethod::Xavier)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[512, 512]); }
#[test]
fn test_column_parallel_shape_simple() {
let shape = vec![32, 768];
let result = TensorParallelShapes::column_parallel_shape(&shape, 4, 0);
assert_eq!(result, vec![32, 192]); }
#[test]
fn test_column_parallel_shape_world_size_1() {
let shape = vec![32, 768];
let result = TensorParallelShapes::column_parallel_shape(&shape, 1, 0);
assert_eq!(result, vec![32, 768]);
}
#[test]
fn test_column_parallel_shape_3d() {
let shape = vec![4, 32, 128];
let result = TensorParallelShapes::column_parallel_shape(&shape, 2, 0);
assert_eq!(result, vec![4, 32, 64]); }
#[test]
fn test_row_parallel_shape_simple() {
let shape = vec![768, 256];
let result = TensorParallelShapes::row_parallel_shape(&shape, 4, 0);
assert_eq!(result, vec![192, 256]);
}
#[test]
fn test_row_parallel_shape_3d() {
let shape = vec![4, 64, 128];
let result = TensorParallelShapes::row_parallel_shape(&shape, 2, 0);
assert_eq!(result, vec![4, 32, 128]); }
#[test]
fn test_split_sizes_even() {
assert_eq!(
TensorParallelShapes::split_sizes(100, 4),
vec![25, 25, 25, 25]
);
}
#[test]
fn test_split_sizes_uneven() {
assert_eq!(
TensorParallelShapes::split_sizes(101, 4),
vec![26, 25, 25, 25]
);
}
#[test]
fn test_split_sizes_one_process() {
assert_eq!(TensorParallelShapes::split_sizes(50, 1), vec![50]);
}
#[test]
fn test_split_sizes_more_procs_than_size() {
let sizes = TensorParallelShapes::split_sizes(3, 5);
assert_eq!(sizes.len(), 5);
let total: usize = sizes.iter().sum();
assert_eq!(total, 3);
}
#[test]
fn test_column_parallel_weight_rank_0() {
let weight =
TensorParallelInit::column_parallel_weight(256, 1024, 4, 0, InitMethod::Xavier)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[256, 256]);
}
#[test]
fn test_column_parallel_weight_last_rank() {
let weight =
TensorParallelInit::column_parallel_weight(256, 1024, 4, 3, InitMethod::Xavier)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[256, 256]); }
#[test]
fn test_row_parallel_weight() {
let weight = TensorParallelInit::row_parallel_weight(1024, 256, 4, 0, InitMethod::Xavier)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[256, 256]); }
#[test]
fn test_normal_init() {
let weight = TensorParallelInit::column_parallel_weight(
64,
64,
1,
0,
InitMethod::Normal {
mean: 0.0,
std: 0.02,
},
)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[64, 64]);
}
#[test]
fn test_uniform_init() {
let weight = TensorParallelInit::column_parallel_weight(
64,
64,
1,
0,
InitMethod::Uniform {
low: -1.0,
high: 1.0,
},
)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[64, 64]);
let data = weight.data().expect("data extraction failed");
for &val in data.iter() {
assert!((-1.0..=1.0).contains(&val), "Value out of range: {}", val);
}
}
#[test]
fn test_kaiming_init() {
let weight = TensorParallelInit::column_parallel_weight(64, 128, 2, 0, InitMethod::Kaiming)
.expect("operation failed in test");
assert_eq!(weight.shape(), &[64, 64]); }
#[test]
fn test_sharding_strategy_variants() {
let _col = ShardingStrategy::ColumnWise(1);
let _row = ShardingStrategy::RowWise(0);
let _rep = ShardingStrategy::Replicated;
let _block = ShardingStrategy::Block2D {
row_dim: 0,
col_dim: 1,
};
}
#[test]
fn test_init_method_variants() {
let _normal = InitMethod::Normal {
mean: 0.0,
std: 1.0,
};
let _uniform = InitMethod::Uniform {
low: -1.0,
high: 1.0,
};
let _xavier = InitMethod::Xavier;
let _kaiming = InitMethod::Kaiming;
}
#[test]
fn test_compute_gradient_norm_single() {
let data = vec![3.0, 4.0]; let tensor = Tensor::from_vec(data, &[2]).expect("tensor creation failed");
let norm =
TensorParallelOps::compute_gradient_norm(&[tensor]).expect("norm computation failed");
assert!((norm - 5.0).abs() < 1e-5);
}
#[test]
fn test_compute_gradient_norm_multiple() {
let t1 = Tensor::from_data(vec![1.0, 0.0], &[2]).expect("tensor creation failed");
let t2 = Tensor::from_data(vec![0.0, 1.0], &[2]).expect("tensor creation failed");
let norm =
TensorParallelOps::compute_gradient_norm(&[t1, t2]).expect("norm computation failed");
assert!((norm - std::f32::consts::SQRT_2).abs() < 1e-5);
}
#[test]
fn test_compute_gradient_norm_zero() {
let tensor = Tensor::zeros(&[4]).expect("tensor creation failed");
let norm =
TensorParallelOps::compute_gradient_norm(&[tensor]).expect("norm computation failed");
assert!(norm.abs() < 1e-6);
}
#[test]
fn test_all_reduce_handle_not_complete() {
let tensor = Tensor::ones(&[2]).expect("tensor creation failed");
let handle = AllReduceHandle {
tensor,
completed: false,
};
assert!(!handle.is_complete());
}
#[test]
fn test_all_gather_handle_creation() {
let local = Tensor::ones(&[2]).expect("tensor creation failed");
let dist = DistributedTensor {
local_shard: local,
global_shape: vec![4],
partition: TensorPartition {
split_dim: 0,
start_idx: 0,
end_idx: 2,
num_partitions: 2,
partition_rank: 0,
},
device_id: 0,
};
let handle = AllGatherHandle {
distributed: dist,
completed: false,
};
assert!(!handle.completed);
}
}