use crate::{GpuBuffer, GpuDevice, GpuError, Result};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, Default)]
pub struct MemoryStats {
pub total_allocated: u64,
pub total_freed: u64,
pub current_usage: u64,
pub peak_usage: u64,
pub allocation_count: u64,
}
impl MemoryStats {
#[must_use]
pub fn current_bytes(&self) -> u64 {
self.current_usage
}
#[must_use]
pub fn current_mb(&self) -> f64 {
self.current_usage as f64 / (1024.0 * 1024.0)
}
#[must_use]
pub fn peak_bytes(&self) -> u64 {
self.peak_usage
}
#[must_use]
pub fn peak_mb(&self) -> f64 {
self.peak_usage as f64 / (1024.0 * 1024.0)
}
}
pub struct MemoryAllocator {
device: Arc<wgpu::Device>,
total_allocated: AtomicU64,
total_freed: AtomicU64,
current_usage: AtomicU64,
peak_usage: AtomicU64,
allocation_count: AtomicU64,
}
impl MemoryAllocator {
#[must_use]
pub fn new(device: &GpuDevice) -> Self {
Self {
device: Arc::clone(device.device()),
total_allocated: AtomicU64::new(0),
total_freed: AtomicU64::new(0),
current_usage: AtomicU64::new(0),
peak_usage: AtomicU64::new(0),
allocation_count: AtomicU64::new(0),
}
}
pub fn track_allocation(&self, size: u64) {
self.total_allocated.fetch_add(size, Ordering::Relaxed);
let current = self.current_usage.fetch_add(size, Ordering::Relaxed) + size;
self.allocation_count.fetch_add(1, Ordering::Relaxed);
let mut peak = self.peak_usage.load(Ordering::Relaxed);
while current > peak {
match self.peak_usage.compare_exchange_weak(
peak,
current,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => peak = x,
}
}
}
pub fn track_deallocation(&self, size: u64) {
self.total_freed.fetch_add(size, Ordering::Relaxed);
self.current_usage.fetch_sub(size, Ordering::Relaxed);
self.allocation_count.fetch_sub(1, Ordering::Relaxed);
}
pub fn stats(&self) -> MemoryStats {
MemoryStats {
total_allocated: self.total_allocated.load(Ordering::Relaxed),
total_freed: self.total_freed.load(Ordering::Relaxed),
current_usage: self.current_usage.load(Ordering::Relaxed),
peak_usage: self.peak_usage.load(Ordering::Relaxed),
allocation_count: self.allocation_count.load(Ordering::Relaxed),
}
}
pub fn reset_stats(&self) {
self.total_allocated.store(0, Ordering::Relaxed);
self.total_freed.store(0, Ordering::Relaxed);
self.current_usage.store(0, Ordering::Relaxed);
self.peak_usage.store(0, Ordering::Relaxed);
self.allocation_count.store(0, Ordering::Relaxed);
}
pub fn device(&self) -> &Arc<wgpu::Device> {
&self.device
}
}
pub struct MemoryPool {
#[allow(dead_code)]
device: Arc<wgpu::Device>,
allocator: Arc<MemoryAllocator>,
pools: RwLock<HashMap<u64, Vec<GpuBuffer>>>,
}
impl MemoryPool {
#[must_use]
pub fn new(device: &GpuDevice) -> Self {
Self {
device: Arc::clone(device.device()),
allocator: Arc::new(MemoryAllocator::new(device)),
pools: RwLock::new(HashMap::new()),
}
}
pub fn allocate(
&self,
device: &GpuDevice,
size: u64,
buffer_type: crate::buffer::BufferType,
) -> Result<GpuBuffer> {
{
let mut pools = self.pools.write();
if let Some(pool) = pools.get_mut(&size) {
if let Some(buffer) = pool.pop() {
return Ok(buffer);
}
}
}
let buffer = GpuBuffer::new(device, size, buffer_type)?;
self.allocator.track_allocation(size);
Ok(buffer)
}
pub fn deallocate(&self, buffer: GpuBuffer) {
let size = buffer.size();
let mut pools = self.pools.write();
pools.entry(size).or_default().push(buffer);
}
pub fn clear(&self) {
let mut pools = self.pools.write();
for (size, buffers) in pools.drain() {
let total_size = size * buffers.len() as u64;
self.allocator.track_deallocation(total_size);
}
}
pub fn pool_size(&self) -> usize {
let pools = self.pools.read();
pools.values().map(std::vec::Vec::len).sum()
}
pub fn stats(&self) -> MemoryStats {
self.allocator.stats()
}
pub fn allocator(&self) -> &Arc<MemoryAllocator> {
&self.allocator
}
}
pub struct ManagedBuffer {
buffer: Option<GpuBuffer>,
pool: Arc<MemoryPool>,
}
impl ManagedBuffer {
pub fn new(buffer: GpuBuffer, pool: Arc<MemoryPool>) -> Self {
Self {
buffer: Some(buffer),
pool,
}
}
#[must_use]
pub fn buffer(&self) -> &GpuBuffer {
self.buffer
.as_ref()
.unwrap_or_else(|| unreachable!("ManagedBuffer accessed after buffer was released"))
}
pub fn try_buffer(&self) -> Result<&GpuBuffer> {
self.buffer
.as_ref()
.ok_or_else(|| GpuError::Internal("Buffer already released".to_string()))
}
pub fn take(mut self) -> Result<GpuBuffer> {
self.buffer
.take()
.ok_or_else(|| GpuError::Internal("Buffer already released".to_string()))
}
}
impl Drop for ManagedBuffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
self.pool.deallocate(buffer);
}
}
}
impl std::ops::Deref for ManagedBuffer {
type Target = GpuBuffer;
fn deref(&self) -> &Self::Target {
self.buffer()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_stats() {
let stats = MemoryStats {
total_allocated: 1024 * 1024 * 100, total_freed: 1024 * 1024 * 20, current_usage: 1024 * 1024 * 80, peak_usage: 1024 * 1024 * 90, allocation_count: 10,
};
assert_eq!(stats.current_bytes(), 1024 * 1024 * 80);
assert!((stats.current_mb() - 80.0).abs() < 0.01);
assert_eq!(stats.peak_bytes(), 1024 * 1024 * 90);
assert!((stats.peak_mb() - 90.0).abs() < 0.01);
}
#[test]
#[ignore] fn test_memory_allocator_tracking() {
let Ok(gpu_device) = crate::device::GpuDevice::new(None) else {
return;
};
let allocator = MemoryAllocator::new(&gpu_device);
allocator.track_allocation(1024);
allocator.track_allocation(2048);
let stats = allocator.stats();
assert_eq!(stats.total_allocated, 3072);
assert_eq!(stats.current_usage, 3072);
assert_eq!(stats.allocation_count, 2);
allocator.track_deallocation(1024);
let stats = allocator.stats();
assert_eq!(stats.total_freed, 1024);
assert_eq!(stats.current_usage, 2048);
assert_eq!(stats.allocation_count, 1);
}
#[test]
fn test_memory_allocator_tracking_no_gpu() {
let total_allocated = AtomicU64::new(0);
let total_freed = AtomicU64::new(0);
let current_usage = AtomicU64::new(0);
let allocation_count = AtomicU64::new(0);
total_allocated.fetch_add(1024, Ordering::Relaxed);
current_usage.fetch_add(1024, Ordering::Relaxed);
allocation_count.fetch_add(1, Ordering::Relaxed);
total_allocated.fetch_add(2048, Ordering::Relaxed);
current_usage.fetch_add(2048, Ordering::Relaxed);
allocation_count.fetch_add(1, Ordering::Relaxed);
assert_eq!(total_allocated.load(Ordering::Relaxed), 3072);
assert_eq!(current_usage.load(Ordering::Relaxed), 3072);
assert_eq!(allocation_count.load(Ordering::Relaxed), 2);
total_freed.fetch_add(1024, Ordering::Relaxed);
current_usage.fetch_sub(1024, Ordering::Relaxed);
allocation_count.fetch_sub(1, Ordering::Relaxed);
assert_eq!(total_freed.load(Ordering::Relaxed), 1024);
assert_eq!(current_usage.load(Ordering::Relaxed), 2048);
assert_eq!(allocation_count.load(Ordering::Relaxed), 1);
}
}