use crate::VBuffer;
use super::CLDevice;
use anyhow::{Context, Result, bail};
use opencl3::{
command_queue::CommandQueue,
device::{self as cl_device},
memory::{self as cl_memory, Buffer, ClMem},
types,
};
use std::ptr;
use std::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CLBufferConfig {
pub size: usize,
pub mmap: bool,
pub device_index: usize,
pub platform_index: usize,
pub device: u64,
}
impl CLBufferConfig {
pub fn with_cpu(&mut self) {
self.device = cl_device::CL_DEVICE_TYPE_CPU;
}
}
impl Default for CLBufferConfig {
fn default() -> Self {
Self {
size: 2048 * 1024 * 1024, device_index: 0,
platform_index: 0,
mmap: false,
device: cl_device::CL_DEVICE_TYPE_GPU | cl_device::CL_DEVICE_TYPE_ACCELERATOR,
}
}
}
pub struct CLBuffer {
queue: CommandQueue,
buffer: RwLock<Buffer<u8>>,
offset: u64,
size: usize,
mmap: bool,
}
impl CLBuffer {
pub fn new(device: &CLDevice, size: usize, mmap: bool) -> Result<Self> {
let queue = device.create_queue()?;
let buffer = RwLock::new(device.create_buffer(&queue, size)?);
Ok(Self {
queue,
buffer,
offset: 0,
size,
mmap,
})
}
#[inline]
fn within(&self, offset: u64) -> bool {
offset >= self.offset && offset < self.offset + self.size as u64
}
}
impl VBuffer for CLBuffer {
fn remaining(&self, offset: u64) -> Option<usize> {
if self.within(offset) {
Some((self.size as u64 + self.offset - offset) as usize)
} else {
None
}
}
fn size(&self) -> usize {
self.size
}
fn offset(&mut self, offset: u64) {
self.offset = offset;
}
fn read(&self, offset: u64, data: &mut [u8]) -> Result<()> {
if !self.within(offset) {
bail!("Attempted to read out of buffer");
}
let local_offset = (offset - self.offset) as usize;
let length = data.len();
if local_offset + length > self.size {
bail!("Attempted to read past end of buffer");
}
unsafe {
if self.mmap {
let buffer_guard = self
.buffer
.write()
.map_err(|_| anyhow::anyhow!("Failed to lock buffer RwLock for read"))?;
let mut host_ptr = ptr::null_mut();
let _ = self
.queue
.enqueue_map_buffer(
&*buffer_guard,
types::CL_TRUE,
cl_memory::CL_MEM_READ_ONLY,
local_offset,
length,
&mut host_ptr,
&[],
)
.context("Failed to mmap from buffer")?;
data.as_mut_ptr()
.copy_from_nonoverlapping(host_ptr as *mut u8, length);
let _ = self
.queue
.enqueue_unmap_mem_object(buffer_guard.get(), host_ptr, &[])
.context("Failed to unmmap from buffer")?
.wait();
} else {
let buffer_guard = self
.buffer
.read()
.map_err(|_| anyhow::anyhow!("Failed to lock buffer RwLock for read"))?;
self.queue
.enqueue_read_buffer(&*buffer_guard, types::CL_TRUE, local_offset, data, &[])
.context("Failed to enqueue blocking read from buffer")?;
}
}
Ok(())
}
fn write(&self, offset: u64, data: &[u8]) -> Result<()> {
if !self.within(offset) {
bail!("Attempted to write out of buffer");
}
let local_offset = (offset - self.offset) as usize;
let length = data.len();
if local_offset + length > self.size {
bail!("Attempted to write past end of buffer");
}
let mut buffer_guard = self
.buffer
.write()
.map_err(|_| anyhow::anyhow!("Failed to lock buffer RwLock for write"))?;
unsafe {
if self.mmap {
let mut host_ptr = ptr::null_mut();
let _ = self
.queue
.enqueue_map_buffer(
&*buffer_guard,
types::CL_TRUE,
cl_memory::CL_MEM_WRITE_ONLY,
local_offset,
length,
&mut host_ptr,
&[],
)
.context("Failed to mmap from buffer")?;
data.as_ptr()
.copy_to_nonoverlapping(host_ptr as *mut u8, length);
let _ = self
.queue
.enqueue_unmap_mem_object(buffer_guard.get(), host_ptr, &[])
.context("Failed to unmmap from buffer")?
.wait();
} else {
self.queue
.enqueue_write_buffer(
&mut *buffer_guard,
types::CL_TRUE,
local_offset,
data,
&[],
)
.context("Failed to enqueue blocking write to buffer")?;
}
}
Ok(())
}
}
impl Drop for CLBuffer {
fn drop(&mut self) {
log::debug!("Freeing OCL memory buffer");
}
}