Skip to main content

mlx_native/
buffer_pool.rs

1//! [`MlxBufferPool`] — arena-style GPU buffer allocator with reuse.
2//!
3//! Buffers are bucketed by power-of-two sizes.  When a buffer is released back
4//! to the pool, it is added to the free list for its size bucket.  A subsequent
5//! `alloc` call will reuse a free buffer of compatible (>= requested) size
6//! rather than allocating new Metal memory.
7//!
8//! Two return-path patterns are supported and **must not be mixed within a
9//! single arena cycle**:
10//!
11//! * **Per-buffer** via [`release`](MlxBufferPool::release) — explicit return
12//!   of a single buffer to the free list, suitable for ad-hoc patterns where
13//!   the caller knows the precise lifetime of each buffer.
14//! * **Arena bulk** via [`reset`](MlxBufferPool::reset) — bulk-return of every
15//!   buffer handed out by [`alloc`](MlxBufferPool::alloc) since the previous
16//!   reset.  Suitable for per-inference / per-decode-token arena patterns
17//!   where no individual buffer's lifetime crosses the reset boundary.
18//!
19//! Internally, every `alloc` records an ARC-cloned `metal::Buffer` handle so
20//! that `reset` can bulk-recycle without requiring callers to enumerate every
21//! buffer individually.  ARC retain on `metal::Buffer` is cheap (refcount inc).
22
23use std::collections::HashMap;
24
25use crate::buffer::MlxBuffer;
26use crate::device::MlxDevice;
27use crate::dtypes::DType;
28use crate::error::{MlxError, Result};
29
30/// Arena-style buffer pool that reuses Metal buffer allocations.
31///
32/// # Design
33///
34/// * Buffers are bucketed by their allocated size rounded up to the nearest
35///   power of two.  This reduces fragmentation at the cost of occasionally
36///   over-allocating by up to 2x.
37/// * `release()` returns a single buffer; `reset()` returns all outstanding
38///   buffers handed out since the last reset.
39/// * The `MlxDevice` is passed in at every [`alloc`] call (rather than stored
40///   in the pool).  This keeps the pool free of lifetime parameters so it
41///   can be embedded in any owner struct (e.g. the per-decode-token
42///   `DecodeBuffers` cache in hf2q's qwen35 forward path).
43///
44/// # Why an arena reset matters
45///
46/// In the per-decode-token hot path, each token allocates ~1750 Metal buffers
47/// for scratch / intermediate / parameter storage across attention, FFN, and
48/// linear-attention layers.  Direct `MlxDevice::alloc_buffer()` calls hit
49/// Metal's allocator each time (5-30 µs each); pooling reuses the underlying
50/// `metal::Buffer` objects across token boundaries so steady-state allocation
51/// cost amortizes to near zero.  See ADR-012 §Optimize / Task #15 for the
52/// MoE dwq46 0.90× parity gap that motivated this work.
53pub struct MlxBufferPool {
54    /// Free buffers keyed by their power-of-two bucket size.
55    free: HashMap<usize, Vec<metal::Buffer>>,
56    /// Buffers handed out by [`alloc`] since the last [`reset`].  Each entry
57    /// holds an ARC-cloned `metal::Buffer` so the pool's reference keeps the
58    /// underlying GPU allocation alive even after the caller's `MlxBuffer`
59    /// goes out of scope.  [`reset`] drains this into [`free`].
60    in_use: Vec<(usize, metal::Buffer)>,
61}
62
63impl Default for MlxBufferPool {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl MlxBufferPool {
70    /// Create a new empty buffer pool.  The Metal device is passed to
71    /// [`alloc`] at every call site, so the pool itself is lifetime-free.
72    pub fn new() -> Self {
73        Self {
74            free: HashMap::new(),
75            in_use: Vec::new(),
76        }
77    }
78
79    /// Allocate a buffer from the pool.
80    ///
81    /// If a free buffer of compatible size exists in the pool, it is reused
82    /// (with updated dtype/shape metadata).  Otherwise a new Metal buffer is
83    /// allocated from `device` at the bucket size so future reuse is
84    /// possible for any request up to that bucket.
85    ///
86    /// Each successful `alloc` registers the buffer in the pool's in-use
87    /// list (ARC clone — cheap), so a subsequent [`reset`] returns it to
88    /// the free list automatically.
89    pub fn alloc(
90        &mut self,
91        device: &MlxDevice,
92        byte_len: usize,
93        dtype: DType,
94        shape: Vec<usize>,
95    ) -> Result<MlxBuffer> {
96        let bucket = bucket_size(byte_len);
97
98        // Try to reuse a free buffer from this bucket.
99        let metal_buf = self
100            .free
101            .get_mut(&bucket)
102            .and_then(|free_list| free_list.pop());
103
104        let metal_buf = match metal_buf {
105            Some(b) => b,
106            None => {
107                // Fresh allocation at bucket size.
108                let raw = device.metal_device().new_buffer(
109                    bucket as u64,
110                    metal::MTLResourceOptions::StorageModeShared,
111                );
112                if raw.contents().is_null() {
113                    return Err(MlxError::BufferAllocationError { bytes: bucket });
114                }
115                raw
116            }
117        };
118
119        // Track the handout so reset() can recycle it.  ARC clone is cheap.
120        self.in_use.push((bucket, metal_buf.clone()));
121
122        Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
123    }
124
125    /// Return a single buffer to the pool's free list for future reuse.
126    ///
127    /// The Metal memory is **not** deallocated — it stays resident on the GPU
128    /// for fast reuse.  `release` is the per-buffer alternative to [`reset`];
129    /// see the module docs for guidance on which to use.
130    ///
131    /// **Mixing `release` and `reset` within the same arena cycle is not
132    /// supported** — the pool's in-use list does not deduplicate, so a buffer
133    /// returned via `release` and then bulk-returned via `reset` would land in
134    /// the free list twice (each entry holds an ARC clone of the same Metal
135    /// buffer; the duplication wastes a free-list slot but is not a memory
136    /// leak — both clones drop together once popped).  Pick one pattern per
137    /// arena cycle.
138    pub fn release(&mut self, buffer: MlxBuffer) {
139        let bucket = bucket_size(buffer.byte_len());
140        let metal_buf = buffer.into_inner();
141        self.free.entry(bucket).or_default().push(metal_buf);
142    }
143
144    /// Bulk-return every buffer handed out by [`alloc`] since the last reset
145    /// to the pool's free list.
146    ///
147    /// # Caller contract
148    ///
149    /// All `MlxBuffer` values returned by `alloc` since the last reset must be
150    /// out-of-scope (dropped) at the time `reset` is called.  Reset transfers
151    /// the pool's ARC clones to the free list, where they become available to
152    /// subsequent [`alloc`] calls.  If a caller is still holding an `MlxBuffer`
153    /// and a later `alloc` re-issues the underlying buffer, the two callers
154    /// will share GPU memory (aliasing).  The Metal ARC keeps the storage
155    /// alive in either case, but writes from the new caller will be visible
156    /// to the stale caller — a correctness bug, not a memory error.
157    ///
158    /// In Rust's ownership model, locally-bound `MlxBuffer` values fall out of
159    /// scope at the end of their lexical block, making the per-decode-token
160    /// arena pattern safe by construction:
161    ///
162    /// ```ignore
163    /// loop {
164    ///     pool.reset();          // start of token — recycle previous token's buffers
165    ///     forward_pass(&pool);   // many alloc(), no explicit release
166    /// }                          // forward_pass returns; locals dropped
167    /// ```
168    pub fn reset(&mut self) {
169        for (bucket, metal_buf) in self.in_use.drain(..) {
170            self.free.entry(bucket).or_default().push(metal_buf);
171        }
172    }
173
174    /// Return all free buffers' count (for diagnostics).
175    pub fn free_count(&self) -> usize {
176        self.free.values().map(|v| v.len()).sum()
177    }
178
179    /// Total number of bytes held in the free list.
180    pub fn free_bytes(&self) -> usize {
181        self.free
182            .iter()
183            .map(|(&bucket, bufs)| bucket * bufs.len())
184            .sum()
185    }
186
187    /// Number of buffers currently in-use (alloc'd but not yet reset).
188    pub fn in_use_count(&self) -> usize {
189        self.in_use.len()
190    }
191
192    /// Clear all free buffers, releasing Metal memory.  Does not affect
193    /// in-use tracking.
194    pub fn clear(&mut self) {
195        self.free.clear();
196    }
197}
198
199/// Round `n` up to the nearest power of two.
200///
201/// Returns 1 for n == 0 (though callers should never request 0 bytes).
202fn bucket_size(n: usize) -> usize {
203    if n <= 1 {
204        return 1;
205    }
206    n.next_power_of_two()
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_bucket_size_powers() {
215        assert_eq!(bucket_size(0), 1);
216        assert_eq!(bucket_size(1), 1);
217        assert_eq!(bucket_size(2), 2);
218        assert_eq!(bucket_size(3), 4);
219        assert_eq!(bucket_size(4), 4);
220        assert_eq!(bucket_size(5), 8);
221        assert_eq!(bucket_size(1023), 1024);
222        assert_eq!(bucket_size(1024), 1024);
223        assert_eq!(bucket_size(1025), 2048);
224    }
225
226    #[test]
227    fn test_pool_arena_reset_recycles_in_use() {
228        // Per-decode-token arena pattern: alloc many, drop locals, reset, alloc again.
229        // Subsequent allocs must reuse the same Metal buffers (verified by ARC-cloned
230        // contents pointer).
231        let device = MlxDevice::new().expect("device");
232        let mut pool = MlxBufferPool::new();
233
234        // Cycle 1: allocate three buffers in different buckets, then drop them
235        // (locals fall out of scope at the end of the block).
236        let (ptr_a, ptr_b, ptr_c) = {
237            let buf_a = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc a");
238            let buf_b = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc b");
239            let buf_c = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc c");
240            (buf_a.contents_ptr(), buf_b.contents_ptr(), buf_c.contents_ptr())
241        };
242        assert_eq!(pool.in_use_count(), 3);
243        assert_eq!(pool.free_count(), 0);
244
245        // Reset returns all three to free.
246        pool.reset();
247        assert_eq!(pool.in_use_count(), 0);
248        assert_eq!(pool.free_count(), 3);
249
250        // Cycle 2: allocate compatible-bucket buffers, must reuse the same
251        // underlying Metal buffers (contents_ptr equal).
252        let buf_d = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc d");
253        let buf_e = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc e");
254        let ptr_d = buf_d.contents_ptr();
255        let ptr_e = buf_e.contents_ptr();
256
257        // Pointers must come from {a, b, c} — bucket 1024 reuse for d (matches a or c),
258        // bucket 2048 reuse for e (matches b).
259        assert!(
260            ptr_d == ptr_a || ptr_d == ptr_c,
261            "buf_d {:?} must reuse one of a {:?} / c {:?}",
262            ptr_d, ptr_a, ptr_c,
263        );
264        assert_eq!(ptr_e, ptr_b, "buf_e must reuse b (only 2048-bucket buffer)");
265
266        // After cycle-2 alloc, free has 1 (the unused 1024-bucket buffer) + in_use 2.
267        assert_eq!(pool.in_use_count(), 2);
268        assert_eq!(pool.free_count(), 1);
269    }
270
271    #[test]
272    fn test_pool_reset_with_no_alloc_is_idempotent() {
273        // Empty reset must be a no-op.
274        let device = MlxDevice::new().expect("device");
275        let mut pool = MlxBufferPool::new();
276        pool.reset();
277        assert_eq!(pool.in_use_count(), 0);
278        assert_eq!(pool.free_count(), 0);
279        // Multiple resets without intervening alloc — still no-op.
280        pool.reset();
281        pool.reset();
282        assert_eq!(pool.in_use_count(), 0);
283    }
284
285    #[test]
286    fn test_pool_release_remains_supported_for_compat() {
287        // The existing per-buffer release() pattern still works.  Mixing
288        // release+reset within the same arena cycle is documented as
289        // unsupported but technically lands a duplicate clone in free —
290        // verify the duplicate is harmless (alloc still picks up a buffer).
291        let device = MlxDevice::new().expect("device");
292        let mut pool = MlxBufferPool::new();
293
294        let buf = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc");
295        assert_eq!(pool.in_use_count(), 1);
296        pool.release(buf);
297        // release() does NOT remove from in_use; that's acceptable per the
298        // documented contract (don't mix patterns).  Free has the released one.
299        assert_eq!(pool.free_count(), 1);
300        assert_eq!(pool.in_use_count(), 1);
301
302        // Allocating again pulls from free first.
303        let _buf2 = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc 2");
304        assert_eq!(pool.free_count(), 0);
305        assert_eq!(pool.in_use_count(), 2);
306    }
307}