use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use morok_dtype::DType;
use smallvec::{SmallVec, smallvec};
use morok_dtype::ext::HasDType;
use snafu::ResultExt;
use crate::allocator::{Allocator, BufferOptions, RawBuffer};
use crate::error::{
InvalidViewSnafu, NdarrayShapeSnafu, NotCpuAccessibleSnafu, Result, SizeMismatchSnafu, TypeMismatchSnafu,
};
static BUFFER_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_buffer_id() -> u64 {
BUFFER_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BufferId(pub u64);
#[cfg(feature = "cuda")]
use crate::error::CudaSnafu;
#[cfg(feature = "cuda")]
use snafu::ResultExt;
#[derive(Debug)]
struct BufferData {
id: BufferId,
raw: OnceLock<RawBuffer>,
allocator: Arc<dyn Allocator>,
total_size: usize,
options: BufferOptions,
}
impl BufferData {
fn new(allocator: Arc<dyn Allocator>, size: usize, options: BufferOptions) -> Self {
Self { id: BufferId(next_buffer_id()), raw: OnceLock::new(), allocator, total_size: size, options }
}
fn ensure_allocated(&self) -> Result<()> {
if self.raw.get().is_some() {
return Ok(());
}
let raw = self.allocator.alloc(self.total_size, &self.options)?;
if let Err(raw) = self.raw.set(raw) {
self.allocator.free(raw, &self.options);
}
Ok(())
}
fn is_allocated(&self) -> bool {
self.raw.get().is_some()
}
fn raw(&self) -> &RawBuffer {
self.raw.get().expect("buffer not allocated")
}
}
impl Drop for BufferData {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
self.allocator.free(raw, &self.options);
}
}
}
#[derive(Debug, Clone)]
pub struct Buffer {
data: Arc<BufferData>,
offset: usize,
size: usize,
dtype: DType,
shape: SmallVec<[usize; 4]>,
}
impl Buffer {
pub fn new(allocator: Arc<dyn Allocator>, dtype: DType, shape: Vec<usize>, options: BufferOptions) -> Self {
let size = dtype.bytes() * shape.iter().product::<usize>();
Self {
data: Arc::new(BufferData::new(allocator, size, options)),
offset: 0,
size,
dtype,
shape: SmallVec::from_vec(shape),
}
}
pub fn allocate(
allocator: Arc<dyn Allocator>,
dtype: DType,
shape: Vec<usize>,
options: BufferOptions,
) -> Result<Self> {
let buffer = Self::new(allocator, dtype, shape, options);
buffer.ensure_allocated()?;
Ok(buffer)
}
pub fn view(&self, offset: usize, size: usize) -> Result<Self> {
if offset + size > self.size {
return InvalidViewSnafu { offset, size, buffer_size: self.size }.fail();
}
Ok(Self {
data: Arc::clone(&self.data),
offset: self.offset + offset,
size,
dtype: self.dtype.clone(),
shape: smallvec![size / self.dtype.bytes()],
})
}
pub fn ensure_allocated(&self) -> Result<()> {
self.data.ensure_allocated()
}
pub fn is_allocated(&self) -> bool {
self.data.is_allocated()
}
pub fn size(&self) -> usize {
self.size
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn dtype(&self) -> DType {
self.dtype.clone()
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn as_host_bytes(&self) -> Result<&[u8]> {
self.ensure_allocated()?;
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
Ok(bytes)
}
RawBuffer::Mmap { data, .. } => Ok(&data[self.offset..self.offset + self.size]),
#[cfg(feature = "cuda")]
_ => NotCpuAccessibleSnafu.fail(),
}
}
#[allow(clippy::mut_from_ref)] pub fn as_host_bytes_mut(&self) -> Result<&mut [u8]> {
self.ensure_allocated()?;
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
let bytes = unsafe { &mut (&mut *data.get())[self.offset..self.offset + self.size] };
Ok(bytes)
}
RawBuffer::Mmap { .. } => NotCpuAccessibleSnafu.fail(),
#[cfg(feature = "cuda")]
_ => NotCpuAccessibleSnafu.fail(),
}
}
pub fn as_array<T: HasDType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
self.ensure_allocated()?;
if self.dtype != T::DTYPE {
return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
}
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
let count = bytes.len() / T::DTYPE.bytes();
let typed = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) };
ndarray::ArrayViewD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
}
RawBuffer::Mmap { data, .. } => {
let bytes = &data[self.offset..self.offset + self.size];
let count = bytes.len() / T::DTYPE.bytes();
let typed = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) };
ndarray::ArrayViewD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
}
#[cfg(feature = "cuda")]
_ => NotCpuAccessibleSnafu.fail(),
}
}
#[allow(clippy::mut_from_ref)]
pub fn as_array_mut<T: HasDType>(&self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
self.ensure_allocated()?;
if self.dtype != T::DTYPE {
return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
}
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, cpu_accessible } if *cpu_accessible => {
let bytes = unsafe { &mut (&mut *data.get())[self.offset..self.offset + self.size] };
let count = bytes.len() / T::DTYPE.bytes();
let typed = unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut T, count) };
ndarray::ArrayViewMutD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
}
_ => NotCpuAccessibleSnafu.fail(),
}
}
pub fn as_slice<T: HasDType>(&self) -> Result<&[T]> {
self.ensure_allocated()?;
if self.dtype != T::DTYPE {
return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
}
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, cpu_accessible } if *cpu_accessible => {
let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
let count = bytes.len() / T::DTYPE.bytes();
Ok(unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) })
}
_ => NotCpuAccessibleSnafu.fail(),
}
}
pub fn item<T: HasDType + Copy>(&self) -> Result<T> {
let slice = self.as_slice::<T>()?;
assert_eq!(slice.len(), 1, "item() requires exactly 1 element, got {}", slice.len());
Ok(slice[0])
}
pub fn allocator(&self) -> &dyn Allocator {
&*self.data.allocator
}
pub fn id(&self) -> BufferId {
self.data.id
}
pub fn copyin(&mut self, src: &[u8]) -> Result<()> {
self.ensure_allocated()?;
let expected = self.size;
let actual = src.len();
snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
let slice = unsafe {
let data_mut = &mut *data.get();
&mut data_mut[self.offset..self.offset + self.size]
};
slice.copy_from_slice(src);
Ok(())
}
RawBuffer::Mmap { .. } => panic!("DISK device is read-only: copyin not supported"),
#[cfg(feature = "cuda")]
RawBuffer::CudaDevice { data, device } => {
let cuda_data = unsafe { &mut *data.get() };
let mut view = cuda_data.slice_mut(self.offset..self.offset + self.size);
device.default_stream().memcpy_htod(src, &mut view).context(CudaSnafu)
}
#[cfg(feature = "cuda")]
RawBuffer::CudaUnified { data, .. } => {
let unified_data = unsafe { &mut *data.get() };
let slice = unified_data.as_mut_slice().context(CudaSnafu)?;
let target = &mut slice[self.offset..self.offset + self.size];
target.copy_from_slice(src);
Ok(())
}
}
}
pub fn copyout(&self, dst: &mut [u8]) -> Result<()> {
self.ensure_allocated()?;
let expected = self.size;
let actual = dst.len();
snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
let data_ref = unsafe { &*data.get() };
dst.copy_from_slice(&data_ref[self.offset..self.offset + self.size]);
Ok(())
}
RawBuffer::Mmap { data, .. } => {
dst.copy_from_slice(&data[self.offset..self.offset + self.size]);
Ok(())
}
#[cfg(feature = "cuda")]
RawBuffer::CudaDevice { data, device } => {
device.synchronize().context(CudaSnafu)?;
let cuda_data = unsafe { &*data.get() };
let view = cuda_data.slice(self.offset..self.offset + self.size);
device.default_stream().memcpy_dtoh(&view, dst).context(CudaSnafu)
}
#[cfg(feature = "cuda")]
RawBuffer::CudaUnified { data, .. } => {
let unified_data = unsafe { &*data.get() };
let slice = unified_data.as_slice().context(CudaSnafu)?;
let source = &slice[self.offset..self.offset + self.size];
dst.copy_from_slice(source);
Ok(())
}
}
}
pub fn copy_from(&mut self, src: &Buffer) -> Result<()> {
self.ensure_allocated()?;
src.ensure_allocated()?;
let expected = self.size;
let actual = src.size;
snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
let dst_raw = self.data.raw();
let src_raw = src.data.raw();
match (dst_raw, src_raw) {
(RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::Cpu { data: src_data, .. }) => {
let dst_mut = unsafe { &mut *dst_data.get() };
let src_ref = unsafe { &*src_data.get() };
let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
let src_slice = &src_ref[src.offset..src.offset + src.size];
dst_slice.copy_from_slice(src_slice);
Ok(())
}
(RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::Mmap { data: src_data, .. }) => {
let dst_mut = unsafe { &mut *dst_data.get() };
let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
let src_slice = &src_data[src.offset..src.offset + src.size];
dst_slice.copy_from_slice(src_slice);
Ok(())
}
(RawBuffer::Mmap { .. }, _) => panic!("DISK device is read-only: copy_from not supported"),
#[cfg(feature = "cuda")]
(
RawBuffer::CudaDevice { data: dst_data, device: dst_device },
RawBuffer::CudaDevice { data: src_data, .. },
) => {
let dst_cuda = unsafe { &mut *dst_data.get() };
let src_cuda = unsafe { &*src_data.get() };
let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
let src_view = src_cuda.slice(src.offset..src.offset + src.size);
dst_device.default_stream().memcpy_dtod(&src_view, &mut dst_view).context(CudaSnafu)
}
#[cfg(feature = "cuda")]
(RawBuffer::CudaDevice { data: dst_data, device }, RawBuffer::Cpu { data: src_data, .. }) => {
let dst_cuda = unsafe { &mut *dst_data.get() };
let src_ref = unsafe { &*src_data.get() };
let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
let src_slice = &src_ref[src.offset..src.offset + src.size];
device.default_stream().memcpy_htod(src_slice, &mut dst_view).context(CudaSnafu)
}
#[cfg(feature = "cuda")]
(RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::CudaDevice { data: src_data, device }) => {
let dst_mut = unsafe { &mut *dst_data.get() };
let src_cuda = unsafe { &*src_data.get() };
let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
let src_view = src_cuda.slice(src.offset..src.offset + src.size);
device.default_stream().memcpy_dtoh(&src_view, dst_slice).context(CudaSnafu)
}
#[cfg(feature = "cuda")]
(RawBuffer::CudaUnified { data: dst_data, .. }, RawBuffer::CudaUnified { data: src_data, .. }) => {
let dst_unified = unsafe { &mut *dst_data.get() };
let src_unified = unsafe { &*src_data.get() };
let dst_slice = dst_unified.as_mut_slice().context(CudaSnafu)?;
let src_slice = src_unified.as_slice().context(CudaSnafu)?;
let dst_target = &mut dst_slice[self.offset..self.offset + self.size];
let src_source = &src_slice[src.offset..src.offset + src.size];
dst_target.copy_from_slice(src_source);
Ok(())
}
#[cfg(feature = "cuda")]
(RawBuffer::CudaUnified { data: dst_data, .. }, RawBuffer::Cpu { data: src_data, .. }) => {
let dst_unified = unsafe { &mut *dst_data.get() };
let src_ref = unsafe { &*src_data.get() };
let dst_slice = dst_unified.as_mut_slice().context(CudaSnafu)?;
let dst_target = &mut dst_slice[self.offset..self.offset + self.size];
let src_source = &src_ref[src.offset..src.offset + src.size];
dst_target.copy_from_slice(src_source);
Ok(())
}
#[cfg(feature = "cuda")]
(RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::CudaUnified { data: src_data, .. }) => {
let dst_mut = unsafe { &mut *dst_data.get() };
let src_unified = unsafe { &*src_data.get() };
let src_slice = src_unified.as_slice().context(CudaSnafu)?;
let dst_target = &mut dst_mut[self.offset..self.offset + self.size];
let src_source = &src_slice[src.offset..src.offset + src.size];
dst_target.copy_from_slice(src_source);
Ok(())
}
#[cfg(feature = "cuda")]
(
RawBuffer::CudaUnified { data: dst_data, device: dst_device },
RawBuffer::CudaDevice { data: src_data, .. },
) => {
let src_cuda = unsafe { &*src_data.get() };
let src_view = src_cuda.slice(src.offset..src.offset + src.size);
let dst_unified = unsafe { &mut *dst_data.get() };
let mut dst_target = dst_unified.slice_mut(self.offset..self.offset + self.size);
dst_device.default_stream().memcpy_dtod(&src_view, &mut dst_target).context(CudaSnafu)
}
#[cfg(feature = "cuda")]
(RawBuffer::CudaDevice { data: dst_data, device }, RawBuffer::CudaUnified { data: src_data, .. }) => {
let dst_cuda = unsafe { &mut *dst_data.get() };
let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
let src_unified = unsafe { &*src_data.get() };
let src_source = src_unified.slice(src.offset..src.offset + src.size);
device.default_stream().memcpy_htod(&src_source, &mut dst_view).context(CudaSnafu)
}
}
}
pub fn synchronize(&self) -> Result<()> {
self.data.allocator.synchronize()
}
pub unsafe fn as_raw_ptr(&self) -> *mut u8 {
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
unsafe { (&mut *data.get()).as_mut_ptr().add(self.offset) }
}
RawBuffer::Mmap { data, .. } => {
unsafe { data.as_ptr().add(self.offset) as *mut u8 }
}
#[cfg(feature = "cuda")]
RawBuffer::CudaDevice { .. } | RawBuffer::CudaUnified { .. } => {
unimplemented!("CUDA buffer raw pointers not yet supported for kernel execution")
}
}
}
#[cfg(test)]
pub(crate) fn raw_data_ptr(&self) -> usize {
let raw = self.data.raw();
match raw {
RawBuffer::Cpu { data, .. } => {
unsafe { (*data.get()).as_ptr() as usize }
}
RawBuffer::Mmap { data, .. } => data.as_ptr() as usize,
#[cfg(feature = "cuda")]
RawBuffer::CudaDevice { data, .. } => {
unsafe { &*data.get() as *const _ as usize }
}
#[cfg(feature = "cuda")]
RawBuffer::CudaUnified { data, .. } => {
unsafe { &*data.get() as *const _ as usize }
}
}
}
}