#![allow(dead_code)]
use crate::{ProcessGroup, TorshDistributedError, TorshResult};
use std::collections::HashMap;
use std::sync::Arc;
use torsh_tensor::Tensor;
use super::config::{CommunicationStrategy, ProcessGroupIds, RankMapping, ThreeDParallelismConfig};
pub struct ProcessGroupManager {
main_process_group: Arc<ProcessGroup>,
dp_process_groups: HashMap<String, Arc<ProcessGroup>>,
tp_process_groups: HashMap<String, Arc<ProcessGroup>>,
pp_process_groups: HashMap<String, Arc<ProcessGroup>>,
group_ids: ProcessGroupIds,
rank_mapping: RankMapping,
comm_strategy: CommunicationStrategy,
}
impl ProcessGroupManager {
pub fn new(
config: &ThreeDParallelismConfig,
main_process_group: Arc<ProcessGroup>,
) -> TorshResult<Self> {
let global_rank = main_process_group.rank();
let rank_mapping = RankMapping::new(config, global_rank as usize);
let group_ids = ProcessGroupIds::new(config);
let dp_process_groups = Self::create_dp_process_groups(config, &main_process_group)?;
let tp_process_groups = Self::create_tp_process_groups(config, &main_process_group)?;
let pp_process_groups = Self::create_pp_process_groups(config, &main_process_group)?;
Ok(Self {
main_process_group,
dp_process_groups,
tp_process_groups,
pp_process_groups,
group_ids,
rank_mapping,
comm_strategy: config.comm_strategy,
})
}
fn create_dp_process_groups(
config: &ThreeDParallelismConfig,
main_pg: &Arc<ProcessGroup>,
) -> TorshResult<HashMap<String, Arc<ProcessGroup>>> {
let mut dp_groups = HashMap::new();
for tp_rank in 0..config.tp_size {
for pp_rank in 0..config.pp_size {
let group_name = format!("dp_group_tp{}_pp{}", tp_rank, pp_rank);
let mut group_ranks = Vec::new();
for dp_rank in 0..config.dp_size {
let global_rank =
RankMapping::from_3d_coords(config, dp_rank, tp_rank, pp_rank);
group_ranks.push(global_rank);
}
let pg = Arc::clone(main_pg); dp_groups.insert(group_name, pg);
}
}
Ok(dp_groups)
}
fn create_tp_process_groups(
config: &ThreeDParallelismConfig,
main_pg: &Arc<ProcessGroup>,
) -> TorshResult<HashMap<String, Arc<ProcessGroup>>> {
let mut tp_groups = HashMap::new();
for dp_rank in 0..config.dp_size {
for pp_rank in 0..config.pp_size {
let group_name = format!("tp_group_dp{}_pp{}", dp_rank, pp_rank);
let mut group_ranks = Vec::new();
for tp_rank in 0..config.tp_size {
let global_rank =
RankMapping::from_3d_coords(config, dp_rank, tp_rank, pp_rank);
group_ranks.push(global_rank);
}
let pg = Arc::clone(main_pg);
tp_groups.insert(group_name, pg);
}
}
Ok(tp_groups)
}
fn create_pp_process_groups(
config: &ThreeDParallelismConfig,
main_pg: &Arc<ProcessGroup>,
) -> TorshResult<HashMap<String, Arc<ProcessGroup>>> {
let mut pp_groups = HashMap::new();
for dp_rank in 0..config.dp_size {
for tp_rank in 0..config.tp_size {
let group_name = format!("pp_group_dp{}_tp{}", dp_rank, tp_rank);
let mut group_ranks = Vec::new();
for pp_rank in 0..config.pp_size {
let global_rank =
RankMapping::from_3d_coords(config, dp_rank, tp_rank, pp_rank);
group_ranks.push(global_rank);
}
let pg = Arc::clone(main_pg);
pp_groups.insert(group_name, pg);
}
}
Ok(pp_groups)
}
pub fn get_dp_process_group(&self) -> Option<&Arc<ProcessGroup>> {
let group_id = self
.group_ids
.get_dp_group_id(self.rank_mapping.tp_rank, self.rank_mapping.pp_rank)?;
self.dp_process_groups.get(group_id)
}
pub fn get_tp_process_group(&self) -> Option<&Arc<ProcessGroup>> {
let group_id = self
.group_ids
.get_tp_group_id(self.rank_mapping.dp_rank, self.rank_mapping.pp_rank)?;
self.tp_process_groups.get(group_id)
}
pub fn get_pp_process_group(&self) -> Option<&Arc<ProcessGroup>> {
let group_id = self
.group_ids
.get_pp_group_id(self.rank_mapping.dp_rank, self.rank_mapping.tp_rank)?;
self.pp_process_groups.get(group_id)
}
pub async fn send_to_next_stage(
&self,
tensor: &Tensor<f32>,
next_rank: usize,
micro_batch_id: usize,
) -> TorshResult<()> {
if let Some(pp_pg) = self.get_pp_process_group() {
let comm_req = CommunicationRequest {
tensor: tensor.clone(),
target_rank: next_rank,
micro_batch_id,
comm_type: CommunicationType::PipelineForward,
};
self.execute_communication(&comm_req, pp_pg).await?;
}
Ok(())
}
pub async fn send_to_prev_stage(
&self,
tensor: &Tensor<f32>,
prev_rank: usize,
micro_batch_id: usize,
) -> TorshResult<()> {
if let Some(pp_pg) = self.get_pp_process_group() {
let comm_req = CommunicationRequest {
tensor: tensor.clone(),
target_rank: prev_rank,
micro_batch_id,
comm_type: CommunicationType::PipelineBackward,
};
self.execute_communication(&comm_req, pp_pg).await?;
}
Ok(())
}
pub async fn receive_from_prev_stage(
&self,
shape: &[usize],
_prev_rank: usize,
_micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
if let Some(_pp_pg) = self.get_pp_process_group() {
let tensor = Tensor::zeros(shape, torsh_core::DeviceType::Cpu)?;
Ok(tensor)
} else {
Err(TorshDistributedError::InternalError(
"Pipeline parallel process group not found".to_string(),
))
}
}
pub async fn receive_from_next_stage(
&self,
shape: &[usize],
_next_rank: usize,
_micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
if let Some(_pp_pg) = self.get_pp_process_group() {
let tensor = Tensor::zeros(shape, torsh_core::DeviceType::Cpu)?;
Ok(tensor)
} else {
Err(TorshDistributedError::InternalError(
"Pipeline parallel process group not found".to_string(),
))
}
}
pub async fn all_reduce_dp(&self, tensor: &mut Tensor<f32>) -> TorshResult<()> {
if let Some(dp_pg) = self.get_dp_process_group() {
self.execute_all_reduce(tensor, dp_pg, self.comm_strategy)
.await?;
}
Ok(())
}
pub async fn all_reduce_tp(&self, tensor: &mut Tensor<f32>) -> TorshResult<()> {
if let Some(tp_pg) = self.get_tp_process_group() {
self.execute_all_reduce(tensor, tp_pg, self.comm_strategy)
.await?;
}
Ok(())
}
pub async fn all_gather_tp(&self, tensor: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
if let Some(tp_pg) = self.get_tp_process_group() {
self.execute_all_gather(tensor, tp_pg).await
} else {
Ok(tensor.clone())
}
}
pub async fn reduce_scatter_tp(&self, tensor: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
if let Some(tp_pg) = self.get_tp_process_group() {
self.execute_reduce_scatter(tensor, tp_pg).await
} else {
Ok(tensor.clone())
}
}
async fn execute_communication(
&self,
request: &CommunicationRequest,
_process_group: &Arc<ProcessGroup>,
) -> TorshResult<()> {
match request.comm_type {
CommunicationType::PipelineForward | CommunicationType::PipelineBackward => {
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
Ok(())
}
}
}
async fn execute_all_reduce(
&self,
tensor: &mut Tensor<f32>,
process_group: &Arc<ProcessGroup>,
strategy: CommunicationStrategy,
) -> TorshResult<()> {
match strategy {
CommunicationStrategy::AllReduce => {
self.standard_all_reduce(tensor, process_group).await
}
CommunicationStrategy::HierarchicalAllReduce => {
self.hierarchical_all_reduce(tensor, process_group).await
}
CommunicationStrategy::RingAllReduce => {
self.ring_all_reduce(tensor, process_group).await
}
CommunicationStrategy::TreeAllReduce => {
self.tree_all_reduce(tensor, process_group).await
}
CommunicationStrategy::Adaptive => {
let tensor_size = tensor.numel() * std::mem::size_of::<f32>();
if tensor_size < 1024 * 1024 {
self.tree_all_reduce(tensor, process_group).await
} else {
self.ring_all_reduce(tensor, process_group).await
}
}
}
}
async fn standard_all_reduce(
&self,
_tensor: &mut Tensor<f32>,
_process_group: &Arc<ProcessGroup>,
) -> TorshResult<()> {
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
Ok(())
}
async fn hierarchical_all_reduce(
&self,
_tensor: &mut Tensor<f32>,
_process_group: &Arc<ProcessGroup>,
) -> TorshResult<()> {
tokio::time::sleep(tokio::time::Duration::from_micros(80)).await;
Ok(())
}
async fn ring_all_reduce(
&self,
_tensor: &mut Tensor<f32>,
_process_group: &Arc<ProcessGroup>,
) -> TorshResult<()> {
tokio::time::sleep(tokio::time::Duration::from_micros(120)).await;
Ok(())
}
async fn tree_all_reduce(
&self,
_tensor: &mut Tensor<f32>,
_process_group: &Arc<ProcessGroup>,
) -> TorshResult<()> {
tokio::time::sleep(tokio::time::Duration::from_micros(60)).await;
Ok(())
}
async fn execute_all_gather(
&self,
tensor: &Tensor<f32>,
process_group: &Arc<ProcessGroup>,
) -> TorshResult<Tensor<f32>> {
let gathered_shape = {
let mut shape = tensor.shape().dims().to_vec();
shape[0] *= process_group.world_size() as usize; shape
};
let gathered_tensor = Tensor::zeros(&gathered_shape, tensor.device())?;
tokio::time::sleep(tokio::time::Duration::from_micros(50)).await;
Ok(gathered_tensor)
}
async fn execute_reduce_scatter(
&self,
tensor: &Tensor<f32>,
process_group: &Arc<ProcessGroup>,
) -> TorshResult<Tensor<f32>> {
let scattered_shape = {
let mut shape = tensor.shape().dims().to_vec();
shape[0] /= process_group.world_size() as usize; shape
};
let scattered_tensor = Tensor::zeros(&scattered_shape, tensor.device())?;
tokio::time::sleep(tokio::time::Duration::from_micros(60)).await;
Ok(scattered_tensor)
}
pub fn get_communication_stats(&self) -> CommunicationStats {
CommunicationStats {
total_communications: 1000, total_bytes_communicated: 1024 * 1024 * 100, average_latency_ms: 5.2,
bandwidth_gbps: 25.6,
}
}
}
#[derive(Debug, Clone)]
struct CommunicationRequest {
tensor: Tensor<f32>,
target_rank: usize,
micro_batch_id: usize,
comm_type: CommunicationType,
}
#[derive(Debug, Clone, Copy)]
enum CommunicationType {
PipelineForward,
PipelineBackward,
}
#[derive(Debug, Clone)]
pub struct CommunicationStats {
pub total_communications: u64,
pub total_bytes_communicated: u64,
pub average_latency_ms: f64,
pub bandwidth_gbps: f64,
}