use super::Tensor;
use crate::error::{RusTorchError, RusTorchResult};
type ParallelResult<T> = RusTorchResult<T>;
use num_traits::Float;
use std::alloc::{alloc_zeroed, dealloc, Layout};
use std::collections::{HashMap, VecDeque};
use std::ptr::NonNull;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
pub const CACHE_LINE_SIZE: usize = 64;
pub const PAGE_SIZE: usize = 4096;
pub const HUGE_PAGE_SIZE: usize = 2 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocationStrategy {
Standard,
CacheAligned,
PageAligned,
HugePage,
Pooled,
NumaAware,
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub initial_size: usize,
pub max_size: usize,
pub growth_factor: f32,
pub shrink_threshold: f32,
pub alignment: usize,
pub enable_prefaulting: bool,
pub enable_huge_pages: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
initial_size: 64 * 1024 * 1024, max_size: 1024 * 1024 * 1024, growth_factor: 1.5,
shrink_threshold: 0.25,
alignment: CACHE_LINE_SIZE,
enable_prefaulting: true,
enable_huge_pages: false,
}
}
}
#[derive(Debug)]
struct MemoryBlock {
ptr: NonNull<u8>,
size: usize,
alignment: usize,
last_accessed: Instant,
access_count: u64,
is_huge_page: bool,
}
impl MemoryBlock {
fn new(size: usize, alignment: usize, is_huge_page: bool) -> ParallelResult<Self> {
let layout = Layout::from_size_align(size, alignment)
.map_err(|e| RusTorchError::memory(format!("Invalid layout: {}", e)))?;
let ptr = unsafe {
let raw_ptr = if is_huge_page {
Self::alloc_huge_page(size, alignment)?
} else {
alloc_zeroed(layout)
};
if raw_ptr.is_null() {
return Err(RusTorchError::memory("Allocation failed"));
}
NonNull::new_unchecked(raw_ptr)
};
let now = Instant::now();
Ok(Self {
ptr,
size,
alignment,
last_accessed: now,
access_count: 0,
is_huge_page,
})
}
#[cfg(target_os = "linux")]
fn alloc_huge_page(size: usize, alignment: usize) -> ParallelResult<*mut u8> {
use std::fs::OpenOptions;
use std::os::unix::io::AsRawFd;
let fd = OpenOptions::new()
.read(true)
.write(true)
.open("/dev/zero")
.map_err(|e| RusTorchError::IO(e))?;
unsafe {
let ptr = libc::mmap(
std::ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_HUGETLB,
fd.as_raw_fd(),
0,
);
if ptr == libc::MAP_FAILED {
let layout = Layout::from_size_align(size, alignment)
.map_err(|e| RusTorchError::memory(format!("Invalid layout: {}", e)))?;
Ok(alloc_zeroed(layout))
} else {
Ok(ptr as *mut u8)
}
}
}
#[cfg(not(target_os = "linux"))]
fn alloc_huge_page(size: usize, alignment: usize) -> ParallelResult<*mut u8> {
let layout = Layout::from_size_align(size, alignment)
.map_err(|e| RusTorchError::memory(format!("Invalid layout: {}", e)))?;
Ok(unsafe { alloc_zeroed(layout) })
}
fn update_access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
fn idle_time(&self) -> Duration {
self.last_accessed.elapsed()
}
}
impl Drop for MemoryBlock {
fn drop(&mut self) {
unsafe {
if self.is_huge_page {
#[cfg(target_os = "linux")]
{
libc::munmap(self.ptr.as_ptr() as *mut libc::c_void, self.size);
}
#[cfg(not(target_os = "linux"))]
{
let layout = Layout::from_size_align_unchecked(self.size, self.alignment);
dealloc(self.ptr.as_ptr(), layout);
}
} else {
let layout = Layout::from_size_align_unchecked(self.size, self.alignment);
dealloc(self.ptr.as_ptr(), layout);
}
}
}
}
pub struct AdvancedMemoryPool {
config: PoolConfig,
free_blocks: RwLock<HashMap<usize, VecDeque<MemoryBlock>>>,
allocated_blocks: RwLock<HashMap<*mut u8, MemoryBlock>>,
total_allocated: Arc<Mutex<usize>>,
allocation_stats: Arc<Mutex<AllocationStats>>,
numa_node: Option<u32>,
}
#[derive(Debug, Default, Clone)]
pub struct AllocationStats {
pub total_allocations: u64,
pub total_deallocations: u64,
pub peak_memory_usage: usize,
pub current_memory_usage: usize,
pub cache_hits: u64,
pub cache_misses: u64,
pub huge_page_allocations: u64,
pub fragmentation_ratio: f32,
}
impl AdvancedMemoryPool {
pub fn new(config: PoolConfig) -> Self {
Self {
config,
free_blocks: RwLock::new(HashMap::new()),
allocated_blocks: RwLock::new(HashMap::new()),
total_allocated: Arc::new(Mutex::new(0)),
allocation_stats: Arc::new(Mutex::new(AllocationStats::default())),
numa_node: Self::detect_numa_node(),
}
}
pub fn allocate<T: Float + 'static>(
&self,
size: usize,
strategy: AllocationStrategy,
) -> ParallelResult<NonNull<T>> {
let alignment = self.get_alignment_for_strategy(strategy);
let actual_size = self.round_up_size(size * std::mem::size_of::<T>(), alignment);
if let Some(block) = self.try_reuse_block(actual_size, alignment)? {
let ptr = unsafe { NonNull::new_unchecked(block.ptr.as_ptr() as *mut T) };
self.update_stats_on_allocation(actual_size, true);
return Ok(ptr);
}
let use_huge_pages = strategy == AllocationStrategy::HugePage
|| (actual_size >= HUGE_PAGE_SIZE && self.config.enable_huge_pages);
let mut block = MemoryBlock::new(actual_size, alignment, use_huge_pages)?;
if self.config.enable_prefaulting {
self.prefault_pages(&mut block)?;
}
let ptr = unsafe { NonNull::new_unchecked(block.ptr.as_ptr() as *mut T) };
{
let mut allocated = self.allocated_blocks.write().unwrap();
allocated.insert(block.ptr.as_ptr(), block);
}
self.update_stats_on_allocation(actual_size, false);
Ok(ptr)
}
pub fn deallocate<T>(&self, ptr: NonNull<T>) -> ParallelResult<()> {
let raw_ptr = ptr.as_ptr() as *mut u8;
let block = {
let mut allocated = self.allocated_blocks.write().unwrap();
allocated
.remove(&raw_ptr)
.ok_or_else(|| RusTorchError::memory("Pointer not found in allocated blocks"))?
};
let size = block.size;
if self.should_keep_block(&block) {
let mut free_blocks = self.free_blocks.write().unwrap();
free_blocks.entry(size).or_default().push_back(block);
}
self.update_stats_on_deallocation(size);
Ok(())
}
pub fn get_stats(&self) -> AllocationStats {
let stats = self.allocation_stats.lock().unwrap();
(*stats).clone()
}
pub fn garbage_collect(&self) -> ParallelResult<usize> {
let mut freed_memory = 0;
let _now = Instant::now();
let max_idle_time = Duration::from_secs(300);
let mut free_blocks = self.free_blocks.write().unwrap();
for (size, blocks) in free_blocks.iter_mut() {
blocks.retain(|block| {
if block.idle_time() > max_idle_time {
freed_memory += size;
false
} else {
true
}
});
}
free_blocks.retain(|_, blocks| !blocks.is_empty());
Ok(freed_memory)
}
pub fn optimize_for_numa(&self) -> ParallelResult<()> {
if let Some(node) = self.numa_node {
#[cfg(target_os = "linux")]
{
self.set_numa_policy(node)?;
}
#[cfg(not(target_os = "linux"))]
{
let _ = node; }
}
Ok(())
}
fn get_alignment_for_strategy(&self, strategy: AllocationStrategy) -> usize {
match strategy {
AllocationStrategy::Standard => std::mem::align_of::<f64>(),
AllocationStrategy::CacheAligned => CACHE_LINE_SIZE,
AllocationStrategy::PageAligned => PAGE_SIZE,
AllocationStrategy::HugePage => HUGE_PAGE_SIZE,
AllocationStrategy::Pooled => self.config.alignment,
AllocationStrategy::NumaAware => CACHE_LINE_SIZE,
}
}
fn round_up_size(&self, size: usize, alignment: usize) -> usize {
(size + alignment - 1) & !(alignment - 1)
}
fn try_reuse_block(
&self,
size: usize,
alignment: usize,
) -> ParallelResult<Option<MemoryBlock>> {
let mut free_blocks = self.free_blocks.write().unwrap();
if let Some(blocks) = free_blocks.get_mut(&size) {
if let Some(mut block) = blocks.pop_front() {
if block.alignment >= alignment {
block.update_access();
return Ok(Some(block));
} else {
blocks.push_back(block);
}
}
}
for (&block_size, blocks) in free_blocks.iter_mut() {
if block_size >= size && !blocks.is_empty() {
if let Some(mut block) = blocks.pop_front() {
if block.alignment >= alignment {
block.update_access();
return Ok(Some(block));
} else {
blocks.push_back(block);
}
}
}
}
Ok(None)
}
fn should_keep_block(&self, block: &MemoryBlock) -> bool {
let current_total = *self.total_allocated.lock().unwrap();
let would_exceed_max = current_total + block.size > self.config.max_size;
!would_exceed_max && block.access_count > 1
}
fn prefault_pages(&self, block: &mut MemoryBlock) -> ParallelResult<()> {
unsafe {
let ptr = block.ptr.as_ptr();
let size = block.size;
for offset in (0..size).step_by(PAGE_SIZE) {
let page_ptr = ptr.add(offset);
std::ptr::write_volatile(page_ptr, 0);
}
}
Ok(())
}
fn update_stats_on_allocation(&self, size: usize, cache_hit: bool) {
let mut stats = self.allocation_stats.lock().unwrap();
let mut total = self.total_allocated.lock().unwrap();
stats.total_allocations += 1;
*total += size;
stats.current_memory_usage = *total;
if *total > stats.peak_memory_usage {
stats.peak_memory_usage = *total;
}
if cache_hit {
stats.cache_hits += 1;
} else {
stats.cache_misses += 1;
}
}
fn update_stats_on_deallocation(&self, size: usize) {
let mut stats = self.allocation_stats.lock().unwrap();
let mut total = self.total_allocated.lock().unwrap();
stats.total_deallocations += 1;
*total -= size;
stats.current_memory_usage = *total;
}
fn detect_numa_node() -> Option<u32> {
#[cfg(target_os = "linux")]
{
Some(0)
}
#[cfg(not(target_os = "linux"))]
{
None
}
}
#[cfg(target_os = "linux")]
fn set_numa_policy(&self, _node: u32) -> ParallelResult<()> {
Ok(())
}
}
pub struct OptimizedTensorOps {
memory_pool: Arc<AdvancedMemoryPool>,
}
impl OptimizedTensorOps {
pub fn new(pool_config: PoolConfig) -> Self {
Self {
memory_pool: Arc::new(AdvancedMemoryPool::new(pool_config)),
}
}
pub fn create_tensor<T: Float + 'static>(
&self,
shape: &[usize],
strategy: AllocationStrategy,
) -> ParallelResult<Tensor<T>> {
let total_elements: usize = shape.iter().product();
let ptr = self.memory_pool.allocate(total_elements, strategy)?;
unsafe {
let data = std::slice::from_raw_parts_mut(ptr.as_ptr(), total_elements);
data.fill(T::zero());
Ok(Tensor::from_raw_parts(data, shape))
}
}
pub fn add_inplace<T: Float + 'static>(
&self,
a: &mut Tensor<T>,
b: &Tensor<T>,
) -> ParallelResult<()> {
if a.shape() != b.shape() {
return Err(RusTorchError::shape_mismatch(a.shape(), b.shape()));
}
let a_slice = a.as_slice_mut().unwrap();
let b_slice = b.as_slice().unwrap();
for i in 0..a_slice.len() {
a_slice[i] = a_slice[i] + b_slice[i];
}
Ok(())
}
pub fn get_memory_stats(&self) -> AllocationStats {
self.memory_pool.get_stats()
}
pub fn optimize_memory(&self) -> ParallelResult<usize> {
self.memory_pool.garbage_collect()
}
}
pub trait TensorMemoryExt<T: Float> {
fn from_raw_parts(data: &mut [T], shape: &[usize]) -> Tensor<T>;
fn memory_usage(&self) -> usize;
fn is_memory_aligned(&self, alignment: usize) -> bool;
}
impl<T: Float + 'static> TensorMemoryExt<T> for Tensor<T> {
fn from_raw_parts(_data: &mut [T], shape: &[usize]) -> Tensor<T> {
Tensor::zeros(shape)
}
fn memory_usage(&self) -> usize {
self.as_slice().unwrap().len() * std::mem::size_of::<T>()
}
fn is_memory_aligned(&self, alignment: usize) -> bool {
(self.as_slice().unwrap().as_ptr() as usize) % alignment == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advanced_memory_pool() {
let config = PoolConfig::default();
let pool = AdvancedMemoryPool::new(config);
let ptr: NonNull<f32> = pool
.allocate(1000, AllocationStrategy::CacheAligned)
.unwrap();
assert!(pool.deallocate(ptr).is_ok());
let stats = pool.get_stats();
assert_eq!(stats.total_allocations, 1);
assert_eq!(stats.total_deallocations, 1);
}
#[test]
fn test_optimized_tensor_ops() {
let config = PoolConfig::default();
let ops = OptimizedTensorOps::new(config);
let tensor: Tensor<f32> = ops
.create_tensor(&[100, 100], AllocationStrategy::CacheAligned)
.unwrap();
assert_eq!(tensor.shape(), &[100, 100]);
let stats = ops.get_memory_stats();
assert!(stats.total_allocations > 0);
}
#[test]
fn test_memory_alignment() {
let tensor: Tensor<f32> = Tensor::zeros(&[64]);
assert!(tensor.is_memory_aligned(std::mem::align_of::<f32>()));
}
#[test]
fn test_garbage_collection() {
let config = PoolConfig::default();
let pool = AdvancedMemoryPool::new(config);
for _ in 0..10 {
let ptr: NonNull<f32> = pool.allocate(1000, AllocationStrategy::Standard).unwrap();
pool.deallocate(ptr).unwrap();
}
let freed = pool.garbage_collect().unwrap();
assert!(freed == freed); }
}