1use std::sync::Arc;
4
5use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
6
7use crate::buffer::GpuBuffer;
8use crate::error::Result;
9
10pub struct KaioDevice {
23 ctx: Arc<CudaContext>,
24 stream: Arc<CudaStream>,
25}
26
27impl std::fmt::Debug for KaioDevice {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("KaioDevice")
30 .field("ordinal", &self.ctx.ordinal())
31 .finish()
32 }
33}
34
35impl KaioDevice {
36 pub fn new(ordinal: usize) -> Result<Self> {
41 let ctx = CudaContext::new(ordinal)?;
42 let stream = ctx.default_stream();
43 Ok(Self { ctx, stream })
44 }
45
46 pub fn info(&self) -> Result<DeviceInfo> {
48 DeviceInfo::from_context(&self.ctx)
49 }
50
51 pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
53 let slice = self.stream.clone_htod(data)?;
54 Ok(GpuBuffer::from_raw(slice))
55 }
56
57 pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
59 let slice = self.stream.alloc_zeros::<T>(len)?;
60 Ok(GpuBuffer::from_raw(slice))
61 }
62
63 pub fn stream(&self) -> &Arc<CudaStream> {
68 &self.stream
69 }
70
71 pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
84 let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
85 let module = self.ctx.load_module(ptx)?;
86 Ok(crate::module::KaioModule::from_raw(module))
87 }
88}
89
90#[derive(Debug, Clone)]
97pub struct DeviceInfo {
98 pub name: String,
100 pub compute_capability: (u32, u32),
102 pub total_memory: usize,
104}
105
106impl DeviceInfo {
107 fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
109 use cudarc::driver::result::device;
110
111 let ordinal = ctx.ordinal();
112 let dev = device::get(ordinal as i32)?;
113 let name = device::get_name(dev)?;
114 let total_memory = unsafe { device::total_mem(dev)? };
115
116 let major = unsafe {
119 device::get_attribute(
120 dev,
121 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
122 )?
123 };
124 let minor = unsafe {
125 device::get_attribute(
126 dev,
127 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
128 )?
129 };
130
131 Ok(Self {
132 name,
133 compute_capability: (major as u32, minor as u32),
134 total_memory,
135 })
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::sync::OnceLock;
143
144 static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
145 fn device() -> &'static KaioDevice {
146 DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
147 }
148
149 #[test]
150 #[ignore] fn device_creation() {
152 let dev = KaioDevice::new(0);
153 assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
154 }
155
156 #[test]
157 #[ignore]
158 fn device_info_name() {
159 let info = device().info().expect("info() failed");
160 assert!(!info.name.is_empty(), "device name should not be empty");
161 eprintln!("GPU name: {}", info.name);
163 }
164
165 #[test]
166 #[ignore]
167 fn device_info_compute_capability() {
168 let info = device().info().expect("info() failed");
169 let (major, _minor) = info.compute_capability;
171 assert!(
172 major >= 7,
173 "expected SM 7.0+ GPU, got SM {}.{}",
174 info.compute_capability.0,
175 info.compute_capability.1,
176 );
177 eprintln!(
178 "GPU compute capability: SM {}.{}",
179 info.compute_capability.0, info.compute_capability.1
180 );
181 }
182
183 #[test]
184 #[ignore]
185 fn buffer_roundtrip_f32() {
186 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
187 let buf = device().alloc_from(&data).expect("alloc_from failed");
188 let result = buf.to_host(device()).expect("to_host failed");
189 assert_eq!(result, data, "roundtrip data mismatch");
190 }
191
192 #[test]
193 #[ignore]
194 fn buffer_alloc_zeros() {
195 let buf = device()
196 .alloc_zeros::<f32>(100)
197 .expect("alloc_zeros failed");
198 let result = buf.to_host(device()).expect("to_host failed");
199 assert_eq!(result, vec![0.0f32; 100]);
200 }
201
202 #[test]
203 #[ignore]
204 fn buffer_len() {
205 let buf = device()
206 .alloc_from(&[1.0f32, 2.0, 3.0])
207 .expect("alloc_from failed");
208 assert_eq!(buf.len(), 3);
209 assert!(!buf.is_empty());
210 }
211
212 #[test]
213 #[ignore]
214 fn invalid_device_ordinal() {
215 let result = KaioDevice::new(999);
216 assert!(result.is_err(), "expected error for ordinal 999");
217 }
218}