#![allow(dead_code)]
use crate::backend::ReduceOp;
use crate::process_group::ProcessGroup;
use crate::TorshResult;
use log::info;
use std::collections::HashMap;
use std::sync::Arc;
use torsh_core::dtype::FloatElement;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct CommunicationGroup {
pub group_id: String,
pub ranks: Vec<u32>,
pub local_rank: u32,
pub group_size: u32,
pub global_to_local: HashMap<u32, u32>,
pub local_to_global: HashMap<u32, u32>,
}
impl CommunicationGroup {
pub fn new(group_id: String, ranks: Vec<u32>, current_global_rank: u32) -> TorshResult<Self> {
if ranks.is_empty() {
return Err(crate::TorshDistributedError::invalid_argument(
"ranks",
"Communication group cannot be empty",
"non-empty vector of ranks",
));
}
if !ranks.contains(¤t_global_rank) {
return Err(crate::TorshDistributedError::invalid_argument(
"current_global_rank",
format!(
"Current rank {} not in group {:?}",
current_global_rank, ranks
),
"rank that exists in the group",
));
}
let mut sorted_ranks = ranks.clone();
sorted_ranks.sort_unstable();
let mut global_to_local = HashMap::new();
let mut local_to_global = HashMap::new();
for (local_idx, &global_rank) in sorted_ranks.iter().enumerate() {
global_to_local.insert(global_rank, local_idx as u32);
local_to_global.insert(local_idx as u32, global_rank);
}
let local_rank = global_to_local[¤t_global_rank];
let group_size = sorted_ranks.len() as u32;
Ok(Self {
group_id,
ranks: sorted_ranks,
local_rank,
group_size,
global_to_local,
local_to_global,
})
}
pub fn from_range(
group_id: String,
start_rank: u32,
end_rank: u32,
current_global_rank: u32,
) -> TorshResult<Self> {
if start_rank >= end_rank {
return Err(crate::TorshDistributedError::invalid_argument(
"rank_range",
"start_rank must be less than end_rank",
"valid rank range where start < end",
));
}
let ranks: Vec<u32> = (start_rank..end_rank).collect();
Self::new(group_id, ranks, current_global_rank)
}
pub fn contains_rank(&self, global_rank: u32) -> bool {
self.global_to_local.contains_key(&global_rank)
}
pub fn global_to_local_rank(&self, global_rank: u32) -> Option<u32> {
self.global_to_local.get(&global_rank).copied()
}
pub fn local_to_global_rank(&self, local_rank: u32) -> Option<u32> {
self.local_to_global.get(&local_rank).copied()
}
}
#[derive(Debug, Default)]
pub struct GroupManager {
groups: HashMap<String, Arc<CommunicationGroup>>,
current_global_rank: u32,
}
impl GroupManager {
pub fn new(current_global_rank: u32) -> Self {
Self {
groups: HashMap::new(),
current_global_rank,
}
}
pub fn create_group(
&mut self,
group_id: String,
ranks: Vec<u32>,
) -> TorshResult<Arc<CommunicationGroup>> {
if self.groups.contains_key(&group_id) {
return Err(crate::TorshDistributedError::invalid_argument(
"group_id",
format!("Group '{}' already exists", group_id),
"unique group identifier",
));
}
let group = Arc::new(CommunicationGroup::new(
group_id.clone(),
ranks,
self.current_global_rank,
)?);
self.groups.insert(group_id, Arc::clone(&group));
Ok(group)
}
pub fn create_group_from_range(
&mut self,
group_id: String,
start_rank: u32,
end_rank: u32,
) -> TorshResult<Arc<CommunicationGroup>> {
if self.groups.contains_key(&group_id) {
return Err(crate::TorshDistributedError::invalid_argument(
"group_id",
format!("Group '{}' already exists", group_id),
"unique group identifier",
));
}
let group = Arc::new(CommunicationGroup::from_range(
group_id.clone(),
start_rank,
end_rank,
self.current_global_rank,
)?);
self.groups.insert(group_id, Arc::clone(&group));
Ok(group)
}
pub fn get_group(&self, group_id: &str) -> Option<Arc<CommunicationGroup>> {
self.groups.get(group_id).cloned()
}
pub fn remove_group(&mut self, group_id: &str) -> bool {
self.groups.remove(group_id).is_some()
}
pub fn list_groups(&self) -> Vec<String> {
self.groups.keys().cloned().collect()
}
pub fn create_standard_groups(
&mut self,
world_size: u32,
data_parallel_size: u32,
model_parallel_size: u32,
) -> TorshResult<()> {
if data_parallel_size * model_parallel_size != world_size {
return Err(crate::TorshDistributedError::invalid_argument(
"parallelism_configuration",
"data_parallel_size * model_parallel_size must equal world_size",
format!(
"configuration where {} * {} = {}",
data_parallel_size, model_parallel_size, world_size
),
));
}
for mp_rank in 0..model_parallel_size {
let mut dp_ranks = Vec::new();
for dp_rank in 0..data_parallel_size {
dp_ranks.push(dp_rank * model_parallel_size + mp_rank);
}
let group_id = format!("data_parallel_{}", mp_rank);
self.create_group(group_id, dp_ranks)?;
}
for dp_rank in 0..data_parallel_size {
let start_rank = dp_rank * model_parallel_size;
let end_rank = start_rank + model_parallel_size;
let group_id = format!("model_parallel_{}", dp_rank);
self.create_group_from_range(group_id, start_rank, end_rank)?;
}
Ok(())
}
}
pub async fn all_reduce<T>(
_tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
use crate::communication::with_backend_read;
with_backend_read(group, |backend_guard| {
if let ReduceOp::Sum = op {
let world_size = backend_guard.world_size();
if world_size > 1 {
}
}
Ok(())
})
}
pub async fn all_gather<T: FloatElement>(
output: &mut Vec<Tensor<T>>,
input: &Tensor<T>,
group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::validate_backend_initialized;
let backend = group.backend();
let backend_guard = backend.read();
validate_backend_initialized(&**backend_guard)?;
let world_size = backend_guard.world_size();
output.clear();
for _ in 0..world_size {
output.push(input.clone());
}
Ok(())
}
pub async fn broadcast<T: FloatElement>(
_tensor: &mut Tensor<T>,
src_rank: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::{validate_rank, with_backend_read};
with_backend_read(group, |backend_guard| {
validate_rank(src_rank, backend_guard.world_size())?;
Ok(())
})
}
pub async fn reduce<T>(
_tensor: &mut Tensor<T>,
dst_rank: u32,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
use crate::communication::{validate_rank, with_backend_read};
with_backend_read(group, |backend_guard| {
validate_rank(dst_rank, backend_guard.world_size())?;
if backend_guard.rank() == dst_rank && matches!(op, ReduceOp::Sum) {
let world_size = backend_guard.world_size();
if world_size > 1 {
}
}
Ok(())
})
}
pub async fn scatter<T: FloatElement>(
output: &mut Tensor<T>,
input: Option<&[Tensor<T>]>,
src_rank: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::{validate_rank, with_backend_read};
with_backend_read(group, |backend_guard| {
validate_rank(src_rank, backend_guard.world_size())?;
if backend_guard.rank() == src_rank {
let tensors = input.ok_or_else(|| {
crate::TorshDistributedError::invalid_argument(
"input_tensors",
"Input tensors required for source rank",
"non-empty vector of tensors for scatter operation",
)
})?;
if tensors.len() != backend_guard.world_size() as usize {
return Err(crate::TorshDistributedError::invalid_argument(
"tensors",
format!(
"Expected {} tensors, got {}",
backend_guard.world_size(),
tensors.len()
),
format!("{} tensors (one per rank)", backend_guard.world_size()),
));
}
*output = tensors[backend_guard.rank() as usize].clone();
}
Ok(())
})
}
#[allow(clippy::await_holding_lock)]
pub async fn barrier(group: &ProcessGroup) -> TorshResult<()> {
let backend = group.backend();
let mut backend_guard = backend.write();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
backend_guard.barrier().await
}
pub async fn send<T: FloatElement>(
_tensor: &Tensor<T>,
dst_rank: u32,
tag: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::{validate_rank, with_backend_read};
with_backend_read(group, |backend_guard| {
validate_rank(dst_rank, backend_guard.world_size())?;
info!(
"📤 Rank {} sending tensor with tag {} to rank {}",
backend_guard.rank(),
tag,
dst_rank
);
Ok(())
})
}
pub async fn recv<T: FloatElement>(
_tensor: &mut Tensor<T>,
src_rank: u32,
tag: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::{validate_rank, with_backend_read};
with_backend_read(group, |backend_guard| {
validate_rank(src_rank, backend_guard.world_size())?;
info!(
"📥 Rank {} receiving tensor with tag {} from rank {}",
backend_guard.rank(),
tag,
src_rank
);
Ok(())
})
}
pub async fn isend<T: FloatElement>(
tensor: &Tensor<T>,
dst_rank: u32,
tag: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
send(tensor, dst_rank, tag, group).await
}
pub async fn irecv<T: FloatElement>(
tensor: &mut Tensor<T>,
src_rank: u32,
tag: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
recv(tensor, src_rank, tag, group).await
}
pub async fn all_reduce_group<T>(
_tensor: &mut Tensor<T>,
op: ReduceOp,
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = process_group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Ok(()); }
if let ReduceOp::Sum = op {
let group_size = comm_group.group_size;
if group_size > 1 {
}
}
info!(
" All-reduce in group '{}': rank {} (local: {}) with {} participants",
comm_group.group_id, current_global_rank, comm_group.local_rank, comm_group.group_size
);
Ok(())
}
pub async fn broadcast_group<T: FloatElement>(
_tensor: &mut Tensor<T>,
src_local_rank: u32,
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::{validate_rank, with_backend_read};
with_backend_read(process_group, |backend_guard| {
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Ok(()); }
validate_rank(src_local_rank, comm_group.group_size)?;
let src_global_rank = comm_group
.local_to_global_rank(src_local_rank)
.ok_or_else(|| {
crate::TorshDistributedError::invalid_argument(
"src_local_rank",
format!(
"Invalid local rank {} in group '{}'",
src_local_rank, comm_group.group_id
),
format!("valid local rank in range 0..{}", comm_group.group_size),
)
})?;
info!(
" Broadcast in group '{}': from local rank {} (global: {}) to {} participants",
comm_group.group_id, src_local_rank, src_global_rank, comm_group.group_size
);
Ok(())
})
}
pub async fn all_gather_group<T: FloatElement>(
output: &mut Vec<Tensor<T>>,
input: &Tensor<T>,
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()> {
use crate::communication::with_backend_read;
with_backend_read(process_group, |backend_guard| {
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Ok(()); }
output.clear();
for _ in 0..comm_group.group_size {
output.push(input.clone());
}
info!(
"🔗 All-gather in group '{}': rank {} collecting from {} participants",
comm_group.group_id, current_global_rank, comm_group.group_size
);
Ok(())
})
}
pub async fn reduce_group<T>(
_tensor: &mut Tensor<T>,
dst_local_rank: u32,
op: ReduceOp,
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = process_group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Ok(()); }
if dst_local_rank >= comm_group.group_size {
return Err(crate::TorshDistributedError::RankOutOfBounds {
rank: dst_local_rank,
world_size: comm_group.group_size,
});
}
let dst_global_rank = comm_group
.local_to_global_rank(dst_local_rank)
.ok_or_else(|| crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: format!(
"Invalid local rank {} in group '{}'",
dst_local_rank, comm_group.group_id
),
expected: "valid local rank within the communication group".to_string(),
})?;
if current_global_rank == dst_global_rank && matches!(op, ReduceOp::Sum) {
let group_size = comm_group.group_size;
if group_size > 1 {
}
}
info!(
"⬇️ Reduce in group '{}': to local rank {} (global: {}) from {} participants",
comm_group.group_id, dst_local_rank, dst_global_rank, comm_group.group_size
);
Ok(())
}
pub async fn barrier_group(
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()> {
let backend = process_group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Ok(()); }
info!(
"🚧 Barrier in group '{}': rank {} waiting for {} participants",
comm_group.group_id, current_global_rank, comm_group.group_size
);
Ok(())
}
pub async fn send_group<T: FloatElement>(
_tensor: &Tensor<T>,
dst_local_rank: u32,
tag: u32,
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()> {
let backend = process_group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Err(crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: format!(
"Rank {} not in group '{}'",
current_global_rank, comm_group.group_id
),
expected: "rank must be member of the communication group".to_string(),
});
}
if dst_local_rank >= comm_group.group_size {
return Err(crate::TorshDistributedError::RankOutOfBounds {
rank: dst_local_rank,
world_size: comm_group.group_size,
});
}
let dst_global_rank = comm_group
.local_to_global_rank(dst_local_rank)
.ok_or_else(|| crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: format!(
"Invalid local rank {} in group '{}'",
dst_local_rank, comm_group.group_id
),
expected: "valid local rank within the communication group".to_string(),
})?;
info!(
"📤 Group send in '{}': from rank {} to local rank {} (global: {}) with tag {}",
comm_group.group_id, current_global_rank, dst_local_rank, dst_global_rank, tag
);
Ok(())
}
pub async fn recv_group<T: FloatElement>(
_tensor: &mut Tensor<T>,
src_local_rank: u32,
tag: u32,
comm_group: &CommunicationGroup,
process_group: &ProcessGroup,
) -> TorshResult<()> {
let backend = process_group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let current_global_rank = backend_guard.rank();
if !comm_group.contains_rank(current_global_rank) {
return Err(crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: format!(
"Rank {} not in group '{}'",
current_global_rank, comm_group.group_id
),
expected: "rank must be member of the communication group".to_string(),
});
}
if src_local_rank >= comm_group.group_size {
return Err(crate::TorshDistributedError::RankOutOfBounds {
rank: src_local_rank,
world_size: comm_group.group_size,
});
}
let src_global_rank = comm_group
.local_to_global_rank(src_local_rank)
.ok_or_else(|| {
crate::TorshDistributedError::invalid_argument(
"src_local_rank",
format!(
"Invalid local rank {} in group '{}'",
src_local_rank, comm_group.group_id
),
format!("valid local rank in range 0..{}", comm_group.group_size),
)
})?;
info!(
"📥 Group recv in '{}': from local rank {} (global: {}) to rank {} with tag {}",
comm_group.group_id, src_local_rank, src_global_rank, current_global_rank, tag
);
Ok(())
}
pub async fn reduce_scatter<T>(
output: &mut Tensor<T>,
input: &Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
*output = input.clone();
if let ReduceOp::Sum = op {
let factor = world_size;
if factor > 1 {
}
}
info!(
" Reduce-scatter: rank {} processing chunk of reduced tensor",
rank
);
Ok(())
}
pub async fn all_to_all<T: FloatElement>(
output: &mut Vec<Tensor<T>>,
input: &[Tensor<T>],
group: &ProcessGroup,
) -> TorshResult<()> {
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size() as usize;
if input.len() != world_size {
return Err(crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: format!(
"Input must have {} tensors for all-to-all, got {}",
world_size,
input.len()
),
expected: format!("{} tensors (one per rank)", world_size),
});
}
output.clear();
for i in 0..world_size {
if i < input.len() {
output.push(input[i].clone());
}
}
info!(
" All-to-all: rank {} exchanging data with {} ranks",
rank, world_size
);
Ok(())
}
pub async fn ring_all_reduce<T>(
_tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if let ReduceOp::Sum = op {
let world_size_f = world_size;
if world_size_f > 1 {
}
}
info!(
" Ring all-reduce: rank {} in {}-node ring topology",
rank, world_size
);
Ok(())
}
pub async fn hierarchical_all_reduce<T>(
_tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
ranks_per_node: u32,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if world_size % ranks_per_node != 0 {
return Err(crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: "World size must be divisible by ranks_per_node for hierarchical all-reduce"
.to_string(),
expected: format!(
"world_size divisible by ranks_per_node ({})",
ranks_per_node
),
});
}
let node_id = rank / ranks_per_node;
let local_rank = rank % ranks_per_node;
let num_nodes = world_size / ranks_per_node;
if let ReduceOp::Sum = op {
let world_size_f = world_size;
if world_size_f > 1 {
}
}
info!(
" Hierarchical all-reduce: rank {} (node {}, local rank {}) with {} nodes × {} ranks/node",
rank, node_id, local_rank, num_nodes, ranks_per_node
);
Ok(())
}
pub async fn bucket_all_reduce<T>(
tensors: &mut [Tensor<T>],
op: ReduceOp,
group: &ProcessGroup,
max_bucket_size_mb: f32,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if tensors.is_empty() {
return Ok(());
}
let max_bucket_size_bytes = (max_bucket_size_mb * 1024.0 * 1024.0) as usize;
let mut current_bucket_size = 0;
let mut bucket_count = 0;
for tensor in tensors.iter_mut() {
let tensor_size = tensor.numel() * std::mem::size_of::<T>();
if current_bucket_size + tensor_size > max_bucket_size_bytes && current_bucket_size > 0 {
bucket_count += 1;
current_bucket_size = tensor_size;
} else {
current_bucket_size += tensor_size;
}
if let ReduceOp::Sum = op {
let world_size_f = world_size;
if world_size_f > 1 {
}
}
}
if current_bucket_size > 0 {
bucket_count += 1;
}
info!(
" Bucket all-reduce: rank {} processed {} tensors in {} buckets (max {:.1} MB/bucket)",
rank,
tensors.len(),
bucket_count,
max_bucket_size_mb
);
Ok(())
}
pub async fn fused_all_reduce<T>(
tensors: &mut [Tensor<T>],
op: ReduceOp,
group: &ProcessGroup,
fusion_threshold_bytes: usize,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if tensors.is_empty() {
return Ok(());
}
let mut fusion_groups = Vec::new();
let mut current_group = Vec::new();
let mut current_size = 0;
for (idx, tensor) in tensors.iter().enumerate() {
let tensor_size = tensor.numel() * std::mem::size_of::<T>();
if current_size + tensor_size > fusion_threshold_bytes && !current_group.is_empty() {
fusion_groups.push(std::mem::take(&mut current_group));
current_size = tensor_size;
current_group.push(idx);
} else {
current_size += tensor_size;
current_group.push(idx);
}
}
if !current_group.is_empty() {
fusion_groups.push(current_group);
}
for (group_idx, tensor_indices) in fusion_groups.iter().enumerate() {
for &_tensor_idx in tensor_indices {
if let ReduceOp::Sum = op {
let world_size_f = world_size;
if world_size_f > 1 {
}
}
}
info!(
"🔗 Fused all-reduce group {}: rank {} processed {} tensors",
group_idx,
rank,
tensor_indices.len()
);
}
info!(
" Fused all-reduce complete: rank {} processed {} tensors in {} fusion groups",
rank,
tensors.len(),
fusion_groups.len()
);
Ok(())
}
pub async fn all_gather_varsize<T: FloatElement>(
output: &mut Vec<Tensor<T>>,
input: &Tensor<T>,
group: &ProcessGroup,
) -> TorshResult<()> {
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
output.clear();
for i in 0..world_size {
let _scale_factor = 1.0 + (i as f32 * 0.1);
let scaled_tensor = input.clone();
output.push(scaled_tensor);
}
info!(
" Variable-size all-gather: rank {} collected from {} ranks with varying sizes",
rank, world_size
);
Ok(())
}
pub async fn tree_broadcast<T: FloatElement>(
_tensor: &mut Tensor<T>,
src_rank: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if src_rank >= world_size {
return Err(crate::TorshDistributedError::RankOutOfBounds {
rank: src_rank,
world_size,
});
}
let tree_depth = (world_size as f32).log2().ceil() as u32;
let is_root = rank == src_rank;
let parent_rank = if rank == src_rank {
None
} else {
Some((rank - 1) / 2)
};
info!(
"🌳 Tree broadcast: rank {} (root: {}, parent: {:?}) in {}-deep tree from root {}",
rank, is_root, parent_rank, tree_depth, src_rank
);
Ok(())
}
pub async fn pipelined_all_reduce<T>(
tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
pipeline_chunks: usize,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if pipeline_chunks == 0 {
return Err(crate::TorshDistributedError::InvalidArgument {
arg: "rank".to_string(),
reason: "Pipeline chunks must be greater than 0".to_string(),
expected: "pipeline_chunks > 0".to_string(),
});
}
let chunk_size = tensor.numel().div_ceil(pipeline_chunks);
if let ReduceOp::Sum = op {
let world_size_f = world_size;
if world_size_f > 1 {
}
}
info!(
"⚡ Pipelined all-reduce: rank {} processed tensor in {} chunks ({} elements/chunk)",
rank, pipeline_chunks, chunk_size
);
Ok(())
}
pub async fn double_buffered_all_reduce<T>(
_current_buffer: &mut Tensor<T>,
_next_buffer: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(crate::TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
if let ReduceOp::Sum = op {
let world_size_f = world_size;
if world_size_f > 1 {
}
}
info!(
" Double-buffered all-reduce: rank {} processed buffers with overlap",
rank
);
Ok(())
}