#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct AllocationStats {
pub total_allocations: usize,
pub total_bytes: usize,
pub active_allocations: usize,
pub is_frozen: bool,
pub peak_usage: usize,
}
pub trait Allocator: Clone + Send + Sync {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64>;
fn deallocate(&self, ptr: u64, size_bytes: usize);
fn freeze(&self) -> bool {
false
}
fn unfreeze(&self) {
}
fn is_frozen(&self) -> bool {
false
}
fn allocated_bytes(&self) -> usize {
0 }
fn stats(&self) -> AllocationStats {
AllocationStats::default()
}
fn reset(&self) -> crate::error::Result<()> {
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct DefaultAllocator<D> {
device: D,
allocate_fn: fn(usize, &D) -> crate::error::Result<u64>,
deallocate_fn: fn(u64, usize, &D),
}
impl<D: Clone + Send + Sync> DefaultAllocator<D> {
pub fn new(
device: D,
allocate_fn: fn(usize, &D) -> crate::error::Result<u64>,
deallocate_fn: fn(u64, usize, &D),
) -> Self {
Self {
device,
allocate_fn,
deallocate_fn,
}
}
pub fn device(&self) -> &D {
&self.device
}
}
impl<D: Clone + Send + Sync> Allocator for DefaultAllocator<D> {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64> {
(self.allocate_fn)(size_bytes, &self.device)
}
fn deallocate(&self, ptr: u64, size_bytes: usize) {
(self.deallocate_fn)(ptr, size_bytes, &self.device)
}
}
#[derive(Debug)]
struct TrackingState<A: Allocator> {
inner: A,
total_allocations: usize,
total_bytes: usize,
active_allocations: usize,
active_bytes: usize,
peak_usage: usize,
frozen: bool,
}
#[derive(Debug)]
pub struct TrackingAllocator<A: Allocator> {
state: std::sync::Arc<std::sync::Mutex<TrackingState<A>>>,
}
impl<A: Allocator> Clone for TrackingAllocator<A> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
}
}
}
impl<A: Allocator> TrackingAllocator<A> {
fn lock(&self) -> std::sync::MutexGuard<'_, TrackingState<A>> {
self.state
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
pub fn new(inner: A) -> Self {
Self {
state: std::sync::Arc::new(std::sync::Mutex::new(TrackingState {
inner,
total_allocations: 0,
total_bytes: 0,
active_allocations: 0,
active_bytes: 0,
peak_usage: 0,
frozen: false,
})),
}
}
pub fn active_bytes(&self) -> usize {
let s = self.lock();
s.active_bytes
}
}
impl<A: Allocator> Allocator for TrackingAllocator<A> {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64> {
let mut s = self.lock();
if s.frozen {
return Err(crate::error::Error::AllocatorFrozen);
}
let ptr = s.inner.allocate(size_bytes)?;
s.total_allocations += 1;
s.total_bytes += size_bytes;
s.active_allocations += 1;
s.active_bytes += size_bytes;
if s.active_bytes > s.peak_usage {
s.peak_usage = s.active_bytes;
}
Ok(ptr)
}
fn deallocate(&self, ptr: u64, size_bytes: usize) {
let mut s = self.lock();
s.inner.deallocate(ptr, size_bytes);
s.active_allocations = s.active_allocations.saturating_sub(1);
s.active_bytes = s.active_bytes.saturating_sub(size_bytes);
}
fn freeze(&self) -> bool {
let mut s = self.lock();
s.frozen = true;
true
}
fn unfreeze(&self) {
let mut s = self.lock();
s.frozen = false;
}
fn is_frozen(&self) -> bool {
let s = self.lock();
s.frozen
}
fn allocated_bytes(&self) -> usize {
let s = self.lock();
s.active_bytes
}
fn stats(&self) -> AllocationStats {
let s = self.lock();
AllocationStats {
total_allocations: s.total_allocations,
total_bytes: s.total_bytes,
active_allocations: s.active_allocations,
is_frozen: s.frozen,
peak_usage: s.peak_usage,
}
}
fn reset(&self) -> crate::error::Result<()> {
let mut s = self.lock();
if s.active_allocations > 0 {
return Err(crate::error::Error::AllocatorBusy {
active_allocations: s.active_allocations,
});
}
s.total_allocations = 0;
s.total_bytes = 0;
s.active_bytes = 0;
s.peak_usage = 0;
Ok(())
}
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub struct AllocGuard<'a, A: Allocator> {
allocator: &'a A,
ptr: u64,
size: usize,
released: bool,
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
impl<'a, A: Allocator> AllocGuard<'a, A> {
pub fn new(allocator: &'a A, size_bytes: usize) -> crate::error::Result<Self> {
let ptr = allocator.allocate(size_bytes)?;
Ok(Self {
allocator,
ptr,
size: size_bytes,
released: false,
})
}
#[inline]
pub fn ptr(&self) -> u64 {
self.ptr
}
#[inline]
pub fn release(mut self) -> u64 {
self.released = true;
self.ptr
}
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
impl<A: Allocator> Drop for AllocGuard<'_, A> {
fn drop(&mut self) {
if !self.released && self.ptr != 0 {
self.allocator.deallocate(self.ptr, self.size);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_allocator_trait_bounds() {
fn assert_allocator<A: Allocator>() {}
assert_allocator::<DefaultAllocator<()>>();
}
#[derive(Clone)]
struct TestAllocator;
impl Allocator for TestAllocator {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64> {
if size_bytes == 0 {
return Ok(0);
}
let layout = std::alloc::Layout::from_size_align(size_bytes, 8).unwrap();
let ptr = unsafe { std::alloc::alloc(layout) };
if ptr.is_null() {
return Err(crate::error::Error::OutOfMemory { size: size_bytes });
}
Ok(ptr as u64)
}
fn deallocate(&self, ptr: u64, size_bytes: usize) {
if ptr == 0 || size_bytes == 0 {
return;
}
let layout = std::alloc::Layout::from_size_align(size_bytes, 8).unwrap();
unsafe { std::alloc::dealloc(ptr as *mut u8, layout) };
}
}
#[test]
fn test_tracking_allocator_basic_stats() {
let tracking = TrackingAllocator::new(TestAllocator);
let stats = tracking.stats();
assert_eq!(stats.total_allocations, 0);
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.active_allocations, 0);
assert_eq!(stats.peak_usage, 0);
assert!(!stats.is_frozen);
let ptr1 = tracking.allocate(1024).unwrap();
let stats = tracking.stats();
assert_eq!(stats.total_allocations, 1);
assert_eq!(stats.total_bytes, 1024);
assert_eq!(stats.active_allocations, 1);
assert_eq!(stats.peak_usage, 1024);
let ptr2 = tracking.allocate(2048).unwrap();
let stats = tracking.stats();
assert_eq!(stats.total_allocations, 2);
assert_eq!(stats.total_bytes, 3072);
assert_eq!(stats.active_allocations, 2);
assert_eq!(stats.peak_usage, 3072);
tracking.deallocate(ptr1, 1024);
let stats = tracking.stats();
assert_eq!(stats.active_allocations, 1);
assert_eq!(stats.peak_usage, 3072);
tracking.deallocate(ptr2, 2048);
let stats = tracking.stats();
assert_eq!(stats.active_allocations, 0);
assert_eq!(stats.peak_usage, 3072); }
#[test]
fn test_tracking_allocator_allocated_bytes() {
let tracking = TrackingAllocator::new(TestAllocator);
assert_eq!(tracking.allocated_bytes(), 0);
let ptr = tracking.allocate(512).unwrap();
assert_eq!(tracking.allocated_bytes(), 512);
assert_eq!(tracking.active_bytes(), 512);
tracking.deallocate(ptr, 512);
assert_eq!(tracking.allocated_bytes(), 0);
}
#[test]
fn test_tracking_allocator_freeze() {
let tracking = TrackingAllocator::new(TestAllocator);
assert!(!tracking.is_frozen());
assert!(tracking.freeze());
assert!(tracking.is_frozen());
let result = tracking.allocate(128);
assert!(result.is_err());
match result.unwrap_err() {
crate::error::Error::AllocatorFrozen => {}
other => panic!("expected AllocatorFrozen, got: {other}"),
}
tracking.unfreeze();
assert!(!tracking.is_frozen());
let ptr = tracking.allocate(128).unwrap();
tracking.deallocate(ptr, 128);
}
#[test]
fn test_tracking_allocator_reset_success() {
let tracking = TrackingAllocator::new(TestAllocator);
let ptr = tracking.allocate(1024).unwrap();
tracking.deallocate(ptr, 1024);
tracking.reset().unwrap();
let stats = tracking.stats();
assert_eq!(stats.total_allocations, 0);
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.active_allocations, 0);
assert_eq!(stats.peak_usage, 0);
}
#[test]
fn test_tracking_allocator_reset_busy() {
let tracking = TrackingAllocator::new(TestAllocator);
let ptr = tracking.allocate(1024).unwrap();
let result = tracking.reset();
assert!(result.is_err());
match result.unwrap_err() {
crate::error::Error::AllocatorBusy {
active_allocations: 1,
} => {}
other => panic!("expected AllocatorBusy(1), got: {other}"),
}
tracking.deallocate(ptr, 1024);
}
#[test]
fn test_tracking_allocator_peak_across_cycles() {
let tracking = TrackingAllocator::new(TestAllocator);
let p1 = tracking.allocate(2048).unwrap();
let p2 = tracking.allocate(2048).unwrap();
assert_eq!(tracking.stats().peak_usage, 4096);
tracking.deallocate(p1, 2048);
tracking.deallocate(p2, 2048);
assert_eq!(tracking.stats().peak_usage, 4096);
tracking.reset().unwrap();
assert_eq!(tracking.stats().peak_usage, 0);
let p3 = tracking.allocate(512).unwrap();
assert_eq!(tracking.stats().peak_usage, 512);
tracking.deallocate(p3, 512);
}
#[test]
fn test_tracking_allocator_clone_shares_state() {
let tracking = TrackingAllocator::new(TestAllocator);
let clone = tracking.clone();
let ptr = tracking.allocate(256).unwrap();
assert_eq!(clone.stats().active_allocations, 1);
clone.deallocate(ptr, 256);
assert_eq!(tracking.stats().active_allocations, 0);
}
#[test]
fn test_tracking_allocator_freeze_preserved_on_reset() {
let tracking = TrackingAllocator::new(TestAllocator);
tracking.freeze();
tracking.reset().unwrap();
assert!(tracking.is_frozen());
}
#[test]
fn test_allocation_stats_default() {
let stats = AllocationStats::default();
assert_eq!(stats.total_allocations, 0);
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.active_allocations, 0);
assert!(!stats.is_frozen);
assert_eq!(stats.peak_usage, 0);
}
}