use crate::error::{Result, TrustformersError};
use std::alloc::{alloc, dealloc, Layout};
use std::collections::{HashMap, VecDeque};
use std::ptr::NonNull;
use std::sync::{Arc, Mutex, RwLock};
use std::time::Instant;
use tracing::{debug, info};
use trustformers_core::errors::TrustformersError as CoreTrustformersError;
#[cfg(target_os = "linux")]
use std::fs;
#[cfg(target_os = "linux")]
use std::io::Read;
#[derive(Debug, Clone)]
pub struct MemoryPoolConfig {
pub initial_size: usize,
pub max_size: usize,
pub alignment: usize,
pub num_buckets: usize,
pub enable_gc: bool,
pub gc_threshold: f64,
pub enable_tracking: bool,
pub preallocation_strategy: PreallocationStrategy,
pub enable_numa: bool,
pub preferred_numa_node: Option<u32>,
pub enable_numa_interleaving: bool,
pub numa_interleaving_threshold: usize,
pub enable_numa_balancing: bool,
pub enable_thread_affinity: bool,
}
impl Default for MemoryPoolConfig {
fn default() -> Self {
Self {
initial_size: 64 * 1024 * 1024, max_size: 2 * 1024 * 1024 * 1024, alignment: 64, num_buckets: 16,
enable_gc: true,
gc_threshold: 0.7,
enable_tracking: true,
preallocation_strategy: PreallocationStrategy::Exponential,
enable_numa: true,
preferred_numa_node: None, enable_numa_interleaving: true,
numa_interleaving_threshold: 1024 * 1024, enable_numa_balancing: true,
enable_thread_affinity: true,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum PreallocationStrategy {
None,
Linear,
Exponential,
Fibonacci,
}
#[derive(Debug, Clone)]
pub struct NumaTopology {
pub num_nodes: u32,
pub cpus_per_node: HashMap<u32, Vec<u32>>,
pub memory_per_node: HashMap<u32, u64>,
pub distance_matrix: Vec<Vec<u32>>,
pub current_node: Option<u32>,
}
impl NumaTopology {
pub fn detect() -> Result<Self> {
#[cfg(target_os = "linux")]
{
Self::detect_linux()
}
#[cfg(target_os = "windows")]
{
Self::detect_windows()
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
Ok(Self {
num_nodes: 1,
cpus_per_node: [(0, vec![0])].iter().cloned().collect(),
memory_per_node: [(0, 8 * 1024 * 1024 * 1024)].iter().cloned().collect(), distance_matrix: vec![vec![10]],
current_node: Some(0),
})
}
}
#[cfg(target_os = "linux")]
fn detect_linux() -> Result<Self> {
let mut numa_topology = NumaTopology {
num_nodes: 0,
cpus_per_node: HashMap::new(),
memory_per_node: HashMap::new(),
distance_matrix: Vec::new(),
current_node: None,
};
let nodes_dir = "/sys/devices/system/node/";
if std::path::Path::new(nodes_dir).exists() {
let nodes = fs::read_dir(nodes_dir).map_err(|e| TrustformersError::Io {
message: format!("Failed to read NUMA nodes: {}", e),
path: Some(nodes_dir.to_string()),
suggestion: Some("Check system NUMA support and permissions".to_string()),
})?;
for entry in nodes {
let entry = entry.map_err(|e| TrustformersError::Io {
message: format!("Failed to read node entry: {}", e),
path: Some(nodes_dir.to_string()),
suggestion: Some("Check NUMA node directory permissions".to_string()),
})?;
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("node") {
if let Some(node_id_str) = name_str.strip_prefix("node") {
if let Ok(node_id) = node_id_str.parse::<u32>() {
numa_topology.num_nodes = numa_topology.num_nodes.max(node_id + 1);
let cpulist_path = format!("{}/node{}/cpulist", nodes_dir, node_id);
if let Ok(cpulist) = fs::read_to_string(&cpulist_path) {
let cpus = Self::parse_cpu_list(cpulist.trim())?;
numa_topology.cpus_per_node.insert(node_id, cpus);
}
let meminfo_path = format!("{}/node{}/meminfo", nodes_dir, node_id);
if let Ok(meminfo) = fs::read_to_string(&meminfo_path) {
if let Ok(memory) = Self::parse_memory_info(&meminfo) {
numa_topology.memory_per_node.insert(node_id, memory);
}
}
}
}
}
}
numa_topology.distance_matrix = Self::read_distance_matrix(numa_topology.num_nodes)?;
numa_topology.current_node = Self::get_current_numa_node();
} else {
numa_topology.num_nodes = 1;
numa_topology.cpus_per_node.insert(0, vec![0]);
numa_topology.memory_per_node.insert(0, 8 * 1024 * 1024 * 1024); numa_topology.distance_matrix = vec![vec![10]];
numa_topology.current_node = Some(0);
}
Ok(numa_topology)
}
#[cfg(target_os = "windows")]
fn detect_windows() -> Result<Self> {
Ok(Self {
num_nodes: 1,
cpus_per_node: [(0, vec![0])].iter().cloned().collect(),
memory_per_node: [(0, 8 * 1024 * 1024 * 1024)].iter().cloned().collect(),
distance_matrix: vec![vec![10]],
current_node: Some(0),
})
}
#[cfg(target_os = "linux")]
fn parse_cpu_list(cpulist: &str) -> Result<Vec<u32>> {
let mut cpus = Vec::new();
for part in cpulist.split(',') {
if part.contains('-') {
let range: Vec<&str> = part.split('-').collect();
if range.len() == 2 {
let start: u32 = range[0].parse().map_err(|e| {
TrustformersError::invalid_input(
format!("Invalid CPU range start: {}", e),
Some("cpu_range_start"),
Some("valid integer"),
Some(range[0]),
)
})?;
let end: u32 = range[1].parse().map_err(|e| {
TrustformersError::invalid_input(
format!("Invalid CPU range end: {}", e),
Some("cpu_range_end"),
Some("valid integer"),
Some(range[1]),
)
})?;
for cpu in start..=end {
cpus.push(cpu);
}
}
} else {
let cpu: u32 = part.parse().map_err(|e| {
TrustformersError::invalid_input(
format!("Invalid CPU ID: {}", e),
Some("cpu_id"),
Some("valid integer"),
Some(part),
)
})?;
cpus.push(cpu);
}
}
Ok(cpus)
}
#[cfg(target_os = "linux")]
fn parse_memory_info(meminfo: &str) -> Result<u64> {
for line in meminfo.lines() {
if line.contains("MemTotal:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let memory_kb: u64 = parts[1].parse().map_err(|e| {
TrustformersError::invalid_input(
format!("Invalid memory size: {}", e),
Some("memory_size"),
Some("valid integer"),
Some(parts[1]),
)
})?;
return Ok(memory_kb * 1024); }
}
}
Err(TrustformersError::invalid_input(
"Could not parse memory info",
Some("meminfo"),
Some("MemTotal line with valid format"),
Some("missing or invalid format"),
))
}
#[cfg(target_os = "linux")]
fn read_distance_matrix(num_nodes: u32) -> Result<Vec<Vec<u32>>> {
let mut matrix = vec![vec![10; num_nodes as usize]; num_nodes as usize];
for i in 0..num_nodes {
let distance_path = format!("/sys/devices/system/node/node{}/distance", i);
if let Ok(distances) = fs::read_to_string(&distance_path) {
let distance_values: Vec<u32> =
distances.split_whitespace().filter_map(|s| s.parse().ok()).collect();
for (j, &distance) in distance_values.iter().enumerate() {
if j < num_nodes as usize {
matrix[i as usize][j] = distance;
}
}
}
}
Ok(matrix)
}
#[cfg(target_os = "linux")]
fn get_current_numa_node() -> Option<u32> {
if let Ok(status) = fs::read_to_string("/proc/self/status") {
for line in status.lines() {
if line.starts_with("Mems_allowed_list:") {
if let Some(nodes_str) = line.split(':').nth(1) {
let nodes_str = nodes_str.trim();
if let Ok(node) = nodes_str.parse::<u32>() {
return Some(node);
}
}
}
}
}
None
}
pub fn get_preferred_node(&self) -> u32 {
self.current_node.unwrap_or(0)
}
pub fn find_closest_node(&self, from_node: u32) -> u32 {
if from_node >= self.num_nodes {
return 0;
}
let mut min_distance = u32::MAX;
let mut closest_node = 0;
for (to_node, &distance) in self.distance_matrix[from_node as usize].iter().enumerate() {
if distance < min_distance {
min_distance = distance;
closest_node = to_node as u32;
}
}
closest_node
}
pub fn are_nodes_local(&self, node1: u32, node2: u32) -> bool {
if node1 >= self.num_nodes || node2 >= self.num_nodes {
return false;
}
self.distance_matrix[node1 as usize][node2 as usize] <= 20
}
}
#[derive(Debug, Clone, Copy)]
pub enum NumaAllocationStrategy {
Local,
Interleaved,
Preferred(u32),
MostFree,
}
#[derive(Debug, Clone, Default)]
pub struct NumaMemoryStats {
pub allocations_per_node: HashMap<u32, u64>,
pub memory_per_node: HashMap<u32, u64>,
pub cross_numa_accesses: u64,
pub numa_balancing_ops: u64,
pub affinity_changes: u64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MemoryAccessPattern {
Sequential,
Random,
Temporal,
Spatial,
ReadOnly,
WriteHeavy,
Unknown,
}
#[derive(Debug, Clone)]
struct MemoryBlock {
ptr: NonNull<u8>,
size: usize,
layout: Layout,
allocated_at: Instant,
last_accessed: Instant,
is_free: bool,
bucket_index: usize,
numa_node: Option<u32>,
allocation_strategy: NumaAllocationStrategy,
access_pattern: MemoryAccessPattern,
}
impl MemoryBlock {
fn new(size: usize, alignment: usize, bucket_index: usize) -> Result<Self> {
let layout = Layout::from_size_align(size, alignment).map_err(|e| {
TrustformersError::invalid_input(
format!("Invalid memory layout: {}", e),
Some("layout"),
Some("valid size and alignment"),
Some(&format!("size: {}, alignment: {}", size, alignment)),
)
})?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
return Err(TrustformersError::Resource {
message: "Failed to allocate memory block".to_string(),
resource_type: "memory".to_string(),
current_usage: Some(format!("requested: {} bytes", size)),
suggestion: Some("Check available memory and reduce allocation size".to_string()),
recovery_actions: vec![],
});
}
let now = Instant::now();
Ok(Self {
ptr: NonNull::new(ptr).expect("allocated pointer should not be null"),
size,
layout,
allocated_at: now,
last_accessed: now,
is_free: true,
bucket_index,
numa_node: None,
allocation_strategy: NumaAllocationStrategy::Local,
access_pattern: MemoryAccessPattern::Sequential,
})
}
fn free(&mut self) {
if !self.is_free {
unsafe {
dealloc(self.ptr.as_ptr(), self.layout);
}
self.is_free = true;
}
}
fn mark_used(&mut self) {
self.is_free = false;
self.last_accessed = Instant::now();
}
fn mark_free(&mut self) {
self.is_free = true;
self.last_accessed = Instant::now();
}
}
impl Drop for MemoryBlock {
fn drop(&mut self) {
if !self.is_free {
self.free();
}
}
}
#[derive(Debug, Default, Clone)]
pub struct MemoryPoolStats {
pub total_allocated: usize,
pub total_freed: usize,
pub current_usage: usize,
pub peak_usage: usize,
pub total_requests: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub fragmentation_ratio: f64,
pub gc_runs: u64,
pub blocks_per_bucket: Vec<usize>,
pub avg_allocation_time_us: f64,
pub avg_deallocation_time_us: f64,
}
impl MemoryPoolStats {
pub fn hit_rate(&self) -> f64 {
if self.total_requests > 0 {
self.cache_hits as f64 / self.total_requests as f64
} else {
0.0
}
}
pub fn efficiency(&self) -> f64 {
if self.total_allocated > 0 {
(self.total_allocated - self.current_usage) as f64 / self.total_allocated as f64
} else {
1.0
}
}
}
#[derive(Debug)]
struct MemoryBucket {
size_range: (usize, usize),
free_blocks: VecDeque<usize>, allocated_blocks: Vec<usize>,
total_allocations: u64,
total_deallocations: u64,
}
impl MemoryBucket {
fn new(min_size: usize, max_size: usize) -> Self {
Self {
size_range: (min_size, max_size),
free_blocks: VecDeque::new(),
allocated_blocks: Vec::new(),
total_allocations: 0,
total_deallocations: 0,
}
}
fn can_fit(&self, size: usize) -> bool {
size >= self.size_range.0 && size <= self.size_range.1
}
fn allocate_block(&mut self) -> Option<usize> {
if let Some(block_index) = self.free_blocks.pop_front() {
self.allocated_blocks.push(block_index);
self.total_allocations += 1;
Some(block_index)
} else {
None
}
}
fn deallocate_block(&mut self, block_index: usize) {
if let Some(pos) = self.allocated_blocks.iter().position(|&x| x == block_index) {
self.allocated_blocks.remove(pos);
self.free_blocks.push_back(block_index);
self.total_deallocations += 1;
}
}
}
pub struct MemoryPool {
config: MemoryPoolConfig,
blocks: Arc<RwLock<Vec<MemoryBlock>>>,
buckets: Arc<RwLock<Vec<MemoryBucket>>>,
stats: Arc<Mutex<MemoryPoolStats>>,
allocation_map: Arc<RwLock<HashMap<*const u8, usize>>>, }
impl MemoryPool {
pub fn new(config: MemoryPoolConfig) -> Result<Self> {
let bucket_sizes = Self::calculate_bucket_sizes(&config);
let mut buckets = Vec::new();
for i in 0..bucket_sizes.len() {
let min_size = bucket_sizes[i];
let max_size = if i + 1 < bucket_sizes.len() {
bucket_sizes[i + 1] - 1
} else {
config.max_size
};
buckets.push(MemoryBucket::new(min_size, max_size));
}
let pool = Self {
config: config.clone(),
blocks: Arc::new(RwLock::new(Vec::new())),
buckets: Arc::new(RwLock::new(buckets)),
stats: Arc::new(Mutex::new(MemoryPoolStats::default())),
allocation_map: Arc::new(RwLock::new(HashMap::new())),
};
pool.preallocate_memory()?;
Ok(pool)
}
fn calculate_bucket_sizes(config: &MemoryPoolConfig) -> Vec<usize> {
let mut sizes = Vec::new();
let min_size = config.alignment;
let max_size = config.max_size;
match config.preallocation_strategy {
PreallocationStrategy::None => {
sizes.push(min_size);
},
PreallocationStrategy::Linear => {
let step = (max_size - min_size) / config.num_buckets;
for i in 0..config.num_buckets {
sizes.push(min_size + i * step);
}
},
PreallocationStrategy::Exponential => {
let mut size = min_size;
for _ in 0..config.num_buckets {
sizes.push(size);
size = (size * 2).min(max_size);
if size >= max_size {
break;
}
}
},
PreallocationStrategy::Fibonacci => {
let mut a = min_size;
let mut b = min_size * 2;
sizes.push(a);
for _ in 1..config.num_buckets {
sizes.push(b);
let next = a + b;
a = b;
b = next;
if b >= max_size {
break;
}
}
},
}
sizes
}
fn preallocate_memory(&self) -> Result<()> {
match self.config.preallocation_strategy {
PreallocationStrategy::None => Ok(()),
_ => {
let buckets = self.buckets.read().expect("lock should not be poisoned");
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
for (bucket_index, bucket) in buckets.iter().enumerate() {
let size = bucket.size_range.1; for _ in 0..4 {
let block = MemoryBlock::new(size, self.config.alignment, bucket_index)?;
blocks.push(block);
}
}
info!("Preallocated {} blocks", blocks.len());
Ok(())
},
}
}
pub fn allocate(&self, size: usize) -> Result<NonNull<u8>> {
let start_time = Instant::now();
let aligned_size = self.align_size(size);
let bucket_index = self.find_bucket_for_size(aligned_size)?;
let block_index = {
let mut buckets = self.buckets.write().expect("lock should not be poisoned");
let bucket = &mut buckets[bucket_index];
if let Some(index) = bucket.allocate_block() {
index
} else {
drop(buckets); self.allocate_new_block(aligned_size, bucket_index)?
}
};
{
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
blocks[block_index].mark_used();
let ptr = blocks[block_index].ptr.as_ptr();
let mut allocation_map =
self.allocation_map.write().expect("lock should not be poisoned");
allocation_map.insert(ptr, block_index);
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_requests += 1;
stats.current_usage += aligned_size;
stats.peak_usage = stats.peak_usage.max(stats.current_usage);
let allocation_time = start_time.elapsed().as_micros() as f64;
stats.avg_allocation_time_us = (stats.avg_allocation_time_us
* (stats.total_requests - 1) as f64
+ allocation_time)
/ stats.total_requests as f64;
if block_index < self.blocks.read().expect("lock should not be poisoned").len() {
stats.cache_hits += 1;
} else {
stats.cache_misses += 1;
}
}
let blocks = self.blocks.read().expect("lock should not be poisoned");
Ok(blocks[block_index].ptr)
}
pub fn deallocate(&self, ptr: NonNull<u8>) -> Result<()> {
let start_time = Instant::now();
let ptr_raw = ptr.as_ptr() as *const u8;
let block_index = {
let mut allocation_map =
self.allocation_map.write().expect("lock should not be poisoned");
allocation_map.remove(&ptr_raw).ok_or_else(|| {
TrustformersError::invalid_input(
"Invalid pointer for deallocation",
Some("pointer"),
Some("valid allocated pointer"),
Some(&format!("{:p}", ptr_raw)),
)
})?
};
let size = {
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
blocks[block_index].mark_free();
blocks[block_index].size
};
{
let mut buckets = self.buckets.write().expect("lock should not be poisoned");
let blocks = self.blocks.read().expect("lock should not be poisoned");
let bucket_index = blocks[block_index].bucket_index;
buckets[bucket_index].deallocate_block(block_index);
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_freed += size;
stats.current_usage = stats.current_usage.saturating_sub(size);
let deallocation_time = start_time.elapsed().as_micros() as f64;
let total_deallocations = stats.total_freed / size.max(1);
stats.avg_deallocation_time_us = (stats.avg_deallocation_time_us
* (total_deallocations - 1) as f64
+ deallocation_time)
/ total_deallocations as f64;
}
if self.config.enable_gc && self.should_run_gc() {
self.run_garbage_collection()?;
}
Ok(())
}
fn allocate_new_block(&self, size: usize, bucket_index: usize) -> Result<usize> {
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
let block = MemoryBlock::new(size, self.config.alignment, bucket_index)?;
let block_index = blocks.len();
blocks.push(block);
let mut buckets = self.buckets.write().expect("lock should not be poisoned");
buckets[bucket_index].free_blocks.push_back(block_index);
Ok(block_index)
}
fn find_bucket_for_size(&self, size: usize) -> Result<usize> {
let buckets = self.buckets.read().expect("lock should not be poisoned");
for (index, bucket) in buckets.iter().enumerate() {
if bucket.can_fit(size) {
return Ok(index);
}
}
Err(TrustformersError::Core(CoreTrustformersError::other(
format!("Size too large for any bucket: {} bytes", size),
)))
}
fn align_size(&self, size: usize) -> usize {
(size + self.config.alignment - 1) & !(self.config.alignment - 1)
}
fn should_run_gc(&self) -> bool {
let stats = self.stats.lock().expect("lock should not be poisoned");
stats.fragmentation_ratio > self.config.gc_threshold
}
fn run_garbage_collection(&self) -> Result<()> {
let start_time = Instant::now();
debug!("Starting garbage collection");
let mut blocks_freed = 0;
let mut memory_freed = 0;
{
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
let mut buckets = self.buckets.write().expect("lock should not be poisoned");
let mut allocation_map =
self.allocation_map.write().expect("lock should not be poisoned");
let threshold = Instant::now() - std::time::Duration::from_secs(300); let mut blocks_to_remove = Vec::new();
for (index, block) in blocks.iter().enumerate() {
if block.is_free && block.last_accessed < threshold {
blocks_to_remove.push(index);
}
}
blocks_to_remove.reverse();
for &index in &blocks_to_remove {
let block = &mut blocks[index];
memory_freed += block.size;
block.free();
let bucket = &mut buckets[block.bucket_index];
bucket.free_blocks.retain(|&x| x != index);
allocation_map.retain(|_, &mut v| v != index);
blocks_freed += 1;
}
for &index in &blocks_to_remove {
blocks.remove(index);
for bucket in buckets.iter_mut() {
for block_idx in &mut bucket.free_blocks {
if *block_idx > index {
*block_idx -= 1;
}
}
for block_idx in &mut bucket.allocated_blocks {
if *block_idx > index {
*block_idx -= 1;
}
}
}
for (_, block_idx) in allocation_map.iter_mut() {
if *block_idx > index {
*block_idx -= 1;
}
}
}
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.gc_runs += 1;
stats.fragmentation_ratio = self.calculate_fragmentation();
}
let gc_time = start_time.elapsed();
info!(
"Garbage collection completed: freed {} blocks ({} bytes) in {:?}",
blocks_freed, memory_freed, gc_time
);
Ok(())
}
fn calculate_fragmentation(&self) -> f64 {
let blocks = self.blocks.read().expect("lock should not be poisoned");
if blocks.is_empty() {
return 0.0;
}
let total_blocks = blocks.len();
let free_blocks = blocks.iter().filter(|b| b.is_free).count();
if total_blocks > 0 {
free_blocks as f64 / total_blocks as f64
} else {
0.0
}
}
pub fn get_stats(&self) -> MemoryPoolStats {
let mut stats = self.stats.lock().expect("lock should not be poisoned").clone();
stats.fragmentation_ratio = self.calculate_fragmentation();
let buckets = self.buckets.read().expect("lock should not be poisoned");
stats.blocks_per_bucket =
buckets.iter().map(|b| b.free_blocks.len() + b.allocated_blocks.len()).collect();
stats
}
pub fn reset_stats(&self) {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
*stats = MemoryPoolStats::default();
}
pub fn force_gc(&self) -> Result<()> {
self.run_garbage_collection()
}
pub fn memory_usage(&self) -> MemoryUsage {
let stats = self.get_stats();
let blocks = self.blocks.read().expect("lock should not be poisoned");
let total_capacity = blocks.iter().map(|b| b.size).sum();
let allocated_memory = blocks.iter().filter(|b| !b.is_free).map(|b| b.size).sum();
MemoryUsage {
total_capacity,
allocated_memory,
free_memory: total_capacity - allocated_memory,
utilization_ratio: if total_capacity > 0 {
allocated_memory as f64 / total_capacity as f64
} else {
0.0
},
num_blocks: blocks.len(),
num_free_blocks: blocks.iter().filter(|b| b.is_free).count(),
fragmentation_ratio: stats.fragmentation_ratio,
}
}
pub fn shrink(&self) -> Result<usize> {
let mut bytes_freed = 0;
let threshold = Instant::now() - std::time::Duration::from_secs(600);
{
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
let mut buckets = self.buckets.write().expect("lock should not be poisoned");
blocks.retain(|block| {
if block.is_free && block.last_accessed < threshold {
bytes_freed += block.size;
false
} else {
true
}
});
for bucket in buckets.iter_mut() {
bucket.free_blocks.clear();
bucket.allocated_blocks.clear();
}
for (index, block) in blocks.iter().enumerate() {
let bucket = &mut buckets[block.bucket_index];
if block.is_free {
bucket.free_blocks.push_back(index);
} else {
bucket.allocated_blocks.push(index);
}
}
}
info!("Pool shrink freed {} bytes", bytes_freed);
Ok(bytes_freed)
}
}
impl Drop for MemoryPool {
fn drop(&mut self) {
let mut blocks = self.blocks.write().expect("lock should not be poisoned");
for block in blocks.iter_mut() {
block.free();
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryUsage {
pub total_capacity: usize,
pub allocated_memory: usize,
pub free_memory: usize,
pub utilization_ratio: f64,
pub num_blocks: usize,
pub num_free_blocks: usize,
pub fragmentation_ratio: f64,
}
pub struct ThreadLocalMemoryPool {
pools: std::thread::LocalKey<std::cell::RefCell<Option<MemoryPool>>>,
config: MemoryPoolConfig,
}
impl ThreadLocalMemoryPool {
pub fn new(config: MemoryPoolConfig) -> Self {
std::thread_local! {
static POOL: std::cell::RefCell<Option<MemoryPool>> = const { std::cell::RefCell::new(None) };
}
Self {
pools: POOL,
config,
}
}
pub fn with_pool<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce(&MemoryPool) -> Result<R>,
{
let pool = MemoryPool::new(self.config.clone())?;
f(&pool)
}
}
#[allow(static_mut_refs)] static mut GLOBAL_POOL: Option<MemoryPool> = None;
static GLOBAL_POOL_INIT: std::sync::Once = std::sync::Once::new();
pub fn init_global_pool(config: MemoryPoolConfig) -> Result<()> {
GLOBAL_POOL_INIT.call_once(|| unsafe {
GLOBAL_POOL =
Some(MemoryPool::new(config).expect("Failed to initialize global memory pool"));
});
Ok(())
}
#[allow(static_mut_refs)] pub fn global_pool() -> Result<&'static MemoryPool> {
unsafe {
GLOBAL_POOL.as_ref().ok_or_else(|| TrustformersError::Resource {
message: "Global memory pool not initialized".to_string(),
resource_type: "memory_pool".to_string(),
current_usage: Some("uninitialized".to_string()),
suggestion: Some(
"Call init_global_pool() before using global pool functions".to_string(),
),
recovery_actions: vec![],
})
}
}
pub fn allocate(size: usize) -> Result<NonNull<u8>> {
global_pool()?.allocate(size)
}
pub fn deallocate(ptr: NonNull<u8>) -> Result<()> {
global_pool()?.deallocate(ptr)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_creation() {
let config = MemoryPoolConfig::default();
let pool = MemoryPool::new(config);
assert!(pool.is_ok());
}
#[test]
fn test_basic_allocation_deallocation() {
let config = MemoryPoolConfig::default();
let pool = MemoryPool::new(config).expect("operation failed in test");
let ptr = pool.allocate(1024).expect("operation failed in test");
assert!(pool.deallocate(ptr).is_ok());
}
#[test]
fn test_multiple_allocations() {
let config = MemoryPoolConfig::default();
let pool = MemoryPool::new(config).expect("operation failed in test");
let mut ptrs = Vec::new();
for _ in 0..10 {
let ptr = pool.allocate(512).expect("operation failed in test");
ptrs.push(ptr);
}
for (i, ptr) in ptrs.into_iter().enumerate() {
if let Err(e) = pool.deallocate(ptr) {
eprintln!(
"Failed to deallocate pointer {:?} at index {}: {:?}",
ptr, i, e
);
}
}
}
#[test]
fn test_bucket_size_calculation() {
let config = MemoryPoolConfig {
num_buckets: 4,
preallocation_strategy: PreallocationStrategy::Exponential,
..Default::default()
};
let sizes = MemoryPool::calculate_bucket_sizes(&config);
assert!(!sizes.is_empty());
assert!(sizes.len() <= config.num_buckets);
for i in 1..sizes.len() {
assert!(sizes[i] >= sizes[i - 1]);
}
}
#[test]
fn test_size_alignment() {
let config = MemoryPoolConfig {
alignment: 64,
..Default::default()
};
let pool = MemoryPool::new(config).expect("operation failed in test");
assert_eq!(pool.align_size(1), 64);
assert_eq!(pool.align_size(64), 64);
assert_eq!(pool.align_size(65), 128);
assert_eq!(pool.align_size(100), 128);
}
#[test]
fn test_statistics_tracking() {
let config = MemoryPoolConfig {
enable_tracking: true,
..Default::default()
};
let pool = MemoryPool::new(config).expect("operation failed in test");
let initial_stats = pool.get_stats();
assert_eq!(initial_stats.total_requests, 0);
let ptr = pool.allocate(1024).expect("operation failed in test");
let stats_after_alloc = pool.get_stats();
assert_eq!(stats_after_alloc.total_requests, 1);
assert!(stats_after_alloc.current_usage > 0);
pool.deallocate(ptr).expect("operation failed in test");
let stats_after_dealloc = pool.get_stats();
assert!(stats_after_dealloc.total_freed > 0);
}
#[test]
fn test_memory_usage_calculation() {
let config = MemoryPoolConfig::default();
let pool = MemoryPool::new(config).expect("operation failed in test");
let usage_before = pool.memory_usage();
let ptr = pool.allocate(2048).expect("operation failed in test");
let usage_after = pool.memory_usage();
assert!(usage_after.allocated_memory > usage_before.allocated_memory);
assert!(usage_after.utilization_ratio > usage_before.utilization_ratio);
pool.deallocate(ptr).expect("operation failed in test");
}
}