#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
use crate::error::{GpuError, GpuResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum CaptureMode {
Global,
#[default]
ThreadLocal,
Relaxed,
}
#[cfg(feature = "cuda")]
impl CaptureMode {
#[inline]
pub fn to_cuda(self) -> cudarc::driver::sys::CUstreamCaptureMode {
use cudarc::driver::sys::CUstreamCaptureMode::*;
match self {
Self::Global => CU_STREAM_CAPTURE_MODE_GLOBAL,
Self::ThreadLocal => CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
Self::Relaxed => CU_STREAM_CAPTURE_MODE_RELAXED,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CaptureStatus {
None,
Active,
Invalidated,
}
#[cfg(feature = "cuda")]
impl CaptureStatus {
fn from_cuda(raw: cudarc::driver::sys::CUstreamCaptureStatus) -> Self {
use cudarc::driver::sys::CUstreamCaptureStatus::*;
match raw {
CU_STREAM_CAPTURE_STATUS_NONE => Self::None,
CU_STREAM_CAPTURE_STATUS_ACTIVE => Self::Active,
CU_STREAM_CAPTURE_STATUS_INVALIDATED => Self::Invalidated,
}
}
}
impl CaptureStatus {
#[inline]
pub fn is_capturing(&self) -> bool {
matches!(self, Self::Active)
}
#[inline]
pub fn is_invalidated(&self) -> bool {
matches!(self, Self::Invalidated)
}
}
#[cfg(feature = "cuda")]
pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
buf: CudaSlice<T>,
stream: Arc<CudaStream>,
}
#[cfg(feature = "cuda")]
impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
let buf = stream.clone_htod(&[initial])?;
Ok(Self {
buf,
stream: Arc::clone(stream),
})
}
pub fn update(&mut self, value: T) -> GpuResult<()> {
self.stream.memcpy_htod(&[value], &mut self.buf)?;
Ok(())
}
#[inline]
pub fn inner(&self) -> &CudaSlice<T> {
&self.buf
}
}
#[cfg(feature = "cuda")]
pub struct CapturedGraph {
graph: cudarc::driver::CudaGraph,
pool: Option<Arc<CapturePool>>,
replay_count: AtomicU64,
uploaded: std::sync::atomic::AtomicBool,
}
#[cfg(feature = "cuda")]
impl CapturedGraph {
pub fn launch(&self) -> GpuResult<()> {
self.graph.launch()?;
self.replay_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn upload(&self) -> GpuResult<()> {
if self.uploaded.load(Ordering::Acquire) {
return Ok(());
}
self.graph.upload()?;
self.uploaded.store(true, Ordering::Release);
Ok(())
}
#[inline]
pub fn num_replays(&self) -> u64 {
self.replay_count.load(Ordering::Relaxed)
}
#[inline]
pub fn is_uploaded(&self) -> bool {
self.uploaded.load(Ordering::Acquire)
}
pub fn pool_buffer_count(&self) -> usize {
self.pool.as_ref().map(|p| p.buffer_count()).unwrap_or(0)
}
pub fn has_pool(&self) -> bool {
self.pool.is_some()
}
pub fn pool(&self) -> Option<&Arc<CapturePool>> {
self.pool.as_ref()
}
}
#[cfg(feature = "cuda")]
pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
begin_capture_with_mode(stream, CaptureMode::default())
}
#[cfg(feature = "cuda")]
pub fn begin_capture_with_mode(stream: &Arc<CudaStream>, mode: CaptureMode) -> GpuResult<()> {
stream.begin_capture(mode.to_cuda())?;
Ok(())
}
#[cfg(feature = "cuda")]
pub fn capture_status(stream: &Arc<CudaStream>) -> GpuResult<CaptureStatus> {
let raw = stream.capture_status()?;
Ok(CaptureStatus::from_cuda(raw))
}
#[cfg(feature = "cuda")]
pub fn is_stream_capturing(stream: &Arc<CudaStream>) -> GpuResult<bool> {
Ok(capture_status(stream)?.is_capturing())
}
#[cfg(feature = "cuda")]
pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
let graph = stream.end_capture(flags)?.ok_or(GpuError::InvalidState {
message: "CUDA graph capture returned null".to_string(),
})?;
Ok(CapturedGraph {
graph,
pool: None,
replay_count: AtomicU64::new(0),
uploaded: std::sync::atomic::AtomicBool::new(false),
})
}
#[cfg(feature = "cuda")]
pub fn end_capture_with_pool(
stream: &Arc<CudaStream>,
pool: Arc<CapturePool>,
) -> GpuResult<CapturedGraph> {
let mut graph = end_capture(stream)?;
graph.pool = Some(pool);
Ok(graph)
}
#[cfg(feature = "cuda")]
pub struct GraphCaptureGuard {
stream: Arc<CudaStream>,
pool: Option<Arc<CapturePool>>,
active: bool,
}
#[cfg(feature = "cuda")]
impl GraphCaptureGuard {
pub fn begin(stream: &Arc<CudaStream>) -> GpuResult<Self> {
Self::begin_with_mode(stream, CaptureMode::default())
}
pub fn begin_with_mode(stream: &Arc<CudaStream>, mode: CaptureMode) -> GpuResult<Self> {
begin_capture_with_mode(stream, mode)?;
Ok(Self {
stream: Arc::clone(stream),
pool: None,
active: true,
})
}
pub fn begin_with_pool(stream: &Arc<CudaStream>, pool: Arc<CapturePool>) -> GpuResult<Self> {
begin_capture_with_pool(&pool, stream)?;
Ok(Self {
stream: Arc::clone(stream),
pool: Some(pool),
active: true,
})
}
pub fn finish(mut self) -> GpuResult<CapturedGraph> {
self.active = false;
if let Some(pool) = self.pool.take() {
end_capture_with_pool(&self.stream, pool)
} else {
end_capture(&self.stream)
}
}
pub fn status(&self) -> GpuResult<CaptureStatus> {
capture_status(&self.stream)
}
}
#[cfg(feature = "cuda")]
impl Drop for GraphCaptureGuard {
fn drop(&mut self) {
if !self.active {
return;
}
let _ = end_capture(&self.stream);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphPoolHandle(pub u64);
#[cfg(feature = "cuda")]
static NEXT_POOL_HANDLE: AtomicU64 = AtomicU64::new(1);
#[cfg(feature = "cuda")]
static POOL_REGISTRY: std::sync::OnceLock<
std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>>,
> = std::sync::OnceLock::new();
#[cfg(feature = "cuda")]
fn pool_registry() -> &'static std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>> {
POOL_REGISTRY.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
}
#[cfg(feature = "cuda")]
pub fn graph_pool_handle() -> GraphPoolHandle {
let id = NEXT_POOL_HANDLE.fetch_add(1, Ordering::Relaxed);
let pool = Arc::new(CapturePool::new());
let mut reg = pool_registry().lock().unwrap_or_else(|p| p.into_inner());
reg.insert(id, pool);
GraphPoolHandle(id)
}
#[cfg(feature = "cuda")]
pub fn capture_pool_for_handle(handle: GraphPoolHandle) -> Option<Arc<CapturePool>> {
let reg = pool_registry().lock().unwrap_or_else(|p| p.into_inner());
reg.get(&handle.0).cloned()
}
#[cfg(feature = "cuda")]
pub fn release_graph_pool_handle(handle: GraphPoolHandle) {
let mut reg = pool_registry().lock().unwrap_or_else(|p| p.into_inner());
reg.remove(&handle.0);
}
#[cfg(feature = "cuda")]
pub fn make_graphed_callable<F>(
stream: &Arc<CudaStream>,
mode: CaptureMode,
f: F,
) -> GpuResult<CapturedGraph>
where
F: FnOnce() -> GpuResult<()>,
{
let guard = GraphCaptureGuard::begin_with_mode(stream, mode)?;
match f() {
Ok(()) => guard.finish(),
Err(e) => {
drop(guard);
Err(e)
}
}
}
#[cfg(feature = "cuda")]
pub struct CapturePool {
sealed: std::sync::atomic::AtomicBool,
buffers: std::sync::Mutex<Vec<Box<dyn std::any::Any + Send + Sync + 'static>>>,
}
#[cfg(feature = "cuda")]
impl CapturePool {
pub fn new() -> Self {
Self {
sealed: std::sync::atomic::AtomicBool::new(false),
buffers: std::sync::Mutex::new(Vec::new()),
}
}
pub fn seal(&self) {
self.sealed
.store(true, std::sync::atomic::Ordering::Release);
}
pub fn unseal(&self) {
self.sealed
.store(false, std::sync::atomic::Ordering::Release);
}
pub fn is_capture_pool_sealed(&self) -> bool {
self.sealed.load(std::sync::atomic::Ordering::Acquire)
}
pub fn record_buffer<B>(&self, buffer: B) -> usize
where
B: Send + Sync + 'static,
{
let mut guard = self.buffers.lock().unwrap_or_else(|p| p.into_inner());
let idx = guard.len();
guard.push(Box::new(buffer));
idx
}
pub fn buffer_count(&self) -> usize {
self.buffers.lock().map(|g| g.len()).unwrap_or(0)
}
pub fn clear_buffers(&self) {
let mut guard = self.buffers.lock().unwrap_or_else(|p| p.into_inner());
guard.clear();
}
}
#[cfg(feature = "cuda")]
impl Default for CapturePool {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "cuda")]
pub fn begin_capture_with_pool(pool: &CapturePool, stream: &Arc<CudaStream>) -> GpuResult<()> {
if pool.is_capture_pool_sealed() {
return Err(GpuError::InvalidState {
message: "cannot begin graph capture: capture pool is sealed".into(),
});
}
begin_capture(stream)
}
#[cfg(not(feature = "cuda"))]
pub struct CapturePool;
#[cfg(not(feature = "cuda"))]
impl CapturePool {
pub fn new() -> Self {
Self
}
pub fn seal(&self) {
}
pub fn unseal(&self) {
}
pub fn is_capture_pool_sealed(&self) -> bool {
false
}
pub fn buffer_count(&self) -> usize {
0
}
}
#[cfg(not(feature = "cuda"))]
impl Default for CapturePool {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(feature = "cuda"))]
pub fn begin_capture_with_pool<T>(_pool: &CapturePool, _stream: &T) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub struct DeviceScalar<T: Copy> {
_phantom: std::marker::PhantomData<T>,
}
#[cfg(not(feature = "cuda"))]
pub struct CapturedGraph;
#[cfg(not(feature = "cuda"))]
impl CapturedGraph {
pub fn launch(&self) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
pub fn upload(&self) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
pub fn num_replays(&self) -> u64 {
0
}
pub fn is_uploaded(&self) -> bool {
false
}
}
#[cfg(not(feature = "cuda"))]
pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn begin_capture_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn capture_status<T>(_stream: &T) -> GpuResult<CaptureStatus> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn is_stream_capturing<T>(_stream: &T) -> GpuResult<bool> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn end_capture_with_pool<T>(
_stream: &T,
_pool: std::sync::Arc<CapturePool>,
) -> GpuResult<CapturedGraph> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub struct GraphCaptureGuard {
_never: core::convert::Infallible,
}
#[cfg(not(feature = "cuda"))]
impl GraphCaptureGuard {
pub fn begin<T>(_stream: &T) -> GpuResult<Self> {
Err(GpuError::NoCudaFeature)
}
pub fn begin_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<Self> {
Err(GpuError::NoCudaFeature)
}
pub fn begin_with_pool<T>(_stream: &T, _pool: std::sync::Arc<CapturePool>) -> GpuResult<Self> {
Err(GpuError::NoCudaFeature)
}
pub fn finish(self) -> GpuResult<CapturedGraph> {
match self._never {}
}
pub fn status(&self) -> GpuResult<CaptureStatus> {
match self._never {}
}
}
#[cfg(not(feature = "cuda"))]
pub fn graph_pool_handle() -> GraphPoolHandle {
GraphPoolHandle(0)
}
#[cfg(not(feature = "cuda"))]
pub fn capture_pool_for_handle(_handle: GraphPoolHandle) -> Option<std::sync::Arc<CapturePool>> {
None
}
#[cfg(not(feature = "cuda"))]
pub fn release_graph_pool_handle(_handle: GraphPoolHandle) {
}
#[cfg(not(feature = "cuda"))]
pub fn make_graphed_callable<T, F>(
_stream: &T,
_mode: CaptureMode,
_f: F,
) -> GpuResult<CapturedGraph>
where
F: FnOnce() -> GpuResult<()>,
{
Err(GpuError::NoCudaFeature)
}
#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
#[test]
fn capture_pool_buffer_count_starts_at_zero() {
let pool = CapturePool::new();
assert_eq!(pool.buffer_count(), 0);
}
#[test]
fn capture_pool_record_buffer_increments_count() {
let pool = CapturePool::new();
let buf_a: Vec<f32> = vec![0.0; 10];
let idx = pool.record_buffer(buf_a);
assert_eq!(idx, 0);
assert_eq!(pool.buffer_count(), 1);
let buf_b: Vec<f64> = vec![0.0; 5];
let idx = pool.record_buffer(buf_b);
assert_eq!(idx, 1);
assert_eq!(pool.buffer_count(), 2);
}
#[test]
fn capture_pool_clear_buffers_resets_count_but_keeps_pool() {
let pool = CapturePool::new();
pool.record_buffer(vec![0u8; 16]);
pool.record_buffer(vec![0u8; 32]);
assert_eq!(pool.buffer_count(), 2);
pool.clear_buffers();
assert_eq!(pool.buffer_count(), 0);
pool.record_buffer(vec![0u8; 8]);
assert_eq!(pool.buffer_count(), 1);
}
#[test]
fn capture_pool_drop_releases_registered_buffers() {
let buf = Arc::new(vec![1.0f32, 2.0, 3.0]);
let pool = CapturePool::new();
pool.record_buffer(Arc::clone(&buf));
assert_eq!(Arc::strong_count(&buf), 2);
drop(pool);
assert_eq!(Arc::strong_count(&buf), 1);
}
#[test]
fn capture_pool_records_heterogeneous_types() {
let pool = CapturePool::new();
pool.record_buffer(vec![0.0f32; 4]);
pool.record_buffer(vec![0.0f64; 4]);
pool.record_buffer(vec![0u8; 4]);
pool.record_buffer(Arc::new(42i32));
assert_eq!(pool.buffer_count(), 4);
}
#[test]
fn capture_pool_seal_unseal() {
let pool = CapturePool::new();
assert!(!pool.is_capture_pool_sealed());
pool.seal();
assert!(pool.is_capture_pool_sealed());
pool.unseal();
assert!(!pool.is_capture_pool_sealed());
}
#[test]
fn capture_mode_default_is_thread_local() {
assert_eq!(CaptureMode::default(), CaptureMode::ThreadLocal);
}
#[test]
fn capture_mode_to_cuda_round_trip() {
use cudarc::driver::sys::CUstreamCaptureMode::*;
assert_eq!(CaptureMode::Global.to_cuda(), CU_STREAM_CAPTURE_MODE_GLOBAL);
assert_eq!(
CaptureMode::ThreadLocal.to_cuda(),
CU_STREAM_CAPTURE_MODE_THREAD_LOCAL
);
assert_eq!(
CaptureMode::Relaxed.to_cuda(),
CU_STREAM_CAPTURE_MODE_RELAXED
);
}
#[test]
fn capture_status_is_capturing_only_when_active() {
assert!(!CaptureStatus::None.is_capturing());
assert!(CaptureStatus::Active.is_capturing());
assert!(!CaptureStatus::Invalidated.is_capturing());
}
#[test]
fn capture_status_is_invalidated_only_when_broken() {
assert!(!CaptureStatus::None.is_invalidated());
assert!(!CaptureStatus::Active.is_invalidated());
assert!(CaptureStatus::Invalidated.is_invalidated());
}
#[test]
fn capture_status_from_cuda_maps_all_variants() {
use cudarc::driver::sys::CUstreamCaptureStatus::*;
assert_eq!(
CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_NONE),
CaptureStatus::None
);
assert_eq!(
CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_ACTIVE),
CaptureStatus::Active
);
assert_eq!(
CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_INVALIDATED),
CaptureStatus::Invalidated
);
}
#[test]
fn graph_pool_handle_allocates_unique_ids() {
let h1 = graph_pool_handle();
let h2 = graph_pool_handle();
assert_ne!(h1, h2, "each call should return a fresh id");
assert!(capture_pool_for_handle(h1).is_some());
assert!(capture_pool_for_handle(h2).is_some());
release_graph_pool_handle(h1);
release_graph_pool_handle(h2);
}
#[test]
fn graph_pool_handle_shares_single_pool_across_lookups() {
let h = graph_pool_handle();
let a = capture_pool_for_handle(h).expect("handle registered");
let b = capture_pool_for_handle(h).expect("handle still registered");
assert!(
Arc::ptr_eq(&a, &b),
"both lookups should return the same pool Arc"
);
a.record_buffer(vec![1.0f32, 2.0]);
assert_eq!(b.buffer_count(), 1);
release_graph_pool_handle(h);
assert!(capture_pool_for_handle(h).is_none());
assert_eq!(a.buffer_count(), 1);
}
#[test]
fn graph_pool_handle_release_is_idempotent() {
let h = graph_pool_handle();
assert!(capture_pool_for_handle(h).is_some());
release_graph_pool_handle(h);
release_graph_pool_handle(h); assert!(capture_pool_for_handle(h).is_none());
}
#[test]
fn graph_pool_handle_unknown_id_returns_none() {
let fake = GraphPoolHandle(u64::MAX);
assert!(capture_pool_for_handle(fake).is_none());
}
}
#[cfg(all(test, not(feature = "cuda")))]
mod no_cuda_tests {
use super::*;
#[test]
fn capture_mode_and_status_exist_without_cuda_feature() {
let _ = CaptureMode::default();
assert!(!CaptureStatus::None.is_capturing());
assert!(CaptureStatus::Active.is_capturing());
assert!(CaptureStatus::Invalidated.is_invalidated());
}
#[test]
fn graph_pool_handle_without_cuda_returns_sentinel() {
let h = graph_pool_handle();
assert_eq!(h.0, 0, "stub handle is always zero without cuda feature");
assert!(capture_pool_for_handle(h).is_none());
release_graph_pool_handle(h); }
#[test]
fn captured_graph_stub_num_replays_and_is_uploaded_are_zero() {
let g = CapturedGraph;
assert_eq!(g.num_replays(), 0);
assert!(!g.is_uploaded());
}
}