use std::collections::HashMap;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::error::{MlxError, Result};
pub struct MlxBufferPool {
free: HashMap<usize, Vec<metal::Buffer>>,
in_use: Vec<(usize, metal::Buffer)>,
residency_set: Option<crate::residency::ResidencySet>,
resident_buffers: HashMap<usize, metal::Buffer>,
}
impl Default for MlxBufferPool {
fn default() -> Self {
Self::new()
}
}
impl MlxBufferPool {
pub fn new() -> Self {
Self {
free: HashMap::new(),
in_use: Vec::new(),
residency_set: None,
resident_buffers: HashMap::new(),
}
}
pub fn alloc(
&mut self,
device: &MlxDevice,
byte_len: usize,
dtype: DType,
shape: Vec<usize>,
) -> Result<MlxBuffer> {
let (buffer, added_residency) = self.alloc_inner(device, byte_len, dtype, shape)?;
if added_residency {
if let Some(set) = self.residency_set.as_ref() {
set.commit();
}
}
Ok(buffer)
}
pub fn alloc_batch<I>(&mut self, device: &MlxDevice, requests: I) -> Result<Vec<MlxBuffer>>
where
I: IntoIterator<Item = (usize, DType, Vec<usize>)>,
{
let mut buffers = Vec::new();
let mut added_residency = false;
for (byte_len, dtype, shape) in requests {
let (buffer, added) = self.alloc_inner(device, byte_len, dtype, shape)?;
added_residency |= added;
buffers.push(buffer);
}
if added_residency {
if let Some(set) = self.residency_set.as_ref() {
set.commit();
}
}
Ok(buffers)
}
fn alloc_inner(
&mut self,
device: &MlxDevice,
byte_len: usize,
dtype: DType,
shape: Vec<usize>,
) -> Result<(MlxBuffer, bool)> {
let bucket = bucket_size(byte_len);
let mut added_residency = false;
let metal_buf = self
.free
.get_mut(&bucket)
.and_then(|free_list| free_list.pop());
let metal_buf = match metal_buf {
Some(b) => b,
None => {
let raw = device
.metal_device()
.new_buffer(bucket as u64, metal::MTLResourceOptions::StorageModeShared);
if raw.contents().is_null() {
return Err(MlxError::BufferAllocationError { bytes: bucket });
}
added_residency = self.register_residency_allocation(device, &raw)?;
raw
}
};
self.in_use.push((bucket, metal_buf.clone()));
Ok((MlxBuffer::from_raw(metal_buf, dtype, shape), added_residency))
}
pub fn release(&mut self, buffer: MlxBuffer) {
let bucket = bucket_size(buffer.byte_len());
let metal_buf = buffer.into_inner();
self.free.entry(bucket).or_default().push(metal_buf);
}
pub fn reset(&mut self) {
for (bucket, metal_buf) in self.in_use.drain(..) {
self.free.entry(bucket).or_default().push(metal_buf);
}
}
pub fn register_existing(
&mut self,
device: &MlxDevice,
buffer: &MlxBuffer,
) -> Result<()> {
if let Some(buffer_set) = buffer.residency_set() {
let Some(device_set) = device.residency_set() else {
return Err(MlxError::InvalidArgument(
"MlxBuffer is registered with a residency set, but device has none".into(),
));
};
if !buffer_set.same_owner(device_set) {
return Err(MlxError::InvalidArgument(
"MlxBufferPool cannot register a buffer from a different residency-enabled device"
.into(),
));
}
match self.residency_set.as_ref() {
Some(pool_set) if !pool_set.same_owner(device_set) => {
return Err(MlxError::InvalidArgument(
"MlxBufferPool cannot mix residency-enabled devices".into(),
));
}
Some(_) => {}
None => {
self.residency_set = Some(device_set.clone());
}
}
return Ok(());
}
let added = self.register_residency_allocation(device, buffer.metal_buffer())?;
if added {
if let Some(set) = self.residency_set.as_ref() {
set.commit();
}
}
Ok(())
}
pub fn free_count(&self) -> usize {
self.free.values().map(|v| v.len()).sum()
}
pub fn free_bytes(&self) -> usize {
self.free
.iter()
.map(|(&bucket, bufs)| bucket * bufs.len())
.sum()
}
pub fn in_use_count(&self) -> usize {
self.in_use.len()
}
pub fn clear(&mut self) {
let mut removed_any = false;
if let Some(set) = self.residency_set.as_ref() {
for metal_buf in self.free.values().flatten() {
let key = buffer_key(metal_buf);
if let Some(resident_buf) = self.resident_buffers.remove(&key) {
set.remove_allocation(&resident_buf);
removed_any = true;
}
}
if removed_any {
set.commit();
}
}
self.free.clear();
}
fn register_residency_allocation(
&mut self,
device: &MlxDevice,
buffer: &metal::Buffer,
) -> Result<bool> {
let Some(device_set) = device.residency_set() else {
return Ok(false);
};
match self.residency_set.as_ref() {
Some(pool_set) if !pool_set.same_owner(device_set) => {
return Err(MlxError::InvalidArgument(
"MlxBufferPool cannot mix residency-enabled devices".into(),
));
}
Some(_) => {}
None => {
self.residency_set = Some(device_set.clone());
}
}
let key = buffer_key(buffer);
if !self.resident_buffers.contains_key(&key) {
device_set.add_allocation(buffer);
self.resident_buffers.insert(key, buffer.clone());
return Ok(true);
}
Ok(false)
}
fn remove_all_residency_allocations(&mut self) {
let Some(set) = self.residency_set.as_ref() else {
return;
};
if self.resident_buffers.is_empty() {
return;
}
for buffer in self.resident_buffers.values() {
set.remove_allocation(buffer);
}
set.commit();
self.resident_buffers.clear();
}
}
impl Drop for MlxBufferPool {
fn drop(&mut self) {
self.remove_all_residency_allocations();
}
}
fn bucket_size(n: usize) -> usize {
if n <= 1 {
return 1;
}
n.next_power_of_two()
}
#[inline]
fn buffer_key(buffer: &metal::Buffer) -> usize {
buffer.contents() as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bucket_size_powers() {
assert_eq!(bucket_size(0), 1);
assert_eq!(bucket_size(1), 1);
assert_eq!(bucket_size(2), 2);
assert_eq!(bucket_size(3), 4);
assert_eq!(bucket_size(4), 4);
assert_eq!(bucket_size(5), 8);
assert_eq!(bucket_size(1023), 1024);
assert_eq!(bucket_size(1024), 1024);
assert_eq!(bucket_size(1025), 2048);
}
#[test]
fn test_pool_arena_reset_recycles_in_use() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
let (ptr_a, ptr_b, ptr_c) = {
let buf_a = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc a");
let buf_b = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc b");
let buf_c = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc c");
(buf_a.contents_ptr(), buf_b.contents_ptr(), buf_c.contents_ptr())
};
assert_eq!(pool.in_use_count(), 3);
assert_eq!(pool.free_count(), 0);
pool.reset();
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.free_count(), 3);
let buf_d = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc d");
let buf_e = pool.alloc(&device, 2048, DType::F32, vec![512]).expect("alloc e");
let ptr_d = buf_d.contents_ptr();
let ptr_e = buf_e.contents_ptr();
assert!(
ptr_d == ptr_a || ptr_d == ptr_c,
"buf_d {:?} must reuse one of a {:?} / c {:?}",
ptr_d, ptr_a, ptr_c,
);
assert_eq!(ptr_e, ptr_b, "buf_e must reuse b (only 2048-bucket buffer)");
assert_eq!(pool.in_use_count(), 2);
assert_eq!(pool.free_count(), 1);
}
#[test]
fn test_pool_reset_with_no_alloc_is_idempotent() {
let mut pool = MlxBufferPool::new();
pool.reset();
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.free_count(), 0);
pool.reset();
pool.reset();
assert_eq!(pool.in_use_count(), 0);
}
#[test]
fn test_register_existing_does_not_recycle_on_reset() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
let external = device
.alloc_buffer(4096, DType::U8, vec![4096])
.expect("alloc external");
let external_ptr = external.contents_ptr();
pool.register_existing(&device, &external)
.expect("register_existing");
assert_eq!(pool.in_use_count(), 0);
pool.reset();
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.free_count(), 0);
drop(pool);
assert_eq!(external.contents_ptr(), external_ptr);
let slice: &[u8] = external.as_slice().expect("slice still valid");
assert_eq!(slice.len(), 4096);
}
#[test]
fn test_register_existing_idempotent() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
let external = device
.alloc_buffer(2048, DType::U8, vec![2048])
.expect("alloc external");
pool.register_existing(&device, &external)
.expect("register 1");
pool.register_existing(&device, &external)
.expect("register 2 (idempotent)");
drop(pool);
let _slice: &[u8] = external.as_slice().expect("still valid");
}
#[test]
fn test_register_existing_no_residency_env_is_noop() {
let prev = std::env::var("HF2Q_NO_RESIDENCY").ok();
crate::residency::reset_residency_env_cache_for_test();
std::env::set_var("HF2Q_NO_RESIDENCY", "1");
let device = MlxDevice::new().expect("device");
assert!(
!device.residency_sets_enabled(),
"device should boot without residency under HF2Q_NO_RESIDENCY=1",
);
let mut pool = MlxBufferPool::new();
let external = device
.alloc_buffer(1024, DType::U8, vec![1024])
.expect("alloc external");
pool.register_existing(&device, &external)
.expect("register_existing under HF2Q_NO_RESIDENCY=1 should succeed");
assert!(pool.residency_set.is_none());
assert!(pool.resident_buffers.is_empty());
match prev {
Some(v) => std::env::set_var("HF2Q_NO_RESIDENCY", v),
None => std::env::remove_var("HF2Q_NO_RESIDENCY"),
}
crate::residency::reset_residency_env_cache_for_test();
}
#[test]
fn test_pool_release_remains_supported_for_compat() {
let device = MlxDevice::new().expect("device");
let mut pool = MlxBufferPool::new();
let buf = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc");
assert_eq!(pool.in_use_count(), 1);
pool.release(buf);
assert_eq!(pool.free_count(), 1);
assert_eq!(pool.in_use_count(), 1);
let _buf2 = pool.alloc(&device, 1024, DType::F32, vec![256]).expect("alloc 2");
assert_eq!(pool.free_count(), 0);
assert_eq!(pool.in_use_count(), 2);
}
}