Skip to main content

singe_cuda/
kernel.rs

1use std::{ffi::CStr, ptr};
2
3use singe_cuda_sys::driver;
4
5use crate::{
6    context::Context,
7    error::{Error, Result},
8    try_ffi,
9    types::FunctionAttribute,
10};
11
12/// Raw CUDA kernel handle adapter.
13///
14/// Implementors define how to query names, attributes, and mutable attributes
15/// for a specific CUDA kernel handle type.
16pub trait KernelHandle {
17    /// Raw CUDA handle type accepted by the corresponding driver calls.
18    type RawHandle: Copy;
19
20    /// Calls the CUDA API that returns the kernel name for `raw`.
21    ///
22    /// # Safety
23    ///
24    /// `raw` must be a valid kernel handle for `device_id`, and `name` must be
25    /// valid for CUDA to write a pointer-sized result.
26    unsafe fn raw_name(
27        raw: Self::RawHandle,
28        name: *mut *const i8,
29        device_id: i32,
30    ) -> driver::CUresult;
31
32    /// Calls the CUDA API that returns a kernel attribute for `raw`.
33    ///
34    /// # Safety
35    ///
36    /// `raw` must be a valid kernel handle for `device_id`, and `value` must be
37    /// valid for CUDA to write an `i32` result.
38    unsafe fn raw_attribute(
39        value: *mut i32,
40        attribute: driver::CUfunction_attribute,
41        raw: Self::RawHandle,
42        device_id: i32,
43    ) -> driver::CUresult;
44
45    /// Calls the CUDA API that sets a kernel attribute for `raw`.
46    ///
47    /// # Safety
48    ///
49    /// `raw` must be a valid kernel handle for `device_id`, and the selected
50    /// attribute/value pair must be accepted by CUDA for that handle.
51    unsafe fn set_attribute(
52        raw: Self::RawHandle,
53        attribute: driver::CUfunction_attribute,
54        value: i32,
55        device_id: i32,
56    ) -> driver::CUresult;
57}
58
59#[derive(Debug)]
60pub struct ModuleKernelHandle;
61
62impl KernelHandle for ModuleKernelHandle {
63    type RawHandle = driver::CUfunction;
64
65    unsafe fn raw_name(
66        raw: Self::RawHandle,
67        name: *mut *const i8,
68        _device_id: i32,
69    ) -> driver::CUresult {
70        unsafe { driver::cuFuncGetName(name, raw) }
71    }
72
73    unsafe fn raw_attribute(
74        value: *mut i32,
75        attribute: driver::CUfunction_attribute,
76        raw: Self::RawHandle,
77        _device_id: i32,
78    ) -> driver::CUresult {
79        unsafe { driver::cuFuncGetAttribute(value, attribute, raw) }
80    }
81
82    unsafe fn set_attribute(
83        raw: Self::RawHandle,
84        attribute: driver::CUfunction_attribute,
85        value: i32,
86        _device_id: i32,
87    ) -> driver::CUresult {
88        unsafe { driver::cuFuncSetAttribute(raw, attribute, value) }
89    }
90}
91
92#[derive(Debug)]
93pub struct LibraryKernelHandle;
94
95impl KernelHandle for LibraryKernelHandle {
96    type RawHandle = driver::CUkernel;
97
98    unsafe fn raw_name(
99        raw: Self::RawHandle,
100        name: *mut *const i8,
101        _device_id: i32,
102    ) -> driver::CUresult {
103        unsafe { driver::cuKernelGetName(name, raw) }
104    }
105
106    unsafe fn raw_attribute(
107        value: *mut i32,
108        attribute: driver::CUfunction_attribute,
109        raw: Self::RawHandle,
110        device_id: i32,
111    ) -> driver::CUresult {
112        unsafe { driver::cuKernelGetAttribute(value, attribute, raw, device_id) }
113    }
114
115    unsafe fn set_attribute(
116        raw: Self::RawHandle,
117        attribute: driver::CUfunction_attribute,
118        value: i32,
119        device_id: i32,
120    ) -> driver::CUresult {
121        unsafe { driver::cuKernelSetAttribute(attribute, value, raw, device_id) }
122    }
123}
124
125pub fn name<H: KernelHandle>(ctx: &Context, raw: H::RawHandle) -> Result<String> {
126    ctx.bind()?;
127    let mut name = ptr::null();
128    unsafe {
129        try_ffi!(H::raw_name(raw, &raw mut name, ctx.device().id()))?;
130        if name.is_null() {
131            return Err(Error::NullHandle);
132        }
133        Ok(CStr::from_ptr(name).to_string_lossy().into_owned())
134    }
135}
136
137pub fn attribute<H: KernelHandle>(
138    ctx: &Context,
139    raw: H::RawHandle,
140    attribute: FunctionAttribute,
141) -> Result<i32> {
142    ctx.bind()?;
143    let mut value = 0;
144    unsafe {
145        try_ffi!(H::raw_attribute(
146            &raw mut value,
147            attribute.into(),
148            raw,
149            ctx.device().id(),
150        ))?;
151    }
152    Ok(value)
153}
154
155pub fn set_attribute<H: KernelHandle>(
156    ctx: &Context,
157    raw: H::RawHandle,
158    attribute: FunctionAttribute,
159    value: i32,
160) -> Result<()> {
161    ctx.bind()?;
162    unsafe {
163        try_ffi!(H::set_attribute(
164            raw,
165            attribute.into(),
166            value,
167            ctx.device().id(),
168        ))?;
169    }
170    Ok(())
171}