1#[cfg(test)]
2extern crate std;
3
4use core::{mem::MaybeUninit, ptr::addr_of};
5use alloc::{string::{String}, vec::Vec};
6use opencl_sys::{cl_kernel, clReleaseKernel, clCreateKernel, clGetKernelInfo, cl_kernel_info, CL_KERNEL_FUNCTION_NAME, CL_KERNEL_NUM_ARGS, CL_KERNEL_REFERENCE_COUNT, CL_KERNEL_CONTEXT, CL_KERNEL_PROGRAM, clSetKernelArg, cl_kernel_arg_info, CL_KERNEL_ARG_ADDRESS_GLOBAL, CL_KERNEL_ARG_ADDRESS_LOCAL, CL_KERNEL_ARG_ADDRESS_CONSTANT, CL_KERNEL_ARG_ADDRESS_PRIVATE, CL_KERNEL_ARG_ADDRESS_QUALIFIER, CL_KERNEL_ARG_ACCESS_READ_ONLY, CL_KERNEL_ARG_ACCESS_WRITE_ONLY, CL_KERNEL_ARG_ACCESS_READ_WRITE, CL_KERNEL_ARG_ACCESS_NONE, CL_KERNEL_ARG_ACCESS_QUALIFIER, clGetKernelArgInfo, CL_KERNEL_ARG_NAME, CL_KERNEL_ARG_TYPE_NAME, CL_KERNEL_ARG_TYPE_CONST, CL_KERNEL_ARG_TYPE_RESTRICT, CL_KERNEL_ARG_TYPE_VOLATILE, CL_KERNEL_ARG_TYPE_QUALIFIER, clEnqueueNDRangeKernel, cl_mem, cl_kernel_arg_type_qualifier};
7use crate::{prelude::{Error, Program, Context, CommandQueue, BaseEvent}, error::Result, buffer::MemBuffer};
8
9#[cfg(feature = "error-stack")]
10use alloc::format;
11
12#[derive(PartialEq, Eq, Hash)]
13#[repr(transparent)]
14pub struct Kernel (pub(crate) cl_kernel);
15
16impl Kernel {
17 #[inline]
21 pub unsafe fn new_unchecked (program: &Program, name: &str) -> Result<Self> {
22 let mut name = name.as_bytes().to_vec();
23 name.push(0);
24
25 let mut err = 0;
26 let id = clCreateKernel(program.0, name.as_ptr().cast(), &mut err);
27 if err == 0 { return Ok(Self(id)); }
28
29 cfg_if::cfg_if! {
30 if #[cfg(feature = "error-stack")] {
31 let err = Error::from(err);
32 let report = error_stack::Report::new(err);
33
34 let report = match err {
35 Error::InvalidProgram => report.attach_printable("program is not a valid program object"),
36 Error::InvalidProgramExecutable => report.attach_printable("there is no successfully built executable for program"),
37 Error::InvalidKernelName => report.attach_printable("kernel name not found in program"),
38 Error::InvalidKernelDefinition => report.attach_printable("the function definition for __kernel function given by kernel name does not exist"),
39 Error::InvalidValue => report.attach_printable("kernel name is null"),
40 Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
41 _ => report
42 };
43
44 Err(report)
45 } else {
46 Err(Error::from(err))
47 }
48 }
49 }
50
51 #[inline(always)]
52 pub fn set_arg<T: Copy> (&mut self, idx: u32, v: T) -> Result<()> {
53 let err = unsafe { clSetKernelArg(self.0, idx, core::mem::size_of::<T>(), addr_of!(v).cast()) };
54 self.parse_error_set_arg(err, idx, core::mem::size_of::<T>())
55 }
56
57 #[inline(always)]
58 pub fn set_mem_arg<T: Copy + Unpin> (&mut self, idx: u32, v: &MemBuffer<T>) -> Result<()> {
59 let err = unsafe { clSetKernelArg(self.0, idx, core::mem::size_of::<cl_mem>(), addr_of!(v.0).cast()) };
60 self.parse_error_set_arg(err, idx, core::mem::size_of::<cl_mem>())
61 }
62
63 #[inline(always)]
64 pub fn alloc_arg<T> (&mut self, idx: u32, len: usize) -> Result<()> {
65 let arg_size = len.checked_mul(core::mem::size_of::<T>()).expect("Kernel argument size overflow");
66 let err = unsafe { clSetKernelArg(self.0, idx, arg_size, core::ptr::null_mut()) };
67 self.parse_error_set_arg(err, idx, arg_size)
68 }
69
70 #[inline(always)]
72 pub fn name (&self) -> Result<String> {
73 self.get_info_string(CL_KERNEL_FUNCTION_NAME)
74 }
75
76 #[inline(always)]
78 pub fn num_args (&self) -> Result<u32> {
79 self.get_info(CL_KERNEL_NUM_ARGS)
80 }
81
82 #[inline(always)]
84 pub fn reference_count (&self) -> Result<u32> {
85 self.get_info(CL_KERNEL_REFERENCE_COUNT)
86 }
87
88 #[inline(always)]
90 pub fn context (&self) -> Result<Context> {
91 self.get_info(CL_KERNEL_CONTEXT)
92 }
93
94 #[inline(always)]
96 pub fn program (&self) -> Result<Program> {
97 self.get_info(CL_KERNEL_PROGRAM)
98 }
99
100 #[inline(always)]
102 pub fn arg_address_qualifier (&self, idx: u32) -> Result<AddrQualifier> {
103 self.get_arg_info(CL_KERNEL_ARG_ADDRESS_QUALIFIER, idx)
104 }
105
106 #[inline(always)]
108 pub fn arg_access_qualifier (&self, idx: u32) -> Result<AccessQualifier> {
109 self.get_arg_info(CL_KERNEL_ARG_ACCESS_QUALIFIER, idx)
110 }
111
112 #[inline(always)]
114 pub fn arg_type_name (&self, idx: u32) -> Result<String> {
115 self.get_arg_info_string(CL_KERNEL_ARG_TYPE_NAME, idx)
116 }
117
118 #[inline(always)]
120 pub fn arg_qualifier (&self, idx: u32) -> Result<String> {
121 self.get_arg_info(CL_KERNEL_ARG_TYPE_QUALIFIER, idx)
122 }
123
124 #[inline(always)]
126 pub fn arg_name (&self, idx: u32) -> Result<String> {
127 self.get_arg_info_string(CL_KERNEL_ARG_NAME, idx)
128 }
129
130 #[cfg(feature = "def")]
131 #[inline(always)]
132 pub fn enqueue<const N: usize> (&mut self, global_dims: &[usize; N], local_dims: Option<&[usize; N]>, wait: impl IntoIterator<Item = impl AsRef<BaseEvent>>) -> Result<BaseEvent> {
133 self.enqueue_with_queue(CommandQueue::default(), global_dims, local_dims, wait)
134 }
135
136 pub fn enqueue_with_queue<const N: usize> (&mut self, queue: &CommandQueue, global_dims: &[usize; N], local_dims: Option<&[usize; N]>, wait: impl IntoIterator<Item = impl AsRef<BaseEvent>>) -> Result<BaseEvent> {
137 let dim_len = u32::try_from(N).expect("Too many work dimensions");
138 let local_dims = match local_dims {
139 Some(x) => x.as_ptr(),
140 None => core::ptr::null()
141 };
142
143 let wait = wait.into_iter().map(|x| x.as_ref().0).collect::<Vec<_>>();
144 let wait_len = u32::try_from(wait.len()).unwrap();
145 let wait = match wait_len {
146 0 => core::ptr::null(),
147 _ => wait.as_ptr()
148 };
149
150 let mut event = core::ptr::null_mut();
151 let err = unsafe {
152 clEnqueueNDRangeKernel(queue.0, self.0, dim_len, core::ptr::null_mut(), global_dims.as_ptr(), local_dims, wait_len, wait, &mut event)
153 };
154
155 if err == 0 { return BaseEvent::new(event); }
156
157 cfg_if::cfg_if! {
158 if #[cfg(feature = "error-stack")] {
159 let err = Error::from(err);
160 let report = error_stack::Report::new(err);
161
162 let report = match err {
163 Error::InvalidProgramExecutable => report.attach_printable("there is no successfully built program executable available for device associated with the command queue"),
164 Error::InvalidCommandQueue => report.attach_printable("command queue is not a valid command-queue."),
165 Error::InvalidKernel => report.attach_printable("kernel is not a valid kernel object"),
166 Error::InvalidContext => report.attach_printable("context associated with the command queue and kernel is not the same or the context associated with command queue and events in the event wait list are not the same"),
167 Error::InvalidKernelArgs => report.attach_printable("the kernel argument values have not been specified"),
168 Error::InvalidWorkDimension => report.attach_printable("work-dimension is not a valid value (i.e. a value between 1 and 3)"),
169 Error::InvalidWorkGroupSize => report.attach_printable("local work size is specified and is invalid (i.e. specified values in local work size exceed the maximum size of workgroup for the device associated with queue)"),
170 Error::InvalidGlobalOffset => report.attach_printable("global work offset is not NULL"),
171 Error::OutOfResources => report.attach_printable("there is a failure to queue the execution instance of kernel on the command-queue because of insufficient resources needed to execute the kernel"),
172 Error::MemObjectAllocationFailure => report.attach_printable("there is a failure to allocate memory for data store associated with image or buffer objects specified as arguments to kernel"),
173 Error::InvalidEventWaitList => report.attach_printable("event objects in event wait list are not valid events"),
174 Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
175 _ => report
176 };
177
178 Err(report)
179 } else {
180 Err(Error::from(err))
181 }
182 }
183 }
184
185 #[inline]
186 fn get_info_string (&self, ty: cl_kernel_info) -> Result<String> {
187 unsafe {
188 let mut len = 0;
189 let err = clGetKernelInfo(self.0, ty, 0, core::ptr::null_mut(), &mut len);
190 self.parse_error(err, ty, 0)?;
191
192 let mut result = Vec::<u8>::with_capacity(len);
193 let err = clGetKernelInfo(self.0, ty, len, result.as_mut_ptr().cast(), core::ptr::null_mut());
194 self.parse_error(err, ty, len)?;
195
196 result.set_len(len - 1);
197 Ok(String::from_utf8(result).unwrap())
198 }
199 }
200
201 #[inline]
202 fn get_info<T> (&self, ty: cl_kernel_info) -> Result<T> {
203 let mut value = MaybeUninit::<T>::uninit();
204
205 unsafe {
206 let err = clGetKernelInfo(self.0, ty, core::mem::size_of::<T>(), value.as_mut_ptr().cast(), core::ptr::null_mut());
207 self.parse_error(err, ty, core::mem::size_of::<T>())?;
208 Ok(value.assume_init())
209 }
210 }
211
212 #[inline]
213 fn get_arg_info_string (&self, ty: cl_kernel_arg_info, idx: u32) -> Result<String> {
214 unsafe {
215 let mut len = 0;
216 let err = clGetKernelArgInfo(self.0, idx, ty, 0, core::ptr::null_mut(), &mut len);
217 self.parse_error_arg(err, ty, 0)?;
218
219 let mut result = Vec::<u8>::with_capacity(len);
220 let err = clGetKernelArgInfo(self.0, idx, ty, len, result.as_mut_ptr().cast(), core::ptr::null_mut());
221 self.parse_error_arg(err, ty, len)?;
222
223 result.set_len(len - 1);
224 Ok(String::from_utf8(result).unwrap())
225 }
226 }
227
228 #[inline]
229 fn get_arg_info<T> (&self, ty: cl_kernel_arg_info, idx: u32) -> Result<T> {
230 let mut value = MaybeUninit::<T>::uninit();
231
232 unsafe {
233 let err = clGetKernelArgInfo(self.0, idx, ty, core::mem::size_of::<T>(), value.as_mut_ptr().cast(), core::ptr::null_mut());
234 self.parse_error_arg(err, ty, core::mem::size_of::<T>())?;
235 Ok(value.assume_init())
236 }
237 }
238
239 #[allow(unused_variables)]
240 fn parse_error (&self, err: i32, ty: cl_kernel_info, size: usize) -> Result<()> {
241 if err == 0 {
242 return Ok(());
243 }
244
245 cfg_if::cfg_if! {
246 if #[cfg(feature = "error-stack")] {
247 let err = Error::from(err);
248 let report = error_stack::Report::new(err);
249
250 let report = match err {
251 Error::InvalidKernel => report.attach_printable(format!("'{:?}' is not a valid kernel", self.0)),
252 Error::InvalidValue => report.attach_printable(format!("'{ty}' is not one of the supported values or size in bytes specified by {size} is < size of return type as specified in the table above and '{ty}' is not a NULL value")),
253 Error::OutOfResources => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the device"),
254 Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
255 _ => report
256 };
257
258 Err(report)
259 } else {
260 Err(Error::from(err))
261 }
262 }
263 }
264
265 #[allow(unused_variables)]
266 fn parse_error_arg (&self, err: i32, ty: cl_kernel_arg_info, size: usize) -> Result<()> {
267 if err == 0 {
268 return Ok(());
269 }
270
271 cfg_if::cfg_if! {
272 if #[cfg(feature = "error-stack")] {
273 let err = Error::from(err);
274 let report = error_stack::Report::new(err);
275
276 let report = match err {
277 Error::InvalidKernel => report.attach_printable(format!("'{:?}' is not a valid kernel", self.0)),
278 Error::InvalidValue => report.attach_printable(format!("'{ty}' is not one of the supported values or size in bytes specified by {size} is < size of return type as specified in the table above and '{ty}' is not a NULL value")),
279 Error::OutOfResources => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the device"),
280 Error::OutOfHostMemory => report.attach_printable("failure to allocate resources required by the OpenCL implementation on the host"),
281 _ => report
282 };
283
284 Err(report)
285 } else {
286 Err(Error::from(err))
287 }
288 }
289 }
290
291 #[allow(unused_variables)]
292 fn parse_error_set_arg (&self, err: i32, idx: u32, size: usize) -> Result<()> {
293 if err == 0 {
294 return Ok(());
295 }
296
297 cfg_if::cfg_if! {
298 if #[cfg(feature = "error-stack")] {
299 let err = Error::from(err);
300 let report = error_stack::Report::new(err);
301
302 let report = match err {
303 Error::InvalidKernel => report.attach_printable(format!("'{:?}' is not a valid kernel", self.0)),
304 Error::InvalidArgIndex => report.attach_printable(format!("'{idx}' is not a valid argument index")),
305 Error::InvalidArgValue => report.attach_printable("arg value specified is NULL for an argument that is not declared with the __local qualifier or vice-versa"),
306 Error::InvalidMemObject => report.attach_printable("arg value is not a valid memory object for an argument declared to be a memory object"),
307 Error::InvalidSampler => report.attach_printable("arg value is not a valid sampler for an argument declared to be a sampler"),
308 Error::InvalidArgSize => report.attach_printable(format!("{size} != size of the declared data type")),
309 _ => report
310 };
311
312 Err(report)
313 } else {
314 Err(Error::from(err))
315 }
316 }
317 }
318}
319
320impl Drop for Kernel {
332 #[inline(always)]
333 fn drop(&mut self) {
334 unsafe {
335 tri_panic!(clReleaseKernel(self.0))
336 }
337 }
338}
339
340unsafe impl Send for Kernel {}
341unsafe impl Sync for Kernel {}
342
343#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
344#[repr(u32)]
345pub enum AddrQualifier {
346 Global = CL_KERNEL_ARG_ADDRESS_GLOBAL,
347 Local = CL_KERNEL_ARG_ADDRESS_LOCAL,
348 Constant = CL_KERNEL_ARG_ADDRESS_CONSTANT,
349 Private = CL_KERNEL_ARG_ADDRESS_PRIVATE
350}
351
352impl Default for AddrQualifier {
353 #[inline(always)]
354 fn default() -> Self {
355 Self::Private
356 }
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
360#[repr(u32)]
361pub enum AccessQualifier {
362 ReadOnly = CL_KERNEL_ARG_ACCESS_READ_ONLY,
363 WriteOnly = CL_KERNEL_ARG_ACCESS_WRITE_ONLY,
364 ReadWrite = CL_KERNEL_ARG_ACCESS_READ_WRITE,
365 None = CL_KERNEL_ARG_ACCESS_NONE
366}
367
368bitflags::bitflags! {
369 #[repr(transparent)]
370 pub struct TypeQualifier: cl_kernel_arg_type_qualifier {
371 const CONST = CL_KERNEL_ARG_TYPE_CONST;
372 const RESTRICT = CL_KERNEL_ARG_TYPE_RESTRICT;
373 const VOLATILE = CL_KERNEL_ARG_TYPE_VOLATILE;
374 }
375}
376
377impl Default for TypeQualifier {
378 #[inline(always)]
379 fn default() -> Self {
380 Self::empty()
381 }
382}