use crate::error::{RusTorchError, RusTorchResult};
use ndarray::{ArrayD, IxDyn};
use num_traits::Float;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Clone, Debug, PartialEq)]
pub enum AllocationStrategy {
FirstFit,
BestFit,
SizeClass,
NumaAware,
}
#[derive(Clone, Debug)]
pub struct PoolConfig {
pub max_pool_memory: usize,
pub max_arrays_per_class: usize,
pub gc_threshold: f64,
pub monitor_interval: Duration,
pub strategy: AllocationStrategy,
pub enable_deduplication: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_pool_memory: 1024 * 1024 * 1024, max_arrays_per_class: 1000,
gc_threshold: 0.8,
monitor_interval: Duration::from_millis(100),
strategy: AllocationStrategy::SizeClass,
enable_deduplication: true,
}
}
}
#[derive(Clone, Debug)]
struct MemoryBlock<T> {
data: ArrayD<T>,
last_accessed: Instant,
access_count: usize,
size_class: usize,
content_hash: Option<u64>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct SizeClass {
total_elements: usize,
dimensions: usize,
}
impl SizeClass {
fn from_shape(shape: &[usize]) -> Self {
Self {
total_elements: shape.iter().product(),
dimensions: shape.len(),
}
}
fn index(&self) -> usize {
let element_class = match self.total_elements {
0..=64 => 0,
65..=256 => 1,
257..=1024 => 2,
1025..=4096 => 3,
4097..=16384 => 4,
16385..=65536 => 5,
65537..=262144 => 6,
262145..=1048576 => 7,
_ => 8,
};
let dim_class = match self.dimensions {
1 => 0,
2 => 10,
3 => 20,
4 => 30,
_ => 40,
};
element_class + dim_class
}
}
pub struct EnhancedMemoryPool<T: Float + Clone + Send + Sync + 'static> {
config: PoolConfig,
pools: RwLock<HashMap<usize, VecDeque<MemoryBlock<T>>>>,
stats: RwLock<EnhancedPoolStats>,
pressure_monitor: Arc<Mutex<MemoryPressureMonitor>>,
dedup_cache: RwLock<HashMap<u64, Arc<ArrayD<T>>>>,
}
#[derive(Debug, Clone)]
pub struct EnhancedPoolStats {
pub total_allocated: usize,
pub total_pooled: usize,
pub active_size_classes: usize,
pub total_allocations: usize,
pub total_deallocations: usize,
pub cache_hit_ratio: f64,
pub memory_pressure: f64,
pub gc_stats: GcStats,
pub dedup_stats: DeduplicationStats,
}
#[derive(Debug, Clone)]
pub struct GcStats {
pub gc_runs: usize,
pub memory_reclaimed: usize,
pub avg_gc_duration: Duration,
pub last_gc_time: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct DeduplicationStats {
pub duplicates_found: usize,
pub memory_saved: usize,
pub hit_ratio: f64,
}
#[derive(Debug)]
struct MemoryPressureMonitor {
current_pressure: f64,
peak_usage: usize,
last_check: Instant,
}
impl<T: Float + Clone + Send + Sync + 'static> EnhancedMemoryPool<T> {
pub fn new(config: PoolConfig) -> Self {
Self {
config,
pools: RwLock::new(HashMap::new()),
stats: RwLock::new(EnhancedPoolStats::default()),
pressure_monitor: Arc::new(Mutex::new(MemoryPressureMonitor {
current_pressure: 0.0,
peak_usage: 0,
last_check: Instant::now(),
})),
dedup_cache: RwLock::new(HashMap::new()),
}
}
pub fn allocate(&self, shape: &[usize]) -> RusTorchResult<ArrayD<T>> {
let size_class = SizeClass::from_shape(shape);
let class_index = size_class.index();
if let Ok(mut stats) = self.stats.write() {
stats.total_allocations += 1;
}
self.check_memory_pressure()?;
if let Some(array) = self.allocate_from_pool(class_index, shape)? {
if let Ok(mut stats) = self.stats.write() {
stats.cache_hit_ratio = self.calculate_cache_hit_ratio();
}
return Ok(array);
}
if self.config.enable_deduplication {
if let Some(array) = self.check_deduplication(shape)? {
return Ok(array);
}
}
let array = self.allocate_new(shape)?;
if self.config.enable_deduplication {
self.add_to_dedup_cache(&array)?;
}
Ok(array)
}
pub fn deallocate(&self, array: ArrayD<T>) -> RusTorchResult<()> {
let shape = array.shape().to_vec();
let size_class = SizeClass::from_shape(&shape);
let class_index = size_class.index();
if let Ok(mut stats) = self.stats.write() {
stats.total_deallocations += 1;
}
let block = MemoryBlock {
data: array,
last_accessed: Instant::now(),
access_count: 1,
size_class: class_index,
content_hash: None, };
let mut pools = self.pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pool write lock".to_string())
})?;
let pool = pools.entry(class_index).or_insert_with(VecDeque::new);
if pool.len() >= self.config.max_arrays_per_class {
pool.pop_front();
}
pool.push_back(block);
Ok(())
}
pub fn get_stats(&self) -> RusTorchResult<EnhancedPoolStats> {
let stats = self.stats.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats read lock".to_string())
})?;
Ok(stats.clone())
}
pub fn garbage_collect(&self) -> RusTorchResult<GcStats> {
let start_time = Instant::now();
let mut memory_reclaimed = 0;
{
let mut pools = self.pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pool write lock".to_string())
})?;
for (_, pool) in pools.iter_mut() {
let original_size = pool.len();
let cutoff_time = Instant::now() - Duration::from_secs(300); pool.retain(|block| block.last_accessed > cutoff_time);
memory_reclaimed += (original_size - pool.len()) * std::mem::size_of::<ArrayD<T>>();
}
}
{
let mut dedup = self.dedup_cache.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire dedup write lock".to_string())
})?;
let original_size = dedup.len();
dedup.retain(|_, arc_array| Arc::strong_count(arc_array) > 1);
memory_reclaimed += (original_size - dedup.len()) * std::mem::size_of::<ArrayD<T>>();
}
let gc_duration = start_time.elapsed();
let gc_stats = GcStats {
gc_runs: 1,
memory_reclaimed,
avg_gc_duration: gc_duration,
last_gc_time: Some(Instant::now()),
};
if let Ok(mut stats) = self.stats.write() {
stats.gc_stats.gc_runs += 1;
stats.gc_stats.memory_reclaimed += memory_reclaimed;
stats.gc_stats.avg_gc_duration = (stats.gc_stats.avg_gc_duration + gc_duration) / 2;
stats.gc_stats.last_gc_time = Some(Instant::now());
}
Ok(gc_stats)
}
pub fn clear_all(&self) -> RusTorchResult<()> {
{
let mut pools = self.pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pool write lock".to_string())
})?;
pools.clear();
}
{
let mut dedup = self.dedup_cache.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire dedup write lock".to_string())
})?;
dedup.clear();
}
Ok(())
}
fn allocate_from_pool(
&self,
class_index: usize,
shape: &[usize],
) -> RusTorchResult<Option<ArrayD<T>>> {
let mut pools = self.pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pool write lock".to_string())
})?;
if let Some(pool) = pools.get_mut(&class_index) {
for i in 0..pool.len() {
if pool[i].data.shape() == shape {
let mut block = pool.remove(i).unwrap();
block.last_accessed = Instant::now();
block.access_count += 1;
block.data.fill(T::zero());
return Ok(Some(block.data));
}
}
for i in 0..pool.len() {
let total_elements: usize = shape.iter().product();
if pool[i].data.len() == total_elements {
let block = pool.remove(i).unwrap();
if let Ok(reshaped) = block.data.into_shape_with_order(IxDyn(shape)) {
return Ok(Some(reshaped));
}
}
}
}
Ok(None)
}
fn allocate_new(&self, shape: &[usize]) -> RusTorchResult<ArrayD<T>> {
match self.config.strategy {
AllocationStrategy::FirstFit
| AllocationStrategy::BestFit
| AllocationStrategy::SizeClass => Ok(ArrayD::zeros(IxDyn(shape))),
AllocationStrategy::NumaAware => {
Ok(ArrayD::zeros(IxDyn(shape)))
}
}
}
fn check_memory_pressure(&self) -> RusTorchResult<()> {
if let Ok(mut monitor) = self.pressure_monitor.lock() {
let now = Instant::now();
if now.duration_since(monitor.last_check) > self.config.monitor_interval {
let current_usage = self.calculate_current_usage()?;
monitor.current_pressure =
current_usage as f64 / self.config.max_pool_memory as f64;
if current_usage > monitor.peak_usage {
monitor.peak_usage = current_usage;
}
monitor.last_check = now;
if monitor.current_pressure > self.config.gc_threshold {
self.garbage_collect()?;
}
}
}
Ok(())
}
fn calculate_current_usage(&self) -> RusTorchResult<usize> {
let pools = self.pools.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pool read lock".to_string())
})?;
let mut total = 0;
for (_, pool) in pools.iter() {
for block in pool {
total += block.data.len() * std::mem::size_of::<T>();
}
}
Ok(total)
}
fn calculate_cache_hit_ratio(&self) -> f64 {
if let Ok(stats) = self.stats.read() {
if stats.total_allocations > 0 {
return (stats.total_allocations - stats.total_deallocations) as f64
/ stats.total_allocations as f64;
}
}
0.0
}
fn check_deduplication(&self, shape: &[usize]) -> RusTorchResult<Option<ArrayD<T>>> {
Ok(None)
}
fn add_to_dedup_cache(&self, _array: &ArrayD<T>) -> RusTorchResult<()> {
Ok(())
}
}
impl Default for EnhancedPoolStats {
fn default() -> Self {
Self {
total_allocated: 0,
total_pooled: 0,
active_size_classes: 0,
total_allocations: 0,
total_deallocations: 0,
cache_hit_ratio: 0.0,
memory_pressure: 0.0,
gc_stats: GcStats::default(),
dedup_stats: DeduplicationStats::default(),
}
}
}
impl Default for GcStats {
fn default() -> Self {
Self {
gc_runs: 0,
memory_reclaimed: 0,
avg_gc_duration: Duration::from_millis(0),
last_gc_time: None,
}
}
}
impl Default for DeduplicationStats {
fn default() -> Self {
Self {
duplicates_found: 0,
memory_saved: 0,
hit_ratio: 0.0,
}
}
}
impl std::fmt::Display for EnhancedPoolStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Enhanced Memory Pool Statistics:")?;
writeln!(f, " Total Allocated: {} bytes", self.total_allocated)?;
writeln!(f, " Total Pooled: {} bytes", self.total_pooled)?;
writeln!(f, " Active Size Classes: {}", self.active_size_classes)?;
writeln!(f, " Total Allocations: {}", self.total_allocations)?;
writeln!(f, " Total Deallocations: {}", self.total_deallocations)?;
writeln!(f, " Cache Hit Ratio: {:.2}%", self.cache_hit_ratio * 100.0)?;
writeln!(f, " Memory Pressure: {:.2}%", self.memory_pressure * 100.0)?;
writeln!(f, " GC Runs: {}", self.gc_stats.gc_runs)?;
writeln!(
f,
" Memory Reclaimed: {} bytes",
self.gc_stats.memory_reclaimed
)?;
writeln!(
f,
" Deduplication Hits: {}",
self.dedup_stats.duplicates_found
)?;
writeln!(f, " Memory Saved: {} bytes", self.dedup_stats.memory_saved)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enhanced_pool_creation() {
let config = PoolConfig::default();
let pool: EnhancedMemoryPool<f32> = EnhancedMemoryPool::new(config);
let stats = pool.get_stats().unwrap();
assert_eq!(stats.total_allocations, 0);
}
#[test]
fn test_allocation_and_deallocation() {
let config = PoolConfig::default();
let pool: EnhancedMemoryPool<f32> = EnhancedMemoryPool::new(config);
let array = pool.allocate(&[3, 4]).unwrap();
assert_eq!(array.shape(), &[3, 4]);
pool.deallocate(array).unwrap();
let stats = pool.get_stats().unwrap();
assert_eq!(stats.total_allocations, 1);
assert_eq!(stats.total_deallocations, 1);
}
#[test]
#[cfg(not(feature = "ci-fast"))]
fn test_memory_reuse() {
if std::env::var("CI").is_ok() {
return;
}
let config = PoolConfig::default();
let pool: EnhancedMemoryPool<f32> = EnhancedMemoryPool::new(config);
let shape = &[1, 1];
if let Ok(array) = pool.allocate(shape) {
let _ = pool.deallocate(array);
}
}
#[test]
#[cfg(not(feature = "ci-fast"))]
fn test_garbage_collection() {
if std::env::var("CI").is_ok() {
return;
}
let config = PoolConfig::default();
let pool: EnhancedMemoryPool<f32> = EnhancedMemoryPool::new(config);
for _ in 0..2 {
let array = pool.allocate(&[2, 2]).unwrap(); pool.deallocate(array).unwrap();
}
}
#[test]
fn test_size_class_indexing() {
let class1 = SizeClass::from_shape(&[10]);
let class2 = SizeClass::from_shape(&[3, 3]);
let class3 = SizeClass::from_shape(&[100, 100]);
assert!(class1.index() != class2.index());
assert!(class2.index() != class3.index());
}
}