use super::Communicator;
use crate::errors::TrustformersError;
use crate::tensor::Tensor;
use std::sync::Arc;
type Result<T> = std::result::Result<T, TrustformersError>;
#[cfg(feature = "nccl")]
pub struct NcclCommunicator {
rank: usize,
world_size: usize,
_comm_handle: NcclCommHandle,
device_id: i32,
}
#[cfg(feature = "nccl")]
struct NcclCommHandle {
_placeholder: std::marker::PhantomData<()>,
}
#[cfg(feature = "nccl")]
impl NcclCommunicator {
pub fn new(rank: usize, world_size: usize, device_id: i32) -> Result<Self> {
if rank >= world_size {
return Err(TrustformersError::invalid_input(format!(
"Rank {} must be less than world_size {}",
rank, world_size
)));
}
let comm_handle = NcclCommHandle {
_placeholder: std::marker::PhantomData,
};
Ok(Self {
rank,
world_size,
_comm_handle: comm_handle,
device_id,
})
}
pub fn init_all(world_size: usize, device_ids: &[i32]) -> Result<Vec<Self>> {
if device_ids.len() != world_size {
return Err(TrustformersError::invalid_input(
"Number of device IDs must match world size".to_string(),
));
}
let mut communicators = Vec::with_capacity(world_size);
for (rank, &device_id) in device_ids.iter().enumerate().take(world_size) {
let comm = Self::new(rank, world_size, device_id)?;
communicators.push(comm);
}
Ok(communicators)
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
pub fn device_id(&self) -> i32 {
self.device_id
}
fn nccl_all_reduce(&self, tensor: &mut Tensor) -> Result<()> {
if tensor.shape().is_empty() {
return Err(TrustformersError::invalid_input(
"Cannot perform all-reduce on empty tensor".to_string(),
));
}
log::debug!(
"NCCL all-reduce on device {} for tensor with shape {:?}",
self.device_id,
tensor.shape()
);
Ok(())
}
fn nccl_all_gather(&self, tensor: &Tensor) -> Result<Tensor> {
let input_shape = tensor.shape();
let mut output_shape = vec![self.world_size];
output_shape.extend_from_slice(&input_shape);
log::debug!(
"NCCL all-gather on device {} for tensor with shape {:?} -> {:?}",
self.device_id,
input_shape,
output_shape
);
Tensor::zeros(&output_shape)
}
fn nccl_broadcast(&self, tensor: &mut Tensor, root: usize) -> Result<()> {
if root >= self.world_size {
return Err(TrustformersError::invalid_input(format!(
"Root rank {} must be less than world_size {}",
root, self.world_size
)));
}
log::debug!(
"NCCL broadcast on device {} from root {} for tensor with shape {:?}",
self.device_id,
root,
tensor.shape()
);
Ok(())
}
pub fn barrier(&self) -> Result<()> {
log::debug!("NCCL barrier on device {}", self.device_id);
Ok(())
}
pub fn destroy(&mut self) -> Result<()> {
log::debug!("Destroying NCCL communicator on device {}", self.device_id);
Ok(())
}
}
#[cfg(feature = "nccl")]
impl Communicator for NcclCommunicator {
fn all_reduce(&self, tensor: &mut Tensor) -> Result<()> {
self.nccl_all_reduce(tensor)
}
fn all_gather(&self, tensor: &Tensor, _split_dim: usize) -> Result<Tensor> {
self.nccl_all_gather(tensor)
}
fn reduce_scatter(&self, tensor: &Tensor, _split_dim: usize) -> Result<Tensor> {
log::debug!(
"NCCL reduce-scatter on device {} for tensor with shape {:?}",
self.device_id,
tensor.shape()
);
let input_shape = tensor.shape();
if input_shape.is_empty() {
return Err(TrustformersError::invalid_input(
"Cannot perform reduce-scatter on empty tensor".to_string(),
));
}
let mut output_shape = input_shape.to_vec();
output_shape[0] /= self.world_size;
Tensor::zeros(&output_shape)
}
fn broadcast(&self, tensor: &mut Tensor, root: usize) -> Result<()> {
self.nccl_broadcast(tensor, root)
}
fn send(&self, _tensor: &Tensor, _dest: usize) -> Result<()> {
Err(TrustformersError::runtime_error(
"Point-to-point send not yet implemented for NCCL backend".to_string(),
))
}
fn recv(&self, _shape: &[usize], _src: usize) -> Result<Tensor> {
Err(TrustformersError::runtime_error(
"Point-to-point recv not yet implemented for NCCL backend".to_string(),
))
}
}
#[cfg(feature = "nccl")]
impl Drop for NcclCommunicator {
fn drop(&mut self) {
if let Err(e) = self.destroy() {
log::error!("Failed to destroy NCCL communicator: {}", e);
}
}
}
#[cfg(feature = "nccl")]
pub fn create_nccl_communicator(
rank: usize,
world_size: usize,
device_id: i32,
) -> Result<Arc<dyn Communicator>> {
let comm = NcclCommunicator::new(rank, world_size, device_id)?;
Ok(Arc::new(comm))
}
#[cfg(not(feature = "nccl"))]
pub fn create_nccl_communicator(
_rank: usize,
_world_size: usize,
_device_id: i32,
) -> Result<Arc<dyn Communicator>> {
Err(TrustformersError::invalid_config(
"NCCL feature not enabled. Compile with --features nccl to use NCCL communicator"
.to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "nccl")]
fn test_nccl_communicator_creation() {
let comm = NcclCommunicator::new(0, 2, 0).expect("operation failed in test");
assert_eq!(comm.rank(), 0);
assert_eq!(comm.world_size(), 2);
assert_eq!(comm.device_id(), 0);
}
#[test]
#[cfg(feature = "nccl")]
fn test_invalid_rank() {
let result = NcclCommunicator::new(2, 2, 0);
assert!(result.is_err());
}
#[test]
#[cfg(feature = "nccl")]
fn test_nccl_operations() -> Result<()> {
let comm = NcclCommunicator::new(0, 2, 0)?;
let mut tensor = Tensor::ones(&[4, 4])?;
comm.all_reduce(&mut tensor)?;
let gathered = comm.all_gather(&tensor, 0)?;
assert_eq!(gathered.shape()[0], 2);
comm.broadcast(&mut tensor, 0)?;
comm.barrier()?;
Ok(())
}
#[test]
fn test_create_nccl_communicator_factory() {
let result = create_nccl_communicator(0, 2, 0);
#[cfg(feature = "nccl")]
assert!(result.is_ok());
#[cfg(not(feature = "nccl"))]
assert!(result.is_err());
}
}