mlx_native/buffer.rs
1//! [`MlxBuffer`] — typed wrapper around a Metal GPU buffer.
2//!
3//! Buffers are allocated with `StorageModeShared` so that CPU and GPU share
4//! the same physical memory on Apple Silicon (zero-copy access via
5//! [`as_slice`](MlxBuffer::as_slice) / [`as_mut_slice`](MlxBuffer::as_mut_slice)).
6
7use std::fmt;
8
9use metal::Buffer as MetalBuffer;
10
11use crate::dtypes::DType;
12use crate::error::{MlxError, Result};
13
14/// A Metal GPU buffer annotated with element dtype and tensor shape.
15///
16/// On Apple Silicon the underlying memory is unified — `contents_ptr()` gives
17/// direct CPU access without any copy or transfer.
18///
19/// # Thread Safety
20///
21/// `MlxBuffer` is `Send + Sync` because the inner `metal::Buffer` is.
22pub struct MlxBuffer {
23 /// The underlying Metal buffer (StorageModeShared).
24 inner: MetalBuffer,
25 /// Element data type.
26 dtype: DType,
27 /// Tensor shape (e.g. `[2, 3, 4]` for a rank-3 tensor).
28 shape: Vec<usize>,
29}
30
31// metal::Buffer is Send + Sync; our extra fields (DType, Vec<usize>) are too.
32crate::static_assertions_send_sync!(MlxBuffer);
33
34impl MlxBuffer {
35 /// Create a new `MlxBuffer` wrapping an already-allocated Metal buffer.
36 ///
37 /// # When to use
38 ///
39 /// Use this to wrap Metal buffers obtained from external frameworks (e.g.
40 /// candle's `MetalStorage::buffer()`) for zero-copy interop on Apple
41 /// Silicon unified memory. Both frameworks see the same physical memory.
42 ///
43 /// # Safety contract
44 ///
45 /// The caller must ensure that `inner` remains valid for the lifetime of
46 /// the returned `MlxBuffer`. If the buffer was obtained from another
47 /// framework, the caller must ensure that framework does not deallocate
48 /// the buffer while this `MlxBuffer` exists.
49 pub fn from_raw(inner: MetalBuffer, dtype: DType, shape: Vec<usize>) -> Self {
50 Self {
51 inner,
52 dtype,
53 shape,
54 }
55 }
56
57 // ---- accessors ----
58
59 /// Element data type.
60 #[inline]
61 pub fn dtype(&self) -> DType {
62 self.dtype
63 }
64
65 /// Tensor shape (dimensions).
66 #[inline]
67 pub fn shape(&self) -> &[usize] {
68 &self.shape
69 }
70
71 /// Total byte length of the Metal buffer.
72 #[inline]
73 pub fn byte_len(&self) -> usize {
74 self.inner.length() as usize
75 }
76
77 /// Number of elements (product of shape dimensions, or `byte_len / dtype.size_of()`).
78 #[inline]
79 pub fn element_count(&self) -> usize {
80 self.shape.iter().copied().product()
81 }
82
83 /// Raw pointer to the buffer contents (CPU-accessible on Apple Silicon).
84 ///
85 /// # Safety
86 ///
87 /// The caller must ensure proper synchronization — do not read while a GPU
88 /// command buffer that writes this buffer is in flight.
89 #[inline]
90 pub fn contents_ptr(&self) -> *mut std::ffi::c_void {
91 self.inner.contents()
92 }
93
94 /// Reference to the underlying `metal::Buffer` for passing to the encoder.
95 #[inline]
96 pub fn metal_buffer(&self) -> &MetalBuffer {
97 &self.inner
98 }
99
100 /// Consume self and return the inner `metal::Buffer` (used by buffer pool).
101 #[inline]
102 pub(crate) fn into_inner(self) -> MetalBuffer {
103 self.inner
104 }
105
106 // ---- typed CPU access (zero-copy on unified memory) ----
107
108 /// View the buffer contents as a typed slice.
109 ///
110 /// Returns an error if the buffer byte length is not an exact multiple of
111 /// `size_of::<T>()`.
112 ///
113 /// # Safety contract
114 ///
115 /// The caller must ensure:
116 /// 1. `T` matches the actual element type stored in the buffer.
117 /// 2. No GPU command buffer that writes this buffer is currently in flight.
118 pub fn as_slice<T: bytemuck::Pod>(&self) -> Result<&[T]> {
119 let elem_size = std::mem::size_of::<T>();
120 if elem_size == 0 {
121 return Err(MlxError::InvalidArgument(
122 "Cannot view buffer as zero-sized type".into(),
123 ));
124 }
125 let byte_len = self.byte_len();
126 if byte_len % elem_size != 0 {
127 return Err(MlxError::InvalidArgument(format!(
128 "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
129 )));
130 }
131 let ptr = self.contents_ptr();
132 if ptr.is_null() {
133 return Err(MlxError::BufferAllocationError { bytes: byte_len });
134 }
135 let count = byte_len / elem_size;
136 // SAFETY: Metal guarantees the pointer is valid for `byte_len` bytes and
137 // properly aligned for any type on Apple Silicon shared memory. The
138 // caller upholds the type-match and no-concurrent-GPU-write contract.
139 let slice = unsafe { std::slice::from_raw_parts(ptr as *const T, count) };
140 Ok(slice)
141 }
142
143 /// View the buffer contents as a mutable typed slice.
144 ///
145 /// Same safety contract as [`as_slice`](Self::as_slice), plus: the caller
146 /// must ensure exclusive access (no other references to this buffer's memory
147 /// exist).
148 pub fn as_mut_slice<T: bytemuck::Pod>(&mut self) -> Result<&mut [T]> {
149 let elem_size = std::mem::size_of::<T>();
150 if elem_size == 0 {
151 return Err(MlxError::InvalidArgument(
152 "Cannot view buffer as zero-sized type".into(),
153 ));
154 }
155 let byte_len = self.byte_len();
156 if byte_len % elem_size != 0 {
157 return Err(MlxError::InvalidArgument(format!(
158 "Buffer byte length {byte_len} is not a multiple of element size {elem_size}"
159 )));
160 }
161 let ptr = self.contents_ptr();
162 if ptr.is_null() {
163 return Err(MlxError::BufferAllocationError { bytes: byte_len });
164 }
165 let count = byte_len / elem_size;
166 // SAFETY: same as as_slice, plus caller ensures exclusive mutable access.
167 let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) };
168 Ok(slice)
169 }
170
171 /// Overwrite the dtype and shape metadata.
172 ///
173 /// This does **not** re-allocate the Metal buffer — it only changes the
174 /// logical interpretation. The caller must ensure the new shape is
175 /// consistent with the buffer's byte length.
176 #[allow(dead_code)]
177 pub(crate) fn reshape(&mut self, dtype: DType, shape: Vec<usize>) {
178 self.dtype = dtype;
179 self.shape = shape;
180 }
181}
182
183impl fmt::Debug for MlxBuffer {
184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185 f.debug_struct("MlxBuffer")
186 .field("dtype", &self.dtype)
187 .field("shape", &self.shape)
188 .field("byte_len", &self.byte_len())
189 .finish()
190 }
191}