use std::fmt;
use std::sync::Arc;
use metal::Buffer as MetalBuffer;
use crate::dtypes::DType;
use crate::error::{MlxError, Result};
use crate::residency::ResidencySet;
pub struct MlxBuffer {
storage: Arc<MlxBufferStorage>,
dtype: DType,
shape: Vec<usize>,
byte_offset: u64,
}
pub(crate) struct MlxBufferStorage {
inner: MetalBuffer,
residency_set: Option<ResidencySet>,
}
impl Drop for MlxBufferStorage {
fn drop(&mut self) {
if let Some(set) = self.residency_set.as_ref() {
set.remove_allocation(&self.inner);
}
}
}
crate::static_assertions_send_sync!(MlxBuffer);
impl Clone for MlxBuffer {
fn clone(&self) -> Self {
Self {
storage: self.storage.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
byte_offset: self.byte_offset,
}
}
}
impl MlxBuffer {
pub fn from_raw(inner: MetalBuffer, dtype: DType, shape: Vec<usize>) -> Self {
Self {
storage: Arc::new(MlxBufferStorage {
inner,
residency_set: None,
}),
dtype,
shape,
byte_offset: 0,
}
}
pub(crate) fn with_residency(
inner: MetalBuffer,
dtype: DType,
shape: Vec<usize>,
residency_set: ResidencySet,
) -> Self {
residency_set.add_allocation(&inner);
Self {
storage: Arc::new(MlxBufferStorage {
inner,
residency_set: Some(residency_set),
}),
dtype,
shape,
byte_offset: 0,
}
}
#[inline]
pub fn slice_view(&self, byte_offset: u64, n_elements: usize) -> Self {
let end = byte_offset as usize + n_elements * self.dtype.size_of();
assert!(
end <= self.storage.inner.length() as usize,
"slice_view: out of bounds (byte_offset={}, n_elements={}, dtype_size={}, buf_len={})",
byte_offset,
n_elements,
self.dtype.size_of(),
self.storage.inner.length()
);
Self {
storage: self.storage.clone(),
dtype: self.dtype,
shape: vec![n_elements],
byte_offset,
}
}
#[inline]
pub fn dtype(&self) -> DType {
self.dtype
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn byte_len(&self) -> usize {
self.storage.inner.length() as usize
}
#[inline]
pub fn element_count(&self) -> usize {
self.shape.iter().copied().product()
}
#[inline]
pub fn contents_ptr(&self) -> *mut std::ffi::c_void {
self.storage.inner.contents()
}
#[inline]
pub fn metal_buffer(&self) -> &MetalBuffer {
&self.storage.inner
}
#[inline]
pub fn byte_offset(&self) -> u64 {
self.byte_offset
}
#[inline]
pub(crate) fn into_inner(self) -> MetalBuffer {
self.storage.inner.clone()
}
#[inline]
pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
self.storage.residency_set.as_ref()
}
pub fn as_slice<T: bytemuck::Pod>(&self) -> Result<&[T]> {
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return Err(MlxError::InvalidArgument(
"Cannot view buffer as zero-sized type".into(),
));
}
let byte_len = self.byte_len();
if byte_len % elem_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
)));
}
let ptr = self.contents_ptr();
if ptr.is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
let count = byte_len / elem_size;
let slice = unsafe { std::slice::from_raw_parts(ptr as *const T, count) };
Ok(slice)
}
pub fn as_mut_slice<T: bytemuck::Pod>(&mut self) -> Result<&mut [T]> {
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return Err(MlxError::InvalidArgument(
"Cannot view buffer as zero-sized type".into(),
));
}
let byte_len = self.byte_len();
if byte_len % elem_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
)));
}
let ptr = self.contents_ptr();
if ptr.is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
let count = byte_len / elem_size;
let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) };
Ok(slice)
}
#[allow(dead_code)]
pub(crate) fn reshape(&mut self, dtype: DType, shape: Vec<usize>) {
self.dtype = dtype;
self.shape = shape;
}
}
impl fmt::Debug for MlxBuffer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MlxBuffer")
.field("dtype", &self.dtype)
.field("shape", &self.shape)
.field("byte_len", &self.byte_len())
.finish()
}
}