use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::convert::TryInto;
use crate::vec_or_slice::VecOrSlice;
use crate::*;
#[derive(Debug, Fail, PartialEq, Eq, Clone)]
pub enum SessionError {
#[fail(display = "The given queue index {} was out of range", _0)]
QueueIndexOutOfRange(usize),
}
#[derive(Debug)]
pub struct Session {
devices: ManuallyDrop<Vec<ClDeviceID>>,
context: ManuallyDrop<ClContext>,
program: ManuallyDrop<ClProgram>,
queues: ManuallyDrop<Vec<ClCommandQueue>>,
}
impl Session {
pub fn create_with_devices<'a, D>(devices: D, src: &str) -> Output<Session>
where
D: Into<VecOrSlice<'a, ClDeviceID>>,
{
unsafe {
let devices = devices.into();
let context = ClContext::create(devices.as_slice())?;
let mut program = ClProgram::create_with_source(&context, src)?;
program.build(devices.as_slice())?;
let props = CommandQueueProperties::default();
let maybe_queues: Result<Vec<ClCommandQueue>, Error> = devices
.iter()
.map(|dev| ClCommandQueue::create(&context, dev, Some(props)))
.collect();
let queues = maybe_queues?;
let sess = Session {
devices: ManuallyDrop::new(devices.to_vec()),
context: ManuallyDrop::new(context),
program: ManuallyDrop::new(program),
queues: ManuallyDrop::new(queues),
};
Ok(sess)
}
}
pub fn create(src: &str) -> Output<Session> {
let platforms = list_platforms()?;
let mut devices = Vec::new();
for platform in platforms.iter() {
let platform_devices = list_devices_by_type(platform, DeviceType::ALL)?;
devices.extend(platform_devices);
}
Session::create_with_devices(devices, src)
}
pub unsafe fn decompose(
mut self,
) -> (Vec<ClDeviceID>, ClContext, ClProgram, Vec<ClCommandQueue>) {
let devices: Vec<ClDeviceID> = utils::take_manually_drop(&mut self.devices);
let context: ClContext = utils::take_manually_drop(&mut self.context);
let program: ClProgram = utils::take_manually_drop(&mut self.program);
let queues: Vec<ClCommandQueue> = utils::take_manually_drop(&mut self.queues);
std::mem::forget(self);
(devices, context, program, queues)
}
pub fn devices(&self) -> &[ClDeviceID] {
&(*self.devices)[..]
}
pub fn context(&self) -> &ClContext {
&(*self.context)
}
pub fn program(&self) -> &ClProgram {
&(*self.program)
}
pub fn queues(&self) -> &[ClCommandQueue] {
&(*self.queues)[..]
}
pub unsafe fn create_kernel(&self, kernel_name: &str) -> Output<ClKernel> {
ClKernel::create(self.program(), kernel_name)
}
pub unsafe fn create_mem<T: ClNumber, B: BufferCreator<T>>(
&self,
buffer_creator: B,
) -> Output<ClMem> {
let cfg = buffer_creator.mem_config();
ClMem::create_with_config(self.context(), buffer_creator, cfg)
}
pub unsafe fn create_mem_with_config<T: ClNumber, B: BufferCreator<T>>(
&self,
buffer_creator: B,
mem_config: MemConfig,
) -> Output<ClMem> {
ClMem::create_with_config(self.context(), buffer_creator, mem_config)
}
#[inline]
fn get_queue_by_index(&mut self, index: usize) -> Output<&mut ClCommandQueue> {
self.queues
.get_mut(index)
.ok_or_else(|| SessionError::QueueIndexOutOfRange(index).into())
}
pub unsafe fn write_buffer<'a, T: ClNumber, H: Into<VecOrSlice<'a, T>>>(
&mut self,
queue_index: usize,
mem: &mut ClMem,
host_buffer: H,
opts: Option<CommandQueueOptions>,
) -> Output<ClEvent> {
mem.number_type().match_or_panic(T::number_type());
let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
queue.write_buffer(mem, host_buffer, opts)
}
pub unsafe fn read_buffer<'a, T: ClNumber, H: Into<MutVecOrSlice<'a, T>>>(
&mut self,
queue_index: usize,
mem: &mut ClMem,
host_buffer: H,
opts: Option<CommandQueueOptions>,
) -> Output<BufferReadEvent<T>> {
let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
queue.read_buffer(mem, host_buffer, opts)
}
pub unsafe fn enqueue_kernel(
&mut self,
queue_index: usize,
kernel: &mut ClKernel,
work: &Work,
opts: Option<CommandQueueOptions>,
) -> Output<ClEvent> {
let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
let cq_opts: CommandQueueOptions = opts.into();
let event = cl_enqueue_nd_range_kernel(
queue.command_queue_ptr(),
kernel.kernel_ptr(),
work,
&cq_opts.waitlist[..],
)?;
ClEvent::new(event)
}
pub fn execute_sync_kernel_operation(
&mut self,
queue_index: usize,
mut kernel_op: KernelOperation,
) -> Output<()> {
unsafe {
let mut kernel = self.create_kernel(kernel_op.name())?;
let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
for (arg_index, (arg_size, arg_ptr)) in kernel_op.mut_args().iter_mut().enumerate() {
kernel.set_arg_raw(
arg_index.try_into().unwrap(),
*arg_size,
*arg_ptr
)?;
}
let work = kernel_op.work()?;
let event = queue.enqueue_kernel(&mut kernel, &work, kernel_op.command_queue_opts())?;
event.wait()?;
Ok(())
}
}
}
unsafe impl Send for Session {}
impl Drop for Session {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.queues);
ManuallyDrop::drop(&mut self.program);
ManuallyDrop::drop(&mut self.context);
ManuallyDrop::drop(&mut self.devices);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SessionQueue<'a> {
phantom: PhantomData<&'a ClCommandQueue>,
index: usize,
}
impl<'a> SessionQueue<'a> {
pub fn new(index: usize) -> SessionQueue<'a> {
SessionQueue {
index,
phantom: PhantomData,
}
}
}
#[derive(Debug, Fail, PartialEq, Eq, Clone)]
pub enum SessionBuilderError {
#[fail(display = "Given ClMem has no associated cl_mem object")]
NoAssociatedMemObject,
#[fail(
display = "For session building platforms AND devices cannot be specifed together; they are mutually exclusive."
)]
CannotSpecifyPlatformsAndDevices,
#[fail(
display = "For session building program src AND binaries cannot be specifed together; they are mutually exclusive."
)]
CannotSpecifyProgramSrcAndProgramBinaries,
#[fail(
display = "For session building either program src or program binaries must be specified."
)]
MustSpecifyProgramSrcOrProgramBinaries,
#[fail(
display = "Building a session with program binaries requires exactly 1 device: Got {:?} devices",
_0
)]
BinaryProgramRequiresExactlyOneDevice(usize),
}
const CANNOT_SPECIFY_SRC_AND_BINARIES: Error =
Error::SessionBuilderError(SessionBuilderError::CannotSpecifyProgramSrcAndProgramBinaries);
const MUST_SPECIFY_SRC_OR_BINARIES: Error =
Error::SessionBuilderError(SessionBuilderError::MustSpecifyProgramSrcOrProgramBinaries);
#[derive(Default)]
pub struct SessionBuilder<'a> {
pub program_src: Option<&'a str>,
pub program_binaries: Option<&'a [u8]>,
pub device_type: Option<DeviceType>,
pub platforms: Option<&'a [ClPlatformID]>,
pub devices: Option<&'a [ClDeviceID]>,
pub command_queue_properties: Option<CommandQueueProperties>,
}
impl<'a> SessionBuilder<'a> {
pub fn new() -> SessionBuilder<'a> {
SessionBuilder {
program_src: None,
program_binaries: None,
device_type: None,
platforms: None,
devices: None,
command_queue_properties: None,
}
}
pub fn with_program_src(mut self, src: &'a str) -> SessionBuilder<'a> {
self.program_src = Some(src);
self
}
pub fn with_program_binaries(mut self, bins: &'a [u8]) -> SessionBuilder<'a> {
self.program_binaries = Some(bins);
self
}
pub fn with_platforms(mut self, platforms: &'a [ClPlatformID]) -> SessionBuilder<'a> {
self.platforms = Some(platforms);
self
}
pub fn with_devices(mut self, devices: &'a [ClDeviceID]) -> SessionBuilder<'a> {
self.devices = Some(devices);
self
}
pub fn with_device_type(mut self, device_type: DeviceType) -> SessionBuilder<'a> {
self.device_type = Some(device_type);
self
}
pub fn with_command_queue_properties(
mut self,
props: CommandQueueProperties,
) -> SessionBuilder<'a> {
self.command_queue_properties = Some(props);
self
}
fn check_for_error_state(&self) -> Output<()> {
match self {
Self {
program_src: Some(_),
program_binaries: Some(_),
..
} => return Err(CANNOT_SPECIFY_SRC_AND_BINARIES),
Self {
program_src: None,
program_binaries: None,
..
} => return Err(MUST_SPECIFY_SRC_OR_BINARIES),
_ => Ok(()),
}
}
pub unsafe fn build(self) -> Output<Session> {
self.check_for_error_state()?;
let context_builder = ClContextBuilder {
devices: self.devices,
device_type: self.device_type,
platforms: self.platforms,
};
let built_context = context_builder.build()?;
let (context, devices): (ClContext, Vec<ClDeviceID>) = match built_context {
BuiltClContext::Context(ctx) => (ctx, self.devices.unwrap().to_vec()),
BuiltClContext::ContextWithDevices(ctx, owned_devices) => (ctx, owned_devices),
};
let program: ClProgram = match (&self, devices.len()) {
(
Self {
program_src: Some(src),
..
},
_,
) => {
let mut prog: ClProgram = ClProgram::create_with_source(&context, src)?;
prog.build(&devices[..])?;
Ok(prog)
}
(
Self {
program_binaries: Some(bins),
..
},
1,
) => {
let mut prog: ClProgram =
ClProgram::create_with_binary(&context, &devices[0], *bins)?;
prog.build(&devices[..])?;
Ok(prog)
}
(
Self {
program_binaries: Some(_),
..
},
n_devices,
) => {
let e = SessionBuilderError::BinaryProgramRequiresExactlyOneDevice(n_devices);
Err(Error::SessionBuilderError(e))
}
_ => unreachable!(),
}?;
let props = CommandQueueProperties::default();
let maybe_queues: Result<Vec<ClCommandQueue>, Error> = devices
.iter()
.map(|dev| ClCommandQueue::create(&context, dev, Some(props)))
.collect();
let queues = maybe_queues?;
let sess = Session {
devices: ManuallyDrop::new(devices),
context: ManuallyDrop::new(context),
program: ManuallyDrop::new(program),
queues: ManuallyDrop::new(queues),
};
Ok(sess)
}
}
#[cfg(test)]
mod tests {
use crate::{BufferReadEvent, KernelOperation, Session};
const SRC: &'static str = "__kernel void test(__global int *data) {
data[get_global_id(0)] += 1;
}";
fn get_session(src: &str) -> Session {
Session::create(src).unwrap_or_else(|e| panic!("Failed to get_session {:?}", e))
}
#[test]
fn session_execute_sync_kernel_operation_works() {
let mut session = get_session(SRC);
let data: Vec<i32> = vec![1, 2, 3, 4, 5];
let dims = data.len();
let mut buff = unsafe { session.create_mem(&data[..]) }.unwrap();
let kernel_op = KernelOperation::new("test")
.with_dims(dims)
.add_arg(&mut buff);
session
.execute_sync_kernel_operation(0, kernel_op)
.unwrap_or_else(|e| {
panic!("Failed to execute sync kernel operation: {:?}", e);
});
let data3 = vec![0i32; 5];
unsafe {
let mut read_event: BufferReadEvent<i32> = session
.read_buffer(0, &mut buff, data3, None)
.unwrap_or_else(|e| {
panic!("Failed to read buffer: {:?}", e);
});
let data4 = read_event
.wait()
.unwrap_or_else(|e| panic!("Failed to wait for read event: {:?}", e))
.unwrap();
assert_eq!(data4, vec![2, 3, 4, 5, 6]);
}
}
}