hlocl/
kernel.rs

1#[cfg(test)]
2extern crate std;
3
4use core::{mem::MaybeUninit, ptr::addr_of};
5use alloc::{string::{String}, vec::Vec};
6use opencl_sys::{cl_kernel, clReleaseKernel, clCreateKernel, clGetKernelInfo, cl_kernel_info, CL_KERNEL_FUNCTION_NAME, CL_KERNEL_NUM_ARGS, CL_KERNEL_REFERENCE_COUNT, CL_KERNEL_CONTEXT, CL_KERNEL_PROGRAM, clSetKernelArg, cl_kernel_arg_info, CL_KERNEL_ARG_ADDRESS_GLOBAL, CL_KERNEL_ARG_ADDRESS_LOCAL, CL_KERNEL_ARG_ADDRESS_CONSTANT, CL_KERNEL_ARG_ADDRESS_PRIVATE, CL_KERNEL_ARG_ADDRESS_QUALIFIER, CL_KERNEL_ARG_ACCESS_READ_ONLY, CL_KERNEL_ARG_ACCESS_WRITE_ONLY, CL_KERNEL_ARG_ACCESS_READ_WRITE, CL_KERNEL_ARG_ACCESS_NONE, CL_KERNEL_ARG_ACCESS_QUALIFIER, clGetKernelArgInfo, CL_KERNEL_ARG_NAME, CL_KERNEL_ARG_TYPE_NAME, CL_KERNEL_ARG_TYPE_CONST, CL_KERNEL_ARG_TYPE_RESTRICT, CL_KERNEL_ARG_TYPE_VOLATILE, CL_KERNEL_ARG_TYPE_QUALIFIER, clEnqueueNDRangeKernel, cl_mem, cl_kernel_arg_type_qualifier};
7use crate::{prelude::{Error, Program, Context, CommandQueue, BaseEvent}, error::Result, buffer::MemBuffer};
8
9#[cfg(feature = "error-stack")]
10use alloc::format;
11
12#[derive(PartialEq, Eq, Hash)]
13#[repr(transparent)]
14pub struct Kernel (pub(crate) cl_kernel);
15
16impl Kernel {
17    /// Creates a new kernel from a program and a name.
18    /// # Safety
19    /// It's up to the caller to ensure this is the only time the kernel is initialized
20    #[inline]
21    pub unsafe fn new_unchecked (program: &Program, name: &str) -> Result<Self> {
22        let mut name = name.as_bytes().to_vec();
23        name.push(0);
24        
25        let mut err = 0;
26        let id = clCreateKernel(program.0, name.as_ptr().cast(), &mut err);
27        if err == 0 { return Ok(Self(id)); }
28
29        cfg_if::cfg_if! {
30            if #[cfg(feature = "error-stack")] {
31                let err = Error::from(err);
32                let report = error_stack::Report::new(err);
33
34                let report = match err {
35                    Error::InvalidProgram => report.attach_printable("program is not a valid program object"),
36                    Error::InvalidProgramExecutable => report.attach_printable("there is no successfully built executable for program"),
37                    Error::InvalidKernelName => report.attach_printable("kernel name not found in program"),
38                    Error::InvalidKernelDefinition => report.attach_printable("the function definition for __kernel function given by kernel name does not exist"),
39                    Error::InvalidValue => report.attach_printable("kernel name is null"),
40                    Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
41                    _ => report
42                };
43
44                Err(report)
45            } else {
46                Err(Error::from(err))
47            }
48        }
49    }
50
51    #[inline(always)]
52    pub fn set_arg<T: Copy> (&mut self, idx: u32, v: T) -> Result<()> {
53        let err = unsafe { clSetKernelArg(self.0, idx, core::mem::size_of::<T>(), addr_of!(v).cast()) };
54        self.parse_error_set_arg(err, idx, core::mem::size_of::<T>())
55    }
56
57    #[inline(always)]
58    pub fn set_mem_arg<T: Copy + Unpin> (&mut self, idx: u32, v: &MemBuffer<T>) -> Result<()> {
59        let err = unsafe { clSetKernelArg(self.0, idx, core::mem::size_of::<cl_mem>(), addr_of!(v.0).cast()) };
60        self.parse_error_set_arg(err, idx, core::mem::size_of::<cl_mem>())
61    }
62
63    #[inline(always)]
64    pub fn alloc_arg<T> (&mut self, idx: u32, len: usize) -> Result<()> {
65        let arg_size = len.checked_mul(core::mem::size_of::<T>()).expect("Kernel argument size overflow");
66        let err = unsafe { clSetKernelArg(self.0, idx, arg_size, core::ptr::null_mut()) };
67        self.parse_error_set_arg(err, idx, arg_size)
68    }
69
70    /// Return the kernel function name.
71    #[inline(always)]
72    pub fn name (&self) -> Result<String> {
73        self.get_info_string(CL_KERNEL_FUNCTION_NAME)
74    }
75
76    /// Return the number of arguments to _kernel_.
77    #[inline(always)]
78    pub fn num_args (&self) -> Result<u32> {
79        self.get_info(CL_KERNEL_NUM_ARGS)
80    }
81
82    /// Return the _kernel_ reference count.
83    #[inline(always)]
84    pub fn reference_count (&self) -> Result<u32> {
85        self.get_info(CL_KERNEL_REFERENCE_COUNT)
86    }
87
88    /// Return the context associated with _kernel_.
89    #[inline(always)]
90    pub fn context (&self) -> Result<Context> {
91        self.get_info(CL_KERNEL_CONTEXT)
92    }
93
94    /// Return the program object associated with _kernel_.
95    #[inline(always)]
96    pub fn program (&self) -> Result<Program> {
97        self.get_info(CL_KERNEL_PROGRAM)
98    }
99
100    /// Returns the address qualifier specified for the argument given by ```idx```.
101    #[inline(always)]
102    pub fn arg_address_qualifier (&self, idx: u32) -> Result<AddrQualifier> {
103        self.get_arg_info(CL_KERNEL_ARG_ADDRESS_QUALIFIER, idx)
104    }
105
106    /// Returns the access qualifier specified for the argument given by ```idx```.
107    #[inline(always)]
108    pub fn arg_access_qualifier (&self, idx: u32) -> Result<AccessQualifier> {
109        self.get_arg_info(CL_KERNEL_ARG_ACCESS_QUALIFIER, idx)
110    }
111
112    /// Returns the type name specified for the argument given by ```idx```.
113    #[inline(always)]
114    pub fn arg_type_name (&self, idx: u32) -> Result<String> {
115        self.get_arg_info_string(CL_KERNEL_ARG_TYPE_NAME, idx)
116    }
117
118    /// Returns the type qualifier specified for the argument given by ```idx```.
119    #[inline(always)]
120    pub fn arg_qualifier (&self, idx: u32) -> Result<String> {
121        self.get_arg_info(CL_KERNEL_ARG_TYPE_QUALIFIER, idx)
122    }
123
124    /// Returns the name specified for the argument given by ```idx```. 
125    #[inline(always)]
126    pub fn arg_name (&self, idx: u32) -> Result<String> {
127        self.get_arg_info_string(CL_KERNEL_ARG_NAME, idx)
128    }
129
130    #[cfg(feature = "def")]
131    #[inline(always)]
132    pub fn enqueue<const N: usize> (&mut self, global_dims: &[usize; N], local_dims: Option<&[usize; N]>, wait: impl IntoIterator<Item = impl AsRef<BaseEvent>>) -> Result<BaseEvent> {
133        self.enqueue_with_queue(CommandQueue::default(), global_dims, local_dims, wait)
134    }
135
136    pub fn enqueue_with_queue<const N: usize> (&mut self, queue: &CommandQueue, global_dims: &[usize; N], local_dims: Option<&[usize; N]>, wait: impl IntoIterator<Item = impl AsRef<BaseEvent>>) -> Result<BaseEvent> {        
137        let dim_len = u32::try_from(N).expect("Too many work dimensions");
138        let local_dims = match local_dims {
139            Some(x) => x.as_ptr(),
140            None => core::ptr::null()
141        };
142
143        let wait = wait.into_iter().map(|x| x.as_ref().0).collect::<Vec<_>>();
144        let wait_len = u32::try_from(wait.len()).unwrap();
145        let wait = match wait_len {
146            0 => core::ptr::null(),
147            _ => wait.as_ptr()
148        };
149
150        let mut event = core::ptr::null_mut();
151        let err = unsafe {
152            clEnqueueNDRangeKernel(queue.0, self.0, dim_len, core::ptr::null_mut(), global_dims.as_ptr(), local_dims, wait_len, wait, &mut event)
153        };
154
155        if err == 0 { return BaseEvent::new(event); }
156
157        cfg_if::cfg_if! {
158            if #[cfg(feature = "error-stack")] {
159                let err = Error::from(err);
160                let report = error_stack::Report::new(err);
161
162                let report = match err {
163                    Error::InvalidProgramExecutable => report.attach_printable("there is no successfully built program executable available for device associated with the command queue"),
164                    Error::InvalidCommandQueue => report.attach_printable("command queue is not a valid command-queue."),
165                    Error::InvalidKernel => report.attach_printable("kernel is not a valid kernel object"),
166                    Error::InvalidContext => report.attach_printable("context associated with the command queue and kernel is not the same or the context associated with command queue and events in the event wait list are not the same"),
167                    Error::InvalidKernelArgs => report.attach_printable("the kernel argument values have not been specified"),
168                    Error::InvalidWorkDimension => report.attach_printable("work-dimension is not a valid value (i.e. a value between 1 and 3)"),
169                    Error::InvalidWorkGroupSize => report.attach_printable("local work size is specified and is invalid (i.e. specified values in local work size exceed the maximum size of workgroup for the device associated with queue)"),
170                    Error::InvalidGlobalOffset => report.attach_printable("global work offset is not NULL"),
171                    Error::OutOfResources => report.attach_printable("there is a failure to queue the execution instance of kernel on the command-queue because of insufficient resources needed to execute the kernel"),
172                    Error::MemObjectAllocationFailure => report.attach_printable("there is a failure to allocate memory for data store associated with image or buffer objects specified as arguments to kernel"),
173                    Error::InvalidEventWaitList => report.attach_printable("event objects in event wait list are not valid events"),
174                    Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
175                    _ => report
176                };
177
178                Err(report)
179            } else {
180                Err(Error::from(err))
181            }
182        }
183    }
184
185    #[inline]
186    fn get_info_string (&self, ty: cl_kernel_info) -> Result<String> {
187        unsafe {
188            let mut len = 0;
189            let err = clGetKernelInfo(self.0, ty, 0, core::ptr::null_mut(), &mut len);
190            self.parse_error(err, ty, 0)?;
191
192            let mut result = Vec::<u8>::with_capacity(len);
193            let err = clGetKernelInfo(self.0, ty, len, result.as_mut_ptr().cast(), core::ptr::null_mut());
194            self.parse_error(err, ty, len)?;
195
196            result.set_len(len - 1);
197            Ok(String::from_utf8(result).unwrap())
198        }
199    }
200
201    #[inline]
202    fn get_info<T> (&self, ty: cl_kernel_info) -> Result<T> {
203        let mut value = MaybeUninit::<T>::uninit();
204        
205        unsafe {
206            let err = clGetKernelInfo(self.0, ty, core::mem::size_of::<T>(), value.as_mut_ptr().cast(), core::ptr::null_mut());
207            self.parse_error(err, ty, core::mem::size_of::<T>())?;
208            Ok(value.assume_init())
209        }
210    }
211
212    #[inline]
213    fn get_arg_info_string (&self, ty: cl_kernel_arg_info, idx: u32) -> Result<String> {
214        unsafe {
215            let mut len = 0;
216            let err = clGetKernelArgInfo(self.0, idx, ty, 0, core::ptr::null_mut(), &mut len);
217            self.parse_error_arg(err, ty, 0)?;
218
219            let mut result = Vec::<u8>::with_capacity(len);
220            let err = clGetKernelArgInfo(self.0, idx, ty, len, result.as_mut_ptr().cast(), core::ptr::null_mut());
221            self.parse_error_arg(err, ty, len)?;
222            
223            result.set_len(len - 1);
224            Ok(String::from_utf8(result).unwrap())
225        }
226    }
227
228    #[inline]
229    fn get_arg_info<T> (&self, ty: cl_kernel_arg_info, idx: u32) -> Result<T> {
230        let mut value = MaybeUninit::<T>::uninit();
231        
232        unsafe {
233            let err = clGetKernelArgInfo(self.0, idx, ty, core::mem::size_of::<T>(), value.as_mut_ptr().cast(), core::ptr::null_mut());
234            self.parse_error_arg(err, ty, core::mem::size_of::<T>())?;
235            Ok(value.assume_init())
236        }
237    }
238
239    #[allow(unused_variables)]
240    fn parse_error (&self, err: i32, ty: cl_kernel_info, size: usize) -> Result<()> {
241        if err == 0 {
242            return Ok(());
243        }
244
245        cfg_if::cfg_if! {
246            if #[cfg(feature = "error-stack")] {
247                let err = Error::from(err);
248                let report = error_stack::Report::new(err);
249
250                let report = match err {
251                    Error::InvalidKernel => report.attach_printable(format!("'{:?}' is not a valid kernel", self.0)),
252                    Error::InvalidValue => report.attach_printable(format!("'{ty}' is not one of the supported values or size in bytes specified by {size} is < size of return type as specified in the table above and '{ty}' is not a NULL value")),
253                    Error::OutOfResources => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the device"),
254                    Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
255                    _ => report
256                };
257
258                Err(report)
259            } else {
260                Err(Error::from(err))
261            }
262        }
263    }
264
265    #[allow(unused_variables)]
266    fn parse_error_arg (&self, err: i32, ty: cl_kernel_arg_info, size: usize) -> Result<()> {
267        if err == 0 {
268            return Ok(());
269        }
270
271        cfg_if::cfg_if! {
272            if #[cfg(feature = "error-stack")] {
273                let err = Error::from(err);
274                let report = error_stack::Report::new(err);
275
276                let report = match err {
277                    Error::InvalidKernel => report.attach_printable(format!("'{:?}' is not a valid kernel", self.0)),
278                    Error::InvalidValue => report.attach_printable(format!("'{ty}' is not one of the supported values or size in bytes specified by {size} is < size of return type as specified in the table above and '{ty}' is not a NULL value")),
279                    Error::OutOfResources => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the device"),
280                    Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
281                    _ => report
282                };
283
284                Err(report)
285            } else {
286                Err(Error::from(err))
287            }
288        }
289    }
290
291    #[allow(unused_variables)]
292    fn parse_error_set_arg (&self, err: i32, idx: u32, size: usize) -> Result<()> {
293        if err == 0 {
294            return Ok(());
295        }
296
297        cfg_if::cfg_if! {
298            if #[cfg(feature = "error-stack")] {
299                let err = Error::from(err);
300                let report = error_stack::Report::new(err);
301
302                let report = match err {
303                    Error::InvalidKernel => report.attach_printable(format!("'{:?}' is not a valid kernel", self.0)),
304                    Error::InvalidArgIndex => report.attach_printable(format!("'{idx}' is not a valid argument index")),
305                    Error::InvalidArgValue => report.attach_printable("arg value specified is NULL for an argument that is not declared with the __local qualifier or vice-versa"),
306                    Error::InvalidMemObject => report.attach_printable("arg value is not a valid memory object for an argument declared to be a memory object"),
307                    Error::InvalidSampler => report.attach_printable("arg value is not a valid sampler for an argument declared to be a sampler"),
308                    Error::InvalidArgSize => report.attach_printable(format!("{size} != size of the declared data type")),
309                    _ => report
310                };
311
312                Err(report)
313            } else {
314                Err(Error::from(err))
315            }
316        }
317    }
318}
319
320/*impl Clone for Kernel {
321    #[inline(always)]
322    fn clone(&self) -> Self {
323        unsafe {
324            tri_panic!(clRetainKernel(self.0))
325        }
326        
327        Self(self.0)
328    }
329}*/
330
331impl Drop for Kernel {
332    #[inline(always)]
333    fn drop(&mut self) {
334        unsafe {
335            tri_panic!(clReleaseKernel(self.0))
336        }
337    }
338}
339
340unsafe impl Send for Kernel {}
341unsafe impl Sync for Kernel {}
342
343#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
344#[repr(u32)]
345pub enum AddrQualifier {
346    Global = CL_KERNEL_ARG_ADDRESS_GLOBAL,
347    Local = CL_KERNEL_ARG_ADDRESS_LOCAL,
348    Constant = CL_KERNEL_ARG_ADDRESS_CONSTANT,
349    Private = CL_KERNEL_ARG_ADDRESS_PRIVATE
350}
351
352impl Default for AddrQualifier {
353    #[inline(always)]
354    fn default() -> Self {
355        Self::Private
356    }
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
360#[repr(u32)]
361pub enum AccessQualifier {
362    ReadOnly = CL_KERNEL_ARG_ACCESS_READ_ONLY,
363    WriteOnly = CL_KERNEL_ARG_ACCESS_WRITE_ONLY,
364    ReadWrite = CL_KERNEL_ARG_ACCESS_READ_WRITE,
365    None = CL_KERNEL_ARG_ACCESS_NONE
366}
367
368bitflags::bitflags! {
369    #[repr(transparent)]
370    pub struct TypeQualifier: cl_kernel_arg_type_qualifier {
371        const CONST = CL_KERNEL_ARG_TYPE_CONST;
372        const RESTRICT = CL_KERNEL_ARG_TYPE_RESTRICT;
373        const VOLATILE = CL_KERNEL_ARG_TYPE_VOLATILE;
374    }
375}
376
377impl Default for TypeQualifier {
378    #[inline(always)]
379    fn default() -> Self {
380        Self::empty()
381    }
382}