use std::alloc::{alloc, dealloc, Layout};
use std::collections::VecDeque;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub initial_size: usize,
pub max_size: usize,
pub timeout: Duration,
pub alignment: usize,
pub enable_stats: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
initial_size: 10,
max_size: 100,
timeout: Duration::from_secs(60),
alignment: 8,
enable_stats: true,
}
}
}
pub trait MemoryPool<T> {
fn get(&self) -> Option<T>;
fn put(&self, item: T);
fn size(&self) -> usize;
fn clear(&self);
fn stats(&self) -> PoolStats;
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub allocations: u64,
pub deallocations: u64,
pub hits: u64,
pub misses: u64,
pub current_size: usize,
pub peak_size: usize,
pub hit_rate: f64,
}
pub struct AudioBufferPool {
buffers: Arc<Mutex<VecDeque<PooledBuffer>>>,
config: PoolConfig,
stats: Arc<RwLock<PoolStats>>,
}
#[derive(Debug)]
struct PooledBuffer {
data: Vec<f32>,
#[allow(dead_code)]
created_at: Instant,
last_used: Instant,
}
impl AudioBufferPool {
pub fn new(config: PoolConfig) -> Self {
let pool = Self {
buffers: Arc::new(Mutex::new(VecDeque::new())),
config,
stats: Arc::new(RwLock::new(PoolStats::default())),
};
pool.preallocate();
pool
}
pub fn with_default_config() -> Self {
Self::new(PoolConfig::default())
}
pub fn get_buffer(&self, size: usize) -> Vec<f32> {
let Ok(mut buffers) = self.buffers.lock() else {
return vec![0.0; size];
};
if let Some(pos) = buffers.iter().position(|buf| buf.data.len() >= size) {
let Some(mut pooled_buf) = buffers.remove(pos) else {
return vec![0.0; size];
};
pooled_buf.last_used = Instant::now();
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.hits += 1;
stats.current_size = buffers.len();
let total_requests = stats.hits + stats.misses;
stats.hit_rate = if total_requests > 0 {
stats.hits as f64 / total_requests as f64
} else {
0.0
};
}
}
pooled_buf.data.resize(size, 0.0);
pooled_buf.data.fill(0.0); pooled_buf.data
} else {
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.misses += 1;
stats.allocations += 1;
let total_requests = stats.hits + stats.misses;
stats.hit_rate = if total_requests > 0 {
stats.hits as f64 / total_requests as f64
} else {
0.0
};
}
}
vec![0.0; size]
}
}
pub fn return_buffer(&self, mut buffer: Vec<f32>) {
let Ok(mut buffers) = self.buffers.lock() else {
return;
};
if buffers.len() >= self.config.max_size {
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.deallocations += 1;
}
}
return;
}
buffer.fill(0.0);
let pooled_buffer = PooledBuffer {
data: buffer,
created_at: Instant::now(),
last_used: Instant::now(),
};
buffers.push_back(pooled_buffer);
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.current_size = buffers.len();
if stats.current_size > stats.peak_size {
stats.peak_size = stats.current_size;
}
}
}
}
pub fn cleanup_expired(&self) {
let Ok(mut buffers) = self.buffers.lock() else {
return;
};
let now = Instant::now();
let initial_len = buffers.len();
buffers.retain(|buf| now.duration_since(buf.last_used) < self.config.timeout);
if self.config.enable_stats {
let removed = initial_len - buffers.len();
if let Ok(mut stats) = self.stats.write() {
stats.deallocations += removed as u64;
stats.current_size = buffers.len();
}
}
}
fn preallocate(&self) {
let Ok(mut buffers) = self.buffers.lock() else {
return;
};
for _ in 0..self.config.initial_size {
let buffer = PooledBuffer {
data: Vec::with_capacity(1024), created_at: Instant::now(),
last_used: Instant::now(),
};
buffers.push_back(buffer);
}
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.current_size = buffers.len();
stats.peak_size = buffers.len();
}
}
}
}
impl MemoryPool<Vec<f32>> for AudioBufferPool {
fn get(&self) -> Option<Vec<f32>> {
Some(self.get_buffer(1024)) }
fn put(&self, item: Vec<f32>) {
self.return_buffer(item);
}
fn size(&self) -> usize {
self.buffers
.lock()
.map(|buffers| buffers.len())
.unwrap_or(0)
}
fn clear(&self) {
let Ok(mut buffers) = self.buffers.lock() else {
return;
};
let cleared_count = buffers.len();
buffers.clear();
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.deallocations += cleared_count as u64;
stats.current_size = 0;
}
}
}
fn stats(&self) -> PoolStats {
self.stats
.read()
.map(|stats| stats.clone())
.unwrap_or_default()
}
}
pub struct TensorPool {
tensors: Arc<Mutex<VecDeque<PooledTensor>>>,
config: PoolConfig,
stats: Arc<RwLock<PoolStats>>,
}
#[derive(Debug)]
struct PooledTensor {
ptr: NonNull<u8>,
layout: Layout,
size: usize,
#[allow(dead_code)]
created_at: Instant,
last_used: Instant,
}
unsafe impl Send for PooledTensor {}
unsafe impl Sync for PooledTensor {}
impl TensorPool {
pub fn new(config: PoolConfig) -> Self {
Self {
tensors: Arc::new(Mutex::new(VecDeque::new())),
config,
stats: Arc::new(RwLock::new(PoolStats::default())),
}
}
pub fn allocate(&self, size: usize) -> Option<NonNull<u8>> {
let mut tensors = self.tensors.lock().ok()?;
if let Some(pos) = tensors.iter().position(|tensor| tensor.size >= size) {
let mut pooled_tensor = tensors.remove(pos)?;
pooled_tensor.last_used = Instant::now();
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.hits += 1;
stats.current_size = tensors.len();
}
}
Some(pooled_tensor.ptr)
} else {
let layout = Layout::from_size_align(size, self.config.alignment).ok()?;
unsafe {
let ptr = alloc(layout);
if ptr.is_null() {
return None;
}
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.misses += 1;
stats.allocations += 1;
}
}
Some(NonNull::new_unchecked(ptr))
}
}
}
pub fn deallocate(&self, ptr: NonNull<u8>, size: usize) {
let Ok(mut tensors) = self.tensors.lock() else {
if let Ok(layout) = Layout::from_size_align(size, self.config.alignment) {
unsafe {
dealloc(ptr.as_ptr(), layout);
}
}
return;
};
if tensors.len() >= self.config.max_size {
unsafe {
if let Ok(layout) = Layout::from_size_align(size, self.config.alignment) {
dealloc(ptr.as_ptr(), layout);
}
}
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.deallocations += 1;
}
}
return;
}
let Ok(layout) = Layout::from_size_align(size, self.config.alignment) else {
return;
};
let pooled_tensor = PooledTensor {
ptr,
layout,
size,
created_at: Instant::now(),
last_used: Instant::now(),
};
tensors.push_back(pooled_tensor);
if self.config.enable_stats {
if let Ok(mut stats) = self.stats.write() {
stats.current_size = tensors.len();
if stats.current_size > stats.peak_size {
stats.peak_size = stats.current_size;
}
}
}
}
pub fn cleanup_expired(&self) {
let Ok(mut tensors) = self.tensors.lock() else {
return;
};
let now = Instant::now();
let initial_len = tensors.len();
let mut expired: Vec<PooledTensor> = Vec::new();
let mut i = 0;
while i < tensors.len() {
if now.duration_since(tensors[i].last_used) >= self.config.timeout {
if let Some(tensor) = tensors.remove(i) {
expired.push(tensor);
}
} else {
i += 1;
}
}
for tensor in expired {
unsafe {
dealloc(tensor.ptr.as_ptr(), tensor.layout);
}
}
if self.config.enable_stats {
let removed = initial_len - tensors.len();
if let Ok(mut stats) = self.stats.write() {
stats.deallocations += removed as u64;
stats.current_size = tensors.len();
}
}
}
pub fn stats(&self) -> PoolStats {
self.stats
.read()
.map(|stats| stats.clone())
.unwrap_or_default()
}
}
impl Default for AudioBufferPool {
fn default() -> Self {
Self::with_default_config()
}
}
impl Drop for TensorPool {
fn drop(&mut self) {
let Ok(mut tensors) = self.tensors.lock() else {
return;
};
for tensor in tensors.drain(..) {
unsafe {
dealloc(tensor.ptr.as_ptr(), tensor.layout);
}
}
}
}
pub struct ThreadLocalPools {
audio_pool: thread_local::ThreadLocal<AudioBufferPool>,
tensor_pool: thread_local::ThreadLocal<TensorPool>,
config: PoolConfig,
}
impl ThreadLocalPools {
pub fn new(config: PoolConfig) -> Self {
Self {
audio_pool: thread_local::ThreadLocal::new(),
tensor_pool: thread_local::ThreadLocal::new(),
config,
}
}
pub fn audio_pool(&self) -> &AudioBufferPool {
self.audio_pool
.get_or(|| AudioBufferPool::new(self.config.clone()))
}
pub fn tensor_pool(&self) -> &TensorPool {
self.tensor_pool
.get_or(|| TensorPool::new(self.config.clone()))
}
pub fn cleanup_all(&self) {
if let Some(audio_pool) = self.audio_pool.get() {
audio_pool.cleanup_expired();
}
if let Some(tensor_pool) = self.tensor_pool.get() {
tensor_pool.cleanup_expired();
}
}
pub fn aggregated_stats(&self) -> (PoolStats, PoolStats) {
let audio_stats = self.audio_pool.get().map(|p| p.stats()).unwrap_or_default();
let tensor_stats = self
.tensor_pool
.get()
.map(|p| p.stats())
.unwrap_or_default();
(audio_stats, tensor_stats)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_audio_buffer_pool_basic() {
let config = PoolConfig {
initial_size: 5,
max_size: 10,
timeout: Duration::from_secs(1),
..Default::default()
};
let pool = AudioBufferPool::new(config);
let buffer = pool.get_buffer(1024);
assert_eq!(buffer.len(), 1024);
assert!(buffer.iter().all(|&x| x == 0.0));
pool.return_buffer(buffer);
let buffer2 = pool.get_buffer(1024);
assert_eq!(buffer2.len(), 1024);
let stats = pool.stats();
assert!(stats.hits > 0 || stats.misses > 0);
}
#[test]
fn test_pool_cleanup() {
let config = PoolConfig {
initial_size: 2,
max_size: 10,
timeout: Duration::from_millis(50),
..Default::default()
};
let pool = AudioBufferPool::new(config);
for i in 0..3 {
let buffer = vec![0.0; 512 * (i + 1)];
pool.return_buffer(buffer);
}
let initial_size = pool.size();
assert!(initial_size > 0);
thread::sleep(Duration::from_millis(100));
pool.cleanup_expired();
let size_after_cleanup = pool.size();
assert!(size_after_cleanup <= initial_size);
}
#[test]
fn test_tensor_pool_allocation() {
let config = PoolConfig {
alignment: 32,
..Default::default()
};
let pool = TensorPool::new(config);
let ptr = pool.allocate(1024).unwrap();
pool.deallocate(ptr, 1024);
let ptr2 = pool.allocate(1024).unwrap();
pool.deallocate(ptr2, 1024);
}
#[test]
fn test_thread_local_pools() {
let config = PoolConfig::default();
let pools = ThreadLocalPools::new(config);
let audio_pool = pools.audio_pool();
let buffer = audio_pool.get_buffer(512);
assert_eq!(buffer.len(), 512);
audio_pool.return_buffer(buffer);
let tensor_pool = pools.tensor_pool();
let ptr = tensor_pool.allocate(256).unwrap();
tensor_pool.deallocate(ptr, 256);
pools.cleanup_all();
}
#[test]
fn test_pool_statistics() {
let config = PoolConfig {
enable_stats: true,
..Default::default()
};
let pool = AudioBufferPool::new(config);
let buffer1 = pool.get_buffer(512);
let buffer2 = pool.get_buffer(1024);
pool.return_buffer(buffer1);
pool.return_buffer(buffer2);
let buffer3 = pool.get_buffer(512);
let stats = pool.stats();
assert!(stats.hits > 0 || stats.misses > 0);
assert!(stats.allocations > 0);
assert!(stats.current_size > 0);
pool.return_buffer(buffer3);
}
}