use crate::*;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use torsh_core::{DType, Result as TorshResult, Shape};
static MEMORY_POOL: std::sync::OnceLock<Arc<RwLock<SparseMemoryPool>>> = std::sync::OnceLock::new();
fn get_memory_pool() -> &'static Arc<RwLock<SparseMemoryPool>> {
MEMORY_POOL.get_or_init(|| Arc::new(RwLock::new(SparseMemoryPool::new())))
}
#[derive(Debug, Clone)]
pub struct MemoryPoolConfig {
pub max_total_memory: usize,
pub max_allocation_size: usize,
pub num_size_buckets: usize,
pub memory_timeout: Duration,
pub enable_tracking: bool,
pub gc_interval: Duration,
}
impl Default for MemoryPoolConfig {
fn default() -> Self {
Self {
max_total_memory: 1024 * 1024 * 1024, max_allocation_size: 128 * 1024 * 1024, num_size_buckets: 16,
memory_timeout: Duration::from_secs(300), enable_tracking: true,
gc_interval: Duration::from_secs(60), }
}
}
#[derive(Debug)]
struct MemoryBucket {
size_range: (usize, usize),
available_blocks: VecDeque<MemoryBlock>,
total_allocated: usize,
active_allocations: usize,
}
#[derive(Debug)]
struct MemoryBlock {
memory: Vec<u8>,
allocated_at: Instant,
last_accessed: Instant,
ref_count: usize,
#[allow(dead_code)]
id: u64,
}
#[derive(Debug)]
pub struct SparseMemoryPool {
config: MemoryPoolConfig,
buckets: Vec<MemoryBucket>,
total_allocated: usize,
stats: MemoryStatistics,
last_gc: Instant,
next_block_id: u64,
allocation_cache: HashMap<String, Vec<u64>>,
}
#[derive(Debug, Clone)]
pub struct MemoryStatistics {
pub total_allocated_bytes: usize,
pub total_deallocated_bytes: usize,
pub active_allocations: usize,
pub current_memory_usage: usize,
pub peak_memory_usage: usize,
pub allocation_requests: usize,
pub pool_reuses: usize,
pub garbage_collections: usize,
pub average_allocation_size: f64,
}
impl Default for MemoryStatistics {
fn default() -> Self {
Self {
total_allocated_bytes: 0,
total_deallocated_bytes: 0,
active_allocations: 0,
current_memory_usage: 0,
peak_memory_usage: 0,
allocation_requests: 0,
pool_reuses: 0,
garbage_collections: 0,
average_allocation_size: 0.0,
}
}
}
impl MemoryBucket {
fn new(size_range: (usize, usize)) -> Self {
Self {
size_range,
available_blocks: VecDeque::new(),
total_allocated: 0,
active_allocations: 0,
}
}
fn can_handle(&self, size: usize) -> bool {
size >= self.size_range.0 && size <= self.size_range.1
}
fn allocate(&mut self, size: usize, block_id: u64) -> Option<MemoryBlock> {
if !self.can_handle(size) {
return None;
}
if let Some(mut block) = self.available_blocks.pop_front() {
block.last_accessed = Instant::now();
block.ref_count = 1;
return Some(block);
}
let actual_size = self.size_range.1; let memory = vec![0u8; actual_size];
let now = Instant::now();
self.total_allocated += actual_size;
self.active_allocations += 1;
Some(MemoryBlock {
memory,
allocated_at: now,
last_accessed: now,
ref_count: 1,
id: block_id,
})
}
fn deallocate(&mut self, block: MemoryBlock) {
self.active_allocations = self.active_allocations.saturating_sub(1);
self.available_blocks.push_back(block);
}
fn cleanup_expired(&mut self, timeout: Duration) -> usize {
let now = Instant::now();
let initial_count = self.available_blocks.len();
self.available_blocks.retain(|block| {
let should_keep = now.duration_since(block.last_accessed) < timeout;
if !should_keep {
self.total_allocated = self.total_allocated.saturating_sub(block.memory.len());
}
should_keep
});
initial_count - self.available_blocks.len()
}
}
impl Default for SparseMemoryPool {
fn default() -> Self {
Self::new()
}
}
impl SparseMemoryPool {
pub fn new() -> Self {
Self::with_config(MemoryPoolConfig::default())
}
pub fn with_config(config: MemoryPoolConfig) -> Self {
let buckets = Self::create_buckets(&config);
Self {
config,
buckets,
total_allocated: 0,
stats: MemoryStatistics::default(),
last_gc: Instant::now(),
next_block_id: 1,
allocation_cache: HashMap::new(),
}
}
fn create_buckets(config: &MemoryPoolConfig) -> Vec<MemoryBucket> {
let mut buckets = Vec::new();
let max_size = config.max_allocation_size;
let num_buckets = config.num_size_buckets;
let mut size = 1024; for _ in 0..num_buckets {
let next_size = (size * 2).min(max_size);
buckets.push(MemoryBucket::new((size, next_size)));
size = next_size;
if size >= max_size {
break;
}
}
buckets
}
pub fn allocate(
&mut self,
size: usize,
allocation_type: &str,
) -> TorshResult<SparseMemoryHandle> {
if size > self.config.max_allocation_size {
return Err(torsh_core::TorshError::Other(format!(
"Allocation size {} exceeds maximum {}",
size, self.config.max_allocation_size
)));
}
if self.total_allocated + size > self.config.max_total_memory {
self.garbage_collect();
if self.total_allocated + size > self.config.max_total_memory {
return Err(torsh_core::TorshError::Other(
"Memory pool capacity exceeded".to_string(),
));
}
}
let bucket_idx = self.find_bucket_for_size(size);
let block_id = self.next_block_id;
self.next_block_id += 1;
let block = if let Some(idx) = bucket_idx {
self.buckets[idx].allocate(size, block_id)
} else {
let memory = vec![0u8; size];
let now = Instant::now();
Some(MemoryBlock {
memory,
allocated_at: now,
last_accessed: now,
ref_count: 1,
id: block_id,
})
};
if let Some(block) = block {
self.total_allocated += block.memory.len();
self.stats.total_allocated_bytes += block.memory.len();
self.stats.allocation_requests += 1;
self.stats.active_allocations += 1;
self.stats.current_memory_usage = self.total_allocated;
if self.total_allocated > self.stats.peak_memory_usage {
self.stats.peak_memory_usage = self.total_allocated;
}
self.stats.average_allocation_size =
self.stats.total_allocated_bytes as f64 / self.stats.allocation_requests as f64;
self.allocation_cache
.entry(allocation_type.to_string())
.or_default()
.push(block_id);
Ok(SparseMemoryHandle::new(block, bucket_idx))
} else {
Err(torsh_core::TorshError::Other(
"Failed to allocate memory".to_string(),
))
}
}
fn find_bucket_for_size(&self, size: usize) -> Option<usize> {
self.buckets
.iter()
.position(|bucket| bucket.can_handle(size))
}
pub fn deallocate(&mut self, handle: SparseMemoryHandle) {
let (block, bucket_idx) = handle.into_parts();
self.total_allocated = self.total_allocated.saturating_sub(block.memory.len());
self.stats.total_deallocated_bytes += block.memory.len();
self.stats.active_allocations = self.stats.active_allocations.saturating_sub(1);
self.stats.current_memory_usage = self.total_allocated;
if let Some(idx) = bucket_idx {
if idx < self.buckets.len() {
self.buckets[idx].deallocate(block);
self.stats.pool_reuses += 1;
}
}
}
pub fn garbage_collect(&mut self) {
let now = Instant::now();
if now.duration_since(self.last_gc) < self.config.gc_interval {
return; }
let mut _total_freed = 0;
for bucket in &mut self.buckets {
_total_freed += bucket.cleanup_expired(self.config.memory_timeout);
}
self.allocation_cache.retain(|_, ids| {
ids.retain(|_| true); !ids.is_empty()
});
self.stats.garbage_collections += 1;
self.last_gc = now;
}
pub fn force_garbage_collect(&mut self) {
self.last_gc = Instant::now() - self.config.gc_interval;
self.garbage_collect();
}
pub fn statistics(&self) -> MemoryStatistics {
self.stats.clone()
}
pub fn usage_by_type(&self) -> HashMap<String, usize> {
let mut usage = HashMap::new();
usage.insert("sparse_matrices".to_string(), self.total_allocated);
usage
}
pub fn is_healthy(&self) -> bool {
let usage_ratio = self.total_allocated as f64 / self.config.max_total_memory as f64;
usage_ratio < 0.8 }
pub fn efficiency_score(&self) -> f64 {
if self.stats.allocation_requests == 0 {
return 1.0;
}
let reuse_ratio = self.stats.pool_reuses as f64 / self.stats.allocation_requests as f64;
let fragmentation_ratio =
1.0 - (self.stats.current_memory_usage as f64 / self.stats.peak_memory_usage as f64);
(reuse_ratio + (1.0 - fragmentation_ratio)) / 2.0
}
}
pub struct SparseMemoryHandle {
block: Option<MemoryBlock>,
bucket_idx: Option<usize>,
}
impl SparseMemoryHandle {
fn new(block: MemoryBlock, bucket_idx: Option<usize>) -> Self {
Self {
block: Some(block),
bucket_idx,
}
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
if let Some(ref mut block) = self.block {
block.last_accessed = Instant::now();
&mut block.memory
} else {
&mut []
}
}
pub fn as_slice(&self) -> &[u8] {
if let Some(ref block) = self.block {
&block.memory
} else {
&[]
}
}
pub fn size(&self) -> usize {
self.block.as_ref().map_or(0, |b| b.memory.len())
}
pub fn age(&self) -> Duration {
self.block.as_ref().map_or(Duration::ZERO, |b| {
Instant::now().duration_since(b.allocated_at)
})
}
pub fn is_valid(&self) -> bool {
self.block.is_some()
}
fn into_parts(mut self) -> (MemoryBlock, Option<usize>) {
let block = self.block.take().expect("Handle should have a block");
(block, self.bucket_idx)
}
}
impl Drop for SparseMemoryHandle {
fn drop(&mut self) {
if let Some(block) = self.block.take() {
if let Ok(mut pool) = get_memory_pool().write() {
pool.deallocate(SparseMemoryHandle {
block: Some(block),
bucket_idx: self.bucket_idx,
});
}
}
}
}
pub struct SparseMemoryManager;
impl SparseMemoryManager {
pub fn allocate(size: usize, allocation_type: &str) -> TorshResult<SparseMemoryHandle> {
get_memory_pool()
.write()
.expect("lock should not be poisoned")
.allocate(size, allocation_type)
}
pub fn global_statistics() -> MemoryStatistics {
get_memory_pool()
.read()
.expect("lock should not be poisoned")
.statistics()
}
pub fn force_garbage_collect() {
get_memory_pool()
.write()
.expect("lock should not be poisoned")
.force_garbage_collect();
}
pub fn is_healthy() -> bool {
get_memory_pool()
.read()
.expect("lock should not be poisoned")
.is_healthy()
}
pub fn efficiency_score() -> f64 {
get_memory_pool()
.read()
.expect("lock should not be poisoned")
.efficiency_score()
}
pub fn configure(config: MemoryPoolConfig) {
let mut pool = get_memory_pool()
.write()
.expect("lock should not be poisoned");
*pool = SparseMemoryPool::with_config(config);
}
pub fn generate_report() -> MemoryReport {
let pool = get_memory_pool()
.read()
.expect("lock should not be poisoned");
let stats = pool.statistics();
let usage_by_type = pool.usage_by_type();
let is_healthy = pool.is_healthy();
let efficiency = pool.efficiency_score();
MemoryReport {
statistics: stats,
usage_by_type,
is_healthy,
efficiency_score: efficiency,
recommendations: Self::generate_recommendations(&pool),
}
}
fn generate_recommendations(pool: &SparseMemoryPool) -> Vec<String> {
let mut recommendations = Vec::new();
let stats = &pool.stats;
let usage_ratio = pool.total_allocated as f64 / pool.config.max_total_memory as f64;
if usage_ratio > 0.9 {
recommendations.push("Memory usage is very high. Consider increasing pool size or optimizing allocations.".to_string());
} else if usage_ratio > 0.8 {
recommendations.push(
"Memory usage is high. Monitor closely and consider optimization.".to_string(),
);
}
let reuse_ratio = if stats.allocation_requests > 0 {
stats.pool_reuses as f64 / stats.allocation_requests as f64
} else {
0.0
};
if reuse_ratio < 0.3 {
recommendations.push("Low memory reuse detected. Consider adjusting bucket sizes or allocation patterns.".to_string());
}
if stats.garbage_collections == 0 && stats.allocation_requests > 100 {
recommendations.push(
"No garbage collections performed. Consider enabling automatic GC.".to_string(),
);
}
let fragmentation = if stats.peak_memory_usage > 0 {
1.0 - (stats.current_memory_usage as f64 / stats.peak_memory_usage as f64)
} else {
0.0
};
if fragmentation > 0.5 {
recommendations.push(
"High memory fragmentation detected. Consider more frequent garbage collection."
.to_string(),
);
}
if recommendations.is_empty() {
recommendations.push("Memory management appears optimal.".to_string());
}
recommendations
}
}
#[derive(Debug, Clone)]
pub struct MemoryReport {
pub statistics: MemoryStatistics,
pub usage_by_type: HashMap<String, usize>,
pub is_healthy: bool,
pub efficiency_score: f64,
pub recommendations: Vec<String>,
}
pub struct MemoryAwareSparseBuilder {
format: SparseFormat,
estimated_nnz: usize,
memory_handles: Vec<SparseMemoryHandle>,
optimization_hints: Vec<String>,
}
impl MemoryAwareSparseBuilder {
pub fn new(format: SparseFormat, estimated_nnz: usize) -> Self {
Self {
format,
estimated_nnz,
memory_handles: Vec::new(),
optimization_hints: Vec::new(),
}
}
pub fn pre_allocate(&mut self) -> TorshResult<()> {
let memory_needed = self.estimate_memory_requirements();
let chunk_size = 1024 * 1024; let num_chunks = memory_needed.div_ceil(chunk_size);
for i in 0..num_chunks {
let size = if i == num_chunks - 1 {
memory_needed - i * chunk_size
} else {
chunk_size
};
let handle = SparseMemoryManager::allocate(size, "sparse_builder")?;
self.memory_handles.push(handle);
}
Ok(())
}
fn estimate_memory_requirements(&self) -> usize {
match self.format {
SparseFormat::Coo => {
self.estimated_nnz * (2 * std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
}
SparseFormat::Csr | SparseFormat::Csc => {
self.estimated_nnz * (std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
+ 1000 * std::mem::size_of::<usize>() }
_ => self.estimated_nnz * 3 * std::mem::size_of::<f32>(), }
}
pub fn build(
self,
data: &[(usize, usize, f32)],
shape: Shape,
) -> TorshResult<Box<dyn SparseTensor>> {
match self.format {
SparseFormat::Coo => {
let mut coo = CooTensor::empty(shape, DType::F32)?;
for &(row, col, val) in data {
coo.insert(row, col, val)?;
}
Ok(Box::new(coo))
}
SparseFormat::Csr => {
let coo = {
let mut coo = CooTensor::empty(shape, DType::F32)?;
for &(row, col, val) in data {
coo.insert(row, col, val)?;
}
coo
};
Ok(Box::new(coo.to_csr()?))
}
_ => {
let mut coo = CooTensor::empty(shape, DType::F32)?;
for &(row, col, val) in data {
coo.insert(row, col, val)?;
}
Ok(Box::new(coo))
}
}
}
pub fn optimization_hints(&self) -> &[String] {
&self.optimization_hints
}
}
pub fn create_sparse_with_memory_management(
data: &[(usize, usize, f32)],
shape: Shape,
format: SparseFormat,
) -> TorshResult<Box<dyn SparseTensor>> {
let mut builder = MemoryAwareSparseBuilder::new(format, data.len());
builder.pre_allocate()?;
builder.build(data, shape)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_creation() {
let pool = SparseMemoryPool::new();
assert!(pool.is_healthy());
assert_eq!(pool.total_allocated, 0);
}
#[test]
fn test_memory_allocation() {
let mut pool = SparseMemoryPool::new();
let handle = pool.allocate(1024, "test").unwrap();
assert!(handle.is_valid());
assert!(handle.size() >= 1024); assert!(pool.total_allocated > 0);
}
#[test]
fn test_memory_statistics() {
let mut pool = SparseMemoryPool::new();
let _handle1 = pool.allocate(1024, "test").unwrap();
let _handle2 = pool.allocate(2048, "test").unwrap();
let stats = pool.statistics();
assert_eq!(stats.allocation_requests, 2);
assert_eq!(stats.active_allocations, 2);
assert!(stats.current_memory_usage > 0);
}
#[test]
fn test_garbage_collection() {
let mut pool = SparseMemoryPool::with_config(MemoryPoolConfig {
memory_timeout: Duration::from_millis(1),
..Default::default()
});
{
let _handle = pool.allocate(1024, "test").unwrap();
}
std::thread::sleep(Duration::from_millis(10));
pool.force_garbage_collect();
let stats = pool.statistics();
assert!(stats.garbage_collections > 0);
}
#[test]
fn test_global_memory_manager() {
let handle = SparseMemoryManager::allocate(1024, "test").unwrap();
assert!(handle.is_valid());
let stats = SparseMemoryManager::global_statistics();
assert!(stats.allocation_requests > 0);
assert!(SparseMemoryManager::is_healthy());
}
#[test]
fn test_memory_report() {
let _handle = SparseMemoryManager::allocate(1024, "test").unwrap();
let report = SparseMemoryManager::generate_report();
assert!(report.statistics.allocation_requests > 0);
assert!(!report.recommendations.is_empty());
}
#[test]
fn test_memory_aware_builder() {
let mut builder = MemoryAwareSparseBuilder::new(SparseFormat::Coo, 10);
builder.pre_allocate().unwrap();
let data = vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)];
let shape = Shape::new(vec![3, 3]);
let sparse = builder.build(&data, shape).unwrap();
assert_eq!(sparse.nnz(), 3);
}
#[test]
fn test_memory_handle_operations() {
let mut handle = SparseMemoryManager::allocate(1024, "test").unwrap();
{
let slice = handle.as_mut_slice();
slice[0] = 42;
slice[1] = 24;
}
let slice = handle.as_slice();
assert_eq!(slice[0], 42);
assert_eq!(slice[1], 24);
assert!(handle.age() >= Duration::ZERO);
assert!(handle.size() >= 1024); }
}