1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
//! CUDA device management.
//!
//! [`GpuDevice`] wraps a `cudarc::driver::CudaContext` and its default stream,
//! providing a safe, ergonomic entry point for all GPU operations.
#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::cublas::CudaBlas;
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaStream};
#[cfg(not(feature = "cuda"))]
use crate::error::GpuError;
use crate::error::GpuResult;
/// Handle to a single CUDA GPU device.
///
/// Holds a CUDA context, default stream, and a **cached cuBLAS handle**.
/// The cuBLAS handle is created once and reused for all matmul/bmm ops,
/// eliminating the ~1.7ms `cuModuleLoadData` overhead that occurs when
/// creating a new `CudaBlas` per operation.
#[cfg(feature = "cuda")]
pub struct GpuDevice {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
blas: CudaBlas,
ordinal: usize,
}
#[cfg(feature = "cuda")]
impl GpuDevice {
/// Initialize the CUDA device at the given ordinal.
///
/// Creates a fresh `CudaContext`, takes its default stream, and
/// constructs a cached `CudaBlas` handle bound to that stream so
/// subsequent matmul/bmm ops reuse it instead of paying the
/// `cuModuleLoadData` cost per call.
pub fn new(ordinal: usize) -> GpuResult<Self> {
let ctx = CudaContext::new(ordinal)?;
let stream = ctx.default_stream();
let blas = CudaBlas::new(stream.clone())?;
Ok(Self {
ctx,
stream,
blas,
ordinal,
})
}
/// Create a `GpuDevice` with a non-blocking stream forked from the
/// given device's default stream. The forked stream supports CUDA graph
/// capture (which the legacy default stream does not).
pub fn fork_for_capture(parent: &GpuDevice) -> GpuResult<Self> {
let stream = parent.stream.fork()?;
let blas = CudaBlas::new(stream.clone())?;
Ok(Self {
ctx: Arc::clone(&parent.ctx),
stream,
blas,
ordinal: parent.ordinal,
})
}
/// The shared `CudaContext` underlying this device.
///
/// Required by `cudarc::driver::CudaModule` loaders and other low-level
/// APIs that need a context handle separate from the stream.
#[inline]
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
/// The device's default (legacy) stream.
///
/// Prefer [`stream`](Self::stream) which respects the
/// thread-local stream override set by [`crate::stream::StreamGuard`].
#[inline]
pub fn default_stream(&self) -> &Arc<CudaStream> {
&self.stream
}
/// The active stream for this device on the current thread.
///
/// Returns the thread-local stream set by [`crate::stream::StreamGuard`]
/// if one is active, otherwise falls back to the device's default stream.
/// All kernel launches and memory operations should use this.
#[inline]
pub fn stream(&self) -> Arc<CudaStream> {
crate::stream::current_stream_or_default(self)
}
/// The cached cuBLAS handle — reused for all matmul/bmm operations.
#[inline]
pub fn blas(&self) -> &CudaBlas {
&self.blas
}
/// The 0-based ordinal of this CUDA device, as reported by the driver.
#[inline]
pub fn ordinal(&self) -> usize {
self.ordinal
}
}
#[cfg(feature = "cuda")]
impl Clone for GpuDevice {
fn clone(&self) -> Self {
let blas =
CudaBlas::new(self.stream.clone()).expect("CudaBlas::new failed in GpuDevice::clone");
Self {
ctx: Arc::clone(&self.ctx),
stream: Arc::clone(&self.stream),
blas,
ordinal: self.ordinal,
}
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for GpuDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuDevice")
.field("ordinal", &self.ordinal)
.finish_non_exhaustive()
}
}
// ---------------------------------------------------------------------------
// Stub when `cuda` feature is disabled
// ---------------------------------------------------------------------------
/// Stub `GpuDevice` when the `cuda` feature is not enabled.
///
/// Every method returns [`GpuError::NoCudaFeature`].
#[cfg(not(feature = "cuda"))]
#[derive(Clone, Debug)]
pub struct GpuDevice {
ordinal: usize,
}
#[cfg(not(feature = "cuda"))]
impl GpuDevice {
/// Always returns an error — compile with `features = ["cuda"]`.
pub fn new(ordinal: usize) -> GpuResult<Self> {
let _ = ordinal;
Err(GpuError::NoCudaFeature)
}
/// The device ordinal.
#[inline]
pub fn ordinal(&self) -> usize {
self.ordinal
}
}