use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
use crate::buffer::GpuBuffer;
use crate::error::Result;
pub struct KaioDevice {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
}
impl std::fmt::Debug for KaioDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KaioDevice")
.field("ordinal", &self.ctx.ordinal())
.finish()
}
}
impl KaioDevice {
pub fn new(ordinal: usize) -> Result<Self> {
let ctx = CudaContext::new(ordinal)?;
let stream = ctx.default_stream();
Ok(Self { ctx, stream })
}
pub fn info(&self) -> Result<DeviceInfo> {
DeviceInfo::from_context(&self.ctx)
}
pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
let slice = self.stream.clone_htod(data)?;
Ok(GpuBuffer::from_raw(slice))
}
pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
let slice = self.stream.alloc_zeros::<T>(len)?;
Ok(GpuBuffer::from_raw(slice))
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
let module = self.ctx.load_module(ptx)?;
Ok(crate::module::KaioModule::from_raw(module))
}
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub name: String,
pub compute_capability: (u32, u32),
pub total_memory: usize,
}
impl DeviceInfo {
fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
use cudarc::driver::result::device;
let ordinal = ctx.ordinal();
let dev = device::get(ordinal as i32)?;
let name = device::get_name(dev)?;
let total_memory = unsafe { device::total_mem(dev)? };
let major = unsafe {
device::get_attribute(
dev,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
)?
};
let minor = unsafe {
device::get_attribute(
dev,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
)?
};
Ok(Self {
name,
compute_capability: (major as u32, minor as u32),
total_memory,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::OnceLock;
static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
fn device() -> &'static KaioDevice {
DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
}
#[test]
#[ignore] fn device_creation() {
let dev = KaioDevice::new(0);
assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
}
#[test]
#[ignore]
fn device_info_name() {
let info = device().info().expect("info() failed");
assert!(!info.name.is_empty(), "device name should not be empty");
eprintln!("GPU name: {}", info.name);
}
#[test]
#[ignore]
fn device_info_compute_capability() {
let info = device().info().expect("info() failed");
assert_eq!(
info.compute_capability,
(8, 9),
"expected SM 8.9 for RTX 4090, got {:?}",
info.compute_capability
);
}
#[test]
#[ignore]
fn buffer_roundtrip_f32() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let buf = device().alloc_from(&data).expect("alloc_from failed");
let result = buf.to_host(device()).expect("to_host failed");
assert_eq!(result, data, "roundtrip data mismatch");
}
#[test]
#[ignore]
fn buffer_alloc_zeros() {
let buf = device()
.alloc_zeros::<f32>(100)
.expect("alloc_zeros failed");
let result = buf.to_host(device()).expect("to_host failed");
assert_eq!(result, vec![0.0f32; 100]);
}
#[test]
#[ignore]
fn buffer_len() {
let buf = device()
.alloc_from(&[1.0f32, 2.0, 3.0])
.expect("alloc_from failed");
assert_eq!(buf.len(), 3);
assert!(!buf.is_empty());
}
#[test]
#[ignore]
fn invalid_device_ordinal() {
let result = KaioDevice::new(999);
assert!(result.is_err(), "expected error for ordinal 999");
}
}