use quantrs2_core::buffer_pool::BufferPool;
use scirs2_core::Complex64;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
static GLOBAL_BUFFER_MANAGER: OnceLock<Arc<Mutex<GlobalBufferManager>>> = OnceLock::new();
pub struct GlobalBufferManager {
f64_pool: BufferPool<f64>,
complex_pool: BufferPool<Complex64>,
vector_pools: HashMap<usize, Vec<Vec<f64>>>,
parameter_pool: BufferPool<f64>,
stats: MemoryStats,
}
#[derive(Debug, Default, Clone)]
pub struct MemoryStats {
pub total_allocated: usize,
pub peak_usage: usize,
pub pool_hits: usize,
pub pool_misses: usize,
pub fragmentation_ratio: f64,
}
impl GlobalBufferManager {
fn new() -> Self {
Self {
f64_pool: BufferPool::new(), complex_pool: BufferPool::new(),
vector_pools: HashMap::with_capacity(16),
parameter_pool: BufferPool::new(),
stats: MemoryStats::default(),
}
}
pub fn get_f64_buffer(&mut self, size: usize) -> Vec<f64> {
self.stats.total_allocated += size * std::mem::size_of::<f64>();
self.update_peak_usage();
self.stats.pool_hits += 1;
let mut buffer = self.f64_pool.get(size);
buffer.resize(size, 0.0);
buffer
}
pub fn return_f64_buffer(&mut self, buffer: Vec<f64>) {
if buffer.len() <= 10000 && buffer.capacity() <= 20000 {
self.f64_pool.put(buffer);
}
}
pub fn get_complex_buffer(&mut self, size: usize) -> Vec<Complex64> {
self.stats.total_allocated += size * std::mem::size_of::<Complex64>();
self.update_peak_usage();
self.stats.pool_hits += 1;
let mut buffer = self.complex_pool.get(size);
buffer.resize(size, Complex64::new(0.0, 0.0));
buffer
}
pub fn return_complex_buffer(&mut self, buffer: Vec<Complex64>) {
if buffer.len() <= 10000 && buffer.capacity() <= 20000 {
self.complex_pool.put(buffer);
}
}
pub fn get_sized_vector(&mut self, size: usize) -> Vec<f64> {
if let Some(pool) = self.vector_pools.get_mut(&size) {
if let Some(vec) = pool.pop() {
self.stats.pool_hits += 1;
return vec;
}
}
self.stats.pool_misses += 1;
vec![0.0; size]
}
pub fn return_sized_vector(&mut self, mut vector: Vec<f64>) {
let size = vector.len();
vector.clear();
if size <= 1024 {
let pool = self.vector_pools.entry(size).or_default();
if pool.len() < 10 {
pool.push(vector);
}
}
}
pub fn get_parameter_buffer(&mut self, size: usize) -> Vec<f64> {
self.stats.pool_hits += 1;
let mut buffer = self.parameter_pool.get(size);
buffer.resize(size, 0.0);
buffer
}
pub fn return_parameter_buffer(&mut self, buffer: Vec<f64>) {
if buffer.len() <= 100 {
self.parameter_pool.put(buffer);
}
}
pub fn collect_garbage(&mut self) {
self.vector_pools.retain(|&size, pool| {
pool.retain(|v| v.capacity() < size * 2);
size <= 1024 && !pool.is_empty()
});
let allocated = self.stats.total_allocated;
let peak = self.stats.peak_usage;
self.stats.fragmentation_ratio = if peak > 0 {
1.0 - (allocated as f64 / peak as f64)
} else {
0.0
};
}
pub const fn get_stats(&self) -> &MemoryStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = MemoryStats::default();
}
const fn update_peak_usage(&mut self) {
if self.stats.total_allocated > self.stats.peak_usage {
self.stats.peak_usage = self.stats.total_allocated;
}
}
}
pub struct BufferManager;
impl BufferManager {
pub fn instance() -> Arc<Mutex<GlobalBufferManager>> {
GLOBAL_BUFFER_MANAGER
.get_or_init(|| Arc::new(Mutex::new(GlobalBufferManager::new())))
.clone()
}
#[must_use]
pub fn alloc_f64_buffer(size: usize) -> Vec<f64> {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.get_f64_buffer(size)
}
pub fn free_f64_buffer(buffer: Vec<f64>) {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.return_f64_buffer(buffer);
}
#[must_use]
pub fn alloc_complex_buffer(size: usize) -> Vec<Complex64> {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.get_complex_buffer(size)
}
pub fn free_complex_buffer(buffer: Vec<Complex64>) {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.return_complex_buffer(buffer);
}
#[must_use]
pub fn alloc_parameter_buffer(size: usize) -> Vec<f64> {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.get_parameter_buffer(size)
}
pub fn free_parameter_buffer(buffer: Vec<f64>) {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.return_parameter_buffer(buffer);
}
#[must_use]
pub fn get_memory_stats() -> MemoryStats {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.get_stats()
.clone()
}
pub fn collect_garbage() {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.collect_garbage();
}
pub fn reset_stats() {
Self::instance()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.reset_stats();
}
}
pub struct ManagedF64Buffer {
buffer: Option<Vec<f64>>,
}
impl ManagedF64Buffer {
#[must_use]
pub fn new(size: usize) -> Self {
Self {
buffer: Some(BufferManager::alloc_f64_buffer(size)),
}
}
pub const fn as_mut(&mut self) -> &mut Vec<f64> {
self.buffer
.as_mut()
.expect("buffer was already taken or not initialized")
}
#[must_use]
pub const fn as_ref(&self) -> &Vec<f64> {
self.buffer
.as_ref()
.expect("buffer was already taken or not initialized")
}
#[must_use]
pub fn take(mut self) -> Vec<f64> {
self.buffer
.take()
.expect("buffer was already taken or not initialized")
}
}
impl Drop for ManagedF64Buffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
BufferManager::free_f64_buffer(buffer);
}
}
}
pub struct ManagedComplexBuffer {
buffer: Option<Vec<Complex64>>,
}
impl ManagedComplexBuffer {
#[must_use]
pub fn new(size: usize) -> Self {
Self {
buffer: Some(BufferManager::alloc_complex_buffer(size)),
}
}
pub const fn as_mut(&mut self) -> &mut Vec<Complex64> {
self.buffer
.as_mut()
.expect("buffer was already taken or not initialized")
}
#[must_use]
pub const fn as_ref(&self) -> &Vec<Complex64> {
self.buffer
.as_ref()
.expect("buffer was already taken or not initialized")
}
#[must_use]
pub fn take(mut self) -> Vec<Complex64> {
self.buffer
.take()
.expect("buffer was already taken or not initialized")
}
}
impl Drop for ManagedComplexBuffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
BufferManager::free_complex_buffer(buffer);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_pooling() {
let buffer1 = BufferManager::alloc_f64_buffer(100);
assert_eq!(buffer1.len(), 100);
BufferManager::free_f64_buffer(buffer1);
let buffer2 = BufferManager::alloc_f64_buffer(100);
assert_eq!(buffer2.len(), 100);
BufferManager::free_f64_buffer(buffer2);
let stats = BufferManager::get_memory_stats();
assert!(stats.pool_hits > 0 || stats.pool_misses > 0);
}
#[test]
fn test_managed_buffer() {
{
let mut managed = ManagedF64Buffer::new(50);
managed.as_mut()[0] = 42.0;
assert_eq!(managed.as_ref()[0], 42.0);
}
let stats = BufferManager::get_memory_stats();
assert!(stats.total_allocated > 0);
}
#[test]
fn test_complex_buffer_pooling() {
let buffer1 = BufferManager::alloc_complex_buffer(50);
assert_eq!(buffer1.len(), 50);
BufferManager::free_complex_buffer(buffer1);
let buffer2 = BufferManager::alloc_complex_buffer(50);
assert_eq!(buffer2.len(), 50);
BufferManager::free_complex_buffer(buffer2);
}
#[test]
fn test_garbage_collection() {
for _ in 0..10 {
let buffer = BufferManager::alloc_f64_buffer(1000);
BufferManager::free_f64_buffer(buffer);
}
BufferManager::collect_garbage();
let stats = BufferManager::get_memory_stats();
assert!(stats.fragmentation_ratio >= 0.0);
}
}