1use std::fmt::Debug;
2
3use libc::{c_void};
4
5use crate::cl_helpers::cl_get_info5;
6use crate::ffi::{
7 clCreateKernel, clGetKernelInfo, clSetKernelArg, cl_context, cl_kernel, cl_kernel_info, cl_mem,
8 cl_program, cl_uint,
9};
10use crate::{
11 build_output, strings, ClContext, ClMem, ClPointer, ClProgram,
12 CommandQueueOptions, Dims, KernelInfo, MemPtr, Output, ProgramPtr, Work,
13 ObjectWrapper
14};
15
16pub unsafe trait KernelArg {
17 fn kernel_arg_size(&self) -> usize;
19 unsafe fn kernel_arg_ptr(&self) -> *const c_void;
20 unsafe fn kernel_arg_mut_ptr(&mut self) -> *mut c_void;
21}
22
23unsafe impl KernelArg for ClMem {
24 fn kernel_arg_size(&self) -> usize {
25 std::mem::size_of::<cl_mem>()
26 }
27 unsafe fn kernel_arg_ptr(&self) -> *const c_void {
28 self.mem_ptr_ref() as *const _ as *const c_void
29 }
30
31 unsafe fn kernel_arg_mut_ptr(&mut self) -> *mut c_void {
32 self.mem_ptr_ref() as *const _ as *mut c_void
33 }
34}
35
36macro_rules! sized_scalar_kernel_arg {
37 ($scalar:ty) => {
38 unsafe impl KernelArg for $scalar {
39 fn kernel_arg_size(&self) -> usize {
40 std::mem::size_of::<$scalar>()
41 }
42
43 unsafe fn kernel_arg_ptr(&self) -> *const c_void {
44 (self as *const $scalar) as *const c_void
45 }
46
47 unsafe fn kernel_arg_mut_ptr(&mut self) -> *mut c_void {
48 (self as *mut $scalar) as *mut c_void
49 }
50 }
51 };
52}
53
54sized_scalar_kernel_arg!(isize);
55sized_scalar_kernel_arg!(i8);
56sized_scalar_kernel_arg!(i16);
57sized_scalar_kernel_arg!(i32);
58sized_scalar_kernel_arg!(i64);
59
60sized_scalar_kernel_arg!(usize);
61sized_scalar_kernel_arg!(u8);
62sized_scalar_kernel_arg!(u16);
63sized_scalar_kernel_arg!(u32);
64sized_scalar_kernel_arg!(u64);
65
66sized_scalar_kernel_arg!(f32);
67sized_scalar_kernel_arg!(f64);
68
69pub unsafe fn cl_set_kernel_arg<T: KernelArg>(
74 kernel: cl_kernel,
75 arg_index: usize,
76 arg: &T,
77) -> Output<()> {
78 cl_set_kernel_arg_raw(
79 kernel,
80 arg_index as cl_uint,
81 arg.kernel_arg_size(),
82 arg.kernel_arg_ptr()
83 )
84}
85
86pub unsafe fn cl_set_kernel_arg_raw(
87 kernel: cl_kernel,
88 arg_index: cl_uint,
89 arg_size: usize,
90 arg_ptr: *const c_void,
91) -> Output<()> {
92 let err_code = clSetKernelArg(
93 kernel,
94 arg_index as cl_uint,
95 arg_size,
96 arg_ptr,
97 );
98
99 build_output((), err_code)
100}
101
102
103pub unsafe fn cl_create_kernel(program: cl_program, name: &str) -> Output<cl_kernel> {
104 let c_name = strings::to_c_string(name)
105 .ok_or_else(|| KernelError::CStringInvalidKernelName(name.to_string()))?;
106 let mut err_code = 0;
107 let kernel: cl_kernel = clCreateKernel(program, c_name.as_ptr(), &mut err_code);
108 build_output(kernel, err_code)
109}
110
111pub unsafe fn cl_get_kernel_info<T: Copy>(
112 kernel: cl_kernel,
113 flag: cl_kernel_info,
114) -> Output<ClPointer<T>> {
115 cl_get_info5(kernel, flag, clGetKernelInfo)
116}
117
118#[derive(Debug, Fail, PartialEq, Eq, Clone)]
120pub enum KernelError {
121 #[fail(
122 display = "The kernel name '{}' could not be represented as a CString.",
123 _0
124 )]
125 CStringInvalidKernelName(String),
126
127 #[fail(display = "Work is required for kernel operation.")]
128 WorkIsRequired,
129
130 #[fail(
131 display = "Returning arg index was out of range for kernel operation - index: {:?}, argc: {:?}",
132 _0, _1
133 )]
134 ReturningArgIndexOutOfRange(usize, usize),
135
136 #[fail(display = "The KernelOpArg was not a mem object type.")]
137 KernelOpArgWasNotMem,
138
139 #[fail(display = "The KernelOpArg was not a num type.")]
140 KernelOpArgWasNotNum,
141}
142
143pub unsafe trait KernelPtr: Sized {
144 unsafe fn kernel_ptr(&self) -> cl_kernel;
145
146 unsafe fn info<T: Copy>(&self, flag: KernelInfo) -> Output<ClPointer<T>> {
147 cl_get_kernel_info(self.kernel_ptr(), flag.into())
148 }
149
150 unsafe fn function_name(&self) -> Output<String> {
151 self.info(KernelInfo::FunctionName)
152 .map(|ret| ret.into_string())
153 }
154
155 unsafe fn num_args(&self) -> Output<u32> {
157 self.info(KernelInfo::NumArgs).map(|ret| ret.into_one())
158 }
159
160 unsafe fn reference_count(&self) -> Output<u32> {
162 self.info(KernelInfo::ReferenceCount)
163 .map(|ret| ret.into_one())
164 }
165
166 unsafe fn context(&self) -> Output<ClContext> {
167 self.info::<cl_context>(KernelInfo::Context)
168 .and_then(|cl_ptr| ClContext::retain_new(cl_ptr.into_one()))
169 }
170
171 unsafe fn program(&self) -> Output<ClProgram> {
172 self.info::<cl_program>(KernelInfo::Program)
173 .and_then(|cl_ptr| ClProgram::retain_new(cl_ptr.into_one()))
174 }
175
176 unsafe fn attributes(&self) -> Output<String> {
177 self.info(KernelInfo::Attributes)
178 .map(|ret| ret.into_string())
179 }
180
181 }
189
190pub type ClKernel = ObjectWrapper<cl_kernel>;
191
192impl ClKernel {
193 pub unsafe fn create(program: &ClProgram, name: &str) -> Output<ClKernel> {
198 cl_create_kernel(program.program_ptr(), name).and_then(|object| ClKernel::new(object))
199 }
200
201 pub unsafe fn set_arg<T: KernelArg>(&mut self, arg_index: usize, arg: &mut T) -> Output<()> {
206 cl_set_kernel_arg(self.kernel_ptr(), arg_index, arg)
207 }
208
209 pub unsafe fn set_arg_raw(&mut self, arg_index: u32, arg_size: usize, arg_ptr: *const c_void) -> Output<()> {
210 cl_set_kernel_arg_raw(self.kernel_ptr(), arg_index, arg_size, arg_ptr)
211 }
212}
213
214unsafe impl KernelPtr for ClKernel {
215 unsafe fn kernel_ptr(&self) -> cl_kernel {
216 self.cl_object()
217 }
218}
219
220pub struct KernelOperation {
221 _name: String,
222 _args: Vec<(usize, *const c_void)>,
223 _work: Option<Work>,
224 pub command_queue_opts: Option<CommandQueueOptions>,
225}
226
227impl KernelOperation{
228 pub fn new(name: &str) -> KernelOperation {
229 KernelOperation {
230 _name: name.to_owned(),
231 _args: vec![],
232 _work: None,
233 command_queue_opts: None,
234 }
235 }
236
237 pub fn name(&self) -> &str {
238 &self._name[..]
239 }
240
241 pub fn command_queue_opts(&self) -> Option<CommandQueueOptions> {
242 self.command_queue_opts.clone()
243 }
244
245 pub fn args(&self) -> &[(usize, *const c_void)] {
246 &self._args[..]
247 }
248
249 pub fn mut_args(&mut self) -> &mut [(usize, *const c_void)] {
250 &mut self._args[..]
251 }
252
253 pub fn with_dims<D: Into<Dims>>(mut self, dims: D) -> KernelOperation {
254 self._work = Some(Work::new(dims.into()));
255 self
256 }
257
258 pub fn with_work<W: Into<Work>>(mut self, work: W) -> KernelOperation {
259 self._work = Some(work.into());
260 self
261 }
262
263 pub fn add_arg<A>(mut self, arg: &mut A) -> KernelOperation where A: KernelArg {
264 self._args.push((arg.kernel_arg_size(), unsafe { arg.kernel_arg_ptr() }));
265 self
266 }
267
268 pub fn with_command_queue_options(mut self, opts: CommandQueueOptions) -> KernelOperation {
269 self.command_queue_opts = Some(opts);
270 self
271 }
272
273 pub fn argc(&self) -> usize {
274 self._args.len()
275 }
276
277 #[inline]
278 pub fn work(&self) -> Output<Work> {
279 self._work
280 .clone()
281 .ok_or_else(|| KernelError::WorkIsRequired.into())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use crate::ffi::*;
288 use crate::*;
289 use libc::c_void;
290
291 const SRC: &'static str = "
292 __kernel void test123(__global int *i) {
293 *i += 1;
294 }";
295
296 const KERNEL_NAME: &'static str = "test123";
297
298 #[test]
299 fn kernel_can_be_created() {
300 let (program, _devices, _context) = ll_testing::get_program(SRC);
301 let _kernel: ClKernel = unsafe { ClKernel::create(&program, KERNEL_NAME).unwrap() };
302 }
303
304 #[test]
305 fn kernel_function_name_works() {
306 let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
307 let function_name = unsafe { kernel.function_name().unwrap() };
308 assert_eq!(function_name, KERNEL_NAME);
309 }
310
311 #[test]
312 fn kernel_num_args_works() {
313 let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
314 let num_args = unsafe { kernel.num_args().unwrap() };
315 assert_eq!(num_args, 1);
316 }
317
318 #[test]
319 fn kernel_reference_count_works() {
320 let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
321 let ref_count = unsafe { kernel.reference_count().unwrap() };
322 assert_eq!(ref_count, 1);
323 }
324
325 #[test]
326 fn kernel_context_works() {
327 let (orig_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
328 let context: ClContext = unsafe { kernel.context().unwrap() };
329 assert_eq!(context, orig_context);
330 }
331
332 #[test]
333 fn kernel_program_works() {
334 let (_context, _devices, orig_program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
335 let program: ClProgram = unsafe { kernel.program().unwrap() };
336 assert_eq!(program, orig_program);
337 }
338
339 #[test]
340 fn kernel_attributes_works() {
341 let (_context, _devices, _program, kernel) = ll_testing::get_kernel(SRC, KERNEL_NAME);
342 let _attributes: String = unsafe { kernel.attributes().unwrap() };
343 }
344
345 #[test]
346 fn kernel_set_args_works_for_u8_scalar() {
347 let src: &str = "
348 __kernel void test123(uchar i) {
349 i + 1;
350 }";
351 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
352 let mut arg1 = 1u8 as cl_uchar;
353 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
354 }
355
356 #[test]
357 fn kernel_set_args_works_for_i8_scalar() {
358 let src: &str = "
359 __kernel void test123(char i) {
360 i + 1;
361 }";
362 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
363 let mut arg1 = 1i8 as cl_char;
364 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
365 }
366
367 #[test]
368 fn kernel_set_args_works_for_u16_scalar() {
369 let src: &str = "
370 __kernel void test123(ushort i) {
371 i + 1;
372 }";
373 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
374 let mut arg1 = 1u16 as cl_ushort;
375 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
376 }
377
378 #[test]
379 fn kernel_set_args_works_for_i16_scalar() {
380 let src: &str = "
381 __kernel void test123(short i) {
382 i + 1;
383 }";
384 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
385 let mut arg1 = 1i16 as cl_ushort;
386 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
387 }
388
389 #[test]
390 fn kernel_set_args_works_for_u32_scalar() {
391 let src: &str = "
392 __kernel void test123(uint i) {
393 i + 1;
394 }";
395 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
396 let mut arg1 = 1u32 as cl_uint;
397 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
398 }
399
400 #[test]
401 fn kernel_set_args_works_for_i32_scalar() {
402 let src: &str = "
403 __kernel void test123(int i) {
404 i + 1;
405 }";
406 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
407 let mut arg1 = 1i32 as cl_uint;
408 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
409 }
410
411 #[test]
412 fn kernel_set_args_works_for_f32_scalar() {
413 let src: &str = "
414 __kernel void test123(float i) {
415 i + 1.0;
416 }";
417 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
418 let mut arg1 = 1.0f32 as cl_float;
419 assert_eq!(std::mem::size_of::<cl_float>(), 4);
420 assert_eq!(std::mem::size_of::<f32>(), std::mem::size_of::<cl_float>());
421 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
422 }
423
424 #[test]
425 fn kernel_set_args_works_for_u64_scalar() {
426 let src: &str = "
427 __kernel void test123(ulong i) {
428 i + 1.0;
429 }";
430 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
431 let mut arg1 = 1u64 as cl_ulong;
432 assert_eq!(std::mem::size_of::<u64>(), std::mem::size_of::<cl_ulong>());
433 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
434 }
435
436 #[test]
437 fn kernel_set_args_works_for_i64_scalar() {
438 let src: &str = "
439 __kernel void test123(long i) {
440 i + 1.0;
441 }";
442 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
443 let mut arg1 = 1i64 as cl_long;
444 assert_eq!(std::mem::size_of::<i64>(), std::mem::size_of::<cl_long>());
445 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
446 }
447
448 #[test]
449 fn kernel_set_arg_works_for_f64_scalar() {
450 let src: &str = "
451 __kernel void test123(double i) {
452 i + 1.0;
453 }";
454 let (_context, _devices, _program, mut kernel) = ll_testing::get_kernel(src, KERNEL_NAME);
455 let mut arg1 = 1.0f64 as cl_double;
456 assert_eq!(std::mem::size_of::<f64>(), std::mem::size_of::<cl_double>());
457 let () = unsafe { kernel.set_arg(0, &mut arg1) }.unwrap();
458 }
459
460 fn build_session(src: &str) -> Session {
461 unsafe { SessionBuilder::new().with_program_src(src).build().unwrap() }
462 }
463
464 #[test]
465 fn kernel_set_arg_works_for_ffi_call() {
466 unsafe {
467 let src: &str = "
468 __kernel void test123(__global uchar *i) {
469 *i += 1;
470 }";
471
472 let session = build_session(src);
473 let kernel = session.create_kernel("test123").unwrap();
474
475 let data = vec![0u8, 0u8];
476 let mem1 = session.create_mem(&data[..]).unwrap();
477 let mem_ptr = &mem1.mem_ptr() as *const _ as *const c_void;
478 let err = clSetKernelArg(
479 kernel.kernel_ptr(),
480 0,
481 std::mem::size_of::<cl_mem>(),
482 mem_ptr,
483 );
484 assert_eq!(err, 0);
485 }
486 }
487
488 #[test]
489 fn kernel_set_arg_works_for_buffer_u8() {
490 unsafe {
491 let src: &str = "
492 __kernel void test123(__global uchar *i) {
493 *i += 1;
494 }";
495
496 let session = build_session(src);
497 let mut kernel = session.create_kernel("test123").unwrap();
498
499 let data = vec![0u8, 0u8];
500 let mut mem1 = session.create_mem(&data[..]).unwrap();
501 assert_eq!(mem1.len().unwrap(), 2);
502 let () = kernel.set_arg(0, &mut mem1).unwrap();
503 }
504 }
505}