use std::collections::HashMap;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::error::Result;
pub struct MlxBufferPool<'d> {
device: &'d MlxDevice,
free: HashMap<usize, Vec<metal::Buffer>>,
}
impl<'d> MlxBufferPool<'d> {
pub fn new(device: &'d MlxDevice) -> Self {
Self {
device,
free: HashMap::new(),
}
}
pub fn alloc(
&mut self,
byte_len: usize,
dtype: DType,
shape: Vec<usize>,
) -> Result<MlxBuffer> {
let bucket = bucket_size(byte_len);
if let Some(free_list) = self.free.get_mut(&bucket) {
if let Some(metal_buf) = free_list.pop() {
let mut buf = MlxBuffer::from_raw(metal_buf, dtype, shape);
let _ = &mut buf; return Ok(buf);
}
}
self.device.alloc_buffer(bucket, dtype, shape)
}
pub fn release(&mut self, buffer: MlxBuffer) {
let bucket = bucket_size(buffer.byte_len());
let metal_buf = buffer.into_inner();
self.free.entry(bucket).or_default().push(metal_buf);
}
pub fn free_count(&self) -> usize {
self.free.values().map(|v| v.len()).sum()
}
pub fn free_bytes(&self) -> usize {
self.free
.iter()
.map(|(&bucket, bufs)| bucket * bufs.len())
.sum()
}
pub fn clear(&mut self) {
self.free.clear();
}
}
fn bucket_size(n: usize) -> usize {
if n <= 1 {
return 1;
}
n.next_power_of_two()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bucket_size_powers() {
assert_eq!(bucket_size(0), 1);
assert_eq!(bucket_size(1), 1);
assert_eq!(bucket_size(2), 2);
assert_eq!(bucket_size(3), 4);
assert_eq!(bucket_size(4), 4);
assert_eq!(bucket_size(5), 8);
assert_eq!(bucket_size(1023), 1024);
assert_eq!(bucket_size(1024), 1024);
assert_eq!(bucket_size(1025), 2048);
}
}