Skip to main content

mlx_native/
buffer.rs

1//! [`MlxBuffer`] — typed wrapper around a Metal GPU buffer.
2//!
3//! Buffers are allocated with `StorageModeShared` so that CPU and GPU share
4//! the same physical memory on Apple Silicon (zero-copy access via
5//! [`as_slice`](MlxBuffer::as_slice) / [`as_mut_slice`](MlxBuffer::as_mut_slice)).
6
7use std::fmt;
8use std::sync::Arc;
9
10use metal::Buffer as MetalBuffer;
11
12use crate::dtypes::DType;
13use crate::error::{MlxError, Result};
14use crate::residency::ResidencySet;
15
16/// A Metal GPU buffer annotated with element dtype and tensor shape.
17///
18/// On Apple Silicon the underlying memory is unified — `contents_ptr()` gives
19/// direct CPU access without any copy or transfer.
20///
21/// # Thread Safety
22///
23/// `MlxBuffer` is `Send + Sync` because the inner `metal::Buffer` is.
24///
25/// # Residency-set lifecycle
26///
27/// Buffers produced by [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer)
28/// on a residency-enabled device carry a shared
29/// [`Arc<MlxBufferStorage>`](MlxBufferStorage) that owns the residency-set
30/// reference and runs `removeAllocation:` (deferred — flushed at the next
31/// `CommandEncoder::commit*` boundary) when the last clone is dropped.
32/// Mirrors llama.cpp's `ggml-metal-device.m:1378-1382` pattern: batch
33/// `addAllocation:` calls in a loop, commit ONCE.
34pub struct MlxBuffer {
35    /// The underlying Metal buffer (StorageModeShared) plus optional
36    /// residency-set membership guard.
37    storage: Arc<MlxBufferStorage>,
38    /// Element data type.
39    dtype: DType,
40    /// Tensor shape (e.g. `[2, 3, 4]` for a rank-3 tensor).
41    shape: Vec<usize>,
42    /// Byte offset into the underlying Metal buffer (for slice views).
43    /// Zero for normally-allocated buffers.
44    byte_offset: u64,
45}
46
47/// Owns a single Metal buffer allocation plus an optional residency-set
48/// membership guard.
49///
50/// Wrapped in [`Arc`] inside [`MlxBuffer`] so that [`Clone`] / [`slice_view`]
51/// share both the underlying Metal allocation and the residency-set
52/// registration. The Drop fires `removeAllocation:` only when the LAST clone
53/// goes out of scope — matching llama.cpp's `addAllocation:` /
54/// `removeAllocation:` lifecycle in `ggml-metal-device.m:1378-1382` and
55/// `ggml-metal-device.m:1397-1399`.
56///
57/// Drop is **deferred**: it calls `set.remove_allocation(buffer)` which marks
58/// the residency set's pending flag but does NOT call `[set commit]`. The
59/// commit is flushed at the next [`CommandEncoder::commit*`] boundary via
60/// [`ResidencySet::flush_pending`]. This collapses the per-allocation commit
61/// storm (~880 commits/decode-token in iter8d/8e claude+codex variants) into
62/// at most one commit per CB submission.
63pub(crate) struct MlxBufferStorage {
64    inner: MetalBuffer,
65    residency_set: Option<ResidencySet>,
66}
67
68impl Drop for MlxBufferStorage {
69    fn drop(&mut self) {
70        if let Some(set) = self.residency_set.as_ref() {
71            // Mirror ggml-metal-device.m:1397-1399 free-path semantics, but
72            // deferred — the actual `[set commit]` is issued at the next
73            // CommandEncoder::commit* boundary by flush_pending().
74            set.remove_allocation(&self.inner);
75        }
76    }
77}
78
79// metal::Buffer is Send + Sync; our extra fields (DType, Vec<usize>) are too.
80crate::static_assertions_send_sync!(MlxBuffer);
81
82impl Clone for MlxBuffer {
83    /// Increment the storage's `Arc` ref-count and wrap it in a new
84    /// `MlxBuffer`. Both the original and the clone refer to the same
85    /// underlying GPU allocation AND share the residency-set membership
86    /// guard — no data is copied, no double-registration occurs.
87    ///
88    /// This is safe because `metal::Buffer` wraps an `MTLBuffer` Objective-C
89    /// object whose lifetime is managed by ARC; `Arc::clone` increments the
90    /// Rust-side refcount, and the inner `MlxBufferStorage` Drop runs once
91    /// when the last clone is released.
92    fn clone(&self) -> Self {
93        Self {
94            storage: self.storage.clone(),
95            dtype: self.dtype,
96            shape: self.shape.clone(),
97            byte_offset: self.byte_offset,
98        }
99    }
100}
101
102impl MlxBuffer {
103    /// Create a new `MlxBuffer` wrapping an already-allocated Metal buffer.
104    ///
105    /// # When to use
106    ///
107    /// Use this to wrap Metal buffers obtained from external frameworks (e.g.
108    /// candle's `MetalStorage::buffer()`) for zero-copy interop on Apple
109    /// Silicon unified memory.  Both frameworks see the same physical memory.
110    ///
111    /// # Safety contract
112    ///
113    /// The caller must ensure that `inner` remains valid for the lifetime of
114    /// the returned `MlxBuffer`.  If the buffer was obtained from another
115    /// framework, the caller must ensure that framework does not deallocate
116    /// the buffer while this `MlxBuffer` exists.
117    ///
118    /// The returned buffer carries no residency-set guard — pool / external
119    /// callers that want residency tracking should go through
120    /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer) or
121    /// [`MlxBufferPool::register_existing`](crate::MlxBufferPool::register_existing).
122    pub fn from_raw(inner: MetalBuffer, dtype: DType, shape: Vec<usize>) -> Self {
123        Self {
124            storage: Arc::new(MlxBufferStorage {
125                inner,
126                residency_set: None,
127            }),
128            dtype,
129            shape,
130            byte_offset: 0,
131        }
132    }
133
134    /// Create a new buffer and stage its Metal allocation for inclusion in
135    /// the given residency set.
136    ///
137    /// Calls `set.add_allocation(buffer)` (deferred — no `[set commit]` until
138    /// the next [`flush_pending`](ResidencySet::flush_pending) at a
139    /// `CommandEncoder::commit*` boundary). The buffer's residency-set guard
140    /// is dropped when the last clone of the returned `MlxBuffer` (and any
141    /// slice views) goes out of scope, which fires the matching
142    /// `removeAllocation:` (also deferred).
143    ///
144    /// Crate-private — external callers should go through
145    /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer).
146    pub(crate) fn with_residency(
147        inner: MetalBuffer,
148        dtype: DType,
149        shape: Vec<usize>,
150        residency_set: ResidencySet,
151    ) -> Self {
152        // Stage the addAllocation; the actual `[set commit]` is deferred to
153        // the next encoder.commit* boundary via flush_pending. This is the
154        // structural fix for the per-allocation commit storm; mirrors
155        // llama.cpp's ggml-metal-device.m:1378-1382 pattern.
156        residency_set.add_allocation(&inner);
157
158        Self {
159            storage: Arc::new(MlxBufferStorage {
160                inner,
161                residency_set: Some(residency_set),
162            }),
163            dtype,
164            shape,
165            byte_offset: 0,
166        }
167    }
168
169    /// Create a zero-copy slice view of this buffer.
170    ///
171    /// Returns a new `MlxBuffer` that shares the same underlying Metal buffer
172    /// but starts at `byte_offset` bytes from the beginning and contains
173    /// `n_elements` elements of type `dtype`. No data is copied.
174    ///
175    /// The slice view shares the parent's residency-set guard via the
176    /// `Arc<MlxBufferStorage>`, so it does NOT trigger a second
177    /// `addAllocation:` and does NOT deregister the parent on drop.
178    ///
179    /// When this view is bound to a kernel, the encoder passes the byte offset
180    /// to Metal's `setBuffer:offset:atIndex:`, so the kernel sees only the
181    /// slice region.
182    ///
183    /// # Panics
184    ///
185    /// Panics if `byte_offset + n_elements * dtype.size_of() > self.inner.length()`.
186    #[inline]
187    pub fn slice_view(&self, byte_offset: u64, n_elements: usize) -> Self {
188        let end = byte_offset as usize + n_elements * self.dtype.size_of();
189        assert!(
190            end <= self.storage.inner.length() as usize,
191            "slice_view: out of bounds (byte_offset={}, n_elements={}, dtype_size={}, buf_len={})",
192            byte_offset,
193            n_elements,
194            self.dtype.size_of(),
195            self.storage.inner.length()
196        );
197        Self {
198            storage: self.storage.clone(),
199            dtype: self.dtype,
200            shape: vec![n_elements],
201            byte_offset,
202        }
203    }
204
205    // ---- accessors ----
206
207    /// Element data type.
208    #[inline]
209    pub fn dtype(&self) -> DType {
210        self.dtype
211    }
212
213    /// Tensor shape (dimensions).
214    #[inline]
215    pub fn shape(&self) -> &[usize] {
216        &self.shape
217    }
218
219    /// Total byte length of the Metal buffer.
220    #[inline]
221    pub fn byte_len(&self) -> usize {
222        self.storage.inner.length() as usize
223    }
224
225    /// Number of elements (product of shape dimensions, or `byte_len / dtype.size_of()`).
226    #[inline]
227    pub fn element_count(&self) -> usize {
228        self.shape.iter().copied().product()
229    }
230
231    /// Raw pointer to the buffer contents (CPU-accessible on Apple Silicon).
232    ///
233    /// # Safety
234    ///
235    /// The caller must ensure proper synchronization — do not read while a GPU
236    /// command buffer that writes this buffer is in flight.
237    #[inline]
238    pub fn contents_ptr(&self) -> *mut std::ffi::c_void {
239        self.storage.inner.contents()
240    }
241
242    /// Reference to the underlying `metal::Buffer` for passing to the encoder.
243    #[inline]
244    pub fn metal_buffer(&self) -> &MetalBuffer {
245        &self.storage.inner
246    }
247
248    /// Byte offset into the underlying Metal buffer (zero for non-slice buffers).
249    ///
250    /// When passing this buffer to a Metal kernel via `setBuffer:offset:atIndex:`,
251    /// use this offset so the kernel sees only the intended sub-region.
252    #[inline]
253    pub fn byte_offset(&self) -> u64 {
254        self.byte_offset
255    }
256
257    /// Consume self and return the inner `metal::Buffer` (used by buffer pool).
258    ///
259    /// If this is the last clone of the underlying `Arc<MlxBufferStorage>`,
260    /// the storage Drop fires after this returns — staging a deferred
261    /// `removeAllocation:` if the buffer carried a residency-set guard.
262    /// Pool-internal buffers do not carry guards, so this is a no-op for
263    /// the pool's `release` path.
264    #[inline]
265    pub(crate) fn into_inner(self) -> MetalBuffer {
266        self.storage.inner.clone()
267    }
268
269    /// Borrow the residency set that this buffer was registered with, if any.
270    ///
271    /// Used by [`MlxBufferPool::register_existing`](crate::MlxBufferPool::register_existing)
272    /// to short-circuit re-registration: a buffer created via
273    /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer) on a
274    /// residency-enabled device already owns its registration via the
275    /// `Arc<MlxBufferStorage>`, so the pool path is a no-op (modulo
276    /// validation that the device matches).
277    #[inline]
278    pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
279        self.storage.residency_set.as_ref()
280    }
281
282    // ---- typed CPU access (zero-copy on unified memory) ----
283
284    /// View the buffer contents as a typed slice.
285    ///
286    /// Returns an error if the buffer byte length is not an exact multiple of
287    /// `size_of::<T>()`.
288    ///
289    /// # Safety contract
290    ///
291    /// The caller must ensure:
292    /// 1. `T` matches the actual element type stored in the buffer.
293    /// 2. No GPU command buffer that writes this buffer is currently in flight.
294    pub fn as_slice<T: bytemuck::Pod>(&self) -> Result<&[T]> {
295        let elem_size = std::mem::size_of::<T>();
296        if elem_size == 0 {
297            return Err(MlxError::InvalidArgument(
298                "Cannot view buffer as zero-sized type".into(),
299            ));
300        }
301        let byte_len = self.byte_len();
302        if byte_len % elem_size != 0 {
303            return Err(MlxError::InvalidArgument(format!(
304                "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
305            )));
306        }
307        let ptr = self.contents_ptr();
308        if ptr.is_null() {
309            return Err(MlxError::BufferAllocationError { bytes: byte_len });
310        }
311        let count = byte_len / elem_size;
312        // SAFETY: Metal guarantees the pointer is valid for `byte_len` bytes and
313        // properly aligned for any type on Apple Silicon shared memory.  The
314        // caller upholds the type-match and no-concurrent-GPU-write contract.
315        let slice = unsafe { std::slice::from_raw_parts(ptr as *const T, count) };
316        Ok(slice)
317    }
318
319    /// View the buffer contents as a mutable typed slice.
320    ///
321    /// Same safety contract as [`as_slice`](Self::as_slice), plus: the caller
322    /// must ensure exclusive access (no other references to this buffer's memory
323    /// exist).
324    pub fn as_mut_slice<T: bytemuck::Pod>(&mut self) -> Result<&mut [T]> {
325        let elem_size = std::mem::size_of::<T>();
326        if elem_size == 0 {
327            return Err(MlxError::InvalidArgument(
328                "Cannot view buffer as zero-sized type".into(),
329            ));
330        }
331        let byte_len = self.byte_len();
332        if byte_len % elem_size != 0 {
333            return Err(MlxError::InvalidArgument(format!(
334                "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
335            )));
336        }
337        let ptr = self.contents_ptr();
338        if ptr.is_null() {
339            return Err(MlxError::BufferAllocationError { bytes: byte_len });
340        }
341        let count = byte_len / elem_size;
342        // SAFETY: same as as_slice, plus caller ensures exclusive mutable access.
343        let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) };
344        Ok(slice)
345    }
346
347    /// Overwrite the dtype and shape metadata.
348    ///
349    /// This does **not** re-allocate the Metal buffer — it only changes the
350    /// logical interpretation.  The caller must ensure the new shape is
351    /// consistent with the buffer's byte length.
352    #[allow(dead_code)]
353    pub(crate) fn reshape(&mut self, dtype: DType, shape: Vec<usize>) {
354        self.dtype = dtype;
355        self.shape = shape;
356    }
357}
358
359impl fmt::Debug for MlxBuffer {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        f.debug_struct("MlxBuffer")
362            .field("dtype", &self.dtype)
363            .field("shape", &self.shape)
364            .field("byte_len", &self.byte_len())
365            .finish()
366    }
367}