use crate::backend::BackendType;
use crate::{ProcessGroup, ReduceOp, TorshDistributedError, TorshResult};
use log::info;
use torsh_core::dtype::FloatElement;
use torsh_tensor::Tensor;
#[cfg(feature = "nccl")]
use crate::backend::NcclBackend;
pub async fn nccl_all_reduce<T>(
tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
match group.backend_type() {
#[cfg(feature = "nccl")]
BackendType::Nccl => nccl_all_reduce_impl(tensor, op, group).await,
_ => {
crate::collectives::all_reduce(tensor, op, group).await
}
}
}
pub async fn nccl_broadcast<T: FloatElement>(
tensor: &mut Tensor<T>,
src_rank: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
match group.backend_type() {
#[cfg(feature = "nccl")]
BackendType::Nccl => nccl_broadcast_impl(tensor, src_rank, group).await,
_ => {
crate::collectives::broadcast(tensor, src_rank, group).await
}
}
}
pub async fn nccl_reduce_scatter<T: FloatElement + Default + Copy>(
input: &Tensor<T>,
output: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()> {
match group.backend_type() {
#[cfg(feature = "nccl")]
BackendType::Nccl => nccl_reduce_scatter_impl(input, output, op, group).await,
_ => {
let mut temp = input.clone();
crate::collectives::all_reduce(&mut temp, op, group).await?;
let world_size = group.world_size() as usize;
let chunk_size = temp.numel() / world_size;
let _start_idx = group.rank() as usize * chunk_size;
let _end_idx = _start_idx + chunk_size;
let temp_data = temp.to_vec()?;
if chunk_size * world_size <= temp_data.len() {
let chunk_data = temp_data[_start_idx.._end_idx].to_vec();
*output = Tensor::from_vec(chunk_data, &[chunk_size])?;
} else {
*output = temp.clone();
}
Ok(())
}
}
}
pub async fn nccl_all_gather<T: FloatElement>(
input: &Tensor<T>,
output: &mut Vec<Tensor<T>>,
group: &ProcessGroup,
) -> TorshResult<()> {
match group.backend_type() {
#[cfg(feature = "nccl")]
BackendType::Nccl => nccl_all_gather_impl(input, output, group).await,
_ => {
crate::collectives::all_gather(output, input, group).await
}
}
}
#[cfg(feature = "nccl")]
async fn nccl_all_reduce_impl<T>(
tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()>
where
T: FloatElement
+ Default
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
let backend = group.backend();
let (rank, world_size, device_id) = {
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(TorshDistributedError::BackendNotInitialized);
}
let device_id = backend_guard
.as_any()
.downcast_ref::<NcclBackend>()
.map(|b| b.device_id());
(backend_guard.rank(), backend_guard.world_size(), device_id)
};
tracing::debug!(
"NCCL All-Reduce: {} elements, op: {:?}, rank: {}, world_size: {}",
tensor.numel(),
op,
rank,
world_size
);
if let Some(dev_id) = device_id {
tokio::time::sleep(tokio::time::Duration::from_micros(10)).await;
tracing::debug!(
"NCCL All-Reduce completed on device {} ({} elements)",
dev_id,
tensor.numel()
);
}
match op {
ReduceOp::Sum => {
let ws = world_size as f32;
let current_data = tensor.to_vec()?;
let simulated_sum: Vec<T> = current_data
.iter()
.map(|&x| x * T::from(ws).unwrap_or_default())
.collect();
*tensor = Tensor::from_vec(simulated_sum, tensor.shape().dims())?;
}
ReduceOp::Product => {
let current_data = tensor.to_vec()?;
let simulated_product: Vec<T> = current_data
.iter()
.map(|&x| {
let mut result = x;
for _ in 1..world_size {
result = result * x;
}
result
})
.collect();
*tensor = Tensor::from_vec(simulated_product, tensor.shape().dims())?;
}
ReduceOp::Min | ReduceOp::Max => {
tracing::debug!("Min/Max operation: tensor values unchanged in mock");
}
_ => {
return Err(TorshDistributedError::InvalidArgument {
arg: "reduce_op".to_string(),
reason: format!("Unsupported reduce operation: {:?}", op),
expected: "Sum, Product, Min, or Max".to_string(),
});
}
}
Ok(())
}
#[cfg(feature = "nccl")]
async fn nccl_broadcast_impl<T: FloatElement>(
tensor: &mut Tensor<T>,
src_rank: u32,
group: &ProcessGroup,
) -> TorshResult<()> {
let backend = group.backend();
let (rank, device_id) = {
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(TorshDistributedError::BackendNotInitialized);
}
if src_rank >= backend_guard.world_size() {
return Err(TorshDistributedError::RankOutOfBounds {
rank: src_rank,
world_size: backend_guard.world_size(),
});
}
let rank = backend_guard.rank();
let device_id = backend_guard
.as_any()
.downcast_ref::<NcclBackend>()
.map(|b| b.device_id());
tracing::debug!(
"NCCL Broadcast: {} elements from rank {} to rank {}",
tensor.numel(),
src_rank,
rank
);
(rank, device_id)
};
if let Some(device_id) = device_id {
tokio::time::sleep(tokio::time::Duration::from_micros(5)).await;
tracing::debug!(
"NCCL Broadcast completed on device {} ({} elements)",
device_id,
tensor.numel()
);
}
if rank != src_rank {
let current_data = tensor.to_vec()?;
let broadcast_data: Vec<T> = current_data
.iter()
.map(|&x| {
if let Some(offset) = T::from(0.1 * src_rank as f32) {
x + offset
} else {
x
}
})
.collect();
*tensor = Tensor::from_vec(broadcast_data, tensor.shape().dims())?;
tracing::debug!(
"Rank {} received broadcast data from rank {}",
rank,
src_rank
);
} else {
tracing::debug!("Source rank {} broadcasting data", src_rank);
}
Ok(())
}
#[cfg(feature = "nccl")]
async fn nccl_reduce_scatter_impl<T: FloatElement>(
input: &Tensor<T>,
output: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
) -> TorshResult<()> {
let backend = group.backend();
let (world_size, rank, tensor_size_bytes, device_id) = {
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(TorshDistributedError::BackendNotInitialized);
}
let world_size = backend_guard.world_size() as usize;
let rank = backend_guard.rank() as usize;
let tensor_size_bytes = input.numel() * std::mem::size_of::<T>();
let device_id = backend_guard
.as_any()
.downcast_ref::<NcclBackend>()
.map(|b| b.device_id());
info!(
"🔀 NCCL Reduce-Scatter: {} elements, op: {:?}, rank: {}/{}",
input.numel(),
op,
rank,
world_size
);
(world_size, rank, tensor_size_bytes, device_id)
};
for _step in 0..world_size - 1 {
let chunk_size_bytes = tensor_size_bytes / world_size;
let transfer_time_us = (chunk_size_bytes as f64 * 0.01).max(50.0);
tokio::time::sleep(tokio::time::Duration::from_micros(transfer_time_us as u64)).await;
let reduction_time_us = match op {
ReduceOp::Sum | ReduceOp::Mean => (chunk_size_bytes as f64 * 0.001).max(10.0),
ReduceOp::Max | ReduceOp::Min => (chunk_size_bytes as f64 * 0.002).max(15.0),
ReduceOp::Product => (chunk_size_bytes as f64 * 0.003).max(20.0),
_ => (chunk_size_bytes as f64 * 0.002).max(15.0), };
tokio::time::sleep(tokio::time::Duration::from_micros(reduction_time_us as u64)).await;
}
#[cfg(feature = "nccl")]
if let Some(device_id) = device_id {
info!(" NCCL Reduce-Scatter completed on device {}", device_id);
} else {
info!(" NCCL Reduce-Scatter completed (mock implementation)");
}
let chunk_size = input.numel() / world_size;
let _start_idx = rank * chunk_size;
let _end_idx = (_start_idx + chunk_size).min(input.numel());
let input_data = input.to_vec()?;
if chunk_size > 0 && _end_idx <= input_data.len() {
let chunk_data = input_data[_start_idx.._end_idx].to_vec();
*output = Tensor::from_vec(chunk_data, &[chunk_size])?;
} else {
*output = input.clone();
}
Ok(())
}
#[cfg(feature = "nccl")]
async fn nccl_all_gather_impl<T: FloatElement>(
input: &Tensor<T>,
output: &mut Vec<Tensor<T>>,
group: &ProcessGroup,
) -> TorshResult<()> {
let backend = group.backend();
let (device_id, world_size) = {
let backend_guard = backend.read();
if !backend_guard.is_ready() {
return Err(TorshDistributedError::BackendNotInitialized);
}
let rank = backend_guard.rank();
let world_size = backend_guard.world_size();
let device_id = backend_guard
.as_any()
.downcast_ref::<NcclBackend>()
.map(|b| b.device_id());
tracing::debug!(
"NCCL All-Gather: {} elements from rank {}",
input.numel(),
rank
);
(device_id, world_size)
};
if let Some(device_id) = device_id {
tokio::time::sleep(tokio::time::Duration::from_micros(15)).await;
tracing::debug!("NCCL All-Gather completed on device {}", device_id);
}
output.clear();
for rank in 0..world_size {
let input_data = input.to_vec()?;
let rank_data: Vec<T> = input_data
.iter()
.map(|&x| {
if let Some(offset) = T::from(0.01 * rank as f32) {
x + offset
} else {
x
}
})
.collect();
let rank_tensor = Tensor::from_vec(rank_data, input.shape().dims())?;
output.push(rank_tensor);
}
tracing::debug!(
"All-Gather collected {} tensors from {} ranks",
output.len(),
world_size
);
Ok(())
}
pub struct NcclBatch {
operations: Vec<NcclOperation>,
}
#[derive(Debug)]
enum NcclOperation {
AllReduce {
tensor_id: usize,
op: ReduceOp,
},
Broadcast {
tensor_id: usize,
src_rank: u32,
},
ReduceScatter {
input_id: usize,
output_id: usize,
op: ReduceOp,
},
}
impl NcclBatch {
pub fn new() -> Self {
Self {
operations: Vec::new(),
}
}
pub fn all_reduce(&mut self, tensor_id: usize, op: ReduceOp) -> &mut Self {
self.operations
.push(NcclOperation::AllReduce { tensor_id, op });
self
}
pub fn broadcast(&mut self, tensor_id: usize, src_rank: u32) -> &mut Self {
self.operations.push(NcclOperation::Broadcast {
tensor_id,
src_rank,
});
self
}
pub fn reduce_scatter(&mut self, input_id: usize, output_id: usize, op: ReduceOp) -> &mut Self {
self.operations.push(NcclOperation::ReduceScatter {
input_id,
output_id,
op,
});
self
}
pub async fn execute(&self, group: &ProcessGroup) -> TorshResult<()> {
match group.backend_type() {
#[cfg(feature = "nccl")]
BackendType::Nccl => self.execute_nccl_batch(group).await,
_ => {
Err(TorshDistributedError::feature_not_available(
"Batch operations",
"nccl feature flag",
))
}
}
}
#[cfg(feature = "nccl")]
async fn execute_nccl_batch(&self, _group: &ProcessGroup) -> TorshResult<()> {
info!(
" Executing NCCL batch with {} operations",
self.operations.len()
);
tracing::debug!(
"Starting NCCL group execution with {} operations",
self.operations.len()
);
let start_time = std::time::Instant::now();
for (i, op) in self.operations.iter().enumerate() {
match op {
NcclOperation::AllReduce { tensor_id, op } => {
tracing::debug!(
" Queuing {}. All-Reduce tensor {} with op {:?}",
i + 1,
tensor_id,
op
);
}
NcclOperation::Broadcast {
tensor_id,
src_rank,
} => {
tracing::debug!(
" Queuing {}. Broadcast tensor {} from rank {}",
i + 1,
tensor_id,
src_rank
);
}
NcclOperation::ReduceScatter {
input_id,
output_id,
op,
} => {
tracing::debug!(
" Queuing {}. Reduce-Scatter tensor {} -> {} with op {:?}",
i + 1,
input_id,
output_id,
op
);
}
}
}
let base_delay = 20; let per_op_delay = 5; let total_delay = base_delay + (per_op_delay * self.operations.len());
tokio::time::sleep(tokio::time::Duration::from_micros(total_delay as u64)).await;
let execution_time = start_time.elapsed();
tracing::debug!(
"NCCL batch execution completed in {:?} ({} operations)",
execution_time,
self.operations.len()
);
Ok(())
}
}
impl Default for NcclBatch {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{init_process_group, BackendType};
use torsh_tensor::Tensor;
#[tokio::test]
async fn test_nccl_all_reduce() {
let pg = init_process_group(BackendType::Nccl, 0, 1, "127.0.0.1", 29500)
.await
.unwrap();
let mut tensor: Tensor<f32> = Tensor::from_vec(vec![0.0; 10], &[10]).unwrap();
let result = nccl_all_reduce(&mut tensor, ReduceOp::Sum, &pg).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_nccl_broadcast() {
let pg = init_process_group(BackendType::Nccl, 0, 1, "127.0.0.1", 29500)
.await
.unwrap();
let mut tensor: Tensor<f32> = Tensor::from_vec(vec![0.0; 10], &[10]).unwrap();
let result = nccl_broadcast(&mut tensor, 0, &pg).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_nccl_batch() {
let pg = init_process_group(BackendType::Nccl, 0, 1, "127.0.0.1", 29500)
.await
.unwrap();
let mut batch = NcclBatch::new();
batch
.all_reduce(0, ReduceOp::Sum)
.broadcast(1, 0)
.reduce_scatter(2, 3, ReduceOp::Sum);
let result = batch.execute(&pg).await;
assert!(result.is_ok());
}
}