extern crate chrono;
extern crate futures;
extern crate futures_cpupool;
extern crate ocl;
extern crate ocl_extras as extras;
extern crate rand;
#[macro_use]
extern crate colorify;
use crate::extras::{Command, CommandDetails, CommandGraph, KernelArgBuffer, SubBufferPool};
use futures::sync::mpsc::{self, Sender};
use futures::{stream, Future, Join, Sink, Stream};
use futures_cpupool::{CpuFuture, CpuPool};
use ocl::error::Error as OclError;
use ocl::flags::{CommandQueueProperties, MapFlags, MemFlags};
use ocl::prm::Float4;
use ocl::{
Context, Device, Event, EventList, FutureMemMap, Kernel, OclPrm, Platform, Program, Queue,
Result as OclResult,
};
use rand::{rngs::SmallRng, Rng, SeedableRng};
const INITIAL_BUFFER_LEN: u32 = 1 << 24; const SUB_BUF_MIN_LEN: u32 = 1 << 15; const SUB_BUF_MAX_LEN: u32 = 1 << 19;
enum TaskKind {
Simple,
Complex,
}
#[allow(dead_code)]
struct Task {
task_id: usize,
cmd_graph: CommandGraph,
kernels: Vec<Kernel>,
expected_result: Option<Float4>,
kind: TaskKind,
work_size: u32,
finish_events: EventList,
}
impl Task {
pub fn new(task_id: usize, kind: TaskKind, work_size: u32) -> Task {
Task {
task_id: task_id,
cmd_graph: CommandGraph::new(),
kernels: Vec::new(),
expected_result: None,
kind: kind,
work_size: work_size,
finish_events: EventList::new(),
}
}
pub fn add_fill_command(&mut self, target_buffer_id: usize) -> Result<usize, ()> {
self.cmd_graph.add(Command::new(CommandDetails::Fill {
target: target_buffer_id,
}))
}
pub fn add_write_command(&mut self, target_buffer_id: usize) -> Result<usize, ()> {
self.cmd_graph.add(Command::new(CommandDetails::Write {
target: target_buffer_id,
}))
}
pub fn add_read_command(&mut self, source_buffer_id: usize) -> Result<usize, ()> {
self.cmd_graph.add(Command::new(CommandDetails::Read {
source: source_buffer_id,
}))
}
pub fn add_kernel(
&mut self,
kernel: Kernel,
source_buffer_ids: Vec<KernelArgBuffer>,
target_buffer_ids: Vec<KernelArgBuffer>,
) -> Result<usize, ()> {
self.kernels.push(kernel);
self.cmd_graph.add(Command::new(CommandDetails::Kernel {
id: self.kernels.len() - 1,
sources: source_buffer_ids,
targets: target_buffer_ids,
}))
}
pub fn add_copy_command(
&mut self,
source_buffer_id: usize,
target_buffer_id: usize,
) -> Result<usize, ()> {
self.cmd_graph.add(Command::new(CommandDetails::Copy {
source: source_buffer_id,
target: target_buffer_id,
}))
}
pub fn set_expected_result(&mut self, expected_result: Float4) {
self.expected_result = Some(expected_result)
}
#[allow(dead_code)]
pub fn get_finish_events(&mut self) -> &mut EventList {
self.finish_events.clear();
self.cmd_graph.get_finish_events(&mut self.finish_events);
&mut self.finish_events
}
pub fn fill<T: OclPrm>(&self, pattern: T, cmd_idx: usize, buf_pool: &SubBufferPool<T>) {
let buffer_id = match *self.cmd_graph.commands()[cmd_idx].details() {
CommandDetails::Fill { target } => target,
_ => panic!("Task::fill: Not a fill command."),
};
let mut ev = Event::empty();
let buf = buf_pool.get(buffer_id).unwrap();
buf.cmd()
.fill(pattern, None)
.ewait(self.cmd_graph.get_req_events(cmd_idx).unwrap())
.enew(&mut ev)
.enq()
.unwrap();
self.cmd_graph.set_cmd_event(cmd_idx, ev).unwrap();
}
pub fn map<T: OclPrm>(&self, cmd_idx: usize, buf_pool: &SubBufferPool<T>) -> FutureMemMap<T> {
let (buffer_id, flags, is_write) = match *self.cmd_graph.commands()[cmd_idx].details() {
CommandDetails::Write { target } => {
(target, MapFlags::new().write_invalidate_region(), true)
}
CommandDetails::Read { source } => (source, MapFlags::new().read(), false),
_ => panic!("Task::map: Not a write or read command."),
};
let buf = buf_pool.get(buffer_id).unwrap();
let (map_wait_list, unmap_wait_list) = if is_write {
(
None,
Some(self.cmd_graph.get_req_events(cmd_idx).unwrap().clone()),
)
} else {
(Some(self.cmd_graph.get_req_events(cmd_idx).unwrap()), None)
};
let mut future_data = unsafe {
buf.cmd()
.map()
.flags(flags)
.ewait(map_wait_list)
.enq_async()
.unwrap()
};
if is_write {
future_data.set_unmap_wait_events(unmap_wait_list.unwrap());
}
let unmap_event_target = future_data.create_unmap_event().unwrap().clone();
self.cmd_graph
.set_cmd_event(cmd_idx, unmap_event_target.into())
.unwrap();
future_data
}
pub fn copy<T: OclPrm>(&self, cmd_idx: usize, buf_pool: &SubBufferPool<T>) {
let (src_buf_id, tar_buf_id) = match *self.cmd_graph.commands()[cmd_idx].details() {
CommandDetails::Copy { source, target } => (source, target),
_ => panic!("Task::copy: Not a copy command."),
};
let mut ev = Event::empty();
let src_buf = buf_pool.get(src_buf_id).unwrap();
let tar_buf = buf_pool.get(tar_buf_id).unwrap();
src_buf
.cmd()
.copy(tar_buf, None, None)
.ewait(self.cmd_graph.get_req_events(cmd_idx).unwrap())
.enew(&mut ev)
.enq()
.unwrap();
self.cmd_graph.set_cmd_event(cmd_idx, ev).unwrap();
}
pub fn kernel(&self, cmd_idx: usize) {
let kernel_id = match *self.cmd_graph.commands()[cmd_idx].details() {
CommandDetails::Kernel { id, .. } => id,
_ => panic!("Task::kernel: Not a kernel command."),
};
let mut ev = Event::empty();
unsafe {
self.kernels[kernel_id]
.cmd()
.enew(&mut ev)
.ewait(self.cmd_graph.get_req_events(cmd_idx).unwrap())
.enq()
.unwrap();
}
self.cmd_graph.set_cmd_event(cmd_idx, ev).unwrap();
}
}
fn coeff(add: bool) -> f32 {
if add {
1.
} else {
-1.
}
}
fn gen_kern_src(kernel_name: &str, type_str: &str, simple: bool, add: bool) -> String {
let op = if add { "+" } else { "-" };
if simple {
format!(
r#"__kernel void {kn}(
__global {ts}* in,
{ts} values,
__global {ts}* out)
{{
uint idx = get_global_id(0);
out[idx] = in[idx] {op} values;
}}"#,
kn = kernel_name,
op = op,
ts = type_str
)
} else {
format!(
r#"__kernel void {kn}(
__global {ts}* in_0,
__global {ts}* in_1,
__global {ts}* in_2,
{ts} values,
__global {ts}* out)
{{
uint idx = get_global_id(0);
out[idx] = in_0[idx] {op} in_1[idx] {op} in_2[idx] {op} values;
}}"#,
kn = kernel_name,
op = op,
ts = type_str
)
}
}
fn create_simple_task(
task_id: usize,
device: Device,
context: &Context,
buf_pool: &mut SubBufferPool<Float4>,
work_size: u32,
queues: &[Queue],
) -> Result<Task, ()> {
let write_buf_flags = Some(MemFlags::new().read_only() | MemFlags::new().host_write_only());
let read_buf_flags = Some(MemFlags::new().write_only() | MemFlags::new().host_read_only());
let mut task = Task::new(task_id, TaskKind::Simple, work_size);
let write_buf_id = match buf_pool.alloc(work_size, write_buf_flags) {
Ok(buf_id) => buf_id,
Err(_) => return Err(()),
};
let read_buf_id = match buf_pool.alloc(work_size, read_buf_flags) {
Ok(buf_id) => buf_id,
Err(_) => {
buf_pool.free(write_buf_id).ok();
return Err(());
}
};
buf_pool
.get_mut(write_buf_id)
.unwrap()
.set_default_queue(queues[0].clone());
buf_pool
.get_mut(read_buf_id)
.unwrap()
.set_default_queue(queues[1].clone());
let program = Program::builder()
.devices(device)
.src(gen_kern_src("kern", "float4", true, true))
.build(context)
.unwrap();
let kern = Kernel::builder()
.program(&program)
.name("kern")
.queue(queues[2].clone())
.global_work_size(work_size)
.arg(buf_pool.get(write_buf_id).unwrap())
.arg(Float4::new(100., 100., 100., 100.))
.arg(buf_pool.get(read_buf_id).unwrap())
.build()
.unwrap();
assert!(task.add_write_command(write_buf_id).unwrap() == 0);
assert!(
task.add_kernel(
kern,
vec![KernelArgBuffer::new(0, write_buf_id)],
vec![KernelArgBuffer::new(2, read_buf_id)]
)
.unwrap()
== 1
);
assert!(task.add_read_command(read_buf_id).unwrap() == 2);
task.cmd_graph.populate_requisites();
Ok(task)
}
fn enqueue_simple_task(
task: &Task,
buf_pool: &SubBufferPool<Float4>,
thread_pool: &CpuPool,
tx: Sender<usize>,
) -> Join<CpuFuture<usize, OclError>, CpuFuture<Sender<usize>, OclError>> {
let task_id = task.task_id;
let write = task.map(0, &buf_pool).and_then(move |mut data| {
for val in data.iter_mut() {
*val = Float4::new(50., 50., 50., 50.);
}
printlnc!(green: "Task [{}] (simple): Buffer initialized.", task_id);
Ok(task_id)
});
let write_spawned = thread_pool.spawn(write);
task.kernel(1);
let verify = task
.map(2, &buf_pool)
.and_then(move |data| {
let mut val_count = 0usize;
for val in data.iter() {
let correct_val = Float4::new(150., 150., 150., 150.);
if *val != correct_val {
return Err(
format!("Result value mismatch: {:?} != {:?}", val, correct_val).into(),
);
}
val_count += 1;
}
printlnc!(yellow: "Task [{}] (simple): Verify successful: \
{} values correct.", task_id, val_count);
Ok(tx.send(val_count))
})
.and_then(|send| send.map_err(|e| OclError::from(e)));
let verify_spawned = thread_pool.spawn(verify);
write_spawned.join(verify_spawned)
}
fn create_complex_task(
task_id: usize,
device: Device,
context: &Context,
buf_pool: &mut SubBufferPool<Float4>,
work_size: u32,
queues: &[Queue],
rng: &mut SmallRng,
) -> Result<Task, ()> {
let mut task = Task::new(task_id, TaskKind::Complex, work_size);
let buffer_count = 7;
let buffer_id_res: Vec<_> = (0..buffer_count)
.map(|i| {
let flags = match i {
0 => Some(MemFlags::new().read_only().host_write_only()),
1..=5 => Some(MemFlags::new().read_write().host_no_access()),
6 => Some(MemFlags::new().write_only().host_read_only()),
_ => panic!("Only 7 buffers are configured."),
};
buf_pool.alloc(work_size, flags)
})
.collect();
let mut buffer_ids = Vec::with_capacity(buffer_count);
for idx in 0..buffer_count {
match buffer_id_res[idx] {
Ok(buf_id) => {
buf_pool
.get_mut(buf_id)
.unwrap()
.set_default_queue(queues[idx].clone());
buffer_ids.push(buf_id)
}
Err(_) => {
for prev_idx in 0..idx {
buf_pool.free(buffer_id_res[prev_idx].unwrap()).ok();
}
return Err(());
}
}
}
let kern_a_sign = rng.gen();
let kern_b_sign = rng.gen();
let kern_c_sign = rng.gen();
let kern_a_val = rng.gen_range(-1000. ..1000.);
let kern_b_val = rng.gen_range(-500. ..500.);
let kern_c_val = rng.gen_range(-2000. ..2000.);
let program = Program::builder()
.devices(device)
.src(gen_kern_src("kernel_a", "float4", true, kern_a_sign))
.src(gen_kern_src("kernel_b", "float4", false, kern_b_sign))
.src(gen_kern_src("kernel_c", "float4", true, kern_c_sign))
.build(context)
.unwrap();
let kernel_a = Kernel::builder()
.program(&program)
.name("kernel_a")
.queue(queues[7].clone())
.global_work_size(work_size)
.arg(buf_pool.get(buffer_ids[0]).unwrap())
.arg(&Float4::new(kern_a_val, kern_a_val, kern_a_val, kern_a_val))
.arg(buf_pool.get(buffer_ids[1]).unwrap())
.build()
.unwrap();
let kernel_b = Kernel::builder()
.program(&program)
.name("kernel_b")
.queue(queues[7].clone())
.global_work_size(work_size)
.arg(buf_pool.get(buffer_ids[2]).unwrap())
.arg(buf_pool.get(buffer_ids[3]).unwrap())
.arg(buf_pool.get(buffer_ids[4]).unwrap())
.arg(&Float4::new(kern_b_val, kern_b_val, kern_b_val, kern_b_val))
.arg(buf_pool.get(buffer_ids[5]).unwrap())
.build()
.unwrap();
let kernel_c = Kernel::builder()
.program(&program)
.name("kernel_c")
.queue(queues[7].clone())
.global_work_size(work_size)
.arg(buf_pool.get(buffer_ids[5]).unwrap())
.arg(&Float4::new(kern_c_val, kern_c_val, kern_c_val, kern_c_val))
.arg(buf_pool.get(buffer_ids[6]).unwrap())
.build()
.unwrap();
assert!(task.add_write_command(buffer_ids[0]).unwrap() == 0);
assert!(
task.add_kernel(
kernel_a,
vec![KernelArgBuffer::new(0, buffer_ids[0])],
vec![KernelArgBuffer::new(2, buffer_ids[1])]
)
.unwrap()
== 1
);
assert!(task.add_copy_command(buffer_ids[1], buffer_ids[2]).unwrap() == 2);
assert!(task.add_copy_command(buffer_ids[1], buffer_ids[3]).unwrap() == 3);
assert!(task.add_fill_command(buffer_ids[4]).unwrap() == 4);
assert!(
task.add_kernel(
kernel_b,
vec![
KernelArgBuffer::new(0, buffer_ids[2]),
KernelArgBuffer::new(1, buffer_ids[3]),
KernelArgBuffer::new(2, buffer_ids[4])
],
vec![KernelArgBuffer::new(4, buffer_ids[5])]
)
.unwrap()
== 5
);
assert!(
task.add_kernel(
kernel_c,
vec![KernelArgBuffer::new(0, buffer_ids[5])],
vec![KernelArgBuffer::new(2, buffer_ids[6])]
)
.unwrap()
== 6
);
assert!(task.add_read_command(buffer_ids[6]).unwrap() == 7);
let kern_a_out_val = 500. + (coeff(kern_a_sign) * kern_a_val);
let kern_b_out_val = kern_a_out_val
+ (coeff(kern_b_sign) * kern_a_out_val)
+ (coeff(kern_b_sign) * 50.)
+ (coeff(kern_b_sign) * kern_b_val);
let kern_c_out_val = kern_b_out_val + (coeff(kern_c_sign) * kern_c_val);
task.set_expected_result(Float4::new(
kern_c_out_val,
kern_c_out_val,
kern_c_out_val,
kern_c_out_val,
));
task.cmd_graph.populate_requisites();
Ok(task)
}
fn enqueue_complex_task(
task: &Task,
buf_pool: &SubBufferPool<Float4>,
thread_pool: &CpuPool,
tx: Sender<usize>,
) -> Join<CpuFuture<usize, OclError>, CpuFuture<Sender<usize>, OclError>> {
let task_id = task.task_id;
let write = task.map(0, &buf_pool).and_then(move |mut data| {
for val in data.iter_mut() {
*val = Float4::new(500., 500., 500., 500.);
}
printlnc!(green_bold: "Task [{}] (complex): Buffer initialized.", task_id);
Ok(task_id)
});
task.kernel(1);
task.copy(2, buf_pool);
task.copy(3, buf_pool);
task.fill(Float4::new(50., 50., 50., 50.), 4, buf_pool);
task.kernel(5);
task.kernel(6);
let expected_result = task.expected_result.unwrap();
let verify = task
.map(7, &buf_pool)
.and_then(move |data| {
let mut val_count = 0usize;
for val in data.iter() {
let correct_val = expected_result;
if *val != correct_val {
return Err(
format!("Result value mismatch: {:?} != {:?}", val, correct_val).into(),
);
}
val_count += 1;
}
printlnc!(yellow_bold: "Task [{}] (complex): Verify successful: \
{} values correct.", task_id, val_count);
Ok(tx.send(val_count))
})
.and_then(|send| send.map_err(|e| OclError::from(e)));
let write_spawned = thread_pool.spawn(write);
let verify_spawned = thread_pool.spawn(verify);
write_spawned.join(verify_spawned)
}
fn fmt_duration(duration: chrono::Duration) -> String {
let el_sec = duration.num_seconds();
let el_ms = duration.num_milliseconds() - (el_sec * 1000);
format!("{}.{} seconds", el_sec, el_ms)
}
pub fn async_menagerie() -> OclResult<()> {
let buffer_size_range = SUB_BUF_MIN_LEN..SUB_BUF_MAX_LEN;
let mut rng = SmallRng::from_entropy();
let platform = Platform::default();
printlnc!(blue: "Platform: {}", platform.name()?);
let device_idx = rng.gen_range(0..15);
let device = Device::specifier()
.wrapping_indices(vec![device_idx])
.to_device_list(Some(platform))?[0];
printlnc!(teal: "Device: {} {}", device.vendor()?, device.name()?);
let context = Context::builder()
.platform(platform)
.devices(device)
.build()?;
let queue_flags = Some(CommandQueueProperties::new().out_of_order());
let queues_simple: Vec<_> = (0..3)
.map(|_| {
Queue::new(&context, device, queue_flags)
.or_else(|_| Queue::new(&context, device, None))
.unwrap()
})
.collect();
let queues_complex: Vec<_> = (0..8)
.map(|_| {
Queue::new(&context, device, queue_flags)
.or_else(|_| Queue::new(&context, device, None))
.unwrap()
})
.collect();
let mut buf_pool: SubBufferPool<Float4> = SubBufferPool::new(
INITIAL_BUFFER_LEN,
Queue::new(&context, device, queue_flags)
.or_else(|_| Queue::new(&context, device, None))?,
);
let mut tasks = Vec::with_capacity(256);
let thread_pool = CpuPool::new_num_cpus();
let mut correct_val_count = 0usize;
let (tx, mut rx) = mpsc::channel(1);
let start_time = chrono::Local::now();
printlnc!(white_bold: "Creating and enqueuing tasks...");
loop {
let work_size = rng.gen_range(buffer_size_range.clone());
let task_id = tasks.len();
let task_res = if rng.gen() {
create_simple_task(
task_id,
device,
&context,
&mut buf_pool,
work_size,
&queues_simple,
)
} else {
create_complex_task(
task_id,
device,
&context,
&mut buf_pool,
work_size,
&queues_complex,
&mut rng,
)
};
let task = match task_res {
Ok(task) => task,
Err(_) => {
println!("Buffer pool is now full.");
break;
}
};
match task.kind {
TaskKind::Simple => tasks.push(enqueue_simple_task(
&task,
&buf_pool,
&thread_pool,
tx.clone(),
)),
TaskKind::Complex => tasks.push(enqueue_complex_task(
&task,
&buf_pool,
&thread_pool,
tx.clone(),
)),
}
}
let create_enqueue_duration = chrono::Local::now() - start_time;
let task_count = tasks.len();
printlnc!(white_bold: "Waiting on {} tasks to complete...", task_count);
stream::futures_unordered(tasks)
.for_each(|(task_id, _)| {
printlnc!(orange: "Task [{}]: Complete.", task_id);
Ok(())
})
.wait()?;
rx.close();
for count in rx.wait() {
correct_val_count += count.unwrap();
}
let run_duration = chrono::Local::now() - start_time - create_enqueue_duration;
let total_duration = chrono::Local::now() - start_time;
printlnc!(white_bold: "\nAll {} (float4) result values from {} tasks are correct! \n\
Durations => | Create/Enqueue: {} | Run: {} | Total: {}",
correct_val_count, task_count, fmt_duration(create_enqueue_duration),
fmt_duration(run_duration), fmt_duration(total_duration));
Ok(())
}
pub fn main() {
match async_menagerie() {
Ok(_) => (),
Err(err) => println!("{}", err),
}
}