use std::ops::{Deref, DerefMut};
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use rten_gemm::{PackedAMatrix, PackedBMatrix};
use rten_tensor::storage::{Alloc, CowData};
use rten_tensor::{Contiguous, Layout, TensorBase};
pub struct Buffer {
ptr: *mut u8,
capacity: usize,
layout: std::alloc::Layout,
drop: fn(&mut Buffer),
}
impl Buffer {
fn from_vec<T>(mut vec: Vec<T>) -> Buffer {
let layout = std::alloc::Layout::array::<T>(vec.capacity()).unwrap();
vec.clear();
let mut vec_md = std::mem::ManuallyDrop::new(vec);
Buffer {
ptr: vec_md.as_mut_ptr() as *mut u8,
capacity: vec_md.capacity(),
layout,
drop: Buffer::release::<T>,
}
}
fn can_fit<T>(&self, capacity: usize) -> bool {
self.layout_match::<T>() && self.capacity >= capacity
}
fn into_vec<T>(self) -> Option<Vec<T>> {
if !self.layout_match::<T>() {
return None;
}
let vec = unsafe { Vec::from_raw_parts(self.ptr as *mut T, 0, self.capacity) };
std::mem::forget(self);
Some(vec)
}
fn layout_match<T>(&self) -> bool {
std::alloc::Layout::array::<T>(self.capacity)
.map(|layout| layout == self.layout)
.unwrap_or(false)
}
fn release<T>(this: &mut Buffer) {
let vec = unsafe { Vec::<T>::from_raw_parts(this.ptr as *mut T, 0, this.capacity) };
std::mem::drop(vec);
}
}
unsafe impl Send for Buffer {}
unsafe impl Sync for Buffer {}
impl Drop for Buffer {
fn drop(&mut self) {
(self.drop)(self);
}
}
impl<T> From<Vec<T>> for Buffer {
fn from(val: Vec<T>) -> Buffer {
Self::from_vec(val)
}
}
pub struct BufferPool {
buffers: Mutex<Vec<Buffer>>,
alloc_count: AtomicUsize,
hit_count: AtomicUsize,
min_size: usize,
}
impl BufferPool {
pub fn new() -> BufferPool {
BufferPool {
buffers: Mutex::new(Vec::new()),
alloc_count: AtomicUsize::new(0),
hit_count: AtomicUsize::new(0),
min_size: 128,
}
}
pub fn with_min_size(mut self, n_bytes: usize) -> Self {
self.min_size = n_bytes;
self
}
pub fn alloc<T>(&self, capacity: usize) -> Vec<T> {
if capacity * size_of::<T>() < self.min_size {
return Vec::with_capacity(capacity);
}
self.alloc_count.fetch_add(1, Ordering::AcqRel);
let mut buffers = self.buffers.lock().unwrap();
let best_fit = buffers
.iter()
.enumerate()
.fold(None, |best_fit, (idx, buffer)| {
if !buffer.can_fit::<T>(capacity) {
return best_fit;
};
if let Some((best_fit_idx, best_fit_size)) = best_fit
&& buffer.capacity >= best_fit_size
{
return Some((best_fit_idx, best_fit_size));
}
Some((idx, buffer.capacity))
});
if let Some((best_fit, _overhead)) = best_fit {
self.hit_count.fetch_add(1, Ordering::AcqRel);
let item = buffers.remove(best_fit);
return item.into_vec::<T>().expect("alignment should match");
}
std::mem::drop(buffers);
Vec::with_capacity(capacity)
}
pub fn add<B: Into<Buffer>>(&self, buf: B) {
let buf: Buffer = buf.into();
if buf.layout.size() >= self.min_size {
self.buffers.lock().unwrap().push(buf);
}
}
pub fn alloc_count(&self) -> usize {
self.alloc_count.load(Ordering::Acquire)
}
pub fn hit_count(&self) -> usize {
self.hit_count.load(Ordering::Acquire)
}
pub fn len(&self) -> usize {
self.buffers.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.buffers.lock().unwrap().is_empty()
}
}
impl Alloc for BufferPool {
fn alloc<T>(&self, capacity: usize) -> Vec<T> {
self.alloc(capacity)
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
pub trait ExtractBuffer {
fn extract_buffer(self) -> Option<Buffer>;
}
impl<T> ExtractBuffer for Vec<T> {
fn extract_buffer(self) -> Option<Buffer> {
if self.capacity() > 0 {
Some(self.into())
} else {
None
}
}
}
impl<T, L: Layout + Clone> ExtractBuffer for TensorBase<Vec<T>, L> {
fn extract_buffer(self) -> Option<Buffer> {
Some(self.into_non_contiguous_data().into())
}
}
impl<T, L: Layout + Clone> ExtractBuffer for Contiguous<TensorBase<Vec<T>, L>> {
fn extract_buffer(self) -> Option<Buffer> {
Some(self.into_data().into())
}
}
impl<T, L: Layout + Clone> ExtractBuffer for TensorBase<CowData<'_, T>, L> {
fn extract_buffer(self) -> Option<Buffer> {
self.into_non_contiguous_data().map(|data| data.into())
}
}
impl<T, L: Layout + Clone> ExtractBuffer for Contiguous<TensorBase<CowData<'_, T>, L>> {
fn extract_buffer(self) -> Option<Buffer> {
self.into_data().map(|data| data.into())
}
}
impl<T> ExtractBuffer for PackedAMatrix<T> {
fn extract_buffer(self) -> Option<Buffer> {
Some(self.into_vec().into())
}
}
impl<T> ExtractBuffer for PackedBMatrix<T> {
fn extract_buffer(self) -> Option<Buffer> {
Some(self.into_vec().into())
}
}
pub trait AutoReturn {
fn auto_return(self, pool: &BufferPool) -> PoolRef<'_, Self>
where
Self: Sized + ExtractBuffer;
}
impl<EB: ExtractBuffer> AutoReturn for EB {
fn auto_return(self, pool: &BufferPool) -> PoolRef<'_, EB> {
PoolRef::new(pool, self)
}
}
pub struct PoolRef<'a, T: ExtractBuffer> {
pool: &'a BufferPool,
container: Option<T>,
}
impl<'a, T: ExtractBuffer> PoolRef<'a, T> {
pub fn new(pool: &'a BufferPool, tensor: T) -> Self {
PoolRef {
pool,
container: Some(tensor),
}
}
pub fn take(mut self) -> T {
self.container.take().unwrap()
}
}
impl<T: ExtractBuffer> Deref for PoolRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.container.as_ref().unwrap()
}
}
impl<T: ExtractBuffer> DerefMut for PoolRef<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.container.as_mut().unwrap()
}
}
impl<T: ExtractBuffer> Drop for PoolRef<'_, T> {
fn drop(&mut self) {
if let Some(container) = self.container.take()
&& let Some(buffer) = container.extract_buffer()
{
self.pool.add(buffer)
}
}
}
#[cfg(test)]
mod tests {
use super::{AutoReturn, Buffer, BufferPool, ExtractBuffer};
use rten_tensor::NdTensor;
use rten_tensor::prelude::*;
#[test]
fn test_buffer() {
let vec = vec![1i32, 2, 3];
let cap = vec.capacity();
let buf = Buffer::from_vec(vec);
let new_vec: Vec<f32> = buf.into_vec().unwrap();
assert_eq!(new_vec.capacity(), cap); assert_eq!(new_vec.len(), 0);
let buf = Buffer::from_vec(new_vec);
let new_vec: Option<Vec<i64>> = buf.into_vec();
assert!(new_vec.is_none());
let vec = vec![1i32, 2, 3];
let buf = Buffer::from_vec(vec);
let new_vec: Option<Vec<u8>> = buf.into_vec();
assert!(new_vec.is_none());
}
#[test]
fn test_empty_buffer() {
let buf = Buffer::from_vec(Vec::<i32>::new());
std::mem::drop(buf);
}
#[test]
fn test_zst_buffer() {
let vec = vec![(), ()];
let cap = vec.capacity();
let buf = Buffer::from_vec(vec);
let new_vec: Vec<()> = buf.into_vec().unwrap();
assert_eq!(new_vec.len(), 0);
assert_eq!(new_vec.capacity(), cap);
let buf = Buffer::from_vec(new_vec);
std::mem::drop(buf);
}
#[test]
fn test_pool_alloc_tensor() {
let pool = BufferPool::new().with_min_size(0);
let tensor = NdTensor::<f32, 2>::zeros_in(&pool, [2, 2]);
assert_eq!(tensor.shape(), [2, 2]);
assert_eq!(pool.alloc_count(), 1);
assert_eq!(pool.hit_count(), 0);
let ptr = tensor.data().unwrap().as_ptr();
pool.add(tensor.extract_buffer().unwrap());
let tensor = NdTensor::<f32, 2>::zeros_in(&pool, [2, 2]);
assert_eq!(tensor.shape(), [2, 2]);
assert_eq!(pool.alloc_count(), 2);
assert_eq!(pool.hit_count(), 1);
assert_eq!(tensor.data().unwrap().as_ptr(), ptr);
pool.add(tensor.extract_buffer().unwrap());
let tensor = NdTensor::<f32, 2>::zeros_in(&pool, [2, 1]);
assert_eq!(tensor.shape(), [2, 1]);
assert_eq!(pool.alloc_count(), 3);
assert_eq!(pool.hit_count(), 2);
pool.add(tensor.extract_buffer().unwrap());
let tensor = NdTensor::<f32, 2>::zeros_in(&pool, [2, 3]);
assert_eq!(tensor.shape(), [2, 3]);
assert_eq!(pool.alloc_count(), 4);
assert_eq!(pool.hit_count(), 2);
let int_tensor = NdTensor::<i8, 2>::zeros_in(&pool, [2, 2]);
assert_eq!(int_tensor.shape(), [2, 2]);
assert_eq!(pool.alloc_count(), 5);
assert_eq!(pool.hit_count(), 2);
pool.add(int_tensor.extract_buffer().unwrap());
let int_tensor = NdTensor::<i8, 2>::zeros_in(&pool, [2, 2]);
assert_eq!(int_tensor.shape(), [2, 2]);
assert_eq!(pool.alloc_count(), 6);
assert_eq!(pool.hit_count(), 3);
}
#[test]
fn test_pool_alloc() {
let pool = BufferPool::new().with_min_size(0);
let vec = pool.alloc::<f32>(128);
assert_eq!(vec.capacity(), 128);
assert_eq!(vec.len(), 0);
assert_eq!(pool.alloc_count(), 1);
assert_eq!(pool.hit_count(), 0);
pool.add(vec);
let vec = pool.alloc::<f32>(64);
assert_eq!(vec.capacity(), 128);
assert_eq!(vec.len(), 0);
assert_eq!(pool.alloc_count(), 2);
assert_eq!(pool.hit_count(), 1);
}
#[test]
fn test_pool_alloc_small() {
let pool = BufferPool::new().with_min_size(18);
let vec = pool.alloc::<f32>(4);
assert_eq!(vec.capacity(), 4);
assert_eq!(vec.len(), 0);
assert_eq!(pool.alloc_count(), 0);
assert_eq!(pool.hit_count(), 0);
let vec = pool.alloc::<f32>(5);
assert_eq!(vec.capacity(), 5);
assert_eq!(vec.len(), 0);
assert_eq!(pool.alloc_count(), 1);
assert_eq!(pool.hit_count(), 0);
}
#[test]
fn test_pool_add_small() {
let pool = BufferPool::new().with_min_size(18);
pool.add(vec![0.0f32; 4]);
assert_eq!(pool.len(), 0);
pool.add(vec![0.0f32; 5]);
assert_eq!(pool.len(), 1);
}
#[test]
fn test_pool_alloc_zst() {
let pool = BufferPool::new().with_min_size(0);
let vec = pool.alloc::<()>(128);
assert_eq!(vec.capacity(), usize::MAX);
pool.add(vec);
let vec = pool.alloc::<()>(512);
assert_eq!(vec.capacity(), usize::MAX);
pool.add(vec);
assert_eq!(pool.alloc_count(), 2);
assert_eq!(pool.hit_count(), 1);
}
#[test]
fn test_pool_alloc_non_copy_type() {
let pool = BufferPool::new().with_min_size(0);
let mut vec = pool.alloc::<String>(5);
vec.push("hello".into());
vec.push("world".into());
let ptr = vec.as_ptr();
pool.add(vec);
let vec = pool.alloc::<String>(3);
assert_eq!(vec.as_ptr(), ptr);
assert_eq!(vec.capacity(), 5);
assert_eq!(vec.len(), 0);
}
#[test]
fn test_pool_ref_auto_return() {
let pool = BufferPool::new().with_min_size(0);
assert_eq!(pool.len(), 0);
let tensor = NdTensor::<f32, 2>::zeros_in(&pool, [2, 2]).auto_return(&pool);
assert_eq!(tensor.shape(), [2, 2]);
assert_eq!(pool.alloc_count(), 1);
assert_eq!(pool.len(), 0);
let copy = tensor.to_contiguous_in(&pool).auto_return(&pool);
std::mem::drop(copy);
assert_eq!(pool.alloc_count(), 1);
assert_eq!(pool.len(), 0);
let copy = tensor
.transposed()
.to_contiguous_in(&pool)
.auto_return(&pool);
std::mem::drop(copy);
assert_eq!(pool.alloc_count(), 2);
assert_eq!(pool.len(), 1);
std::mem::drop(tensor);
assert_eq!(pool.len(), 2);
let non_empty = Vec::<f32>::with_capacity(16).auto_return(&pool);
std::mem::drop(non_empty);
assert_eq!(pool.len(), 3);
let empty = Vec::<f32>::new().auto_return(&pool);
std::mem::drop(empty);
assert_eq!(pool.len(), 3);
}
#[test]
fn test_pool_ref_take() {
let pool = BufferPool::new().with_min_size(0);
assert_eq!(pool.len(), 0);
{
let tensor = NdTensor::<f32, 2>::zeros_in(&pool, [2, 2]).auto_return(&pool);
assert_eq!(tensor.shape(), [2, 2]);
tensor.take(); }
assert_eq!(pool.len(), 0);
}
}