use std::fmt;
use libc::c_void;
use crate::ffi::{
clCreateBuffer, clGetMemObjectInfo, cl_context, cl_int, cl_mem, cl_mem_flags, cl_mem_info,
};
use crate::cl_helpers::cl_get_info5;
use crate::{
build_output, ClContext, ClNumber, ClPointer, ContextPtr, HostAccessMemFlags,
KernelAccessMemFlags, MemFlags, MemInfo, MemLocationMemFlags, Output, NumberType,
NumberTyped, ObjectWrapper,
};
pub unsafe fn cl_create_buffer_with_creator<T: ClNumber, B: BufferCreator<T>>(
context: cl_context,
mem_flags: cl_mem_flags,
buffer_creator: B,
) -> Output<cl_mem> {
cl_create_buffer(
context,
mem_flags,
buffer_creator.buffer_byte_size(),
buffer_creator.buffer_ptr()
)
}
pub unsafe fn cl_create_buffer(
context: cl_context,
mem_flags: cl_mem_flags,
size_in_bytes: usize,
ptr: *mut c_void,
) -> Output<cl_mem> {
let mut err_code: cl_int = 0;
let cl_mem_object: cl_mem =
clCreateBuffer(context, mem_flags, size_in_bytes, ptr, &mut err_code);
build_output(cl_mem_object, err_code)
}
pub fn cl_get_mem_object_info<T>(device_mem: cl_mem, flag: cl_mem_info) -> Output<ClPointer<T>>
where
T: Copy,
{
unsafe { cl_get_info5(device_mem, flag, clGetMemObjectInfo) }
}
pub trait BufferCreator<T: ClNumber>: Sized {
fn buffer_byte_size(&self) -> usize;
fn buffer_ptr(&self) -> *mut c_void;
fn mem_config(&self) -> MemConfig;
}
impl<T: ClNumber> BufferCreator<T> for &[T] {
fn buffer_byte_size(&self) -> usize {
std::mem::size_of::<T>() * self.len()
}
fn buffer_ptr(&self) -> *mut c_void {
self.as_ptr() as *const _ as *mut c_void
}
fn mem_config(&self) -> MemConfig {
MemConfig::for_data()
}
}
impl<T: ClNumber> BufferCreator<T> for &mut [T] {
fn buffer_byte_size(&self) -> usize {
std::mem::size_of::<T>() * self.len()
}
fn buffer_ptr(&self) -> *mut c_void {
self.as_ptr() as *const _ as *mut c_void
}
fn mem_config(&self) -> MemConfig {
MemConfig::for_data()
}
}
impl<T: ClNumber> BufferCreator<T> for usize {
fn buffer_byte_size(&self) -> usize {
std::mem::size_of::<T>() * *self
}
fn buffer_ptr(&self) -> *mut c_void {
std::ptr::null_mut()
}
fn mem_config(&self) -> MemConfig {
MemConfig::for_size()
}
}
pub unsafe trait MemPtr: NumberTyped {
unsafe fn mem_ptr(&self) -> cl_mem;
unsafe fn mem_ptr_ref(&self) -> &cl_mem;
unsafe fn get_info<I: Copy>(&self, flag: MemInfo) -> Output<ClPointer<I>> {
cl_get_mem_object_info::<I>(self.mem_ptr(), flag.into())
}
unsafe fn len(&self) -> Output<usize> {
let mem_size_in_bytes = self.size()?;
Ok(mem_size_in_bytes / self.number_type().size_of_t())
}
unsafe fn is_empty(&self) -> Output<bool> {
self.len().map(|l| l == 0)
}
unsafe fn context(&self) -> Output<ClContext> {
self.get_info::<cl_context>(MemInfo::Context)
.and_then(|cl_ptr| ClContext::retain_new(cl_ptr.into_one()))
}
unsafe fn reference_count(&self) -> Output<u32> {
self.get_info(MemInfo::ReferenceCount)
.map(|ret| ret.into_one())
}
unsafe fn size(&self) -> Output<usize> {
self.get_info(MemInfo::Size).map(|ret| ret.into_one())
}
unsafe fn offset(&self) -> Output<usize> {
self.get_info(MemInfo::Offset).map(|ret| ret.into_one())
}
unsafe fn flags(&self) -> Output<MemFlags> {
self.get_info(MemInfo::Flags).map(|ret| ret.into_one())
}
}
#[derive(Eq, PartialEq)]
pub struct ClMem {
inner: ObjectWrapper<cl_mem>,
t: NumberType,
}
impl NumberTyped for ClMem {
fn number_type(&self) -> NumberType {
self.t
}
}
impl ClMem {
pub unsafe fn new<T: ClNumber>(object: cl_mem) -> Output<ClMem> {
Ok(ClMem {
inner: ObjectWrapper::new(object)?,
t: T::number_type()
})
}
pub fn create<T: ClNumber, B: BufferCreator<T>>(
context: &ClContext,
buffer_creator: B,
host_access: HostAccess,
kernel_access: KernelAccess,
mem_location: MemLocation,
) -> Output<ClMem> {
unsafe {
let mem_object = cl_create_buffer_with_creator(
context.context_ptr(),
cl_mem_flags::from(host_access)
| cl_mem_flags::from(kernel_access)
| cl_mem_flags::from(mem_location),
buffer_creator,
)?;
ClMem::new::<T>(mem_object)
}
}
pub unsafe fn create_with_config<T: ClNumber, B: BufferCreator<T>>(
context: &ClContext,
buffer_creator: B,
mem_config: MemConfig,
) -> Output<ClMem> {
let mem_object = cl_create_buffer_with_creator(
context.context_ptr(),
mem_config.into(),
buffer_creator,
)?;
ClMem::new::<T>(mem_object)
}
}
unsafe impl MemPtr for ClMem {
unsafe fn mem_ptr(&self) -> cl_mem {
self.inner.cl_object()
}
unsafe fn mem_ptr_ref(&self) -> &cl_mem {
self.inner.cl_object_ref()
}
}
unsafe impl Send for ClMem {}
impl fmt::Debug for ClMem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", unsafe { self.mem_ptr() })
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn mem_can_be_created_with_len() {
let (context, _devices) = ll_testing::get_context();
let mem_config = MemConfig::default();
let _mem: ClMem =
unsafe { ClMem::create_with_config::<u32, usize>(&context, 10, mem_config).unwrap() };
}
#[test]
fn mem_can_be_created_with_slice() {
let (context, _devices) = ll_testing::get_context();
let data: Vec<u32> = vec![0, 1, 2, 3, 4];
let mem_config = MemConfig::for_data();
let _mem: ClMem =
unsafe { ClMem::create_with_config(&context, &data[..], mem_config).unwrap() };
}
mod mem_ptr_trait {
use crate::*;
#[test]
fn len_method_works() {
let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
let len = unsafe { ll_mem.len().unwrap() };
assert_eq!(len, 10);
}
#[test]
fn reference_count_method_works() {
let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
let ref_count = unsafe { ll_mem.reference_count().unwrap() };
assert_eq!(ref_count, 1);
}
#[test]
fn size_method_returns_size_in_bytes() {
let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
let bytes_size = unsafe { ll_mem.size().unwrap() };
assert_eq!(bytes_size, 10 * std::mem::size_of::<u32>());
}
#[test]
fn offset_method_works() {
let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
let offset = unsafe { ll_mem.offset().unwrap() };
assert_eq!(offset, 0);
}
#[test]
fn flags_method_works() {
let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
let flags = unsafe { ll_mem.flags().unwrap() };
assert_eq!(flags, MemFlags::READ_WRITE_ALLOC_HOST_PTR);
}
}
}
pub enum KernelAccess {
ReadOnly,
WriteOnly,
ReadWrite,
}
impl From<KernelAccess> for KernelAccessMemFlags {
fn from(kernel_access: KernelAccess) -> KernelAccessMemFlags {
match kernel_access {
KernelAccess::ReadOnly => KernelAccessMemFlags::READ_ONLY,
KernelAccess::WriteOnly => KernelAccessMemFlags::WRITE_ONLY,
KernelAccess::ReadWrite => KernelAccessMemFlags::READ_WRITE,
}
}
}
impl From<KernelAccess> for MemFlags {
fn from(kernel_access: KernelAccess) -> MemFlags {
MemFlags::from(KernelAccessMemFlags::from(kernel_access))
}
}
impl From<KernelAccess> for cl_mem_flags {
fn from(kernel_access: KernelAccess) -> cl_mem_flags {
cl_mem_flags::from(MemFlags::from(kernel_access))
}
}
pub enum HostAccess {
ReadOnly,
WriteOnly,
NoAccess,
ReadWrite,
}
impl From<HostAccess> for HostAccessMemFlags {
fn from(host_access: HostAccess) -> HostAccessMemFlags {
match host_access {
HostAccess::ReadOnly => HostAccessMemFlags::READ_ONLY,
HostAccess::WriteOnly => HostAccessMemFlags::WRITE_ONLY,
HostAccess::NoAccess => HostAccessMemFlags::NO_ACCESS,
HostAccess::ReadWrite => HostAccessMemFlags::READ_WRITE,
}
}
}
impl From<HostAccess> for MemFlags {
fn from(host_access: HostAccess) -> MemFlags {
MemFlags::from(HostAccessMemFlags::from(host_access))
}
}
impl From<HostAccess> for cl_mem_flags {
fn from(host_access: HostAccess) -> cl_mem_flags {
cl_mem_flags::from(MemFlags::from(host_access))
}
}
pub enum MemLocation {
KeepInPlace,
AllocOnDevice,
CopyToDevice,
ForceCopyToDevice,
}
impl From<MemLocation> for MemLocationMemFlags {
fn from(mem_location: MemLocation) -> MemLocationMemFlags {
match mem_location {
MemLocation::KeepInPlace => MemLocationMemFlags::KEEP_IN_PLACE,
MemLocation::AllocOnDevice => MemLocationMemFlags::ALLOC_ON_DEVICE,
MemLocation::CopyToDevice => MemLocationMemFlags::COPY_TO_DEVICE,
MemLocation::ForceCopyToDevice => MemLocationMemFlags::FORCE_COPY_TO_DEVICE,
}
}
}
impl From<MemLocation> for MemFlags {
fn from(mem_location: MemLocation) -> MemFlags {
MemFlags::from(MemLocationMemFlags::from(mem_location))
}
}
impl From<MemLocation> for cl_mem_flags {
fn from(mem_location: MemLocation) -> cl_mem_flags {
cl_mem_flags::from(MemFlags::from(mem_location))
}
}
pub struct MemConfig {
pub host_access: HostAccess,
pub kernel_access: KernelAccess,
pub mem_location: MemLocation,
}
impl MemConfig {
pub fn build(
host_access: HostAccess,
kernel_access: KernelAccess,
mem_location: MemLocation,
) -> MemConfig {
MemConfig {
host_access,
kernel_access,
mem_location,
}
}
}
impl From<MemConfig> for MemFlags {
fn from(mem_config: MemConfig) -> MemFlags {
unsafe { MemFlags::from_bits_unchecked(cl_mem_flags::from(mem_config)) }
}
}
impl From<MemConfig> for cl_mem_flags {
fn from(mem_config: MemConfig) -> cl_mem_flags {
cl_mem_flags::from(mem_config.host_access)
| cl_mem_flags::from(mem_config.kernel_access)
| cl_mem_flags::from(mem_config.mem_location)
}
}
impl Default for MemConfig {
fn default() -> MemConfig {
MemConfig {
host_access: HostAccess::ReadWrite,
kernel_access: KernelAccess::ReadWrite,
mem_location: MemLocation::AllocOnDevice,
}
}
}
impl MemConfig {
pub fn for_data() -> MemConfig {
MemConfig {
mem_location: MemLocation::CopyToDevice,
..MemConfig::default()
}
}
pub fn for_size() -> MemConfig {
MemConfig {
mem_location: MemLocation::AllocOnDevice,
..MemConfig::default()
}
}
}
#[cfg(test)]
mod mem_flags_tests {
use super::*;
use crate::KernelAccessMemFlags;
#[test]
fn kernel_access_read_only_conversion_into_kernel_access_mem_flag() {
let kernel_access = KernelAccess::ReadOnly;
assert_eq!(
KernelAccessMemFlags::from(kernel_access),
KernelAccessMemFlags::READ_ONLY
);
}
#[test]
fn kernel_access_write_only_conversion_into_kernel_access_mem_flag() {
let kernel_access = KernelAccess::WriteOnly;
assert_eq!(
KernelAccessMemFlags::from(kernel_access),
KernelAccessMemFlags::WRITE_ONLY
);
}
#[test]
fn kernel_access_convert_read_write_into_kernel_access_mem_flag() {
let kernel_access = KernelAccess::ReadWrite;
assert_eq!(
KernelAccessMemFlags::from(kernel_access),
KernelAccessMemFlags::READ_WRITE
);
}
#[test]
fn host_access_read_only_conversion_into_host_access_mem_flag() {
let host_access = HostAccess::ReadOnly;
assert_eq!(
HostAccessMemFlags::from(host_access),
HostAccessMemFlags::READ_ONLY
);
}
#[test]
fn host_access_write_only_conversion_into_host_access_mem_flag() {
let host_access = HostAccess::WriteOnly;
assert_eq!(
HostAccessMemFlags::from(host_access),
HostAccessMemFlags::WRITE_ONLY
);
}
#[test]
fn host_access_read_write_conversion_into_host_access_mem_flag() {
let host_access = HostAccess::ReadWrite;
assert_eq!(
HostAccessMemFlags::from(host_access),
HostAccessMemFlags::READ_WRITE
);
}
#[test]
fn host_access_no_access_conversion_into_host_access_mem_flag() {
let host_access = HostAccess::NoAccess;
assert_eq!(
HostAccessMemFlags::from(host_access),
HostAccessMemFlags::NO_ACCESS
);
}
#[test]
fn mem_location_keep_in_place_conversion_into_mem_location_mem_flag() {
let mem_location = MemLocation::KeepInPlace;
assert_eq!(
MemLocationMemFlags::from(mem_location),
MemLocationMemFlags::KEEP_IN_PLACE
);
}
#[test]
fn mem_location_alloc_on_device_conversion_into_mem_location_mem_flag() {
let mem_location = MemLocation::AllocOnDevice;
assert_eq!(
MemLocationMemFlags::from(mem_location),
MemLocationMemFlags::ALLOC_ON_DEVICE
);
}
#[test]
fn mem_location_copy_to_device_conversion_into_mem_location_mem_flag() {
let mem_location = MemLocation::CopyToDevice;
assert_eq!(
MemLocationMemFlags::from(mem_location),
MemLocationMemFlags::COPY_TO_DEVICE
);
}
#[test]
fn mem_location_force_copy_to_device_conversion_into_mem_location_mem_flag() {
let mem_location = MemLocation::ForceCopyToDevice;
assert_eq!(
MemLocationMemFlags::from(mem_location),
MemLocationMemFlags::FORCE_COPY_TO_DEVICE
);
}
#[test]
fn mem_config_conversion_into_cl_mem_flags() {
let mem_location = MemLocation::AllocOnDevice;
let host_access = HostAccess::ReadWrite;
let kernel_access = KernelAccess::ReadWrite;
let mem_config = MemConfig {
mem_location,
host_access,
kernel_access,
};
let expected = MemFlags::ALLOC_HOST_PTR.bits()
| MemFlags::HOST_READ_WRITE.bits()
| MemFlags::KERNEL_READ_WRITE.bits();
assert_eq!(cl_mem_flags::from(mem_config), expected);
}
}