use std::ptr;
use super::context::get_driver;
use super::sys::{
CUgraph, CUgraphExec, CUgraphExecUpdateResult, CUgraphNode, CUstream, CudaDriver,
CU_GRAPH_EXEC_UPDATE_SUCCESS,
};
use crate::GpuError;
pub struct CudaGraph {
graph: CUgraph,
}
unsafe impl Send for CudaGraph {}
unsafe impl Sync for CudaGraph {}
impl CudaGraph {
pub fn new() -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut graph: CUgraph = ptr::null_mut();
let result = unsafe { (driver.cuGraphCreate)(&mut graph, 0) };
CudaDriver::check(result).map_err(|e| GpuError::GraphCreate(e.to_string()))?;
Ok(Self { graph })
}
pub(crate) fn from_raw(graph: CUgraph) -> Self {
Self { graph }
}
#[must_use]
pub fn raw(&self) -> CUgraph {
self.graph
}
#[allow(clippy::too_many_arguments)]
pub fn add_kernel_node(
&mut self,
func: super::sys::CUfunction,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
args: &mut [*mut std::ffi::c_void],
deps: &[super::sys::CUgraphNode],
) -> Result<super::sys::CUgraphNode, GpuError> {
let driver = get_driver()?;
let params = super::sys::CudaKernelNodeParams {
func,
grid_dim_x: grid.0,
grid_dim_y: grid.1,
grid_dim_z: grid.2,
block_dim_x: block.0,
block_dim_y: block.1,
block_dim_z: block.2,
shared_mem_bytes: shared_mem,
kernel_params: args.as_mut_ptr(),
extra: ptr::null_mut(),
};
let mut node: super::sys::CUgraphNode = ptr::null_mut();
let result = unsafe {
(driver.cuGraphAddKernelNode)(
&mut node,
self.graph,
if deps.is_empty() {
ptr::null()
} else {
deps.as_ptr()
},
deps.len(),
¶ms,
)
};
CudaDriver::check(result)
.map_err(|e| GpuError::GraphCreate(format!("add_kernel_node: {e}")))?;
Ok(node)
}
pub fn instantiate(&self) -> Result<CudaGraphExec, GpuError> {
let driver = get_driver()?;
let mut graph_exec: CUgraphExec = ptr::null_mut();
let result =
unsafe { (driver.cuGraphInstantiateWithFlags)(&mut graph_exec, self.graph, 0) };
CudaDriver::check(result).map_err(|e| GpuError::GraphInstantiate(e.to_string()))?;
Ok(CudaGraphExec::from_raw(graph_exec))
}
}
impl Default for CudaGraph {
fn default() -> Self {
Self::new().expect("Failed to create CUDA graph")
}
}
impl Drop for CudaGraph {
fn drop(&mut self) {
if !self.graph.is_null() {
if let Ok(driver) = get_driver() {
unsafe { (driver.cuGraphDestroy)(self.graph) };
}
}
}
}
pub struct CudaGraphExec {
exec: CUgraphExec,
}
unsafe impl Send for CudaGraphExec {}
unsafe impl Sync for CudaGraphExec {}
impl CudaGraphExec {
pub(crate) fn from_raw(exec: CUgraphExec) -> Self {
Self { exec }
}
#[must_use]
pub fn raw(&self) -> CUgraphExec {
self.exec
}
pub fn update(&self, new_graph: &CudaGraph) -> Result<bool, GpuError> {
let driver = get_driver()?;
let mut error_node: CUgraphNode = ptr::null_mut();
let mut update_result: CUgraphExecUpdateResult = 0;
let result = unsafe {
(driver.cuGraphExecUpdate)(
self.exec,
new_graph.graph,
&mut error_node,
&mut update_result,
)
};
if result != 0 && update_result != CU_GRAPH_EXEC_UPDATE_SUCCESS {
return Ok(false); }
CudaDriver::check(result)
.map_err(|e| GpuError::GraphInstantiate(format!("graph update: {e}")))?;
Ok(update_result == CU_GRAPH_EXEC_UPDATE_SUCCESS)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn launch(&self, stream: CUstream) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuGraphLaunch)(self.exec, stream) };
CudaDriver::check(result).map_err(|e| GpuError::GraphLaunch(e.to_string()))
}
}
impl Drop for CudaGraphExec {
fn drop(&mut self) {
if !self.exec.is_null() {
if let Ok(driver) = get_driver() {
unsafe { (driver.cuGraphExecDestroy)(self.exec) };
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CaptureMode {
#[default]
Global,
ThreadLocal,
Relaxed,
}
impl CaptureMode {
#[must_use]
pub fn to_cuda_mode(self) -> u32 {
match self {
CaptureMode::Global => 0, CaptureMode::ThreadLocal => 1, CaptureMode::Relaxed => 2, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capture_mode_values() {
assert_eq!(CaptureMode::Global.to_cuda_mode(), 0);
assert_eq!(CaptureMode::ThreadLocal.to_cuda_mode(), 1);
assert_eq!(CaptureMode::Relaxed.to_cuda_mode(), 2);
}
#[test]
fn test_capture_mode_default() {
assert_eq!(CaptureMode::default(), CaptureMode::Global);
}
}