use std::ops::Deref;
use ocl::Buffer;
use crate::buffer::{BufferConverter, BufferInstance, BufferMut};
use crate::opencl::OpenCL;
use crate::{Error, Number};
impl<T: Number> BufferInstance<T> for Buffer<T> {
fn read(&self) -> BufferConverter<'_, T> {
BufferConverter::CL(self.into())
}
fn read_value(&self, offset: usize) -> Result<T, Error> {
if offset < self.len() {
let slice = self.map().offset(offset).len(1).read();
let value = unsafe { slice.enq()? };
let value = value.first().copied().expect("value");
Ok(value)
} else {
Err(Error::bounds(format!(
"invalid offset {offset} for a buffer of length {}",
self.len()
)))
}
}
fn len(&self) -> usize {
self.len()
}
}
impl<T: Number> BufferMut<T> for Buffer<T> {
fn cl(&mut self) -> Result<&mut Buffer<T>, Error> {
Ok(self)
}
fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
if data.len() == Buffer::len(self) {
let data = data.to_cl()?;
data.copy(self, None, None).enq().map_err(Error::from)
} else {
Err(Error::bounds(format!(
"cannot overwrite a buffer of size {} with one of size {}",
Buffer::len(self),
data.len()
)))
}
}
fn write_value(&mut self, value: T) -> Result<(), Error> {
let buf = Buffer::builder()
.context(OpenCL::context())
.len(Buffer::len(self))
.fill_val(value)
.build()?;
*self = buf;
Ok(())
}
fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
if offset < Buffer::len(self) {
let slice = self.map().offset(offset).len(1).read();
let mut slice = unsafe { slice.enq()? };
slice.as_mut()[0] = value;
Ok(())
} else {
Err(Error::bounds(format!(
"invalid offset {offset} for a buffer of length {}",
Buffer::len(self)
)))
}
}
}
impl<T: Number> BufferInstance<T> for &Buffer<T> {
fn read(&self) -> BufferConverter<'_, T> {
BufferConverter::CL((*self).into())
}
fn read_value(&self, offset: usize) -> Result<T, Error> {
BufferInstance::read_value(*self, offset)
}
fn len(&self) -> usize {
Buffer::len(self)
}
}
impl<T: Number> BufferInstance<T> for &mut Buffer<T> {
fn read(&self) -> BufferConverter<'_, T> {
BufferConverter::CL((&**self).into())
}
fn read_value(&self, offset: usize) -> Result<T, Error> {
BufferInstance::read_value(*self, offset)
}
fn len(&self) -> usize {
Buffer::<T>::len(self)
}
}
impl<T: Number> BufferMut<T> for &mut Buffer<T> {
fn cl(&mut self) -> Result<&mut Buffer<T>, Error> {
Ok(*self)
}
fn write<'b>(&mut self, data: BufferConverter<'b, T>) -> Result<(), Error> {
BufferMut::write(&mut **self, data)
}
fn write_value(&mut self, value: T) -> Result<(), Error> {
BufferMut::write_value(&mut **self, value)
}
fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
BufferMut::write_value_at(&mut **self, offset, value)
}
}
#[derive(Clone)]
pub enum CLConverter<'a, T: Number> {
Owned(Buffer<T>),
Borrowed(&'a Buffer<T>),
}
#[cfg(feature = "opencl")]
impl<'a, T: Number> CLConverter<'a, T> {
pub fn into_buffer(self) -> Result<Buffer<T>, Error> {
match self {
Self::Owned(buffer) => Ok(buffer),
Self::Borrowed(buffer) => {
let cl_queue = buffer.default_queue().expect("OpenCL queue");
let copy = Buffer::builder()
.queue(cl_queue.clone())
.len(buffer.len())
.build()?;
buffer.copy(©, None, None).enq()?;
Ok(copy)
}
}
}
pub fn len(&self) -> usize {
match self {
Self::Owned(buffer) => Buffer::len(buffer),
Self::Borrowed(buffer) => Buffer::len(buffer),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(feature = "opencl")]
impl<'a, T: Number> Deref for CLConverter<'a, T> {
type Target = Buffer<T>;
fn deref(&self) -> &Buffer<T> {
match self {
Self::Owned(buffer) => buffer,
Self::Borrowed(buffer) => buffer,
}
}
}
impl<T: Number> From<Buffer<T>> for CLConverter<'static, T> {
fn from(buf: Buffer<T>) -> Self {
Self::Owned(buf)
}
}
impl<'a, T: Number> From<&'a Buffer<T>> for CLConverter<'a, T> {
fn from(buf: &'a Buffer<T>) -> Self {
Self::Borrowed(buf)
}
}