use std::sync::Mutex;
use super::resource::{
Access, AllocTag, BlockId, DeviceBlock, DeviceMemoryResource, ResourceError, ResourceResult,
StreamId,
};
struct BudgetState {
reserved: usize,
}
pub struct GlobalDeviceBudget {
inner: Box<dyn DeviceMemoryResource + Send + Sync>,
limit: usize,
state: Mutex<BudgetState>,
}
impl GlobalDeviceBudget {
pub fn new(inner: Box<dyn DeviceMemoryResource + Send + Sync>, limit: usize) -> Self {
let initial = inner.bytes_outstanding();
Self {
inner,
limit,
state: Mutex::new(BudgetState { reserved: initial }),
}
}
pub fn limit(&self) -> usize {
self.limit
}
pub fn reserved_bytes(&self) -> usize {
self.state
.lock()
.expect("GlobalDeviceBudget poisoned")
.reserved
}
pub fn remaining(&self) -> usize {
let state = self.state.lock().expect("GlobalDeviceBudget poisoned");
self.limit.saturating_sub(state.reserved)
}
}
impl DeviceMemoryResource for GlobalDeviceBudget {
fn allocate(
&self,
bytes: usize,
stream: StreamId,
tag: AllocTag,
) -> ResourceResult<DeviceBlock> {
{
let mut state = self.state.lock().expect("GlobalDeviceBudget poisoned");
let remaining = self.limit.saturating_sub(state.reserved);
if bytes <= remaining {
state.reserved = state.reserved.saturating_add(bytes);
drop(state);
return match self.inner.allocate(bytes, stream, tag) {
Ok(block) => Ok(block),
Err(e) => {
let mut state = self.state.lock().expect("GlobalDeviceBudget poisoned");
state.reserved = state.reserved.saturating_sub(bytes);
Err(e)
}
};
}
if bytes > self.limit {
return Err(ResourceError::OutOfBudget {
requested: bytes,
remaining,
});
}
}
let _ = self.reap_pending();
let mut state = self.state.lock().expect("GlobalDeviceBudget poisoned");
let remaining = self.limit.saturating_sub(state.reserved);
if bytes > remaining {
return Err(ResourceError::OutOfBudget {
requested: bytes,
remaining,
});
}
state.reserved = state.reserved.saturating_add(bytes);
drop(state);
match self.inner.allocate(bytes, stream, tag) {
Ok(block) => Ok(block),
Err(e) => {
let mut state = self.state.lock().expect("GlobalDeviceBudget poisoned");
state.reserved = state.reserved.saturating_sub(bytes);
Err(e)
}
}
}
fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
let mut state = self.state.lock().expect("GlobalDeviceBudget poisoned");
let before = self.inner.bytes_outstanding();
let result = self.inner.deallocate(block);
let after = self.inner.bytes_outstanding();
let freed = before.saturating_sub(after);
if freed > 0 {
state.reserved = state.reserved.saturating_sub(freed);
}
result
}
fn device_ordinal(&self) -> u32 {
self.inner.device_ordinal()
}
fn bytes_outstanding(&self) -> usize {
self.inner.bytes_outstanding()
}
fn reap_pending(&self) -> ResourceResult<()> {
let mut state = self.state.lock().expect("GlobalDeviceBudget poisoned");
let before = self.inner.bytes_outstanding();
let result = self.inner.reap_pending();
let after = self.inner.bytes_outstanding();
let freed = before.saturating_sub(after);
if freed > 0 {
state.reserved = state.reserved.saturating_sub(freed);
}
result
}
fn record_block_use(&self, block: &DeviceBlock, use_stream: StreamId) -> ResourceResult<()> {
self.inner.record_block_use(block, use_stream)
}
fn supports_block_use_tracking(&self) -> bool {
self.inner.supports_block_use_tracking()
}
fn prepare_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
self.inner.prepare_block_use(block, use_stream, access)
}
fn finish_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
self.inner.finish_block_use(block, use_stream, access)
}
}
#[cfg(test)]
mod tests {
use super::super::async_resource::AsyncCudaResource;
use super::super::direct::DirectCudaResource;
use super::super::resource::{BlockState, Generation};
use super::super::stream_pool::StreamPool;
use super::*;
use std::sync::Arc;
use crate::CudaDevice;
fn try_device() -> Option<Arc<CudaDevice>> {
CudaDevice::new(0).ok().map(Arc::new)
}
struct AlwaysFailAllocResource {
ord: u32,
outstanding: std::sync::atomic::AtomicUsize,
}
impl AlwaysFailAllocResource {
fn new(ord: u32) -> Self {
Self {
ord,
outstanding: std::sync::atomic::AtomicUsize::new(0),
}
}
}
impl DeviceMemoryResource for AlwaysFailAllocResource {
fn allocate(
&self,
_bytes: usize,
_stream: StreamId,
_tag: AllocTag,
) -> ResourceResult<DeviceBlock> {
Err(ResourceError::Driver("inner always fails".into()))
}
fn deallocate(&self, _block: DeviceBlock) -> ResourceResult<()> {
Ok(())
}
fn device_ordinal(&self) -> u32 {
self.ord
}
fn bytes_outstanding(&self) -> usize {
self.outstanding.load(std::sync::atomic::Ordering::Relaxed)
}
}
#[test]
fn allocate_within_limit_succeeds_and_updates_reserved() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let budget = GlobalDeviceBudget::new(inner, 64 * 1024);
let block = budget
.allocate(2048, StreamId::DEFAULT, AllocTag("budget-success"))
.expect("alloc within limit");
assert_eq!(budget.reserved_bytes(), 2048);
assert_eq!(budget.remaining(), 64 * 1024 - 2048);
assert_eq!(budget.bytes_outstanding(), 2048);
budget.deallocate(block).expect("dealloc");
assert_eq!(budget.reserved_bytes(), 0);
assert_eq!(budget.bytes_outstanding(), 0);
}
#[test]
fn allocate_at_exact_limit_succeeds_then_next_byte_rejected() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let budget = GlobalDeviceBudget::new(inner, 4096);
let block = budget
.allocate(4096, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc at exact limit");
assert_eq!(budget.reserved_bytes(), 4096);
assert_eq!(budget.remaining(), 0);
let err = budget.allocate(1, StreamId::DEFAULT, AllocTag::UNTAGGED);
assert!(
matches!(
err,
Err(ResourceError::OutOfBudget {
requested: 1,
remaining: 0
})
),
"expected OutOfBudget {{1,0}}, got {:?}",
err
);
assert_eq!(budget.reserved_bytes(), 4096);
budget.deallocate(block).expect("dealloc");
assert_eq!(budget.reserved_bytes(), 0);
}
#[test]
fn over_limit_alloc_returns_out_of_budget_with_correct_remaining() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let budget = GlobalDeviceBudget::new(inner, 1024);
let block = budget
.allocate(768, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("first alloc");
assert_eq!(budget.remaining(), 256);
let err = budget.allocate(512, StreamId::DEFAULT, AllocTag::UNTAGGED);
assert!(
matches!(
err,
Err(ResourceError::OutOfBudget {
requested: 512,
remaining: 256
})
),
"expected OutOfBudget {{512,256}}, got {:?}",
err
);
budget.deallocate(block).expect("dealloc");
}
#[test]
fn failed_inner_allocation_rolls_back_reservation() {
let inner = Box::new(AlwaysFailAllocResource::new(0));
let budget = GlobalDeviceBudget::new(inner, 1024 * 1024);
assert_eq!(budget.reserved_bytes(), 0);
let err = budget.allocate(2048, StreamId::DEFAULT, AllocTag::UNTAGGED);
assert!(matches!(err, Err(ResourceError::Driver(_))));
assert_eq!(budget.reserved_bytes(), 0);
assert_eq!(budget.remaining(), 1024 * 1024);
}
#[test]
fn deallocate_releases_budget_immediately_for_synchronous_inner() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let budget = GlobalDeviceBudget::new(inner, 16 * 1024);
let block = budget
.allocate(8 * 1024, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
assert_eq!(budget.reserved_bytes(), 8 * 1024);
budget.deallocate(block).expect("dealloc");
assert_eq!(
budget.reserved_bytes(),
0,
"synchronous inner releases budget at deallocate"
);
budget.reap_pending().expect("reap noop");
assert_eq!(budget.reserved_bytes(), 0);
}
#[test]
fn deallocate_holds_budget_for_async_inner_until_reap_pending() {
let Some(device) = try_device() else {
return;
};
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let inner = Box::new(AsyncCudaResource::new(
Arc::clone(&device),
0,
Arc::clone(&pool),
));
let budget = GlobalDeviceBudget::new(inner, 32 * 1024);
let block = budget
.allocate(4096, StreamId::DEFAULT, AllocTag("budget-async"))
.expect("alloc");
assert_eq!(budget.reserved_bytes(), 4096);
budget.deallocate(block).expect("dealloc");
assert_eq!(
budget.reserved_bytes(),
4096,
"async inner: budget must stay reserved until reap_pending drains pending free"
);
assert_eq!(budget.bytes_outstanding(), 4096);
budget.reap_pending().expect("reap");
assert_eq!(
budget.reserved_bytes(),
0,
"async inner: reap_pending releases the pending bytes"
);
assert_eq!(budget.bytes_outstanding(), 0);
}
#[test]
fn deallocate_unknown_block_does_not_release_budget() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let budget = GlobalDeviceBudget::new(inner, 16 * 1024);
let block = budget
.allocate(2048, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
assert_eq!(budget.reserved_bytes(), 2048);
let bogus = DeviceBlock {
ptr: 0xfeed_face,
device_ordinal: 0,
alloc_stream: StreamId::DEFAULT,
bytes: 1024,
align: 1,
tag: AllocTag::UNTAGGED,
generation: Generation::next(),
state: BlockState::Live,
};
let res = budget.deallocate(bogus);
assert!(matches!(res, Err(ResourceError::UseAfterFree { .. })));
assert_eq!(
budget.reserved_bytes(),
2048,
"bogus dealloc must not release budget"
);
budget.deallocate(block).expect("real dealloc");
assert_eq!(budget.reserved_bytes(), 0);
}
#[test]
fn forwards_device_ordinal() {
let inner = Box::new(AlwaysFailAllocResource::new(7));
let budget = GlobalDeviceBudget::new(inner, 1024);
assert_eq!(budget.device_ordinal(), 7);
}
}