use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use cudarc::driver::safe::CudaGraph as CudarcGraph;
struct CudaGraphInner(CudarcGraph);
unsafe impl Send for CudaGraphInner {}
pub struct CudaGraph {
inner: Arc<Mutex<CudaGraphInner>>,
launch_count: Arc<AtomicUsize>,
}
impl Clone for CudaGraph {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
launch_count: self.launch_count.clone(),
}
}
}
impl std::fmt::Debug for CudaGraph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaGraph")
.field("launch_count", &self.launch_count.load(Ordering::Relaxed))
.finish()
}
}
impl CudaGraph {
pub(crate) fn new(graph: CudarcGraph) -> Self {
Self {
inner: Arc::new(Mutex::new(CudaGraphInner(graph))),
launch_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn launch_count(&self) -> usize {
self.launch_count.load(Ordering::Relaxed)
}
}
impl crate::runtime::Graph for CudaGraph {
fn launch(&self) -> crate::error::Result<()> {
let guard = self.inner.lock().unwrap_or_else(|p| p.into_inner());
guard
.0
.launch()
.map_err(|e| crate::error::Error::Backend(format!("CUDA graph launch failed: {e}")))?;
self.launch_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
fn is_replay_capable(&self) -> bool {
true
}
}
unsafe impl Send for CudaGraph {}
unsafe impl Sync for CudaGraph {}