use std::ptr;
use super::context::get_driver;
use super::sys::{CUgraph, CUgraphExec, CUstream, CudaDriver};
use crate::GpuError;
pub struct CudaGraph {
graph: CUgraph,
}
unsafe impl Send for CudaGraph {}
unsafe impl Sync for CudaGraph {}
impl CudaGraph {
pub fn new() -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut graph: CUgraph = ptr::null_mut();
let result = unsafe { (driver.cuGraphCreate)(&mut graph, 0) };
CudaDriver::check(result).map_err(|e| GpuError::GraphCreate(e.to_string()))?;
Ok(Self { graph })
}
pub(crate) fn from_raw(graph: CUgraph) -> Self {
Self { graph }
}
#[must_use]
pub fn raw(&self) -> CUgraph {
self.graph
}
pub fn instantiate(&self) -> Result<CudaGraphExec, GpuError> {
let driver = get_driver()?;
let mut graph_exec: CUgraphExec = ptr::null_mut();
let result =
unsafe { (driver.cuGraphInstantiateWithFlags)(&mut graph_exec, self.graph, 0) };
CudaDriver::check(result).map_err(|e| GpuError::GraphInstantiate(e.to_string()))?;
Ok(CudaGraphExec::from_raw(graph_exec))
}
}
impl Default for CudaGraph {
fn default() -> Self {
Self::new().expect("Failed to create CUDA graph")
}
}
impl Drop for CudaGraph {
fn drop(&mut self) {
if !self.graph.is_null() {
if let Ok(driver) = get_driver() {
unsafe { (driver.cuGraphDestroy)(self.graph) };
}
}
}
}
pub struct CudaGraphExec {
exec: CUgraphExec,
}
unsafe impl Send for CudaGraphExec {}
unsafe impl Sync for CudaGraphExec {}
impl CudaGraphExec {
pub(crate) fn from_raw(exec: CUgraphExec) -> Self {
Self { exec }
}
#[must_use]
pub fn raw(&self) -> CUgraphExec {
self.exec
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn launch(&self, stream: CUstream) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuGraphLaunch)(self.exec, stream) };
CudaDriver::check(result).map_err(|e| GpuError::GraphLaunch(e.to_string()))
}
}
impl Drop for CudaGraphExec {
fn drop(&mut self) {
if !self.exec.is_null() {
if let Ok(driver) = get_driver() {
unsafe { (driver.cuGraphExecDestroy)(self.exec) };
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CaptureMode {
#[default]
Global,
ThreadLocal,
Relaxed,
}
impl CaptureMode {
#[must_use]
pub fn to_cuda_mode(self) -> u32 {
match self {
CaptureMode::Global => 0, CaptureMode::ThreadLocal => 1, CaptureMode::Relaxed => 2, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capture_mode_values() {
assert_eq!(CaptureMode::Global.to_cuda_mode(), 0);
assert_eq!(CaptureMode::ThreadLocal.to_cuda_mode(), 1);
assert_eq!(CaptureMode::Relaxed.to_cuda_mode(), 2);
}
#[test]
fn test_capture_mode_default() {
assert_eq!(CaptureMode::default(), CaptureMode::Global);
}
}