rustorch 0.6.29

Production-ready PyTorch-compatible deep learning library in Rust with special mathematical functions (gamma, Bessel, error functions), statistical distributions, Fourier transforms (FFT/RFFT), matrix decomposition (SVD/QR/LU/eigenvalue), automatic differentiation, neural networks, computer vision transforms, complete GPU acceleration (CUDA/Metal/OpenCL), SIMD optimizations, parallel processing, WebAssembly browser support, comprehensive distributed learning support, and performance validation
Documentation
//! Communication backends for distributed training
//! 分散学習用通信バックエンド

use super::common::{BackendOptimizations, CommonOps};
use super::{DistributedOps, ProcessGroup, ReduceOp};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;

/// NCCL backend for NVIDIA GPU communication
/// NVIDIA GPU通信用NCCLバックエンド
#[cfg(feature = "nccl")]
pub struct NCCLBackend {
    process_group: ProcessGroup,
    comm: *mut std::ffi::c_void,
}

#[cfg(feature = "nccl")]
unsafe impl Send for NCCLBackend {}

#[cfg(feature = "nccl")]
unsafe impl Sync for NCCLBackend {}

#[cfg(feature = "nccl")]
impl NCCLBackend {
    pub fn new(process_group: ProcessGroup) -> RusTorchResult<Self> {
        Ok(Self {
            process_group,
            comm: std::ptr::null_mut(),
        })
    }
}

#[cfg(feature = "nccl")]
impl<T: Float + Send + Sync + 'static> DistributedOps<T> for NCCLBackend {
    fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_reduce(tensor, op)
    }

    fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_gather(tensor, self.process_group.world_size)
    }

    fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_broadcast(tensor, root)
    }

    fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_gather(tensor, self.process_group.world_size, root)
    }

    fn scatter(&self, tensors: &[Tensor<T>], _root: usize) -> RusTorchResult<Tensor<T>> {
        if tensors.is_empty() {
            return Err(RusTorchError::tensor_op(
                "Empty tensor array for scatter operation",
            ));
        }
        Ok(tensors[0].clone())
    }

    fn reduce(&self, _tensor: &mut Tensor<T>, _root: usize, _op: ReduceOp) -> RusTorchResult<()> {
        Ok(())
    }
}

#[cfg(feature = "nccl")]
impl<T: Float + Send + Sync + 'static> BackendOptimizations<T> for NCCLBackend {}

/// Gloo backend for CPU and GPU communication
/// CPUおよびGPU通信用Glooバックエンド
pub struct GlooBackend {
    process_group: ProcessGroup,
}

/// Gloo transport options
/// Gloo転送オプション
#[derive(Debug, Clone, Copy)]
pub enum GlooTransport {
    /// TCP transport for network communication
    /// ネットワーク通信用TCP転送
    TCP,
    /// InfiniBand transport for high-performance clusters
    /// 高性能クラスター用InfiniBand転送
    InfiniBand,
    /// Shared memory transport for single-node communication
    /// 単一ノード通信用共有メモリ転送
    SharedMemory,
}

/// Gloo communication context
/// Gloo通信コンテキスト
pub struct GlooContext {}

impl GlooBackend {
    /// Create new Gloo backend instance
    /// 新しいGlooバックエンドインスタンスを作成
    pub fn new(process_group: ProcessGroup) -> RusTorchResult<Self> {
        let _context = GlooContext {};

        Ok(Self { process_group })
    }
}

impl<T: Float + Send + Sync + 'static> DistributedOps<T> for GlooBackend {
    fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_reduce(tensor, op)
    }

    fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_gather(tensor, self.process_group.world_size)
    }

    fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_broadcast(tensor, root)
    }

    fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_gather(tensor, self.process_group.world_size, root)
    }

    fn scatter(&self, tensors: &[Tensor<T>], _root: usize) -> RusTorchResult<Tensor<T>> {
        if tensors.is_empty() {
            return Err(RusTorchError::tensor_op(
                "Empty tensor array for scatter operation",
            ));
        }
        Ok(tensors[0].clone())
    }

    fn reduce(&self, tensor: &mut Tensor<T>, root: usize, op: ReduceOp) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_reduce(tensor, op)?;
        let _ = root; // Suppress unused parameter warning
        Ok(())
    }
}

impl<T: Float + Send + Sync + 'static> BackendOptimizations<T> for GlooBackend {}

/// TCP backend for simple distributed training
/// シンプルな分散学習用TCPバックエンド
pub struct TCPBackend {
    process_group: ProcessGroup,
}

/// TCP connection to remote process
/// リモートプロセスへのTCP接続
pub struct TCPConnection {}

impl TCPBackend {
    /// Create new TCP backend instance
    /// 新しいTCPバックエンドインスタンスを作成
    pub fn new(process_group: ProcessGroup) -> RusTorchResult<Self> {
        Ok(Self { process_group })
    }
}

impl<T: Float + Send + Sync + 'static> DistributedOps<T> for TCPBackend {
    fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_reduce(tensor, op)
    }

    fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_gather(tensor, self.process_group.world_size)
    }

    fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_broadcast(tensor, root)
    }

    fn reduce(&self, tensor: &mut Tensor<T>, root: usize, op: ReduceOp) -> RusTorchResult<()> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_all_reduce(tensor, op)?;
        let _ = root; // Suppress unused parameter warning
        Ok(())
    }

    fn scatter(&self, tensors: &[Tensor<T>], _root: usize) -> RusTorchResult<Tensor<T>> {
        if tensors.is_empty() {
            return Err(RusTorchError::tensor_op(
                "Empty tensor array for scatter operation",
            ));
        }
        Ok(tensors[0].clone())
    }

    fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>> {
        CommonOps::validate_tensor(tensor)?;
        CommonOps::default_gather(tensor, self.process_group.world_size, root)
    }
}

impl<T: Float + Send + Sync + 'static> BackendOptimizations<T> for TCPBackend {}

/// Backend factory
/// バックエンドファクトリ
pub struct BackendFactory;

impl BackendFactory {
    /// Create backend instance based on process group configuration
    /// プロセスグループ設定に基づいてバックエンドインスタンスを作成
    pub fn create_backend<T: Float + Send + Sync + 'static>(
        process_group: ProcessGroup,
    ) -> RusTorchResult<Box<dyn DistributedOps<T> + Send + Sync>> {
        match process_group.backend {
            #[cfg(feature = "nccl")]
            super::DistributedBackend::NCCL => {
                let backend = NCCLBackend::new(process_group)?;
                Ok(Box::new(backend))
            }
            super::DistributedBackend::Gloo => {
                let backend = GlooBackend::new(process_group)?;
                Ok(Box::new(backend))
            }
            super::DistributedBackend::TCP => {
                let backend = TCPBackend::new(process_group)?;
                Ok(Box::new(backend))
            }
            #[cfg(not(feature = "nccl"))]
            super::DistributedBackend::NCCL => {
                Err(RusTorchError::backend_unavailable("NCCL not compiled"))
            }
            super::DistributedBackend::MPI => {
                Err(RusTorchError::backend_unavailable("MPI not implemented"))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::super::DistributedBackend;
    use super::*;

    #[test]
    fn test_gloo_backend_creation() {
        let pg = ProcessGroup::new(
            0,
            4,
            DistributedBackend::Gloo,
            "localhost".to_string(),
            12345,
        );

        let backend = GlooBackend::new(pg);
        assert!(backend.is_ok());
    }
}