Skip to main content

ferrotorch_gpu/
device.rs

1//! CUDA device management.
2//!
3//! [`GpuDevice`] wraps a `cudarc::driver::CudaContext` and its default stream,
4//! providing a safe, ergonomic entry point for all GPU operations.
5
6#[cfg(feature = "cuda")]
7use std::sync::Arc;
8
9#[cfg(feature = "cuda")]
10use cudarc::cublas::CudaBlas;
11#[cfg(feature = "cuda")]
12use cudarc::driver::{CudaContext, CudaStream};
13
14#[cfg(not(feature = "cuda"))]
15use crate::error::GpuError;
16use crate::error::GpuResult;
17
18/// Handle to a single CUDA GPU device.
19///
20/// Holds a CUDA context, default stream, and a **cached cuBLAS handle**.
21/// The cuBLAS handle is created once and reused for all matmul/bmm ops,
22/// eliminating the ~1.7ms `cuModuleLoadData` overhead that occurs when
23/// creating a new `CudaBlas` per operation.
24#[cfg(feature = "cuda")]
25pub struct GpuDevice {
26    ctx: Arc<CudaContext>,
27    stream: Arc<CudaStream>,
28    blas: CudaBlas,
29    ordinal: usize,
30}
31
32#[cfg(feature = "cuda")]
33impl GpuDevice {
34    pub fn new(ordinal: usize) -> GpuResult<Self> {
35        let ctx = CudaContext::new(ordinal)?;
36        let stream = ctx.default_stream();
37        let blas = CudaBlas::new(stream.clone())?;
38        Ok(Self {
39            ctx,
40            stream,
41            blas,
42            ordinal,
43        })
44    }
45
46    /// Create a `GpuDevice` with a non-blocking stream forked from the
47    /// given device's default stream. The forked stream supports CUDA graph
48    /// capture (which the legacy default stream does not).
49    pub fn fork_for_capture(parent: &GpuDevice) -> GpuResult<Self> {
50        let stream = parent.stream.fork()?;
51        let blas = CudaBlas::new(stream.clone())?;
52        Ok(Self {
53            ctx: Arc::clone(&parent.ctx),
54            stream,
55            blas,
56            ordinal: parent.ordinal,
57        })
58    }
59
60    #[inline]
61    pub fn context(&self) -> &Arc<CudaContext> {
62        &self.ctx
63    }
64
65    /// The device's default (legacy) stream.
66    ///
67    /// Prefer [`current_stream`](Self::current_stream) which respects the
68    /// thread-local stream override set by [`StreamGuard`].
69    #[inline]
70    pub fn default_stream(&self) -> &Arc<CudaStream> {
71        &self.stream
72    }
73
74    /// The active stream for this device on the current thread.
75    ///
76    /// Returns the thread-local stream set by [`StreamGuard`] if one is
77    /// active, otherwise falls back to the device's default stream. All
78    /// kernel launches and memory operations should use this.
79    #[inline]
80    pub fn stream(&self) -> Arc<CudaStream> {
81        crate::stream::current_stream_or_default(self)
82    }
83
84    /// The cached cuBLAS handle — reused for all matmul/bmm operations.
85    #[inline]
86    pub fn blas(&self) -> &CudaBlas {
87        &self.blas
88    }
89
90    #[inline]
91    pub fn ordinal(&self) -> usize {
92        self.ordinal
93    }
94}
95
96#[cfg(feature = "cuda")]
97impl Clone for GpuDevice {
98    fn clone(&self) -> Self {
99        let blas =
100            CudaBlas::new(self.stream.clone()).expect("CudaBlas::new failed in GpuDevice::clone");
101        Self {
102            ctx: Arc::clone(&self.ctx),
103            stream: Arc::clone(&self.stream),
104            blas,
105            ordinal: self.ordinal,
106        }
107    }
108}
109
110#[cfg(feature = "cuda")]
111impl std::fmt::Debug for GpuDevice {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("GpuDevice")
114            .field("ordinal", &self.ordinal)
115            .finish_non_exhaustive()
116    }
117}
118
119// ---------------------------------------------------------------------------
120// Stub when `cuda` feature is disabled
121// ---------------------------------------------------------------------------
122
123/// Stub `GpuDevice` when the `cuda` feature is not enabled.
124///
125/// Every method returns [`GpuError::NoCudaFeature`].
126#[cfg(not(feature = "cuda"))]
127#[derive(Clone, Debug)]
128pub struct GpuDevice {
129    ordinal: usize,
130}
131
132#[cfg(not(feature = "cuda"))]
133impl GpuDevice {
134    /// Always returns an error — compile with `features = ["cuda"]`.
135    pub fn new(ordinal: usize) -> GpuResult<Self> {
136        let _ = ordinal;
137        Err(GpuError::NoCudaFeature)
138    }
139
140    /// The device ordinal.
141    #[inline]
142    pub fn ordinal(&self) -> usize {
143        self.ordinal
144    }
145}