kaio_runtime/device.rs
1//! CUDA device management.
2
3use std::sync::{Arc, OnceLock};
4
5use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
6
7use crate::buffer::GpuBuffer;
8use crate::error::Result;
9
10/// Process-wide latch for the debug-build performance note.
11///
12/// Sprint 7.0.5 A2: emit a one-time stderr note on first `KaioDevice::new`
13/// when the binary is built in debug mode. Prevents the common "benchmarked
14/// in debug, bounced" adoption failure where new users run a showcase example
15/// with `cargo run` (defaulting to debug) and conclude KAIO is slow. The note
16/// is performance-framed only — debug-mode does not affect correctness, and a
17/// `cargo test`-in-debug user checking kernel output should not see their
18/// correctness results cast into doubt.
19static DEBUG_WARNED: OnceLock<()> = OnceLock::new();
20
21/// Performance-framed debug-mode note body. `const` so tests can assert on
22/// its content without re-typing the message.
23const DEBUG_WARNING_MESSAGE: &str = "[kaio] Note: debug build — GPU kernel performance is ~10-20x slower than --release. Use `cargo run --release` / `cargo test --release` for representative performance numbers. Correctness is unaffected. Set KAIO_SUPPRESS_DEBUG_WARNING=1 to silence.";
24
25/// Pure decision function: should the debug-mode note fire on this
26/// process? Split out from [`maybe_warn_debug_build`] so the env-var
27/// logic is testable without the static `OnceLock` interfering.
28fn should_emit_debug_warning() -> bool {
29 cfg!(debug_assertions) && std::env::var("KAIO_SUPPRESS_DEBUG_WARNING").is_err()
30}
31
32/// Emit the debug-mode performance note to stderr once per process, if
33/// [`should_emit_debug_warning`] returns true.
34///
35/// Called from [`KaioDevice::new`] — every KAIO program hits this path on
36/// first launch, so the note surfaces exactly when a user would first
37/// benefit from knowing. In release builds, `cfg!(debug_assertions)` folds
38/// to `false` and the whole body compiles out.
39fn maybe_warn_debug_build() {
40 if should_emit_debug_warning() {
41 DEBUG_WARNED.get_or_init(|| {
42 eprintln!("{DEBUG_WARNING_MESSAGE}");
43 });
44 }
45}
46
47/// A KAIO GPU device — wraps a CUDA context and its default stream.
48///
49/// Created via [`KaioDevice::new`] with a device ordinal (0 for the first GPU).
50/// All allocation and transfer operations go through the default stream.
51///
52/// # Example
53///
54/// ```ignore
55/// let device = KaioDevice::new(0)?;
56/// let buf = device.alloc_from(&[1.0f32, 2.0, 3.0])?;
57/// let host = buf.to_host(&device)?;
58/// ```
59pub struct KaioDevice {
60 ctx: Arc<CudaContext>,
61 stream: Arc<CudaStream>,
62}
63
64impl std::fmt::Debug for KaioDevice {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("KaioDevice")
67 .field("ordinal", &self.ctx.ordinal())
68 .finish()
69 }
70}
71
72impl KaioDevice {
73 /// Create a new device targeting the GPU at the given ordinal.
74 ///
75 /// Ordinal 0 is the first GPU. Returns an error if no GPU exists at
76 /// that ordinal or if the CUDA driver fails to initialize.
77 pub fn new(ordinal: usize) -> Result<Self> {
78 maybe_warn_debug_build();
79 let ctx = CudaContext::new(ordinal)?;
80 let stream = ctx.default_stream();
81 Ok(Self { ctx, stream })
82 }
83
84 /// Query basic information about this device.
85 pub fn info(&self) -> Result<DeviceInfo> {
86 DeviceInfo::from_context(&self.ctx)
87 }
88
89 /// CUDA device ordinal (0-indexed) this device wraps.
90 ///
91 /// Used by bridge crates (e.g. `kaio-candle`) to cross-check that a
92 /// host-framework device and a `KaioDevice` refer to the same GPU.
93 pub fn ordinal(&self) -> usize {
94 self.ctx.ordinal()
95 }
96
97 /// Allocate device memory and copy data from a host slice.
98 pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
99 let slice = self.stream.clone_htod(data)?;
100 Ok(GpuBuffer::from_cuda_slice(slice))
101 }
102
103 /// Allocate zero-initialized device memory.
104 pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
105 let slice = self.stream.alloc_zeros::<T>(len)?;
106 Ok(GpuBuffer::from_cuda_slice(slice))
107 }
108
109 /// Access the underlying CUDA stream for kernel launch operations.
110 ///
111 /// Used with cudarc's `launch_builder` to launch kernels. In Phase 2,
112 /// the proc macro will generate typed wrappers that hide this.
113 pub fn stream(&self) -> &Arc<CudaStream> {
114 &self.stream
115 }
116
117 /// Load a PTX module from source text and return a [`crate::module::KaioModule`].
118 ///
119 /// The PTX text is passed to the CUDA driver's `cuModuleLoadData` —
120 /// no NVRTC compilation occurs. The driver JIT-compiles the PTX for
121 /// the current GPU.
122 ///
123 /// # Deprecated — prefer [`load_module`](Self::load_module)
124 ///
125 /// The module path runs
126 /// [`PtxModule::validate`](kaio_core::ir::PtxModule::validate)
127 /// before the driver sees the PTX, catching SM mismatches (e.g.
128 /// `mma.sync` on sub-Ampere targets) with readable
129 /// [`KaioError::Validation`](crate::error::KaioError::Validation)
130 /// errors instead of cryptic `ptxas` failures deep in the driver.
131 ///
132 /// This function remains public for raw-PTX use cases (external PTX
133 /// files, hand-written PTX for research, bypassing validation
134 /// intentionally). It is not scheduled for removal in the 0.2.x line.
135 ///
136 /// # Migration
137 ///
138 /// Before:
139 /// ```ignore
140 /// let ptx_text: String = build_my_ptx();
141 /// let module = device.load_ptx(&ptx_text)?;
142 /// ```
143 ///
144 /// After:
145 /// ```ignore
146 /// use kaio_core::ir::PtxModule;
147 /// let ptx_module: PtxModule = build_my_module("sm_80");
148 /// let module = device.load_module(&ptx_module)?;
149 /// ```
150 #[deprecated(
151 since = "0.2.1",
152 note = "use load_module(&PtxModule) — runs PtxModule::validate() for readable SM-mismatch errors"
153 )]
154 pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
155 let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
156 let module = self.ctx.load_module(ptx)?;
157 Ok(crate::module::KaioModule::from_raw(module))
158 }
159
160 /// Validate, emit, and load a [`kaio_core::ir::PtxModule`] on the device.
161 ///
162 /// This is the preferred entrypoint when the caller has an in-memory
163 /// `PtxModule` (as opposed to raw PTX text). Before the PTX text is
164 /// handed to the driver, [`kaio_core::ir::PtxModule::validate`]
165 /// checks that the module's target SM supports every feature used by
166 /// its kernels — raising
167 /// [`KaioError::Validation`](crate::error::KaioError::Validation) if
168 /// e.g. a `mma.sync` op is present but the target is `sm_70`.
169 ///
170 /// Surfacing the error at this layer gives the user a readable
171 /// message ("`mma.sync.m16n8k16 requires sm_80+, target is sm_70`")
172 /// instead of a cryptic `ptxas` error from deep in the driver.
173 pub fn load_module(
174 &self,
175 module: &kaio_core::ir::PtxModule,
176 ) -> Result<crate::module::KaioModule> {
177 use kaio_core::emit::{Emit, PtxWriter};
178
179 module.validate()?;
180
181 let mut w = PtxWriter::new();
182 module
183 .emit(&mut w)
184 .map_err(|e| crate::error::KaioError::PtxLoad(format!("emit failed: {e}")))?;
185 let ptx_text = w.finish();
186
187 // `load_ptx` is #[deprecated] as a public API to steer users to the
188 // validated module path, but it's still the correct internal
189 // implementation detail after we've emitted the PTX text here.
190 #[allow(deprecated)]
191 self.load_ptx(&ptx_text)
192 }
193}
194
195/// Basic information about a CUDA device.
196///
197/// Phase 1 includes name, compute capability, and total memory.
198/// Additional fields (SM count, max threads per block, max shared memory,
199/// warp size) are planned for Phase 3/4 when shared memory and occupancy
200/// calculations matter.
201#[derive(Debug, Clone)]
202pub struct DeviceInfo {
203 /// GPU device name (e.g. "NVIDIA GeForce RTX 4090").
204 pub name: String,
205 /// Compute capability as (major, minor) — e.g. (8, 9) for SM 8.9.
206 pub compute_capability: (u32, u32),
207 /// Total device memory in bytes.
208 pub total_memory: usize,
209}
210
211impl DeviceInfo {
212 /// Query device info from a CUDA context.
213 fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
214 use cudarc::driver::result::device;
215
216 let ordinal = ctx.ordinal();
217 let dev = device::get(ordinal as i32)?;
218 let name = device::get_name(dev)?;
219 let total_memory = unsafe { device::total_mem(dev)? };
220
221 // SAFETY: dev is a valid device handle obtained from device::get().
222 // get_attribute reads a device property — no mutation, no aliasing.
223 let major = unsafe {
224 device::get_attribute(
225 dev,
226 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
227 )?
228 };
229 let minor = unsafe {
230 device::get_attribute(
231 dev,
232 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
233 )?
234 };
235
236 Ok(Self {
237 name,
238 compute_capability: (major as u32, minor as u32),
239 total_memory,
240 })
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::sync::OnceLock;
248
249 static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
250 fn device() -> &'static KaioDevice {
251 DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
252 }
253
254 // Sprint 7.0.5 A2: debug-mode performance note tests.
255 //
256 // These verify the pure-function half of the warning logic. The
257 // once-per-process behavior mediated by the static `DEBUG_WARNED`
258 // OnceLock is not testable in-process without restructuring (the
259 // latch is set for the lifetime of the test binary); manual/subprocess
260 // verification is in sprint_7_0_5.md.
261
262 #[test]
263 fn debug_warning_message_is_performance_framed_not_correctness_framed() {
264 // Regression canary (Sprint 7.0.5 A2 message framing): if the
265 // wording ever drifts to imply correctness is affected ("results
266 // are not meaningful," "output is invalid," etc.) this test
267 // fails. The whole point of the message is to prevent perf
268 // misunderstandings WITHOUT scaring off correctness testing.
269 let msg = DEBUG_WARNING_MESSAGE;
270 assert!(
271 msg.contains("performance"),
272 "debug warning must mention performance: {msg}"
273 );
274 assert!(
275 msg.contains("Correctness is unaffected") || msg.contains("correctness is unaffected"),
276 "debug warning must explicitly state correctness is unaffected: {msg}"
277 );
278 assert!(
279 !msg.to_lowercase().contains("not meaningful")
280 && !msg.to_lowercase().contains("invalid"),
281 "debug warning must NOT imply results are invalid/not meaningful: {msg}"
282 );
283 assert!(
284 msg.contains("KAIO_SUPPRESS_DEBUG_WARNING"),
285 "debug warning must document the opt-out env var: {msg}"
286 );
287 }
288
289 #[test]
290 fn debug_warning_opt_out_env_var_suppresses() {
291 // SAFETY: single-threaded env-var manipulation inside a test.
292 // Restore the prior value (if any) before returning so other
293 // tests in the same binary don't observe stale state.
294 let prev = std::env::var("KAIO_SUPPRESS_DEBUG_WARNING").ok();
295 unsafe {
296 std::env::set_var("KAIO_SUPPRESS_DEBUG_WARNING", "1");
297 }
298 assert!(
299 !should_emit_debug_warning(),
300 "KAIO_SUPPRESS_DEBUG_WARNING=1 must suppress the warning"
301 );
302 unsafe {
303 std::env::remove_var("KAIO_SUPPRESS_DEBUG_WARNING");
304 }
305 // In debug builds the warning should now be allowed; in release
306 // builds cfg!(debug_assertions) is false so it's suppressed either
307 // way. Assert the cfg-consistent expectation.
308 assert_eq!(should_emit_debug_warning(), cfg!(debug_assertions));
309 // Restore
310 if let Some(v) = prev {
311 unsafe {
312 std::env::set_var("KAIO_SUPPRESS_DEBUG_WARNING", v);
313 }
314 }
315 }
316
317 #[test]
318 #[ignore] // requires NVIDIA GPU
319 fn device_creation() {
320 let dev = KaioDevice::new(0);
321 assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
322 }
323
324 #[test]
325 #[ignore]
326 fn device_info_name() {
327 let info = device().info().expect("info() failed");
328 assert!(!info.name.is_empty(), "device name should not be empty");
329 // RTX 4090 should contain "4090" somewhere in the name
330 eprintln!("GPU name: {}", info.name);
331 }
332
333 #[test]
334 #[ignore]
335 fn device_info_compute_capability() {
336 let info = device().info().expect("info() failed");
337 // Any SM 7.0+ GPU should work (Volta and newer)
338 let (major, _minor) = info.compute_capability;
339 assert!(
340 major >= 7,
341 "expected SM 7.0+ GPU, got SM {}.{}",
342 info.compute_capability.0,
343 info.compute_capability.1,
344 );
345 eprintln!(
346 "GPU compute capability: SM {}.{}",
347 info.compute_capability.0, info.compute_capability.1
348 );
349 }
350
351 #[test]
352 #[ignore]
353 fn buffer_roundtrip_f32() {
354 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
355 let buf = device().alloc_from(&data).expect("alloc_from failed");
356 let result = buf.to_host(device()).expect("to_host failed");
357 assert_eq!(result, data, "roundtrip data mismatch");
358 }
359
360 #[test]
361 #[ignore]
362 fn buffer_alloc_zeros() {
363 let buf = device()
364 .alloc_zeros::<f32>(100)
365 .expect("alloc_zeros failed");
366 let result = buf.to_host(device()).expect("to_host failed");
367 assert_eq!(result, vec![0.0f32; 100]);
368 }
369
370 #[test]
371 #[ignore]
372 fn buffer_len() {
373 let buf = device()
374 .alloc_from(&[1.0f32, 2.0, 3.0])
375 .expect("alloc_from failed");
376 assert_eq!(buf.len(), 3);
377 assert!(!buf.is_empty());
378 }
379
380 #[test]
381 #[ignore]
382 fn invalid_device_ordinal() {
383 let result = KaioDevice::new(999);
384 assert!(result.is_err(), "expected error for ordinal 999");
385 }
386}