use std::fmt;
use metal::Buffer as MetalBuffer;
use crate::dtypes::DType;
use crate::error::{MlxError, Result};
pub struct MlxBuffer {
inner: MetalBuffer,
dtype: DType,
shape: Vec<usize>,
}
crate::static_assertions_send_sync!(MlxBuffer);
impl MlxBuffer {
pub fn from_raw(inner: MetalBuffer, dtype: DType, shape: Vec<usize>) -> Self {
Self {
inner,
dtype,
shape,
}
}
#[inline]
pub fn dtype(&self) -> DType {
self.dtype
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn byte_len(&self) -> usize {
self.inner.length() as usize
}
#[inline]
pub fn element_count(&self) -> usize {
self.shape.iter().copied().product()
}
#[inline]
pub fn contents_ptr(&self) -> *mut std::ffi::c_void {
self.inner.contents()
}
#[inline]
pub fn metal_buffer(&self) -> &MetalBuffer {
&self.inner
}
#[inline]
pub(crate) fn into_inner(self) -> MetalBuffer {
self.inner
}
pub fn as_slice<T: bytemuck::Pod>(&self) -> Result<&[T]> {
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return Err(MlxError::InvalidArgument(
"Cannot view buffer as zero-sized type".into(),
));
}
let byte_len = self.byte_len();
if byte_len % elem_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
)));
}
let ptr = self.contents_ptr();
if ptr.is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
let count = byte_len / elem_size;
let slice = unsafe { std::slice::from_raw_parts(ptr as *const T, count) };
Ok(slice)
}
pub fn as_mut_slice<T: bytemuck::Pod>(&mut self) -> Result<&mut [T]> {
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return Err(MlxError::InvalidArgument(
"Cannot view buffer as zero-sized type".into(),
));
}
let byte_len = self.byte_len();
if byte_len % elem_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
)));
}
let ptr = self.contents_ptr();
if ptr.is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
let count = byte_len / elem_size;
let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) };
Ok(slice)
}
#[allow(dead_code)]
pub(crate) fn reshape(&mut self, dtype: DType, shape: Vec<usize>) {
self.dtype = dtype;
self.shape = shape;
}
}
impl fmt::Debug for MlxBuffer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MlxBuffer")
.field("dtype", &self.dtype)
.field("shape", &self.shape)
.field("byte_len", &self.byte_len())
.finish()
}
}