mlx_native/buffer_pool.rs
1//! [`MlxBufferPool`] — arena-style GPU buffer allocator with reuse.
2//!
3//! Buffers are bucketed by power-of-two sizes. When a buffer is released back
4//! to the pool, it is added to the free list for its size bucket. A subsequent
5//! `alloc` call will reuse a free buffer of compatible (>= requested) size
6//! rather than allocating new Metal memory.
7//!
8//! Two return-path patterns are supported and **must not be mixed within a
9//! single arena cycle**:
10//!
11//! * **Per-buffer** via [`release`](MlxBufferPool::release) — explicit return
12//! of a single buffer to the free list, suitable for ad-hoc patterns where
13//! the caller knows the precise lifetime of each buffer.
14//! * **Arena bulk** via [`reset`](MlxBufferPool::reset) — bulk-return of every
15//! buffer handed out by [`alloc`](MlxBufferPool::alloc) since the previous
16//! reset. Suitable for per-inference / per-decode-token arena patterns
17//! where no individual buffer's lifetime crosses the reset boundary.
18//!
19//! Internally, every `alloc` records an ARC-cloned `metal::Buffer` handle so
20//! that `reset` can bulk-recycle without requiring callers to enumerate every
21//! buffer individually. ARC retain on `metal::Buffer` is cheap (refcount inc).
22
23use std::collections::HashMap;
24
25use crate::buffer::MlxBuffer;
26use crate::device::MlxDevice;
27use crate::dtypes::DType;
28use crate::error::{MlxError, Result};
29
30/// Arena-style buffer pool that reuses Metal buffer allocations.
31///
32/// # Design
33///
34/// * Buffers are bucketed by their allocated size rounded up to the nearest
35/// power of two. This reduces fragmentation at the cost of occasionally
36/// over-allocating by up to 2x.
37/// * `release()` returns a single buffer; `reset()` returns all outstanding
38/// buffers handed out since the last reset.
39/// * The `MlxDevice` is passed in at every [`alloc`] call (rather than stored
40/// in the pool). This keeps the pool free of lifetime parameters so it
41/// can be embedded in any owner struct (e.g. the per-decode-token
42/// `DecodeBuffers` cache in hf2q's qwen35 forward path).
43///
44/// # Why an arena reset matters
45///
46/// In the per-decode-token hot path, each token allocates ~1750 Metal buffers
47/// for scratch / intermediate / parameter storage across attention, FFN, and
48/// linear-attention layers. Direct `MlxDevice::alloc_buffer()` calls hit
49/// Metal's allocator each time (5-30 µs each); pooling reuses the underlying
50/// `metal::Buffer` objects across token boundaries so steady-state allocation
51/// cost amortizes to near zero. See ADR-012 §Optimize / Task #15 for the
52/// MoE dwq46 0.90× parity gap that motivated this work.
53pub struct MlxBufferPool {
54 /// Free buffers keyed by their power-of-two bucket size.
55 free: HashMap<usize, Vec<metal::Buffer>>,
56 /// Buffers handed out by [`alloc`] since the last [`reset`]. Each entry
57 /// holds an ARC-cloned `metal::Buffer` so the pool's reference keeps the
58 /// underlying GPU allocation alive even after the caller's `MlxBuffer`
59 /// goes out of scope. [`reset`] drains this into [`free`].
60 in_use: Vec<(usize, metal::Buffer)>,
61}
62
63impl Default for MlxBufferPool {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl MlxBufferPool {
70 /// Create a new empty buffer pool. The Metal device is passed to
71 /// [`alloc`] at every call site, so the pool itself is lifetime-free.
72 pub fn new() -> Self {
73 Self {
74 free: HashMap::new(),
75 in_use: Vec::new(),
76 }
77 }
78
79 /// Allocate a buffer from the pool.
80 ///
81 /// If a free buffer of compatible size exists in the pool, it is reused
82 /// (with updated dtype/shape metadata). Otherwise a new Metal buffer is
83 /// allocated from `device` at the bucket size so future reuse is
84 /// possible for any request up to that bucket.
85 ///
86 /// Each successful `alloc` registers the buffer in the pool's in-use
87 /// list (ARC clone — cheap), so a subsequent [`reset`] returns it to
88 /// the free list automatically.
89 pub fn alloc(
90 &mut self,
91 device: &MlxDevice,
92 byte_len: usize,
93 dtype: DType,
94 shape: Vec<usize>,
95 ) -> Result<MlxBuffer> {
96 let bucket = bucket_size(byte_len);
97
98 // Try to reuse a free buffer from this bucket.
99 let metal_buf = self
100 .free
101 .get_mut(&bucket)
102 .and_then(|free_list| free_list.pop());
103
104 let metal_buf = match metal_buf {
105 Some(b) => b,
106 None => {
107 // Fresh allocation at bucket size.
108 let raw = device.metal_device().new_buffer(
109 bucket as u64,
110 metal::MTLResourceOptions::StorageModeShared,
111 );
112 if raw.contents().is_null() {
113 return Err(MlxError::BufferAllocationError { bytes: bucket });
114 }
115 raw
116 }
117 };
118
119 // Track the handout so reset() can recycle it. ARC clone is cheap.
120 self.in_use.push((bucket, metal_buf.clone()));
121
122 Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
123 }
124
125 /// Return a single buffer to the pool's free list for future reuse.
126 ///
127 /// The Metal memory is **not** deallocated — it stays resident on the GPU
128 /// for fast reuse. `release` is the per-buffer alternative to [`reset`];
129 /// see the module docs for guidance on which to use.
130 ///
131 /// **Mixing `release` and `reset` within the same arena cycle is not
132 /// supported** — the pool's in-use list does not deduplicate, so a buffer
133 /// returned via `release` and then bulk-returned via `reset` would land in
134 /// the free list twice (each entry holds an ARC clone of the same Metal
135 /// buffer; the duplication wastes a free-list slot but is not a memory
136 /// leak — both clones drop together once popped). Pick one pattern per
137 /// arena cycle.
138 pub fn release(&mut self, buffer: MlxBuffer) {
139 let bucket = bucket_size(buffer.byte_len());
140 let metal_buf = buffer.into_inner();
141 self.free.entry(bucket).or_default().push(metal_buf);
142 }
143
144 /// Bulk-return every buffer handed out by [`alloc`] since the last reset
145 /// to the pool's free list.
146 ///
147 /// # Caller contract
148 ///
149 /// All `MlxBuffer` values returned by `alloc` since the last reset must be
150 /// out-of-scope (dropped) at the time `reset` is called. Reset transfers
151 /// the pool's ARC clones to the free list, where they become available to
152 /// subsequent [`alloc`] calls. If a caller is still holding an `MlxBuffer`
153 /// and a later `alloc` re-issues the underlying buffer, the two callers
154 /// will share GPU memory (aliasing). The Metal ARC keeps the storage
155 /// alive in either case, but writes from the new caller will be visible
156 /// to the stale caller — a correctness bug, not a memory error.
157 ///
158 /// In Rust's ownership model, locally-bound `MlxBuffer` values fall out of
159 /// scope at the end of their lexical block, making the per-decode-token
160 /// arena pattern safe by construction:
161 ///
162 /// ```ignore
163 /// loop {
164 /// pool.reset(); // start of token — recycle previous token's buffers
165 /// forward_pass(&pool); // many alloc(), no explicit release
166 /// } // forward_pass returns; locals dropped
167 /// ```
168 pub fn reset(&mut self) {
169 for (bucket, metal_buf) in self.in_use.drain(..) {
170 self.free.entry(bucket).or_default().push(metal_buf);
171 }
172 }
173
174 /// Return all free buffers' count (for diagnostics).
175 pub fn free_count(&self) -> usize {
176 self.free.values().map(|v| v.len()).sum()
177 }
178
179 /// Total number of bytes held in the free list.
180 pub fn free_bytes(&self) -> usize {
181 self.free
182 .iter()
183 .map(|(&bucket, bufs)| bucket * bufs.len())
184 .sum()
185 }
186
187 /// Number of buffers currently in-use (alloc'd but not yet reset).
188 pub fn in_use_count(&self) -> usize {
189 self.in_use.len()
190 }
191
192 /// Clear all free buffers, releasing Metal memory. Does not affect
193 /// in-use tracking.
194 pub fn clear(&mut self) {
195 self.free.clear();
196 }
197}
198
199/// Round `n` up to the nearest power of two.
200///
201/// Returns 1 for n == 0 (though callers should never request 0 bytes).
202fn bucket_size(n: usize) -> usize {
203 if n <= 1 {
204 return 1;
205 }
206 n.next_power_of_two()
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_bucket_size_powers() {
215 assert_eq!(bucket_size(0), 1);
216 assert_eq!(bucket_size(1), 1);
217 assert_eq!(bucket_size(2), 2);
218 assert_eq!(bucket_size(3), 4);
219 assert_eq!(bucket_size(4), 4);
220 assert_eq!(bucket_size(5), 8);
221 assert_eq!(bucket_size(1023), 1024);
222 assert_eq!(bucket_size(1024), 1024);
223 assert_eq!(bucket_size(1025), 2048);
224 }
225
226 #[test]
227 fn test_pool_arena_reset_recycles_in_use() {
228 // Per-decode-token arena pattern: alloc many, drop locals, reset, alloc again.
229 // Subsequent allocs must reuse the same Metal buffers (verified by ARC-cloned
230 // contents pointer).
231 let device = MlxDevice::new().expect("device");
232 let mut pool = MlxBufferPool::new();
233
234 // Cycle 1: allocate three buffers in different buckets, then drop them
235 // (locals fall out of scope at the end of the block).
236 let (ptr_a, ptr_b, ptr_c) = {
237 let buf_a = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc a");
238 let buf_b = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc b");
239 let buf_c = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc c");
240 (buf_a.contents_ptr(), buf_b.contents_ptr(), buf_c.contents_ptr())
241 };
242 assert_eq!(pool.in_use_count(), 3);
243 assert_eq!(pool.free_count(), 0);
244
245 // Reset returns all three to free.
246 pool.reset();
247 assert_eq!(pool.in_use_count(), 0);
248 assert_eq!(pool.free_count(), 3);
249
250 // Cycle 2: allocate compatible-bucket buffers, must reuse the same
251 // underlying Metal buffers (contents_ptr equal).
252 let buf_d = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc d");
253 let buf_e = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc e");
254 let ptr_d = buf_d.contents_ptr();
255 let ptr_e = buf_e.contents_ptr();
256
257 // Pointers must come from {a, b, c} — bucket 1024 reuse for d (matches a or c),
258 // bucket 2048 reuse for e (matches b).
259 assert!(
260 ptr_d == ptr_a || ptr_d == ptr_c,
261 "buf_d {:?} must reuse one of a {:?} / c {:?}",
262 ptr_d, ptr_a, ptr_c,
263 );
264 assert_eq!(ptr_e, ptr_b, "buf_e must reuse b (only 2048-bucket buffer)");
265
266 // After cycle-2 alloc, free has 1 (the unused 1024-bucket buffer) + in_use 2.
267 assert_eq!(pool.in_use_count(), 2);
268 assert_eq!(pool.free_count(), 1);
269 }
270
271 #[test]
272 fn test_pool_reset_with_no_alloc_is_idempotent() {
273 // Empty reset must be a no-op.
274 let device = MlxDevice::new().expect("device");
275 let mut pool = MlxBufferPool::new();
276 pool.reset();
277 assert_eq!(pool.in_use_count(), 0);
278 assert_eq!(pool.free_count(), 0);
279 // Multiple resets without intervening alloc — still no-op.
280 pool.reset();
281 pool.reset();
282 assert_eq!(pool.in_use_count(), 0);
283 }
284
285 #[test]
286 fn test_pool_release_remains_supported_for_compat() {
287 // The existing per-buffer release() pattern still works. Mixing
288 // release+reset within the same arena cycle is documented as
289 // unsupported but technically lands a duplicate clone in free —
290 // verify the duplicate is harmless (alloc still picks up a buffer).
291 let device = MlxDevice::new().expect("device");
292 let mut pool = MlxBufferPool::new();
293
294 let buf = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc");
295 assert_eq!(pool.in_use_count(), 1);
296 pool.release(buf);
297 // release() does NOT remove from in_use; that's acceptable per the
298 // documented contract (don't mix patterns). Free has the released one.
299 assert_eq!(pool.free_count(), 1);
300 assert_eq!(pool.in_use_count(), 1);
301
302 // Allocating again pulls from free first.
303 let _buf2 = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc 2");
304 assert_eq!(pool.free_count(), 0);
305 assert_eq!(pool.in_use_count(), 2);
306 }
307}