use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::GpuResult;
pub struct CudaAllocator {
device: Arc<GpuDevice>,
allocated_bytes: AtomicUsize,
peak_bytes: AtomicUsize,
}
impl CudaAllocator {
pub fn new(device: Arc<GpuDevice>) -> Self {
Self {
device,
allocated_bytes: AtomicUsize::new(0),
peak_bytes: AtomicUsize::new(0),
}
}
#[cfg(feature = "cuda")]
pub fn alloc_zeros<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
let bytes = count.checked_mul(std::mem::size_of::<T>()).unwrap_or(usize::MAX);
let slice = self.device.stream().alloc_zeros::<T>(count)?;
let prev = self.allocated_bytes.fetch_add(bytes, Ordering::Relaxed);
self.peak_bytes.fetch_max(prev + bytes, Ordering::Relaxed);
Ok(CudaBuffer {
data: Some(slice),
len: count,
device_ordinal: self.device.ordinal(),
pool_fn: None,
})
}
#[cfg(feature = "cuda")]
pub fn alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr,
{
let bytes = data.len().checked_mul(std::mem::size_of::<T>()).unwrap_or(usize::MAX);
let slice = self.device.stream().clone_htod(data)?;
let prev = self.allocated_bytes.fetch_add(bytes, Ordering::Relaxed);
self.peak_bytes.fetch_max(prev + bytes, Ordering::Relaxed);
Ok(CudaBuffer {
data: Some(slice),
len: data.len(),
device_ordinal: self.device.ordinal(),
pool_fn: None,
})
}
pub fn free<T>(&self, buffer: CudaBuffer<T>) {
let bytes = buffer.len().checked_mul(std::mem::size_of::<T>()).unwrap_or(0);
self.allocated_bytes.fetch_sub(bytes, Ordering::Relaxed);
drop(buffer);
}
#[inline]
pub fn memory_allocated(&self) -> usize {
self.allocated_bytes.load(Ordering::Relaxed)
}
#[inline]
pub fn max_memory_allocated(&self) -> usize {
self.peak_bytes.load(Ordering::Relaxed)
}
pub fn reset_peak_stats(&self) {
let current = self.allocated_bytes.load(Ordering::Relaxed);
self.peak_bytes.store(current, Ordering::Relaxed);
}
pub fn empty_cache(&self) {
}
#[inline]
pub fn device(&self) -> &GpuDevice {
&self.device
}
}
impl std::fmt::Debug for CudaAllocator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaAllocator")
.field("device_ordinal", &self.device.ordinal())
.field("allocated_bytes", &self.allocated_bytes.load(Ordering::Relaxed))
.field("peak_bytes", &self.peak_bytes.load(Ordering::Relaxed))
.finish()
}
}
#[cfg(not(feature = "cuda"))]
impl CudaAllocator {
pub fn alloc_zeros<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
Err(crate::error::GpuError::NoCudaFeature)
}
pub fn alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
Err(crate::error::GpuError::NoCudaFeature)
}
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
fn make_allocator() -> CudaAllocator {
let device = GpuDevice::new(0).expect("CUDA device 0");
CudaAllocator::new(Arc::new(device))
}
#[test]
fn new_allocator_starts_at_zero() {
let alloc = make_allocator();
assert_eq!(alloc.memory_allocated(), 0);
assert_eq!(alloc.max_memory_allocated(), 0);
}
#[test]
fn empty_cache_is_harmless() {
let alloc = make_allocator();
alloc.empty_cache(); }
#[test]
fn debug_impl() {
let alloc = make_allocator();
let s = format!("{alloc:?}");
assert!(s.contains("CudaAllocator"));
assert!(s.contains("allocated_bytes"));
}
#[test]
fn alloc_increases_allocated_bytes() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(256).expect("alloc_zeros");
assert_eq!(alloc.memory_allocated(), 256 * std::mem::size_of::<f32>());
assert_eq!(alloc.max_memory_allocated(), 256 * std::mem::size_of::<f32>());
alloc.free(buf);
}
#[test]
fn free_decreases_allocated_bytes() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(128).expect("alloc_zeros");
let expected = 128 * std::mem::size_of::<f32>();
assert_eq!(alloc.memory_allocated(), expected);
alloc.free(buf);
assert_eq!(alloc.memory_allocated(), 0);
}
#[test]
fn peak_tracks_maximum() {
let alloc = make_allocator();
let buf1 = alloc.alloc_zeros::<f32>(100).expect("alloc 1");
let buf2 = alloc.alloc_zeros::<f32>(200).expect("alloc 2");
let peak_after_two = alloc.max_memory_allocated();
alloc.free(buf1);
assert_eq!(alloc.max_memory_allocated(), peak_after_two);
assert!(alloc.memory_allocated() < peak_after_two);
alloc.free(buf2);
assert_eq!(alloc.memory_allocated(), 0);
assert_eq!(alloc.max_memory_allocated(), peak_after_two);
}
#[test]
fn reset_peak_stats_lowers_peak() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(512).expect("alloc");
let high = alloc.max_memory_allocated();
alloc.free(buf);
assert_eq!(alloc.max_memory_allocated(), high);
alloc.reset_peak_stats();
assert_eq!(alloc.max_memory_allocated(), 0);
}
#[test]
fn alloc_copy_tracks_bytes() {
let alloc = make_allocator();
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let buf = alloc.alloc_copy(&data).expect("alloc_copy");
assert_eq!(alloc.memory_allocated(), 4 * std::mem::size_of::<f64>());
alloc.free(buf);
assert_eq!(alloc.memory_allocated(), 0);
}
#[test]
fn zero_element_alloc() {
let alloc = make_allocator();
let buf = alloc.alloc_zeros::<f32>(0).expect("alloc_zeros empty");
assert_eq!(alloc.memory_allocated(), 0);
assert_eq!(buf.len(), 0);
assert!(buf.is_empty());
alloc.free(buf);
assert_eq!(alloc.memory_allocated(), 0);
}
}