use crate::core::Status;
use crate::error::{Error as OclError, Result as OclResult};
use crate::ffi::{c_void, cl_event};
use crate::flags::{CommandQueueProperties, MemFlags};
use crate::prm::Int4;
use crate::traits::IntoRawEventArray;
use crate::{Buffer, Context, Device, Event, Kernel, Platform, Program, Queue, RwVec};
use futures::Future;
use std::thread;
const WORK_SIZE: usize = 1 << 12;
const INIT_VAL: i32 = 50;
const SCALAR_ADDEND: i32 = 100;
const TASK_ITERS: i32 = 16;
const PRINT: bool = false;
pub static KERN_SRC: &'static str = r#"
__kernel void add_slowly(
__global int4* in,
__private int addend,
__global int4* out)
{
uint const idx = get_global_id(0);
float4 const inflated_val = (float4)(addend) * (float4)(255.0);
int4 sum = (int4)(0);
for (int i = 0; i < addend; i++) {
sum += convert_int4((inflated_val / (float4)(255.0)) / (float4)(addend));
}
out[idx] = in[idx] + sum;
}
"#;
pub fn fill_junk(
src_buf: &Buffer<Int4>,
common_queue: &Queue,
kernel_event: Option<&Event>,
verify_init_event: Option<&Event>,
fill_event: &mut Option<Event>,
task_iter: i32,
) {
extern "C" fn _print_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Fill starting \t(iter: {}) ...",
task_iter as usize
);
}
}
extern "C" fn _print_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!("* Fill complete \t(iter: {})", task_iter as usize);
}
}
let wait_list = [kernel_event, verify_init_event].into_raw_array();
let fill_wait_marker = wait_list.to_marker(&common_queue).unwrap();
if let Some(ref marker) = fill_wait_marker {
unsafe {
marker
.set_callback(_print_starting, task_iter as *mut c_void)
.unwrap();
}
} else {
_print_starting(0 as cl_event, 0, task_iter as *mut c_void);
}
*fill_event = Some(Event::empty());
src_buf
.cmd()
.fill(Int4::new(-999, -999, -999, -999), None)
.queue(common_queue)
.ewait(&wait_list)
.enew(fill_event.as_mut())
.enq()
.unwrap();
unsafe {
fill_event
.as_ref()
.unwrap()
.set_callback(_print_complete, task_iter as *mut c_void)
.unwrap();
}
}
pub fn write_init(
src_buf: &Buffer<Int4>,
rw_vec: &RwVec<Int4>,
common_queue: &Queue,
write_init_release_queue_0: &Queue,
fill_event: Option<&Event>,
verify_init_event: Option<&Event>,
write_init_event: &mut Option<Event>,
write_val: i32,
task_iter: i32,
) -> Box<dyn Future<Item = i32, Error = OclError> + Send> {
extern "C" fn _write_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!("* Write init complete \t(iter: {})", task_iter as usize);
}
}
let wait_list = [verify_init_event, fill_event];
let mut future_guard = rw_vec.clone().write();
future_guard.set_lock_wait_events(wait_list);
let release_event = future_guard
.create_release_event(write_init_release_queue_0)
.unwrap()
.clone();
let future_write_vec = future_guard.and_then(move |mut data| {
if PRINT {
println!("* Write init starting \t(iter: {}) ...", task_iter);
}
for val in data.iter_mut() {
*val = Int4::splat(write_val);
}
Ok(())
});
let mut future_write_buffer = src_buf
.cmd()
.write(rw_vec)
.queue(common_queue)
.ewait(&release_event)
.enq_async()
.unwrap();
*write_init_event = Some(
future_write_buffer
.create_release_event(write_init_release_queue_0)
.unwrap()
.clone(),
);
unsafe {
write_init_event
.as_ref()
.unwrap()
.set_callback(_write_complete, task_iter as *mut c_void)
.unwrap();
}
let future_drop_guard = future_write_buffer.and_then(move |_| Ok(()));
Box::new(
future_write_vec
.join(future_drop_guard)
.map(move |(_, _)| task_iter),
)
}
pub fn verify_init(
src_buf: &Buffer<Int4>,
rw_vec: &RwVec<Int4>,
common_queue: &Queue,
verify_init_queue: &Queue,
write_init_event: Option<&Event>,
verify_init_event: &mut Option<Event>,
correct_val: i32,
task_iter: i32,
) -> Box<dyn Future<Item = i32, Error = OclError> + Send> {
extern "C" fn _verify_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Verify init starting \t(iter: {}) ...",
task_iter as usize
);
}
}
let wait_list = [&verify_init_event.as_ref(), &write_init_event].into_raw_array();
let mut future_read_data = src_buf
.cmd()
.read(rw_vec)
.queue(common_queue)
.ewait(&wait_list)
.enq_async()
.unwrap();
unsafe {
future_read_data
.lock_event()
.unwrap()
.set_callback(_verify_starting, task_iter as *mut c_void)
.unwrap();
}
*verify_init_event = Some(
future_read_data
.create_release_event(verify_init_queue)
.unwrap()
.clone(),
);
Box::new(future_read_data.and_then(move |data| {
let mut val_count = 0;
for (idx, val) in data.iter().enumerate() {
let cval = Int4::new(correct_val, correct_val, correct_val, correct_val);
if *val != cval {
return Err(format!(
"Verify init: Result value mismatch: {:?} != {:?} @ [{}] \
for task iter: [{}].",
val, cval, idx, task_iter
)
.into());
}
val_count += 1;
}
if PRINT {
println!("* Verify init complete \t(iter: {})", task_iter);
}
Ok(val_count)
}))
}
pub fn kernel_add(
kern: &Kernel,
common_queue: &Queue,
verify_add_event: Option<&Event>,
write_init_event: Option<&Event>,
kernel_event: &mut Option<Event>,
task_iter: i32,
) {
extern "C" fn _print_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Kernel starting \t(iter: {}) ...",
task_iter as usize
);
}
}
extern "C" fn _print_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!("* Kernel complete \t(iter: {})", task_iter as usize);
}
}
let wait_list = [&verify_add_event, &write_init_event].into_raw_array();
let kernel_wait_marker = wait_list.to_marker(&common_queue).unwrap();
unsafe {
kernel_wait_marker
.as_ref()
.unwrap()
.set_callback(_print_starting, task_iter as *mut c_void)
.unwrap();
}
*kernel_event = Some(Event::empty());
unsafe {
kern.cmd()
.queue(common_queue)
.ewait(&wait_list)
.enew(kernel_event.as_mut())
.enq()
.unwrap();
}
unsafe {
kernel_event
.as_ref()
.unwrap()
.set_callback(_print_complete, task_iter as *mut c_void)
.unwrap();
}
}
pub fn verify_add(
dst_buf: &Buffer<Int4>,
rw_vec: &RwVec<Int4>,
common_queue: &Queue,
verify_add_unmap_queue: &Queue,
wait_event: Option<&Event>,
verify_add_event: &mut Option<Event>,
correct_val: i32,
task_iter: i32,
) -> Box<dyn Future<Item = i32, Error = OclError> + Send> {
extern "C" fn _verify_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Verify add starting \t(iter: {}) ...",
task_iter as usize
);
}
}
let mut future_read_data = dst_buf
.cmd()
.read(rw_vec)
.queue(common_queue)
.ewait(wait_event)
.enq_async()
.unwrap();
unsafe {
future_read_data
.lock_event()
.unwrap()
.set_callback(_verify_starting, task_iter as *mut c_void)
.unwrap();
}
*verify_add_event = Some(
future_read_data
.create_release_event(verify_add_unmap_queue)
.unwrap()
.clone(),
);
Box::new(future_read_data.and_then(move |mut data| {
let mut val_count = 0;
for (idx, val) in data.iter().enumerate() {
let cval = Int4::splat(correct_val);
if *val != cval {
return Err(format!(
"Verify add: Result value mismatch: {:?} != {:?} @ [{}] \
for task iter: [{}].",
val, cval, idx, task_iter
)
.into());
}
val_count += 1;
}
for val in data.iter_mut() {
*val = Int4::splat(0);
}
if PRINT {
println!("* Verify add complete \t(iter: {})", task_iter);
}
Ok(val_count)
}))
}
fn create_queue(
context: &Context,
device: Device,
flags: Option<CommandQueueProperties>,
) -> OclResult<Queue> {
Queue::new(context, device, flags.clone()).or_else(|err| match err.api_status() {
Some(Status::CL_INVALID_VALUE) => {
Err("Device does not support out of order queues.".into())
}
_ => Err(err.into()),
})
}
#[test]
pub fn rw_vec() {
for platform in Platform::list() {
for device in Device::list_all(platform).unwrap() {
println!(
"Device: {} {}",
device.vendor().unwrap(),
device.name().unwrap()
);
let context = Context::builder()
.platform(platform)
.devices(device)
.build()
.unwrap();
let queue_flags = if cfg!(feature = "async_block") {
None
} else {
Some(CommandQueueProperties::new().out_of_order())
};
let common_queue = create_queue(&context, device, queue_flags).unwrap();
let write_init_unmap_queue_0 = create_queue(&context, device, queue_flags).unwrap();
let verify_init_queue = create_queue(&context, device, queue_flags).unwrap();
let verify_add_unmap_queue = create_queue(&context, device, queue_flags).unwrap();
let src_buf_flags = MemFlags::new().alloc_host_ptr().read_only();
let dst_buf_flags = MemFlags::new()
.alloc_host_ptr()
.write_only()
.host_read_only();
let src_buf: Buffer<Int4> = Buffer::builder()
.context(&context)
.flags(src_buf_flags)
.len(WORK_SIZE)
.build()
.unwrap();
let dst_buf: Buffer<Int4> = Buffer::builder()
.context(&context)
.flags(dst_buf_flags)
.len(WORK_SIZE)
.build()
.unwrap();
let program = Program::builder()
.devices(device)
.src(KERN_SRC)
.build(&context)
.unwrap();
let kern = Kernel::builder()
.program(&program)
.name("add_slowly")
.global_work_size(WORK_SIZE)
.arg(&src_buf)
.arg(SCALAR_ADDEND)
.arg(&dst_buf)
.build()
.unwrap();
let rw_vec: RwVec<Int4> = RwVec::from(vec![Default::default(); WORK_SIZE]);
let mut threads = Vec::with_capacity(TASK_ITERS as usize);
let mut fill_event = None;
let mut write_init_event = None;
let mut verify_init_event = None;
let mut kernel_event = None;
let mut verify_add_event = None;
if PRINT {
println!("Starting cycles ...");
}
for task_iter in 0..TASK_ITERS {
let ival = INIT_VAL + task_iter;
let tval = ival + SCALAR_ADDEND;
fill_junk(
&src_buf,
&common_queue,
verify_init_event.as_ref(),
kernel_event.as_ref(),
&mut fill_event,
task_iter,
);
let write_init = write_init(
&src_buf,
&rw_vec,
&common_queue,
&write_init_unmap_queue_0,
fill_event.as_ref(),
verify_init_event.as_ref(),
&mut write_init_event,
ival,
task_iter,
);
let verify_init = verify_init(
&src_buf,
&rw_vec,
&common_queue,
&verify_init_queue,
write_init_event.as_ref(),
&mut verify_init_event,
ival,
task_iter,
);
kernel_add(
&kern,
&common_queue,
verify_add_event.as_ref(),
write_init_event.as_ref(),
&mut kernel_event,
task_iter,
);
let verify_add = verify_add(
&dst_buf,
&rw_vec,
&common_queue,
&verify_add_unmap_queue,
kernel_event.as_ref(),
&mut verify_add_event,
tval,
task_iter,
);
if PRINT {
println!("All commands for iteration {} enqueued", task_iter);
}
let task = write_init.join3(verify_init, verify_add);
threads.push(
thread::Builder::new()
.name(format!("task_iter_[{}]", task_iter).into())
.spawn(move || {
if PRINT {
println!("Waiting on task iter [{}]...", task_iter);
}
match task.wait() {
Ok(res) => {
if PRINT {
println!(
"Task iter [{}] complete with result: {:?}",
task_iter, res
);
}
true
}
Err(err) => {
if PRINT {
println!(
"\n############## ERROR (task iter: [{}]) \
############## \n{:?}\n",
task_iter, err
);
}
false
}
}
})
.unwrap(),
);
}
let mut all_correct = true;
for thread in threads {
match thread.join() {
Ok(res) => {
if PRINT {
println!("Thread result: {:?}", res);
}
if !res {
all_correct = false;
}
}
Err(err) => panic!("{:?}", err),
}
}
if all_correct {
println!("All result values are correct.");
} else {
println!("Errors found!");
}
}
}
}