#![allow(clippy::result_large_err)]
use crate::{Device, Result, TensorError};
use std::sync::Arc;
pub trait TensorBuffer: Send + Sync {
type Elem: Clone;
fn device(&self) -> &Device;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn size_bytes(&self) -> usize;
fn clone_buffer(&self) -> Result<Box<dyn TensorBuffer<Elem = Self::Elem>>>;
fn view(&self, offset: usize, len: usize) -> Result<Box<dyn TensorBuffer<Elem = Self::Elem>>>;
fn to_cpu(&self) -> Result<Vec<Self::Elem>>;
unsafe fn as_ptr(&self) -> *const Self::Elem;
unsafe fn as_mut_ptr(&mut self) -> *mut Self::Elem;
}
pub struct SharedBuffer<T> {
data: Arc<dyn TensorBuffer<Elem = T>>,
offset: usize,
len: usize,
}
impl<T: Clone + Send + Sync + 'static> SharedBuffer<T> {
pub fn new(buffer: Box<dyn TensorBuffer<Elem = T>>) -> Self {
let len = buffer.len();
Self {
data: Arc::from(buffer),
offset: 0,
len,
}
}
pub fn view(&self, offset: usize, len: usize) -> Result<Self> {
if offset + len > self.len {
return Err(TensorError::invalid_argument(format!(
"View out of bounds: offset={offset}, len={len}, buffer_len={}",
self.len
)));
}
Ok(Self {
data: Arc::clone(&self.data),
offset: self.offset + offset,
len,
})
}
pub fn ref_count(&self) -> usize {
Arc::strong_count(&self.data)
}
}
pub struct CpuBuffer<T> {
data: Vec<T>,
device: Device,
}
impl<T: Clone + Send + Sync> CpuBuffer<T> {
pub fn new(data: Vec<T>) -> Self {
Self {
data,
device: Device::Cpu,
}
}
pub fn zeros(len: usize) -> Self
where
T: Default,
{
Self {
data: vec![T::default(); len],
device: Device::Cpu,
}
}
}
impl<T: Clone + Send + Sync + 'static> TensorBuffer for CpuBuffer<T> {
type Elem = T;
fn device(&self) -> &Device {
&self.device
}
fn len(&self) -> usize {
self.data.len()
}
fn size_bytes(&self) -> usize {
self.data.len() * std::mem::size_of::<T>()
}
fn clone_buffer(&self) -> Result<Box<dyn TensorBuffer<Elem = Self::Elem>>> {
Ok(Box::new(Self {
data: self.data.clone(),
device: self.device,
}))
}
fn view(&self, offset: usize, len: usize) -> Result<Box<dyn TensorBuffer<Elem = Self::Elem>>> {
if offset + len > self.data.len() {
return Err(TensorError::invalid_argument(format!(
"View out of bounds: offset={offset}, len={len}, buffer_len={}",
self.data.len()
)));
}
Ok(Box::new(Self {
data: self.data[offset..offset + len].to_vec(),
device: self.device,
}))
}
fn to_cpu(&self) -> Result<Vec<Self::Elem>> {
Ok(self.data.clone())
}
unsafe fn as_ptr(&self) -> *const Self::Elem {
self.data.as_ptr()
}
unsafe fn as_mut_ptr(&mut self) -> *mut Self::Elem {
self.data.as_mut_ptr()
}
}
type PoolKey = (Device, usize);
type PoolValue = Vec<Box<dyn std::any::Any + Send>>;
type PoolMap = std::collections::HashMap<PoolKey, PoolValue>;
pub struct MemoryPool {
pools: std::sync::Mutex<PoolMap>,
max_pool_size: usize,
}
impl MemoryPool {
pub fn new(max_pool_size: usize) -> Self {
Self {
pools: std::sync::Mutex::new(std::collections::HashMap::new()),
max_pool_size,
}
}
pub fn allocate<
T: Clone + Send + Sync + Default + bytemuck::Pod + bytemuck::Zeroable + 'static,
>(
&self,
device: Device,
len: usize,
) -> Box<dyn TensorBuffer<Elem = T>> {
let key = (device, std::mem::size_of::<T>());
let mut pools = self.pools.lock().expect("lock should not be poisoned");
if let Some(pool) = pools.get_mut(&key) {
for i in 0..pool.len() {
if let Some(buffer) = pool[i].downcast_ref::<CpuBuffer<T>>() {
if buffer.len() >= len {
let recycled = pool.swap_remove(i);
if let Ok(mut buffer) = recycled.downcast::<CpuBuffer<T>>() {
buffer.data.resize(len, T::default());
return buffer;
}
}
}
}
}
match device {
Device::Cpu => Box::new(CpuBuffer::zeros(len)),
#[cfg(feature = "gpu")]
Device::Gpu(id) => {
use crate::gpu::buffer::GpuBuffer;
match GpuBuffer::<T>::zeros(len, id) {
Ok(buf) => Box::new(buf),
Err(_) => Box::new(CpuBuffer::zeros(len)), }
}
#[cfg(feature = "rocm")]
Device::Rocm(id) => {
use crate::gpu::buffer::GpuBuffer;
match GpuBuffer::<T>::zeros(len, id) {
Ok(buf) => Box::new(buf),
Err(_) => Box::new(CpuBuffer::zeros(len)), }
}
}
}
pub fn deallocate<T: 'static>(&self, device: Device, buffer: Box<dyn std::any::Any + Send>) {
let key = (device, std::mem::size_of::<T>());
let mut pools = self.pools.lock().expect("lock should not be poisoned");
let pool = pools.entry(key).or_default();
if pool.len() < self.max_pool_size {
pool.push(buffer);
}
}
}
lazy_static::lazy_static! {
pub static ref MEMORY_POOL: MemoryPool = MemoryPool::new(100);
}