Skip to main content

mlx_native/
buffer.rs

1//! [`MlxBuffer`] — typed wrapper around a Metal GPU buffer.
2//!
3//! Buffers are allocated with `StorageModeShared` so that CPU and GPU share
4//! the same physical memory on Apple Silicon (zero-copy access via
5//! [`as_slice`](MlxBuffer::as_slice) / [`as_mut_slice`](MlxBuffer::as_mut_slice)).
6
7use std::fmt;
8
9use metal::Buffer as MetalBuffer;
10
11use crate::dtypes::DType;
12use crate::error::{MlxError, Result};
13
14/// A Metal GPU buffer annotated with element dtype and tensor shape.
15///
16/// On Apple Silicon the underlying memory is unified — `contents_ptr()` gives
17/// direct CPU access without any copy or transfer.
18///
19/// # Thread Safety
20///
21/// `MlxBuffer` is `Send + Sync` because the inner `metal::Buffer` is.
22pub struct MlxBuffer {
23    /// The underlying Metal buffer (StorageModeShared).
24    inner: MetalBuffer,
25    /// Element data type.
26    dtype: DType,
27    /// Tensor shape (e.g. `[2, 3, 4]` for a rank-3 tensor).
28    shape: Vec<usize>,
29    /// Byte offset into the underlying Metal buffer (for slice views).
30    /// Zero for normally-allocated buffers.
31    byte_offset: u64,
32}
33
34// metal::Buffer is Send + Sync; our extra fields (DType, Vec<usize>) are too.
35crate::static_assertions_send_sync!(MlxBuffer);
36
37impl Clone for MlxBuffer {
38    /// Increment the Metal buffer's ARC retain count and wrap it in a new
39    /// `MlxBuffer`.  Both the original and the clone refer to the same
40    /// underlying GPU allocation — no data is copied.
41    ///
42    /// This is safe because `metal::Buffer` wraps an `MTLBuffer` Objective-C
43    /// object whose lifetime is managed by ARC; `Clone` calls `retain` and
44    /// `drop` calls `release`.
45    fn clone(&self) -> Self {
46        Self {
47            inner: self.inner.clone(),
48            dtype: self.dtype,
49            shape: self.shape.clone(),
50            byte_offset: self.byte_offset,
51        }
52    }
53}
54
55impl MlxBuffer {
56    /// Create a new `MlxBuffer` wrapping an already-allocated Metal buffer.
57    ///
58    /// # When to use
59    ///
60    /// Use this to wrap Metal buffers obtained from external frameworks (e.g.
61    /// candle's `MetalStorage::buffer()`) for zero-copy interop on Apple
62    /// Silicon unified memory.  Both frameworks see the same physical memory.
63    ///
64    /// # Safety contract
65    ///
66    /// The caller must ensure that `inner` remains valid for the lifetime of
67    /// the returned `MlxBuffer`.  If the buffer was obtained from another
68    /// framework, the caller must ensure that framework does not deallocate
69    /// the buffer while this `MlxBuffer` exists.
70    pub fn from_raw(inner: MetalBuffer, dtype: DType, shape: Vec<usize>) -> Self {
71        Self {
72            inner,
73            dtype,
74            shape,
75            byte_offset: 0,
76        }
77    }
78
79    /// Create a zero-copy slice view of this buffer.
80    ///
81    /// Returns a new `MlxBuffer` that shares the same underlying Metal buffer
82    /// but starts at `byte_offset` bytes from the beginning and contains
83    /// `n_elements` elements of type `dtype`. No data is copied.
84    ///
85    /// When this view is bound to a kernel, the encoder passes the byte offset
86    /// to Metal's `setBuffer:offset:atIndex:`, so the kernel sees only the
87    /// slice region.
88    ///
89    /// # Panics
90    ///
91    /// Panics if `byte_offset + n_elements * dtype.size_of() > self.inner.length()`.
92    #[inline]
93    pub fn slice_view(&self, byte_offset: u64, n_elements: usize) -> Self {
94        let end = byte_offset as usize + n_elements * self.dtype.size_of();
95        assert!(
96            end <= self.inner.length() as usize,
97            "slice_view: out of bounds (byte_offset={}, n_elements={}, dtype_size={}, buf_len={})",
98            byte_offset, n_elements, self.dtype.size_of(), self.inner.length()
99        );
100        Self {
101            inner: self.inner.clone(),
102            dtype: self.dtype,
103            shape: vec![n_elements],
104            byte_offset,
105        }
106    }
107
108    // ---- accessors ----
109
110    /// Element data type.
111    #[inline]
112    pub fn dtype(&self) -> DType {
113        self.dtype
114    }
115
116    /// Tensor shape (dimensions).
117    #[inline]
118    pub fn shape(&self) -> &[usize] {
119        &self.shape
120    }
121
122    /// Total byte length of the Metal buffer.
123    #[inline]
124    pub fn byte_len(&self) -> usize {
125        self.inner.length() as usize
126    }
127
128    /// Number of elements (product of shape dimensions, or `byte_len / dtype.size_of()`).
129    #[inline]
130    pub fn element_count(&self) -> usize {
131        self.shape.iter().copied().product()
132    }
133
134    /// Raw pointer to the buffer contents (CPU-accessible on Apple Silicon).
135    ///
136    /// # Safety
137    ///
138    /// The caller must ensure proper synchronization — do not read while a GPU
139    /// command buffer that writes this buffer is in flight.
140    #[inline]
141    pub fn contents_ptr(&self) -> *mut std::ffi::c_void {
142        self.inner.contents()
143    }
144
145    /// Reference to the underlying `metal::Buffer` for passing to the encoder.
146    #[inline]
147    pub fn metal_buffer(&self) -> &MetalBuffer {
148        &self.inner
149    }
150
151    /// Byte offset into the underlying Metal buffer (zero for non-slice buffers).
152    ///
153    /// When passing this buffer to a Metal kernel via `setBuffer:offset:atIndex:`,
154    /// use this offset so the kernel sees only the intended sub-region.
155    #[inline]
156    pub fn byte_offset(&self) -> u64 {
157        self.byte_offset
158    }
159
160    /// Consume self and return the inner `metal::Buffer` (used by buffer pool).
161    #[inline]
162    pub(crate) fn into_inner(self) -> MetalBuffer {
163        self.inner
164    }
165
166    // ---- typed CPU access (zero-copy on unified memory) ----
167
168    /// View the buffer contents as a typed slice.
169    ///
170    /// Returns an error if the buffer byte length is not an exact multiple of
171    /// `size_of::<T>()`.
172    ///
173    /// # Safety contract
174    ///
175    /// The caller must ensure:
176    /// 1. `T` matches the actual element type stored in the buffer.
177    /// 2. No GPU command buffer that writes this buffer is currently in flight.
178    pub fn as_slice<T: bytemuck::Pod>(&self) -> Result<&[T]> {
179        let elem_size = std::mem::size_of::<T>();
180        if elem_size == 0 {
181            return Err(MlxError::InvalidArgument(
182                "Cannot view buffer as zero-sized type".into(),
183            ));
184        }
185        let byte_len = self.byte_len();
186        if byte_len % elem_size != 0 {
187            return Err(MlxError::InvalidArgument(format!(
188                "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
189            )));
190        }
191        let ptr = self.contents_ptr();
192        if ptr.is_null() {
193            return Err(MlxError::BufferAllocationError { bytes: byte_len });
194        }
195        let count = byte_len / elem_size;
196        // SAFETY: Metal guarantees the pointer is valid for `byte_len` bytes and
197        // properly aligned for any type on Apple Silicon shared memory.  The
198        // caller upholds the type-match and no-concurrent-GPU-write contract.
199        let slice = unsafe { std::slice::from_raw_parts(ptr as *const T, count) };
200        Ok(slice)
201    }
202
203    /// View the buffer contents as a mutable typed slice.
204    ///
205    /// Same safety contract as [`as_slice`](Self::as_slice), plus: the caller
206    /// must ensure exclusive access (no other references to this buffer's memory
207    /// exist).
208    pub fn as_mut_slice<T: bytemuck::Pod>(&mut self) -> Result<&mut [T]> {
209        let elem_size = std::mem::size_of::<T>();
210        if elem_size == 0 {
211            return Err(MlxError::InvalidArgument(
212                "Cannot view buffer as zero-sized type".into(),
213            ));
214        }
215        let byte_len = self.byte_len();
216        if byte_len % elem_size != 0 {
217            return Err(MlxError::InvalidArgument(format!(
218                "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
219            )));
220        }
221        let ptr = self.contents_ptr();
222        if ptr.is_null() {
223            return Err(MlxError::BufferAllocationError { bytes: byte_len });
224        }
225        let count = byte_len / elem_size;
226        // SAFETY: same as as_slice, plus caller ensures exclusive mutable access.
227        let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) };
228        Ok(slice)
229    }
230
231    /// Overwrite the dtype and shape metadata.
232    ///
233    /// This does **not** re-allocate the Metal buffer — it only changes the
234    /// logical interpretation.  The caller must ensure the new shape is
235    /// consistent with the buffer's byte length.
236    #[allow(dead_code)]
237    pub(crate) fn reshape(&mut self, dtype: DType, shape: Vec<usize>) {
238        self.dtype = dtype;
239        self.shape = shape;
240    }
241}
242
243impl fmt::Debug for MlxBuffer {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        f.debug_struct("MlxBuffer")
246            .field("dtype", &self.dtype)
247            .field("shape", &self.shape)
248            .field("byte_len", &self.byte_len())
249            .finish()
250    }
251}