kaio_runtime/device.rs
1//! CUDA device management.
2
3use std::sync::Arc;
4
5use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
6
7use crate::buffer::GpuBuffer;
8use crate::error::Result;
9
10/// A KAIO GPU device — wraps a CUDA context and its default stream.
11///
12/// Created via [`KaioDevice::new`] with a device ordinal (0 for the first GPU).
13/// All allocation and transfer operations go through the default stream.
14///
15/// # Example
16///
17/// ```ignore
18/// let device = KaioDevice::new(0)?;
19/// let buf = device.alloc_from(&[1.0f32, 2.0, 3.0])?;
20/// let host = buf.to_host(&device)?;
21/// ```
22pub struct KaioDevice {
23 ctx: Arc<CudaContext>,
24 stream: Arc<CudaStream>,
25}
26
27impl std::fmt::Debug for KaioDevice {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("KaioDevice")
30 .field("ordinal", &self.ctx.ordinal())
31 .finish()
32 }
33}
34
35impl KaioDevice {
36 /// Create a new device targeting the GPU at the given ordinal.
37 ///
38 /// Ordinal 0 is the first GPU. Returns an error if no GPU exists at
39 /// that ordinal or if the CUDA driver fails to initialize.
40 pub fn new(ordinal: usize) -> Result<Self> {
41 let ctx = CudaContext::new(ordinal)?;
42 let stream = ctx.default_stream();
43 Ok(Self { ctx, stream })
44 }
45
46 /// Query basic information about this device.
47 pub fn info(&self) -> Result<DeviceInfo> {
48 DeviceInfo::from_context(&self.ctx)
49 }
50
51 /// Allocate device memory and copy data from a host slice.
52 pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
53 let slice = self.stream.clone_htod(data)?;
54 Ok(GpuBuffer::from_raw(slice))
55 }
56
57 /// Allocate zero-initialized device memory.
58 pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
59 let slice = self.stream.alloc_zeros::<T>(len)?;
60 Ok(GpuBuffer::from_raw(slice))
61 }
62
63 /// Access the underlying CUDA stream for kernel launch operations.
64 ///
65 /// Used with cudarc's `launch_builder` to launch kernels. In Phase 2,
66 /// the proc macro will generate typed wrappers that hide this.
67 pub fn stream(&self) -> &Arc<CudaStream> {
68 &self.stream
69 }
70
71 /// Load a PTX module from source text and return a [`crate::module::KaioModule`].
72 ///
73 /// The PTX text is passed to the CUDA driver's `cuModuleLoadData` —
74 /// no NVRTC compilation occurs. The driver JIT-compiles the PTX for
75 /// the current GPU.
76 ///
77 /// # Deprecated — prefer [`load_module`](Self::load_module)
78 ///
79 /// The module path runs
80 /// [`PtxModule::validate`](kaio_core::ir::PtxModule::validate)
81 /// before the driver sees the PTX, catching SM mismatches (e.g.
82 /// `mma.sync` on sub-Ampere targets) with readable
83 /// [`KaioError::Validation`](crate::error::KaioError::Validation)
84 /// errors instead of cryptic `ptxas` failures deep in the driver.
85 ///
86 /// This function remains public for raw-PTX use cases (external PTX
87 /// files, hand-written PTX for research, bypassing validation
88 /// intentionally). It is not scheduled for removal in the 0.2.x line.
89 ///
90 /// # Migration
91 ///
92 /// Before:
93 /// ```ignore
94 /// let ptx_text: String = build_my_ptx();
95 /// let module = device.load_ptx(&ptx_text)?;
96 /// ```
97 ///
98 /// After:
99 /// ```ignore
100 /// use kaio_core::ir::PtxModule;
101 /// let ptx_module: PtxModule = build_my_module("sm_80");
102 /// let module = device.load_module(&ptx_module)?;
103 /// ```
104 #[deprecated(
105 since = "0.2.1",
106 note = "use load_module(&PtxModule) — runs PtxModule::validate() for readable SM-mismatch errors"
107 )]
108 pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
109 let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
110 let module = self.ctx.load_module(ptx)?;
111 Ok(crate::module::KaioModule::from_raw(module))
112 }
113
114 /// Validate, emit, and load a [`kaio_core::ir::PtxModule`] on the device.
115 ///
116 /// This is the preferred entrypoint when the caller has an in-memory
117 /// `PtxModule` (as opposed to raw PTX text). Before the PTX text is
118 /// handed to the driver, [`kaio_core::ir::PtxModule::validate`]
119 /// checks that the module's target SM supports every feature used by
120 /// its kernels — raising
121 /// [`KaioError::Validation`](crate::error::KaioError::Validation) if
122 /// e.g. a `mma.sync` op is present but the target is `sm_70`.
123 ///
124 /// Surfacing the error at this layer gives the user a readable
125 /// message ("`mma.sync.m16n8k16 requires sm_80+, target is sm_70`")
126 /// instead of a cryptic `ptxas` error from deep in the driver.
127 pub fn load_module(
128 &self,
129 module: &kaio_core::ir::PtxModule,
130 ) -> Result<crate::module::KaioModule> {
131 use kaio_core::emit::{Emit, PtxWriter};
132
133 module.validate()?;
134
135 let mut w = PtxWriter::new();
136 module
137 .emit(&mut w)
138 .map_err(|e| crate::error::KaioError::PtxLoad(format!("emit failed: {e}")))?;
139 let ptx_text = w.finish();
140
141 // `load_ptx` is #[deprecated] as a public API to steer users to the
142 // validated module path, but it's still the correct internal
143 // implementation detail after we've emitted the PTX text here.
144 #[allow(deprecated)]
145 self.load_ptx(&ptx_text)
146 }
147}
148
149/// Basic information about a CUDA device.
150///
151/// Phase 1 includes name, compute capability, and total memory.
152/// Additional fields (SM count, max threads per block, max shared memory,
153/// warp size) are planned for Phase 3/4 when shared memory and occupancy
154/// calculations matter.
155#[derive(Debug, Clone)]
156pub struct DeviceInfo {
157 /// GPU device name (e.g. "NVIDIA GeForce RTX 4090").
158 pub name: String,
159 /// Compute capability as (major, minor) — e.g. (8, 9) for SM 8.9.
160 pub compute_capability: (u32, u32),
161 /// Total device memory in bytes.
162 pub total_memory: usize,
163}
164
165impl DeviceInfo {
166 /// Query device info from a CUDA context.
167 fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
168 use cudarc::driver::result::device;
169
170 let ordinal = ctx.ordinal();
171 let dev = device::get(ordinal as i32)?;
172 let name = device::get_name(dev)?;
173 let total_memory = unsafe { device::total_mem(dev)? };
174
175 // SAFETY: dev is a valid device handle obtained from device::get().
176 // get_attribute reads a device property — no mutation, no aliasing.
177 let major = unsafe {
178 device::get_attribute(
179 dev,
180 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
181 )?
182 };
183 let minor = unsafe {
184 device::get_attribute(
185 dev,
186 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
187 )?
188 };
189
190 Ok(Self {
191 name,
192 compute_capability: (major as u32, minor as u32),
193 total_memory,
194 })
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use std::sync::OnceLock;
202
203 static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
204 fn device() -> &'static KaioDevice {
205 DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
206 }
207
208 #[test]
209 #[ignore] // requires NVIDIA GPU
210 fn device_creation() {
211 let dev = KaioDevice::new(0);
212 assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
213 }
214
215 #[test]
216 #[ignore]
217 fn device_info_name() {
218 let info = device().info().expect("info() failed");
219 assert!(!info.name.is_empty(), "device name should not be empty");
220 // RTX 4090 should contain "4090" somewhere in the name
221 eprintln!("GPU name: {}", info.name);
222 }
223
224 #[test]
225 #[ignore]
226 fn device_info_compute_capability() {
227 let info = device().info().expect("info() failed");
228 // Any SM 7.0+ GPU should work (Volta and newer)
229 let (major, _minor) = info.compute_capability;
230 assert!(
231 major >= 7,
232 "expected SM 7.0+ GPU, got SM {}.{}",
233 info.compute_capability.0,
234 info.compute_capability.1,
235 );
236 eprintln!(
237 "GPU compute capability: SM {}.{}",
238 info.compute_capability.0, info.compute_capability.1
239 );
240 }
241
242 #[test]
243 #[ignore]
244 fn buffer_roundtrip_f32() {
245 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
246 let buf = device().alloc_from(&data).expect("alloc_from failed");
247 let result = buf.to_host(device()).expect("to_host failed");
248 assert_eq!(result, data, "roundtrip data mismatch");
249 }
250
251 #[test]
252 #[ignore]
253 fn buffer_alloc_zeros() {
254 let buf = device()
255 .alloc_zeros::<f32>(100)
256 .expect("alloc_zeros failed");
257 let result = buf.to_host(device()).expect("to_host failed");
258 assert_eq!(result, vec![0.0f32; 100]);
259 }
260
261 #[test]
262 #[ignore]
263 fn buffer_len() {
264 let buf = device()
265 .alloc_from(&[1.0f32, 2.0, 3.0])
266 .expect("alloc_from failed");
267 assert_eq!(buf.len(), 3);
268 assert!(!buf.is_empty());
269 }
270
271 #[test]
272 #[ignore]
273 fn invalid_device_ordinal() {
274 let result = KaioDevice::new(999);
275 assert!(result.is_err(), "expected error for ordinal 999");
276 }
277}