use std::sync::{Arc, OnceLock};
use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
use crate::buffer::GpuBuffer;
use crate::error::Result;
static DEBUG_WARNED: OnceLock<()> = OnceLock::new();
const DEBUG_WARNING_MESSAGE: &str = "[kaio] Note: debug build — GPU kernel performance is ~10-20x slower than --release. Use `cargo run --release` / `cargo test --release` for representative performance numbers. Correctness is unaffected. Set KAIO_SUPPRESS_DEBUG_WARNING=1 to silence.";
fn should_emit_debug_warning() -> bool {
cfg!(debug_assertions) && std::env::var("KAIO_SUPPRESS_DEBUG_WARNING").is_err()
}
fn maybe_warn_debug_build() {
if should_emit_debug_warning() {
DEBUG_WARNED.get_or_init(|| {
eprintln!("{DEBUG_WARNING_MESSAGE}");
});
}
}
pub struct KaioDevice {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
}
impl std::fmt::Debug for KaioDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KaioDevice")
.field("ordinal", &self.ctx.ordinal())
.finish()
}
}
impl KaioDevice {
pub fn new(ordinal: usize) -> Result<Self> {
maybe_warn_debug_build();
let ctx = CudaContext::new(ordinal)?;
let stream = ctx.default_stream();
Ok(Self { ctx, stream })
}
pub fn info(&self) -> Result<DeviceInfo> {
DeviceInfo::from_context(&self.ctx)
}
pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
let slice = self.stream.clone_htod(data)?;
Ok(GpuBuffer::from_raw(slice))
}
pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
let slice = self.stream.alloc_zeros::<T>(len)?;
Ok(GpuBuffer::from_raw(slice))
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
#[deprecated(
since = "0.2.1",
note = "use load_module(&PtxModule) — runs PtxModule::validate() for readable SM-mismatch errors"
)]
pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
let module = self.ctx.load_module(ptx)?;
Ok(crate::module::KaioModule::from_raw(module))
}
pub fn load_module(
&self,
module: &kaio_core::ir::PtxModule,
) -> Result<crate::module::KaioModule> {
use kaio_core::emit::{Emit, PtxWriter};
module.validate()?;
let mut w = PtxWriter::new();
module
.emit(&mut w)
.map_err(|e| crate::error::KaioError::PtxLoad(format!("emit failed: {e}")))?;
let ptx_text = w.finish();
#[allow(deprecated)]
self.load_ptx(&ptx_text)
}
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub name: String,
pub compute_capability: (u32, u32),
pub total_memory: usize,
}
impl DeviceInfo {
fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
use cudarc::driver::result::device;
let ordinal = ctx.ordinal();
let dev = device::get(ordinal as i32)?;
let name = device::get_name(dev)?;
let total_memory = unsafe { device::total_mem(dev)? };
let major = unsafe {
device::get_attribute(
dev,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
)?
};
let minor = unsafe {
device::get_attribute(
dev,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
)?
};
Ok(Self {
name,
compute_capability: (major as u32, minor as u32),
total_memory,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::OnceLock;
static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
fn device() -> &'static KaioDevice {
DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
}
#[test]
fn debug_warning_message_is_performance_framed_not_correctness_framed() {
let msg = DEBUG_WARNING_MESSAGE;
assert!(
msg.contains("performance"),
"debug warning must mention performance: {msg}"
);
assert!(
msg.contains("Correctness is unaffected") || msg.contains("correctness is unaffected"),
"debug warning must explicitly state correctness is unaffected: {msg}"
);
assert!(
!msg.to_lowercase().contains("not meaningful")
&& !msg.to_lowercase().contains("invalid"),
"debug warning must NOT imply results are invalid/not meaningful: {msg}"
);
assert!(
msg.contains("KAIO_SUPPRESS_DEBUG_WARNING"),
"debug warning must document the opt-out env var: {msg}"
);
}
#[test]
fn debug_warning_opt_out_env_var_suppresses() {
let prev = std::env::var("KAIO_SUPPRESS_DEBUG_WARNING").ok();
unsafe {
std::env::set_var("KAIO_SUPPRESS_DEBUG_WARNING", "1");
}
assert!(
!should_emit_debug_warning(),
"KAIO_SUPPRESS_DEBUG_WARNING=1 must suppress the warning"
);
unsafe {
std::env::remove_var("KAIO_SUPPRESS_DEBUG_WARNING");
}
assert_eq!(should_emit_debug_warning(), cfg!(debug_assertions));
if let Some(v) = prev {
unsafe {
std::env::set_var("KAIO_SUPPRESS_DEBUG_WARNING", v);
}
}
}
#[test]
#[ignore] fn device_creation() {
let dev = KaioDevice::new(0);
assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
}
#[test]
#[ignore]
fn device_info_name() {
let info = device().info().expect("info() failed");
assert!(!info.name.is_empty(), "device name should not be empty");
eprintln!("GPU name: {}", info.name);
}
#[test]
#[ignore]
fn device_info_compute_capability() {
let info = device().info().expect("info() failed");
let (major, _minor) = info.compute_capability;
assert!(
major >= 7,
"expected SM 7.0+ GPU, got SM {}.{}",
info.compute_capability.0,
info.compute_capability.1,
);
eprintln!(
"GPU compute capability: SM {}.{}",
info.compute_capability.0, info.compute_capability.1
);
}
#[test]
#[ignore]
fn buffer_roundtrip_f32() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let buf = device().alloc_from(&data).expect("alloc_from failed");
let result = buf.to_host(device()).expect("to_host failed");
assert_eq!(result, data, "roundtrip data mismatch");
}
#[test]
#[ignore]
fn buffer_alloc_zeros() {
let buf = device()
.alloc_zeros::<f32>(100)
.expect("alloc_zeros failed");
let result = buf.to_host(device()).expect("to_host failed");
assert_eq!(result, vec![0.0f32; 100]);
}
#[test]
#[ignore]
fn buffer_len() {
let buf = device()
.alloc_from(&[1.0f32, 2.0, 3.0])
.expect("alloc_from failed");
assert_eq!(buf.len(), 3);
assert!(!buf.is_empty());
}
#[test]
#[ignore]
fn invalid_device_ordinal() {
let result = KaioDevice::new(999);
assert!(result.is_err(), "expected error for ordinal 999");
}
}