use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use cudarc::driver::{CudaEvent, CudaSlice};
use super::resource::{
Access, AllocTag, BlockId, BlockState, DeviceBlock, DeviceMemoryResource, Generation,
ResourceError, ResourceResult, StreamId,
};
use super::stream_pool::StreamPool;
use crate::CudaDevice;
struct LiveEntry {
slice: CudaSlice<u8>,
generation: Generation,
last_write: Option<(StreamId, CudaEvent)>,
outstanding_reads: Vec<(StreamId, CudaEvent)>,
}
pub struct AsyncCudaResource {
device: Arc<CudaDevice>,
device_ordinal: u32,
stream_pool: Arc<StreamPool>,
live: Mutex<HashMap<u64, LiveEntry>>,
live_bytes: AtomicUsize,
pending_bytes: AtomicUsize,
pending_per_stream: Mutex<HashMap<StreamId, usize>>,
}
impl AsyncCudaResource {
pub fn new(device: Arc<CudaDevice>, device_ordinal: u32, stream_pool: Arc<StreamPool>) -> Self {
Self {
device,
device_ordinal,
stream_pool,
live: Mutex::new(HashMap::new()),
live_bytes: AtomicUsize::new(0),
pending_bytes: AtomicUsize::new(0),
pending_per_stream: Mutex::new(HashMap::new()),
}
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn stream_pool(&self) -> &Arc<StreamPool> {
&self.stream_pool
}
pub fn live_bytes(&self) -> usize {
self.live_bytes.load(Ordering::Relaxed)
}
pub fn pending_free_bytes(&self) -> usize {
self.pending_bytes.load(Ordering::Relaxed)
}
pub fn pending_per_stream_total(&self) -> usize {
let map = self
.pending_per_stream
.lock()
.expect("AsyncCudaResource pending_per_stream poisoned");
map.values().copied().sum()
}
pub fn pending_use_event_count(&self, ptr: u64) -> Option<usize> {
let live = self
.live
.lock()
.expect("AsyncCudaResource live map poisoned");
live.get(&ptr)
.map(|e| e.outstanding_reads.len() + if e.last_write.is_some() { 1 } else { 0 })
}
}
impl DeviceMemoryResource for AsyncCudaResource {
fn allocate(
&self,
bytes: usize,
stream: StreamId,
tag: AllocTag,
) -> ResourceResult<DeviceBlock> {
if bytes == 0 {
return Err(ResourceError::Driver(
"AsyncCudaResource: zero-byte allocation not supported".to_string(),
));
}
let cu_stream = self.stream_pool.resolve(stream).ok_or_else(|| {
ResourceError::StreamMisuse(format!(
"AsyncCudaResource: unknown StreamId({})",
stream.0
))
})?;
let slice = unsafe {
cu_stream
.alloc::<u8>(bytes)
.map_err(|e| ResourceError::Driver(format!("cuMemAllocAsync({}): {}", bytes, e)))?
};
let alloc_event = cu_stream.record_event(None).map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::allocate: record allocation-ready event failed: {}",
e
))
})?;
let (raw_ptr, sync) =
<CudaSlice<u8> as cudarc::driver::DevicePtr<u8>>::device_ptr(&slice, slice.stream());
std::mem::forget(sync);
let ptr = raw_ptr;
{
let mut live = self
.live
.lock()
.expect("AsyncCudaResource live map poisoned");
if live.contains_key(&ptr) {
return Err(ResourceError::Driver(format!(
"AsyncCudaResource: pointer collision on alloc ({:#x})",
ptr
)));
}
let generation = Generation::next();
live.insert(
ptr,
LiveEntry {
slice,
generation,
last_write: Some((stream, alloc_event)),
outstanding_reads: Vec::new(),
},
);
self.live_bytes.fetch_add(bytes, Ordering::Relaxed);
Ok(DeviceBlock {
ptr,
device_ordinal: self.device_ordinal,
alloc_stream: stream,
bytes,
align: std::mem::align_of::<u8>(),
tag,
generation,
state: BlockState::Live,
})
}
}
fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
if block.device_ordinal != self.device_ordinal {
return Err(ResourceError::Driver(format!(
"AsyncCudaResource: deallocate on wrong device (block ord {} vs resource ord {})",
block.device_ordinal, self.device_ordinal
)));
}
let alloc_stream = self
.stream_pool
.resolve(block.alloc_stream)
.ok_or_else(|| {
ResourceError::StreamMisuse(format!(
"AsyncCudaResource::deallocate: alloc_stream StreamId({}) does not resolve",
block.alloc_stream.0
))
})?;
let (slice, last_write, outstanding_reads) = {
let mut live = self
.live
.lock()
.expect("AsyncCudaResource live map poisoned");
match live.get(&block.ptr) {
Some(entry) if entry.generation == block.generation => {
if let Some((write_stream, event)) = &entry.last_write {
if *write_stream != block.alloc_stream {
alloc_stream.wait(event).map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::deallocate: cuStreamWaitEvent on \
last_write failed: {}",
e
))
})?;
}
}
for (read_stream, event) in &entry.outstanding_reads {
if *read_stream != block.alloc_stream {
alloc_stream.wait(event).map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::deallocate: cuStreamWaitEvent on \
outstanding read failed: {}",
e
))
})?;
}
}
let LiveEntry {
slice,
last_write,
outstanding_reads,
..
} = live
.remove(&block.ptr)
.expect("present under lock per get above");
(slice, last_write, outstanding_reads)
}
Some(_) | None => {
return Err(ResourceError::UseAfterFree {
generation: block.generation,
});
}
}
};
self.live_bytes.fetch_sub(block.bytes, Ordering::Relaxed);
{
let mut per_stream = self
.pending_per_stream
.lock()
.expect("AsyncCudaResource pending_per_stream poisoned");
*per_stream.entry(block.alloc_stream).or_insert(0) += block.bytes;
self.pending_bytes.fetch_add(block.bytes, Ordering::Relaxed);
}
drop(slice);
drop(last_write);
drop(outstanding_reads);
Ok(())
}
fn device_ordinal(&self) -> u32 {
self.device_ordinal
}
fn bytes_outstanding(&self) -> usize {
self.live_bytes.load(Ordering::Relaxed) + self.pending_bytes.load(Ordering::Relaxed)
}
fn reap_pending(&self) -> ResourceResult<()> {
self.reap_pending_with(|stream_id| match self.stream_pool.resolve(stream_id) {
Some(stream) => stream.synchronize().map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::reap_pending: stream sync failed: {}",
e
))
}),
None => Ok(()),
})
}
fn supports_block_use_tracking(&self) -> bool {
true
}
fn record_block_use(&self, block: &DeviceBlock, use_stream: StreamId) -> ResourceResult<()> {
self.finish_block_use(BlockId::from_block(block), use_stream, Access::Read)
}
fn prepare_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
if block.device_ordinal != self.device_ordinal {
return Err(ResourceError::Driver(format!(
"AsyncCudaResource::prepare_block_use: block device {} != resource device {}",
block.device_ordinal, self.device_ordinal
)));
}
let use_cu_stream = self.stream_pool.resolve(use_stream).ok_or_else(|| {
ResourceError::StreamMisuse(format!(
"AsyncCudaResource::prepare_block_use: unknown StreamId({})",
use_stream.0
))
})?;
let live = self
.live
.lock()
.expect("AsyncCudaResource live map poisoned");
let entry = match live.get(&block.ptr) {
Some(entry) if entry.generation == block.generation => entry,
Some(_) | None => {
return Err(ResourceError::UseAfterFree {
generation: block.generation,
});
}
};
if access.reads() || access.writes() {
if let Some((write_stream, event)) = &entry.last_write {
if *write_stream != use_stream {
use_cu_stream.wait(event).map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::prepare_block_use: wait on last_write failed: {}",
e
))
})?;
}
}
}
if access.writes() {
for (read_stream, event) in &entry.outstanding_reads {
if *read_stream != use_stream {
use_cu_stream.wait(event).map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::prepare_block_use: wait on outstanding read \
failed: {}",
e
))
})?;
}
}
}
Ok(())
}
fn finish_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
if block.device_ordinal != self.device_ordinal {
return Err(ResourceError::Driver(format!(
"AsyncCudaResource::finish_block_use: block device {} != resource device {}",
block.device_ordinal, self.device_ordinal
)));
}
let use_cu_stream = self.stream_pool.resolve(use_stream).ok_or_else(|| {
ResourceError::StreamMisuse(format!(
"AsyncCudaResource::finish_block_use: unknown StreamId({})",
use_stream.0
))
})?;
{
let live = self
.live
.lock()
.expect("AsyncCudaResource live map poisoned");
match live.get(&block.ptr) {
Some(entry) if entry.generation == block.generation => {}
Some(_) | None => {
return Err(ResourceError::UseAfterFree {
generation: block.generation,
});
}
}
}
let event = use_cu_stream.record_event(None).map_err(|e| {
ResourceError::Driver(format!(
"AsyncCudaResource::finish_block_use: event record failed: {}",
e
))
})?;
let mut live = self
.live
.lock()
.expect("AsyncCudaResource live map poisoned");
match live.get_mut(&block.ptr) {
Some(entry) if entry.generation == block.generation => {
if access.writes() {
entry.last_write = Some((use_stream, event));
entry.outstanding_reads.clear();
} else {
debug_assert!(access.reads());
entry.outstanding_reads.push((use_stream, event));
}
Ok(())
}
Some(_) | None => {
drop(event);
Err(ResourceError::UseAfterFree {
generation: block.generation,
})
}
}
}
}
impl AsyncCudaResource {
pub(crate) fn reap_pending_with<F>(&self, mut sync_stream: F) -> ResourceResult<()>
where
F: FnMut(StreamId) -> ResourceResult<()>,
{
let drained: HashMap<StreamId, usize> = {
let mut per_stream = self
.pending_per_stream
.lock()
.expect("AsyncCudaResource pending_per_stream poisoned");
std::mem::take(&mut *per_stream)
};
if drained.is_empty() {
return Ok(());
}
let mut synced_total: usize = 0;
let mut failure: Option<ResourceError> = None;
let mut unsynced: Vec<(StreamId, usize)> = Vec::new();
let mut iter = drained.into_iter();
while let Some((stream_id, bytes)) = iter.next() {
match sync_stream(stream_id) {
Ok(()) => {
synced_total = synced_total.saturating_add(bytes);
}
Err(e) => {
unsynced.push((stream_id, bytes));
unsynced.extend(iter.by_ref());
failure = Some(e);
break;
}
}
}
if !unsynced.is_empty() {
let mut per_stream = self
.pending_per_stream
.lock()
.expect("AsyncCudaResource pending_per_stream poisoned");
for (stream_id, bytes) in unsynced {
*per_stream.entry(stream_id).or_insert(0) += bytes;
}
}
if synced_total > 0 {
self.pending_bytes
.fetch_sub(synced_total, Ordering::Relaxed);
}
match failure {
Some(e) => Err(e),
None => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn try_setup() -> Option<(Arc<CudaDevice>, Arc<StreamPool>)> {
let device = Arc::new(CudaDevice::new(0).ok()?);
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
Some((device, pool))
}
#[test]
fn allocate_then_deallocate_round_trips_on_default_stream() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(device, 0, pool);
let block = r
.allocate(2048, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
assert_eq!(block.bytes, 2048);
assert_eq!(block.alloc_stream, StreamId::DEFAULT);
assert_eq!(r.bytes_outstanding(), 2048);
assert_eq!(r.live_bytes(), 2048);
assert_eq!(r.pending_free_bytes(), 0);
r.deallocate(block).expect("dealloc");
assert_eq!(r.live_bytes(), 0);
assert_eq!(r.pending_free_bytes(), 2048);
assert_eq!(r.bytes_outstanding(), 2048);
r.reap_pending().expect("reap pending");
assert_eq!(r.bytes_outstanding(), 0);
assert_eq!(r.pending_free_bytes(), 0);
}
#[test]
fn allocate_on_acquired_non_default_stream() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(device, 0, Arc::clone(&pool));
let stream = pool.acquire().expect("acquire non-default stream");
let block = r
.allocate(1024, stream, AllocTag("async-test"))
.expect("alloc on non-default stream");
assert_eq!(block.alloc_stream, stream);
r.deallocate(block).expect("dealloc");
assert_eq!(r.bytes_outstanding(), 1024);
r.reap_pending().expect("reap pending");
assert_eq!(r.bytes_outstanding(), 0);
}
#[test]
fn allocate_unknown_stream_id_rejected() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(device, 0, pool);
let err = r.allocate(64, StreamId(99), AllocTag::UNTAGGED);
assert!(matches!(err, Err(ResourceError::StreamMisuse(_))));
}
#[test]
fn deallocate_unknown_block_returns_use_after_free() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(device, 0, pool);
let bogus = DeviceBlock {
ptr: 0xfeed_face,
device_ordinal: 0,
alloc_stream: StreamId::DEFAULT,
bytes: 16,
align: 1,
tag: AllocTag::UNTAGGED,
generation: Generation::next(),
state: BlockState::Live,
};
assert!(matches!(
r.deallocate(bogus),
Err(ResourceError::UseAfterFree { .. })
));
}
#[test]
fn reap_with_no_pending_is_noop() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(device, 0, pool);
r.reap_pending().expect("reap on empty");
assert_eq!(r.bytes_outstanding(), 0);
}
fn install_pending(r: &AsyncCudaResource, entries: &[(StreamId, usize)]) {
let mut per_stream = r
.pending_per_stream
.lock()
.expect("AsyncCudaResource pending_per_stream poisoned");
let mut total: usize = 0;
for (id, bytes) in entries {
*per_stream.entry(*id).or_insert(0) += *bytes;
total = total.saturating_add(*bytes);
}
drop(per_stream);
r.pending_bytes.fetch_add(total, Ordering::Relaxed);
}
#[test]
fn reap_pending_recovers_unsynced_streams_when_sync_fails() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool));
install_pending(&r, &[(StreamId(1), 1024), (StreamId(2), 2048)]);
assert_eq!(r.pending_free_bytes(), 3072);
assert_eq!(r.pending_per_stream_total(), 3072);
let synced = std::sync::Mutex::new(Vec::<StreamId>::new());
let result = r.reap_pending_with(|stream_id| {
if stream_id == StreamId(2) {
Err(ResourceError::Driver(
"simulated sync failure on StreamId(2)".into(),
))
} else {
synced.lock().unwrap().push(stream_id);
Ok(())
}
});
assert!(matches!(result, Err(ResourceError::Driver(_))));
let synced = synced.into_inner().unwrap();
let synced_bytes: usize = if synced.contains(&StreamId(1)) {
1024
} else {
0
};
let expected_pending = 3072 - synced_bytes;
assert_eq!(
r.pending_free_bytes(),
expected_pending,
"synced={:?}; pending_bytes must reflect only un-synced bytes",
synced
);
assert_eq!(
r.pending_per_stream_total(),
expected_pending,
"synced={:?}; pending_per_stream_total must equal pending_free_bytes \
(cross-counter invariant)",
synced
);
r.reap_pending_with(|_| Ok(())).expect("retry reap");
assert_eq!(r.pending_free_bytes(), 0);
assert_eq!(r.pending_per_stream_total(), 0);
}
#[test]
fn reap_pending_drains_normally_when_sync_always_succeeds() {
let Some((device, pool)) = try_setup() else {
return;
};
let r = AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool));
install_pending(&r, &[(StreamId(1), 256), (StreamId(2), 512)]);
r.reap_pending_with(|_| Ok(())).expect("reap");
assert_eq!(r.pending_free_bytes(), 0);
assert_eq!(r.pending_per_stream_total(), 0);
}
}