#![allow(clippy::result_large_err)]
use crate::{Device, Result, Tensor, TensorError};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub enum CollectiveOp {
AllReduce(ReductionOp),
Broadcast { src_device: Device },
AllGather,
ReduceScatter(ReductionOp),
Send {
src_device: Device,
dst_device: Device,
},
Recv {
src_device: Device,
dst_device: Device,
},
}
#[derive(Debug, Clone, Copy)]
pub enum ReductionOp {
Sum,
Mean,
Max,
Min,
Product,
}
#[derive(Debug, Clone)]
pub struct CommunicationGroup {
devices: Vec<Device>,
rank_map: HashMap<Device, usize>,
}
impl CommunicationGroup {
pub fn new(devices: Vec<Device>) -> Self {
let rank_map = devices
.iter()
.enumerate()
.map(|(rank, &device)| (device, rank))
.collect();
Self { devices, rank_map }
}
pub fn devices(&self) -> &[Device] {
&self.devices
}
pub fn rank(&self, device: &Device) -> Option<usize> {
self.rank_map.get(device).copied()
}
pub fn device_at_rank(&self, rank: usize) -> Option<Device> {
self.devices.get(rank).copied()
}
pub fn size(&self) -> usize {
self.devices.len()
}
}
pub struct CollectiveManager {
groups: HashMap<String, CommunicationGroup>,
default_group: Option<String>,
}
impl CollectiveManager {
pub fn new() -> Self {
Self {
groups: HashMap::new(),
default_group: None,
}
}
pub fn create_group(&mut self, name: String, devices: Vec<Device>) -> Result<()> {
if devices.is_empty() {
return Err(TensorError::invalid_argument(
"Communication group cannot be empty".to_string(),
));
}
let group = CommunicationGroup::new(devices);
self.groups.insert(name.clone(), group);
if self.default_group.is_none() {
self.default_group = Some(name);
}
Ok(())
}
pub fn set_default_group(&mut self, name: String) -> Result<()> {
if !self.groups.contains_key(&name) {
return Err(TensorError::invalid_argument(format!(
"Group '{name}' does not exist"
)));
}
self.default_group = Some(name);
Ok(())
}
pub fn get_group(&self, name: &str) -> Option<&CommunicationGroup> {
self.groups.get(name)
}
pub fn get_default_group(&self) -> Option<&CommunicationGroup> {
self.default_group
.as_ref()
.and_then(|name| self.groups.get(name))
}
pub fn all_reduce<T>(
&self,
tensor: &Tensor<T>,
op: ReductionOp,
group_name: Option<&str>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let group = if let Some(name) = group_name {
self.get_group(name)
.ok_or_else(|| TensorError::invalid_argument(format!("Group '{name}' not found")))?
} else {
self.get_default_group()
.ok_or_else(|| TensorError::invalid_argument("No default group set".to_string()))?
};
self.simple_all_reduce(tensor, op, group)
}
pub fn broadcast<T>(
&self,
tensor: &Tensor<T>,
src_device: Device,
group_name: Option<&str>,
) -> Result<Vec<Tensor<T>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let group = if let Some(name) = group_name {
self.get_group(name)
.ok_or_else(|| TensorError::invalid_argument(format!("Group '{name}' not found")))?
} else {
self.get_default_group()
.ok_or_else(|| TensorError::invalid_argument("No default group set".to_string()))?
};
if !group.devices().contains(&src_device) {
return Err(TensorError::invalid_argument(
"Source device not in communication group".to_string(),
));
}
if tensor.device() != &src_device {
return Err(TensorError::device_mismatch(
"broadcast",
&src_device.to_string(),
&tensor.device().to_string(),
));
}
let mut results = Vec::new();
for &device in group.devices() {
let broadcasted_tensor = tensor.to_device(device)?;
results.push(broadcasted_tensor);
}
Ok(results)
}
pub fn all_gather<T>(
&self,
tensor: &Tensor<T>,
group_name: Option<&str>,
) -> Result<Vec<Tensor<T>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let group = if let Some(name) = group_name {
self.get_group(name)
.ok_or_else(|| TensorError::invalid_argument(format!("Group '{name}' not found")))?
} else {
self.get_default_group()
.ok_or_else(|| TensorError::invalid_argument("No default group set".to_string()))?
};
let cpu_tensor = tensor.to_cpu()?;
let mut results = Vec::new();
for &device in group.devices() {
let device_tensor = cpu_tensor.to_device(device)?;
results.push(device_tensor);
}
Ok(results)
}
fn simple_all_reduce<T>(
&self,
tensor: &Tensor<T>,
op: ReductionOp,
group: &CommunicationGroup,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let cpu_tensor = tensor.to_cpu()?;
let group_size = group.size();
if group_size <= 1 {
return cpu_tensor.to_device(tensor.device().clone());
}
let accumulated_tensor = cpu_tensor.clone();
match op {
ReductionOp::Sum => {
accumulated_tensor.to_device(tensor.device().clone())
}
ReductionOp::Mean => {
if let Some(data) = accumulated_tensor.as_slice() {
let mean_data: Vec<T> = data
.iter()
.map(|&x| {
x / T::from(group_size)
.expect("group_size should convert to numeric type")
})
.collect();
let mean_tensor =
Tensor::from_vec(mean_data, accumulated_tensor.shape().dims())?;
mean_tensor.to_device(tensor.device().clone())
} else {
accumulated_tensor.to_device(tensor.device().clone())
}
}
ReductionOp::Max => {
accumulated_tensor.to_device(tensor.device().clone())
}
ReductionOp::Min => {
accumulated_tensor.to_device(tensor.device().clone())
}
ReductionOp::Product => {
accumulated_tensor.to_device(tensor.device().clone())
}
}
}
pub fn all_reduce_gradients<T>(
&self,
gradients: &[Tensor<T>],
group_name: Option<&str>,
) -> Result<Vec<Tensor<T>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let group = if let Some(name) = group_name {
self.get_group(name)
.ok_or_else(|| TensorError::invalid_argument(format!("Group '{name}' not found")))?
} else {
self.get_default_group()
.ok_or_else(|| TensorError::invalid_argument("No default group set".to_string()))?
};
let mut reduced_gradients = Vec::new();
for gradient in gradients {
let reduced_gradient = self.simple_all_reduce(gradient, ReductionOp::Mean, group)?;
reduced_gradients.push(reduced_gradient);
}
Ok(reduced_gradients)
}
pub fn sync_parameters<T>(
&self,
parameters: &[Tensor<T>],
src_device: Device,
group_name: Option<&str>,
) -> Result<Vec<Vec<Tensor<T>>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let _group = if let Some(name) = group_name {
self.get_group(name)
.ok_or_else(|| TensorError::invalid_argument(format!("Group '{name}' not found")))?
} else {
self.get_default_group()
.ok_or_else(|| TensorError::invalid_argument("No default group set".to_string()))?
};
let mut synced_parameters = Vec::new();
for parameter in parameters {
let broadcasted = self.broadcast(parameter, src_device, group_name)?;
synced_parameters.push(broadcasted);
}
Ok(synced_parameters)
}
pub fn ring_all_reduce<T>(
&self,
tensor: &Tensor<T>,
group_name: Option<&str>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let group = if let Some(name) = group_name {
self.get_group(name)
.ok_or_else(|| TensorError::invalid_argument(format!("Group '{name}' not found")))?
} else {
self.get_default_group()
.ok_or_else(|| TensorError::invalid_argument("No default group set".to_string()))?
};
let group_size = group.size();
if group_size <= 1 {
return Ok(tensor.clone());
}
self.simple_all_reduce(tensor, ReductionOp::Mean, group)
}
}
impl Default for CollectiveManager {
fn default() -> Self {
Self::new()
}
}
static COLLECTIVE_MANAGER: Mutex<Option<CollectiveManager>> = Mutex::new(None);
pub fn init_collective() -> Result<()> {
let mut manager = COLLECTIVE_MANAGER
.lock()
.expect("lock should not be poisoned");
if manager.is_none() {
*manager = Some(CollectiveManager::new());
}
Ok(())
}
pub fn get_collective_manager() -> Result<Arc<Mutex<CollectiveManager>>> {
let manager = COLLECTIVE_MANAGER
.lock()
.expect("lock should not be poisoned");
if manager.is_none() {
return Err(TensorError::invalid_argument(
"Collective not initialized. Call init_collective() first".to_string(),
));
}
Ok(Arc::new(Mutex::new(CollectiveManager::new())))
}
pub fn create_process_group(name: String, devices: Vec<Device>) -> Result<()> {
init_collective()?;
let manager = get_collective_manager()?;
let mut mgr = manager.lock().expect("lock should not be poisoned");
mgr.create_group(name, devices)
}
pub fn all_reduce<T>(
tensor: &Tensor<T>,
op: ReductionOp,
group_name: Option<&str>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let manager = get_collective_manager()?;
let mgr = manager.lock().expect("lock should not be poisoned");
mgr.all_reduce(tensor, op, group_name)
}
pub fn broadcast<T>(
tensor: &Tensor<T>,
src_device: Device,
group_name: Option<&str>,
) -> Result<Vec<Tensor<T>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let manager = get_collective_manager()?;
let mgr = manager.lock().expect("lock should not be poisoned");
mgr.broadcast(tensor, src_device, group_name)
}
pub fn all_gather<T>(tensor: &Tensor<T>, group_name: Option<&str>) -> Result<Vec<Tensor<T>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let manager = get_collective_manager()?;
let mgr = manager.lock().expect("lock should not be poisoned");
mgr.all_gather(tensor, group_name)
}
pub fn all_reduce_gradients<T>(
gradients: &[Tensor<T>],
group_name: Option<&str>,
) -> Result<Vec<Tensor<T>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let manager = get_collective_manager()?;
let mgr = manager.lock().expect("lock should not be poisoned");
mgr.all_reduce_gradients(gradients, group_name)
}
pub fn sync_parameters<T>(
parameters: &[Tensor<T>],
src_device: Device,
group_name: Option<&str>,
) -> Result<Vec<Vec<Tensor<T>>>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let manager = get_collective_manager()?;
let mgr = manager.lock().expect("lock should not be poisoned");
mgr.sync_parameters(parameters, src_device, group_name)
}
pub fn ring_all_reduce<T>(tensor: &Tensor<T>, group_name: Option<&str>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ std::ops::Add<Output = T>
+ PartialOrd
+ std::ops::Mul<Output = T>
+ scirs2_core::num_traits::Float,
{
let manager = get_collective_manager()?;
let mgr = manager.lock().expect("lock should not be poisoned");
mgr.ring_all_reduce(tensor, group_name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_communication_group_creation() {
#[cfg(feature = "gpu")]
let devices = vec![Device::Cpu, Device::Gpu(0), Device::Gpu(1)];
#[cfg(not(feature = "gpu"))]
let devices = vec![Device::Cpu];
let group = CommunicationGroup::new(devices.clone());
#[cfg(feature = "gpu")]
{
assert_eq!(group.size(), 3);
assert_eq!(group.devices(), &devices);
assert_eq!(group.rank(&Device::Cpu), Some(0));
assert_eq!(group.rank(&Device::Gpu(0)), Some(1));
assert_eq!(group.rank(&Device::Gpu(1)), Some(2));
}
#[cfg(not(feature = "gpu"))]
{
assert_eq!(group.size(), 1);
assert_eq!(group.devices(), &devices);
assert_eq!(group.rank(&Device::Cpu), Some(0));
}
}
#[test]
fn test_collective_manager() {
let mut manager = CollectiveManager::new();
#[cfg(feature = "gpu")]
let devices = vec![Device::Cpu, Device::Gpu(0)];
#[cfg(not(feature = "gpu"))]
let devices = vec![Device::Cpu];
manager
.create_group("test_group".to_string(), devices)
.expect("test: operation should succeed");
let group = manager
.get_group("test_group")
.expect("test: get_group should succeed");
#[cfg(feature = "gpu")]
assert_eq!(group.size(), 2);
#[cfg(not(feature = "gpu"))]
assert_eq!(group.size(), 1);
}
#[test]
fn test_broadcast_operation() {
let mut manager = CollectiveManager::new();
let devices = vec![Device::Cpu];
manager
.create_group("test_group".to_string(), devices)
.expect("test: operation should succeed");
let tensor = Tensor::<f32>::ones(&[2, 2]);
let results = manager
.broadcast(&tensor, Device::Cpu, Some("test_group"))
.expect("test: operation should succeed");
assert_eq!(results.len(), 1);
assert_eq!(results[0].device(), &Device::Cpu);
}
}