1use std::ffi::c_void;
15
16use oxicuda_driver::ffi::{CUfunction, CUmodule};
17use oxicuda_driver::loader::try_driver;
18
19use crate::error::{CudaRtError, CudaRtResult};
20use crate::stream::CudaStream;
21
22pub type CudaFunction = CUfunction;
26
27pub type CudaModule = CUmodule;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub struct Dim3 {
37 pub x: u32,
39 pub y: u32,
41 pub z: u32,
43}
44
45impl Dim3 {
46 #[must_use]
48 pub const fn one_d(x: u32) -> Self {
49 Self { x, y: 1, z: 1 }
50 }
51
52 #[must_use]
54 pub const fn two_d(x: u32, y: u32) -> Self {
55 Self { x, y, z: 1 }
56 }
57
58 #[must_use]
60 pub const fn three_d(x: u32, y: u32, z: u32) -> Self {
61 Self { x, y, z }
62 }
63
64 #[must_use]
66 pub fn volume(self) -> u64 {
67 self.x as u64 * self.y as u64 * self.z as u64
68 }
69}
70
71impl From<u32> for Dim3 {
72 fn from(x: u32) -> Self {
73 Self::one_d(x)
74 }
75}
76
77impl From<(u32, u32)> for Dim3 {
78 fn from((x, y): (u32, u32)) -> Self {
79 Self::two_d(x, y)
80 }
81}
82
83impl From<(u32, u32, u32)> for Dim3 {
84 fn from((x, y, z): (u32, u32, u32)) -> Self {
85 Self::three_d(x, y, z)
86 }
87}
88
89#[derive(Debug, Clone, Copy, Default)]
95pub struct FuncAttributes {
96 pub shared_size_bytes: usize,
98 pub const_size_bytes: usize,
100 pub local_size_bytes: usize,
102 pub max_threads_per_block: u32,
104 pub num_regs: u32,
106 pub ptx_version: u32,
108 pub binary_version: u32,
110 pub cache_mode_ca: bool,
112 pub max_dynamic_shared_size_bytes: usize,
114 pub preferred_shared_memory_carveout: i32,
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
122pub enum FuncAttribute {
123 MaxDynamicSharedMemorySize = 8,
125 PreferredSharedMemoryCarveout = 9,
127}
128
129pub fn module_load_ptx(ptx: &[u8]) -> CudaRtResult<CudaModule> {
139 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
140 let mut module = CUmodule::default();
141 let mut ptx_owned;
143 let ptx_ptr = if ptx.last().copied() == Some(0) {
144 ptx.as_ptr()
145 } else {
146 ptx_owned = ptx.to_vec();
147 ptx_owned.push(0);
148 ptx_owned.as_ptr()
149 };
150 let rc = unsafe {
152 (api.cu_module_load_data_ex)(
153 &raw mut module,
154 ptx_ptr as *const c_void,
155 0,
156 std::ptr::null_mut(),
157 std::ptr::null_mut(),
158 )
159 };
160 if rc != 0 {
161 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidPtx));
162 }
163 Ok(module)
164}
165
166pub fn module_get_function(module: CudaModule, name: &str) -> CudaRtResult<CudaFunction> {
174 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
175 let mut func = CUfunction::default();
176 let name_cstr = std::ffi::CString::new(name).map_err(|_| CudaRtError::InvalidSymbol)?;
177 let rc = unsafe { (api.cu_module_get_function)(&raw mut func, module, name_cstr.as_ptr()) };
179 if rc != 0 {
180 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::SymbolNotFound));
181 }
182 Ok(func)
183}
184
185pub fn module_unload(module: CudaModule) -> CudaRtResult<()> {
193 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
194 let rc = unsafe { (api.cu_module_unload)(module) };
196 if rc != 0 {
197 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
198 }
199 Ok(())
200}
201
202pub unsafe fn launch_kernel(
229 func: CudaFunction,
230 grid: Dim3,
231 block: Dim3,
232 args: &mut [*mut c_void],
233 shared_mem: u32,
234 stream: CudaStream,
235) -> CudaRtResult<()> {
236 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
237 let rc = unsafe {
239 (api.cu_launch_kernel)(
240 func,
241 grid.x,
242 grid.y,
243 grid.z,
244 block.x,
245 block.y,
246 block.z,
247 shared_mem,
248 stream.raw(),
249 args.as_mut_ptr(),
250 std::ptr::null_mut(), )
252 };
253 if rc != 0 {
254 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::LaunchFailure));
255 }
256 Ok(())
257}
258
259pub fn func_get_attributes(func: CudaFunction) -> CudaRtResult<FuncAttributes> {
267 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
268
269 let get_attr_fn = api.cu_func_get_attribute.ok_or(CudaRtError::NotSupported)?;
271 let attr = |a: oxicuda_driver::ffi::CUfunction_attribute| -> CudaRtResult<i32> {
272 let mut v: std::ffi::c_int = 0;
273 let rc = unsafe { get_attr_fn(&raw mut v, a as std::ffi::c_int, func) };
275 if rc != 0 {
276 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDeviceFunction));
277 }
278 Ok(v)
279 };
280
281 use oxicuda_driver::ffi::CUfunction_attribute as FA;
282 Ok(FuncAttributes {
283 shared_size_bytes: attr(FA::SharedSizeBytes)? as usize,
284 const_size_bytes: attr(FA::ConstSizeBytes)? as usize,
285 local_size_bytes: attr(FA::LocalSizeBytes)? as usize,
286 max_threads_per_block: attr(FA::MaxThreadsPerBlock)? as u32,
287 num_regs: attr(FA::NumRegs)? as u32,
288 ptx_version: attr(FA::PtxVersion)? as u32,
289 binary_version: attr(FA::BinaryVersion)? as u32,
290 cache_mode_ca: attr(FA::CacheModeCa)? != 0,
291 max_dynamic_shared_size_bytes: attr(FA::MaxDynamicSharedSizeBytes)? as usize,
292 preferred_shared_memory_carveout: attr(FA::PreferredSharedMemoryCarveout)?,
293 })
294}
295
296pub fn func_set_attribute(func: CudaFunction, attr: FuncAttribute, value: i32) -> CudaRtResult<()> {
304 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
305 let set_attr_fn = api.cu_func_set_attribute.ok_or(CudaRtError::NotSupported)?;
307 let rc = unsafe { set_attr_fn(func, attr as std::ffi::c_int, value) };
309 if rc != 0 {
310 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDeviceFunction));
311 }
312 Ok(())
313}
314
315#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn dim3_one_d() {
323 let d = Dim3::one_d(128);
324 assert_eq!(d.x, 128);
325 assert_eq!(d.y, 1);
326 assert_eq!(d.z, 1);
327 assert_eq!(d.volume(), 128);
328 }
329
330 #[test]
331 fn dim3_from_u32() {
332 let d: Dim3 = 256u32.into();
333 assert_eq!(d.x, 256);
334 }
335
336 #[test]
337 fn dim3_from_tuple() {
338 let d: Dim3 = (32u32, 8u32).into();
339 assert_eq!(d.volume(), 256);
340 let d3: Dim3 = (4u32, 4u32, 4u32).into();
341 assert_eq!(d3.volume(), 64);
342 }
343
344 #[test]
345 fn dim3_volume() {
346 assert_eq!(Dim3::three_d(2, 3, 4).volume(), 24);
347 }
348
349 #[test]
350 fn module_load_ptx_without_gpu_errors() {
351 let ptx = b"// empty\n\0";
352 let _ = module_load_ptx(ptx); }
354}