use cudarc::driver::safe::CudaStream;
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::{Arc, Mutex};
use super::arena::CudaArena;
use crate::runtime::Allocator;
const FREE_LIST_CAP: usize = 64;
const DEFAULT_FREE_LIST_CAP_BYTES: usize = 1024 * 1024 * 1024;
fn resolve_free_list_cap_bytes() -> usize {
super::env_config::env_mib_to_bytes(
"NUMR_CUDA_FREE_LIST_CAP_MB",
DEFAULT_FREE_LIST_CAP_BYTES as u64,
) as usize
}
#[derive(Default)]
struct FreeList {
map: HashMap<u64, VecDeque<u64>>,
total_bytes: usize,
}
impl FreeList {
fn pop(&mut self, size_bytes: u64) -> Option<u64> {
let ptr = self.map.get_mut(&size_bytes).and_then(|b| b.pop_front())?;
self.total_bytes -= size_bytes as usize;
Some(ptr)
}
fn push(&mut self, size_bytes: u64, ptr: u64) {
self.map.entry(size_bytes).or_default().push_back(ptr);
self.total_bytes += size_bytes as usize;
}
fn evict_to_cap(&mut self, cap: usize) -> Vec<u64> {
let mut evicted = Vec::new();
while self.total_bytes > cap {
let largest = self
.map
.iter()
.filter(|(_, b)| !b.is_empty())
.map(|(&s, _)| s)
.max();
match largest {
Some(size) => {
if let Some(ptr) = self.pop(size) {
evicted.push(ptr);
} else {
break;
}
}
None => break,
}
}
evicted
}
}
#[derive(Clone)]
pub struct CudaAllocator {
stream: Arc<CudaStream>,
free_list: Arc<Mutex<FreeList>>,
free_list_cap_bytes: usize,
frozen: Arc<std::sync::atomic::AtomicBool>,
captured_ptrs: Arc<Mutex<HashSet<u64>>>,
pool_handle: u64,
arena: Arc<Mutex<Option<CudaArena>>>,
}
impl CudaAllocator {
pub(super) fn new(stream: Arc<CudaStream>, pool_handle: u64) -> Self {
Self {
stream,
free_list: Arc::new(Mutex::new(FreeList::default())),
free_list_cap_bytes: resolve_free_list_cap_bytes(),
frozen: Arc::new(std::sync::atomic::AtomicBool::new(false)),
captured_ptrs: Arc::new(Mutex::new(HashSet::new())),
pool_handle,
arena: Arc::new(Mutex::new(None)),
}
}
unsafe fn driver_alloc(&self, size_bytes: usize) -> crate::error::Result<u64> {
let mut ptr: u64 = 0;
let result = unsafe {
cudarc::driver::sys::cuMemAllocAsync(&mut ptr, size_bytes, self.stream.cu_stream())
};
if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Ok(ptr);
}
let drained: Vec<u64> = {
let mut fl = self.free_list.lock().unwrap();
fl.total_bytes = 0;
fl.map
.drain()
.flat_map(|(_, bucket)| bucket.into_iter())
.collect()
};
for p in drained {
let _ = unsafe { cudarc::driver::sys::cuMemFreeAsync(p, self.stream.cu_stream()) };
}
let _ = self.stream.synchronize();
if self.pool_handle != 0 {
let pool = self.pool_handle as cudarc::driver::sys::CUmemoryPool;
let _ = unsafe { cudarc::driver::sys::cuMemPoolTrimTo(pool, 0) };
}
let result = unsafe {
cudarc::driver::sys::cuMemAllocAsync(&mut ptr, size_bytes, self.stream.cu_stream())
};
if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS {
Ok(ptr)
} else {
Err(crate::error::Error::OutOfMemory { size: size_bytes })
}
}
unsafe fn driver_free(&self, ptr: u64) {
let _ = unsafe { cudarc::driver::sys::cuMemFreeAsync(ptr, self.stream.cu_stream()) };
}
pub fn install_arena(&self, base: u64, size: usize) -> crate::error::Result<()> {
let mut guard = self.arena.lock().unwrap_or_else(|p| p.into_inner());
if guard.is_some() {
return Err(crate::error::Error::Internal(
"CudaAllocator::install_arena: an arena is already installed; \
graph capture is not re-entrant on a single client"
.into(),
));
}
*guard = Some(CudaArena::new(base, size));
Ok(())
}
pub fn clear_arena(&self) {
let mut guard = self.arena.lock().unwrap_or_else(|p| p.into_inner());
*guard = None;
}
pub fn has_arena(&self) -> bool {
let guard = self.arena.lock().unwrap_or_else(|p| p.into_inner());
guard.is_some()
}
}
impl Allocator for CudaAllocator {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64> {
if size_bytes == 0 {
return Ok(0);
}
if self.frozen.load(std::sync::atomic::Ordering::Relaxed) {
let ptr = {
let mut arena_guard = self.arena.lock().unwrap_or_else(|p| p.into_inner());
if let Some(ref mut arena) = *arena_guard {
arena.allocate(size_bytes)?
} else {
drop(arena_guard);
unsafe { self.driver_alloc(size_bytes) }?
}
};
self.captured_ptrs
.lock()
.unwrap_or_else(|p| p.into_inner())
.insert(ptr);
return Ok(ptr);
}
{
let mut fl = self.free_list.lock().unwrap();
if let Some(ptr) = fl.pop(size_bytes as u64) {
return Ok(ptr);
}
}
unsafe { self.driver_alloc(size_bytes) }
}
fn deallocate(&self, ptr: u64, size_bytes: usize) {
if ptr == 0 {
return;
}
if self.frozen.load(std::sync::atomic::Ordering::Relaxed) {
let mut arena_guard = self.arena.lock().unwrap_or_else(|p| p.into_inner());
if let Some(ref mut arena) = *arena_guard {
arena.deallocate(ptr);
} else {
drop(arena_guard);
unsafe { self.driver_free(ptr) };
}
return;
}
let mut evict: Vec<u64> = Vec::new();
{
let mut fl = self.free_list.lock().unwrap();
let size = size_bytes as u64;
match fl.map.get_mut(&size) {
Some(bucket) if bucket.len() >= FREE_LIST_CAP => {
if let Some(old) = bucket.pop_front() {
evict.push(old);
}
bucket.push_back(ptr);
}
_ => fl.push(size, ptr),
}
evict.extend(fl.evict_to_cap(self.free_list_cap_bytes));
}
for old_ptr in evict {
unsafe { self.driver_free(old_ptr) };
}
}
fn is_frozen(&self) -> bool {
self.frozen.load(std::sync::atomic::Ordering::Relaxed)
}
fn freeze(&self) -> bool {
self.frozen
.store(true, std::sync::atomic::Ordering::Relaxed);
true
}
fn unfreeze(&self) {
self.frozen
.store(false, std::sync::atomic::Ordering::Relaxed);
self.clear_arena();
let captured: HashSet<u64> = {
let mut set = self.captured_ptrs.lock().unwrap();
std::mem::take(&mut *set)
};
if captured.is_empty() {
return;
}
#[cfg(debug_assertions)]
eprintln!(
"[numr::cuda] unfreeze: {} pointer(s) from the freeze window were \
still live at unfreeze (will be freed by the driver on next graph \
launch — expected for graph-internal scratch).",
captured.len()
);
#[cfg(debug_assertions)]
{
let fl = self.free_list.lock().unwrap();
for bucket in fl.map.values() {
for &cached_ptr in bucket {
debug_assert!(
!captured.contains(&cached_ptr),
"[numr::cuda] GRAPH CORRUPTION DETECTED: pointer 0x{:x} was \
allocated during a freeze window (graph-capture) but was \
subsequently absorbed into the Rust free list via the \
un-frozen deallocate() path. On next allocation this \
address would be handed to a non-graph caller while the \
CUDA graph still holds a reference to it.",
cached_ptr
);
}
}
}
}
fn reset(&self) -> crate::error::Result<()> {
let drained: Vec<u64> = {
let mut fl = self.free_list.lock().unwrap();
fl.total_bytes = 0;
fl.map
.drain()
.flat_map(|(_, bucket)| bucket.into_iter())
.collect()
};
for ptr in drained {
unsafe { self.driver_free(ptr) };
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::super::client::CudaClient;
use super::super::device::CudaDevice;
use crate::runtime::Allocator;
#[cfg(feature = "cuda")]
#[test]
#[ignore = "requires a live CUDA GPU"]
fn captured_ptrs_unfreeze_detects_corruption() {
let device = CudaDevice { index: 0 };
let client =
CudaClient::new_uncached(device).expect("CudaClient creation requires a CUDA GPU");
let alloc = &client.allocator;
let _p1 = alloc.allocate(256).expect("alloc p1");
let _p2 = alloc.allocate(512).expect("alloc p2");
alloc.freeze();
assert!(alloc.is_frozen(), "allocator should be frozen");
let p3 = alloc.allocate(128).expect("alloc p3 during freeze");
assert_ne!(p3, 0, "frozen alloc must return a non-null pointer");
{
let set = alloc.captured_ptrs.lock().unwrap();
assert!(
set.contains(&p3),
"p3 must be present in captured_ptrs after frozen allocate"
);
}
{
let mut fl = alloc.free_list.lock().unwrap();
fl.push(128, p3);
}
#[cfg(debug_assertions)]
{
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
alloc.unfreeze();
}));
assert!(
result.is_err(),
"unfreeze() must panic when a captured pointer leaks into the free list"
);
}
#[cfg(not(debug_assertions))]
{
alloc.unfreeze();
let set = alloc.captured_ptrs.lock().unwrap();
assert!(
set.is_empty(),
"captured_ptrs must be empty after unfreeze (release build)"
);
}
}
#[cfg(feature = "cuda")]
#[test]
#[ignore = "requires a live CUDA GPU"]
fn captured_ptrs_cleared_after_clean_unfreeze() {
let device = CudaDevice { index: 0 };
let client =
CudaClient::new_uncached(device).expect("CudaClient creation requires a CUDA GPU");
let alloc = &client.allocator;
alloc.freeze();
let p = alloc.allocate(64).expect("alloc during freeze");
assert_ne!(p, 0);
{
let set = alloc.captured_ptrs.lock().unwrap();
assert!(set.contains(&p), "p must be in captured_ptrs");
}
alloc.unfreeze();
{
let set = alloc.captured_ptrs.lock().unwrap();
assert!(set.is_empty(), "captured_ptrs must be empty after unfreeze");
}
assert!(!alloc.is_frozen(), "allocator must be unfrozen");
}
}