use std::collections::HashMap;
use std::sync::Arc;
use super::Communicator;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ParallelDim {
Data,
Tensor,
Pipeline,
Expert,
}
pub struct CommunicatorGroup {
world: Arc<dyn Communicator>,
dims: HashMap<ParallelDim, Arc<dyn Communicator>>,
}
impl CommunicatorGroup {
pub fn new(
world: Arc<dyn Communicator>,
tp_size: usize,
pp_size: usize,
dp_size: usize,
) -> Result<Self> {
let ws = world.world_size();
if tp_size * pp_size * dp_size != ws {
return Err(Error::Backend(format!(
"CommunicatorGroup: tp({tp_size}) * pp({pp_size}) * dp({dp_size}) = {} != world_size({ws})",
tp_size * pp_size * dp_size,
)));
}
let rank = world.rank();
let mut dims = HashMap::new();
let tp_idx = rank % tp_size;
let pp_idx = (rank / tp_size) % pp_size;
let dp_idx = rank / (tp_size * pp_size);
if tp_size > 1 {
let tp_color = (dp_idx * pp_size + pp_idx) as u32;
if let Some(comm) = world.split(tp_color, tp_idx as u32)? {
dims.insert(ParallelDim::Tensor, Arc::from(comm));
}
}
if pp_size > 1 {
let color_offset = dp_size * pp_size;
let pp_color = (color_offset + dp_idx * tp_size + tp_idx) as u32;
if let Some(comm) = world.split(pp_color, pp_idx as u32)? {
dims.insert(ParallelDim::Pipeline, Arc::from(comm));
}
}
if dp_size > 1 {
let color_offset = dp_size * pp_size + dp_size * tp_size;
let dp_color = (color_offset + pp_idx * tp_size + tp_idx) as u32;
if let Some(comm) = world.split(dp_color, dp_idx as u32)? {
dims.insert(ParallelDim::Data, Arc::from(comm));
}
}
Ok(Self { world, dims })
}
pub fn world(&self) -> &Arc<dyn Communicator> {
&self.world
}
pub fn tp(&self) -> Option<&Arc<dyn Communicator>> {
self.dims.get(&ParallelDim::Tensor)
}
pub fn pp(&self) -> Option<&Arc<dyn Communicator>> {
self.dims.get(&ParallelDim::Pipeline)
}
pub fn dp(&self) -> Option<&Arc<dyn Communicator>> {
self.dims.get(&ParallelDim::Data)
}
pub fn get(&self, dim: ParallelDim) -> Option<&Arc<dyn Communicator>> {
self.dims.get(&dim)
}
pub fn set_expert(&mut self, comm: Arc<dyn Communicator>) {
self.dims.insert(ParallelDim::Expert, comm);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::communicator::NoOpCommunicator;
#[test]
fn test_parallel_dim_eq() {
assert_eq!(ParallelDim::Data, ParallelDim::Data);
assert_ne!(ParallelDim::Data, ParallelDim::Tensor);
}
#[test]
fn test_single_rank_group() {
let world = Arc::new(NoOpCommunicator) as Arc<dyn Communicator>;
let group = CommunicatorGroup::new(world, 1, 1, 1).unwrap();
assert_eq!(group.world().world_size(), 1);
assert!(group.tp().is_none());
assert!(group.pp().is_none());
assert!(group.dp().is_none());
}
#[test]
fn test_invalid_dimensions() {
let world = Arc::new(NoOpCommunicator) as Arc<dyn Communicator>;
let result = CommunicatorGroup::new(world, 2, 2, 2);
assert!(result.is_err());
}
}