use crate::{BackendType, Device, MemoryPool, MemoryPoolConfig, MemoryStats, PoolStats};
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::Arc;
use torsh_core::{
device::DeviceType,
error::{Result, TorshError},
};
impl From<DeviceType> for BackendType {
fn from(device_type: DeviceType) -> Self {
match device_type {
DeviceType::Cpu => BackendType::Cpu,
DeviceType::Cuda(_) => BackendType::Cuda,
DeviceType::Metal(_) => BackendType::Metal,
DeviceType::Wgpu(_) => BackendType::WebGpu,
}
}
}
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, format, vec::Vec};
pub struct UnifiedMemoryPool {
backend_pools: RwLock<HashMap<BackendType, Arc<Mutex<Box<dyn MemoryPool>>>>>,
backend_configs: RwLock<HashMap<BackendType, MemoryPoolConfig>>,
global_stats: RwLock<MemoryStats>,
pool_stats: RwLock<HashMap<BackendType, PoolStats>>,
default_config: MemoryPoolConfig,
enable_cross_backend_sharing: bool,
pressure_threshold: f32,
enable_auto_gc: bool,
}
impl UnifiedMemoryPool {
pub fn new(default_config: MemoryPoolConfig) -> Self {
Self {
backend_pools: RwLock::new(HashMap::new()),
backend_configs: RwLock::new(HashMap::new()),
global_stats: RwLock::new(MemoryStats::default()),
pool_stats: RwLock::new(HashMap::new()),
default_config,
enable_cross_backend_sharing: false,
pressure_threshold: 0.85,
enable_auto_gc: true,
}
}
pub fn with_default_config() -> Self {
Self::new(MemoryPoolConfig::default())
}
pub fn register_backend_pool(
&self,
backend_type: BackendType,
pool: Box<dyn MemoryPool>,
config: Option<MemoryPoolConfig>,
) -> Result<()> {
let config = config.unwrap_or_else(|| self.default_config.clone());
{
let mut pools = self.backend_pools.write();
pools.insert(backend_type, Arc::new(Mutex::new(pool)));
}
{
let mut configs = self.backend_configs.write();
configs.insert(backend_type, config);
}
self.update_global_stats()?;
Ok(())
}
pub fn unregister_backend_pool(&self, backend_type: BackendType) -> Result<()> {
{
let mut pools = self.backend_pools.write();
pools.remove(&backend_type);
}
{
let mut configs = self.backend_configs.write();
configs.remove(&backend_type);
}
{
let mut stats = self.pool_stats.write();
stats.remove(&backend_type);
}
self.update_global_stats()?;
Ok(())
}
pub fn get_or_create_pool(
&self,
backend_type: BackendType,
device: &Device,
) -> Result<Arc<Mutex<Box<dyn MemoryPool>>>> {
{
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
return Ok(Arc::clone(pool));
}
}
let config = self
.backend_configs
.read()
.get(&backend_type)
.cloned()
.unwrap_or_else(|| self.default_config.clone());
let pool = self.create_backend_pool(backend_type, device, config)?;
self.register_backend_pool(backend_type, pool, None)?;
let pools = self.backend_pools.read();
pools
.get(&backend_type)
.cloned()
.ok_or_else(|| TorshError::BackendError("Failed to create backend pool".to_string()))
}
fn create_backend_pool(
&self,
backend_type: BackendType,
device: &Device,
config: MemoryPoolConfig,
) -> Result<Box<dyn MemoryPool>> {
match backend_type {
BackendType::Cpu => Ok(Box::new(CpuMemoryPool::new(device.clone(), config))),
BackendType::Cuda => Ok(Box::new(CudaMemoryPool::new(device.clone(), config))),
BackendType::Metal => Ok(Box::new(MetalMemoryPool::new(device.clone(), config))),
BackendType::WebGpu => Ok(Box::new(WebGpuMemoryPool::new(device.clone(), config))),
BackendType::Rocm => Ok(Box::new(RocmMemoryPool::new(device.clone(), config))),
BackendType::Auto => Err(TorshError::BackendError(
"Cannot create pool for Auto backend type".to_string(),
)),
}
}
pub fn allocate(&self, device: &Device, size: usize, alignment: usize) -> Result<*mut u8> {
let backend_type = device.device_type().into();
let pool = self.get_or_create_pool(backend_type, device)?;
if self.enable_auto_gc && self.is_under_pressure(backend_type)? {
self.garbage_collect_backend(backend_type)?;
}
let result = {
let mut pool = pool.lock();
pool.allocate(size, alignment)
};
if result.is_ok() {
self.update_global_stats()?;
}
result
}
pub fn deallocate(&self, device: &Device, ptr: *mut u8, size: usize) -> Result<()> {
let backend_type = device.device_type().into();
let pool = self.get_or_create_pool(backend_type, device)?;
let result = {
let mut pool = pool.lock();
pool.deallocate(ptr, size)
};
if result.is_ok() {
self.update_global_stats()?;
}
result
}
pub fn is_under_pressure(&self, backend_type: BackendType) -> Result<bool> {
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
let pool = pool.lock();
let stats = pool.stats();
let utilization = stats.allocated as f32 / stats.capacity as f32;
Ok(utilization > self.pressure_threshold)
} else {
Ok(false)
}
}
pub fn garbage_collect_backend(&self, backend_type: BackendType) -> Result<usize> {
let pools = self.backend_pools.read();
let available = if let Some(pool) = pools.get(&backend_type) {
let mut pool = pool.lock();
pool.reset()?;
let available = pool.available();
drop(pool);
available
} else {
0
};
drop(pools);
self.update_global_stats()?;
Ok(available)
}
pub fn garbage_collect_all(&self) -> Result<usize> {
let mut total_freed = 0;
let backend_types: Vec<BackendType> =
{ self.backend_pools.read().keys().cloned().collect() };
for backend_type in backend_types {
total_freed += self.garbage_collect_backend(backend_type)?;
}
Ok(total_freed)
}
pub fn defragment_backend(
&self,
backend_type: BackendType,
) -> Result<crate::memory::DefragmentationResult> {
use std::time::Instant;
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
let mut pool = pool.lock();
let _start_time = Instant::now();
let _frag_before = pool.fragmentation_info();
let result = pool.defragment()?;
self.update_global_stats()?;
Ok(result)
} else {
Err(TorshError::BackendError(format!(
"Backend {:?} not found for defragmentation",
backend_type
)))
}
}
pub fn defragment_all(
&self,
) -> Result<Vec<(BackendType, crate::memory::DefragmentationResult)>> {
let mut results = Vec::new();
let backend_types: Vec<BackendType> =
{ self.backend_pools.read().keys().cloned().collect() };
for backend_type in backend_types {
if self.backend_needs_defragmentation(backend_type)? {
let result = self.defragment_backend(backend_type)?;
results.push((backend_type, result));
}
}
Ok(results)
}
pub fn backend_needs_defragmentation(&self, backend_type: BackendType) -> Result<bool> {
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
let pool = pool.lock();
Ok(pool.needs_defragmentation())
} else {
Ok(false)
}
}
pub fn get_backend_fragmentation_info(
&self,
backend_type: BackendType,
) -> Result<crate::memory::FragmentationInfo> {
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
let pool = pool.lock();
Ok(pool.fragmentation_info())
} else {
Err(TorshError::BackendError(format!(
"Backend {:?} not found",
backend_type
)))
}
}
pub fn compact_backend(
&self,
backend_type: BackendType,
) -> Result<crate::memory::CompactionResult> {
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
let mut pool = pool.lock();
let result = pool.compact()?;
self.update_global_stats()?;
Ok(result)
} else {
Err(TorshError::BackendError(format!(
"Backend {:?} not found for compaction",
backend_type
)))
}
}
pub fn get_overall_fragmentation(&self) -> Result<crate::memory::FragmentationInfo> {
let pools = self.backend_pools.read();
let mut total_free_blocks = 0;
let mut total_allocated_blocks = 0;
let mut total_free_memory = 0;
let mut total_allocated_memory = 0;
let mut total_fragmentation = 0.0;
let mut backend_count = 0;
let mut largest_free_block = 0;
let mut smallest_free_block = usize::MAX;
for pool in pools.values() {
let pool = pool.lock();
let frag_info = pool.fragmentation_info();
total_free_blocks += frag_info.free_blocks;
total_allocated_blocks += frag_info.allocated_blocks;
total_free_memory += frag_info.total_free_memory;
total_allocated_memory += frag_info.total_allocated_memory;
total_fragmentation += frag_info.overall_fragmentation;
backend_count += 1;
largest_free_block = largest_free_block.max(frag_info.largest_free_block);
if frag_info.smallest_free_block > 0 {
smallest_free_block = smallest_free_block.min(frag_info.smallest_free_block);
}
}
if smallest_free_block == usize::MAX {
smallest_free_block = 0;
}
let average_free_block = if total_free_blocks > 0 {
total_free_memory / total_free_blocks
} else {
0
};
let overall_fragmentation = if backend_count > 0 {
total_fragmentation / backend_count as f32
} else {
0.0
};
let utilization_efficiency = if total_free_memory + total_allocated_memory > 0 {
total_allocated_memory as f32 / (total_free_memory + total_allocated_memory) as f32
} else {
0.0
};
Ok(crate::memory::FragmentationInfo {
overall_fragmentation,
external_fragmentation: overall_fragmentation * 0.8, internal_fragmentation: overall_fragmentation * 0.2, free_blocks: total_free_blocks,
allocated_blocks: total_allocated_blocks,
largest_free_block,
smallest_free_block,
average_free_block,
total_free_memory,
total_allocated_memory,
utilization_efficiency,
allocation_efficiency: utilization_efficiency, })
}
pub fn set_auto_defragmentation(&mut self, enabled: bool) {
self.enable_auto_gc = enabled;
}
pub fn set_pressure_threshold(&mut self, threshold: f32) {
self.pressure_threshold = threshold.clamp(0.0, 1.0);
}
pub fn perform_automatic_maintenance(&self) -> Result<()> {
if !self.enable_auto_gc {
return Ok(());
}
let backend_types: Vec<BackendType> =
{ self.backend_pools.read().keys().cloned().collect() };
for backend_type in backend_types {
if self.is_under_pressure(backend_type)? {
self.garbage_collect_backend(backend_type)?;
if self.is_under_pressure(backend_type)? {
self.defragment_backend(backend_type)?;
}
} else if self.backend_needs_defragmentation(backend_type)? {
let frag_info = self.get_backend_fragmentation_info(backend_type)?;
if frag_info.is_severely_fragmented() {
self.defragment_backend(backend_type)?;
}
}
}
Ok(())
}
fn update_global_stats(&self) -> Result<()> {
let mut total_capacity = 0;
let mut total_allocated = 0;
let mut total_available = 0;
let mut backend_stats = HashMap::new();
{
let pools = self.backend_pools.read();
for (backend_type, pool) in pools.iter() {
let pool = pool.lock();
let stats = pool.stats();
total_capacity += stats.capacity;
total_allocated += stats.allocated;
total_available += stats.available;
backend_stats.insert(*backend_type, stats);
}
}
{
let mut pool_stats = self.pool_stats.write();
*pool_stats = backend_stats;
}
{
let mut global_stats = self.global_stats.write();
global_stats.total_memory = total_capacity;
global_stats.allocated_memory = total_allocated;
global_stats.available_memory = total_available;
global_stats.efficiency = if total_capacity > 0 {
total_allocated as f32 / total_capacity as f32
} else {
0.0
};
}
Ok(())
}
pub fn global_stats(&self) -> MemoryStats {
self.global_stats.read().clone()
}
pub fn backend_stats(&self, backend_type: BackendType) -> Option<PoolStats> {
self.pool_stats.read().get(&backend_type).cloned()
}
pub fn all_backend_stats(&self) -> HashMap<BackendType, PoolStats> {
self.pool_stats.read().clone()
}
pub fn set_auto_gc(&mut self, enable: bool) {
self.enable_auto_gc = enable;
}
pub fn set_cross_backend_sharing(&mut self, enable: bool) {
self.enable_cross_backend_sharing = enable;
}
pub fn registered_backends(&self) -> Vec<BackendType> {
self.backend_pools.read().keys().cloned().collect()
}
pub fn reset_all(&self) -> Result<()> {
let backend_types: Vec<BackendType> =
{ self.backend_pools.read().keys().cloned().collect() };
for backend_type in backend_types {
let pools = self.backend_pools.read();
if let Some(pool) = pools.get(&backend_type) {
let mut pool = pool.lock();
pool.reset()?;
}
}
self.update_global_stats()?;
Ok(())
}
pub fn backend_config(&self, backend_type: BackendType) -> Option<MemoryPoolConfig> {
self.backend_configs.read().get(&backend_type).cloned()
}
pub fn set_backend_config(
&self,
backend_type: BackendType,
config: MemoryPoolConfig,
) -> Result<()> {
{
let mut configs = self.backend_configs.write();
configs.insert(backend_type, config);
}
Ok(())
}
}
mod backend_pools {
use super::*;
#[derive(Debug)]
pub struct CpuMemoryPool {
#[allow(dead_code)]
device: Device,
config: MemoryPoolConfig,
allocated_blocks: RwLock<Vec<(usize, usize)>>, stats: RwLock<PoolStats>,
}
impl CpuMemoryPool {
pub fn new(device: Device, config: MemoryPoolConfig) -> Self {
let stats = PoolStats {
capacity: config.initial_size,
available: config.initial_size,
..Default::default()
};
Self {
device,
config,
allocated_blocks: RwLock::new(Vec::new()),
stats: RwLock::new(stats),
}
}
}
impl MemoryPool for CpuMemoryPool {
fn allocate(&mut self, size: usize, alignment: usize) -> Result<*mut u8> {
use std::alloc::{alloc, Layout};
let layout = Layout::from_size_align(size, alignment)
.map_err(|e| TorshError::AllocationError(format!("Invalid layout: {}", e)))?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
return Err(TorshError::AllocationError(
"Failed to allocate CPU memory".to_string(),
));
}
{
let mut blocks = self.allocated_blocks.write();
blocks.push((ptr as usize, size));
}
{
let mut stats = self.stats.write();
stats.allocated += size;
stats.available = stats.capacity.saturating_sub(stats.allocated);
stats.allocated_blocks += 1;
}
Ok(ptr)
}
fn deallocate(&mut self, ptr: *mut u8, size: usize) -> Result<()> {
use std::alloc::{dealloc, Layout};
{
let mut blocks = self.allocated_blocks.write();
if let Some(pos) = blocks
.iter()
.position(|(addr, s)| *addr == ptr as usize && *s == size)
{
blocks.remove(pos);
} else {
return Err(TorshError::InvalidArgument(
"Block not found for deallocation".to_string(),
));
}
}
let layout = Layout::from_size_align(size, self.config.alignment)
.map_err(|e| TorshError::AllocationError(format!("Invalid layout: {}", e)))?;
unsafe {
dealloc(ptr, layout);
}
{
let mut stats = self.stats.write();
stats.allocated = stats.allocated.saturating_sub(size);
stats.available = stats.capacity.saturating_sub(stats.allocated);
stats.allocated_blocks = stats.allocated_blocks.saturating_sub(1);
}
Ok(())
}
fn stats(&self) -> PoolStats {
self.stats.read().clone()
}
fn reset(&mut self) -> Result<()> {
use std::alloc::{dealloc, Layout};
{
let blocks = self.allocated_blocks.read().clone();
for (addr, size) in blocks {
let layout =
Layout::from_size_align(size, self.config.alignment).map_err(|e| {
TorshError::AllocationError(format!("Invalid layout: {}", e))
})?;
unsafe {
dealloc(addr as *mut u8, layout);
}
}
}
{
let mut blocks = self.allocated_blocks.write();
blocks.clear();
}
{
let mut stats = self.stats.write();
stats.allocated = 0;
stats.available = stats.capacity;
stats.allocated_blocks = 0;
}
Ok(())
}
fn capacity(&self) -> usize {
self.config.initial_size
}
fn available(&self) -> usize {
self.stats.read().available
}
fn defragment(&mut self) -> Result<crate::memory::DefragmentationResult> {
Ok(crate::memory::DefragmentationResult {
blocks_moved: 0,
memory_compacted: 0,
duration_ms: 0.0,
fragmentation_before: 0.0,
fragmentation_after: 0.0,
efficiency_improvement: 0.0,
success: true,
})
}
fn needs_defragmentation(&self) -> bool {
let stats = self.stats.read();
stats.allocated_blocks > 10 && stats.allocated < stats.capacity / 2
}
fn fragmentation_info(&self) -> crate::memory::FragmentationInfo {
let stats = self.stats.read();
let fragmentation = if stats.capacity > 0 {
1.0 - (stats.available as f32 / stats.capacity as f32)
} else {
0.0
};
crate::memory::FragmentationInfo {
overall_fragmentation: fragmentation,
external_fragmentation: fragmentation * 0.8,
internal_fragmentation: fragmentation * 0.2,
free_blocks: if stats.available > 0 { 1 } else { 0 },
allocated_blocks: stats.allocated_blocks,
largest_free_block: stats.available,
smallest_free_block: if stats.available > 0 {
stats.available
} else {
0
},
average_free_block: stats.available,
total_free_memory: stats.available,
total_allocated_memory: stats.allocated,
utilization_efficiency: if stats.capacity > 0 {
stats.allocated as f32 / stats.capacity as f32
} else {
0.0
},
allocation_efficiency: if stats.capacity > 0 {
stats.allocated as f32 / stats.capacity as f32
} else {
0.0
},
}
}
fn compact(&mut self) -> Result<crate::memory::CompactionResult> {
Ok(crate::memory::CompactionResult {
allocations_moved: 0,
bytes_moved: 0,
duration_ms: 0.0,
largest_free_before: self.available(),
largest_free_after: self.available(),
free_blocks_before: 1,
free_blocks_after: 1,
success: true,
})
}
}
#[derive(Debug)]
pub struct CudaMemoryPool {
#[allow(dead_code)]
device: Device,
config: MemoryPoolConfig,
stats: RwLock<PoolStats>,
}
impl CudaMemoryPool {
pub fn new(device: Device, config: MemoryPoolConfig) -> Self {
let stats = PoolStats {
capacity: config.initial_size,
available: config.initial_size,
..Default::default()
};
Self {
device,
config,
stats: RwLock::new(stats),
}
}
}
impl MemoryPool for CudaMemoryPool {
fn allocate(&mut self, _size: usize, _alignment: usize) -> Result<*mut u8> {
Ok(std::ptr::null_mut())
}
fn deallocate(&mut self, _ptr: *mut u8, _size: usize) -> Result<()> {
Ok(())
}
fn stats(&self) -> PoolStats {
self.stats.read().clone()
}
fn reset(&mut self) -> Result<()> {
Ok(())
}
fn capacity(&self) -> usize {
self.config.initial_size
}
fn available(&self) -> usize {
self.stats.read().available
}
fn defragment(&mut self) -> Result<crate::memory::DefragmentationResult> {
Ok(crate::memory::DefragmentationResult {
blocks_moved: 0,
memory_compacted: 0,
duration_ms: 0.0,
fragmentation_before: 0.0,
fragmentation_after: 0.0,
efficiency_improvement: 0.0,
success: true,
})
}
fn needs_defragmentation(&self) -> bool {
false
}
fn fragmentation_info(&self) -> crate::memory::FragmentationInfo {
let stats = self.stats.read();
crate::memory::FragmentationInfo {
overall_fragmentation: 0.0,
external_fragmentation: 0.0,
internal_fragmentation: 0.0,
free_blocks: 1,
allocated_blocks: stats.allocated_blocks,
largest_free_block: stats.available,
smallest_free_block: stats.available,
average_free_block: stats.available,
total_free_memory: stats.available,
total_allocated_memory: stats.allocated,
utilization_efficiency: 0.0,
allocation_efficiency: 0.0,
}
}
fn compact(&mut self) -> Result<crate::memory::CompactionResult> {
Ok(crate::memory::CompactionResult {
allocations_moved: 0,
bytes_moved: 0,
duration_ms: 0.0,
largest_free_before: self.available(),
largest_free_after: self.available(),
free_blocks_before: 1,
free_blocks_after: 1,
success: true,
})
}
}
#[derive(Debug)]
pub struct MetalMemoryPool {
#[allow(dead_code)]
device: Device,
config: MemoryPoolConfig,
stats: RwLock<PoolStats>,
}
impl MetalMemoryPool {
pub fn new(device: Device, config: MemoryPoolConfig) -> Self {
let stats = PoolStats {
capacity: config.initial_size,
available: config.initial_size,
..Default::default()
};
Self {
device,
config,
stats: RwLock::new(stats),
}
}
}
impl MemoryPool for MetalMemoryPool {
fn allocate(&mut self, _size: usize, _alignment: usize) -> Result<*mut u8> {
Ok(std::ptr::null_mut())
}
fn deallocate(&mut self, _ptr: *mut u8, _size: usize) -> Result<()> {
Ok(())
}
fn stats(&self) -> PoolStats {
self.stats.read().clone()
}
fn reset(&mut self) -> Result<()> {
Ok(())
}
fn capacity(&self) -> usize {
self.config.initial_size
}
fn available(&self) -> usize {
self.stats.read().available
}
fn defragment(&mut self) -> Result<crate::memory::DefragmentationResult> {
Ok(crate::memory::DefragmentationResult {
blocks_moved: 0,
memory_compacted: 0,
duration_ms: 0.0,
fragmentation_before: 0.0,
fragmentation_after: 0.0,
efficiency_improvement: 0.0,
success: true,
})
}
fn needs_defragmentation(&self) -> bool {
false
}
fn fragmentation_info(&self) -> crate::memory::FragmentationInfo {
let stats = self.stats.read();
crate::memory::FragmentationInfo {
overall_fragmentation: 0.0,
external_fragmentation: 0.0,
internal_fragmentation: 0.0,
free_blocks: 1,
allocated_blocks: stats.allocated_blocks,
largest_free_block: stats.available,
smallest_free_block: stats.available,
average_free_block: stats.available,
total_free_memory: stats.available,
total_allocated_memory: stats.allocated,
utilization_efficiency: 0.0,
allocation_efficiency: 0.0,
}
}
fn compact(&mut self) -> Result<crate::memory::CompactionResult> {
Ok(crate::memory::CompactionResult {
allocations_moved: 0,
bytes_moved: 0,
duration_ms: 0.0,
largest_free_before: self.available(),
largest_free_after: self.available(),
free_blocks_before: 1,
free_blocks_after: 1,
success: true,
})
}
}
#[derive(Debug)]
pub struct WebGpuMemoryPool {
#[allow(dead_code)]
device: Device,
config: MemoryPoolConfig,
stats: RwLock<PoolStats>,
}
impl WebGpuMemoryPool {
pub fn new(device: Device, config: MemoryPoolConfig) -> Self {
let stats = PoolStats {
capacity: config.initial_size,
available: config.initial_size,
..Default::default()
};
Self {
device,
config,
stats: RwLock::new(stats),
}
}
}
impl MemoryPool for WebGpuMemoryPool {
fn allocate(&mut self, _size: usize, _alignment: usize) -> Result<*mut u8> {
Ok(std::ptr::null_mut())
}
fn deallocate(&mut self, _ptr: *mut u8, _size: usize) -> Result<()> {
Ok(())
}
fn stats(&self) -> PoolStats {
self.stats.read().clone()
}
fn reset(&mut self) -> Result<()> {
Ok(())
}
fn capacity(&self) -> usize {
self.config.initial_size
}
fn available(&self) -> usize {
self.stats.read().available
}
fn defragment(&mut self) -> Result<crate::memory::DefragmentationResult> {
Ok(crate::memory::DefragmentationResult {
blocks_moved: 0,
memory_compacted: 0,
duration_ms: 0.0,
fragmentation_before: 0.0,
fragmentation_after: 0.0,
efficiency_improvement: 0.0,
success: true,
})
}
fn needs_defragmentation(&self) -> bool {
false
}
fn fragmentation_info(&self) -> crate::memory::FragmentationInfo {
let stats = self.stats.read();
crate::memory::FragmentationInfo {
overall_fragmentation: 0.0,
external_fragmentation: 0.0,
internal_fragmentation: 0.0,
free_blocks: 1,
allocated_blocks: stats.allocated_blocks,
largest_free_block: stats.available,
smallest_free_block: stats.available,
average_free_block: stats.available,
total_free_memory: stats.available,
total_allocated_memory: stats.allocated,
utilization_efficiency: 0.0,
allocation_efficiency: 0.0,
}
}
fn compact(&mut self) -> Result<crate::memory::CompactionResult> {
Ok(crate::memory::CompactionResult {
allocations_moved: 0,
bytes_moved: 0,
duration_ms: 0.0,
largest_free_before: self.available(),
largest_free_after: self.available(),
free_blocks_before: 1,
free_blocks_after: 1,
success: true,
})
}
}
#[derive(Debug)]
pub struct RocmMemoryPool {
#[allow(dead_code)]
device: Device,
config: MemoryPoolConfig,
stats: RwLock<PoolStats>,
}
impl RocmMemoryPool {
pub fn new(device: Device, config: MemoryPoolConfig) -> Self {
let stats = PoolStats {
capacity: config.initial_size,
available: config.initial_size,
..Default::default()
};
Self {
device,
config,
stats: RwLock::new(stats),
}
}
}
impl MemoryPool for RocmMemoryPool {
fn allocate(&mut self, _size: usize, _alignment: usize) -> Result<*mut u8> {
Ok(std::ptr::null_mut())
}
fn deallocate(&mut self, _ptr: *mut u8, _size: usize) -> Result<()> {
Ok(())
}
fn stats(&self) -> PoolStats {
self.stats.read().clone()
}
fn reset(&mut self) -> Result<()> {
Ok(())
}
fn capacity(&self) -> usize {
self.config.initial_size
}
fn available(&self) -> usize {
self.stats.read().available
}
fn defragment(&mut self) -> Result<crate::memory::DefragmentationResult> {
Ok(crate::memory::DefragmentationResult {
blocks_moved: 0,
memory_compacted: 0,
duration_ms: 0.0,
fragmentation_before: 0.0,
fragmentation_after: 0.0,
efficiency_improvement: 0.0,
success: true,
})
}
fn needs_defragmentation(&self) -> bool {
false
}
fn fragmentation_info(&self) -> crate::memory::FragmentationInfo {
let stats = self.stats.read();
crate::memory::FragmentationInfo {
overall_fragmentation: 0.0,
external_fragmentation: 0.0,
internal_fragmentation: 0.0,
free_blocks: 1,
allocated_blocks: stats.allocated_blocks,
largest_free_block: stats.available,
smallest_free_block: stats.available,
average_free_block: stats.available,
total_free_memory: stats.available,
total_allocated_memory: stats.allocated,
utilization_efficiency: 0.0,
allocation_efficiency: 0.0,
}
}
fn compact(&mut self) -> Result<crate::memory::CompactionResult> {
Ok(crate::memory::CompactionResult {
allocations_moved: 0,
bytes_moved: 0,
duration_ms: 0.0,
largest_free_before: self.available(),
largest_free_after: self.available(),
free_blocks_before: 1,
free_blocks_after: 1,
success: true,
})
}
}
}
pub use backend_pools::*;
#[cfg(test)]
mod tests {
use super::*;
use crate::device::{Device, DeviceInfo};
use torsh_core::device::DeviceType;
fn create_test_device(device_type: DeviceType) -> Device {
let info = DeviceInfo::default();
Device::new(0, device_type, "Test Device".to_string(), info)
}
#[test]
fn test_unified_memory_pool_creation() {
let config = MemoryPoolConfig::default();
let pool = UnifiedMemoryPool::new(config);
assert_eq!(pool.registered_backends().len(), 0);
assert_eq!(pool.global_stats().total_memory, 0);
}
#[test]
fn test_backend_registration() {
let pool = UnifiedMemoryPool::with_default_config();
let device = create_test_device(DeviceType::Cpu);
let cpu_pool = Box::new(CpuMemoryPool::new(device, MemoryPoolConfig::default()));
let result = pool.register_backend_pool(BackendType::Cpu, cpu_pool, None);
assert!(result.is_ok());
let backends = pool.registered_backends();
assert_eq!(backends.len(), 1);
assert!(backends.contains(&BackendType::Cpu));
}
#[test]
fn test_backend_unregistration() {
let pool = UnifiedMemoryPool::with_default_config();
let device = create_test_device(DeviceType::Cpu);
let cpu_pool = Box::new(CpuMemoryPool::new(device, MemoryPoolConfig::default()));
pool.register_backend_pool(BackendType::Cpu, cpu_pool, None)
.unwrap();
assert_eq!(pool.registered_backends().len(), 1);
let result = pool.unregister_backend_pool(BackendType::Cpu);
assert!(result.is_ok());
assert_eq!(pool.registered_backends().len(), 0);
}
#[test]
fn test_memory_allocation() {
let pool = UnifiedMemoryPool::with_default_config();
let device = create_test_device(DeviceType::Cpu);
let result = pool.allocate(&device, 1024, 16);
assert!(result.is_ok());
if let Ok(ptr) = result {
let dealloc_result = pool.deallocate(&device, ptr, 1024);
assert!(dealloc_result.is_ok());
}
}
#[test]
fn test_global_statistics() {
let pool = UnifiedMemoryPool::with_default_config();
let device = create_test_device(DeviceType::Cpu);
let cpu_pool = Box::new(CpuMemoryPool::new(device, MemoryPoolConfig::default()));
pool.register_backend_pool(BackendType::Cpu, cpu_pool, None)
.unwrap();
let stats = pool.global_stats();
assert!(stats.total_memory > 0);
}
#[test]
fn test_memory_pressure() {
let pool = UnifiedMemoryPool::with_default_config();
let device = create_test_device(DeviceType::Cpu);
let cpu_pool = Box::new(CpuMemoryPool::new(device, MemoryPoolConfig::default()));
pool.register_backend_pool(BackendType::Cpu, cpu_pool, None)
.unwrap();
let is_under_pressure = pool.is_under_pressure(BackendType::Cpu);
assert!(is_under_pressure.is_ok());
assert!(!is_under_pressure.unwrap()); }
#[test]
fn test_garbage_collection() {
let pool = UnifiedMemoryPool::with_default_config();
let device = create_test_device(DeviceType::Cpu);
let cpu_pool = Box::new(CpuMemoryPool::new(device, MemoryPoolConfig::default()));
pool.register_backend_pool(BackendType::Cpu, cpu_pool, None)
.unwrap();
let result = pool.garbage_collect_backend(BackendType::Cpu);
assert!(result.is_ok());
let result = pool.garbage_collect_all();
assert!(result.is_ok());
}
}