use super::index::SearchIndex;
use super::{MemoryPool, RingBuffer, Slice, SliceBinding, SliceHandle, SliceId};
use crate::memory_management::memory_pool::calculate_padding;
use crate::memory_management::{MemoryUsage, StorageExclude};
use crate::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use alloc::vec::Vec;
use hashbrown::HashMap;
pub(crate) struct SlicedPool {
pages: HashMap<StorageId, MemoryPage>,
slices: HashMap<SliceId, Slice>,
storage_index: SearchIndex<StorageId>,
ring: RingBuffer,
recently_added_pages: Vec<StorageId>,
recently_allocated_size: u64,
page_size: u64,
max_alloc_size: u64,
alignment: u64,
}
#[derive(new, Debug)]
pub(crate) struct MemoryPage {
pub(crate) slices: HashMap<u64, SliceId>,
}
impl MemoryPage {
pub(crate) fn merge_with_next_slice(
&mut self,
first_slice_address: u64,
slices: &mut HashMap<SliceId, Slice>,
) -> bool {
let first_slice_id = self.find_slice(first_slice_address).expect(
"merge_with_next_slice shouldn't be called with a nonexistent first_slice address",
);
let next_slice_address =
first_slice_address + slices.get(&first_slice_id).unwrap().effective_size();
if let Some(next_slice_id) = self.find_slice(next_slice_address) {
let (next_slice_eff_size, next_slice_is_free) = {
let next_slice = slices.get(&next_slice_id).unwrap();
(next_slice.effective_size(), next_slice.is_free())
};
if next_slice_is_free {
let first_slice = slices.get_mut(&first_slice_id).unwrap();
let first_slice_eff_size = first_slice.effective_size();
let first_slice_offset = first_slice.storage.offset();
let merged_size = first_slice_eff_size + next_slice_eff_size;
first_slice.storage.utilization = StorageUtilization {
size: merged_size,
offset: first_slice_offset,
};
first_slice.padding = 0;
self.slices.remove(&next_slice_address);
slices.remove(&next_slice_id);
return true;
}
return false;
}
false
}
pub(crate) fn find_slice(&self, address: u64) -> Option<SliceId> {
let slice_id = self.slices.get(&address);
slice_id.copied()
}
pub(crate) fn insert_slice(&mut self, address: u64, slice: SliceId) {
self.slices.insert(address, slice);
}
}
impl MemoryPool for SlicedPool {
fn max_alloc_size(&self) -> u64 {
self.max_alloc_size
}
fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> {
self.slices.get(binding.id()).map(|s| &s.storage)
}
fn try_reserve(&mut self, size: u64, exclude: Option<&StorageExclude>) -> Option<SliceHandle> {
let padding = calculate_padding(size, self.alignment);
let effective_size = size + padding;
let slice_id = self.ring.find_free_slice(
effective_size,
&mut self.pages,
&mut self.slices,
exclude,
)?;
let slice = self.slices.get_mut(&slice_id).unwrap();
let old_slice_size = slice.effective_size();
let offset = slice.storage.utilization.offset;
slice.storage.utilization = StorageUtilization { offset, size };
let new_padding = old_slice_size - size;
slice.padding = new_padding;
assert_eq!(
slice.effective_size(),
old_slice_size,
"new and old slice should have the same size"
);
Some(slice.handle.clone())
}
fn alloc<Storage: ComputeStorage>(&mut self, storage: &mut Storage, size: u64) -> SliceHandle {
let storage_id = self.create_page(storage);
self.recently_added_pages.push(storage_id);
self.recently_allocated_size += self.page_size;
let slice = self.create_slice(0, size, storage_id);
let effective_size = slice.effective_size();
let extra_slice = if effective_size < self.page_size {
Some(self.create_slice(effective_size, self.page_size - effective_size, storage_id))
} else {
None
};
let handle_slice = slice.handle.clone();
let storage_id = slice.storage.id;
let slice_id = slice.id();
let slice_offset = slice.storage.offset();
self.slices.insert(slice_id, slice);
let page = self.pages.get_mut(&storage_id).unwrap();
page.slices.insert(slice_offset, slice_id);
if let Some(extra_slice) = extra_slice {
let extra_slice_id = extra_slice.id();
let extra_slice_offset = extra_slice.storage.offset();
self.slices.insert(extra_slice_id, extra_slice);
page.slices.insert(extra_slice_offset, extra_slice_id);
}
handle_slice
}
fn get_memory_usage(&self) -> MemoryUsage {
let used_slices: Vec<_> = self
.slices
.values()
.filter(|slice| !slice.is_free())
.collect();
MemoryUsage {
number_allocs: used_slices.len() as u64,
bytes_in_use: used_slices.iter().map(|s| s.storage.size()).sum(),
bytes_padding: used_slices.iter().map(|s| s.padding).sum(),
bytes_reserved: (self.pages.len() as u64) * self.page_size,
}
}
fn cleanup<Storage: ComputeStorage>(
&mut self,
_storage: &mut Storage,
_alloc_nr: u64,
_explicit: bool,
) {
}
}
impl SlicedPool {
pub(crate) fn new(page_size: u64, max_alloc_size: u64, alignment: u64) -> Self {
assert_eq!(page_size % alignment, 0);
Self {
pages: HashMap::new(),
slices: HashMap::new(),
storage_index: SearchIndex::new(),
ring: RingBuffer::new(alignment),
recently_added_pages: Vec::new(),
recently_allocated_size: 0,
alignment,
page_size,
max_alloc_size,
}
}
fn create_slice(&self, offset: u64, size: u64, storage_id: StorageId) -> Slice {
assert_eq!(
offset % self.alignment,
0,
"slice with offset {offset} needs to be a multiple of {}",
self.alignment
);
let handle = SliceHandle::new();
let storage = StorageHandle {
id: storage_id,
utilization: StorageUtilization { offset, size },
};
let padding = calculate_padding(size, self.alignment);
Slice::new(storage, handle, padding)
}
fn create_page<Storage: ComputeStorage>(&mut self, storage: &mut Storage) -> StorageId {
let storage = storage.alloc(self.page_size);
let id = storage.id;
self.ring.push_page(id);
self.pages.insert(id, MemoryPage::new(HashMap::new()));
self.storage_index.insert(id, self.page_size);
id
}
}
impl Slice {
pub(crate) fn split(&mut self, offset_slice: u64, buffer_alignment: u64) -> Option<Self> {
let size_new = self.effective_size() - offset_slice;
let offset_new = self.storage.offset() + offset_slice;
let old_size = self.effective_size();
let storage_new = StorageHandle {
id: self.storage.id,
utilization: StorageUtilization {
offset: offset_new,
size: size_new,
},
};
self.storage.utilization = StorageUtilization {
offset: self.storage.offset(),
size: offset_slice,
};
if offset_new % buffer_alignment != 0 {
panic!("slice with offset {offset_new} needs to be a multiple of {buffer_alignment}");
}
let handle = SliceHandle::new();
if size_new < buffer_alignment {
self.padding = old_size - offset_slice;
assert_eq!(self.effective_size(), old_size);
return None;
}
assert!(
size_new >= buffer_alignment,
"Size new > {buffer_alignment}"
);
self.padding = 0;
let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment);
Some(Slice::new(storage_new, handle, padding))
}
pub(crate) fn next_slice_position(&self) -> u64 {
self.storage.offset() + self.effective_size()
}
}