ferrotorch-gpu 0.1.4

CUDA GPU backend for ferrotorch
Documentation
//! CUDA device management.
//!
//! [`GpuDevice`] wraps a `cudarc::driver::CudaContext` and its default stream,
//! providing a safe, ergonomic entry point for all GPU operations.

#[cfg(feature = "cuda")]
use std::sync::Arc;

#[cfg(feature = "cuda")]
use cudarc::cublas::CudaBlas;
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaStream};

use crate::error::GpuResult;
#[cfg(not(feature = "cuda"))]
use crate::error::GpuError;

/// Handle to a single CUDA GPU device.
///
/// Holds a CUDA context, default stream, and a **cached cuBLAS handle**.
/// The cuBLAS handle is created once and reused for all matmul/bmm ops,
/// eliminating the ~1.7ms `cuModuleLoadData` overhead that occurs when
/// creating a new `CudaBlas` per operation.
#[cfg(feature = "cuda")]
pub struct GpuDevice {
    ctx: Arc<CudaContext>,
    stream: Arc<CudaStream>,
    blas: CudaBlas,
    ordinal: usize,
}

#[cfg(feature = "cuda")]
impl GpuDevice {
    pub fn new(ordinal: usize) -> GpuResult<Self> {
        let ctx = CudaContext::new(ordinal)?;
        let stream = ctx.default_stream();
        let blas = CudaBlas::new(stream.clone())?;
        Ok(Self { ctx, stream, blas, ordinal })
    }

    /// Create a `GpuDevice` with a non-blocking stream forked from the
    /// given device's default stream. The forked stream supports CUDA graph
    /// capture (which the legacy default stream does not).
    pub fn fork_for_capture(parent: &GpuDevice) -> GpuResult<Self> {
        let stream = parent.stream.fork()?;
        let blas = CudaBlas::new(stream.clone())?;
        Ok(Self {
            ctx: Arc::clone(&parent.ctx),
            stream,
            blas,
            ordinal: parent.ordinal,
        })
    }

    #[inline]
    pub fn context(&self) -> &Arc<CudaContext> { &self.ctx }

    #[inline]
    pub fn stream(&self) -> &Arc<CudaStream> { &self.stream }

    /// The cached cuBLAS handle — reused for all matmul/bmm operations.
    #[inline]
    pub fn blas(&self) -> &CudaBlas { &self.blas }

    #[inline]
    pub fn ordinal(&self) -> usize { self.ordinal }
}

#[cfg(feature = "cuda")]
impl Clone for GpuDevice {
    fn clone(&self) -> Self {
        let blas = CudaBlas::new(self.stream.clone())
            .expect("CudaBlas::new failed in GpuDevice::clone");
        Self {
            ctx: Arc::clone(&self.ctx),
            stream: Arc::clone(&self.stream),
            blas,
            ordinal: self.ordinal,
        }
    }
}

#[cfg(feature = "cuda")]
impl std::fmt::Debug for GpuDevice {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("GpuDevice")
            .field("ordinal", &self.ordinal)
            .finish_non_exhaustive()
    }
}

// ---------------------------------------------------------------------------
// Stub when `cuda` feature is disabled
// ---------------------------------------------------------------------------

/// Stub `GpuDevice` when the `cuda` feature is not enabled.
///
/// Every method returns [`GpuError::NoCudaFeature`].
#[cfg(not(feature = "cuda"))]
#[derive(Clone, Debug)]
pub struct GpuDevice {
    ordinal: usize,
}

#[cfg(not(feature = "cuda"))]
impl GpuDevice {
    /// Always returns an error — compile with `features = ["cuda"]`.
    pub fn new(ordinal: usize) -> GpuResult<Self> {
        let _ = ordinal;
        Err(GpuError::NoCudaFeature)
    }

    /// The device ordinal.
    #[inline]
    pub fn ordinal(&self) -> usize {
        self.ordinal
    }
}