use crate::{async_executor::{AsyncExecutor, kernel_arg::KernelArg}, cl_types::{cl_buffer::ClBuffer, cl_device::opencl_version::OpenCLVersion, cl_event::ClEvent, cl_image::ClImage, cl_kernel::ClKernel, cl_svm_buffer::ClSvmBuffer, cl_pipe::ClPipe}, error::ClError};
use std::os::raw::c_void;
use futures;
#[cfg(feature = "CL_VERSION_1_1")]
pub enum OutputRead<'a> {
Buffer {
buffer: &'a ClBuffer,
host_ptr: *mut c_void,
size: usize,
},
#[cfg(feature = "CL_VERSION_1_2")]
Image {
image: &'a ClImage,
host_ptr: *mut c_void,
origin: [usize; 3],
region: [usize; 3],
}
}
#[cfg(feature = "CL_VERSION_1_1")]
unsafe impl<'a> Send for OutputRead<'a> {}
#[cfg(feature = "CL_VERSION_1_1")]
unsafe impl<'a> Sync for OutputRead<'a> {}
#[cfg(feature = "CL_VERSION_1_1")]
pub struct TaskReport {
pub kernel_execution_events: Vec<ClEvent>,
pub read_events: Vec<ClEvent>,
}
#[cfg(feature = "CL_VERSION_1_1")]
impl TaskReport {
pub fn new() -> Self {
Self {
kernel_execution_events: Vec::new(),
read_events: Vec::new(),
}
}
pub fn total_kernel_duration_ns(&self) -> u64 {
self.kernel_execution_events.iter().filter_map(|e| e.get_duration_nanos().ok()).sum()
}
pub fn total_read_duration_ns(&self) -> u64 {
self.read_events.iter().filter_map(|e| e.get_duration_nanos().ok()).sum()
}
}
#[cfg(feature = "CL_VERSION_1_1")]
pub struct TaskBuilder<'a> {
async_executor: &'a AsyncExecutor,
kernel: &'a ClKernel,
kernel_args: Vec<KernelArg<'a>>,
global_work_dims: Option<[usize; 3]>,
global_work_offset: Option<[usize; 3]>,
local_work_dims: Option<[usize; 3]>,
output_reads: Vec<OutputRead<'a>>,
wait_list: Option<Vec<ClEvent>>,
profiling_enabled: bool,
}
impl<'a> TaskBuilder<'a> {
pub fn new(async_executor: &'a AsyncExecutor, kernel: &'a ClKernel) -> Self {
Self {
async_executor,
kernel,
kernel_args: Vec::new(),
local_work_dims: None,
global_work_dims: None,
global_work_offset: None,
output_reads: Vec::new(),
wait_list: None,
profiling_enabled: async_executor.is_profiling_enabled(),
}
}
pub fn with_profiling(mut self, enabled: bool) -> Self {
self.profiling_enabled = enabled;
self
}
pub fn arg_scalar<T>(self, arg_index: u32, scalar: T) -> Self {
self.add_scalar(arg_index, scalar, std::mem::size_of::<T>())
}
pub fn arg_buffer(self, arg_index: u32, buffer: &'a ClBuffer) -> Self {
self.add_buffer(arg_index, buffer)
}
pub fn arg_image(self, arg_index: u32, image: &'a ClImage) -> Self {
self.add_image_buffer(arg_index, image)
}
#[cfg(feature = "CL_VERSION_2_0")]
pub fn arg_svm<T>(self, arg_index: u32, buffer: &'a ClSvmBuffer<T>) -> Self {
self.add_svm_buffer(arg_index, buffer)
}
#[cfg(feature = "CL_VERSION_2_0")]
pub fn arg_pipe(self, arg_index: u32, pipe: &'a ClPipe) -> Self {
self.add_pipe(arg_index, pipe)
}
pub fn read_buffer<T>(mut self, buffer: &'a ClBuffer, host_memory: &mut [T]) -> Self {
self.output_reads.push(OutputRead::Buffer {
buffer,
host_ptr: host_memory.as_mut_ptr() as *mut c_void,
size: host_memory.len() * std::mem::size_of::<T>(),
});
self
}
#[cfg(feature = "CL_VERSION_1_2")]
pub fn read_image<T>(mut self, image: &'a ClImage, host_memory: &mut [T], origin: [usize; 3], region: [usize; 3]) -> Self {
self.output_reads.push(OutputRead::Image {
image,
host_ptr: host_memory.as_mut_ptr() as *mut c_void,
origin,
region,
});
self
}
pub fn add_scalar<T>(mut self, arg_index: u32,scalar: T, byte_size: usize) -> Self {
let kernel_arg = KernelArg::Scalar {
arg_index,
arg: unsafe {
std::slice::from_raw_parts(&scalar as *const T as *const u8, byte_size).to_vec()
}
};
self.kernel_args.push(kernel_arg);
self
}
pub fn add_buffer(mut self, arg_index: u32, buffer: &'a ClBuffer) -> Self {
let kernel_arg = KernelArg::Buffer { arg_index, arg: buffer };
self.kernel_args.push(kernel_arg);
self
}
#[cfg(feature = "CL_VERSION_2_0")]
pub fn add_svm_buffer<T>(mut self, arg_index: u32, buffer: &'a ClSvmBuffer<T>) -> Self {
let kernel_arg = KernelArg::Svm { arg_index, arg: buffer.as_ptr(), len: buffer.len };
self.kernel_args.push(kernel_arg);
self
}
#[cfg(feature = "CL_VERSION_1_2")]
pub fn add_image_buffer(mut self, arg_index: u32, image_buffer: &'a ClImage) -> Self {
let kernel_arg = KernelArg::Image { arg_index, arg: image_buffer };
self.kernel_args.push(kernel_arg);
self
}
#[cfg(feature = "CL_VERSION_2_0")]
pub fn add_pipe(mut self, arg_index: u32, pipe: &'a ClPipe) -> Self {
let kernel_arg = KernelArg::Pipe { arg_index, arg: pipe };
self.kernel_args.push(kernel_arg);
self
}
pub fn local_work_dims(mut self, x: usize, y: usize, z: usize) -> Self {
self.local_work_dims = Some([x,y,z]);
self
}
pub fn global_work_offset(mut self, x: usize, y: usize, z: usize) -> Self {
self.global_work_offset = Some([x,y,z]);
self
}
pub fn global_work_dims(mut self, x: usize, y: usize, z: usize) -> Self {
self.global_work_dims = Some([x,y,z]);
self
}
pub fn add_wait_list(mut self, wait_list: Vec<ClEvent>) -> Self {
self.wait_list = Some(wait_list);
self
}
pub async fn run(self) -> Result<TaskReport, ClError> {
let mut report = TaskReport::new();
let num_queues = self.async_executor.queues.len();
if num_queues == 0 {
return Err(ClError::Wrapper(crate::error::wrapper_error::WrapperError::PlatformsNotFound));
}
for arg in &self.kernel_args {
match arg {
KernelArg::Scalar { arg_index, arg } => {
unsafe {
self.kernel.set_args(*arg_index, arg.len(), arg.as_ptr() as *const _)?;
}
}
KernelArg::Buffer { arg_index, arg } => {
let handle = arg.as_ptr();
unsafe {
self.kernel.set_args(*arg_index, std::mem::size_of::<*mut std::os::raw::c_void>(), &handle as *const _ as *const _)?;
}
}
KernelArg::Svm { arg_index, arg, len } => {
unsafe {
self.kernel.set_svm_arg(*arg_index, *len, *arg)?;
}
}
KernelArg::Image { arg_index, arg } => {
let handle = arg.as_ptr();
unsafe {
self.kernel.set_args(*arg_index, std::mem::size_of::<*mut std::os::raw::c_void>(), &handle as *const _ as *const _)?;
}
}
#[cfg(feature = "CL_VERSION_2_0")]
KernelArg::Pipe { arg_index, arg } => {
let handle = arg.as_ptr();
unsafe {
self.kernel.set_args(*arg_index, std::mem::size_of::<*mut std::os::raw::c_void>(), &handle as *const _ as *const _)?;
}
}
}
}
let global_work_dims = self.global_work_dims.unwrap_or([1, 1, 1]);
let global_work_offset = self.global_work_offset.unwrap_or([0, 0, 0]);
let total_work = global_work_dims[0];
let total_weight: u64 = self.async_executor.weights.iter().sum();
let mut futures = Vec::new();
let mut current_offset = global_work_offset[0];
for i in 0..num_queues {
let weight = self.async_executor.weights[i];
let chunk_size = if i == num_queues - 1 {
global_work_offset[0] + total_work - current_offset
} else {
((total_work as u128 * weight as u128) / total_weight as u128) as usize
};
if chunk_size == 0 && i != num_queues - 1 {
continue;
}
let g_offset = vec![current_offset, global_work_offset[1], global_work_offset[2]];
let g_dims = vec![chunk_size, global_work_dims[1], global_work_dims[2]];
let work_dim = if global_work_dims[2] > 1 || global_work_offset[2] > 0 {
3
} else if global_work_dims[1] > 1 || global_work_offset[1] > 0 {
2
} else {
1
};
let g_offset_trimmed = g_offset[..work_dim].to_vec();
let g_dims_trimmed = g_dims[..work_dim].to_vec();
let kernel_ref = self.kernel;
let queue = self.async_executor.queues[i].clone();
let wait_list = self.wait_list.clone();
let l_dims_trimmed = if let Some(ld) = self.local_work_dims {
ld[..work_dim].to_vec()
} else {
let version = self.async_executor.get_device_versions()[i];
let device_res = self.async_executor.get_devices();
if let Ok(devices) = device_res {
let device = devices[i].clone();
if version >= OpenCLVersion::V2_0 && device.get_non_uniform_work_group_support().unwrap_or(false) {
Vec::new() } else if let Ok(preferred) = kernel_ref.get_work_group_size(device) {
if work_dim == 1 {
vec![preferred]
} else {
Vec::new() }
} else {
Vec::new()
}
} else {
Vec::new()
}
};
futures.push(async move {
let event = queue.enqueue_nd_range_kernel(
kernel_ref,
work_dim as u32,
g_offset_trimmed,
g_dims_trimmed,
l_dims_trimmed,
None,
wait_list
).await?;
Ok::<ClEvent, ClError>(event)
});
current_offset += chunk_size;
}
let results = futures::future::join_all(futures).await;
for res in results {
let event = res?;
if self.profiling_enabled {
report.kernel_execution_events.push(event);
}
}
if !self.output_reads.is_empty() {
let queue = self.async_executor.queues.first().ok_or(ClError::Wrapper(crate::error::wrapper_error::WrapperError::PlatformsNotFound))?;
for read in &self.output_reads {
match read {
OutputRead::Buffer { buffer, host_ptr, size } => {
let event = queue.enqueue_read_buffer_raw(*buffer, None, *host_ptr, *size, None).await?;
if self.profiling_enabled {
report.read_events.push(event);
}
}
OutputRead::Image { image, host_ptr, origin, region } => {
let event = queue.read_image_raw(*image, *origin, *region, 0, 0, *host_ptr, None).await?;
if self.profiling_enabled {
report.read_events.push(event);
}
}
}
}
}
Ok(report)
}
}
unsafe impl Sync for TaskBuilder<'_> {}
unsafe impl Send for TaskBuilder<'_> {}