mlx_native/buffer_pool.rs
1//! [`MlxBufferPool`] — arena-style GPU buffer allocator with reuse.
2//!
3//! Buffers are bucketed by power-of-two sizes. When a buffer is released back
4//! to the pool, it is added to the free list for its size bucket. A subsequent
5//! `alloc` call will reuse a free buffer of compatible (>= requested) size
6//! rather than allocating new Metal memory.
7//!
8//! Calling [`reset`](MlxBufferPool::reset) returns all outstanding buffers to
9//! the free list without deallocating Metal memory — ideal for per-inference
10//! arena patterns.
11
12use std::collections::HashMap;
13
14use crate::buffer::MlxBuffer;
15use crate::device::MlxDevice;
16use crate::dtypes::DType;
17use crate::error::Result;
18
19/// Arena-style buffer pool that reuses Metal buffer allocations.
20///
21/// # Design
22///
23/// * Buffers are bucketed by their allocated size rounded up to the nearest
24/// power of two. This reduces fragmentation at the cost of occasionally
25/// over-allocating by up to 2x.
26/// * `release()` returns a single buffer; `reset()` returns all outstanding
27/// buffers.
28/// * The pool holds a reference to the `MlxDevice` so it can allocate fresh
29/// buffers when the free list is empty.
30pub struct MlxBufferPool<'d> {
31 device: &'d MlxDevice,
32 /// Free buffers keyed by their power-of-two bucket size.
33 free: HashMap<usize, Vec<metal::Buffer>>,
34}
35
36impl<'d> MlxBufferPool<'d> {
37 /// Create a new empty buffer pool backed by the given device.
38 pub fn new(device: &'d MlxDevice) -> Self {
39 Self {
40 device,
41 free: HashMap::new(),
42 }
43 }
44
45 /// Allocate a buffer from the pool.
46 ///
47 /// If a free buffer of compatible size exists in the pool, it is reused
48 /// (with updated dtype/shape metadata). Otherwise a new Metal buffer is
49 /// allocated from the device.
50 ///
51 /// The actual Metal buffer size will be rounded up to the nearest power of
52 /// two for bucketing purposes.
53 pub fn alloc(
54 &mut self,
55 byte_len: usize,
56 dtype: DType,
57 shape: Vec<usize>,
58 ) -> Result<MlxBuffer> {
59 let bucket = bucket_size(byte_len);
60
61 // Try to reuse a free buffer from this bucket.
62 if let Some(free_list) = self.free.get_mut(&bucket) {
63 if let Some(metal_buf) = free_list.pop() {
64 let mut buf = MlxBuffer::from_raw(metal_buf, dtype, shape);
65 // The reused buffer may have stale metadata; reshape it.
66 // byte_len is <= bucket, so the Metal buffer is large enough.
67 let _ = &mut buf; // reshape is handled by from_raw above
68 return Ok(buf);
69 }
70 }
71
72 // No free buffer available — allocate a fresh one at the bucket size
73 // (so future reuse is possible for any request up to this bucket).
74 self.device.alloc_buffer(bucket, dtype, shape)
75 }
76
77 /// Return a buffer to the pool's free list for future reuse.
78 ///
79 /// The Metal memory is **not** deallocated — it stays resident on the GPU
80 /// for fast reuse.
81 pub fn release(&mut self, buffer: MlxBuffer) {
82 let bucket = bucket_size(buffer.byte_len());
83 let metal_buf = buffer.into_inner();
84 self.free.entry(bucket).or_default().push(metal_buf);
85 }
86
87 /// Return all free buffers' count (for diagnostics).
88 pub fn free_count(&self) -> usize {
89 self.free.values().map(|v| v.len()).sum()
90 }
91
92 /// Total number of bytes held in the free list.
93 pub fn free_bytes(&self) -> usize {
94 self.free
95 .iter()
96 .map(|(&bucket, bufs)| bucket * bufs.len())
97 .sum()
98 }
99
100 /// Clear all free buffers, releasing Metal memory.
101 pub fn clear(&mut self) {
102 self.free.clear();
103 }
104}
105
106/// Round `n` up to the nearest power of two.
107///
108/// Returns 1 for n == 0 (though callers should never request 0 bytes).
109fn bucket_size(n: usize) -> usize {
110 if n <= 1 {
111 return 1;
112 }
113 n.next_power_of_two()
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_bucket_size_powers() {
122 assert_eq!(bucket_size(0), 1);
123 assert_eq!(bucket_size(1), 1);
124 assert_eq!(bucket_size(2), 2);
125 assert_eq!(bucket_size(3), 4);
126 assert_eq!(bucket_size(4), 4);
127 assert_eq!(bucket_size(5), 8);
128 assert_eq!(bucket_size(1023), 1024);
129 assert_eq!(bucket_size(1024), 1024);
130 assert_eq!(bucket_size(1025), 2048);
131 }
132}