use super::common::{BackendOptimizations, CommonOps};
use super::{DistributedOps, ProcessGroup, ReduceOp};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
#[cfg(feature = "nccl")]
pub struct NCCLBackend {
process_group: ProcessGroup,
comm: *mut std::ffi::c_void,
}
#[cfg(feature = "nccl")]
unsafe impl Send for NCCLBackend {}
#[cfg(feature = "nccl")]
unsafe impl Sync for NCCLBackend {}
#[cfg(feature = "nccl")]
impl NCCLBackend {
pub fn new(process_group: ProcessGroup) -> RusTorchResult<Self> {
Ok(Self {
process_group,
comm: std::ptr::null_mut(),
})
}
}
#[cfg(feature = "nccl")]
impl<T: Float + Send + Sync + 'static> DistributedOps<T> for NCCLBackend {
fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_reduce(tensor, op)
}
fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_gather(tensor, self.process_group.world_size)
}
fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_broadcast(tensor, root)
}
fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_gather(tensor, self.process_group.world_size, root)
}
fn scatter(&self, tensors: &[Tensor<T>], _root: usize) -> RusTorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(RusTorchError::tensor_op(
"Empty tensor array for scatter operation",
));
}
Ok(tensors[0].clone())
}
fn reduce(&self, _tensor: &mut Tensor<T>, _root: usize, _op: ReduceOp) -> RusTorchResult<()> {
Ok(())
}
}
#[cfg(feature = "nccl")]
impl<T: Float + Send + Sync + 'static> BackendOptimizations<T> for NCCLBackend {}
pub struct GlooBackend {
process_group: ProcessGroup,
}
#[derive(Debug, Clone, Copy)]
pub enum GlooTransport {
TCP,
InfiniBand,
SharedMemory,
}
pub struct GlooContext {}
impl GlooBackend {
pub fn new(process_group: ProcessGroup) -> RusTorchResult<Self> {
let _context = GlooContext {};
Ok(Self { process_group })
}
}
impl<T: Float + Send + Sync + 'static> DistributedOps<T> for GlooBackend {
fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_reduce(tensor, op)
}
fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_gather(tensor, self.process_group.world_size)
}
fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_broadcast(tensor, root)
}
fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_gather(tensor, self.process_group.world_size, root)
}
fn scatter(&self, tensors: &[Tensor<T>], _root: usize) -> RusTorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(RusTorchError::tensor_op(
"Empty tensor array for scatter operation",
));
}
Ok(tensors[0].clone())
}
fn reduce(&self, tensor: &mut Tensor<T>, root: usize, op: ReduceOp) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_reduce(tensor, op)?;
let _ = root; Ok(())
}
}
impl<T: Float + Send + Sync + 'static> BackendOptimizations<T> for GlooBackend {}
pub struct TCPBackend {
process_group: ProcessGroup,
}
pub struct TCPConnection {}
impl TCPBackend {
pub fn new(process_group: ProcessGroup) -> RusTorchResult<Self> {
Ok(Self { process_group })
}
}
impl<T: Float + Send + Sync + 'static> DistributedOps<T> for TCPBackend {
fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_reduce(tensor, op)
}
fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_gather(tensor, self.process_group.world_size)
}
fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_broadcast(tensor, root)
}
fn reduce(&self, tensor: &mut Tensor<T>, root: usize, op: ReduceOp) -> RusTorchResult<()> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_all_reduce(tensor, op)?;
let _ = root; Ok(())
}
fn scatter(&self, tensors: &[Tensor<T>], _root: usize) -> RusTorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(RusTorchError::tensor_op(
"Empty tensor array for scatter operation",
));
}
Ok(tensors[0].clone())
}
fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>> {
CommonOps::validate_tensor(tensor)?;
CommonOps::default_gather(tensor, self.process_group.world_size, root)
}
}
impl<T: Float + Send + Sync + 'static> BackendOptimizations<T> for TCPBackend {}
pub struct BackendFactory;
impl BackendFactory {
pub fn create_backend<T: Float + Send + Sync + 'static>(
process_group: ProcessGroup,
) -> RusTorchResult<Box<dyn DistributedOps<T> + Send + Sync>> {
match process_group.backend {
#[cfg(feature = "nccl")]
super::DistributedBackend::NCCL => {
let backend = NCCLBackend::new(process_group)?;
Ok(Box::new(backend))
}
super::DistributedBackend::Gloo => {
let backend = GlooBackend::new(process_group)?;
Ok(Box::new(backend))
}
super::DistributedBackend::TCP => {
let backend = TCPBackend::new(process_group)?;
Ok(Box::new(backend))
}
#[cfg(not(feature = "nccl"))]
super::DistributedBackend::NCCL => {
Err(RusTorchError::backend_unavailable("NCCL not compiled"))
}
super::DistributedBackend::MPI => {
Err(RusTorchError::backend_unavailable("MPI not implemented"))
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::DistributedBackend;
use super::*;
#[test]
fn test_gloo_backend_creation() {
let pg = ProcessGroup::new(
0,
4,
DistributedBackend::Gloo,
"localhost".to_string(),
12345,
);
let backend = GlooBackend::new(pg);
assert!(backend.is_ok());
}
}