use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use cudarc::driver::CudaSlice;
use super::resource::{
AllocTag, BlockState, DeviceBlock, DeviceMemoryResource, Generation, ResourceError,
ResourceResult, StreamId,
};
use crate::CudaDevice;
pub struct DirectCudaResource {
device: Arc<CudaDevice>,
device_ordinal: u32,
live: Mutex<HashMap<u64, CudaSlice<u8>>>,
bytes_outstanding: AtomicUsize,
}
impl DirectCudaResource {
pub fn new(device: Arc<CudaDevice>, device_ordinal: u32) -> Self {
Self {
device,
device_ordinal,
live: Mutex::new(HashMap::new()),
bytes_outstanding: AtomicUsize::new(0),
}
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
}
impl DeviceMemoryResource for DirectCudaResource {
fn allocate(
&self,
bytes: usize,
stream: StreamId,
tag: AllocTag,
) -> ResourceResult<DeviceBlock> {
if bytes == 0 {
return Err(ResourceError::Driver(
"DirectCudaResource: zero-byte allocation not supported".to_string(),
));
}
let slice = unsafe {
self.device.inner().alloc::<u8>(bytes).map_err(|e| {
ResourceError::Driver(format!("cudarc alloc::<u8>({}): {}", bytes, e))
})?
};
let (raw_ptr, sync) =
<CudaSlice<u8> as cudarc::driver::DevicePtr<u8>>::device_ptr(&slice, slice.stream());
std::mem::forget(sync);
let ptr = raw_ptr;
{
let mut live = self.live.lock().expect("live map poisoned");
if live.contains_key(&ptr) {
return Err(ResourceError::Driver(format!(
"DirectCudaResource: pointer collision on alloc ({:#x})",
ptr
)));
}
live.insert(ptr, slice);
}
self.bytes_outstanding.fetch_add(bytes, Ordering::Relaxed);
Ok(DeviceBlock {
ptr,
device_ordinal: self.device_ordinal,
alloc_stream: stream,
bytes,
align: std::mem::align_of::<u8>(),
tag,
generation: Generation::next(),
state: BlockState::Live,
})
}
fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
if block.device_ordinal != self.device_ordinal {
return Err(ResourceError::Driver(format!(
"DirectCudaResource: deallocate on wrong device (block ord {} vs resource ord {})",
block.device_ordinal, self.device_ordinal
)));
}
let removed = {
let mut live = self.live.lock().expect("live map poisoned");
live.remove(&block.ptr)
};
let slice = removed.ok_or(ResourceError::UseAfterFree {
generation: block.generation,
})?;
self.bytes_outstanding
.fetch_sub(block.bytes, Ordering::Relaxed);
drop(slice);
Ok(())
}
fn device_ordinal(&self) -> u32 {
self.device_ordinal
}
fn bytes_outstanding(&self) -> usize {
self.bytes_outstanding.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn try_device() -> Option<Arc<CudaDevice>> {
CudaDevice::new(0).ok().map(Arc::new)
}
#[test]
fn allocate_then_deallocate_round_trips() {
let Some(device) = try_device() else {
eprintln!("Skipping: no CUDA device");
return;
};
let r = DirectCudaResource::new(device, 0);
assert_eq!(r.bytes_outstanding(), 0);
let block = r
.allocate(4096, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
assert_eq!(block.bytes, 4096);
assert_eq!(block.state, BlockState::Live);
assert_eq!(r.bytes_outstanding(), 4096);
r.deallocate(block).expect("dealloc");
assert_eq!(r.bytes_outstanding(), 0);
}
#[test]
fn zero_byte_allocate_rejects() {
let Some(device) = try_device() else {
return;
};
let r = DirectCudaResource::new(device, 0);
let err = r.allocate(0, StreamId::DEFAULT, AllocTag::UNTAGGED);
assert!(matches!(err, Err(ResourceError::Driver(_))));
assert_eq!(r.bytes_outstanding(), 0);
}
#[test]
fn deallocate_unknown_block_returns_use_after_free() {
let Some(device) = try_device() else {
return;
};
let r = DirectCudaResource::new(device, 0);
let bogus = DeviceBlock {
ptr: 0xdead_beef,
device_ordinal: 0,
alloc_stream: StreamId::DEFAULT,
bytes: 16,
align: 1,
tag: AllocTag::UNTAGGED,
generation: Generation::next(),
state: BlockState::Live,
};
assert!(matches!(
r.deallocate(bogus),
Err(ResourceError::UseAfterFree { .. })
));
}
#[test]
fn record_block_use_rejected_with_stream_misuse() {
let Some(device) = try_device() else {
return;
};
let r = DirectCudaResource::new(device, 0);
let block = r
.allocate(64, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
let err = r.record_block_use(&block, StreamId::DEFAULT);
match err {
Err(ResourceError::StreamMisuse(msg)) => {
assert!(
msg.contains("unsupported"),
"expected 'unsupported' in StreamMisuse message, got {:?}",
msg
);
}
other => panic!(
"DirectCudaResource::record_block_use must return StreamMisuse \
to surface unsupported cross-stream tracking; got {:?}",
other
),
}
assert_eq!(r.bytes_outstanding(), 64);
r.deallocate(block).expect("dealloc still works");
}
}