#![allow(dead_code)]
use crate::ProcessGroup;
use crate::TorshResult;
use std::sync::Arc;
use std::time::Instant;
use torsh_tensor::Tensor;
use super::{
config::{PipelineSchedule, RankMapping, ThreeDParallelismConfig},
gradient_sync::GradientSynchronizer,
memory_management::MemoryManager,
model_shards::ModelShards,
performance::Performance3DMonitor,
process_group::ProcessGroupManager,
};
pub struct ThreeDParallelismCoordinator {
config: ThreeDParallelismConfig,
rank_mapping: RankMapping,
process_group_manager: ProcessGroupManager,
memory_manager: MemoryManager,
gradient_synchronizer: GradientSynchronizer,
performance_monitor: Performance3DMonitor,
model_shards: ModelShards,
current_micro_batch: usize,
}
impl ThreeDParallelismCoordinator {
pub fn new(
config: ThreeDParallelismConfig,
process_group: Arc<ProcessGroup>,
) -> TorshResult<Self> {
let global_rank = process_group.rank();
let world_size = process_group.world_size() as usize;
config.validate(world_size)?;
let rank_mapping = RankMapping::new(&config, global_rank as usize);
let process_group_manager = ProcessGroupManager::new(&config, process_group.clone())?;
let memory_manager = MemoryManager::new(&config, &rank_mapping)?;
let gradient_synchronizer = GradientSynchronizer::new(&config, &rank_mapping)?;
let performance_monitor = Performance3DMonitor::new(&rank_mapping);
let model_shards = ModelShards::new(&config)?;
Ok(Self {
config,
rank_mapping,
process_group_manager,
memory_manager,
gradient_synchronizer,
performance_monitor,
model_shards,
current_micro_batch: 0,
})
}
pub async fn forward_pass(
&mut self,
input: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
let start_time = Instant::now();
let sequence_length = input.shape().dims()[1];
let tp_input = if self.rank_mapping.is_tp_head() {
self.split_tensor_parallel(input)?
} else {
self.receive_tensor_parallel_input(input.shape().dims(), micro_batch_id)
.await?
};
let mut current_activations = tp_input;
match self.config.pipeline_schedule {
PipelineSchedule::Interleaved => {
current_activations = self
.interleaved_forward_pass(¤t_activations, micro_batch_id)
.await?;
}
PipelineSchedule::GPipe => {
current_activations = self
.gpipe_forward_pass(¤t_activations, micro_batch_id)
.await?;
}
PipelineSchedule::OneForwardOneBackward => {
current_activations = self
.f1b1_forward_pass(¤t_activations, micro_batch_id)
.await?;
}
PipelineSchedule::RoundRobin => {
current_activations = self
.round_robin_forward_pass(¤t_activations, micro_batch_id)
.await?;
}
}
self.performance_monitor
.record_forward_pass(start_time.elapsed(), sequence_length)
.await;
Ok(current_activations)
}
pub async fn backward_pass(
&mut self,
grad_output: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<()> {
let start_time = Instant::now();
let sequence_length = grad_output.shape().dims()[1];
let _current_gradients = match self.config.pipeline_schedule {
PipelineSchedule::Interleaved => {
self.interleaved_backward_pass(grad_output, micro_batch_id)
.await?
}
PipelineSchedule::GPipe => {
self.gpipe_backward_pass(grad_output, micro_batch_id)
.await?
}
PipelineSchedule::OneForwardOneBackward => {
self.f1b1_backward_pass(grad_output, micro_batch_id).await?
}
PipelineSchedule::RoundRobin => {
self.round_robin_backward_pass(grad_output, micro_batch_id)
.await?
}
};
if micro_batch_id == self.get_num_micro_batches() - 1 {
self.gradient_synchronizer
.synchronize_gradients(&self.model_shards)
.await?;
}
self.performance_monitor
.record_backward_pass(start_time.elapsed(), sequence_length)
.await;
Ok(())
}
async fn interleaved_forward_pass(
&mut self,
input: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
let layers_per_stage = self.config.layers_per_stage();
let mut activations = input.clone();
for layer_idx in 0..layers_per_stage {
let global_layer_idx = self.rank_mapping.pp_rank * layers_per_stage + layer_idx;
activations = self.apply_layer(&activations, global_layer_idx).await?;
if self.config.enable_gradient_checkpointing {
self.memory_manager
.store_activation(&activations, global_layer_idx, micro_batch_id)
.await?;
}
}
if let Some(next_rank) = self.rank_mapping.next_pp_rank() {
self.process_group_manager
.send_to_next_stage(&activations, next_rank, micro_batch_id)
.await?;
}
Ok(activations)
}
async fn interleaved_backward_pass(
&mut self,
grad_output: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
let layers_per_stage = self.config.layers_per_stage();
let mut gradients = grad_output.clone();
for layer_idx in (0..layers_per_stage).rev() {
let global_layer_idx = self.rank_mapping.pp_rank * layers_per_stage + layer_idx;
let activation = if self.config.enable_gradient_checkpointing {
self.memory_manager
.retrieve_activation(global_layer_idx, micro_batch_id)
.await?
} else {
Tensor::zeros(gradients.shape().dims(), gradients.device())?
};
gradients = self
.apply_layer_backward(&gradients, &activation, global_layer_idx)
.await?;
}
if let Some(prev_rank) = self.rank_mapping.prev_pp_rank() {
self.process_group_manager
.send_to_prev_stage(&gradients, prev_rank, micro_batch_id)
.await?;
}
Ok(gradients)
}
async fn gpipe_forward_pass(
&mut self,
input: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
self.interleaved_forward_pass(input, micro_batch_id).await
}
async fn gpipe_backward_pass(
&mut self,
grad_output: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
self.interleaved_backward_pass(grad_output, micro_batch_id)
.await
}
async fn f1b1_forward_pass(
&mut self,
input: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
self.interleaved_forward_pass(input, micro_batch_id).await
}
async fn f1b1_backward_pass(
&mut self,
grad_output: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
self.interleaved_backward_pass(grad_output, micro_batch_id)
.await
}
async fn round_robin_forward_pass(
&mut self,
input: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
self.interleaved_forward_pass(input, micro_batch_id).await
}
async fn round_robin_backward_pass(
&mut self,
grad_output: &Tensor<f32>,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
self.interleaved_backward_pass(grad_output, micro_batch_id)
.await
}
async fn apply_layer(
&mut self,
input: &Tensor<f32>,
layer_idx: usize,
) -> TorshResult<Tensor<f32>> {
let stage_idx = layer_idx / self.config.layers_per_stage();
let layer_in_stage = layer_idx % self.config.layers_per_stage();
if stage_idx < self.model_shards.pipeline_stages.len()
&& layer_in_stage < self.model_shards.pipeline_stages[stage_idx].len()
{
let layer_shard = &self.model_shards.pipeline_stages[stage_idx][layer_in_stage];
let output = input.matmul(&layer_shard.weight)?;
if let Some(ref bias) = layer_shard.bias {
let output_with_bias = output.add(bias)?;
Ok(output_with_bias)
} else {
Ok(output)
}
} else {
Ok(input.clone())
}
}
async fn apply_layer_backward(
&mut self,
grad_output: &Tensor<f32>,
_activation: &Tensor<f32>,
layer_idx: usize,
) -> TorshResult<Tensor<f32>> {
let stage_idx = layer_idx / self.config.layers_per_stage();
let layer_in_stage = layer_idx % self.config.layers_per_stage();
if stage_idx < self.model_shards.pipeline_stages.len()
&& layer_in_stage < self.model_shards.pipeline_stages[stage_idx].len()
{
let layer_shard = &self.model_shards.pipeline_stages[stage_idx][layer_in_stage];
let grad_input = grad_output.matmul(&layer_shard.weight.transpose(0, 1)?)?;
Ok(grad_input)
} else {
Ok(grad_output.clone())
}
}
fn split_tensor_parallel(&self, tensor: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
let shape = tensor.shape();
let dims = shape.dims();
let hidden_dim = dims[dims.len() - 1];
let chunk_size = hidden_dim / self.config.tp_size;
let start_idx = self.rank_mapping.tp_rank * chunk_size;
let _end_idx = start_idx + chunk_size;
let mut new_dims = dims.to_vec();
new_dims[dims.len() - 1] = chunk_size;
Ok(Tensor::zeros(&new_dims, tensor.device())?)
}
async fn receive_tensor_parallel_input(
&self,
shape: &[usize],
_micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
Ok(Tensor::zeros(shape, torsh_core::DeviceType::Cpu)?)
}
fn get_num_micro_batches(&self) -> usize {
4 }
pub fn get_config(&self) -> &ThreeDParallelismConfig {
&self.config
}
pub fn get_rank_mapping(&self) -> &RankMapping {
&self.rank_mapping
}
pub fn get_model_shards(&self) -> &ModelShards {
&self.model_shards
}
pub fn get_performance_stats(&self) -> super::performance::Performance3DStats {
self.performance_monitor.get_stats()
}
pub fn update_config(&mut self, config: ThreeDParallelismConfig) -> TorshResult<()> {
config.validate(self.rank_mapping.world_size)?;
self.config = config;
Ok(())
}
}