#![allow(dead_code)]
use crate::TorshResult;
use std::collections::HashMap;
use torsh_tensor::Tensor;
use super::config::ThreeDParallelismConfig;
fn randn(shape: &[usize]) -> TorshResult<Tensor<f32>> {
use scirs2_core::random::Random;
let total_elements: usize = shape.iter().product();
let mut data = Vec::with_capacity(total_elements);
let mut random = Random::seed(42);
for _ in 0..total_elements {
data.push(random.gen_range(-1.0..1.0));
}
Ok(Tensor::from_data(
data,
shape.to_vec(),
torsh_core::DeviceType::Cpu,
)?)
}
#[derive(Debug)]
pub struct ModelShards {
pub pipeline_stages: Vec<Vec<LayerShard>>,
pub total_parameters: usize,
pub parameters_per_stage: Vec<usize>,
pub shards: HashMap<String, ModelShard>,
layer_mapping: HashMap<usize, (usize, usize)>, }
impl ModelShards {
pub fn new(config: &ThreeDParallelismConfig) -> TorshResult<Self> {
let layers_per_stage = config.layers_per_stage();
let mut pipeline_stages = Vec::new();
let mut parameters_per_stage = Vec::new();
let mut shards = HashMap::new();
let mut layer_mapping = HashMap::new();
let mut total_parameters = 0;
for stage_idx in 0..config.pp_size {
let mut stage_layers = Vec::new();
let mut stage_params = 0;
for layer_in_stage in 0..layers_per_stage {
let global_layer_id = stage_idx * layers_per_stage + layer_in_stage;
let layer_shard = LayerShard::new(global_layer_id, config.tp_size)?;
let param_count = layer_shard.parameter_count();
stage_params += param_count;
total_parameters += param_count;
let layer_name = format!("stage_{}_layer_{}", stage_idx, layer_in_stage);
let model_shard = ModelShard {
parameters: vec![0.0; param_count],
gradients: Some(vec![0.0; param_count]),
shard_info: ShardInfo {
stage_id: stage_idx,
layer_id: global_layer_id,
tp_rank: 0, dp_rank: 0, },
};
shards.insert(layer_name, model_shard);
layer_mapping.insert(global_layer_id, (stage_idx, layer_in_stage));
stage_layers.push(layer_shard);
}
pipeline_stages.push(stage_layers);
parameters_per_stage.push(stage_params);
}
Ok(Self {
pipeline_stages,
total_parameters,
parameters_per_stage,
shards,
layer_mapping,
})
}
pub fn get_layer_shard(&self, layer_id: usize) -> Option<&LayerShard> {
if let Some(&(stage_idx, layer_in_stage)) = self.layer_mapping.get(&layer_id) {
self.pipeline_stages
.get(stage_idx)
.and_then(|stage| stage.get(layer_in_stage))
} else {
None
}
}
pub fn get_layer_shard_mut(&mut self, layer_id: usize) -> Option<&mut LayerShard> {
if let Some(&(stage_idx, layer_in_stage)) = self.layer_mapping.get(&layer_id) {
self.pipeline_stages
.get_mut(stage_idx)
.and_then(|stage| stage.get_mut(layer_in_stage))
} else {
None
}
}
pub fn get_stage_layers(&self, stage_idx: usize) -> Option<&Vec<LayerShard>> {
self.pipeline_stages.get(stage_idx)
}
pub fn get_stage_layers_mut(&mut self, stage_idx: usize) -> Option<&mut Vec<LayerShard>> {
self.pipeline_stages.get_mut(stage_idx)
}
pub fn get_model_shard(&self, name: &str) -> Option<&ModelShard> {
self.shards.get(name)
}
pub fn get_model_shard_mut(&mut self, name: &str) -> Option<&mut ModelShard> {
self.shards.get_mut(name)
}
pub fn update_layer_gradients(
&mut self,
layer_id: usize,
weight_grad: Option<Tensor<f32>>,
bias_grad: Option<Tensor<f32>>,
) -> TorshResult<()> {
if let Some(layer) = self.get_layer_shard_mut(layer_id) {
layer.grad_weight = weight_grad;
layer.grad_bias = bias_grad;
}
Ok(())
}
pub fn zero_gradients(&mut self) -> TorshResult<()> {
for stage_layers in &mut self.pipeline_stages {
for layer in stage_layers {
layer.zero_gradients()?;
}
}
Ok(())
}
pub fn memory_usage_bytes(&self) -> usize {
let mut total_bytes = 0;
for stage_layers in &self.pipeline_stages {
for layer in stage_layers {
total_bytes += layer.memory_usage_bytes();
}
}
total_bytes
}
pub fn create_tp_sharding_plan(&self, tp_size: usize) -> TensorParallelShardingPlan {
let mut sharding_plan = TensorParallelShardingPlan::new(tp_size);
for (stage_idx, stage_layers) in self.pipeline_stages.iter().enumerate() {
for (layer_idx, layer) in stage_layers.iter().enumerate() {
let layer_plan = self.create_layer_tp_plan(layer, tp_size);
sharding_plan.add_layer_plan(stage_idx, layer_idx, layer_plan);
}
}
sharding_plan
}
fn create_layer_tp_plan(&self, layer: &LayerShard, _tp_size: usize) -> LayerTensorParallelPlan {
let binding = layer.weight.shape();
let weight_dims = binding.dims();
let shard_strategies = match layer.layer_type {
LayerType::Linear => {
vec![ShardStrategy::ColumnParallel]
}
LayerType::Attention => {
vec![ShardStrategy::ColumnParallel, ShardStrategy::RowParallel]
}
LayerType::MLP => {
vec![ShardStrategy::ColumnParallel, ShardStrategy::RowParallel]
}
LayerType::Embedding => {
vec![ShardStrategy::VocabParallel]
}
};
LayerTensorParallelPlan {
layer_id: layer.layer_id,
layer_type: layer.layer_type,
weight_shape: weight_dims.to_vec(),
shard_strategies,
communication_pattern: self.determine_communication_pattern(&layer.layer_type),
}
}
fn determine_communication_pattern(&self, layer_type: &LayerType) -> CommunicationPattern {
match layer_type {
LayerType::Linear => CommunicationPattern::AllReduce,
LayerType::Attention => CommunicationPattern::AllGatherThenReduceScatter,
LayerType::MLP => CommunicationPattern::AllGatherThenReduceScatter,
LayerType::Embedding => CommunicationPattern::AllReduce,
}
}
pub fn apply_weight_updates(
&mut self,
updates: &HashMap<String, Tensor<f32>>,
) -> TorshResult<()> {
for (layer_name, update) in updates {
if let Some(shard) = self.shards.get_mut(layer_name) {
let update_data = update.data()?;
for (i, &update_val) in update_data.iter().enumerate() {
if i < shard.parameters.len() {
shard.parameters[i] -= update_val; }
}
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct LayerShard {
pub layer_id: usize,
pub layer_type: LayerType,
pub weight: Tensor<f32>,
pub bias: Option<Tensor<f32>>,
pub grad_weight: Option<Tensor<f32>>,
pub grad_bias: Option<Tensor<f32>>,
pub down_projection_weight: Option<Tensor<f32>>,
pub grad_down_projection: Option<Tensor<f32>>,
}
impl LayerShard {
pub fn new(layer_id: usize, tp_size: usize) -> TorshResult<Self> {
let layer_type = match layer_id % 4 {
0 => LayerType::Embedding,
1 => LayerType::Attention,
2 => LayerType::MLP,
_ => LayerType::Linear,
};
let hidden_size = 512;
let shard_size = hidden_size / tp_size;
let weight = match layer_type {
LayerType::Linear | LayerType::Embedding => randn(&[hidden_size, shard_size])?,
LayerType::Attention => {
randn(&[hidden_size, 3 * shard_size])?
}
LayerType::MLP => {
randn(&[hidden_size, 4 * shard_size])?
}
};
let bias = Some(Tensor::zeros(&[shard_size], torsh_core::DeviceType::Cpu)?);
let down_projection_weight = if matches!(layer_type, LayerType::MLP) {
Some(randn(&[4 * shard_size, hidden_size])?)
} else {
None
};
Ok(Self {
layer_id,
layer_type,
weight,
bias,
grad_weight: None,
grad_bias: None,
down_projection_weight,
grad_down_projection: None,
})
}
pub fn parameter_count(&self) -> usize {
let weight_params = self.weight.numel();
let bias_params = self.bias.as_ref().map(|b| b.numel()).unwrap_or(0);
let down_proj_params = self
.down_projection_weight
.as_ref()
.map(|w| w.numel())
.unwrap_or(0);
weight_params + bias_params + down_proj_params
}
pub fn memory_usage_bytes(&self) -> usize {
let mut bytes = self.weight.numel() * std::mem::size_of::<f32>();
if let Some(ref bias) = self.bias {
bytes += bias.numel() * std::mem::size_of::<f32>();
}
if let Some(ref down_proj) = self.down_projection_weight {
bytes += down_proj.numel() * std::mem::size_of::<f32>();
}
if self.grad_weight.is_some() {
bytes += self.weight.numel() * std::mem::size_of::<f32>();
}
if self.grad_bias.is_some() {
bytes +=
self.bias.as_ref().map(|b| b.numel()).unwrap_or(0) * std::mem::size_of::<f32>();
}
if self.grad_down_projection.is_some() {
bytes += self
.down_projection_weight
.as_ref()
.map(|w| w.numel())
.unwrap_or(0)
* std::mem::size_of::<f32>();
}
bytes
}
pub fn zero_gradients(&mut self) -> TorshResult<()> {
if let Some(ref mut _grad_weight) = self.grad_weight {
}
if let Some(ref mut _grad_bias) = self.grad_bias {
}
if let Some(ref mut _grad_down_proj) = self.grad_down_projection {
}
Ok(())
}
pub fn init_gradients(&mut self) -> TorshResult<()> {
self.grad_weight = Some(Tensor::zeros(
self.weight.shape().dims(),
self.weight.device(),
)?);
if let Some(ref bias) = self.bias {
self.grad_bias = Some(Tensor::zeros(bias.shape().dims(), bias.device())?);
}
if let Some(ref down_proj) = self.down_projection_weight {
self.grad_down_projection =
Some(Tensor::zeros(down_proj.shape().dims(), down_proj.device())?);
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LayerType {
Linear,
Attention,
MLP,
Embedding,
}
#[derive(Debug)]
pub struct ModelShard {
pub parameters: Vec<f32>,
pub gradients: Option<Vec<f32>>,
pub shard_info: ShardInfo,
}
#[derive(Debug, Clone)]
pub struct ShardInfo {
pub stage_id: usize,
pub layer_id: usize,
pub tp_rank: usize,
pub dp_rank: usize,
}
#[derive(Debug)]
pub struct TensorParallelShardingPlan {
tp_size: usize,
layer_plans: HashMap<(usize, usize), LayerTensorParallelPlan>, }
impl TensorParallelShardingPlan {
fn new(tp_size: usize) -> Self {
Self {
tp_size,
layer_plans: HashMap::new(),
}
}
fn add_layer_plan(
&mut self,
stage_idx: usize,
layer_idx: usize,
plan: LayerTensorParallelPlan,
) {
self.layer_plans.insert((stage_idx, layer_idx), plan);
}
pub fn get_layer_plan(
&self,
stage_idx: usize,
layer_idx: usize,
) -> Option<&LayerTensorParallelPlan> {
self.layer_plans.get(&(stage_idx, layer_idx))
}
}
#[derive(Debug, Clone)]
pub struct LayerTensorParallelPlan {
pub layer_id: usize,
pub layer_type: LayerType,
pub weight_shape: Vec<usize>,
pub shard_strategies: Vec<ShardStrategy>,
pub communication_pattern: CommunicationPattern,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ShardStrategy {
ColumnParallel,
RowParallel,
VocabParallel,
Replicated,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CommunicationPattern {
AllReduce,
AllGatherThenReduceScatter,
ReduceScatterThenAllGather,
None,
}