use super::{DistributedOps, ProcessGroup, ReduceOp};
use crate::error::{RusTorchError, RusTorchResult};
use crate::gpu::DeviceType;
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::ffi::c_void;
use std::ptr;
use std::sync::{Arc, Mutex};
#[cfg(feature = "nccl")]
pub struct NCCLCommunicator {
comm: *mut c_void,
device_id: usize,
rank: usize,
nranks: usize,
streams: Vec<*mut c_void>,
}
#[cfg(feature = "nccl")]
unsafe impl Send for NCCLCommunicator {}
#[cfg(feature = "nccl")]
unsafe impl Sync for NCCLCommunicator {}
#[cfg(feature = "nccl")]
impl NCCLCommunicator {
pub fn new(
rank: usize,
nranks: usize,
device_id: usize,
comm_id: &NCCLUniqueId,
) -> RusTorchResult<Self> {
let comm: *mut c_void = ptr::null_mut();
let streams = Vec::new();
Ok(Self {
comm,
device_id,
rank,
nranks,
streams,
})
}
pub fn all_reduce<T: Float>(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
if !matches!(tensor.device, crate::tensor::device::Device::Cuda(_)) {
return Err(RusTorchError::gpu(
"Tensor must be on CUDA device for NCCL".to_string(),
));
}
let _ = (tensor, op);
Ok(())
}
pub fn all_gather<T: Float>(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
if !matches!(tensor.device, crate::tensor::device::Device::Cuda(_)) {
return Err(RusTorchError::gpu(
"Tensor must be on CUDA device for NCCL".to_string(),
));
}
let mut output_tensors = Vec::with_capacity(self.nranks);
for _ in 0..self.nranks {
output_tensors.push(tensor.clone());
}
Ok(output_tensors)
}
pub fn broadcast<T: Float>(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
if !matches!(tensor.device, crate::tensor::device::Device::Cuda(_)) {
return Err(RusTorchError::gpu(
"Tensor must be on CUDA device for NCCL".to_string(),
));
}
if root >= self.nranks {
return Err(RusTorchError::distributed(&format!(
"Root rank {} exceeds communicator size {}",
root, self.nranks
)));
}
let _ = (tensor, root);
Ok(())
}
pub fn device_id(&self) -> usize {
self.device_id
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn nranks(&self) -> usize {
self.nranks
}
}
#[cfg(feature = "nccl")]
impl Drop for NCCLCommunicator {
fn drop(&mut self) {
if !self.comm.is_null() {
}
for stream in &self.streams {
if !stream.is_null() {
}
}
}
}
#[cfg(feature = "nccl")]
#[derive(Debug, Clone)]
pub struct NCCLUniqueId {
id: [u8; 128], }
#[cfg(feature = "nccl")]
impl NCCLUniqueId {
pub fn new() -> RusTorchResult<Self> {
let id = [0u8; 128];
Ok(Self { id })
}
pub fn as_bytes(&self) -> &[u8] {
&self.id
}
pub fn from_bytes(bytes: &[u8]) -> RusTorchResult<Self> {
if bytes.len() != 128 {
return Err(RusTorchError::distributed("Invalid NCCL ID length"));
}
let mut id = [0u8; 128];
id.copy_from_slice(bytes);
Ok(Self { id })
}
}
#[cfg(feature = "nccl")]
impl Default for NCCLUniqueId {
fn default() -> Self {
Self::new().unwrap()
}
}
#[cfg(feature = "nccl")]
pub struct NCCLBackendOptimized {
communicators: HashMap<usize, Arc<Mutex<NCCLCommunicator>>>,
process_group: ProcessGroup,
optimizations: NCCLOptimizations,
}
#[cfg(feature = "nccl")]
#[derive(Debug, Clone)]
pub struct NCCLOptimizations {
pub compression_enabled: bool,
pub bucket_size_mb: usize,
pub num_streams: usize,
pub async_enabled: bool,
pub timeout_seconds: u64,
}
#[cfg(feature = "nccl")]
impl Default for NCCLOptimizations {
fn default() -> Self {
Self {
compression_enabled: false,
bucket_size_mb: 25,
num_streams: 4,
async_enabled: true,
timeout_seconds: 300, }
}
}
#[cfg(feature = "nccl")]
impl NCCLBackendOptimized {
pub fn new(
process_group: ProcessGroup,
device_ids: &[usize],
comm_id: &NCCLUniqueId,
) -> RusTorchResult<Self> {
let mut communicators = HashMap::new();
for &device_id in device_ids {
let comm = NCCLCommunicator::new(
process_group.rank,
process_group.world_size,
device_id,
comm_id,
)?;
communicators.insert(device_id, Arc::new(Mutex::new(comm)));
}
Ok(Self {
communicators,
process_group,
optimizations: NCCLOptimizations::default(),
})
}
pub fn configure_optimizations(&mut self, opts: NCCLOptimizations) {
self.optimizations = opts;
}
pub fn get_communicator(&self, device_id: usize) -> Option<Arc<Mutex<NCCLCommunicator>>> {
self.communicators.get(&device_id).cloned()
}
pub fn all_reduce_bucketed<T: Float + 'static>(
&self,
tensors: &mut [Tensor<T>],
op: ReduceOp,
) -> RusTorchResult<()> {
let buckets = self.create_gradient_buckets(tensors)?;
for bucket in buckets {
for mut tensor in bucket {
self.all_reduce_single(&mut tensor, op)?;
}
}
Ok(())
}
fn create_gradient_buckets<T: Float + 'static>(
&self,
tensors: &[Tensor<T>],
) -> RusTorchResult<Vec<Vec<Tensor<T>>>> {
let bucket_size_bytes = self.optimizations.bucket_size_mb * 1024 * 1024;
let mut buckets = Vec::new();
let mut current_bucket = Vec::new();
let mut current_size = 0;
for tensor in tensors {
let tensor_size = tensor.numel() * std::mem::size_of::<T>();
if current_size + tensor_size > bucket_size_bytes && !current_bucket.is_empty() {
buckets.push(current_bucket.clone());
current_bucket.clear();
current_size = 0;
}
current_bucket.push(tensor.clone());
current_size += tensor_size;
}
if !current_bucket.is_empty() {
buckets.push(current_bucket);
}
Ok(buckets)
}
fn all_reduce_single<T: Float>(
&self,
tensor: &mut Tensor<T>,
op: ReduceOp,
) -> RusTorchResult<()> {
let device_id = match tensor.device {
crate::tensor::device::Device::Cuda(id) => id,
_ => 0,
};
if let Some(comm_arc) = self.get_communicator(device_id) {
let comm = comm_arc.lock().unwrap();
comm.all_reduce(tensor, op)?;
} else {
return Err(RusTorchError::distributed(&format!(
"No NCCL communicator found for device {}",
device_id
)));
}
Ok(())
}
}
#[cfg(feature = "nccl")]
pub struct NCCLOps;
#[cfg(feature = "nccl")]
impl NCCLOps {
pub fn init_multi_gpu(
device_ids: &[usize],
) -> RusTorchResult<HashMap<usize, NCCLCommunicator>> {
if device_ids.is_empty() {
return Err(RusTorchError::distributed("No device IDs provided"));
}
let nranks = device_ids.len();
let comm_id = NCCLUniqueId::new()?;
let mut communicators = HashMap::new();
for (rank, &device_id) in device_ids.iter().enumerate() {
let comm = NCCLCommunicator::new(rank, nranks, device_id, &comm_id)?;
communicators.insert(device_id, comm);
}
Ok(communicators)
}
pub fn all_reduce_multi_gpu<T: Float>(
tensors: &mut [Tensor<T>],
communicators: &HashMap<usize, NCCLCommunicator>,
op: ReduceOp,
) -> RusTorchResult<()> {
for tensor in tensors.iter_mut() {
match tensor.device {
crate::tensor::device::Device::Cuda(device_id) => {
if let Some(comm) = communicators.get(&device_id) {
comm.all_reduce(tensor, op)?;
}
}
_ => {} }
}
Self::synchronize_all_streams(communicators)?;
Ok(())
}
fn synchronize_all_streams(
_communicators: &HashMap<usize, NCCLCommunicator>,
) -> RusTorchResult<()> {
Ok(())
}
pub fn get_optimal_config(num_gpus: usize, gpu_memory_gb: f32) -> NCCLOptimizations {
let bucket_size_mb = if gpu_memory_gb > 16.0 {
50 } else {
25 };
let num_streams = if num_gpus > 4 {
8 } else {
4 };
NCCLOptimizations {
compression_enabled: num_gpus > 8, bucket_size_mb,
num_streams,
async_enabled: true,
timeout_seconds: 300,
}
}
}
#[cfg(feature = "nccl")]
pub struct NCCLProfiler {
timing_data: Arc<Mutex<HashMap<String, Vec<f64>>>>,
bandwidth_data: Arc<Mutex<HashMap<String, f64>>>,
}
#[cfg(feature = "nccl")]
impl NCCLProfiler {
pub fn new() -> Self {
Self {
timing_data: Arc::new(Mutex::new(HashMap::new())),
bandwidth_data: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn start_timing(&self, operation: &str) -> TimingGuard {
TimingGuard::new(operation.to_string(), Arc::clone(&self.timing_data))
}
pub fn get_timing_stats(&self) -> HashMap<String, TimingStats> {
let data = self.timing_data.lock().unwrap();
let mut stats = HashMap::new();
for (op, times) in data.iter() {
if !times.is_empty() {
let sum: f64 = times.iter().sum();
let avg = sum / times.len() as f64;
let min = times.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max = times.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
stats.insert(
op.clone(),
TimingStats {
count: times.len(),
average_ms: avg,
min_ms: min,
max_ms: max,
total_ms: sum,
},
);
}
}
stats
}
pub fn calculate_bandwidth(&self, operation: &str, bytes_transferred: usize, duration_ms: f64) {
if duration_ms > 0.0 {
let bandwidth_gbps = (bytes_transferred as f64 * 8.0) / (duration_ms * 1_000_000.0);
let mut data = self.bandwidth_data.lock().unwrap();
data.insert(operation.to_string(), bandwidth_gbps);
}
}
}
#[cfg(feature = "nccl")]
pub struct TimingGuard {
operation: String,
start_time: std::time::Instant,
timing_data: Arc<Mutex<HashMap<String, Vec<f64>>>>,
}
#[cfg(feature = "nccl")]
impl TimingGuard {
fn new(operation: String, timing_data: Arc<Mutex<HashMap<String, Vec<f64>>>>) -> Self {
Self {
operation,
start_time: std::time::Instant::now(),
timing_data,
}
}
}
#[cfg(feature = "nccl")]
impl Drop for TimingGuard {
fn drop(&mut self) {
let duration_ms = self.start_time.elapsed().as_secs_f64() * 1000.0;
let mut data = self.timing_data.lock().unwrap();
data.entry(self.operation.clone())
.or_insert_with(Vec::new)
.push(duration_ms);
}
}
#[derive(Debug, Clone)]
pub struct TimingStats {
pub count: usize,
pub average_ms: f64,
pub min_ms: f64,
pub max_ms: f64,
pub total_ms: f64,
}
#[cfg(not(feature = "nccl"))]
pub mod fallback {
use super::*;
pub fn nccl_not_available_error() -> RusTorchError {
RusTorchError::backend_unavailable(
"NCCL backend not available. Compile with --features nccl",
)
}
pub fn init_multi_gpu(_device_ids: &[usize]) -> RusTorchResult<()> {
Err(nccl_not_available_error())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "nccl")]
#[test]
fn test_nccl_unique_id() {
let id1 = NCCLUniqueId::new().unwrap();
let bytes = id1.as_bytes();
let id2 = NCCLUniqueId::from_bytes(bytes).unwrap();
assert_eq!(id1.as_bytes(), id2.as_bytes());
}
#[cfg(feature = "nccl")]
#[test]
fn test_nccl_optimizations() {
let opts = NCCLOps::get_optimal_config(8, 32.0);
assert!(opts.compression_enabled);
assert_eq!(opts.bucket_size_mb, 50);
assert_eq!(opts.num_streams, 8);
let opts = NCCLOps::get_optimal_config(2, 8.0);
assert!(!opts.compression_enabled);
assert_eq!(opts.bucket_size_mb, 25);
assert_eq!(opts.num_streams, 4);
}
#[cfg(feature = "nccl")]
#[test]
fn test_nccl_profiler() {
let profiler = NCCLProfiler::new();
{
let _guard = profiler.start_timing("test_op");
std::thread::sleep(std::time::Duration::from_millis(10));
}
let stats = profiler.get_timing_stats();
assert!(stats.contains_key("test_op"));
assert_eq!(stats["test_op"].count, 1);
assert!(stats["test_op"].average_ms >= 10.0);
}
#[cfg(not(feature = "nccl"))]
#[test]
fn test_fallback_error() {
let error = fallback::nccl_not_available_error();
match error {
RusTorchError::BackendUnavailable { .. } => (),
_ => panic!("Expected BackendUnavailable error"),
}
}
}