use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use xlog_core::{Result, XlogError};
use super::direct::DirectCudaResource;
use super::resource::{
Access, AllocTag, BlockId, DeviceBlock, DeviceMemoryResource, ResourceResult, StreamId,
};
use super::stream_pool::StreamPool;
use crate::CudaDevice;
pub const MAX_DEVICE_ORDINALS: usize = 16;
static RUNTIMES: [OnceLock<&'static XlogDeviceRuntime>; MAX_DEVICE_ORDINALS] =
[const { OnceLock::new() }; MAX_DEVICE_ORDINALS];
static INIT_LOCKS: [Mutex<()>; MAX_DEVICE_ORDINALS] =
[const { Mutex::new(()) }; MAX_DEVICE_ORDINALS];
pub struct XlogDeviceRuntime {
device_ordinal: u32,
device: Arc<CudaDevice>,
stream_pool: Arc<StreamPool>,
resource: Mutex<Box<dyn DeviceMemoryResource + Send + Sync>>,
}
impl XlogDeviceRuntime {
pub fn with_resource(
device: Arc<CudaDevice>,
device_ordinal: u32,
stream_pool: Arc<StreamPool>,
resource: Box<dyn DeviceMemoryResource + Send + Sync>,
) -> Self {
Self {
device_ordinal,
device,
stream_pool,
resource: Mutex::new(resource),
}
}
pub fn try_get(ordinal: u32) -> Result<&'static XlogDeviceRuntime> {
let idx = ordinal as usize;
if idx >= MAX_DEVICE_ORDINALS {
return Err(XlogError::Kernel(format!(
"XlogDeviceRuntime: ordinal {} exceeds MAX_DEVICE_ORDINALS={}",
ordinal, MAX_DEVICE_ORDINALS
)));
}
if let Some(rt) = RUNTIMES[idx].get() {
return Ok(*rt);
}
let _guard = INIT_LOCKS[idx]
.lock()
.expect("XlogDeviceRuntime init mutex poisoned");
if let Some(rt) = RUNTIMES[idx].get() {
return Ok(*rt);
}
let device = Arc::new(CudaDevice::new(ordinal as usize).map_err(|e| {
XlogError::Kernel(format!(
"XlogDeviceRuntime: failed to open device {}: {}",
ordinal, e
))
})?);
let stream_pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let resource: Box<dyn DeviceMemoryResource + Send + Sync> =
Box::new(DirectCudaResource::new(Arc::clone(&device), ordinal));
let runtime = Box::new(XlogDeviceRuntime {
device_ordinal: ordinal,
device,
stream_pool,
resource: Mutex::new(resource),
});
let leaked: &'static XlogDeviceRuntime = Box::leak(runtime);
RUNTIMES[idx]
.set(leaked)
.map_err(|_| ())
.expect("XlogDeviceRuntime: OnceLock::set raced under INIT_LOCKS — bug");
Ok(leaked)
}
pub fn device_ordinal(&self) -> u32 {
self.device_ordinal
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn stream_pool(&self) -> &Arc<StreamPool> {
&self.stream_pool
}
pub fn allocate(
&self,
bytes: usize,
stream: StreamId,
tag: AllocTag,
) -> ResourceResult<DeviceBlock> {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.allocate(bytes, stream, tag)
}
pub fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.deallocate(block)
}
pub fn bytes_outstanding(&self) -> usize {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.bytes_outstanding()
}
pub fn reap_pending(&self) -> ResourceResult<()> {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.reap_pending()
}
pub fn record_block_use(
&self,
block: &DeviceBlock,
use_stream: StreamId,
) -> ResourceResult<()> {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.record_block_use(block, use_stream)
}
pub fn supports_block_use_tracking(&self) -> bool {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.supports_block_use_tracking()
}
pub fn prepare_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.prepare_block_use(block, use_stream, access)
}
pub fn finish_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
self.resource
.lock()
.expect("device-runtime resource poisoned")
.finish_block_use(block, use_stream, access)
}
pub fn prepare_first_use<T: cudarc::driver::DeviceRepr>(
&self,
slice: &crate::memory::TrackedCudaSlice<T>,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
let block = slice.runtime_block().ok_or_else(|| {
super::resource::ResourceError::StreamMisuse(
"prepare_first_use: slice is not runtime-backed (the helper's \
GpuMemoryManager must be built via with_runtime)"
.to_string(),
)
})?;
self.prepare_block_use(BlockId::from_block(block), use_stream, access)
}
pub fn finish_first_use<T: cudarc::driver::DeviceRepr>(
&self,
slice: &crate::memory::TrackedCudaSlice<T>,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
let block = slice.runtime_block().ok_or_else(|| {
super::resource::ResourceError::StreamMisuse(
"finish_first_use: slice is not runtime-backed".to_string(),
)
})?;
self.finish_block_use(BlockId::from_block(block), use_stream, access)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn try_runtime() -> Option<&'static XlogDeviceRuntime> {
XlogDeviceRuntime::try_get(0).ok()
}
#[test]
fn try_get_returns_same_singleton() {
let Some(a) = try_runtime() else {
return;
};
let b = XlogDeviceRuntime::try_get(0).expect("re-get");
assert!(std::ptr::eq(a, b), "singleton must be stable for ordinal 0");
assert_eq!(a.device_ordinal(), 0);
}
#[test]
fn allocate_then_deallocate_via_runtime() {
let Some(rt) = try_runtime() else {
return;
};
let before = rt.bytes_outstanding();
let block = rt
.allocate(2048, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
assert_eq!(block.bytes, 2048);
assert_eq!(rt.bytes_outstanding(), before + 2048);
rt.deallocate(block).expect("dealloc");
rt.reap_pending().expect("reap pending");
assert_eq!(rt.bytes_outstanding(), before);
}
#[test]
fn try_get_rejects_out_of_range_ordinal() {
let err = XlogDeviceRuntime::try_get(MAX_DEVICE_ORDINALS as u32);
assert!(err.is_err());
}
#[test]
fn with_resource_composes_owned_runtime_outside_singleton() {
use super::super::async_resource::AsyncCudaResource;
let Some(rt) = try_runtime() else {
return;
};
let device = Arc::clone(rt.device());
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let resource = Box::new(AsyncCudaResource::new(
Arc::clone(&device),
0,
Arc::clone(&pool),
));
let owned = XlogDeviceRuntime::with_resource(device, 0, pool, resource);
assert_eq!(owned.device_ordinal(), 0);
let block = owned
.allocate(1024, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc through composed runtime");
assert_eq!(block.bytes, 1024);
assert_eq!(owned.bytes_outstanding(), 1024);
owned.deallocate(block).expect("dealloc");
owned.reap_pending().expect("reap");
assert_eq!(owned.bytes_outstanding(), 0);
let singleton = XlogDeviceRuntime::try_get(0).expect("singleton");
assert!(
!std::ptr::eq(&owned, singleton),
"with_resource must not aliase the singleton slot"
);
}
#[test]
fn try_get_runtime_record_block_use_rejected_with_stream_misuse() {
let Some(rt) = try_runtime() else {
return;
};
let block = rt
.allocate(64, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc through runtime");
let err = rt.record_block_use(&block, StreamId::DEFAULT);
match err {
Err(super::super::resource::ResourceError::StreamMisuse(msg)) => {
assert!(
msg.contains("unsupported"),
"expected 'unsupported' in StreamMisuse message, got {:?}",
msg
);
}
other => panic!(
"XlogDeviceRuntime::try_get default (DirectCudaResource) must \
reject record_block_use with StreamMisuse; got {:?}",
other
),
}
rt.deallocate(block).expect("dealloc still works");
}
}