use std::alloc::{self, Layout};
use std::ptr::NonNull;
use std::sync::Arc;
use crate::dtype::DType;
use crate::error::{Result, SapientError};
pub trait Buffer: Send + Sync + std::fmt::Debug {
fn as_bytes(&self) -> &[u8];
fn as_bytes_mut(&mut self) -> &mut [u8];
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn alignment(&self) -> usize;
fn device(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct BufferHandle(pub Arc<dyn Buffer>);
impl BufferHandle {
pub fn new(buf: impl Buffer + 'static) -> Self {
Self(Arc::new(buf))
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
pub struct CpuBuffer {
ptr: NonNull<u8>,
len: usize,
align: usize,
layout: Layout,
}
unsafe impl Send for CpuBuffer {}
unsafe impl Sync for CpuBuffer {}
impl std::fmt::Debug for CpuBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CpuBuffer")
.field("len", &self.len)
.field("align", &self.align)
.finish()
}
}
impl CpuBuffer {
pub fn zeros(numel: usize, dtype: DType) -> Result<Self> {
let bytes = dtype.byte_count(numel);
let align = dtype.alignment().max(64); Self::with_capacity(bytes, align)
}
pub fn with_capacity(bytes: usize, align: usize) -> Result<Self> {
if bytes == 0 {
let layout = Layout::from_size_align(1, align)
.map_err(|_| SapientError::AllocationFailed { bytes, align })?;
let ptr = unsafe { alloc::alloc_zeroed(layout) };
let ptr = NonNull::new(ptr).ok_or(SapientError::AllocationFailed { bytes, align })?;
return Ok(Self {
ptr,
len: 0,
align,
layout,
});
}
let layout = Layout::from_size_align(bytes, align)
.map_err(|_| SapientError::AllocationFailed { bytes, align })?;
let raw = unsafe { alloc::alloc_zeroed(layout) };
let ptr = NonNull::new(raw).ok_or(SapientError::AllocationFailed { bytes, align })?;
Ok(Self {
ptr,
len: bytes,
align,
layout,
})
}
pub fn from_f32_slice(data: &[f32]) -> Result<Self> {
let bytes = data.len() * 4;
let buf = Self::with_capacity(bytes, 64)?;
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr() as *const u8, buf.ptr.as_ptr(), bytes);
}
Ok(buf)
}
pub fn from_f32_vec(data: Vec<f32>) -> Result<Self> {
if data.is_empty() {
return Self::with_capacity(0, 4);
}
let len = data.len() * 4;
let layout = Layout::array::<f32>(data.len())
.map_err(|_| SapientError::AllocationFailed { bytes: len, align: 4 })?;
let ptr = data.as_ptr() as *mut u8;
std::mem::forget(data);
Ok(Self {
ptr: NonNull::new(ptr)
.ok_or(SapientError::AllocationFailed { bytes: len, align: 4 })?,
len,
align: std::mem::align_of::<f32>(),
layout,
})
}
pub fn from_bytes_slice(data: &[u8]) -> Result<Self> {
let bytes = data.len();
if bytes == 0 {
return Self::with_capacity(0, 16);
}
let buf = Self::with_capacity(bytes, 16)?;
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), buf.ptr.as_ptr(), bytes);
}
Ok(buf)
}
pub fn as_f32_slice(&self) -> &[f32] {
assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const f32, self.len / 4) }
}
pub fn as_f32_slice_mut(&mut self) -> &mut [f32] {
assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f32, self.len / 4) }
}
pub fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl Buffer for CpuBuffer {
fn as_bytes(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
fn as_bytes_mut(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}
fn len(&self) -> usize {
self.len
}
fn alignment(&self) -> usize {
self.align
}
fn device(&self) -> &str {
"cpu"
}
}
impl Drop for CpuBuffer {
fn drop(&mut self) {
unsafe { alloc::dealloc(self.ptr.as_ptr(), self.layout) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_and_read() {
let buf = CpuBuffer::zeros(4, DType::F32).unwrap();
assert_eq!(buf.len(), 16);
assert!(buf.as_bytes().iter().all(|&b| b == 0));
}
#[test]
fn from_f32_roundtrip() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let buf = CpuBuffer::from_f32_slice(&data).unwrap();
assert_eq!(buf.as_f32_slice(), data.as_slice());
}
#[test]
fn alignment_guarantee() {
let buf = CpuBuffer::with_capacity(32, 64).unwrap();
assert_eq!(buf.as_ptr() as usize % 64, 0);
}
}