use std::collections::{BTreeSet, HashSet};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::GpuResult;
pub const MIN_BLOCK_SIZE: usize = 512;
pub const SMALL_SIZE: usize = 1 << 20;
pub const SMALL_BUFFER: usize = 2 << 20;
pub const MIN_LARGE_ALLOC: usize = 10 << 20;
pub const LARGE_BUFFER: usize = 20 << 20;
pub const ROUND_LARGE: usize = 2 << 20;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct StreamId(pub usize);
static NEXT_BLOCK_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
pub struct Block {
pub(crate) id: usize,
pub device: usize,
pub size: usize,
pub ptr: usize,
pub stream: StreamId,
pub stream_uses: HashSet<StreamId>,
pub allocated: bool,
pub prev: Option<usize>,
pub next: Option<usize>,
pub in_small_pool: bool,
}
impl Block {
pub fn new(
device: usize,
size: usize,
ptr: usize,
stream: StreamId,
in_small_pool: bool,
) -> Self {
Self {
id: NEXT_BLOCK_ID.fetch_add(1, Ordering::Relaxed),
device,
size,
ptr,
stream,
stream_uses: HashSet::new(),
allocated: false,
prev: None,
next: None,
in_small_pool,
}
}
pub fn is_split(&self) -> bool {
self.prev.is_some() || self.next.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct BlockKey {
stream: StreamId,
size: usize,
ptr: usize,
id: usize,
}
impl BlockKey {
fn from_block(b: &Block) -> Self {
Self {
stream: b.stream,
size: b.size,
ptr: b.ptr,
id: b.id,
}
}
fn search(stream: StreamId, size: usize) -> Self {
Self {
stream,
size,
ptr: 0,
id: 0,
}
}
}
pub(crate) struct BlockPool {
free_blocks: BTreeSet<(BlockKey, usize)>, pub is_small: bool,
}
impl BlockPool {
pub fn new(is_small: bool) -> Self {
Self {
free_blocks: BTreeSet::new(),
is_small,
}
}
#[cfg(test)]
pub fn insert(&mut self, block_idx: usize, block: &Block) {
self.free_blocks
.insert((BlockKey::from_block(block), block_idx));
}
pub fn insert_key(&mut self, block_idx: usize, key: BlockKey) {
self.free_blocks.insert((key, block_idx));
}
pub fn remove_key(&mut self, block_idx: usize, key: BlockKey) {
self.free_blocks.remove(&(key, block_idx));
}
pub fn find_free_block(&self, stream: StreamId, size: usize) -> Option<usize> {
let search = (BlockKey::search(stream, size), 0);
if let Some(&(key, idx)) = self.free_blocks.range(search..).next() {
if key.stream == stream {
return Some(idx);
}
}
None
}
pub fn len(&self) -> usize {
self.free_blocks.len()
}
pub fn clear(&mut self) {
self.free_blocks.clear();
}
}
pub(crate) struct AllocatorState {
pub(crate) blocks: Vec<Block>,
pub(crate) small_pool: BlockPool,
pub(crate) large_pool: BlockPool,
pub(crate) reserved_bytes: usize,
pub(crate) allocated_bytes: usize,
pub(crate) peak_bytes: usize,
pub(crate) hits: usize,
pub(crate) misses: usize,
}
impl AllocatorState {
fn new() -> Self {
Self {
blocks: Vec::new(),
small_pool: BlockPool::new(true),
large_pool: BlockPool::new(false),
reserved_bytes: 0,
allocated_bytes: 0,
peak_bytes: 0,
hits: 0,
misses: 0,
}
}
pub(crate) fn get_pool_mut(&mut self, is_small: bool) -> &mut BlockPool {
let pool = if is_small {
&mut self.small_pool
} else {
&mut self.large_pool
};
debug_assert_eq!(pool.is_small, is_small, "pool size-class mismatch");
pool
}
pub(crate) fn add_block(&mut self, block: Block) -> usize {
let idx = self.blocks.len();
self.blocks.push(block);
idx
}
pub(crate) fn should_split(&self, block_idx: usize, size: usize) -> bool {
let block = &self.blocks[block_idx];
let remaining = block.size - size;
if block.in_small_pool {
remaining >= MIN_BLOCK_SIZE
} else {
remaining > SMALL_SIZE
}
}
pub(crate) fn split_block(&mut self, block_idx: usize, size: usize) {
let remaining_size = self.blocks[block_idx].size - size;
let remaining_ptr = self.blocks[block_idx].ptr + size;
let stream = self.blocks[block_idx].stream;
let device = self.blocks[block_idx].device;
let is_small = self.blocks[block_idx].in_small_pool;
let old_next = self.blocks[block_idx].next;
let mut remainder = Block::new(device, remaining_size, remaining_ptr, stream, is_small);
remainder.prev = Some(block_idx);
remainder.next = old_next;
let rem_idx = self.add_block(remainder);
self.blocks[block_idx].size = size;
self.blocks[block_idx].next = Some(rem_idx);
if let Some(old_next_idx) = old_next {
self.blocks[old_next_idx].prev = Some(rem_idx);
}
let rem_key = BlockKey::from_block(&self.blocks[rem_idx]);
let pool = self.get_pool_mut(is_small);
pool.insert_key(rem_idx, rem_key);
}
pub(crate) fn try_merge(&mut self, block_idx: usize, neighbor_idx: Option<usize>) -> usize {
let Some(nbr_idx) = neighbor_idx else {
return 0;
};
if self.blocks[nbr_idx].allocated || !self.blocks[nbr_idx].stream_uses.is_empty() {
return 0;
}
let is_small = self.blocks[nbr_idx].in_small_pool;
let subsumed_size = self.blocks[nbr_idx].size;
let nbr_key = BlockKey::from_block(&self.blocks[nbr_idx]);
{
let pool = self.get_pool_mut(is_small);
pool.remove_key(nbr_idx, nbr_key);
}
if self.blocks[block_idx].prev == Some(nbr_idx) {
let nbr_prev = self.blocks[nbr_idx].prev;
self.blocks[block_idx].ptr = self.blocks[nbr_idx].ptr;
self.blocks[block_idx].size += subsumed_size;
self.blocks[block_idx].prev = nbr_prev;
if let Some(pp) = nbr_prev {
self.blocks[pp].next = Some(block_idx);
}
} else {
let nbr_next = self.blocks[nbr_idx].next;
self.blocks[block_idx].size += subsumed_size;
self.blocks[block_idx].next = nbr_next;
if let Some(nn) = nbr_next {
self.blocks[nn].prev = Some(block_idx);
}
}
self.blocks[nbr_idx].size = 0;
self.blocks[nbr_idx].prev = None;
self.blocks[nbr_idx].next = None;
subsumed_size
}
pub(crate) fn free_block(&mut self, block_idx: usize) {
self.blocks[block_idx].allocated = false;
self.blocks[block_idx].stream_uses.clear();
let size = self.blocks[block_idx].size;
self.allocated_bytes = self.allocated_bytes.saturating_sub(size);
let prev = self.blocks[block_idx].prev;
let next = self.blocks[block_idx].next;
self.try_merge(block_idx, prev);
self.try_merge(block_idx, next);
let is_small = self.blocks[block_idx].in_small_pool;
let merged_key = BlockKey::from_block(&self.blocks[block_idx]);
let pool = self.get_pool_mut(is_small);
pool.insert_key(block_idx, merged_key);
}
pub(crate) fn cached_bytes(&self) -> usize {
self.reserved_bytes.saturating_sub(self.allocated_bytes)
}
}
pub fn round_size(size: usize) -> usize {
if size < MIN_BLOCK_SIZE {
return MIN_BLOCK_SIZE;
}
(size + MIN_BLOCK_SIZE - 1) & !(MIN_BLOCK_SIZE - 1)
}
pub fn get_allocation_size(size: usize) -> usize {
if size <= SMALL_SIZE {
SMALL_BUFFER
} else if size < MIN_LARGE_ALLOC {
LARGE_BUFFER
} else {
(size + ROUND_LARGE - 1) & !(ROUND_LARGE - 1)
}
}
pub struct CudaAllocator {
device: Arc<GpuDevice>,
pub(crate) state: Mutex<AllocatorState>,
allocated_bytes_atomic: AtomicUsize,
peak_bytes_atomic: AtomicUsize,
}
impl CudaAllocator {
pub fn new(device: Arc<GpuDevice>) -> Self {
Self {
device,
state: Mutex::new(AllocatorState::new()),
allocated_bytes_atomic: AtomicUsize::new(0),
peak_bytes_atomic: AtomicUsize::new(0),
}
}
#[cfg(feature = "cuda")]
pub fn alloc_zeros<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
let bytes = count.saturating_mul(std::mem::size_of::<T>());
let slice = self.device.stream().alloc_zeros::<T>(count)?;
let prev = self
.allocated_bytes_atomic
.fetch_add(bytes, Ordering::Relaxed);
self.peak_bytes_atomic
.fetch_max(prev + bytes, Ordering::Relaxed);
Ok(CudaBuffer {
data: Some(slice),
len: count,
alloc_len: count,
device_ordinal: self.device.ordinal(),
pool_fn: None,
})
}
#[cfg(feature = "cuda")]
pub fn alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr,
{
let bytes = data.len().saturating_mul(std::mem::size_of::<T>());
let slice = self.device.stream().clone_htod(data)?;
let prev = self
.allocated_bytes_atomic
.fetch_add(bytes, Ordering::Relaxed);
self.peak_bytes_atomic
.fetch_max(prev + bytes, Ordering::Relaxed);
Ok(CudaBuffer {
data: Some(slice),
len: data.len(),
alloc_len: data.len(),
device_ordinal: self.device.ordinal(),
pool_fn: None,
})
}
pub fn free<T>(&self, buffer: CudaBuffer<T>) {
let bytes = buffer
.len()
.checked_mul(std::mem::size_of::<T>())
.unwrap_or(0);
self.allocated_bytes_atomic
.fetch_sub(bytes, Ordering::Relaxed);
drop(buffer);
}
#[inline]
pub fn memory_allocated(&self) -> usize {
self.allocated_bytes_atomic.load(Ordering::Relaxed)
}
#[inline]
pub fn max_memory_allocated(&self) -> usize {
self.peak_bytes_atomic.load(Ordering::Relaxed)
}
pub fn memory_reserved(&self) -> usize {
self.state.lock().map(|s| s.reserved_bytes).unwrap_or(0)
}
pub fn reset_peak_stats(&self) {
let current = self.allocated_bytes_atomic.load(Ordering::Relaxed);
self.peak_bytes_atomic.store(current, Ordering::Relaxed);
}
pub fn empty_cache(&self) {
let Ok(mut state) = self.state.lock() else {
return;
};
state.small_pool.clear();
state.large_pool.clear();
state.reserved_bytes = state.allocated_bytes;
}
#[inline]
pub fn device(&self) -> &GpuDevice {
&self.device
}
pub fn record_stream_on_block(&self, block_idx: usize, stream: StreamId) {
let Ok(mut state) = self.state.lock() else {
return;
};
if block_idx < state.blocks.len() {
state.blocks[block_idx].stream_uses.insert(stream);
}
}
pub fn block_count(&self) -> usize {
self.state.lock().map(|s| s.blocks.len()).unwrap_or(0)
}
pub fn free_block_count(&self) -> usize {
self.state
.lock()
.map(|s| s.small_pool.len() + s.large_pool.len())
.unwrap_or(0)
}
pub fn cache_stats(&self) -> (usize, usize) {
self.state
.lock()
.map(|s| (s.hits, s.misses))
.unwrap_or((0, 0))
}
pub fn cached_bytes(&self) -> usize {
self.state.lock().map(|s| s.cached_bytes()).unwrap_or(0)
}
pub fn cache_find(&self, size: usize, stream: StreamId) -> Option<(usize, usize)> {
let rounded = round_size(size);
let is_small = rounded <= SMALL_SIZE;
let Ok(mut state) = self.state.lock() else {
return None;
};
let block_idx = {
let pool = state.get_pool_mut(is_small);
pool.find_free_block(stream, rounded)?
};
let key = BlockKey::from_block(&state.blocks[block_idx]);
state.get_pool_mut(is_small).remove_key(block_idx, key);
if state.should_split(block_idx, rounded) {
state.split_block(block_idx, rounded);
}
state.blocks[block_idx].allocated = true;
let actual_size = state.blocks[block_idx].size;
state.allocated_bytes += actual_size;
if state.allocated_bytes > state.peak_bytes {
state.peak_bytes = state.allocated_bytes;
}
state.hits += 1;
Some((block_idx, actual_size))
}
pub fn cache_insert(
&self,
requested_size: usize,
driver_alloc_size: usize,
ptr: usize,
stream: StreamId,
) -> (usize, usize) {
let rounded = round_size(requested_size);
let is_small = rounded <= SMALL_SIZE;
let Ok(mut state) = self.state.lock() else {
return (0, driver_alloc_size);
};
let mut block = Block::new(
self.device.ordinal(),
driver_alloc_size,
ptr,
stream,
is_small,
);
block.allocated = true;
let block_idx = state.add_block(block);
state.reserved_bytes += driver_alloc_size;
if state.should_split(block_idx, rounded) {
state.split_block(block_idx, rounded);
}
let actual_size = state.blocks[block_idx].size;
state.allocated_bytes += actual_size;
if state.allocated_bytes > state.peak_bytes {
state.peak_bytes = state.allocated_bytes;
}
state.misses += 1;
(block_idx, actual_size)
}
pub fn cache_free(&self, block_idx: usize) {
let Ok(mut state) = self.state.lock() else {
return;
};
if block_idx < state.blocks.len() && state.blocks[block_idx].allocated {
state.free_block(block_idx);
}
}
pub fn driver_alloc_size(size: usize) -> usize {
get_allocation_size(round_size(size))
}
}
impl std::fmt::Debug for CudaAllocator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaAllocator")
.field("device_ordinal", &self.device.ordinal())
.field(
"allocated_bytes",
&self.allocated_bytes_atomic.load(Ordering::Relaxed),
)
.field(
"peak_bytes",
&self.peak_bytes_atomic.load(Ordering::Relaxed),
)
.field("cached_bytes", &self.cached_bytes())
.finish()
}
}
#[cfg(not(feature = "cuda"))]
impl CudaAllocator {
pub fn alloc_zeros<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
Err(crate::error::GpuError::NoCudaFeature)
}
pub fn alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
Err(crate::error::GpuError::NoCudaFeature)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_size_minimum() {
assert_eq!(round_size(0), MIN_BLOCK_SIZE);
assert_eq!(round_size(1), MIN_BLOCK_SIZE);
assert_eq!(round_size(511), MIN_BLOCK_SIZE);
assert_eq!(round_size(512), MIN_BLOCK_SIZE);
}
#[test]
fn round_size_multiples() {
assert_eq!(round_size(513), 1024);
assert_eq!(round_size(1024), 1024);
assert_eq!(round_size(1025), 1536);
}
#[test]
fn alloc_size_small() {
assert_eq!(get_allocation_size(512), SMALL_BUFFER);
assert_eq!(get_allocation_size(SMALL_SIZE), SMALL_BUFFER);
}
#[test]
fn alloc_size_mid() {
assert_eq!(get_allocation_size(SMALL_SIZE + 1), LARGE_BUFFER);
assert_eq!(get_allocation_size(MIN_LARGE_ALLOC - 1), LARGE_BUFFER);
}
#[test]
fn alloc_size_large() {
assert_eq!(get_allocation_size(MIN_LARGE_ALLOC), MIN_LARGE_ALLOC);
assert_eq!(
get_allocation_size(MIN_LARGE_ALLOC + 1),
MIN_LARGE_ALLOC + ROUND_LARGE
);
}
fn make_stream() -> StreamId {
StreamId(42)
}
#[test]
fn block_pool_insert_find() {
let mut state = AllocatorState::new();
let stream = make_stream();
let block = Block::new(0, 4096, 0x1000, stream, true);
let idx = state.add_block(block);
state.small_pool.insert(idx, &state.blocks[idx]);
let found = state.small_pool.find_free_block(stream, 512);
assert_eq!(found, Some(idx));
}
#[test]
fn block_pool_respects_stream() {
let mut state = AllocatorState::new();
let stream_a = StreamId(1);
let stream_b = StreamId(2);
let block = Block::new(0, 4096, 0x1000, stream_a, true);
let idx = state.add_block(block);
state.small_pool.insert(idx, &state.blocks[idx]);
assert!(state.small_pool.find_free_block(stream_b, 512).is_none());
assert_eq!(state.small_pool.find_free_block(stream_a, 512), Some(idx));
}
#[test]
fn block_pool_finds_smallest_fit() {
let mut state = AllocatorState::new();
let stream = make_stream();
let b1 = Block::new(0, 4096, 0x1000, stream, true);
let i1 = state.add_block(b1);
state.small_pool.insert(i1, &state.blocks[i1]);
let b2 = Block::new(0, 1024, 0x2000, stream, true);
let i2 = state.add_block(b2);
state.small_pool.insert(i2, &state.blocks[i2]);
let found = state.small_pool.find_free_block(stream, 768);
assert_eq!(found, Some(i2));
}
#[test]
fn split_block_creates_remainder() {
let mut state = AllocatorState::new();
let stream = make_stream();
let block = Block::new(0, 8192, 0x1000, stream, true);
let idx = state.add_block(block);
state.split_block(idx, 1024);
assert_eq!(state.blocks[idx].size, 1024);
let rem_idx = state.blocks[idx].next.unwrap();
assert_eq!(state.blocks[rem_idx].size, 8192 - 1024);
assert_eq!(state.blocks[rem_idx].ptr, 0x1000 + 1024);
assert_eq!(state.blocks[rem_idx].prev, Some(idx));
let found = state.small_pool.find_free_block(stream, 1024);
assert_eq!(found, Some(rem_idx));
}
#[test]
fn coalesce_merges_adjacent_blocks() {
let mut state = AllocatorState::new();
let stream = make_stream();
let a = Block::new(0, 2048, 0x1000, stream, true);
let a_idx = state.add_block(a);
let b = Block::new(0, 2048, 0x1000 + 2048, stream, true);
let b_idx = state.add_block(b);
let c = Block::new(0, 4096, 0x1000 + 4096, stream, true);
let c_idx = state.add_block(c);
state.blocks[a_idx].next = Some(b_idx);
state.blocks[b_idx].prev = Some(a_idx);
state.blocks[b_idx].next = Some(c_idx);
state.blocks[c_idx].prev = Some(b_idx);
state.blocks[b_idx].allocated = true;
state.blocks[b_idx].size = 2048;
state.allocated_bytes = 2048;
state.small_pool.insert(a_idx, &state.blocks[a_idx]);
state.small_pool.insert(c_idx, &state.blocks[c_idx]);
state.free_block(b_idx);
assert_eq!(state.blocks[b_idx].size, 2048 + 2048 + 4096);
assert_eq!(state.blocks[b_idx].ptr, 0x1000);
assert!(!state.blocks[b_idx].allocated);
}
#[test]
fn should_split_small_pool() {
let mut state = AllocatorState::new();
let stream = make_stream();
let block = Block::new(0, 2048, 0x1000, stream, true);
let idx = state.add_block(block);
assert!(state.should_split(idx, 1024));
assert!(!state.should_split(idx, 1800));
}
#[test]
fn should_split_large_pool() {
let mut state = AllocatorState::new();
let stream = make_stream();
let block = Block::new(0, 4 * 1024 * 1024, 0x1000, stream, false);
let idx = state.add_block(block);
assert!(state.should_split(idx, 2 * 1024 * 1024));
assert!(!state.should_split(idx, 3 * 1024 * 1024 + 512 * 1024));
}
#[test]
fn stream_uses_prevent_reuse() {
let stream = make_stream();
let mut block = Block::new(0, 4096, 0x1000, stream, true);
assert!(block.stream_uses.is_empty());
block.stream_uses.insert(StreamId(99));
assert!(!block.stream_uses.is_empty());
}
#[test]
fn stream_uses_prevent_merge() {
let mut state = AllocatorState::new();
let stream = make_stream();
let a = Block::new(0, 2048, 0x1000, stream, true);
let a_idx = state.add_block(a);
let mut b = Block::new(0, 2048, 0x1000 + 2048, stream, true);
b.stream_uses.insert(StreamId(99)); let b_idx = state.add_block(b);
state.blocks[a_idx].next = Some(b_idx);
state.blocks[b_idx].prev = Some(a_idx);
state.small_pool.insert(b_idx, &state.blocks[b_idx]);
let merged = state.try_merge(a_idx, Some(b_idx));
assert_eq!(merged, 0);
assert_eq!(state.blocks[a_idx].size, 2048); }
#[test]
fn cache_find_and_insert_roundtrip() {
let device = Arc::new(match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return, });
let alloc = CudaAllocator::new(device);
let stream = StreamId(1);
let (idx, actual) = alloc.cache_insert(2048, 4096, 0x1000, stream);
assert!(actual <= 4096);
assert_eq!(alloc.cache_stats().1, 1);
alloc.cache_free(idx);
let found = alloc.cache_find(512, stream);
assert!(found.is_some());
assert_eq!(alloc.cache_stats().0, 1); }
#[test]
fn empty_cache_clears_pools() {
let device = Arc::new(match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
});
let alloc = CudaAllocator::new(device);
let stream = StreamId(1);
alloc.cache_insert(1024, 4096, 0x1000, stream);
{
let state = alloc.state.lock().unwrap();
assert!(!state.blocks.is_empty());
}
alloc.cache_free(0);
assert!(alloc.free_block_count() > 0);
alloc.empty_cache();
assert_eq!(alloc.free_block_count(), 0);
}
#[cfg(feature = "cuda")]
mod cuda_tests {
use super::*;
fn make_allocator() -> CudaAllocator {
let device = GpuDevice::new(0).expect("CUDA device 0");
CudaAllocator::new(Arc::new(device))
}
#[test]
fn new_allocator_starts_at_zero() {
let alloc = make_allocator();
assert_eq!(alloc.memory_allocated(), 0);
assert_eq!(alloc.max_memory_allocated(), 0);
}
#[test]
fn empty_cache_is_harmless() {
let alloc = make_allocator();
alloc.empty_cache();
}
#[test]
fn debug_impl() {
let alloc = make_allocator();
let s = format!("{alloc:?}");
assert!(s.contains("CudaAllocator"));
assert!(s.contains("allocated_bytes"));
}
#[test]
fn alloc_increases_allocated_bytes() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(256).expect("alloc_zeros");
assert_eq!(alloc.memory_allocated(), 256 * std::mem::size_of::<f32>());
assert_eq!(
alloc.max_memory_allocated(),
256 * std::mem::size_of::<f32>()
);
alloc.free(buf);
}
#[test]
fn free_decreases_allocated_bytes() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(128).expect("alloc_zeros");
let expected = 128 * std::mem::size_of::<f32>();
assert_eq!(alloc.memory_allocated(), expected);
alloc.free(buf);
assert_eq!(alloc.memory_allocated(), 0);
}
#[test]
fn peak_tracks_maximum() {
let alloc = make_allocator();
let buf1 = alloc.alloc_zeros::<f32>(100).expect("alloc 1");
let buf2 = alloc.alloc_zeros::<f32>(200).expect("alloc 2");
let peak_after_two = alloc.max_memory_allocated();
alloc.free(buf1);
assert_eq!(alloc.max_memory_allocated(), peak_after_two);
assert!(alloc.memory_allocated() < peak_after_two);
alloc.free(buf2);
assert_eq!(alloc.memory_allocated(), 0);
assert_eq!(alloc.max_memory_allocated(), peak_after_two);
}
#[test]
fn reset_peak_stats_lowers_peak() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(512).expect("alloc");
let high = alloc.max_memory_allocated();
alloc.free(buf);
assert_eq!(alloc.max_memory_allocated(), high);
alloc.reset_peak_stats();
assert_eq!(alloc.max_memory_allocated(), 0);
}
#[test]
fn alloc_copy_tracks_bytes() {
let alloc = make_allocator();
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let buf = alloc.alloc_copy(&data).expect("alloc_copy");
assert_eq!(alloc.memory_allocated(), 4 * std::mem::size_of::<f64>());
alloc.free(buf);
assert_eq!(alloc.memory_allocated(), 0);
}
#[test]
fn zero_element_alloc() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(0).expect("alloc_zeros empty");
assert_eq!(alloc.memory_allocated(), 0);
assert_eq!(buf.len(), 0);
assert!(buf.is_empty());
alloc.free(buf);
assert_eq!(alloc.memory_allocated(), 0);
}
}
}