use std::{
ffi::c_void,
mem::{ManuallyDrop, MaybeUninit},
os::raw::{c_char, c_uint},
path::Path,
ptr,
};
use crate::{
error::{CudaResult, ToResult},
function::{BlockSize, GridSize},
sys as cuda,
};
#[macro_export]
macro_rules! kernel_invocation {
($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
{
let name = std::ffi::CString::new(stringify!($function)).unwrap();
let function = $module.get_function(&name);
match function {
Ok(f) => kernel_invocation!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
Err(e) => Err(e),
}
}
};
($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
{
fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
if false {
$(
assert_impl_devicecopy($arg);
)*
};
let boxed = vec![$(&$arg as *const _ as *mut ::std::ffi::c_void),*].into_boxed_slice();
Ok($crate::graph::KernelInvocation::_new_internal(
$crate::function::BlockSize::from($block),
$crate::function::GridSize::from($grid),
$shared,
$function.to_raw(),
vec![].into_boxed_slice(),
))
}
};
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KernelInvocation {
pub block_dim: BlockSize,
pub grid_dim: GridSize,
pub shared_mem_bytes: u32,
func: cuda::CUfunction,
params: Box<*mut c_void>,
params_len: Option<usize>,
}
impl KernelInvocation {
#[doc(hidden)]
pub fn _new_internal(
block_dim: BlockSize,
grid_dim: GridSize,
shared_mem_bytes: u32,
func: cuda::CUfunction,
params: Box<*mut c_void>,
params_len: usize,
) -> Self {
Self {
block_dim,
grid_dim,
shared_mem_bytes,
func,
params,
params_len: Some(params_len),
}
}
pub fn to_raw(self) -> cuda::CUDA_KERNEL_NODE_PARAMS {
cuda::CUDA_KERNEL_NODE_PARAMS {
func: self.func,
gridDimX: self.grid_dim.x,
gridDimY: self.grid_dim.y,
gridDimZ: self.grid_dim.z,
blockDimX: self.block_dim.x,
blockDimY: self.block_dim.y,
blockDimZ: self.block_dim.z,
kernelParams: Box::into_raw(self.params),
sharedMemBytes: self.shared_mem_bytes,
extra: ptr::null_mut(),
}
}
pub unsafe fn from_raw(raw: cuda::CUDA_KERNEL_NODE_PARAMS) -> Self {
Self {
func: raw.func,
grid_dim: GridSize::xyz(raw.gridDimX, raw.gridDimY, raw.gridDimZ),
block_dim: BlockSize::xyz(raw.blockDimX, raw.gridDimY, raw.gridDimZ),
params: Box::from_raw(raw.kernelParams),
shared_mem_bytes: raw.sharedMemBytes,
params_len: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct GraphNode {
raw: cuda::CUgraphNode,
}
unsafe impl Send for GraphNode {}
unsafe impl Sync for GraphNode {}
impl GraphNode {
pub fn from_raw(raw: cuda::CUgraphNode) -> Self {
Self { raw }
}
pub fn to_raw(self) -> cuda::CUgraphNode {
self.raw
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum GraphNodeType {
KernelInvocation,
Memcpy,
Memset,
HostExecute,
ChildGraph,
Empty,
WaitEvent,
EventRecord,
SemaphoreSignal,
SemaphoreWait,
MemoryAllocation,
MemoryFree,
}
impl GraphNodeType {
pub fn from_raw(raw: cuda::CUgraphNodeType) -> Self {
match raw {
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL => GraphNodeType::KernelInvocation,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY => GraphNodeType::Memcpy,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET => GraphNodeType::Memset,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST => GraphNodeType::HostExecute,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH => GraphNodeType::ChildGraph,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY => GraphNodeType::Empty,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT => GraphNodeType::WaitEvent,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD => GraphNodeType::EventRecord,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL => {
GraphNodeType::SemaphoreSignal
}
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT => {
GraphNodeType::SemaphoreWait
}
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC => GraphNodeType::MemoryAllocation,
cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE => GraphNodeType::MemoryFree,
}
}
pub fn to_raw(self) -> cuda::CUgraphNodeType {
match self {
Self::KernelInvocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL,
Self::Memcpy => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY,
Self::Memset => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET,
Self::HostExecute => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST,
Self::ChildGraph => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH,
Self::Empty => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY,
Self::WaitEvent => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT,
Self::EventRecord => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD,
Self::SemaphoreSignal => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL,
Self::SemaphoreWait => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT,
Self::MemoryAllocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC,
Self::MemoryFree => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE,
}
}
}
#[derive(Debug)]
pub struct Graph {
raw: cuda::CUgraph,
node_cache: Option<Vec<GraphNode>>,
}
unsafe impl Send for Graph {}
unsafe impl Sync for Graph {}
bitflags::bitflags! {
#[derive(Default)]
pub struct GraphCreationFlags: u32 {
const NONE = 0b00000000;
}
}
impl Graph {
fn check_deps_are_valid(&mut self, func_name: &str, nodes: &[GraphNode]) -> CudaResult<()> {
for (idx, node) in nodes.iter().enumerate() {
if let Some(pos) = nodes
.iter()
.enumerate()
.position(|(cur_idx, x)| x == node && cur_idx != idx)
{
panic!("Duplicate dependency found in call to `{}`, the first instance is at index {}, the second instance is at index {}", func_name, idx, pos);
}
assert!(
self.is_valid_node(*node)?,
"Invalid (dropped or from another graph) node was given to `{}`",
func_name
);
}
Ok(())
}
pub fn is_valid_node(&mut self, node: GraphNode) -> CudaResult<bool> {
let nodes = self.nodes()?;
Ok(nodes.contains(&node))
}
pub fn num_nodes(&mut self) -> CudaResult<usize> {
unsafe {
let mut len = MaybeUninit::uninit();
cuda::cuGraphGetNodes(self.raw, ptr::null_mut(), len.as_mut_ptr()).to_result()?;
Ok(len.assume_init())
}
}
pub fn nodes(&mut self) -> CudaResult<&[GraphNode]> {
if self.node_cache.is_none() {
unsafe {
let mut len = self.num_nodes()?;
let mut vec = Vec::with_capacity(len);
cuda::cuGraphGetNodes(
self.raw,
vec.as_mut_ptr() as *mut cuda::CUgraphNode,
&mut len as *mut usize,
)
.to_result()?;
vec.set_len(len);
self.node_cache = Some(vec);
}
}
Ok(self.node_cache.as_ref().unwrap())
}
pub fn new(flags: GraphCreationFlags) -> CudaResult<Self> {
let mut raw = MaybeUninit::uninit();
unsafe {
cuda::cuGraphCreate(raw.as_mut_ptr(), flags.bits).to_result()?;
Ok(Self {
raw: raw.assume_init(),
node_cache: Some(vec![]),
})
}
}
#[cfg(any(windows, unix))]
pub fn dump_debug_dotfile<P: AsRef<Path>>(&mut self, path: P) -> CudaResult<()> {
extern "C" {
fn cuGraphDebugDotPrint(
hGraph: cuda::CUgraph,
path: *const c_char,
flags: c_uint,
) -> cuda::CUresult;
}
let path = path.as_ref();
let mut buf = Vec::new();
#[cfg(unix)]
{
use std::os::unix::ffi::OsStrExt;
buf.extend(path.as_os_str().as_bytes());
buf.push(0);
}
#[cfg(windows)]
{
use std::os::windows::ffi::OsStrExt;
buf.extend(
path.as_os_str()
.encode_wide()
.chain(Some(0))
.map(|b| {
let b = b.to_ne_bytes();
b.get(0).copied().into_iter().chain(b.get(1).copied())
})
.flatten(),
);
}
unsafe { cuGraphDebugDotPrint(self.raw, "./out.dot\0".as_ptr().cast(), 1 << 0).to_result() }
}
pub fn add_kernel_node(
&mut self,
invocation: KernelInvocation,
dependencies: impl AsRef<[GraphNode]>,
) -> CudaResult<GraphNode> {
let deps = dependencies.as_ref();
self.check_deps_are_valid("add_kernel_node", deps)?;
self.node_cache = None;
unsafe {
let deps_ptr = deps.as_ptr().cast();
let mut node = MaybeUninit::<GraphNode>::uninit();
let params = invocation.to_raw();
cuda::cuGraphAddKernelNode(
node.as_mut_ptr().cast(),
self.raw,
deps_ptr,
deps.len(),
¶ms as *const _,
)
.to_result()?;
Ok(node.assume_init())
}
}
pub fn num_edges(&mut self) -> CudaResult<usize> {
unsafe {
let mut size = MaybeUninit::uninit();
cuda::cuGraphGetEdges(
self.raw,
ptr::null_mut(),
ptr::null_mut(),
size.as_mut_ptr(),
)
.to_result()?;
Ok(size.assume_init())
}
}
pub fn edges(&mut self) -> CudaResult<Vec<(GraphNode, GraphNode)>> {
unsafe {
let num_edges = self.num_edges()?;
let mut from = vec![ptr::null_mut(); num_edges].into_boxed_slice();
let mut to = vec![ptr::null_mut(); num_edges].into_boxed_slice();
cuda::cuGraphGetEdges(
self.raw,
from.as_mut_ptr(),
to.as_mut_ptr(),
&num_edges as *const _ as *mut usize,
)
.to_result()?;
let mut out = Vec::with_capacity(num_edges);
for (from, to) in from.iter().zip(to.iter()) {
out.push((GraphNode::from_raw(*from), GraphNode::from_raw(*to)))
}
Ok(out)
}
}
pub fn node_type(&mut self, node: GraphNode) -> CudaResult<GraphNodeType> {
self.check_deps_are_valid("node_type", &[node])?;
unsafe {
let mut ty = MaybeUninit::uninit();
cuda::cuGraphNodeGetType(node.to_raw(), ty.as_mut_ptr()).to_result()?;
let raw = ty.assume_init();
Ok(GraphNodeType::from_raw(raw))
}
}
pub fn kernel_node_params(&mut self, node: GraphNode) -> CudaResult<KernelInvocation> {
self.check_deps_are_valid("kernel_node_params", &[node])?;
assert_eq!(
self.node_type(node)?,
GraphNodeType::KernelInvocation,
"Node given to `kernel_node_params` was not a kernel invocation node"
);
unsafe {
let mut params = MaybeUninit::uninit();
cuda::cuGraphKernelNodeGetParams(node.to_raw(), params.as_mut_ptr());
Ok(KernelInvocation::from_raw(params.assume_init()))
}
}
pub unsafe fn from_raw(raw: cuda::CUgraph) -> Self {
Self {
raw,
node_cache: None,
}
}
pub fn into_raw(self) -> cuda::CUgraph {
let me = ManuallyDrop::new(self);
me.raw
}
}
impl Drop for Graph {
fn drop(&mut self) {
unsafe {
cuda::cuGraphDestroy(self.raw);
}
}
}