use crate::{
memory_management::{MemoryUsage, StorageExclude},
storage::{ComputeStorage, StorageHandle, StorageUtilization},
};
use alloc::vec::Vec;
use super::{MemoryPool, Slice, SliceBinding, SliceHandle, calculate_padding};
pub struct ExclusiveMemoryPool {
pages: Vec<MemoryPage>,
alignment: u64,
dealloc_period: u64,
last_dealloc_check: u64,
max_alloc_size: u64,
cur_avg_size: f64,
}
const SIZE_AVG_DECAY: f64 = 0.01;
const ALLOC_AFTER_FREE: u32 = 5;
struct MemoryPage {
slice: Slice,
alloc_size: u64,
free_count: u32,
}
impl ExclusiveMemoryPool {
pub(crate) fn new(max_alloc_size: u64, alignment: u64, dealloc_period: u64) -> Self {
assert_eq!(max_alloc_size % alignment, 0);
Self {
pages: Vec::new(),
alignment,
dealloc_period,
last_dealloc_check: 0,
max_alloc_size,
cur_avg_size: max_alloc_size as f64 / 2.0,
}
}
fn get_free_page(
&mut self,
size: u64,
exclude: Option<&StorageExclude>,
) -> Option<&mut MemoryPage> {
self.pages
.iter_mut()
.filter(|page| {
page.alloc_size >= size
&& page.slice.is_free()
&& !exclude.is_some_and(|e| e.is_excluded(page.slice.storage.id))
})
.min_by_key(|page| page.free_count)
}
fn alloc_page<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
size: u64,
) -> &mut MemoryPage {
let alloc_size = (self.cur_avg_size as u64)
.max(size)
.next_multiple_of(self.alignment);
let storage = storage.alloc(alloc_size);
let handle = SliceHandle::new();
let padding = calculate_padding(size, self.alignment);
let mut slice = Slice::new(storage, handle, padding);
slice.storage.utilization = StorageUtilization { offset: 0, size };
slice.padding = padding;
self.pages.push(MemoryPage {
slice,
alloc_size,
free_count: ALLOC_AFTER_FREE - 1,
});
let idx = self.pages.len() - 1;
&mut self.pages[idx]
}
}
impl MemoryPool for ExclusiveMemoryPool {
fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> {
let binding_id = *binding.id();
self.pages
.iter()
.find(|page| page.slice.id() == binding_id)
.map(|page| &page.slice.storage)
}
fn try_reserve(&mut self, size: u64, exclude: Option<&StorageExclude>) -> Option<SliceHandle> {
self.cur_avg_size =
self.cur_avg_size * (1.0 - SIZE_AVG_DECAY) + size as f64 * SIZE_AVG_DECAY;
let padding = calculate_padding(size, self.alignment);
self.get_free_page(size, exclude).map(|page| {
page.slice.storage.utilization = StorageUtilization { offset: 0, size };
page.slice.padding = padding;
page.free_count = page.free_count.saturating_sub(1);
page.slice.handle.clone()
})
}
fn alloc<Storage: ComputeStorage>(&mut self, storage: &mut Storage, size: u64) -> SliceHandle {
assert!(
size <= self.max_alloc_size,
"Should allocate less than maximum size in pool!"
);
let page = self.alloc_page(storage, size);
page.slice.handle.clone()
}
fn get_memory_usage(&self) -> MemoryUsage {
let used_slices: Vec<_> = self
.pages
.iter()
.filter(|page| !page.slice.is_free())
.collect();
MemoryUsage {
number_allocs: used_slices.len() as u64,
bytes_in_use: used_slices
.iter()
.map(|page| page.slice.storage.size())
.sum(),
bytes_padding: used_slices.iter().map(|page| page.slice.padding).sum(),
bytes_reserved: self.pages.iter().map(|page| page.alloc_size).sum(),
}
}
fn max_alloc_size(&self) -> u64 {
self.max_alloc_size
}
fn cleanup<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
alloc_nr: u64,
explicit: bool,
) {
let check_period = self.dealloc_period / (ALLOC_AFTER_FREE as u64);
if explicit || alloc_nr - self.last_dealloc_check >= check_period {
self.last_dealloc_check = alloc_nr;
self.pages.retain_mut(|page| {
if page.slice.is_free() {
page.free_count += 1;
if page.free_count >= ALLOC_AFTER_FREE || explicit {
storage.dealloc(page.slice.storage.id);
return false;
}
}
true
});
}
}
}