mlx_native/buffer.rs
1//! [`MlxBuffer`] — typed wrapper around a Metal GPU buffer.
2//!
3//! Buffers are allocated with `StorageModeShared` so that CPU and GPU share
4//! the same physical memory on Apple Silicon (zero-copy access via
5//! [`as_slice`](MlxBuffer::as_slice) / [`as_mut_slice`](MlxBuffer::as_mut_slice)).
6
7use std::fmt;
8use std::sync::Arc;
9
10use metal::Buffer as MetalBuffer;
11
12use crate::dtypes::DType;
13use crate::error::{MlxError, Result};
14use crate::residency::ResidencySet;
15
16/// A Metal GPU buffer annotated with element dtype and tensor shape.
17///
18/// On Apple Silicon the underlying memory is unified — `contents_ptr()` gives
19/// direct CPU access without any copy or transfer.
20///
21/// # Thread Safety
22///
23/// `MlxBuffer` is `Send + Sync` because the inner `metal::Buffer` is.
24///
25/// # Residency-set lifecycle
26///
27/// Buffers produced by [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer)
28/// on a residency-enabled device carry a shared
29/// [`Arc<MlxBufferStorage>`](MlxBufferStorage) that owns the residency-set
30/// reference and runs `removeAllocation:` (deferred — flushed at the next
31/// `CommandEncoder::commit*` boundary) when the last clone is dropped.
32/// Mirrors llama.cpp's `ggml-metal-device.m:1378-1382` pattern: batch
33/// `addAllocation:` calls in a loop, commit ONCE.
34pub struct MlxBuffer {
35 /// The underlying Metal buffer (StorageModeShared) plus optional
36 /// residency-set membership guard.
37 storage: Arc<MlxBufferStorage>,
38 /// Element data type.
39 dtype: DType,
40 /// Tensor shape (e.g. `[2, 3, 4]` for a rank-3 tensor).
41 shape: Vec<usize>,
42 /// Byte offset into the underlying Metal buffer (for slice views).
43 /// Zero for normally-allocated buffers.
44 byte_offset: u64,
45}
46
47/// Owns a single Metal buffer allocation plus an optional residency-set
48/// membership guard.
49///
50/// Wrapped in [`Arc`] inside [`MlxBuffer`] so that [`Clone`] / [`slice_view`]
51/// share both the underlying Metal allocation and the residency-set
52/// registration. The Drop fires `removeAllocation:` only when the LAST clone
53/// goes out of scope — matching llama.cpp's `addAllocation:` /
54/// `removeAllocation:` lifecycle in `ggml-metal-device.m:1378-1382` and
55/// `ggml-metal-device.m:1397-1399`.
56///
57/// Drop is **deferred**: it calls `set.remove_allocation(buffer)` which marks
58/// the residency set's pending flag but does NOT call `[set commit]`. The
59/// commit is flushed at the next [`CommandEncoder::commit*`] boundary via
60/// [`ResidencySet::flush_pending`]. This collapses the per-allocation commit
61/// storm (~880 commits/decode-token in iter8d/8e claude+codex variants) into
62/// at most one commit per CB submission.
63pub(crate) struct MlxBufferStorage {
64 inner: MetalBuffer,
65 residency_set: Option<ResidencySet>,
66}
67
68impl Drop for MlxBufferStorage {
69 fn drop(&mut self) {
70 if let Some(set) = self.residency_set.as_ref() {
71 // Mirror ggml-metal-device.m:1397-1399 free-path semantics, but
72 // deferred — the actual `[set commit]` is issued at the next
73 // CommandEncoder::commit* boundary by flush_pending().
74 set.remove_allocation(&self.inner);
75 }
76 }
77}
78
79// metal::Buffer is Send + Sync; our extra fields (DType, Vec<usize>) are too.
80crate::static_assertions_send_sync!(MlxBuffer);
81
82impl Clone for MlxBuffer {
83 /// Increment the storage's `Arc` ref-count and wrap it in a new
84 /// `MlxBuffer`. Both the original and the clone refer to the same
85 /// underlying GPU allocation AND share the residency-set membership
86 /// guard — no data is copied, no double-registration occurs.
87 ///
88 /// This is safe because `metal::Buffer` wraps an `MTLBuffer` Objective-C
89 /// object whose lifetime is managed by ARC; `Arc::clone` increments the
90 /// Rust-side refcount, and the inner `MlxBufferStorage` Drop runs once
91 /// when the last clone is released.
92 fn clone(&self) -> Self {
93 Self {
94 storage: self.storage.clone(),
95 dtype: self.dtype,
96 shape: self.shape.clone(),
97 byte_offset: self.byte_offset,
98 }
99 }
100}
101
102impl MlxBuffer {
103 /// Create a new `MlxBuffer` wrapping an already-allocated Metal buffer.
104 ///
105 /// # When to use
106 ///
107 /// Use this to wrap Metal buffers obtained from external frameworks (e.g.
108 /// candle's `MetalStorage::buffer()`) for zero-copy interop on Apple
109 /// Silicon unified memory. Both frameworks see the same physical memory.
110 ///
111 /// # Safety contract
112 ///
113 /// The caller must ensure that `inner` remains valid for the lifetime of
114 /// the returned `MlxBuffer`. If the buffer was obtained from another
115 /// framework, the caller must ensure that framework does not deallocate
116 /// the buffer while this `MlxBuffer` exists.
117 ///
118 /// The returned buffer carries no residency-set guard — pool / external
119 /// callers that want residency tracking should go through
120 /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer) or
121 /// [`MlxBufferPool::register_existing`](crate::MlxBufferPool::register_existing).
122 pub fn from_raw(inner: MetalBuffer, dtype: DType, shape: Vec<usize>) -> Self {
123 Self {
124 storage: Arc::new(MlxBufferStorage {
125 inner,
126 residency_set: None,
127 }),
128 dtype,
129 shape,
130 byte_offset: 0,
131 }
132 }
133
134 /// Create a new buffer and stage its Metal allocation for inclusion in
135 /// the given residency set.
136 ///
137 /// Calls `set.add_allocation(buffer)` (deferred — no `[set commit]` until
138 /// the next [`flush_pending`](ResidencySet::flush_pending) at a
139 /// `CommandEncoder::commit*` boundary). The buffer's residency-set guard
140 /// is dropped when the last clone of the returned `MlxBuffer` (and any
141 /// slice views) goes out of scope, which fires the matching
142 /// `removeAllocation:` (also deferred).
143 ///
144 /// Crate-private — external callers should go through
145 /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer).
146 pub(crate) fn with_residency(
147 inner: MetalBuffer,
148 dtype: DType,
149 shape: Vec<usize>,
150 residency_set: ResidencySet,
151 ) -> Self {
152 // Stage the addAllocation; the actual `[set commit]` is deferred to
153 // the next encoder.commit* boundary via flush_pending. This is the
154 // structural fix for the per-allocation commit storm; mirrors
155 // llama.cpp's ggml-metal-device.m:1378-1382 pattern.
156 residency_set.add_allocation(&inner);
157
158 Self {
159 storage: Arc::new(MlxBufferStorage {
160 inner,
161 residency_set: Some(residency_set),
162 }),
163 dtype,
164 shape,
165 byte_offset: 0,
166 }
167 }
168
169 /// Create a zero-copy slice view of this buffer.
170 ///
171 /// Returns a new `MlxBuffer` that shares the same underlying Metal buffer
172 /// but starts at `byte_offset` bytes from the beginning and contains
173 /// `n_elements` elements of type `dtype`. No data is copied.
174 ///
175 /// The slice view shares the parent's residency-set guard via the
176 /// `Arc<MlxBufferStorage>`, so it does NOT trigger a second
177 /// `addAllocation:` and does NOT deregister the parent on drop.
178 ///
179 /// When this view is bound to a kernel, the encoder passes the byte offset
180 /// to Metal's `setBuffer:offset:atIndex:`, so the kernel sees only the
181 /// slice region.
182 ///
183 /// # Panics
184 ///
185 /// Panics if `byte_offset + n_elements * dtype.size_of() > self.inner.length()`.
186 #[inline]
187 pub fn slice_view(&self, byte_offset: u64, n_elements: usize) -> Self {
188 let end = byte_offset as usize + n_elements * self.dtype.size_of();
189 assert!(
190 end <= self.storage.inner.length() as usize,
191 "slice_view: out of bounds (byte_offset={}, n_elements={}, dtype_size={}, buf_len={})",
192 byte_offset,
193 n_elements,
194 self.dtype.size_of(),
195 self.storage.inner.length()
196 );
197 Self {
198 storage: self.storage.clone(),
199 dtype: self.dtype,
200 shape: vec![n_elements],
201 byte_offset,
202 }
203 }
204
205 // ---- accessors ----
206
207 /// Element data type.
208 #[inline]
209 pub fn dtype(&self) -> DType {
210 self.dtype
211 }
212
213 /// Tensor shape (dimensions).
214 #[inline]
215 pub fn shape(&self) -> &[usize] {
216 &self.shape
217 }
218
219 /// Total byte length of the Metal buffer.
220 #[inline]
221 pub fn byte_len(&self) -> usize {
222 self.storage.inner.length() as usize
223 }
224
225 /// Number of elements (product of shape dimensions, or `byte_len / dtype.size_of()`).
226 #[inline]
227 pub fn element_count(&self) -> usize {
228 self.shape.iter().copied().product()
229 }
230
231 /// Raw pointer to the buffer contents (CPU-accessible on Apple Silicon).
232 ///
233 /// # Safety
234 ///
235 /// The caller must ensure proper synchronization — do not read while a GPU
236 /// command buffer that writes this buffer is in flight.
237 #[inline]
238 pub fn contents_ptr(&self) -> *mut std::ffi::c_void {
239 self.storage.inner.contents()
240 }
241
242 /// Reference to the underlying `metal::Buffer` for passing to the encoder.
243 #[inline]
244 pub fn metal_buffer(&self) -> &MetalBuffer {
245 &self.storage.inner
246 }
247
248 /// Byte offset into the underlying Metal buffer (zero for non-slice buffers).
249 ///
250 /// When passing this buffer to a Metal kernel via `setBuffer:offset:atIndex:`,
251 /// use this offset so the kernel sees only the intended sub-region.
252 #[inline]
253 pub fn byte_offset(&self) -> u64 {
254 self.byte_offset
255 }
256
257 /// Consume self and return the inner `metal::Buffer` (used by buffer pool).
258 ///
259 /// If this is the last clone of the underlying `Arc<MlxBufferStorage>`,
260 /// the storage Drop fires after this returns — staging a deferred
261 /// `removeAllocation:` if the buffer carried a residency-set guard.
262 /// Pool-internal buffers do not carry guards, so this is a no-op for
263 /// the pool's `release` path.
264 #[inline]
265 pub(crate) fn into_inner(self) -> MetalBuffer {
266 self.storage.inner.clone()
267 }
268
269 /// Borrow the residency set that this buffer was registered with, if any.
270 ///
271 /// Used by [`MlxBufferPool::register_existing`](crate::MlxBufferPool::register_existing)
272 /// to short-circuit re-registration: a buffer created via
273 /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer) on a
274 /// residency-enabled device already owns its registration via the
275 /// `Arc<MlxBufferStorage>`, so the pool path is a no-op (modulo
276 /// validation that the device matches).
277 #[inline]
278 pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
279 self.storage.residency_set.as_ref()
280 }
281
282 // ---- typed CPU access (zero-copy on unified memory) ----
283
284 /// View the buffer contents as a typed slice.
285 ///
286 /// Returns an error if the buffer byte length is not an exact multiple of
287 /// `size_of::<T>()`.
288 ///
289 /// # Safety contract
290 ///
291 /// The caller must ensure:
292 /// 1. `T` matches the actual element type stored in the buffer.
293 /// 2. No GPU command buffer that writes this buffer is currently in flight.
294 pub fn as_slice<T: bytemuck::Pod>(&self) -> Result<&[T]> {
295 let elem_size = std::mem::size_of::<T>();
296 if elem_size == 0 {
297 return Err(MlxError::InvalidArgument(
298 "Cannot view buffer as zero-sized type".into(),
299 ));
300 }
301 let byte_len = self.byte_len();
302 if byte_len % elem_size != 0 {
303 return Err(MlxError::InvalidArgument(format!(
304 "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
305 )));
306 }
307 let ptr = self.contents_ptr();
308 if ptr.is_null() {
309 return Err(MlxError::BufferAllocationError { bytes: byte_len });
310 }
311 let count = byte_len / elem_size;
312 // SAFETY: Metal guarantees the pointer is valid for `byte_len` bytes and
313 // properly aligned for any type on Apple Silicon shared memory. The
314 // caller upholds the type-match and no-concurrent-GPU-write contract.
315 let slice = unsafe { std::slice::from_raw_parts(ptr as *const T, count) };
316 Ok(slice)
317 }
318
319 /// View the buffer contents as a mutable typed slice.
320 ///
321 /// Same safety contract as [`as_slice`](Self::as_slice), plus: the caller
322 /// must ensure exclusive access (no other references to this buffer's memory
323 /// exist).
324 pub fn as_mut_slice<T: bytemuck::Pod>(&mut self) -> Result<&mut [T]> {
325 let elem_size = std::mem::size_of::<T>();
326 if elem_size == 0 {
327 return Err(MlxError::InvalidArgument(
328 "Cannot view buffer as zero-sized type".into(),
329 ));
330 }
331 let byte_len = self.byte_len();
332 if byte_len % elem_size != 0 {
333 return Err(MlxError::InvalidArgument(format!(
334 "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
335 )));
336 }
337 let ptr = self.contents_ptr();
338 if ptr.is_null() {
339 return Err(MlxError::BufferAllocationError { bytes: byte_len });
340 }
341 let count = byte_len / elem_size;
342 // SAFETY: same as as_slice, plus caller ensures exclusive mutable access.
343 let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) };
344 Ok(slice)
345 }
346
347 /// Overwrite the dtype and shape metadata.
348 ///
349 /// This does **not** re-allocate the Metal buffer — it only changes the
350 /// logical interpretation. The caller must ensure the new shape is
351 /// consistent with the buffer's byte length.
352 #[allow(dead_code)]
353 pub(crate) fn reshape(&mut self, dtype: DType, shape: Vec<usize>) {
354 self.dtype = dtype;
355 self.shape = shape;
356 }
357}
358
359impl fmt::Debug for MlxBuffer {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 f.debug_struct("MlxBuffer")
362 .field("dtype", &self.dtype)
363 .field("shape", &self.shape)
364 .field("byte_len", &self.byte_len())
365 .finish()
366 }
367}