Skip to main content

mlx_native/
buffer_pool.rs

1//! [`MlxBufferPool`] — arena-style GPU buffer allocator with reuse.
2//!
3//! Buffers are bucketed by power-of-two sizes.  When a buffer is released back
4//! to the pool, it is added to the free list for its size bucket.  A subsequent
5//! `alloc` call will reuse a free buffer of compatible (>= requested) size
6//! rather than allocating new Metal memory.
7//!
8//! Calling [`reset`](MlxBufferPool::reset) returns all outstanding buffers to
9//! the free list without deallocating Metal memory — ideal for per-inference
10//! arena patterns.
11
12use std::collections::HashMap;
13
14use crate::buffer::MlxBuffer;
15use crate::device::MlxDevice;
16use crate::dtypes::DType;
17use crate::error::Result;
18
19/// Arena-style buffer pool that reuses Metal buffer allocations.
20///
21/// # Design
22///
23/// * Buffers are bucketed by their allocated size rounded up to the nearest
24///   power of two.  This reduces fragmentation at the cost of occasionally
25///   over-allocating by up to 2x.
26/// * `release()` returns a single buffer; `reset()` returns all outstanding
27///   buffers.
28/// * The pool holds a reference to the `MlxDevice` so it can allocate fresh
29///   buffers when the free list is empty.
30pub struct MlxBufferPool<'d> {
31    device: &'d MlxDevice,
32    /// Free buffers keyed by their power-of-two bucket size.
33    free: HashMap<usize, Vec<metal::Buffer>>,
34}
35
36impl<'d> MlxBufferPool<'d> {
37    /// Create a new empty buffer pool backed by the given device.
38    pub fn new(device: &'d MlxDevice) -> Self {
39        Self {
40            device,
41            free: HashMap::new(),
42        }
43    }
44
45    /// Allocate a buffer from the pool.
46    ///
47    /// If a free buffer of compatible size exists in the pool, it is reused
48    /// (with updated dtype/shape metadata).  Otherwise a new Metal buffer is
49    /// allocated from the device.
50    ///
51    /// The actual Metal buffer size will be rounded up to the nearest power of
52    /// two for bucketing purposes.
53    pub fn alloc(
54        &mut self,
55        byte_len: usize,
56        dtype: DType,
57        shape: Vec<usize>,
58    ) -> Result<MlxBuffer> {
59        let bucket = bucket_size(byte_len);
60
61        // Try to reuse a free buffer from this bucket.
62        if let Some(free_list) = self.free.get_mut(&bucket) {
63            if let Some(metal_buf) = free_list.pop() {
64                let mut buf = MlxBuffer::from_raw(metal_buf, dtype, shape);
65                // The reused buffer may have stale metadata; reshape it.
66                // byte_len is <= bucket, so the Metal buffer is large enough.
67                let _ = &mut buf; // reshape is handled by from_raw above
68                return Ok(buf);
69            }
70        }
71
72        // No free buffer available — allocate a fresh one at the bucket size
73        // (so future reuse is possible for any request up to this bucket).
74        self.device.alloc_buffer(bucket, dtype, shape)
75    }
76
77    /// Return a buffer to the pool's free list for future reuse.
78    ///
79    /// The Metal memory is **not** deallocated — it stays resident on the GPU
80    /// for fast reuse.
81    pub fn release(&mut self, buffer: MlxBuffer) {
82        let bucket = bucket_size(buffer.byte_len());
83        let metal_buf = buffer.into_inner();
84        self.free.entry(bucket).or_default().push(metal_buf);
85    }
86
87    /// Return all free buffers' count (for diagnostics).
88    pub fn free_count(&self) -> usize {
89        self.free.values().map(|v| v.len()).sum()
90    }
91
92    /// Total number of bytes held in the free list.
93    pub fn free_bytes(&self) -> usize {
94        self.free
95            .iter()
96            .map(|(&bucket, bufs)| bucket * bufs.len())
97            .sum()
98    }
99
100    /// Clear all free buffers, releasing Metal memory.
101    pub fn clear(&mut self) {
102        self.free.clear();
103    }
104}
105
106/// Round `n` up to the nearest power of two.
107///
108/// Returns 1 for n == 0 (though callers should never request 0 bytes).
109fn bucket_size(n: usize) -> usize {
110    if n <= 1 {
111        return 1;
112    }
113    n.next_power_of_two()
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_bucket_size_powers() {
122        assert_eq!(bucket_size(0), 1);
123        assert_eq!(bucket_size(1), 1);
124        assert_eq!(bucket_size(2), 2);
125        assert_eq!(bucket_size(3), 4);
126        assert_eq!(bucket_size(4), 4);
127        assert_eq!(bucket_size(5), 8);
128        assert_eq!(bucket_size(1023), 1024);
129        assert_eq!(bucket_size(1024), 1024);
130        assert_eq!(bucket_size(1025), 2048);
131    }
132}