use std::alloc::{alloc, dealloc, Layout, System};
use std::cell::RefCell;
use std::collections::HashMap;
use std::ptr::NonNull;
pub struct MemoryPool {
pools: HashMap<usize, Vec<NonNull<u8>>>,
allocated: Vec<(NonNull<u8>, Layout)>,
size_classes: Vec<usize>,
}
impl MemoryPool {
pub fn new() -> Self {
let size_classes = vec![
1024, 4096, 16384, 65536, 262144, 1048576, 4194304, ];
Self {
pools: HashMap::new(),
allocated: Vec::new(),
size_classes,
}
}
pub fn allocate(&mut self, size_bytes: usize) -> Result<NonNull<u8>, crate::error::HyperionError> {
if size_bytes > 4 * 1024 * 1024 { return self.allocate_large(size_bytes);
}
let size_class = self.find_size_class(size_bytes);
if let Some(pool) = self.pools.get_mut(&size_class) {
if let Some(block) = pool.pop() {
return Ok(block);
}
}
self.allocate_large(size_class)
}
pub fn deallocate(&mut self, ptr: NonNull<u8>, layout: Layout) {
let size = layout.size();
if size > 4 * 1024 * 1024 {
unsafe { dealloc(ptr.as_ptr(), layout); }
return;
}
let size_class = self.find_size_class(size);
if let Some(pool) = self.pools.get_mut(&size_class) {
unsafe {
std::ptr::write_bytes(ptr.as_ptr(), 0, size);
}
pool.push(ptr);
} else {
unsafe { dealloc(ptr.as_ptr(), layout); }
}
}
fn find_size_class(&self, size: usize) -> usize {
for &class in &self.size_classes {
if size <= class {
return class;
}
}
size.next_power_of_two()
}
fn allocate_large(&mut self, size: usize) -> Result<NonNull<u8>, crate::error::HyperionError> {
let layout = Layout::from_size_align(size, std::mem::align_of::<u8>())
.map_err(|_| crate::error::HyperionError::Internal(format!("Invalid layout for size {}", size)))?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
panic!("Memory allocation failed for size {}", size);
}
let ptr = unsafe { NonNull::new_unchecked(ptr) };
self.allocated.push((ptr, layout));
Ok(ptr)
}
pub fn cleanup(&mut self) {
for (ptr, layout) in self.allocated.drain(..) {
unsafe { dealloc(ptr.as_ptr(), layout); }
}
self.pools.clear();
}
}
impl Drop for MemoryPool {
fn drop(&mut self) {
self.cleanup();
}
}
thread_local! {
static MEMORY_POOL: RefCell<MemoryPool> = RefCell::new(MemoryPool::new());
}
pub struct FriMemoryManager {
pool: MemoryPool,
}
impl FriMemoryManager {
pub fn new() -> Self {
Self {
pool: MemoryPool::new(),
}
}
pub fn allocate_goldilocks(&mut self, count: usize) -> Vec<crate::field::Goldilocks> {
let size_bytes = count * std::mem::size_of::<crate::field::Goldilocks>();
let ptr = self.pool.allocate(size_bytes).expect("Memory allocation should succeed");
unsafe {
Vec::from_raw_parts(
ptr.as_ptr() as *mut crate::field::Goldilocks,
count,
count,
)
}
}
pub fn deallocate_goldilocks(&mut self, vec: Vec<crate::field::Goldilocks>) -> Result<(), crate::error::HyperionError> {
let ptr = unsafe { NonNull::new_unchecked(vec.as_ptr() as *mut u8) };
let size_bytes = vec.len() * std::mem::size_of::<crate::field::Goldilocks>();
let layout = Layout::from_size_align(size_bytes, std::mem::align_of::<crate::field::Goldilocks>())
.map_err(|_| crate::error::HyperionError::Internal(format!("Invalid layout for Goldilocks vector size {}", size_bytes)))?;
std::mem::forget(vec);
self.pool.deallocate(ptr, layout);
Ok(())
}
}
impl Default for FriMemoryManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::Goldilocks;
#[test]
fn test_memory_pool_allocation() {
let mut pool = MemoryPool::new();
let ptr1 = pool.allocate(1024).unwrap();
assert!(!ptr1.as_ptr().is_null());
let ptr2 = pool.allocate(5 * 1024 * 1024).unwrap(); assert!(!ptr2.as_ptr().is_null());
pool.cleanup();
}
#[test]
fn test_fri_memory_manager() {
let mut manager = FriMemoryManager::new();
let mut vec = manager.allocate_goldilocks(1024);
for i in 0..1024 {
vec[i] = Goldilocks::from_i64(i as i64);
}
for i in 0..1024 {
assert_eq!(vec[i], Goldilocks::from_i64(i as i64));
}
let _ = manager.deallocate_goldilocks(vec);
}
}