use crate::{
memory_management::{BytesFormat, MemoryUsage},
server::IoError,
storage::{ComputeStorage, StorageHandle, StorageUtilization},
};
use alloc::vec::Vec;
use cubecl_common::backtrace::BackTrace;
use super::{MemoryPool, Slice, SliceBinding, SliceHandle, calculate_padding};
pub struct ExclusiveMemoryPool {
pages: Vec<MemoryPage>,
alignment: u64,
dealloc_period: u64,
last_dealloc_check: u64,
max_alloc_size: u64,
cur_avg_size: f64,
}
impl core::fmt::Display for ExclusiveMemoryPool {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!(
" - Exclusive Pool max_alloc_size={}\n",
BytesFormat::new(self.max_alloc_size)
))?;
for page in self.pages.iter() {
let is_free = page.slice.is_free();
let size = BytesFormat::new(page.slice.effective_size());
f.write_fmt(format_args!(" - Page {size} is_free={is_free}\n"))?;
}
if !self.pages.is_empty() {
f.write_fmt(format_args!("\n{}\n", self.get_memory_usage()))?;
}
Ok(())
}
}
const SIZE_AVG_DECAY: f64 = 0.01;
const ALLOC_AFTER_FREE: u32 = 5;
struct MemoryPage {
slice: Slice,
alloc_size: u64,
free_count: u32,
}
impl ExclusiveMemoryPool {
pub(crate) fn new(max_alloc_size: u64, alignment: u64, dealloc_period: u64) -> Self {
assert_eq!(max_alloc_size % alignment, 0);
Self {
pages: Vec::new(),
alignment,
dealloc_period,
last_dealloc_check: 0,
max_alloc_size,
cur_avg_size: max_alloc_size as f64 / 2.0,
}
}
fn get_free_page(&mut self, size: u64) -> Option<&mut MemoryPage> {
self.pages
.iter_mut()
.filter(|page| page.alloc_size >= size && page.slice.is_free())
.min_by_key(|page| page.free_count)
}
fn alloc_page<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
size: u64,
) -> Result<&mut MemoryPage, IoError> {
let alloc_size = (self.cur_avg_size as u64)
.max(size)
.next_multiple_of(self.alignment);
let storage = storage.alloc(alloc_size)?;
let handle = SliceHandle::new();
let padding = calculate_padding(size, self.alignment);
let mut slice = Slice::new(storage, handle, padding);
slice.storage.utilization = StorageUtilization { offset: 0, size };
slice.padding = padding;
self.pages.push(MemoryPage {
slice,
alloc_size,
free_count: ALLOC_AFTER_FREE - 1,
});
let idx = self.pages.len() - 1;
Ok(&mut self.pages[idx])
}
}
impl MemoryPool for ExclusiveMemoryPool {
fn accept(&self, size: u64) -> bool {
self.max_alloc_size >= size
}
fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> {
let binding_id = *binding.id();
self.pages
.iter()
.find(|page| page.slice.id() == binding_id)
.map(|page| &page.slice.storage)
}
fn try_reserve(&mut self, size: u64) -> Option<SliceHandle> {
self.cur_avg_size =
self.cur_avg_size * (1.0 - SIZE_AVG_DECAY) + size as f64 * SIZE_AVG_DECAY;
let padding = calculate_padding(size, self.alignment);
self.get_free_page(size).map(|page| {
page.slice.storage.utilization = StorageUtilization { offset: 0, size };
page.slice.padding = padding;
page.free_count = page.free_count.saturating_sub(1);
page.slice.handle.clone()
})
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, storage))
)]
fn alloc<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
size: u64,
) -> Result<SliceHandle, IoError> {
if size > self.max_alloc_size {
return Err(IoError::BufferTooBig {
size,
backtrace: BackTrace::capture(),
});
}
let page = self.alloc_page(storage, size)?;
Ok(page.slice.handle.clone())
}
fn get_memory_usage(&self) -> MemoryUsage {
let used_slices: Vec<_> = self
.pages
.iter()
.filter(|page| !page.slice.is_free())
.collect();
MemoryUsage {
number_allocs: used_slices.len() as u64,
bytes_in_use: used_slices
.iter()
.map(|page| page.slice.storage.size())
.sum(),
bytes_padding: used_slices.iter().map(|page| page.slice.padding).sum(),
bytes_reserved: self.pages.iter().map(|page| page.alloc_size).sum(),
}
}
fn cleanup<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
alloc_nr: u64,
explicit: bool,
) {
let check_period = self.dealloc_period / (ALLOC_AFTER_FREE as u64);
if explicit || alloc_nr - self.last_dealloc_check >= check_period {
self.last_dealloc_check = alloc_nr;
self.pages.retain_mut(|page| {
if page.slice.is_free() {
page.free_count += 1;
if page.free_count >= ALLOC_AFTER_FREE || explicit {
storage.dealloc(page.slice.storage.id);
return false;
}
}
true
});
}
}
}