open_cl_low_level/
kernel.rs

1use std::fmt::Debug;
2
3use libc::{c_void};
4
5use crate::cl_helpers::cl_get_info5;
6use crate::ffi::{
7    clCreateKernel, clGetKernelInfo, clSetKernelArg, cl_context, cl_kernel, cl_kernel_info, cl_mem,
8    cl_program, cl_uint,
9};
10use crate::{
11    build_output, strings, ClContext, ClMem, ClPointer, ClProgram,
12    CommandQueueOptions, Dims, KernelInfo, MemPtr, Output, ProgramPtr, Work,
13    ObjectWrapper
14};
15
16pub unsafe trait KernelArg {
17    /// size_of<T> or size_of<T> * len
18    fn kernel_arg_size(&self) -> usize;
19    unsafe fn kernel_arg_ptr(&self) -> *const c_void;
20    unsafe fn kernel_arg_mut_ptr(&mut self) -> *mut c_void;
21}
22
23unsafe impl KernelArg for ClMem {
24    fn kernel_arg_size(&self) -> usize {
25        std::mem::size_of::<cl_mem>()
26    }
27    unsafe fn kernel_arg_ptr(&self) -> *const c_void {
28        self.mem_ptr_ref() as *const _ as *const c_void
29    }
30
31    unsafe fn kernel_arg_mut_ptr(&mut self) -> *mut c_void {
32        self.mem_ptr_ref() as *const _ as *mut c_void
33    }
34}
35
36macro_rules! sized_scalar_kernel_arg {
37    ($scalar:ty) => {
38        unsafe impl KernelArg for $scalar {
39            fn kernel_arg_size(&self) -> usize {
40                std::mem::size_of::<$scalar>()
41            }
42
43            unsafe fn kernel_arg_ptr(&self) -> *const c_void {
44                (self as *const $scalar) as *const c_void
45            }
46
47            unsafe fn kernel_arg_mut_ptr(&mut self) -> *mut c_void {
48                (self as *mut $scalar) as *mut c_void
49            }
50        }
51    };
52}
53
54sized_scalar_kernel_arg!(isize);
55sized_scalar_kernel_arg!(i8);
56sized_scalar_kernel_arg!(i16);
57sized_scalar_kernel_arg!(i32);
58sized_scalar_kernel_arg!(i64);
59
60sized_scalar_kernel_arg!(usize);
61sized_scalar_kernel_arg!(u8);
62sized_scalar_kernel_arg!(u16);
63sized_scalar_kernel_arg!(u32);
64sized_scalar_kernel_arg!(u64);
65
66sized_scalar_kernel_arg!(f32);
67sized_scalar_kernel_arg!(f64);
68
69// pub use kernel_arg::{KernelArg, KernelArgSizeAndPointer};
70// use super::kernel_arg::{KernelArg, KernelArgSizeAndPointer};
71// use super::{Kernel, KernelError, KernelLock, KernelPtr, KernelRefCountWithProgram};
72
73pub unsafe fn cl_set_kernel_arg<T: KernelArg>(
74    kernel: cl_kernel,
75    arg_index: usize,
76    arg: &T,
77) -> Output<()> {
78    cl_set_kernel_arg_raw(
79        kernel,
80        arg_index as cl_uint,
81        arg.kernel_arg_size(),
82        arg.kernel_arg_ptr()
83    )
84}
85
86pub unsafe fn cl_set_kernel_arg_raw(
87    kernel: cl_kernel,
88    arg_index: cl_uint,
89    arg_size: usize,
90    arg_ptr: *const c_void,
91) -> Output<()> {
92    let err_code = clSetKernelArg(
93        kernel,
94        arg_index as cl_uint,
95        arg_size,
96        arg_ptr,
97    );
98
99    build_output((), err_code)
100}
101
102
103pub unsafe fn cl_create_kernel(program: cl_program, name: &str) -> Output<cl_kernel> {
104    let c_name = strings::to_c_string(name)
105        .ok_or_else(|| KernelError::CStringInvalidKernelName(name.to_string()))?;
106    let mut err_code = 0;
107    let kernel: cl_kernel = clCreateKernel(program, c_name.as_ptr(), &mut err_code);
108    build_output(kernel, err_code)
109}
110
111pub unsafe fn cl_get_kernel_info<T: Copy>(
112    kernel: cl_kernel,
113    flag: cl_kernel_info,
114) -> Output<ClPointer<T>> {
115    cl_get_info5(kernel, flag, clGetKernelInfo)
116}
117
118/// An error related to a `Kernel`.
119#[derive(Debug, Fail, PartialEq, Eq, Clone)]
120pub enum KernelError {
121    #[fail(
122        display = "The kernel name '{}' could not be represented as a CString.",
123        _0
124    )]
125    CStringInvalidKernelName(String),
126
127    #[fail(display = "Work is required for kernel operation.")]
128    WorkIsRequired,
129
130    #[fail(
131        display = "Returning arg index was out of range for kernel operation - index: {:?}, argc: {:?}",
132        _0, _1
133    )]
134    ReturningArgIndexOutOfRange(usize, usize),
135
136    #[fail(display = "The KernelOpArg was not a mem object type.")]
137    KernelOpArgWasNotMem,
138
139    #[fail(display = "The KernelOpArg was not a num type.")]
140    KernelOpArgWasNotNum,
141}
142
143pub unsafe trait KernelPtr: Sized {
144    unsafe fn kernel_ptr(&self) -> cl_kernel;
145
146    unsafe fn info<T: Copy>(&self, flag: KernelInfo) -> Output<ClPointer<T>> {
147        cl_get_kernel_info(self.kernel_ptr(), flag.into())
148    }
149
150    unsafe fn function_name(&self) -> Output<String> {
151        self.info(KernelInfo::FunctionName)
152            .map(|ret| ret.into_string())
153    }
154
155    /// Returns the number of args for a kernel.
156    unsafe fn num_args(&self) -> Output<u32> {
157        self.info(KernelInfo::NumArgs).map(|ret| ret.into_one())
158    }
159
160    /// Returns the OpenCL reference count of the kernel.
161    unsafe fn reference_count(&self) -> Output<u32> {
162        self.info(KernelInfo::ReferenceCount)
163            .map(|ret| ret.into_one())
164    }
165
166    unsafe fn context(&self) -> Output<ClContext> {
167        self.info::<cl_context>(KernelInfo::Context)
168            .and_then(|cl_ptr| ClContext::retain_new(cl_ptr.into_one()))
169    }
170
171    unsafe fn program(&self) -> Output<ClProgram> {
172        self.info::<cl_program>(KernelInfo::Program)
173            .and_then(|cl_ptr| ClProgram::retain_new(cl_ptr.into_one()))
174    }
175
176    unsafe fn attributes(&self) -> Output<String> {
177        self.info(KernelInfo::Attributes)
178            .map(|ret| ret.into_string())
179    }
180
181    // // OpenCL v2.0
182    // fn max_num_sub_groups(&self) -> Output<String> {
183    //     self.info(KernelInfo::MaxNumSubGroups).map(|ret| ret.to_string())
184    // }
185    // fn compile_num_sub_groups(&self) -> Output<String> {
186    //     self.info(KernelInfo::CompileNumSubGroups).map(|ret| ret.to_string())
187    // }
188}
189
190pub type ClKernel = ObjectWrapper<cl_kernel>;
191
192impl ClKernel {
193    /// Creates a wrapped cl_kernel object.
194    ///
195    /// # Safety
196    /// Calling this function with an invalid ClProgram is undefined behavior.
197    pub unsafe fn create(program: &ClProgram, name: &str) -> Output<ClKernel> {
198        cl_create_kernel(program.program_ptr(), name).and_then(|object| ClKernel::new(object))
199    }
200
201    /// Set adds and arg to a kernel at a given index.
202    ///
203    /// # Safety
204    /// Calling this function on invalid kernel or with invalid `arg` is undefined behavior.
205    pub unsafe fn set_arg<T: KernelArg>(&mut self, arg_index: usize, arg: &mut T) -> Output<()> {
206        cl_set_kernel_arg(self.kernel_ptr(), arg_index, arg)
207    }
208
209    pub unsafe fn set_arg_raw(&mut self, arg_index: u32, arg_size: usize, arg_ptr: *const c_void) -> Output<()> {
210        cl_set_kernel_arg_raw(self.kernel_ptr(), arg_index, arg_size, arg_ptr)
211    }
212}
213
214unsafe impl KernelPtr for ClKernel {
215    unsafe fn kernel_ptr(&self) -> cl_kernel {
216        self.cl_object()
217    }
218}
219
220pub struct KernelOperation {
221    _name: String,
222    _args: Vec<(usize, *const c_void)>,
223    _work: Option<Work>,
224    pub command_queue_opts: Option<CommandQueueOptions>,
225}
226
227impl KernelOperation{
228    pub fn new(name: &str) -> KernelOperation {
229        KernelOperation {
230            _name: name.to_owned(),
231            _args: vec![],
232            _work: None,
233            command_queue_opts: None,
234        }
235    }
236
237    pub fn name(&self) -> &str {
238        &self._name[..]
239    }
240
241    pub fn command_queue_opts(&self) -> Option<CommandQueueOptions> {
242        self.command_queue_opts.clone()
243    }
244
245    pub fn args(&self) -> &[(usize, *const c_void)] {
246        &self._args[..]
247    }
248
249    pub fn mut_args(&mut self) -> &mut [(usize, *const c_void)] {
250        &mut self._args[..]
251    }
252
253    pub fn with_dims<D: Into<Dims>>(mut self, dims: D) -> KernelOperation {
254        self._work = Some(Work::new(dims.into()));
255        self
256    }
257
258    pub fn with_work<W: Into<Work>>(mut self, work: W) -> KernelOperation {
259        self._work = Some(work.into());
260        self
261    }
262
263    pub fn add_arg<A>(mut self, arg: &mut A) -> KernelOperation where A: KernelArg {
264        self._args.push((arg.kernel_arg_size(), unsafe { arg.kernel_arg_ptr() }));
265        self
266    }
267
268    pub fn with_command_queue_options(mut self, opts: CommandQueueOptions) -> KernelOperation {
269        self.command_queue_opts = Some(opts);
270        self
271    }
272
273    pub fn argc(&self) -> usize {
274        self._args.len()
275    }
276
277    #[inline]
278    pub fn work(&self) -> Output<Work> {
279        self._work
280            .clone()
281            .ok_or_else(|| KernelError::WorkIsRequired.into())
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use crate::ffi::*;
288    use crate::*;
289    use libc::c_void;
290
291    const SRC: &'static str = "
292    __kernel void test123(__global int *i) {
293        *i += 1;
294    }";
295
296    const KERNEL_NAME: &'static str = "test123";
297
298    #[test]
299    fn kernel_can_be_created() {
300        let (program, _devices, _context) = ll_testing::get_program(SRC);
301        let _kernel: ClKernel = unsafe { ClKernel::create(&program, KERNEL_NAME).unwrap() };
302    }
303
304    #[test]
305    fn kernel_function_name_works() {
306        let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
307        let function_name = unsafe { kernel.function_name().unwrap() };
308        assert_eq!(function_name, KERNEL_NAME);
309    }
310
311    #[test]
312    fn kernel_num_args_works() {
313        let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
314        let num_args = unsafe { kernel.num_args().unwrap() };
315        assert_eq!(num_args, 1);
316    }
317
318    #[test]
319    fn kernel_reference_count_works() {
320        let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
321        let ref_count = unsafe { kernel.reference_count().unwrap() };
322        assert_eq!(ref_count, 1);
323    }
324
325    #[test]
326    fn kernel_context_works() {
327        let (orig_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
328        let context: ClContext = unsafe { kernel.context().unwrap() };
329        assert_eq!(context, orig_context);
330    }
331
332    #[test]
333    fn kernel_program_works() {
334        let (_context, _devices, orig_program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
335        let program: ClProgram = unsafe { kernel.program().unwrap() };
336        assert_eq!(program, orig_program);
337    }
338
339    #[test]
340    fn kernel_attributes_works() {
341        let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
342        let _attributes: String = unsafe { kernel.attributes().unwrap() };
343    }
344
345    #[test]
346    fn kernel_set_args_works_for_u8_scalar() {
347        let src: &str = "
348        __kernel void test123(uchar i) {
349            i + 1;
350        }";
351        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
352        let mut arg1 = 1u8 as cl_uchar;
353        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
354    }
355
356    #[test]
357    fn kernel_set_args_works_for_i8_scalar() {
358        let src: &str = "
359        __kernel void test123(char i) {
360            i + 1;
361        }";
362        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
363        let mut arg1 = 1i8 as cl_char;
364        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
365    }
366
367    #[test]
368    fn kernel_set_args_works_for_u16_scalar() {
369        let src: &str = "
370        __kernel void test123(ushort i) {
371            i + 1;
372        }";
373        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
374        let mut arg1 = 1u16 as cl_ushort;
375        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
376    }
377
378    #[test]
379    fn kernel_set_args_works_for_i16_scalar() {
380        let src: &str = "
381        __kernel void test123(short i) {
382            i + 1;
383        }";
384        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
385        let mut arg1 = 1i16 as cl_ushort;
386        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
387    }
388
389    #[test]
390    fn kernel_set_args_works_for_u32_scalar() {
391        let src: &str = "
392        __kernel void test123(uint i) {
393            i + 1;
394        }";
395        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
396        let mut arg1 = 1u32 as cl_uint;
397        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
398    }
399
400    #[test]
401    fn kernel_set_args_works_for_i32_scalar() {
402        let src: &str = "
403        __kernel void test123(int i) {
404            i + 1;
405        }";
406        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
407        let mut arg1 = 1i32 as cl_uint;
408        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
409    }
410
411    #[test]
412    fn kernel_set_args_works_for_f32_scalar() {
413        let src: &str = "
414        __kernel void test123(float i) {
415            i + 1.0;
416        }";
417        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
418        let mut arg1 = 1.0f32 as cl_float;
419        assert_eq!(std::mem::size_of::<cl_float>(), 4);
420        assert_eq!(std::mem::size_of::<f32>(), std::mem::size_of::<cl_float>());
421        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
422    }
423
424    #[test]
425    fn kernel_set_args_works_for_u64_scalar() {
426        let src: &str = "
427        __kernel void test123(ulong i) {
428            i + 1.0;
429        }";
430        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
431        let mut arg1 = 1u64 as cl_ulong;
432        assert_eq!(std::mem::size_of::<u64>(), std::mem::size_of::<cl_ulong>());
433        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
434    }
435
436    #[test]
437    fn kernel_set_args_works_for_i64_scalar() {
438        let src: &str = "
439        __kernel void test123(long i) {
440            i + 1.0;
441        }";
442        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
443        let mut arg1 = 1i64 as cl_long;
444        assert_eq!(std::mem::size_of::<i64>(), std::mem::size_of::<cl_long>());
445        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
446    }
447
448    #[test]
449    fn kernel_set_arg_works_for_f64_scalar() {
450        let src: &str = "
451        __kernel void test123(double i) {
452            i + 1.0;
453        }";
454        let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
455        let mut arg1 = 1.0f64 as cl_double;
456        assert_eq!(std::mem::size_of::<f64>(), std::mem::size_of::<cl_double>());
457        let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
458    }
459
460    fn build_session(src: &str) -> Session {
461        unsafe { SessionBuilder::new().with_program_src(src).build().unwrap() }
462    }
463
464    #[test]
465    fn kernel_set_arg_works_for_ffi_call() {
466        unsafe {
467            let src: &str = "
468            __kernel void test123(__global uchar *i) {
469                *i += 1;
470            }";
471
472            let session = build_session(src);
473            let kernel = session.create_kernel("test123").unwrap();
474
475            let data = vec![0u8, 0u8];
476            let mem1 = session.create_mem(&data[..]).unwrap();
477            let mem_ptr = &mem1.mem_ptr() as *const _ as *const c_void;
478            let err = clSetKernelArg(
479                kernel.kernel_ptr(),
480                0,
481                std::mem::size_of::<cl_mem>(),
482                mem_ptr,
483            );
484            assert_eq!(err, 0);
485        }
486    }
487
488    #[test]
489    fn kernel_set_arg_works_for_buffer_u8() {
490        unsafe {
491            let src: &str = "
492            __kernel void test123(__global uchar *i) {
493                *i += 1;
494            }";
495
496            let session = build_session(src);
497            let mut kernel = session.create_kernel("test123").unwrap();
498
499            let data = vec![0u8, 0u8];
500            let mut mem1 = session.create_mem(&data[..]).unwrap();
501            assert_eq!(mem1.len().unwrap(), 2);
502            let () = kernel.set_arg(0, &mut mem1).unwrap();
503        }
504    }
505}