use async_trait::async_trait;
use ferrum_interfaces::memory::{
DefragmentationStats, DeviceMemoryManager, MemoryHandle, MemoryHandleInfo, MemoryInfo,
MemoryPoolConfig as InterfaceMemoryPoolConfig, MemoryPressure, MemoryTransfer, MemoryType,
StreamHandle,
};
use ferrum_types::{Device, Result};
use parking_lot::Mutex;
use std::collections::{HashMap, VecDeque};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
struct MemoryBlock {
handle: MemoryHandle,
size: usize,
is_free: bool,
allocated_at: std::time::Instant,
}
pub struct MemoryPool {
device: Device,
blocks: Mutex<VecDeque<MemoryBlock>>,
free_blocks: Mutex<HashMap<usize, VecDeque<usize>>>, total_allocated: Mutex<usize>,
peak_allocated: Mutex<usize>,
allocation_count: Mutex<u64>,
config: InternalMemoryPoolConfig,
}
#[derive(Debug, Clone)]
pub struct InternalMemoryPoolConfig {
pub initial_size: usize,
pub max_size: usize,
pub growth_factor: f32,
pub enable_defragmentation: bool,
pub min_pooled_size: usize,
pub max_pooled_size: usize,
pub size_buckets: usize,
}
impl Default for InternalMemoryPoolConfig {
fn default() -> Self {
Self {
initial_size: 256 * 1024 * 1024, max_size: 8 * 1024 * 1024 * 1024, growth_factor: 1.5,
enable_defragmentation: true,
min_pooled_size: 256, max_pooled_size: 128 * 1024 * 1024, size_buckets: 64,
}
}
}
impl MemoryPool {
pub fn new(device: Device, config: InternalMemoryPoolConfig) -> Self {
Self {
device,
blocks: Mutex::new(VecDeque::new()),
free_blocks: Mutex::new(HashMap::new()),
total_allocated: Mutex::new(0),
peak_allocated: Mutex::new(0),
allocation_count: Mutex::new(0),
config,
}
}
pub fn allocate(&self, size: usize) -> Result<MemoryHandle> {
let aligned_size = align_size(size, 256);
if let Some(handle) = self.try_allocate_from_pool(aligned_size) {
return Ok(handle);
}
self.allocate_new_block(aligned_size)
}
pub fn deallocate(&self, handle: MemoryHandle) -> Result<()> {
let mut blocks = self.blocks.lock();
for (index, block) in blocks.iter_mut().enumerate() {
if block.handle.id() == handle.id() {
block.is_free = true;
let size = block.size;
drop(blocks);
let mut free_blocks = self.free_blocks.lock();
free_blocks.entry(size).or_default().push_back(index);
debug!("Deallocated block of size {} bytes", size);
return Ok(());
}
}
warn!(
"Attempted to deallocate unknown memory handle: {:?}",
handle
);
Ok(())
}
pub fn stats(&self) -> MemoryInfo {
let blocks = self.blocks.lock();
let total_allocated = *self.total_allocated.lock();
let used_memory = blocks
.iter()
.filter(|b| !b.is_free)
.map(|b| b.size)
.sum::<usize>();
let free_memory = blocks
.iter()
.filter(|b| b.is_free)
.map(|b| b.size)
.sum::<usize>();
let fragmentation_ratio = if total_allocated > 0 {
let free_blocks_count = blocks.iter().filter(|b| b.is_free).count();
free_blocks_count as f32 / blocks.len() as f32
} else {
0.0
};
MemoryInfo {
total_bytes: total_allocated as u64,
used_bytes: used_memory as u64,
free_bytes: free_memory as u64,
reserved_bytes: 0,
active_allocations: blocks.iter().filter(|b| !b.is_free).count(),
fragmentation_ratio,
bandwidth_gbps: None,
}
}
pub fn defragment(&self) -> Result<()> {
if !self.config.enable_defragmentation {
return Ok(());
}
debug!(
"Starting memory pool defragmentation for device {:?}",
self.device
);
let mut blocks = self.blocks.lock();
let mut free_blocks = self.free_blocks.lock();
blocks.retain(|b| !b.is_free);
free_blocks.clear();
for (index, block) in blocks.iter().enumerate() {
if block.is_free {
free_blocks.entry(block.size).or_default().push_back(index);
}
}
debug!("Memory pool defragmentation completed");
Ok(())
}
fn try_allocate_from_pool(&self, size: usize) -> Option<MemoryHandle> {
let mut free_blocks = self.free_blocks.lock();
if let Some(indices) = free_blocks.get_mut(&size) {
if let Some(index) = indices.pop_front() {
let mut blocks = self.blocks.lock();
if let Some(block) = blocks.get_mut(index) {
block.is_free = false;
return Some(block.handle);
}
}
}
let mut best_fit: Option<(usize, usize)> = None;
for (&block_size, indices) in free_blocks.iter() {
if block_size >= size && (best_fit.is_none() || block_size < best_fit.unwrap().0) {
if let Some(&index) = indices.front() {
best_fit = Some((block_size, index));
}
}
}
if let Some((block_size, index)) = best_fit {
free_blocks.get_mut(&block_size)?.pop_front();
let mut blocks = self.blocks.lock();
if let Some(block) = blocks.get_mut(index) {
block.is_free = false;
return Some(block.handle);
}
}
None
}
fn allocate_new_block(&self, size: usize) -> Result<MemoryHandle> {
let current_total = *self.total_allocated.lock();
if current_total + size > self.config.max_size {
return Err(ferrum_types::FerrumError::backend(format!(
"Memory pool size limit exceeded: {} + {} > {}",
current_total, size, self.config.max_size
)));
}
let handle_id = {
let mut count = self.allocation_count.lock();
*count += 1;
*count
};
let handle = MemoryHandle::new(handle_id);
let block = MemoryBlock {
handle,
size,
is_free: false,
allocated_at: std::time::Instant::now(),
};
let mut blocks = self.blocks.lock();
blocks.push_back(block);
{
let mut total = self.total_allocated.lock();
*total += size;
let mut peak = self.peak_allocated.lock();
if *total > *peak {
*peak = *total;
}
}
debug!("Allocated new memory block of size {} bytes", size);
Ok(handle)
}
}
#[async_trait]
impl DeviceMemoryManager for MemoryPool {
async fn allocate(&self, size: usize, _device: &Device) -> Result<MemoryHandle> {
self.allocate(size)
}
async fn allocate_aligned(
&self,
size: usize,
alignment: usize,
_device: &Device,
) -> Result<MemoryHandle> {
let aligned_size = align_size(size, alignment);
self.allocate(aligned_size)
}
async fn deallocate(&self, handle: MemoryHandle) -> Result<()> {
self.deallocate(handle)
}
async fn copy(
&self,
_src: MemoryHandle,
_dst: MemoryHandle,
_size: usize,
_src_offset: usize,
_dst_offset: usize,
) -> Result<()> {
Ok(())
}
async fn copy_async(
&self,
_transfer: MemoryTransfer,
_stream: Option<StreamHandle>,
) -> Result<()> {
Ok(())
}
async fn memory_info(&self, _device: &Device) -> Result<MemoryInfo> {
Ok(self.stats())
}
fn handle_info(&self, handle: MemoryHandle) -> Option<MemoryHandleInfo> {
let blocks = self.blocks.lock();
blocks
.iter()
.find(|b| b.handle.id() == handle.id())
.map(|block| {
MemoryHandleInfo {
handle: block.handle,
size: block.size,
device: self.device.clone(),
alignment: 256, allocated_at: block.allocated_at,
is_mapped: false,
memory_type: MemoryType::General,
}
})
}
async fn configure_pool(
&self,
_device: &Device,
_config: InterfaceMemoryPoolConfig,
) -> Result<()> {
Ok(())
}
async fn defragment(&self, _device: &Device) -> Result<DefragmentationStats> {
let before_fragmentation = self.stats().fragmentation_ratio;
self.defragment()?;
let after_fragmentation = self.stats().fragmentation_ratio;
Ok(DefragmentationStats {
memory_freed: 0, blocks_moved: 0,
time_taken_ms: 0,
fragmentation_before: before_fragmentation,
fragmentation_after: after_fragmentation,
})
}
fn set_pressure_callback(&self, _callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
}
}
fn align_size(size: usize, alignment: usize) -> usize {
(size + alignment - 1) & !(alignment - 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_align_size() {
assert_eq!(align_size(100, 256), 256);
assert_eq!(align_size(256, 256), 256);
assert_eq!(align_size(257, 256), 512);
assert_eq!(align_size(500, 256), 512);
assert_eq!(align_size(1, 64), 64);
assert_eq!(align_size(64, 64), 64);
assert_eq!(align_size(65, 64), 128);
}
#[test]
fn test_memory_pool_creation() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let stats = pool.stats();
assert_eq!(stats.used_bytes, 0);
assert_eq!(stats.active_allocations, 0);
}
#[test]
fn test_memory_pool_allocation() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let handle1 = pool.allocate(1024).unwrap();
let stats = pool.stats();
assert_eq!(stats.active_allocations, 1);
assert!(stats.used_bytes > 0);
let handle2 = pool.allocate(2048).unwrap();
let stats = pool.stats();
assert_eq!(stats.active_allocations, 2);
assert_ne!(handle1.id(), handle2.id());
}
#[test]
fn test_memory_pool_deallocation() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let handle = pool.allocate(1024).unwrap();
assert_eq!(pool.stats().active_allocations, 1);
pool.deallocate(handle).unwrap();
assert_eq!(pool.stats().active_allocations, 0);
}
#[test]
fn test_memory_pool_reuse() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let handle1 = pool.allocate(1024).unwrap();
pool.deallocate(handle1).unwrap();
let _handle2 = pool.allocate(1024).unwrap();
let stats = pool.stats();
assert_eq!(stats.active_allocations, 1);
}
#[test]
fn test_memory_pool_size_limit() {
let device = Device::CPU;
let mut config = InternalMemoryPoolConfig::default();
config.max_size = 1024; let pool = MemoryPool::new(device, config);
let result = pool.allocate(2048);
assert!(result.is_err());
}
#[test]
fn test_memory_pool_multiple_allocations() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let mut handles = Vec::new();
for i in 0..5 {
let handle = pool.allocate(1024 * (i + 1)).unwrap();
handles.push(handle);
}
let stats = pool.stats();
assert_eq!(stats.active_allocations, 5);
for handle in handles {
pool.deallocate(handle).unwrap();
}
let stats = pool.stats();
assert_eq!(stats.active_allocations, 0);
}
#[test]
fn test_memory_pool_stats() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let stats = pool.stats();
assert_eq!(stats.used_bytes, 0);
assert_eq!(stats.active_allocations, 0);
assert_eq!(stats.fragmentation_ratio, 0.0);
let _handle1 = pool.allocate(1024).unwrap();
let _handle2 = pool.allocate(2048).unwrap();
let stats = pool.stats();
assert!(stats.total_bytes >= 1024 + 2048);
assert_eq!(stats.active_allocations, 2);
assert!(stats.used_bytes > 0);
}
#[test]
fn test_memory_pool_defragment() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let handle1 = pool.allocate(1024).unwrap();
let handle2 = pool.allocate(2048).unwrap();
let handle3 = pool.allocate(512).unwrap();
pool.deallocate(handle2).unwrap();
let stats_before = pool.stats();
pool.defragment().unwrap();
let stats_after = pool.stats();
assert_eq!(
stats_before.active_allocations,
stats_after.active_allocations
);
pool.deallocate(handle1).ok();
pool.deallocate(handle3).ok();
}
#[tokio::test]
async fn test_device_memory_manager_trait() {
use ferrum_interfaces::memory::DeviceMemoryManager;
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device.clone(), config);
let handle = DeviceMemoryManager::allocate(&pool, 1024, &device)
.await
.unwrap();
assert_ne!(handle.id(), 0);
let aligned_handle = DeviceMemoryManager::allocate_aligned(&pool, 1000, 256, &device)
.await
.unwrap();
assert_ne!(aligned_handle.id(), 0);
let info = DeviceMemoryManager::memory_info(&pool, &device)
.await
.unwrap();
assert_eq!(info.active_allocations, 2);
DeviceMemoryManager::deallocate(&pool, handle)
.await
.unwrap();
let info = DeviceMemoryManager::memory_info(&pool, &device)
.await
.unwrap();
assert_eq!(info.active_allocations, 1);
DeviceMemoryManager::deallocate(&pool, aligned_handle)
.await
.ok();
}
#[tokio::test]
async fn test_device_memory_manager_defragment() {
use ferrum_interfaces::memory::DeviceMemoryManager;
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device.clone(), config);
let _handle1 = DeviceMemoryManager::allocate(&pool, 1024, &device)
.await
.unwrap();
let _handle2 = DeviceMemoryManager::allocate(&pool, 2048, &device)
.await
.unwrap();
let defrag_stats = DeviceMemoryManager::defragment(&pool, &device)
.await
.unwrap();
assert!(defrag_stats.fragmentation_before >= 0.0);
assert!(defrag_stats.fragmentation_after >= 0.0);
}
#[test]
fn test_handle_info() {
let device = Device::CPU;
let config = InternalMemoryPoolConfig::default();
let pool = MemoryPool::new(device, config);
let handle = pool.allocate(1024).unwrap();
let info = pool.handle_info(handle);
assert!(info.is_some());
let info = info.unwrap();
assert_eq!(info.handle.id(), handle.id());
assert!(info.size >= 1024);
assert_eq!(info.alignment, 256);
assert!(!info.is_mapped);
let invalid_handle = MemoryHandle::new(99999);
let info = pool.handle_info(invalid_handle);
assert!(info.is_none());
}
}