#![allow(unused_variables)]
use super::model_parallel::{DistributedTensor, ModelParallelContext, TensorPartition};
use crate::errors::{tensor_op_error, Result};
use crate::Tensor;
use std::sync::Arc;
pub struct ColumnParallelLinear {
weight: Tensor,
bias: Option<Tensor>,
mp_context: Arc<ModelParallelContext>,
#[allow(dead_code)]
in_features: usize,
out_features: usize,
}
impl ColumnParallelLinear {
pub fn new(
in_features: usize,
out_features: usize,
bias: bool,
mp_context: Arc<ModelParallelContext>,
) -> Result<Self> {
let world_size = mp_context.world_size();
let rank = mp_context.rank();
let out_features_per_device = out_features.div_ceil(world_size);
let local_out_start = rank * out_features_per_device;
let local_out_end = ((rank + 1) * out_features_per_device).min(out_features);
let local_out_features = local_out_end - local_out_start;
let weight = Tensor::randn(&[in_features, local_out_features])?;
let bias = if bias && rank == 0 { Some(Tensor::zeros(&[out_features])?) } else { None };
Ok(Self {
weight,
bias,
mp_context,
in_features,
out_features,
})
}
pub fn forward(&self, input: &Tensor) -> Result<DistributedTensor> {
let output = input.matmul(&self.weight)?;
let output = if let Some(ref bias) = self.bias {
let rank = self.mp_context.rank();
let world_size = self.mp_context.world_size();
let out_features_per_device = self.out_features.div_ceil(world_size);
let local_out_start = rank * out_features_per_device;
let local_out_end = ((rank + 1) * out_features_per_device).min(self.out_features);
let local_bias = bias.slice(0, local_out_start, local_out_end)?;
output.add(&local_bias)?
} else {
output
};
let mut global_shape = input.shape().to_vec();
let last_dim = global_shape.len() - 1;
global_shape[last_dim] = self.out_features;
let partition = TensorPartition {
split_dim: global_shape.len() - 1,
start_idx: self.mp_context.rank() * self.out_features / self.mp_context.world_size(),
end_idx: ((self.mp_context.rank() + 1) * self.out_features
/ self.mp_context.world_size())
.min(self.out_features),
num_partitions: self.mp_context.world_size(),
partition_rank: self.mp_context.rank(),
};
Ok(DistributedTensor::new(
output,
global_shape,
partition,
self.mp_context.rank(),
))
}
pub fn backward(&mut self, grad_output: &DistributedTensor, input: &Tensor) -> Result<Tensor> {
let input_ndim = input.shape().len();
let grad_weight = input
.transpose(input_ndim.saturating_sub(2), input_ndim.saturating_sub(1))?
.matmul(&grad_output.local_shard)?;
let weight_ndim = self.weight.shape().len();
let mut grad_input = grad_output.local_shard.matmul(
&self
.weight
.transpose(weight_ndim.saturating_sub(2), weight_ndim.saturating_sub(1))?,
)?;
self.mp_context.all_reduce(&mut grad_input)?;
Ok(grad_input)
}
}
pub struct RowParallelLinear {
weight: Tensor,
bias: Option<Tensor>,
mp_context: Arc<ModelParallelContext>,
#[allow(dead_code)]
in_features: usize,
_out_features: usize,
}
impl RowParallelLinear {
pub fn new(
in_features: usize,
out_features: usize,
bias: bool,
mp_context: Arc<ModelParallelContext>,
) -> Result<Self> {
let world_size = mp_context.world_size();
let rank = mp_context.rank();
let in_features_per_device = in_features.div_ceil(world_size);
let local_in_start = rank * in_features_per_device;
let local_in_end = ((rank + 1) * in_features_per_device).min(in_features);
let local_in_features = local_in_end - local_in_start;
let weight = Tensor::randn(&[local_in_features, out_features])?;
let bias = if bias { Some(Tensor::zeros(&[out_features])?) } else { None };
Ok(Self {
weight,
bias,
mp_context,
in_features,
_out_features: out_features,
})
}
pub fn forward(&self, input: &DistributedTensor) -> Result<Tensor> {
let local_output = input.local_shard.matmul(&self.weight)?;
let mut output = local_output;
self.mp_context.all_reduce(&mut output)?;
if let Some(bias) = &self.bias {
output = output.add(bias)?;
}
Ok(output)
}
pub fn backward(
&mut self,
grad_output: &Tensor,
input: &DistributedTensor,
) -> Result<DistributedTensor> {
let input_ndim = input.local_shard.shape().len();
let grad_weight = input
.local_shard
.transpose(input_ndim.saturating_sub(2), input_ndim.saturating_sub(1))?
.matmul(grad_output)?;
let weight_ndim = self.weight.shape().len();
let grad_input_local = grad_output.matmul(
&self
.weight
.transpose(weight_ndim.saturating_sub(2), weight_ndim.saturating_sub(1))?,
)?;
let partition = input.partition.clone();
Ok(DistributedTensor::new(
grad_input_local,
input.global_shape.clone(),
partition,
self.mp_context.rank(),
))
}
}
pub struct ParallelMultiHeadAttention {
#[allow(dead_code)]
num_heads_per_device: usize,
num_heads: usize,
head_dim: usize,
hidden_size: usize,
q_proj: ColumnParallelLinear,
k_proj: ColumnParallelLinear,
v_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
mp_context: Arc<ModelParallelContext>,
}
impl ParallelMultiHeadAttention {
pub fn new(
hidden_size: usize,
num_heads: usize,
mp_context: Arc<ModelParallelContext>,
) -> Result<Self> {
let world_size = mp_context.world_size();
if !num_heads.is_multiple_of(world_size) {
return Err(tensor_op_error(
"ParallelMultiHeadAttention::new",
format!(
"Number of heads {} must be divisible by world size {}",
num_heads, world_size
),
));
}
let num_heads_per_device = num_heads / world_size;
let head_dim = hidden_size / num_heads;
let q_proj =
ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
let k_proj =
ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
let v_proj =
ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
let o_proj = RowParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
Ok(Self {
num_heads_per_device,
num_heads,
head_dim,
hidden_size,
q_proj,
k_proj,
v_proj,
o_proj,
mp_context,
})
}
pub fn forward(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let q = self.q_proj.forward(hidden_states)?;
let k = self.k_proj.forward(hidden_states)?;
let v = self.v_proj.forward(hidden_states)?;
let q = q.local_shard.clone();
let k = k.local_shard.clone();
let v = v.local_shard.clone();
let num_heads_local = self.num_heads / self.mp_context.world_size();
let q = q.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
let k = k.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
let v = v.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let k_ndim = k.shape().len();
let scores = q.matmul(&k.transpose(k_ndim.saturating_sub(2), k_ndim.saturating_sub(1))?)?;
let scores = scores.scalar_mul(1.0 / (self.head_dim as f32).sqrt())?;
let scores = if let Some(mask) = attention_mask { scores.add(mask)? } else { scores };
let scores_ndim = scores.shape().len();
let attn_probs = scores.softmax((scores_ndim as i32) - 1)?;
let attn_output = attn_probs.matmul(&v)?;
let attn_output = attn_output.transpose(1, 2)?;
let hidden_size_local = num_heads_local * self.head_dim;
let attn_output = attn_output.reshape(&[batch_size, seq_len, hidden_size_local])?;
let attn_distributed = DistributedTensor::new(
attn_output,
vec![batch_size, seq_len, self.hidden_size],
TensorPartition {
split_dim: 2,
start_idx: self.mp_context.rank() * self.hidden_size / self.mp_context.world_size(),
end_idx: ((self.mp_context.rank() + 1) * self.hidden_size
/ self.mp_context.world_size())
.min(self.hidden_size),
num_partitions: self.mp_context.world_size(),
partition_rank: self.mp_context.rank(),
},
self.mp_context.rank(),
);
self.o_proj.forward(&attn_distributed)
}
}
pub struct ParallelMLP {
fc1: ColumnParallelLinear,
fc2: RowParallelLinear,
activation: ActivationType,
#[allow(dead_code)]
mp_context: Arc<ModelParallelContext>,
}
#[derive(Debug, Clone, Copy)]
pub enum ActivationType {
Relu,
Gelu,
GeluNew,
Swiglu,
}
impl ParallelMLP {
pub fn new(
hidden_size: usize,
intermediate_size: usize,
activation: ActivationType,
mp_context: Arc<ModelParallelContext>,
) -> Result<Self> {
let fc1 =
ColumnParallelLinear::new(hidden_size, intermediate_size, false, mp_context.clone())?;
let fc2 =
RowParallelLinear::new(intermediate_size, hidden_size, false, mp_context.clone())?;
Ok(Self {
fc1,
fc2,
activation,
mp_context,
})
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden = self.fc1.forward(hidden_states)?;
let activated = self.apply_activation(&hidden.local_shard)?;
let hidden_distributed = DistributedTensor::new(
activated,
hidden.global_shape.clone(),
hidden.partition.clone(),
hidden.device_id,
);
self.fc2.forward(&hidden_distributed)
}
fn apply_activation(&self, tensor: &Tensor) -> Result<Tensor> {
use crate::ops::activations::{gelu, gelu_new, relu, swiglu};
match self.activation {
ActivationType::Relu => Ok(relu(tensor)?),
ActivationType::Gelu => Ok(gelu(tensor)?),
ActivationType::GeluNew => Ok(gelu_new(tensor)?),
ActivationType::Swiglu => {
let shape = tensor.shape();
if !shape[shape.len() - 1].is_multiple_of(2) {
return Err(tensor_op_error(
"ParallelMLP::apply_activation",
"SwiGLU requires even dimension for splitting",
));
}
let split_size = shape[shape.len() - 1] / 2;
let mut new_shape = shape.to_vec();
let last_idx = new_shape.len() - 1;
new_shape[last_idx] = split_size;
let last_axis = shape.len() - 1;
let gate_tensor = tensor.slice(last_axis, 0, split_size)?;
let up_tensor = tensor.slice(last_axis, split_size, shape[last_axis])?;
Ok(swiglu(&gate_tensor, &up_tensor)?)
},
}
}
}
#[cfg(test)]
mod tests {
use super::super::model_parallel::{CommunicationBackend, ModelParallelConfig};
use super::*;
#[test]
fn test_column_parallel_linear() {
let config = ModelParallelConfig {
num_devices: 2,
device_ids: vec![0, 1],
comm_backend: CommunicationBackend::Custom,
..Default::default()
};
let mp_context =
Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
let layer = ColumnParallelLinear::new(512, 2048, true, mp_context)
.expect("operation failed in test");
assert_eq!(layer.weight.shape(), &[512, 1024]); }
#[test]
fn test_parallel_attention_heads() {
let config = ModelParallelConfig {
num_devices: 4,
device_ids: vec![0, 1, 2, 3],
comm_backend: CommunicationBackend::Custom,
..Default::default()
};
let mp_context =
Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
let attn =
ParallelMultiHeadAttention::new(768, 12, mp_context).expect("operation failed in test");
assert_eq!(attn.num_heads_per_device, 3); assert_eq!(attn.head_dim, 64); }
}