use std::collections::HashMap;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::error::{MlxError, Result};
pub struct MlxBufferPool {
free: HashMap<usize, Vec<metal::Buffer>>,
in_use: Vec<(usize, metal::Buffer)>,
}
impl Default for MlxBufferPool {
fn default() -> Self {
Self::new()
}
}
impl MlxBufferPool {
pub fn new() -> Self {
Self {
free: HashMap::new(),
in_use: Vec::new(),
}
}
pub fn alloc(
&mut self,
device: &MlxDevice,
byte_len: usize,
dtype: DType,
shape: Vec<usize>,
) -> Result<MlxBuffer> {
let bucket = bucket_size(byte_len);
let metal_buf = self
.free
.get_mut(&bucket)
.and_then(|free_list| free_list.pop());
let metal_buf = match metal_buf {
Some(b) => b,
None => {
let raw = device.metal_device().new_buffer(
bucket as u64,
metal::MTLResourceOptions::StorageModeShared,
);
if raw.contents().is_null() {
return Err(MlxError::BufferAllocationError { bytes: bucket });
}
raw
}
};
self.in_use.push((bucket, metal_buf.clone()));
Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
}
pub fn release(&mut self, buffer: MlxBuffer) {
let bucket = bucket_size(buffer.byte_len());
let metal_buf = buffer.into_inner();
self.free.entry(bucket).or_default().push(metal_buf);
}
pub fn reset(&mut self) {
for (bucket, metal_buf) in self.in_use.drain(..) {
self.free.entry(bucket).or_default().push(metal_buf);
}
}
pub fn free_count(&self) -> usize {
self.free.values().map(|v| v.len()).sum()
}
pub fn free_bytes(&self) -> usize {
self.free
.iter()
.map(|(&bucket, bufs)| bucket * bufs.len())
.sum()
}
pub fn in_use_count(&self) -> usize {
self.in_use.len()
}
pub fn clear(&mut self) {
self.free.clear();
}
}
fn bucket_size(n: usize) -> usize {
if n <= 1 {
return 1;
}
n.next_power_of_two()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bucket_size_powers() {
assert_eq!(bucket_size(0), 1);
assert_eq!(bucket_size(1), 1);
assert_eq!(bucket_size(2), 2);
assert_eq!(bucket_size(3), 4);
assert_eq!(bucket_size(4), 4);
assert_eq!(bucket_size(5), 8);
assert_eq!(bucket_size(1023), 1024);
assert_eq!(bucket_size(1024), 1024);
assert_eq!(bucket_size(1025), 2048);
}
#[test]
fn test_pool_arena_reset_recycles_in_use() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
let (ptr_a, ptr_b, ptr_c) = {
let buf_a = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc a");
let buf_b = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc b");
let buf_c = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc c");
(buf_a.contents_ptr(), buf_b.contents_ptr(), buf_c.contents_ptr())
};
assert_eq!(pool.in_use_count(), 3);
assert_eq!(pool.free_count(), 0);
pool.reset();
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.free_count(), 3);
let buf_d = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc d");
let buf_e = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc e");
let ptr_d = buf_d.contents_ptr();
let ptr_e = buf_e.contents_ptr();
assert!(
ptr_d == ptr_a || ptr_d == ptr_c,
"buf_d {:?} must reuse one of a {:?} / c {:?}",
ptr_d, ptr_a, ptr_c,
);
assert_eq!(ptr_e, ptr_b, "buf_e must reuse b (only 2048-bucket buffer)");
assert_eq!(pool.in_use_count(), 2);
assert_eq!(pool.free_count(), 1);
}
#[test]
fn test_pool_reset_with_no_alloc_is_idempotent() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
pool.reset();
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.free_count(), 0);
pool.reset();
pool.reset();
assert_eq!(pool.in_use_count(), 0);
}
#[test]
fn test_pool_release_remains_supported_for_compat() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
let buf = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc");
assert_eq!(pool.in_use_count(), 1);
pool.release(buf);
assert_eq!(pool.free_count(), 1);
assert_eq!(pool.in_use_count(), 1);
let _buf2 = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc 2");
assert_eq!(pool.free_count(), 0);
assert_eq!(pool.in_use_count(), 2);
}
}