use crate::error::{Result, ZiporaError};
use std::collections::HashMap;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
pub struct MemoryMappedAllocator {
min_mmap_size: usize,
region_cache: Arc<Mutex<HashMap<usize, Vec<*mut u8>>>>,
total_allocated: AtomicU64,
total_freed: AtomicU64,
mmap_calls: AtomicU64,
munmap_calls: AtomicU64,
cache_hits: AtomicU64,
cache_misses: AtomicU64,
}
#[derive(Debug)]
pub struct MmapAllocation {
ptr: NonNull<u8>,
size: usize,
actual_size: usize, }
#[derive(Debug, Clone)]
pub struct MmapStats {
pub total_allocated: u64,
pub total_freed: u64,
pub mmap_calls: u64,
pub munmap_calls: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub cached_regions: usize,
}
impl MemoryMappedAllocator {
pub fn new(min_mmap_size: usize) -> Self {
Self {
min_mmap_size,
region_cache: Arc::new(Mutex::new(HashMap::new())),
total_allocated: AtomicU64::new(0),
total_freed: AtomicU64::new(0),
mmap_calls: AtomicU64::new(0),
munmap_calls: AtomicU64::new(0),
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
}
}
pub fn default() -> Self {
Self::new(16 * 1024)
}
pub fn allocate(&self, size: usize) -> Result<MmapAllocation> {
if size < self.min_mmap_size {
return Err(ZiporaError::invalid_data(
"allocation too small for memory mapping",
));
}
let page_size = Self::get_page_size();
let actual_size = (size + page_size - 1) & !(page_size - 1);
if let Ok(mut cache) = self.region_cache.try_lock() {
if let Some(regions) = cache.get_mut(&actual_size) {
if let Some(ptr) = regions.pop() {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
self.total_allocated
.fetch_add(size as u64, Ordering::Relaxed);
return Ok(MmapAllocation {
ptr: unsafe { NonNull::new_unchecked(ptr) },
size,
actual_size,
});
}
}
}
self.cache_misses.fetch_add(1, Ordering::Relaxed);
self.mmap_calls.fetch_add(1, Ordering::Relaxed);
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
actual_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
-1,
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(ZiporaError::out_of_memory(size));
}
unsafe {
libc::madvise(ptr, actual_size, libc::MADV_WILLNEED);
libc::madvise(ptr, actual_size, libc::MADV_SEQUENTIAL);
}
self.total_allocated
.fetch_add(size as u64, Ordering::Relaxed);
Ok(MmapAllocation {
ptr: unsafe { NonNull::new_unchecked(ptr as *mut u8) },
size,
actual_size,
})
}
pub fn deallocate(&self, allocation: MmapAllocation) -> Result<()> {
self.total_freed
.fetch_add(allocation.size as u64, Ordering::Relaxed);
if let Ok(mut cache) = self.region_cache.try_lock() {
let regions = cache.entry(allocation.actual_size).or_insert_with(Vec::new);
const MAX_CACHED_REGIONS_PER_SIZE: usize = 4;
if regions.len() < MAX_CACHED_REGIONS_PER_SIZE {
regions.push(allocation.ptr.as_ptr());
return Ok(());
}
}
self.munmap_calls.fetch_add(1, Ordering::Relaxed);
unsafe {
if libc::munmap(
allocation.ptr.as_ptr() as *mut libc::c_void,
allocation.actual_size,
) != 0
{
return Err(ZiporaError::io_error("failed to unmap memory"));
}
}
Ok(())
}
pub fn should_use_mmap(&self, size: usize) -> bool {
size >= self.min_mmap_size
}
pub fn stats(&self) -> MmapStats {
let cached_regions = if let Ok(cache) = self.region_cache.try_lock() {
cache.values().map(|v| v.len()).sum()
} else {
0
};
MmapStats {
total_allocated: self.total_allocated.load(Ordering::Relaxed),
total_freed: self.total_freed.load(Ordering::Relaxed),
mmap_calls: self.mmap_calls.load(Ordering::Relaxed),
munmap_calls: self.munmap_calls.load(Ordering::Relaxed),
cache_hits: self.cache_hits.load(Ordering::Relaxed),
cache_misses: self.cache_misses.load(Ordering::Relaxed),
cached_regions,
}
}
pub fn clear_cache(&self) -> Result<()> {
if let Ok(mut cache) = self.region_cache.lock() {
for (size, regions) in cache.drain() {
for ptr in regions {
self.munmap_calls.fetch_add(1, Ordering::Relaxed);
unsafe {
if libc::munmap(ptr as *mut libc::c_void, size) != 0 {
log::warn!("Failed to unmap cached region of size {}", size);
}
}
}
}
}
Ok(())
}
fn get_page_size() -> usize {
unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize }
}
}
impl Drop for MemoryMappedAllocator {
fn drop(&mut self) {
let _ = self.clear_cache();
}
}
unsafe impl Send for MemoryMappedAllocator {}
unsafe impl Sync for MemoryMappedAllocator {}
impl MmapAllocation {
#[inline]
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
pub fn actual_size(&self) -> usize {
self.actual_size
}
pub fn as_ptr<T>(&self) -> *mut T {
self.ptr.as_ptr() as *mut T
}
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mmap_allocator_creation() {
let allocator = MemoryMappedAllocator::new(16 * 1024);
assert!(allocator.should_use_mmap(20 * 1024));
assert!(!allocator.should_use_mmap(8 * 1024));
}
#[test]
fn test_mmap_allocation() {
let allocator = MemoryMappedAllocator::default();
let size = 64 * 1024;
let mut allocation = allocator.allocate(size).unwrap();
assert_eq!(allocation.size(), size);
assert!(allocation.actual_size() >= size);
let slice = allocation.as_mut_slice();
slice[0] = 42;
slice[size - 1] = 84;
let slice = allocation.as_slice();
assert_eq!(slice[0], 42);
assert_eq!(slice[size - 1], 84);
allocator.deallocate(allocation).unwrap();
let stats = allocator.stats();
assert_eq!(stats.total_allocated, size as u64);
assert_eq!(stats.total_freed, size as u64);
assert_eq!(stats.mmap_calls, 1);
}
#[test]
fn test_mmap_cache() {
let allocator = MemoryMappedAllocator::default();
let size = 64 * 1024;
let allocation1 = allocator.allocate(size).unwrap();
allocator.deallocate(allocation1).unwrap();
let stats_before = allocator.stats();
let allocation2 = allocator.allocate(size).unwrap();
allocator.deallocate(allocation2).unwrap();
let stats_after = allocator.stats();
assert_eq!(stats_after.cache_hits, stats_before.cache_hits + 1);
assert_eq!(stats_after.mmap_calls, stats_before.mmap_calls);
}
#[test]
fn test_mmap_different_sizes() {
let allocator = MemoryMappedAllocator::default();
let sizes = vec![16 * 1024, 32 * 1024, 64 * 1024, 128 * 1024];
let mut allocations = Vec::new();
for size in &sizes {
let allocation = allocator.allocate(*size).unwrap();
assert_eq!(allocation.size(), *size);
allocations.push(allocation);
}
for allocation in allocations {
allocator.deallocate(allocation).unwrap();
}
let stats = allocator.stats();
assert_eq!(stats.mmap_calls, sizes.len() as u64);
assert_eq!(stats.total_allocated, sizes.iter().sum::<usize>() as u64);
assert_eq!(stats.total_freed, sizes.iter().sum::<usize>() as u64);
}
#[test]
fn test_mmap_cache_limit() {
let allocator = MemoryMappedAllocator::default();
let size = 64 * 1024;
for _ in 0..10 {
let allocation = allocator.allocate(size).unwrap();
allocator.deallocate(allocation).unwrap();
}
let stats = allocator.stats();
assert!(stats.cached_regions <= 4); }
#[test]
fn test_clear_cache() {
let allocator = MemoryMappedAllocator::default();
let size = 64 * 1024;
let allocation = allocator.allocate(size).unwrap();
allocator.deallocate(allocation).unwrap();
let stats_before = allocator.stats();
assert!(stats_before.cached_regions > 0);
allocator.clear_cache().unwrap();
let stats_after = allocator.stats();
assert_eq!(stats_after.cached_regions, 0);
assert!(stats_after.munmap_calls > stats_before.munmap_calls);
}
#[test]
fn test_invalid_allocation_size() {
let allocator = MemoryMappedAllocator::new(16 * 1024);
let result = allocator.allocate(8 * 1024);
assert!(result.is_err());
}
}