extern crate chrono;
extern crate futures;
extern crate futures_cpupool;
use self::chrono::{DateTime, Duration, Local};
use self::futures::Future;
use self::futures_cpupool::{Builder as CpuPoolBuilder, CpuFuture};
use crate::error::Error as OclError;
use crate::ffi::{c_void, cl_event};
use crate::flags::{CommandQueueProperties, MemFlags};
use crate::prm::Int4;
use crate::r#async::{BufferSink, BufferStream};
use crate::traits::IntoRawEventArray;
use crate::{Buffer, Context, Device, Event, Kernel, Platform, Program, Queue, RwVec};
use std::fmt::Debug;
use std::sync::mpsc::{self, Receiver};
use std::thread::{self, JoinHandle};
const WORK_SIZE: usize = 1 << 18;
const INIT_VAL: i32 = 50;
const SCALAR_ADDEND: i32 = 100;
const TASK_ITERS: i32 = 10;
const PRINT: bool = false;
const MAX_CONCURRENT_TASK_COUNT: usize = 4;
static mut START_TIME: Option<DateTime<Local>> = None;
pub static KERN_SRC: &'static str = r#"
__kernel void add_slowly(
__global int4* in,
__private int addend,
__global int4* out)
{
uint const idx = get_global_id(0);
float4 const inflated_val = (float4)(addend) * (float4)(255.0);
int4 sum = (int4)(0);
for (int i = 0; i < addend; i++) {
sum += convert_int4((inflated_val / (float4)(255.0)) / (float4)(addend));
}
out[idx] = in[idx] + sum;
}
"#;
pub fn fmt_duration(duration: Duration) -> String {
let el_sec = duration.num_seconds();
let el_mus = duration.num_microseconds().unwrap() - (el_sec * 1000000);
format!("{}.{:06}", el_sec, el_mus)
}
pub fn timestamp() -> String {
fmt_duration(chrono::Local::now() - unsafe { START_TIME.unwrap() })
}
pub fn completion_thread<T, E>(rx: Receiver<Option<CpuFuture<T, E>>>) -> JoinHandle<()>
where
T: Send + 'static,
E: Send + Debug + 'static,
{
thread::spawn(move || {
let mut task_i = 0usize;
loop {
match rx.recv().unwrap() {
Some(task) => {
task.wait().unwrap();
if PRINT {
println!("Task {} complete (t: {}s)", task_i, timestamp());
}
task_i += 1;
continue;
}
None => break,
}
}
if PRINT {
println!("All {} futures complete.", task_i);
}
})
}
pub fn fill_junk(
src_buf: &Buffer<Int4>,
common_queue: &Queue,
verify_init_event: Option<&Event>,
kernel_event: Option<&Event>,
fill_event: &mut Option<Event>,
task_iter: i32,
) {
extern "C" fn _print_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Fill starting \t\t(iter: {}, t: {}s) ...",
task_iter as usize,
timestamp()
);
}
}
extern "C" fn _print_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Fill complete \t\t(iter: {}, t: {}s)",
task_iter as usize,
timestamp()
);
}
}
let wait_list = [&kernel_event, &verify_init_event].into_raw_array();
let fill_wait_marker = wait_list.to_marker(&common_queue).unwrap();
if let Some(ref marker) = fill_wait_marker {
unsafe {
marker
.set_callback(_print_starting, task_iter as *mut c_void)
.unwrap();
}
} else {
_print_starting(0 as cl_event, 0, task_iter as *mut c_void);
}
*fill_event = Some(Event::empty());
src_buf
.cmd()
.fill(Int4::new(-999, -999, -999, -999), None)
.queue(common_queue)
.ewait(&wait_list)
.enew(fill_event.as_mut())
.enq()
.unwrap();
unsafe {
fill_event
.as_ref()
.unwrap()
.set_callback(_print_complete, task_iter as *mut c_void)
.unwrap();
}
}
pub fn write_init(
src_buf_sink: &BufferSink<Int4>,
fill_event: Option<&Event>,
verify_init_event: Option<&Event>,
write_init_event: &mut Option<Event>,
write_val: i32,
task_iter: i32,
) -> Box<dyn Future<Item = i32, Error = OclError> + Send> {
extern "C" fn _write_write_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Write init (write) complete \t(iter: {}, t: {}s)",
task_iter as usize,
timestamp()
);
}
}
extern "C" fn _write_flush_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Write init (flush) complete \t(iter: {}, t: {}s)",
task_iter as usize,
timestamp()
);
}
}
let mut write_complete_event = Event::empty();
let future_write_data = src_buf_sink
.clone()
.write()
.ewait_lock([&fill_event, &verify_init_event])
.enew_release(
src_buf_sink.buffer().default_queue().unwrap(),
&mut write_complete_event,
);
unsafe {
write_complete_event
.set_callback(_write_write_complete, task_iter as *mut c_void)
.unwrap();
}
*write_init_event = Some(Event::empty());
let future_flush = src_buf_sink
.clone()
.flush()
.enew(write_init_event.as_mut())
.enq()
.unwrap();
unsafe {
write_init_event
.as_ref()
.unwrap()
.set_callback(_write_flush_complete, task_iter as *mut c_void)
.unwrap();
}
let future_write = future_write_data.and_then(move |mut data| {
if PRINT {
println!(
"* Write init starting \t\t(iter: {}, t: {}s) ...",
task_iter,
timestamp()
);
}
for val in data.iter_mut() {
*val = Int4::new(write_val, write_val, write_val, write_val);
}
Ok(task_iter)
});
Box::new(
future_write
.join(future_flush)
.map(|(task_iter, _)| task_iter),
)
}
pub fn verify_init(
src_buf: &Buffer<Int4>,
dst_vec: &RwVec<Int4>,
common_queue: &Queue,
write_init_event: Option<&Event>,
verify_init_event: &mut Option<Event>,
correct_val: i32,
task_iter: i32,
) -> Box<dyn Future<Item = i32, Error = OclError> + Send> {
extern "C" fn _verify_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Verify init starting \t\t(iter: {}, t: {}s) ...",
task_iter as usize,
timestamp()
);
}
}
let wait_list = [&write_init_event, &verify_init_event.as_ref()].into_raw_array();
let mut future_read_data = src_buf
.cmd()
.read(dst_vec)
.queue(common_queue)
.ewait(&wait_list)
.enq_async()
.unwrap();
unsafe {
future_read_data
.lock_event()
.unwrap()
.set_callback(_verify_starting, task_iter as *mut c_void)
.unwrap();
}
*verify_init_event = Some(
future_read_data
.create_release_event(common_queue)
.unwrap()
.clone(),
);
Box::new(future_read_data.and_then(move |data| {
let mut val_count = 0;
for (idx, val) in data.iter().enumerate() {
let cval = Int4::new(correct_val, correct_val, correct_val, correct_val);
if *val != cval {
return Err(format!(
"Verify init: Result value mismatch: {:?} != {:?} @ [{}]",
val, cval, idx
)
.into());
}
val_count += 1;
}
if PRINT {
println!(
"* Verify init complete \t\t(iter: {}, t: {}s)",
task_iter,
timestamp()
);
}
Ok(val_count)
}))
}
pub fn kernel_add(
kern: &Kernel,
common_queue: &Queue,
verify_add_event: Option<&Event>,
write_init_event: Option<&Event>,
kernel_event: &mut Option<Event>,
task_iter: i32,
) {
extern "C" fn _print_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Kernel starting \t\t(iter: {}, t: {}s) ...",
task_iter as usize,
timestamp()
);
}
}
extern "C" fn _print_complete(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Kernel complete \t\t(iter: {}, t: {}s)",
task_iter as usize,
timestamp()
);
}
}
let wait_list = [&verify_add_event, &write_init_event].into_raw_array();
let kernel_wait_marker = wait_list.to_marker(&common_queue).unwrap();
unsafe {
kernel_wait_marker
.as_ref()
.unwrap()
.set_callback(_print_starting, task_iter as *mut c_void)
.unwrap();
}
*kernel_event = Some(Event::empty());
unsafe {
kern.cmd()
.queue(common_queue)
.ewait(&wait_list)
.enew(kernel_event.as_mut())
.enq()
.unwrap();
}
unsafe {
kernel_event
.as_ref()
.unwrap()
.set_callback(_print_complete, task_iter as *mut c_void)
.unwrap();
}
}
pub fn verify_add(
dst_buf_stream: &BufferStream<Int4>,
kernel_event: Option<&Event>,
verify_add_event: &mut Option<Event>,
correct_val: i32,
task_iter: i32,
) -> Box<dyn Future<Item = i32, Error = OclError> + Send> {
extern "C" fn _verify_starting(_: cl_event, _: i32, task_iter: *mut c_void) {
if PRINT {
println!(
"* Verify add starting \t\t(iter: {}, t: {}s) ...",
task_iter as usize,
timestamp()
);
}
}
unsafe {
kernel_event
.as_ref()
.unwrap()
.set_callback(_verify_starting, task_iter as *mut c_void)
.unwrap();
}
let future_flood = dst_buf_stream
.clone()
.flood()
.ewait(kernel_event)
.enq()
.unwrap();
*verify_add_event = Some(Event::empty());
let queue = dst_buf_stream.buffer().default_queue().unwrap();
let future_read_data = dst_buf_stream
.clone()
.read()
.enew_release(queue, verify_add_event.as_mut().unwrap());
let future_read = future_read_data.and_then(move |data| {
let mut val_count = 0;
let cval = Int4::splat(correct_val);
for (idx, val) in data.iter().enumerate() {
if *val != cval {
return Err(format!(
"Verify add: Result value mismatch: {:?} != {:?} @ [{}]",
val, cval, idx
)
.into());
}
val_count += 1;
}
if PRINT {
println!(
"* Verify add complete \t\t(iter: {}, t: {}s)",
task_iter,
timestamp()
);
}
Ok(val_count)
});
Box::new(
future_flood
.join(future_read)
.map(|(_, task_iter)| task_iter),
)
}
#[test]
pub fn buffer_sink_stream_cycles() {
let platform = Platform::default();
println!("Platform: {}", platform.name().unwrap());
let device = Device::first(platform).unwrap();
println!(
"Device: {} {}",
device.vendor().unwrap(),
device.name().unwrap()
);
let context = Context::builder()
.platform(platform)
.devices(device)
.build()
.unwrap();
let queue_flags = Some(CommandQueueProperties::new().out_of_order());
let common_queue = Queue::new(&context, device, queue_flags)
.or_else(|_| Queue::new(&context, device, None))
.unwrap();
let src_buf_flags = MemFlags::new().alloc_host_ptr().read_only();
let dst_buf_flags = MemFlags::new()
.alloc_host_ptr()
.write_only()
.host_read_only();
let src_buf: Buffer<Int4> = Buffer::builder()
.context(&context)
.flags(src_buf_flags)
.len(WORK_SIZE)
.build()
.unwrap();
let src_buf_sink = unsafe {
BufferSink::from_buffer(
src_buf.clone(),
Some(common_queue.clone()),
0,
src_buf.len(),
)
.unwrap()
};
let dst_buf: Buffer<Int4> = Buffer::builder()
.context(&context)
.flags(dst_buf_flags)
.len(WORK_SIZE)
.build()
.unwrap();
let dst_buf_stream = unsafe {
BufferStream::from_buffer(
dst_buf.clone(),
Some(common_queue.clone()),
0,
dst_buf.len(),
)
.unwrap()
};
let program = Program::builder()
.devices(device)
.src(KERN_SRC)
.build(&context)
.unwrap();
let kern = Kernel::builder()
.program(&program)
.name("add_slowly")
.global_work_size(WORK_SIZE)
.arg(&src_buf)
.arg(SCALAR_ADDEND)
.arg(&dst_buf)
.build()
.unwrap();
let rw_vec: RwVec<Int4> = RwVec::from(vec![Default::default(); WORK_SIZE]);
let thread_pool = CpuPoolBuilder::new().name_prefix("pool_th_").create();
let (tx, rx) = mpsc::sync_channel::<Option<CpuFuture<_, _>>>(MAX_CONCURRENT_TASK_COUNT - 2);
let completion_thread = completion_thread(rx);
let mut fill_event = None;
let mut write_init_event = None;
let mut verify_init_event: Option<Event> = None;
let mut kernel_event = None;
let mut verify_add_event = None;
unsafe {
START_TIME = Some(Local::now());
}
if PRINT {
println!("Starting cycles (t: {}s) ...", timestamp());
}
for task_iter in 0..TASK_ITERS {
let ival = INIT_VAL + task_iter;
let tval = ival + SCALAR_ADDEND;
fill_junk(
&src_buf,
&common_queue,
verify_init_event.as_ref(),
kernel_event.as_ref(),
&mut fill_event,
task_iter,
);
let write_init = write_init(
&src_buf_sink,
fill_event.as_ref(),
verify_init_event.as_ref(),
&mut write_init_event,
ival,
task_iter,
);
let verify_init = verify_init(
&src_buf,
&rw_vec,
&common_queue,
write_init_event.as_ref(),
&mut verify_init_event,
ival,
task_iter,
);
kernel_add(
&kern,
&common_queue,
verify_add_event.as_ref(),
write_init_event.as_ref(),
&mut kernel_event,
task_iter,
);
let verify_add = verify_add(
&dst_buf_stream,
kernel_event.as_ref(),
&mut verify_add_event,
tval,
task_iter,
);
if PRINT {
println!(
"All commands for iteration {} enqueued (t: {}s)",
task_iter,
timestamp()
);
}
let join = write_init.join3(verify_init, verify_add);
let join_spawned = thread_pool.spawn(join);
tx.send(Some(join_spawned)).unwrap();
}
tx.send(None).unwrap();
completion_thread.join().unwrap();
println!(
"All result values are correct! \n\
Duration => | Total: {} seconds |",
timestamp()
);
}