#![allow(dead_code)]
use std::alloc::{GlobalAlloc, Layout, System};
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::mem::{align_of, size_of};
use std::ptr::NonNull;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use torsh_core::{
dtype::TensorElement,
error::{Result, TorshError},
};
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub enable_pooling: bool,
pub pool_size: usize,
pub max_cached_per_size: usize,
pub enable_compression: bool,
pub compression_threshold: usize,
pub enable_numa_awareness: bool,
pub cache_line_size: usize,
pub enable_predictive_allocation: bool,
pub memory_pressure_threshold: f64,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
enable_pooling: true,
pool_size: 1024 * 1024 * 1024, max_cached_per_size: 64,
enable_compression: true,
compression_threshold: 100 * 1024 * 1024, enable_numa_awareness: false, cache_line_size: 64,
enable_predictive_allocation: true,
memory_pressure_threshold: 0.8,
}
}
}
pub struct AdvancedMemoryPool<T: TensorElement> {
config: MemoryConfig,
size_class_pools: RwLock<BTreeMap<usize, VecDeque<NonNull<T>>>>,
stats: RwLock<MemoryStats>,
allocation_history: Mutex<VecDeque<AllocationRecord>>,
predictor: Mutex<Option<AllocationPredictor>>,
compression_manager: Arc<CompressionManager>,
numa_allocators: Vec<Arc<Mutex<NumaAllocator>>>,
}
impl<T: TensorElement> AdvancedMemoryPool<T> {
pub fn new() -> Self {
Self::with_config(MemoryConfig::default())
}
pub fn with_config(config: MemoryConfig) -> Self {
let numa_nodes = if config.enable_numa_awareness {
detect_numa_nodes()
} else {
1
};
let numa_allocators = (0..numa_nodes)
.map(|node_id| Arc::new(Mutex::new(NumaAllocator::new(node_id))))
.collect();
Self {
config,
size_class_pools: RwLock::new(BTreeMap::new()),
stats: RwLock::new(MemoryStats::default()),
allocation_history: Mutex::new(VecDeque::with_capacity(10000)),
predictor: Mutex::new(None),
compression_manager: Arc::new(CompressionManager::new()),
numa_allocators,
}
}
pub fn allocate(&self, size: usize) -> Result<NonNull<T>> {
#[cfg(feature = "profiling")]
{
}
let aligned_size = self.align_size(size);
if self.config.enable_compression && size > self.config.compression_threshold {
return self.allocate_compressed(aligned_size);
}
if let Some(ptr) = self.try_reuse_from_pool(aligned_size)? {
self.record_allocation(aligned_size, true);
return Ok(ptr);
}
if self.config.enable_predictive_allocation {
self.maybe_predictive_allocate(aligned_size)?;
}
let ptr = self.allocate_new(aligned_size)?;
self.record_allocation(aligned_size, false);
Ok(ptr)
}
pub fn deallocate(&self, ptr: NonNull<T>, size: usize) -> Result<()> {
#[cfg(feature = "profiling")]
{
}
let aligned_size = self.align_size(size);
if self.compression_manager.is_compressed(ptr) {
return self.compression_manager.deallocate(ptr);
}
if self.should_cache_allocation(aligned_size) {
let mut pools = self
.size_class_pools
.write()
.expect("lock should not be poisoned");
let pool = pools.entry(aligned_size).or_insert_with(VecDeque::new);
if pool.len() < self.config.max_cached_per_size {
pool.push_back(ptr);
self.update_stats(|stats| stats.pooled_allocations += 1);
return Ok(());
}
}
self.free_allocation(ptr, aligned_size)?;
Ok(())
}
fn try_reuse_from_pool(&self, size: usize) -> Result<Option<NonNull<T>>> {
if !self.config.enable_pooling {
return Ok(None);
}
let mut pools = self
.size_class_pools
.write()
.expect("lock should not be poisoned");
if let Some(pool) = pools.get_mut(&size) {
if let Some(ptr) = pool.pop_front() {
self.update_stats(|stats| stats.pool_hits += 1);
return Ok(Some(ptr));
}
}
let max_oversized = size * 2;
for (&pool_size, pool) in pools.range_mut(size..).take(5) {
if pool_size > max_oversized {
break;
}
if let Some(ptr) = pool.pop_front() {
self.update_stats(|stats| {
stats.pool_hits += 1;
stats.oversized_reuse += 1;
});
return Ok(Some(ptr));
}
}
self.update_stats(|stats| stats.pool_misses += 1);
Ok(None)
}
fn allocate_new(&self, size: usize) -> Result<NonNull<T>> {
let layout = Layout::from_size_align(
size * size_of::<T>(),
align_of::<T>().max(self.config.cache_line_size),
)
.map_err(|_| TorshError::InvalidArgument("Invalid memory layout".to_string()))?;
if self.config.enable_numa_awareness && !self.numa_allocators.is_empty() {
let numa_node = self.select_numa_node();
let allocator = &self.numa_allocators[numa_node];
let mut allocator = allocator.lock().expect("lock should not be poisoned");
return allocator.allocate(layout);
}
unsafe {
let ptr = System.alloc(layout);
if ptr.is_null() {
return Err(TorshError::AllocationError(
"Failed to allocate memory".to_string(),
));
}
self.prefault_pages(ptr, layout.size());
Ok(NonNull::new_unchecked(ptr as *mut T))
}
}
fn allocate_compressed(&self, size: usize) -> Result<NonNull<T>> {
self.compression_manager.allocate_compressed(size)
}
fn free_allocation(&self, ptr: NonNull<T>, size: usize) -> Result<()> {
let layout = Layout::from_size_align(
size * size_of::<T>(),
align_of::<T>().max(self.config.cache_line_size),
)
.map_err(|_| TorshError::InvalidArgument("Invalid memory layout".to_string()))?;
unsafe {
System.dealloc(ptr.as_ptr() as *mut u8, layout);
}
self.update_stats(|stats| stats.direct_deallocations += 1);
Ok(())
}
fn maybe_predictive_allocate(&self, size: usize) -> Result<()> {
let mut predictor_guard = self.predictor.lock().expect("lock should not be poisoned");
if predictor_guard.is_none() {
*predictor_guard = Some(AllocationPredictor::new());
}
if let Some(predictor) = predictor_guard.as_mut() {
if let Some(predicted_sizes) = predictor.predict_next_allocations(size) {
for predicted_size in predicted_sizes {
if predicted_size != size && predicted_size > 0 {
let _ = self.allocate_new(predicted_size);
}
}
}
}
Ok(())
}
fn align_size(&self, size: usize) -> usize {
let cache_line = self.config.cache_line_size;
((size + cache_line - 1) / cache_line) * cache_line
}
fn should_cache_allocation(&self, size: usize) -> bool {
self.config.enable_pooling &&
size <= self.config.pool_size / 100 && !self.is_memory_pressure_high()
}
fn is_memory_pressure_high(&self) -> bool {
let stats = self.stats.read().expect("lock should not be poisoned");
let total_allocations = stats.pool_hits + stats.pool_misses + stats.direct_allocations;
if total_allocations == 0 {
return false;
}
let cache_hit_rate = stats.pool_hits as f64 / total_allocations as f64;
cache_hit_rate < (1.0 - self.config.memory_pressure_threshold)
}
fn prefault_pages(&self, ptr: *mut u8, size: usize) {
const PAGE_SIZE: usize = 4096;
let page_count = (size + PAGE_SIZE - 1) / PAGE_SIZE;
unsafe {
for i in 0..page_count {
let page_ptr = ptr.add(i * PAGE_SIZE);
std::ptr::write_volatile(page_ptr, 0);
}
}
}
fn select_numa_node(&self) -> usize {
let stats = self.stats.read().expect("lock should not be poisoned");
(stats.total_allocations % self.numa_allocators.len()) as usize
}
fn record_allocation(&self, size: usize, was_reused: bool) {
let record = AllocationRecord {
size,
timestamp: Instant::now(),
was_reused,
};
let mut history = self
.allocation_history
.lock()
.expect("lock should not be poisoned");
history.push_back(record);
if history.len() > 10000 {
history.pop_front();
}
self.update_stats(|stats| {
stats.total_allocations += 1;
if was_reused {
stats.reused_allocations += 1;
} else {
stats.direct_allocations += 1;
}
});
}
fn update_stats<F>(&self, f: F)
where
F: FnOnce(&mut MemoryStats),
{
let mut stats = self.stats.write().expect("lock should not be poisoned");
f(&mut *stats);
}
pub fn get_stats(&self) -> MemoryStats {
self.stats
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn defragment(&self) -> Result<DefragmentationReport> {
#[cfg(feature = "profiling")]
{
}
let start_time = Instant::now();
let mut report = DefragmentationReport::default();
{
let mut pools = self
.size_class_pools
.write()
.expect("lock should not be poisoned");
let initial_pools = pools.len();
pools.retain(|_, pool| !pool.is_empty());
report.pools_cleaned = initial_pools - pools.len();
}
if self.config.enable_compression {
report.compression_stats = self.compression_manager.compress_fragmented()?;
}
report.duration = start_time.elapsed();
report.memory_freed = self.estimate_memory_freed();
Ok(report)
}
fn estimate_memory_freed(&self) -> usize {
let stats = self.stats.read().expect("lock should not be poisoned");
stats
.total_allocations
.saturating_sub(stats.reused_allocations)
* 1024 }
}
impl<T: TensorElement> Default for AdvancedMemoryPool<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoryStats {
pub total_allocations: usize,
pub direct_allocations: usize,
pub reused_allocations: usize,
pub pooled_allocations: usize,
pub pool_hits: usize,
pub pool_misses: usize,
pub oversized_reuse: usize,
pub direct_deallocations: usize,
pub compression_saves: usize,
pub numa_allocations: usize,
}
impl MemoryStats {
pub fn hit_rate(&self) -> f64 {
let total_pool_requests = self.pool_hits + self.pool_misses;
if total_pool_requests == 0 {
0.0
} else {
self.pool_hits as f64 / total_pool_requests as f64
}
}
pub fn reuse_rate(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
self.reused_allocations as f64 / self.total_allocations as f64
}
}
}
#[derive(Debug, Clone)]
struct AllocationRecord {
size: usize,
timestamp: Instant,
was_reused: bool,
}
struct AllocationPredictor {
size_patterns: HashMap<usize, Vec<usize>>,
temporal_patterns: VecDeque<(Instant, usize)>,
max_history: usize,
}
impl AllocationPredictor {
fn new() -> Self {
Self {
size_patterns: HashMap::new(),
temporal_patterns: VecDeque::new(),
max_history: 1000,
}
}
fn predict_next_allocations(&mut self, size: usize) -> Option<Vec<usize>> {
self.temporal_patterns.push_back((Instant::now(), size));
if self.temporal_patterns.len() > self.max_history {
self.temporal_patterns.pop_front();
}
if let Some(following_sizes) = self.size_patterns.get(&size) {
let mut counts: HashMap<usize, usize> = HashMap::new();
for &following_size in following_sizes {
*counts.entry(following_size).or_insert(0) += 1;
}
let mut sorted: Vec<_> = counts.into_iter().collect();
sorted.sort_by(|a, b| b.1.cmp(&a.1));
Some(sorted.into_iter().take(3).map(|(size, _)| size).collect())
} else {
None
}
}
}
struct CompressionManager {
compressed_allocations: RwLock<HashMap<usize, CompressedAllocation>>,
}
impl CompressionManager {
fn new() -> Self {
Self {
compressed_allocations: RwLock::new(HashMap::new()),
}
}
fn allocate_compressed<T: TensorElement>(&self, size: usize) -> Result<NonNull<T>> {
let compressed_size = size / 2;
let layout = Layout::from_size_align(compressed_size, align_of::<T>())
.map_err(|_| TorshError::InvalidArgument("Invalid layout".to_string()))?;
unsafe {
let ptr = System.alloc(layout);
if ptr.is_null() {
return Err(TorshError::AllocationError(
"Compression allocation failed".to_string(),
));
}
let allocation = CompressedAllocation {
original_size: size,
compressed_size,
compression_ratio: 0.5,
};
self.compressed_allocations
.write()
.expect("rwlock should not be poisoned")
.insert(ptr as usize, allocation);
Ok(NonNull::new_unchecked(ptr as *mut T))
}
}
fn is_compressed<T: TensorElement>(&self, ptr: NonNull<T>) -> bool {
self.compressed_allocations
.read()
.expect("rwlock should not be poisoned")
.contains_key(&(ptr.as_ptr() as usize))
}
fn deallocate<T: TensorElement>(&self, ptr: NonNull<T>) -> Result<()> {
let ptr_key = ptr.as_ptr() as usize;
let mut allocations = self
.compressed_allocations
.write()
.expect("lock should not be poisoned");
if let Some(allocation) = allocations.remove(&ptr_key) {
let layout = Layout::from_size_align(allocation.compressed_size, align_of::<T>())
.map_err(|_| TorshError::InvalidArgument("Invalid layout".to_string()))?;
unsafe {
System.dealloc(ptr_key as *mut u8, layout);
}
Ok(())
} else {
Err(TorshError::InvalidArgument(
"Allocation not found".to_string(),
))
}
}
fn compress_fragmented(&self) -> Result<CompressionStats> {
Ok(CompressionStats {
allocations_compressed: 0,
memory_saved: 0,
average_compression_ratio: 0.0,
})
}
}
#[derive(Debug, Clone)]
struct CompressedAllocation {
original_size: usize,
compressed_size: usize,
compression_ratio: f64,
}
struct NumaAllocator {
node_id: usize,
allocations: usize,
}
impl NumaAllocator {
fn new(node_id: usize) -> Self {
Self {
node_id,
allocations: 0,
}
}
fn allocate<T: TensorElement>(&mut self, layout: Layout) -> Result<NonNull<T>> {
unsafe {
let ptr = System.alloc(layout);
if ptr.is_null() {
return Err(TorshError::AllocationError(
"NUMA allocation failed".to_string(),
));
}
self.allocations += 1;
Ok(NonNull::new_unchecked(ptr as *mut T))
}
}
}
#[derive(Debug, Default)]
pub struct DefragmentationReport {
pub duration: Duration,
pub pools_cleaned: usize,
pub memory_freed: usize,
pub compression_stats: CompressionStats,
}
#[derive(Debug, Default)]
pub struct CompressionStats {
pub allocations_compressed: usize,
pub memory_saved: usize,
pub average_compression_ratio: f64,
}
fn detect_numa_nodes() -> usize {
1 }
pub struct GlobalMemoryOptimizer {
f32_pool: AdvancedMemoryPool<f32>,
f64_pool: AdvancedMemoryPool<f64>,
i32_pool: AdvancedMemoryPool<i32>,
i64_pool: AdvancedMemoryPool<i64>,
config: MemoryConfig,
}
impl GlobalMemoryOptimizer {
pub fn new() -> Self {
let config = MemoryConfig::default();
Self::with_config(config)
}
pub fn with_config(config: MemoryConfig) -> Self {
Self {
f32_pool: AdvancedMemoryPool::with_config(config.clone()),
f64_pool: AdvancedMemoryPool::with_config(config.clone()),
i32_pool: AdvancedMemoryPool::with_config(config.clone()),
i64_pool: AdvancedMemoryPool::with_config(config.clone()),
config,
}
}
pub fn get_pool<T: TensorElement>(&self) -> Option<&AdvancedMemoryPool<T>> {
None }
pub fn global_defragmentation(&self) -> Result<Vec<DefragmentationReport>> {
let mut reports = Vec::new();
reports.push(self.f32_pool.defragment()?);
reports.push(self.f64_pool.defragment()?);
Ok(reports)
}
pub fn get_aggregate_stats(&self) -> AggregateMemoryStats {
AggregateMemoryStats {
f32_stats: self.f32_pool.get_stats(),
f64_stats: self.f64_pool.get_stats(),
i32_stats: self.i32_pool.get_stats(),
i64_stats: self.i64_pool.get_stats(),
}
}
}
impl Default for GlobalMemoryOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct AggregateMemoryStats {
pub f32_stats: MemoryStats,
pub f64_stats: MemoryStats,
pub i32_stats: MemoryStats,
pub i64_stats: MemoryStats,
}
impl AggregateMemoryStats {
pub fn overall_hit_rate(&self) -> f64 {
let total_hits = self.f32_stats.pool_hits
+ self.f64_stats.pool_hits
+ self.i32_stats.pool_hits
+ self.i64_stats.pool_hits;
let total_misses = self.f32_stats.pool_misses
+ self.f64_stats.pool_misses
+ self.i32_stats.pool_misses
+ self.i64_stats.pool_misses;
let total_requests = total_hits + total_misses;
if total_requests == 0 {
0.0
} else {
total_hits as f64 / total_requests as f64
}
}
pub fn total_allocations(&self) -> usize {
self.f32_stats.total_allocations
+ self.f64_stats.total_allocations
+ self.i32_stats.total_allocations
+ self.i64_stats.total_allocations
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ptr;
#[test]
fn test_memory_config_default() {
let config = MemoryConfig::default();
assert!(config.enable_pooling);
assert!(config.pool_size > 0);
assert!(config.cache_line_size > 0);
}
#[test]
fn test_advanced_memory_pool_creation() {
let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
let stats = pool.get_stats();
assert_eq!(stats.total_allocations, 0);
assert_eq!(stats.pool_hits, 0);
assert_eq!(stats.pool_misses, 0);
}
#[test]
fn test_memory_allocation_and_deallocation() {
let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
let ptr = pool.allocate(1024).expect("allocation should succeed");
pool.deallocate(ptr, 1024)
.expect("deallocation should succeed");
let stats = pool.get_stats();
assert_eq!(stats.total_allocations, 1);
}
#[test]
fn test_memory_pool_reuse() {
let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
let ptr1 = pool.allocate(1024).expect("allocation should succeed");
pool.deallocate(ptr1, 1024)
.expect("deallocation should succeed");
let ptr2 = pool.allocate(1024).expect("allocation should succeed");
pool.deallocate(ptr2, 1024)
.expect("deallocation should succeed");
let stats = pool.get_stats();
assert_eq!(stats.total_allocations, 2);
}
#[test]
fn test_memory_stats_calculations() {
let mut stats = MemoryStats::default();
stats.pool_hits = 80;
stats.pool_misses = 20;
stats.total_allocations = 100;
stats.reused_allocations = 80;
assert_eq!(stats.hit_rate(), 0.8);
assert_eq!(stats.reuse_rate(), 0.8);
}
#[test]
fn test_size_alignment() {
let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::with_config(MemoryConfig {
cache_line_size: 64,
..Default::default()
});
assert_eq!(pool.align_size(1), 64);
assert_eq!(pool.align_size(65), 128);
assert_eq!(pool.align_size(128), 128);
}
#[test]
fn test_defragmentation() {
let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
for i in 0..10 {
let ptr = pool
.allocate(1024 * (i + 1))
.expect("allocation should succeed");
pool.deallocate(ptr, 1024 * (i + 1))
.expect("deallocation should succeed");
}
let report = pool.defragment().expect("defragmentation should succeed");
let _ = report.duration; }
#[test]
fn test_global_memory_optimizer() {
let optimizer = GlobalMemoryOptimizer::new();
let stats = optimizer.get_aggregate_stats();
assert_eq!(stats.total_allocations(), 0);
assert_eq!(stats.overall_hit_rate(), 0.0);
}
#[test]
fn test_compression_manager() {
let manager = CompressionManager::new();
let ptr = NonNull::new(ptr::null_mut::<f32>().wrapping_add(0x1000))
.expect("pointer should be non-null");
assert!(!manager.is_compressed(ptr));
}
#[test]
fn test_allocation_predictor() {
let mut predictor = AllocationPredictor::new();
let predictions = predictor.predict_next_allocations(1024);
assert!(predictions.is_none());
}
#[test]
fn test_memory_pressure_detection() {
let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::with_config(MemoryConfig {
memory_pressure_threshold: 0.5,
..Default::default()
});
assert!(!pool.is_memory_pressure_high());
}
}