pub use cl3::context::{CL_CONTEXT_INTEROP_USER_SYNC, CL_CONTEXT_PLATFORM};
use super::command_queue::CommandQueue;
use super::device::{Device, SubDevice};
use super::kernel::Kernel;
use super::memory::get_supported_image_formats;
use super::program::Program;
use cl3::context;
use cl3::types::{
cl_command_queue_properties, cl_context, cl_context_properties, cl_device_id,
cl_device_partition_property, cl_device_svm_capabilities, cl_image_format, cl_int,
cl_mem_flags, cl_mem_object_type, cl_uint,
};
use libc::{c_char, c_void, intptr_t, size_t};
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::ptr;
pub struct Context {
context: cl_context,
devices: Vec<cl_device_id>,
sub_devices: Vec<SubDevice>,
queues: Vec<CommandQueue>,
programs: Vec<Program>,
kernels: HashMap<CString, Kernel>,
}
impl Drop for Context {
fn drop(&mut self) {
self.kernels.clear();
self.programs.clear();
self.queues.clear();
self.sub_devices.clear();
self.devices.clear();
context::release_context(self.context).unwrap();
}
}
impl Context {
fn new(context: cl_context, devices: Vec<cl_device_id>) -> Context {
Context {
context,
devices,
sub_devices: Vec::<SubDevice>::default(),
queues: Vec::<CommandQueue>::default(),
programs: Vec::<Program>::default(),
kernels: HashMap::<CString, Kernel>::default(),
}
}
pub fn get(&self) -> cl_context {
self.context
}
pub fn from_devices(
devices: Vec<cl_device_id>,
properties: *const cl_context_properties,
pfn_notify: Option<extern "C" fn(*const c_char, *const c_void, size_t, *mut c_void)>,
user_data: *mut c_void,
) -> Result<Context, cl_int> {
let context = context::create_context(&devices, properties, pfn_notify, user_data)?;
Ok(Context::new(context, devices))
}
pub fn from_device(device: Device) -> Result<Context, cl_int> {
let devices: Vec<cl_device_id> = vec![device.id()];
Context::from_devices(devices, ptr::null(), None, ptr::null_mut())
}
pub fn add_command_queue(&mut self, queue: CommandQueue) {
self.queues.push(queue);
}
#[cfg(feature = "CL_VERSION_1_2")]
pub fn create_command_queues(
&mut self,
properties: cl_command_queue_properties,
) -> Result<(), cl_int> {
for index in 0..self.devices.len() {
let device = self.devices[index];
let queue = CommandQueue::create(self.context, device, properties)?;
self.add_command_queue(queue);
}
Ok(())
}
#[cfg(feature = "CL_VERSION_2_0")]
pub fn create_command_queues_with_properties(
&mut self,
properties: cl_command_queue_properties,
queue_size: cl_uint,
) -> Result<(), cl_int> {
for index in 0..self.devices.len() {
let device = self.devices[index];
let queue =
CommandQueue::create_with_properties(self.context, device, properties, queue_size)?;
self.add_command_queue(queue);
}
Ok(())
}
pub fn add_program(&mut self, program: Program) -> Result<usize, cl_int> {
let kernels = program.create_kernels_in_program()?;
let count = kernels.len();
for kernel in kernels {
let kernel = Kernel::new(kernel)?;
let name = kernel.function_name()?;
self.kernels.insert(name, kernel);
}
self.programs.push(program);
Ok(count)
}
pub fn build_program_from_source(
&mut self,
src: &CStr,
options: &CStr,
) -> Result<usize, cl_int> {
let src_array = [src];
let program = Program::create_from_source(self.context, &src_array)?;
program.build(&self.devices, &options)?;
self.add_program(program)
}
pub fn build_program_from_binary(
&mut self,
binaries: &[&[u8]],
options: &CStr,
) -> Result<usize, cl_int> {
let program = Program::create_from_binary(self.context, &self.devices, binaries)?;
program.build(&self.devices, &options)?;
self.add_program(program)
}
pub fn create_sub_devices(
&mut self,
device: cl_device_id,
properties: &[cl_device_partition_property],
) -> Result<usize, cl_int> {
let device = Device::new(device);
let sub_devs = device.create_sub_devices(properties)?;
let count = sub_devs.len();
for device_id in sub_devs {
self.sub_devices.push(SubDevice::new(device_id));
}
Ok(count)
}
pub fn get_kernel(&self, kernel_name: &CStr) -> Option<&Kernel> {
self.kernels.get::<CStr>(&kernel_name)
}
pub fn get_svm_mem_capability(&self) -> cl_device_svm_capabilities {
let device = Device::new(self.devices[0]);
let mut svm_capability = device.svm_mem_capability();
for index in 1..self.devices.len() {
let device = Device::new(self.devices[index]);
svm_capability &= device.svm_mem_capability();
}
svm_capability
}
pub fn get_supported_image_formats(
&self,
flags: cl_mem_flags,
image_type: cl_mem_object_type,
) -> Result<Vec<cl_image_format>, cl_int> {
get_supported_image_formats(self.context, flags, image_type)
}
#[cfg(feature = "CL_VERSION_2_1")]
#[inline]
pub fn set_default_device_command_queue(
&self,
device: cl_device_id,
queue: &CommandQueue,
) -> Result<(), cl_int> {
set_default_device_command_queue(self.context, device, queue.get())
}
pub fn devices(&self) -> &[cl_device_id] {
&self.devices
}
pub fn sub_devices(&self) -> &[SubDevice] {
&self.sub_devices
}
pub fn queues(&self) -> &[CommandQueue] {
&self.queues
}
pub fn default_queue(&self) -> &CommandQueue {
&self.queues[0]
}
pub fn programs(&self) -> &[Program] {
&self.programs
}
pub fn kernels(&self) -> &HashMap<CString, Kernel> {
&self.kernels
}
#[cfg(feature = "CL_VERSION_3_0")]
#[inline]
pub fn set_destructor_callback(
&self,
pfn_notify: extern "C" fn(cl_context, *const c_void),
user_data: *mut c_void,
) -> Result<(), cl_int> {
set_context_destructor_callback(self.context, pfn_notify, user_data)
}
pub fn reference_count(&self) -> Result<cl_uint, cl_int> {
Ok(context::get_context_info(
self.context,
context::ContextInfo::CL_CONTEXT_REFERENCE_COUNT,
)?
.to_uint())
}
pub fn properties(&self) -> Result<Vec<intptr_t>, cl_int> {
Ok(
context::get_context_info(self.context, context::ContextInfo::CL_CONTEXT_PROPERTIES)?
.to_vec_intptr(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::Device;
use crate::platform::get_platforms;
use cl3::device::CL_DEVICE_TYPE_GPU;
#[test]
fn test_context() {
let platforms = get_platforms().unwrap();
assert!(0 < platforms.len());
let platform = &platforms[0];
let devices = platform.get_devices(CL_DEVICE_TYPE_GPU).unwrap();
assert!(0 < devices.len());
let device = Device::new(devices[0]);
let context = Context::from_device(device).unwrap();
println!(
"CL_CONTEXT_REFERENCE_COUNT: {}",
context.reference_count().unwrap()
);
}
}