use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct SizeClass(usize);
impl SizeClass {
pub const CLASSES: [usize; 9] = [
4096, 16384, 65536, 262_144, 1_048_576, 4_194_304, 16_777_216, 67_108_864, 268_435_456, ];
#[must_use]
pub fn for_size(size: usize) -> Option<Self> {
Self::CLASSES
.iter()
.find(|&&class| class >= size)
.map(|&class| SizeClass(class))
}
#[must_use]
pub fn bytes(&self) -> usize {
self.0
}
}
#[derive(Debug)]
pub struct GpuMemoryPool {
free_buffers: BTreeMap<usize, Vec<GpuBufferHandle>>,
total_allocated: usize,
peak_usage: usize,
pool_hits: usize,
pool_misses: usize,
max_size: usize,
}
#[derive(Debug)]
pub struct GpuBufferHandle {
pub size: usize,
pub in_use: bool,
}
impl Default for GpuMemoryPool {
fn default() -> Self {
Self::new()
}
}
impl GpuMemoryPool {
#[must_use]
pub fn new() -> Self {
Self {
free_buffers: BTreeMap::new(),
total_allocated: 0,
peak_usage: 0,
pool_hits: 0,
pool_misses: 0,
max_size: 2 * 1024 * 1024 * 1024, }
}
#[must_use]
pub fn with_max_size(max_size: usize) -> Self {
Self {
max_size,
..Self::new()
}
}
pub fn try_get(&mut self, size: usize) -> Option<GpuBufferHandle> {
let size_class = SizeClass::for_size(size)?;
let class_size = size_class.bytes();
if let Some(buffers) = self.free_buffers.get_mut(&class_size) {
if let Some(mut handle) = buffers.pop() {
handle.in_use = true;
self.pool_hits += 1;
return Some(handle);
}
}
self.pool_misses += 1;
None
}
pub fn return_buffer(&mut self, mut handle: GpuBufferHandle) {
handle.in_use = false;
let size_class = SizeClass::for_size(handle.size).map_or(handle.size, |s| s.bytes());
self.free_buffers
.entry(size_class)
.or_default()
.push(handle);
}
pub fn record_allocation(&mut self, size: usize) {
self.total_allocated += size;
if self.total_allocated > self.peak_usage {
self.peak_usage = self.total_allocated;
}
}
pub fn record_deallocation(&mut self, size: usize) {
self.total_allocated = self.total_allocated.saturating_sub(size);
}
#[must_use]
pub fn has_capacity(&self, size: usize) -> bool {
self.total_allocated + size <= self.max_size
}
#[must_use]
pub fn max_size(&self) -> usize {
self.max_size
}
#[must_use]
pub fn stats(&self) -> PoolStats {
PoolStats {
total_allocated: self.total_allocated,
peak_usage: self.peak_usage,
pool_hits: self.pool_hits,
pool_misses: self.pool_misses,
hit_rate: if self.pool_hits + self.pool_misses > 0 {
self.pool_hits as f64 / (self.pool_hits + self.pool_misses) as f64
} else {
0.0
},
free_buffers: self.free_buffers.values().map(Vec::len).sum(),
}
}
pub fn clear(&mut self) {
self.free_buffers.clear();
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub total_allocated: usize,
pub peak_usage: usize,
pub pool_hits: usize,
pub pool_misses: usize,
pub hit_rate: f64,
pub free_buffers: usize,
}
impl PoolStats {
#[must_use]
pub fn estimated_savings_bytes(&self) -> usize {
if self.pool_hits > 0 {
self.pool_hits * 1024 * 1024 } else {
0
}
}
}
#[derive(Debug)]
pub struct PinnedHostBuffer<T> {
data: Vec<T>,
is_pinned: bool,
}
impl<T: Copy + Default> PinnedHostBuffer<T> {
#[must_use]
pub fn new(len: usize) -> Self {
let data = vec![T::default(); len];
Self {
data,
is_pinned: false, }
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn is_pinned(&self) -> bool {
self.is_pinned
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
pub fn copy_from_slice(&mut self, src: &[T]) {
self.data.copy_from_slice(src);
}
}
#[derive(Debug)]
pub struct StagingBufferPool {
free_buffers: BTreeMap<usize, Vec<PinnedHostBuffer<f32>>>,
total_allocated: usize,
peak_usage: usize,
pool_hits: usize,
pool_misses: usize,
max_size: usize,
}
impl Default for StagingBufferPool {
fn default() -> Self {
Self::new()
}
}
impl StagingBufferPool {
#[must_use]
pub fn new() -> Self {
Self {
free_buffers: BTreeMap::new(),
total_allocated: 0,
peak_usage: 0,
pool_hits: 0,
pool_misses: 0,
max_size: 512 * 1024 * 1024, }
}
#[must_use]
pub fn with_max_size(max_size: usize) -> Self {
Self {
max_size,
..Self::new()
}
}
pub fn get(&mut self, size: usize) -> PinnedHostBuffer<f32> {
let size_bytes = size * std::mem::size_of::<f32>();
let size_class = SizeClass::for_size(size_bytes).map_or(size_bytes, |c| c.bytes());
let elements = size_class / std::mem::size_of::<f32>();
if let Some(buffers) = self.free_buffers.get_mut(&size_class) {
if let Some(buf) = buffers.pop() {
self.pool_hits += 1;
return buf;
}
}
self.pool_misses += 1;
let buf = PinnedHostBuffer::new(elements);
self.total_allocated += buf.size_bytes();
self.peak_usage = self.peak_usage.max(self.total_allocated);
buf
}
pub fn put(&mut self, buf: PinnedHostBuffer<f32>) {
let size_class = buf.size_bytes();
if self.total_allocated > self.max_size {
self.total_allocated = self.total_allocated.saturating_sub(size_class);
return; }
self.free_buffers.entry(size_class).or_default().push(buf);
}
#[must_use]
pub fn stats(&self) -> StagingPoolStats {
let free_count: usize = self.free_buffers.values().map(Vec::len).sum();
StagingPoolStats {
total_allocated: self.total_allocated,
peak_usage: self.peak_usage,
pool_hits: self.pool_hits,
pool_misses: self.pool_misses,
free_buffers: free_count,
hit_rate: if self.pool_hits + self.pool_misses > 0 {
self.pool_hits as f64 / (self.pool_hits + self.pool_misses) as f64
} else {
0.0
},
}
}
pub fn clear(&mut self) {
self.free_buffers.clear();
self.total_allocated = 0;
}
}
include!("transfer.rs");
include!("memory_transfer_mode.rs");