1use std::sync::Arc;
4
5use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
6
7use crate::buffer::GpuBuffer;
8use crate::error::Result;
9
10pub struct KaioDevice {
23 ctx: Arc<CudaContext>,
24 stream: Arc<CudaStream>,
25}
26
27impl std::fmt::Debug for KaioDevice {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("KaioDevice")
30 .field("ordinal", &self.ctx.ordinal())
31 .finish()
32 }
33}
34
35impl KaioDevice {
36 pub fn new(ordinal: usize) -> Result<Self> {
41 let ctx = CudaContext::new(ordinal)?;
42 let stream = ctx.default_stream();
43 Ok(Self { ctx, stream })
44 }
45
46 pub fn info(&self) -> Result<DeviceInfo> {
48 DeviceInfo::from_context(&self.ctx)
49 }
50
51 pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
53 let slice = self.stream.clone_htod(data)?;
54 Ok(GpuBuffer::from_raw(slice))
55 }
56
57 pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
59 let slice = self.stream.alloc_zeros::<T>(len)?;
60 Ok(GpuBuffer::from_raw(slice))
61 }
62
63 pub fn stream(&self) -> &Arc<CudaStream> {
68 &self.stream
69 }
70
71 pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
84 let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
85 let module = self.ctx.load_module(ptx)?;
86 Ok(crate::module::KaioModule::from_raw(module))
87 }
88
89 pub fn load_module(
103 &self,
104 module: &kaio_core::ir::PtxModule,
105 ) -> Result<crate::module::KaioModule> {
106 use kaio_core::emit::{Emit, PtxWriter};
107
108 module.validate()?;
109
110 let mut w = PtxWriter::new();
111 module
112 .emit(&mut w)
113 .map_err(|e| crate::error::KaioError::PtxLoad(format!("emit failed: {e}")))?;
114 let ptx_text = w.finish();
115
116 self.load_ptx(&ptx_text)
117 }
118}
119
120#[derive(Debug, Clone)]
127pub struct DeviceInfo {
128 pub name: String,
130 pub compute_capability: (u32, u32),
132 pub total_memory: usize,
134}
135
136impl DeviceInfo {
137 fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
139 use cudarc::driver::result::device;
140
141 let ordinal = ctx.ordinal();
142 let dev = device::get(ordinal as i32)?;
143 let name = device::get_name(dev)?;
144 let total_memory = unsafe { device::total_mem(dev)? };
145
146 let major = unsafe {
149 device::get_attribute(
150 dev,
151 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
152 )?
153 };
154 let minor = unsafe {
155 device::get_attribute(
156 dev,
157 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
158 )?
159 };
160
161 Ok(Self {
162 name,
163 compute_capability: (major as u32, minor as u32),
164 total_memory,
165 })
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use std::sync::OnceLock;
173
174 static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
175 fn device() -> &'static KaioDevice {
176 DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
177 }
178
179 #[test]
180 #[ignore] fn device_creation() {
182 let dev = KaioDevice::new(0);
183 assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
184 }
185
186 #[test]
187 #[ignore]
188 fn device_info_name() {
189 let info = device().info().expect("info() failed");
190 assert!(!info.name.is_empty(), "device name should not be empty");
191 eprintln!("GPU name: {}", info.name);
193 }
194
195 #[test]
196 #[ignore]
197 fn device_info_compute_capability() {
198 let info = device().info().expect("info() failed");
199 let (major, _minor) = info.compute_capability;
201 assert!(
202 major >= 7,
203 "expected SM 7.0+ GPU, got SM {}.{}",
204 info.compute_capability.0,
205 info.compute_capability.1,
206 );
207 eprintln!(
208 "GPU compute capability: SM {}.{}",
209 info.compute_capability.0, info.compute_capability.1
210 );
211 }
212
213 #[test]
214 #[ignore]
215 fn buffer_roundtrip_f32() {
216 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
217 let buf = device().alloc_from(&data).expect("alloc_from failed");
218 let result = buf.to_host(device()).expect("to_host failed");
219 assert_eq!(result, data, "roundtrip data mismatch");
220 }
221
222 #[test]
223 #[ignore]
224 fn buffer_alloc_zeros() {
225 let buf = device()
226 .alloc_zeros::<f32>(100)
227 .expect("alloc_zeros failed");
228 let result = buf.to_host(device()).expect("to_host failed");
229 assert_eq!(result, vec![0.0f32; 100]);
230 }
231
232 #[test]
233 #[ignore]
234 fn buffer_len() {
235 let buf = device()
236 .alloc_from(&[1.0f32, 2.0, 3.0])
237 .expect("alloc_from failed");
238 assert_eq!(buf.len(), 3);
239 assert!(!buf.is_empty());
240 }
241
242 #[test]
243 #[ignore]
244 fn invalid_device_ordinal() {
245 let result = KaioDevice::new(999);
246 assert!(result.is_err(), "expected error for ordinal 999");
247 }
248}