use crate::error::{LinalgError, LinalgResult};
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::fmt::Debug;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocationStrategy {
BestFit,
FirstFit,
NextFit,
Buddy,
}
#[derive(Debug, Clone)]
pub struct MemoryPoolConfig {
pub pool_size: usize,
pub min_block_size: usize,
pub max_block_size: usize,
pub alignment: usize,
pub strategy: AllocationStrategy,
pub enable_defrag: bool,
pub defrag_threshold: f64,
pub pressure_threshold: f64,
pub max_cache_age: Duration,
}
impl Default for MemoryPoolConfig {
fn default() -> Self {
#[cfg(target_pointer_width = "32")]
let pool_size = 256 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
let pool_size = 1024 * 1024 * 1024;
Self {
pool_size,
min_block_size: 256,
max_block_size: 64 * 1024 * 1024, alignment: 256, strategy: AllocationStrategy::BestFit,
enable_defrag: true,
defrag_threshold: 0.3,
pressure_threshold: 0.9,
max_cache_age: Duration::from_secs(60),
}
}
}
#[derive(Debug, Clone)]
struct MemoryBlock {
offset: usize,
size: usize,
in_use: bool,
allocation_id: Option<usize>,
last_access: Instant,
}
#[derive(Debug, Clone)]
pub struct AllocationHandle {
pub id: usize,
pub offset: usize,
pub size: usize,
created_at: Instant,
}
#[derive(Debug, Clone, Default)]
pub struct MemoryStats {
pub total_size: usize,
pub allocated_bytes: usize,
pub free_bytes: usize,
pub active_allocations: usize,
pub total_allocations: usize,
pub total_deallocations: usize,
pub peak_usage: usize,
pub fragmented_blocks: usize,
pub fragmentation_ratio: f64,
pub cache_hit_rate: f64,
pub cache_hits: usize,
pub cache_misses: usize,
}
pub struct GpuMemoryPool {
config: MemoryPoolConfig,
free_blocks: RwLock<BTreeMap<usize, Vec<usize>>>,
blocks: RwLock<HashMap<usize, MemoryBlock>>,
allocations: RwLock<HashMap<usize, AllocationHandle>>,
block_cache: Mutex<HashMap<usize, VecDeque<usize>>>,
next_id: Mutex<usize>,
stats: RwLock<MemoryStats>,
last_offset: Mutex<usize>,
}
impl GpuMemoryPool {
pub fn new() -> Self {
Self::with_config(MemoryPoolConfig::default())
}
pub fn with_config(config: MemoryPoolConfig) -> Self {
let pool_size = config.pool_size;
let mut blocks = HashMap::new();
blocks.insert(0, MemoryBlock {
offset: 0,
size: pool_size,
in_use: false,
allocation_id: None,
last_access: Instant::now(),
});
let mut free_blocks = BTreeMap::new();
free_blocks.insert(pool_size, vec![0]);
let stats = MemoryStats {
total_size: pool_size,
free_bytes: pool_size,
..Default::default()
};
Self {
config,
free_blocks: RwLock::new(free_blocks),
blocks: RwLock::new(blocks),
allocations: RwLock::new(HashMap::new()),
block_cache: Mutex::new(HashMap::new()),
next_id: Mutex::new(1),
stats: RwLock::new(stats),
last_offset: Mutex::new(0),
}
}
pub fn allocate(&self, size: usize) -> LinalgResult<AllocationHandle> {
let aligned_size = self.align_size(size);
if aligned_size > self.config.max_block_size {
return Err(LinalgError::ComputationError(format!(
"Allocation size {} exceeds maximum block size {}",
aligned_size, self.config.max_block_size
)));
}
if let Some(offset) = self.try_cache(aligned_size) {
return self.complete_allocation(offset, aligned_size);
}
let block_offset = match self.config.strategy {
AllocationStrategy::BestFit => self.find_best_fit(aligned_size)?,
AllocationStrategy::FirstFit => self.find_first_fit(aligned_size)?,
AllocationStrategy::NextFit => self.find_next_fit(aligned_size)?,
AllocationStrategy::Buddy => self.find_buddy_block(aligned_size)?,
};
if let Ok(mut stats) = self.stats.write() {
stats.cache_misses += 1;
self.update_cache_hit_rate(&mut stats);
}
self.complete_allocation(block_offset, aligned_size)
}
fn try_cache(&self, size: usize) -> Option<usize> {
if let Ok(mut cache) = self.block_cache.lock() {
if let Some(offsets) = cache.get_mut(&size) {
if let Some(offset) = offsets.pop_front() {
if let Ok(mut stats) = self.stats.write() {
stats.cache_hits += 1;
self.update_cache_hit_rate(&mut stats);
}
return Some(offset);
}
}
}
None
}
fn complete_allocation(&self, offset: usize, size: usize) -> LinalgResult<AllocationHandle> {
let id = {
let mut id_guard = self.next_id.lock()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let id = *id_guard;
*id_guard += 1;
id
};
{
let mut blocks = self.blocks.write()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let mut free_blocks = self.free_blocks.write()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
if let Some(block) = blocks.get_mut(&offset) {
let original_size = block.size;
if original_size > size {
let remaining_offset = offset + size;
let remaining_size = original_size - size;
blocks.insert(remaining_offset, MemoryBlock {
offset: remaining_offset,
size: remaining_size,
in_use: false,
allocation_id: None,
last_access: Instant::now(),
});
free_blocks.entry(remaining_size)
.or_default()
.push(remaining_offset);
}
block.size = size;
block.in_use = true;
block.allocation_id = Some(id);
block.last_access = Instant::now();
if let Some(offsets) = free_blocks.get_mut(&original_size) {
offsets.retain(|&o| o != offset);
if offsets.is_empty() {
free_blocks.remove(&original_size);
}
}
}
}
let handle = AllocationHandle {
id,
offset,
size,
created_at: Instant::now(),
};
if let Ok(mut allocs) = self.allocations.write() {
allocs.insert(id, handle.clone());
}
if let Ok(mut stats) = self.stats.write() {
stats.allocated_bytes += size;
stats.free_bytes = stats.free_bytes.saturating_sub(size);
stats.active_allocations += 1;
stats.total_allocations += 1;
stats.peak_usage = stats.peak_usage.max(stats.allocated_bytes);
}
Ok(handle)
}
fn find_best_fit(&self, size: usize) -> LinalgResult<usize> {
let free_blocks = self.free_blocks.read()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
for (&block_size, offsets) in free_blocks.range(size..) {
if let Some(&offset) = offsets.first() {
return Ok(offset);
}
}
Err(LinalgError::ComputationError(
"No suitable free block found".to_string()
))
}
fn find_first_fit(&self, size: usize) -> LinalgResult<usize> {
let blocks = self.blocks.read()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let mut offsets: Vec<_> = blocks.iter()
.filter(|(_, b)| !b.in_use && b.size >= size)
.map(|(&o, _)| o)
.collect();
offsets.sort();
offsets.into_iter().next()
.ok_or_else(|| LinalgError::ComputationError(
"No suitable free block found".to_string()
))
}
fn find_next_fit(&self, size: usize) -> LinalgResult<usize> {
let last_offset = *self.last_offset.lock()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let blocks = self.blocks.read()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let mut offsets: Vec<_> = blocks.iter()
.filter(|(_, b)| !b.in_use && b.size >= size)
.map(|(&o, _)| o)
.collect();
offsets.sort();
for &offset in &offsets {
if offset >= last_offset {
if let Ok(mut last) = self.last_offset.lock() {
*last = offset;
}
return Ok(offset);
}
}
for &offset in &offsets {
if let Ok(mut last) = self.last_offset.lock() {
*last = offset;
}
return Ok(offset);
}
Err(LinalgError::ComputationError(
"No suitable free block found".to_string()
))
}
fn find_buddy_block(&self, size: usize) -> LinalgResult<usize> {
let buddy_size = size.next_power_of_two();
self.find_best_fit(buddy_size)
}
pub fn deallocate(&self, handle: &AllocationHandle) -> LinalgResult<()> {
let offset = handle.offset;
let size = handle.size;
if let Ok(mut allocs) = self.allocations.write() {
allocs.remove(&handle.id);
}
{
let mut blocks = self.blocks.write()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let mut free_blocks = self.free_blocks.write()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
if let Some(block) = blocks.get_mut(&offset) {
block.in_use = false;
block.allocation_id = None;
block.last_access = Instant::now();
free_blocks.entry(size)
.or_default()
.push(offset);
}
}
self.try_coalesce(offset)?;
if let Ok(mut cache) = self.block_cache.lock() {
let offsets = cache.entry(size).or_default();
if offsets.len() < 16 { offsets.push_back(offset);
}
}
if let Ok(mut stats) = self.stats.write() {
stats.allocated_bytes = stats.allocated_bytes.saturating_sub(size);
stats.free_bytes += size;
stats.active_allocations = stats.active_allocations.saturating_sub(1);
stats.total_deallocations += 1;
}
Ok(())
}
fn try_coalesce(&self, offset: usize) -> LinalgResult<()> {
let mut blocks = self.blocks.write()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let mut free_blocks = self.free_blocks.write()
.map_err(|_| LinalgError::ComputationError("Lock poisoned".to_string()))?;
let (current_size, current_end) = {
if let Some(block) = blocks.get(&offset) {
if block.in_use {
return Ok(()); }
(block.size, offset + block.size)
} else {
return Ok(());
}
};
if let Some(next_block) = blocks.get(¤t_end).cloned() {
if !next_block.in_use {
blocks.remove(¤t_end);
if let Some(offsets) = free_blocks.get_mut(&next_block.size) {
offsets.retain(|&o| o != current_end);
if offsets.is_empty() {
free_blocks.remove(&next_block.size);
}
}
if let Some(block) = blocks.get_mut(&offset) {
if let Some(offsets) = free_blocks.get_mut(&block.size) {
offsets.retain(|&o| o != offset);
if offsets.is_empty() {
free_blocks.remove(&block.size);
}
}
block.size += next_block.size;
free_blocks.entry(block.size)
.or_default()
.push(offset);
}
}
}
let prev_info: Option<(usize, usize)> = blocks.iter()
.filter(|(_, b)| !b.in_use && b.offset + b.size == offset)
.map(|(&o, b)| (o, b.size))
.next();
if let Some((prev_offset, prev_size)) = prev_info {
let current_size_now = blocks.get(&offset).map(|b| b.size).unwrap_or(0);
blocks.remove(&offset);
if let Some(offsets) = free_blocks.get_mut(¤t_size_now) {
offsets.retain(|&o| o != offset);
if offsets.is_empty() {
free_blocks.remove(¤t_size_now);
}
}
if let Some(offsets) = free_blocks.get_mut(&prev_size) {
offsets.retain(|&o| o != prev_offset);
if offsets.is_empty() {
free_blocks.remove(&prev_size);
}
}
if let Some(prev_block) = blocks.get_mut(&prev_offset) {
prev_block.size += current_size_now;
free_blocks.entry(prev_block.size)
.or_default()
.push(prev_offset);
}
}
Ok(())
}
pub fn stats(&self) -> MemoryStats {
self.stats.read()
.map(|s| s.clone())
.unwrap_or_default()
}
pub fn fragmentation_ratio(&self) -> f64 {
if let Ok(blocks) = self.blocks.read() {
let free_blocks: Vec<_> = blocks.values()
.filter(|b| !b.in_use)
.collect();
if free_blocks.is_empty() {
return 0.0;
}
let total_free: usize = free_blocks.iter().map(|b| b.size).sum();
let largest_free = free_blocks.iter().map(|b| b.size).max().unwrap_or(0);
if total_free == 0 {
return 0.0;
}
1.0 - (largest_free as f64 / total_free as f64)
} else {
0.0
}
}
pub fn maybe_defragment(&self) -> LinalgResult<bool> {
let frag_ratio = self.fragmentation_ratio();
if frag_ratio > self.config.defrag_threshold && self.config.enable_defrag {
self.defragment()?;
return Ok(true);
}
Ok(false)
}
pub fn defragment(&self) -> LinalgResult<()> {
if let Ok(mut stats) = self.stats.write() {
stats.fragmentation_ratio = self.fragmentation_ratio();
stats.fragmented_blocks = self.count_fragmented_blocks();
}
Ok(())
}
fn count_fragmented_blocks(&self) -> usize {
self.blocks.read()
.map(|blocks| blocks.values().filter(|b| !b.in_use).count())
.unwrap_or(0)
}
pub fn evict_old_caches(&self) -> LinalgResult<usize> {
let now = Instant::now();
let mut evicted = 0;
if let Ok(mut cache) = self.block_cache.lock() {
for offsets in cache.values_mut() {
let initial_len = offsets.len();
while offsets.len() > 8 {
offsets.pop_front();
evicted += 1;
}
evicted += initial_len.saturating_sub(offsets.len());
}
}
Ok(evicted)
}
pub fn reset(&self) -> LinalgResult<()> {
let pool_size = self.config.pool_size;
if let Ok(mut blocks) = self.blocks.write() {
blocks.clear();
blocks.insert(0, MemoryBlock {
offset: 0,
size: pool_size,
in_use: false,
allocation_id: None,
last_access: Instant::now(),
});
}
if let Ok(mut free_blocks) = self.free_blocks.write() {
free_blocks.clear();
free_blocks.insert(pool_size, vec![0]);
}
if let Ok(mut allocs) = self.allocations.write() {
allocs.clear();
}
if let Ok(mut cache) = self.block_cache.lock() {
cache.clear();
}
if let Ok(mut stats) = self.stats.write() {
*stats = MemoryStats {
total_size: pool_size,
free_bytes: pool_size,
..Default::default()
};
}
Ok(())
}
fn align_size(&self, size: usize) -> usize {
let alignment = self.config.alignment;
((size + alignment - 1) / alignment) * alignment
}
fn update_cache_hit_rate(&self, stats: &mut MemoryStats) {
let total = stats.cache_hits + stats.cache_misses;
if total > 0 {
stats.cache_hit_rate = stats.cache_hits as f64 / total as f64;
}
}
}
impl Default for GpuMemoryPool {
fn default() -> Self {
Self::new()
}
}
pub struct SharedMemoryPool {
pool: Arc<GpuMemoryPool>,
}
impl SharedMemoryPool {
pub fn new() -> Self {
Self {
pool: Arc::new(GpuMemoryPool::new()),
}
}
pub fn with_config(config: MemoryPoolConfig) -> Self {
Self {
pool: Arc::new(GpuMemoryPool::with_config(config)),
}
}
pub fn allocate(&self, size: usize) -> LinalgResult<AllocationHandle> {
self.pool.allocate(size)
}
pub fn deallocate(&self, handle: &AllocationHandle) -> LinalgResult<()> {
self.pool.deallocate(handle)
}
pub fn stats(&self) -> MemoryStats {
self.pool.stats()
}
pub fn clone_ref(&self) -> Self {
Self {
pool: Arc::clone(&self.pool),
}
}
}
impl Default for SharedMemoryPool {
fn default() -> Self {
Self::new()
}
}
impl Clone for SharedMemoryPool {
fn clone(&self) -> Self {
self.clone_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_allocation() {
let pool = GpuMemoryPool::new();
let handle1 = pool.allocate(1024).expect("Allocation failed");
assert_eq!(handle1.size, 1024);
let stats = pool.stats();
assert_eq!(stats.active_allocations, 1);
assert!(stats.allocated_bytes >= 1024);
}
#[test]
fn test_multiple_allocations() {
let pool = GpuMemoryPool::new();
let handles: Vec<_> = (0..10)
.map(|_| pool.allocate(1024).expect("Allocation failed"))
.collect();
assert_eq!(handles.len(), 10);
let stats = pool.stats();
assert_eq!(stats.active_allocations, 10);
}
#[test]
fn test_allocation_and_deallocation() {
let pool = GpuMemoryPool::new();
let handle = pool.allocate(1024).expect("Allocation failed");
assert_eq!(pool.stats().active_allocations, 1);
pool.deallocate(&handle).expect("Deallocation failed");
assert_eq!(pool.stats().active_allocations, 0);
}
#[test]
fn test_reuse_after_deallocation() {
let pool = GpuMemoryPool::new();
let handle1 = pool.allocate(1024).expect("Allocation failed");
let offset1 = handle1.offset;
pool.deallocate(&handle1).expect("Deallocation failed");
let handle2 = pool.allocate(1024).expect("Allocation failed");
assert!(handle2.offset == offset1 || pool.stats().cache_hits > 0);
}
#[test]
fn test_fragmentation_calculation() {
let config = MemoryPoolConfig {
pool_size: 10240,
min_block_size: 256,
..Default::default()
};
let pool = GpuMemoryPool::with_config(config);
let handles: Vec<_> = (0..5)
.map(|_| pool.allocate(1024).expect("Allocation failed"))
.collect();
for (i, handle) in handles.iter().enumerate() {
if i % 2 == 0 {
pool.deallocate(handle).expect("Deallocation failed");
}
}
let frag_ratio = pool.fragmentation_ratio();
assert!(frag_ratio >= 0.0 && frag_ratio <= 1.0);
}
#[test]
fn test_reset() {
let pool = GpuMemoryPool::new();
for _ in 0..10 {
let _ = pool.allocate(1024);
}
assert!(pool.stats().active_allocations > 0);
pool.reset().expect("Reset failed");
let stats = pool.stats();
assert_eq!(stats.active_allocations, 0);
assert_eq!(stats.allocated_bytes, 0);
}
#[test]
fn test_shared_pool() {
let pool = SharedMemoryPool::new();
let pool2 = pool.clone_ref();
let handle = pool.allocate(1024).expect("Allocation failed");
assert_eq!(pool2.stats().active_allocations, 1);
pool2.deallocate(&handle).expect("Deallocation failed");
assert_eq!(pool.stats().active_allocations, 0);
}
#[test]
fn test_alignment() {
let config = MemoryPoolConfig {
alignment: 256,
..Default::default()
};
let pool = GpuMemoryPool::with_config(config);
let handle = pool.allocate(100).expect("Allocation failed");
assert!(handle.size >= 100);
assert!(handle.size % 256 == 0 || handle.size == 256);
}
}