use crate::{
context::{Context, RawContext},
core::*,
event::{
consumer::{NoopEvent, PhantomEvent},
RawEvent,
},
non_null_const,
prelude::Scope,
wait_list, WaitList,
};
use blaze_proc::docfg;
use opencl_sys::*;
use std::{
borrow::Borrow,
ffi::c_void,
marker::PhantomData,
mem::MaybeUninit,
ptr::{addr_of_mut, NonNull},
};
#[derive(Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct RawKernel(NonNull<c_void>);
impl RawKernel {
#[inline(always)]
pub const fn id(&self) -> cl_kernel {
self.0.as_ptr()
}
#[inline(always)]
pub const unsafe fn from_id_unchecked(id: cl_kernel) -> Self {
Self(NonNull::new_unchecked(id))
}
#[inline(always)]
pub const unsafe fn from_id(id: cl_kernel) -> Option<Self> {
match non_null_const(id) {
Some(x) => Some(Self(x)),
None => None,
}
}
#[inline(always)]
pub unsafe fn retain(&self) -> Result<()> {
tri!(clRetainKernel(self.id()));
Ok(())
}
#[inline(always)]
pub unsafe fn set_argument<T: Copy, R: Borrow<T>>(&mut self, idx: u32, v: R) -> Result<()> {
let ptr = v.borrow() as *const _ as *const _;
tri!(clSetKernelArg(
self.id(),
idx,
core::mem::size_of_val(v.borrow()),
ptr
));
Ok(())
}
#[inline(always)]
pub unsafe fn set_ptr_argument(
&mut self,
idx: u32,
size: usize,
ptr: *const c_void,
) -> Result<()> {
tri!(clSetKernelArg(self.id(), idx, size, ptr));
Ok(())
}
#[inline(always)]
pub unsafe fn allocate_argument(&mut self, idx: u32, size: usize) -> Result<()> {
self.set_ptr_argument(idx, size, core::ptr::null())
}
#[docfg(feature = "svm")]
pub unsafe fn set_svm_argument<T: ?Sized, S: crate::svm::SvmPointer<T>>(
&mut self,
idx: u32,
v: &S,
) -> Result<()> {
tri!(opencl_sys::clSetKernelArgSVMPointer(
self.id(),
idx,
v.as_ptr().cast()
));
Ok(())
}
#[inline(always)]
pub unsafe fn enqueue_unchecked<const N: usize>(
&mut self,
queue: &RawCommandQueue,
global_work_dims: [usize; N],
local_work_dims: impl Into<Option<[usize; N]>>,
wait: WaitList,
) -> Result<RawEvent> {
let work_dim = u32::try_from(N).expect("Integer overflow");
let local_work_dims = local_work_dims.into();
let local_work_dims = match local_work_dims {
Some(x) => x.as_ptr(),
None => core::ptr::null(),
};
let (num_events_in_wait_list, event_wait_list) = wait_list(wait)?;
let mut event = core::ptr::null_mut();
tri!(clEnqueueNDRangeKernel(
queue.id(),
self.id(),
work_dim,
core::ptr::null(),
global_work_dims.as_ptr(),
local_work_dims,
num_events_in_wait_list,
event_wait_list,
addr_of_mut!(event)
));
Ok(RawEvent::from_id(event).unwrap())
}
#[inline(always)]
pub unsafe fn enqueue_with_scope<'scope, 'env, C: Context, const N: usize>(
&mut self,
scope: &'scope Scope<'scope, 'env, C>,
global_work_dims: [usize; N],
local_work_dims: impl Into<Option<[usize; N]>>,
wait: WaitList,
) -> Result<NoopEvent> {
let work_dim = u32::try_from(N).expect("Integer overflow");
let local_work_dims = local_work_dims.into();
let local_work_dims = match local_work_dims {
Some(x) => x.as_ptr(),
None => core::ptr::null(),
};
let (num_events_in_wait_list, event_wait_list) = wait_list(wait)?;
return scope.enqueue_noop(|queue| {
let mut event = core::ptr::null_mut();
tri!(clEnqueueNDRangeKernel(
queue.id(),
self.id(),
work_dim,
core::ptr::null(),
global_work_dims.as_ptr(),
local_work_dims,
num_events_in_wait_list,
event_wait_list,
addr_of_mut!(event)
));
return Ok(RawEvent::from_id(event).unwrap());
});
}
#[inline(always)]
pub unsafe fn enqueue_phantom_with_scope<
'scope,
'env,
T: 'scope,
C: Context,
const N: usize,
>(
&mut self,
scope: &'scope Scope<'scope, 'env, C>,
global_work_dims: [usize; N],
local_work_dims: impl Into<Option<[usize; N]>>,
wait: WaitList,
) -> Result<PhantomEvent<T>> {
Ok(self
.enqueue_with_scope(scope, global_work_dims, local_work_dims, wait)?
.set_consumer(PhantomData))
}
#[inline(always)]
pub fn name(&self) -> Result<String> {
self.get_info_string(CL_KERNEL_FUNCTION_NAME)
}
#[inline(always)]
pub fn num_args(&self) -> Result<u32> {
self.get_info(CL_KERNEL_NUM_ARGS)
}
#[inline(always)]
pub fn reference_count(&self) -> Result<u32> {
self.get_info(CL_KERNEL_REFERENCE_COUNT)
}
#[inline(always)]
pub fn raw_context(&self) -> Result<RawContext> {
let ctx = self.get_info::<cl_context>(CL_KERNEL_CONTEXT)?;
unsafe {
tri!(clRetainContext(ctx));
Ok(RawContext::from_id_unchecked(ctx))
}
}
#[inline(always)]
pub fn program(&self) -> Result<RawProgram> {
let prog = self.get_info::<cl_context>(CL_KERNEL_PROGRAM)?;
unsafe {
tri!(clRetainProgram(prog));
Ok(RawProgram::from_id_unchecked(prog))
}
}
#[inline]
fn get_info_string(&self, ty: cl_kernel_info) -> Result<String> {
unsafe {
let mut len = 0;
tri!(clGetKernelInfo(
self.id(),
ty,
0,
core::ptr::null_mut(),
&mut len
));
let mut result = Vec::<u8>::with_capacity(len);
tri!(clGetKernelInfo(
self.id(),
ty,
len,
result.as_mut_ptr().cast(),
core::ptr::null_mut()
));
result.set_len(len - 1);
Ok(String::from_utf8(result).unwrap())
}
}
#[inline]
fn get_info<T: Copy>(&self, ty: cl_kernel_info) -> Result<T> {
let mut value = MaybeUninit::<T>::uninit();
unsafe {
tri!(clGetKernelInfo(
self.id(),
ty,
core::mem::size_of::<T>(),
value.as_mut_ptr().cast(),
core::ptr::null_mut()
));
Ok(value.assume_init())
}
}
}
impl Clone for RawKernel {
#[inline(always)]
fn clone(&self) -> Self {
unsafe { self.retain().unwrap() }
Self(self.0)
}
}
impl Drop for RawKernel {
#[inline(always)]
fn drop(&mut self) {
unsafe { tri_panic!(clReleaseKernel(self.id())) }
}
}
unsafe impl Send for RawKernel {}
unsafe impl Sync for RawKernel {}
#[cfg(feature = "cl1_2")]
use {
crate::buffer::flags::MemAccess,
opencl_sys::{
clGetKernelArgInfo, cl_kernel_arg_info, CL_KERNEL_ARG_ACCESS_QUALIFIER,
CL_KERNEL_ARG_ADDRESS_QUALIFIER, CL_KERNEL_ARG_NAME, CL_KERNEL_ARG_TYPE_NAME,
CL_KERNEL_ARG_TYPE_QUALIFIER,
},
};
#[docfg(feature = "cl1_2")]
impl RawKernel {
#[inline(always)]
pub fn arg_address_qualifier(&self, idx: u32) -> Result<AddrQualifier> {
self.get_arg_info(CL_KERNEL_ARG_ADDRESS_QUALIFIER, idx)
}
#[inline(always)]
pub fn arg_access_qualifier(&self, idx: u32) -> Result<MemAccess> {
let flags = self.get_arg_info::<opencl_sys::cl_kernel_arg_access_qualifier>(
CL_KERNEL_ARG_ACCESS_QUALIFIER,
idx,
)?;
let v = match flags {
opencl_sys::CL_KERNEL_ARG_ACCESS_READ_ONLY => MemAccess::READ_ONLY,
opencl_sys::CL_KERNEL_ARG_ACCESS_WRITE_ONLY => MemAccess::WRITE_ONLY,
opencl_sys::CL_KERNEL_ARG_ACCESS_READ_WRITE => MemAccess::READ_WRITE,
opencl_sys::CL_KERNEL_ARG_ACCESS_NONE => MemAccess::NONE,
_ => unreachable!(),
};
return Ok(v);
}
#[inline(always)]
pub fn arg_type_name(&self, idx: u32) -> Result<String> {
self.get_arg_info_string(CL_KERNEL_ARG_TYPE_NAME, idx)
}
#[inline(always)]
pub fn arg_qualifier(&self, idx: u32) -> Result<String> {
self.get_arg_info(CL_KERNEL_ARG_TYPE_QUALIFIER, idx)
}
#[inline(always)]
pub fn arg_name(&self, idx: u32) -> Result<String> {
self.get_arg_info_string(CL_KERNEL_ARG_NAME, idx)
}
#[inline]
fn get_arg_info_string(&self, ty: cl_kernel_arg_info, idx: u32) -> Result<String> {
unsafe {
let mut len = 0;
tri!(clGetKernelArgInfo(
self.id(),
idx,
ty,
0,
core::ptr::null_mut(),
&mut len
));
let mut result = Vec::<u8>::with_capacity(len);
tri!(clGetKernelArgInfo(
self.id(),
idx,
ty,
len,
result.as_mut_ptr().cast(),
core::ptr::null_mut()
));
result.set_len(len - 1);
Ok(String::from_utf8(result).unwrap())
}
}
#[inline]
fn get_arg_info<T>(&self, ty: cl_kernel_arg_info, idx: u32) -> Result<T> {
let mut value = MaybeUninit::<T>::uninit();
unsafe {
tri!(clGetKernelArgInfo(
self.id(),
idx,
ty,
core::mem::size_of::<T>(),
value.as_mut_ptr().cast(),
core::ptr::null_mut()
));
Ok(value.assume_init())
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum AddrQualifier {
Global = CL_KERNEL_ARG_ADDRESS_GLOBAL,
Local = CL_KERNEL_ARG_ADDRESS_LOCAL,
Constant = CL_KERNEL_ARG_ADDRESS_CONSTANT,
Private = CL_KERNEL_ARG_ADDRESS_PRIVATE,
}
impl Default for AddrQualifier {
#[inline(always)]
fn default() -> Self {
Self::Private
}
}
bitflags::bitflags! {
#[repr(transparent)]
pub struct TypeQualifier: cl_kernel_arg_type_qualifier {
const CONST = CL_KERNEL_ARG_TYPE_CONST;
const RESTRICT = CL_KERNEL_ARG_TYPE_RESTRICT;
const VOLATILE = CL_KERNEL_ARG_TYPE_VOLATILE;
}
}
impl Default for TypeQualifier {
#[inline(always)]
fn default() -> Self {
Self::empty()
}
}