Skip to main content

kaio_runtime/
device.rs

1//! CUDA device management.
2
3use std::sync::Arc;
4
5use cudarc::driver::{CudaContext, CudaStream, DeviceRepr, ValidAsZeroBits};
6
7use crate::buffer::GpuBuffer;
8use crate::error::Result;
9
10/// A KAIO GPU device — wraps a CUDA context and its default stream.
11///
12/// Created via [`KaioDevice::new`] with a device ordinal (0 for the first GPU).
13/// All allocation and transfer operations go through the default stream.
14///
15/// # Example
16///
17/// ```ignore
18/// let device = KaioDevice::new(0)?;
19/// let buf = device.alloc_from(&[1.0f32, 2.0, 3.0])?;
20/// let host = buf.to_host(&device)?;
21/// ```
22pub struct KaioDevice {
23    ctx: Arc<CudaContext>,
24    stream: Arc<CudaStream>,
25}
26
27impl std::fmt::Debug for KaioDevice {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("KaioDevice")
30            .field("ordinal", &self.ctx.ordinal())
31            .finish()
32    }
33}
34
35impl KaioDevice {
36    /// Create a new device targeting the GPU at the given ordinal.
37    ///
38    /// Ordinal 0 is the first GPU. Returns an error if no GPU exists at
39    /// that ordinal or if the CUDA driver fails to initialize.
40    pub fn new(ordinal: usize) -> Result<Self> {
41        let ctx = CudaContext::new(ordinal)?;
42        let stream = ctx.default_stream();
43        Ok(Self { ctx, stream })
44    }
45
46    /// Query basic information about this device.
47    pub fn info(&self) -> Result<DeviceInfo> {
48        DeviceInfo::from_context(&self.ctx)
49    }
50
51    /// Allocate device memory and copy data from a host slice.
52    pub fn alloc_from<T: DeviceRepr>(&self, data: &[T]) -> Result<GpuBuffer<T>> {
53        let slice = self.stream.clone_htod(data)?;
54        Ok(GpuBuffer::from_raw(slice))
55    }
56
57    /// Allocate zero-initialized device memory.
58    pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuBuffer<T>> {
59        let slice = self.stream.alloc_zeros::<T>(len)?;
60        Ok(GpuBuffer::from_raw(slice))
61    }
62
63    /// Access the underlying CUDA stream for kernel launch operations.
64    ///
65    /// Used with cudarc's `launch_builder` to launch kernels. In Phase 2,
66    /// the proc macro will generate typed wrappers that hide this.
67    pub fn stream(&self) -> &Arc<CudaStream> {
68        &self.stream
69    }
70
71    /// Load a PTX module from source text and return a [`crate::module::KaioModule`].
72    ///
73    /// The PTX text is passed to the CUDA driver's `cuModuleLoadData` —
74    /// no NVRTC compilation occurs. The driver JIT-compiles the PTX for
75    /// the current GPU.
76    ///
77    /// # Example
78    ///
79    /// ```ignore
80    /// let module = device.load_ptx(&ptx_text)?;
81    /// let func = module.function("vector_add")?;
82    /// ```
83    pub fn load_ptx(&self, ptx_text: &str) -> Result<crate::module::KaioModule> {
84        let ptx = cudarc::nvrtc::Ptx::from_src(ptx_text);
85        let module = self.ctx.load_module(ptx)?;
86        Ok(crate::module::KaioModule::from_raw(module))
87    }
88}
89
90/// Basic information about a CUDA device.
91///
92/// Phase 1 includes name, compute capability, and total memory.
93/// Additional fields (SM count, max threads per block, max shared memory,
94/// warp size) are planned for Phase 3/4 when shared memory and occupancy
95/// calculations matter.
96#[derive(Debug, Clone)]
97pub struct DeviceInfo {
98    /// GPU device name (e.g. "NVIDIA GeForce RTX 4090").
99    pub name: String,
100    /// Compute capability as (major, minor) — e.g. (8, 9) for SM 8.9.
101    pub compute_capability: (u32, u32),
102    /// Total device memory in bytes.
103    pub total_memory: usize,
104}
105
106impl DeviceInfo {
107    /// Query device info from a CUDA context.
108    fn from_context(ctx: &Arc<CudaContext>) -> Result<Self> {
109        use cudarc::driver::result::device;
110
111        let ordinal = ctx.ordinal();
112        let dev = device::get(ordinal as i32)?;
113        let name = device::get_name(dev)?;
114        let total_memory = unsafe { device::total_mem(dev)? };
115
116        // SAFETY: dev is a valid device handle obtained from device::get().
117        // get_attribute reads a device property — no mutation, no aliasing.
118        let major = unsafe {
119            device::get_attribute(
120                dev,
121                cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
122            )?
123        };
124        let minor = unsafe {
125            device::get_attribute(
126                dev,
127                cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
128            )?
129        };
130
131        Ok(Self {
132            name,
133            compute_capability: (major as u32, minor as u32),
134            total_memory,
135        })
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use std::sync::OnceLock;
143
144    static DEVICE: OnceLock<KaioDevice> = OnceLock::new();
145    fn device() -> &'static KaioDevice {
146        DEVICE.get_or_init(|| KaioDevice::new(0).expect("GPU required for tests"))
147    }
148
149    #[test]
150    #[ignore] // requires NVIDIA GPU
151    fn device_creation() {
152        let dev = KaioDevice::new(0);
153        assert!(dev.is_ok(), "KaioDevice::new(0) failed: {dev:?}");
154    }
155
156    #[test]
157    #[ignore]
158    fn device_info_name() {
159        let info = device().info().expect("info() failed");
160        assert!(!info.name.is_empty(), "device name should not be empty");
161        // RTX 4090 should contain "4090" somewhere in the name
162        eprintln!("GPU name: {}", info.name);
163    }
164
165    #[test]
166    #[ignore]
167    fn device_info_compute_capability() {
168        let info = device().info().expect("info() failed");
169        // Any SM 7.0+ GPU should work (Volta and newer)
170        let (major, _minor) = info.compute_capability;
171        assert!(
172            major >= 7,
173            "expected SM 7.0+ GPU, got SM {}.{}",
174            info.compute_capability.0,
175            info.compute_capability.1,
176        );
177        eprintln!(
178            "GPU compute capability: SM {}.{}",
179            info.compute_capability.0, info.compute_capability.1
180        );
181    }
182
183    #[test]
184    #[ignore]
185    fn buffer_roundtrip_f32() {
186        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
187        let buf = device().alloc_from(&data).expect("alloc_from failed");
188        let result = buf.to_host(device()).expect("to_host failed");
189        assert_eq!(result, data, "roundtrip data mismatch");
190    }
191
192    #[test]
193    #[ignore]
194    fn buffer_alloc_zeros() {
195        let buf = device()
196            .alloc_zeros::<f32>(100)
197            .expect("alloc_zeros failed");
198        let result = buf.to_host(device()).expect("to_host failed");
199        assert_eq!(result, vec![0.0f32; 100]);
200    }
201
202    #[test]
203    #[ignore]
204    fn buffer_len() {
205        let buf = device()
206            .alloc_from(&[1.0f32, 2.0, 3.0])
207            .expect("alloc_from failed");
208        assert_eq!(buf.len(), 3);
209        assert!(!buf.is_empty());
210    }
211
212    #[test]
213    #[ignore]
214    fn invalid_device_ordinal() {
215        let result = KaioDevice::new(999);
216        assert!(result.is_err(), "expected error for ordinal 999");
217    }
218}