1use std::{
4 ffi::c_void,
5 mem::{ManuallyDrop, MaybeUninit},
6 os::raw::{c_char, c_uint},
7 path::Path,
8 ptr,
9};
10
11use crate::{
12 error::{CudaResult, ToResult},
13 function::{BlockSize, GridSize},
14 sys as cuda,
15};
16
17#[macro_export]
20macro_rules! kernel_invocation {
21 ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
22 {
23 let name = std::ffi::CString::new(stringify!($function)).unwrap();
24 let function = $module.get_function(&name);
25 match function {
26 Ok(f) => kernel_invocation!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
27 Err(e) => Err(e),
28 }
29 }
30 };
31 ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
32 {
33 fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
34 if false {
35 $(
36 assert_impl_devicecopy($arg);
37 )*
38 };
39
40 let boxed = vec![$(&$arg as *const _ as *mut ::std::ffi::c_void),*].into_boxed_slice();
41
42 Ok($crate::graph::KernelInvocation::_new_internal(
43 $crate::function::BlockSize::from($block),
44 $crate::function::GridSize::from($grid),
45 $shared,
46 $function.to_raw(),
47 vec![].into_boxed_slice(),
48 ))
49 }
50 };
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct KernelInvocation {
56 pub block_dim: BlockSize,
57 pub grid_dim: GridSize,
58 pub shared_mem_bytes: u32,
59 func: cuda::CUfunction,
60 params: Box<*mut c_void>,
61 params_len: Option<usize>,
62}
63
64impl KernelInvocation {
65 #[doc(hidden)]
66 pub fn _new_internal(
67 block_dim: BlockSize,
68 grid_dim: GridSize,
69 shared_mem_bytes: u32,
70 func: cuda::CUfunction,
71 params: Box<*mut c_void>,
72 params_len: usize,
73 ) -> Self {
74 Self {
75 block_dim,
76 grid_dim,
77 shared_mem_bytes,
78 func,
79 params,
80 params_len: Some(params_len),
81 }
82 }
83
84 pub fn to_raw(self) -> cuda::CUDA_KERNEL_NODE_PARAMS {
85 cuda::CUDA_KERNEL_NODE_PARAMS {
86 func: self.func,
87 gridDimX: self.grid_dim.x,
88 gridDimY: self.grid_dim.y,
89 gridDimZ: self.grid_dim.z,
90 blockDimX: self.block_dim.x,
91 blockDimY: self.block_dim.y,
92 blockDimZ: self.block_dim.z,
93 kernelParams: Box::into_raw(self.params),
94 sharedMemBytes: self.shared_mem_bytes,
95 extra: ptr::null_mut(),
96 }
97 }
98
99 pub unsafe fn from_raw(raw: cuda::CUDA_KERNEL_NODE_PARAMS) -> Self {
107 Self {
108 func: raw.func,
109 grid_dim: GridSize::xyz(raw.gridDimX, raw.gridDimY, raw.gridDimZ),
110 block_dim: BlockSize::xyz(raw.blockDimX, raw.gridDimY, raw.gridDimZ),
111 params: Box::from_raw(raw.kernelParams),
112 shared_mem_bytes: raw.sharedMemBytes,
113 params_len: None,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
121#[repr(transparent)]
122pub struct GraphNode {
123 raw: cuda::CUgraphNode,
124}
125
126unsafe impl Send for GraphNode {}
127unsafe impl Sync for GraphNode {}
128
129impl GraphNode {
130 pub fn from_raw(raw: cuda::CUgraphNode) -> Self {
133 Self { raw }
134 }
135
136 pub fn to_raw(self) -> cuda::CUgraphNode {
138 self.raw
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144#[repr(u8)]
145pub enum GraphNodeType {
146 KernelInvocation,
148 Memcpy,
150 Memset,
152 HostExecute,
154 ChildGraph,
156 Empty,
158 WaitEvent,
160 EventRecord,
162 SemaphoreSignal,
164 SemaphoreWait,
166 MemoryAllocation,
168 MemoryFree,
170}
171
172impl GraphNodeType {
173 pub fn from_raw(raw: cuda::CUgraphNodeType) -> Self {
175 match raw {
176 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL => GraphNodeType::KernelInvocation,
177 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY => GraphNodeType::Memcpy,
178 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET => GraphNodeType::Memset,
179 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST => GraphNodeType::HostExecute,
180 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH => GraphNodeType::ChildGraph,
181 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY => GraphNodeType::Empty,
182 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT => GraphNodeType::WaitEvent,
183 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD => GraphNodeType::EventRecord,
184 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL => {
185 GraphNodeType::SemaphoreSignal
186 }
187 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT => {
188 GraphNodeType::SemaphoreWait
189 }
190 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC => GraphNodeType::MemoryAllocation,
191 cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE => GraphNodeType::MemoryFree,
192 }
193 }
194
195 pub fn to_raw(self) -> cuda::CUgraphNodeType {
197 match self {
198 Self::KernelInvocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL,
199 Self::Memcpy => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY,
200 Self::Memset => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET,
201 Self::HostExecute => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST,
202 Self::ChildGraph => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH,
203 Self::Empty => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY,
204 Self::WaitEvent => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT,
205 Self::EventRecord => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD,
206 Self::SemaphoreSignal => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL,
207 Self::SemaphoreWait => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT,
208 Self::MemoryAllocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC,
209 Self::MemoryFree => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE,
210 }
211 }
212}
213
214#[derive(Debug)]
240pub struct Graph {
241 raw: cuda::CUgraph,
242 node_cache: Option<Vec<GraphNode>>,
245}
246
247unsafe impl Send for Graph {}
251unsafe impl Sync for Graph {}
252
253bitflags::bitflags! {
254 #[derive(Default)]
257 pub struct GraphCreationFlags: u32 {
258 const NONE = 0b00000000;
260 }
261}
262
263impl Graph {
264 fn check_deps_are_valid(&mut self, func_name: &str, nodes: &[GraphNode]) -> CudaResult<()> {
265 for (idx, node) in nodes.iter().enumerate() {
267 if let Some(pos) = nodes
268 .iter()
269 .enumerate()
270 .position(|(cur_idx, x)| x == node && cur_idx != idx)
271 {
272 panic!("Duplicate dependency found in call to `{}`, the first instance is at index {}, the second instance is at index {}", func_name, idx, pos);
273 }
274
275 assert!(
276 self.is_valid_node(*node)?,
277 "Invalid (dropped or from another graph) node was given to `{}`",
278 func_name
279 );
280 }
281 Ok(())
282 }
283
284 pub fn is_valid_node(&mut self, node: GraphNode) -> CudaResult<bool> {
286 let nodes = self.nodes()?;
287 Ok(nodes.contains(&node))
288 }
289
290 pub fn num_nodes(&mut self) -> CudaResult<usize> {
292 unsafe {
293 let mut len = MaybeUninit::uninit();
294 cuda::cuGraphGetNodes(self.raw, ptr::null_mut(), len.as_mut_ptr()).to_result()?;
295 Ok(len.assume_init())
296 }
297 }
298
299 pub fn nodes(&mut self) -> CudaResult<&[GraphNode]> {
301 if self.node_cache.is_none() {
302 unsafe {
303 let mut len = self.num_nodes()?;
304 let mut vec = Vec::with_capacity(len);
305 cuda::cuGraphGetNodes(
306 self.raw,
307 vec.as_mut_ptr() as *mut cuda::CUgraphNode,
308 &mut len as *mut usize,
309 )
310 .to_result()?;
311 vec.set_len(len);
312 self.node_cache = Some(vec);
313 }
314 }
315 Ok(self.node_cache.as_ref().unwrap())
316 }
317
318 pub fn new(flags: GraphCreationFlags) -> CudaResult<Self> {
320 let mut raw = MaybeUninit::uninit();
321
322 unsafe {
323 cuda::cuGraphCreate(raw.as_mut_ptr(), flags.bits).to_result()?;
324
325 Ok(Self {
326 raw: raw.assume_init(),
327 node_cache: Some(vec![]),
328 })
329 }
330 }
331
332 #[cfg(any(windows, unix))]
335 pub fn dump_debug_dotfile<P: AsRef<Path>>(&mut self, path: P) -> CudaResult<()> {
336 extern "C" {
338 fn cuGraphDebugDotPrint(
339 hGraph: cuda::CUgraph,
340 path: *const c_char,
341 flags: c_uint,
342 ) -> cuda::CUresult;
343 }
344
345 let path = path.as_ref();
346 let mut buf = Vec::new();
347 #[cfg(unix)]
348 {
349 use std::os::unix::ffi::OsStrExt;
350 buf.extend(path.as_os_str().as_bytes());
351 buf.push(0);
352 }
353
354 #[cfg(windows)]
355 {
356 use std::os::windows::ffi::OsStrExt;
357 buf.extend(
358 path.as_os_str()
359 .encode_wide()
360 .chain(Some(0))
361 .map(|b| {
362 let b = b.to_ne_bytes();
363 b.get(0).copied().into_iter().chain(b.get(1).copied())
364 })
365 .flatten(),
366 );
367 }
368
369 unsafe { cuGraphDebugDotPrint(self.raw, "./out.dot\0".as_ptr().cast(), 1 << 0).to_result() }
370 }
371
372 pub fn add_kernel_node(
376 &mut self,
377 invocation: KernelInvocation,
378 dependencies: impl AsRef<[GraphNode]>,
379 ) -> CudaResult<GraphNode> {
380 let deps = dependencies.as_ref();
381 self.check_deps_are_valid("add_kernel_node", deps)?;
382 self.node_cache = None;
384 unsafe {
385 let deps_ptr = deps.as_ptr().cast();
386 let mut node = MaybeUninit::<GraphNode>::uninit();
387 let params = invocation.to_raw();
388 cuda::cuGraphAddKernelNode(
389 node.as_mut_ptr().cast(),
390 self.raw,
391 deps_ptr,
392 deps.len(),
393 ¶ms as *const _,
394 )
395 .to_result()?;
396 Ok(node.assume_init())
397 }
398 }
399
400 pub fn num_edges(&mut self) -> CudaResult<usize> {
402 unsafe {
403 let mut size = MaybeUninit::uninit();
404 cuda::cuGraphGetEdges(
405 self.raw,
406 ptr::null_mut(),
407 ptr::null_mut(),
408 size.as_mut_ptr(),
409 )
410 .to_result()?;
411 Ok(size.assume_init())
412 }
413 }
414
415 pub fn edges(&mut self) -> CudaResult<Vec<(GraphNode, GraphNode)>> {
423 unsafe {
424 let num_edges = self.num_edges()?;
425 let mut from = vec![ptr::null_mut(); num_edges].into_boxed_slice();
426 let mut to = vec![ptr::null_mut(); num_edges].into_boxed_slice();
427
428 cuda::cuGraphGetEdges(
429 self.raw,
430 from.as_mut_ptr(),
431 to.as_mut_ptr(),
432 &num_edges as *const _ as *mut usize,
433 )
434 .to_result()?;
435
436 let mut out = Vec::with_capacity(num_edges);
437 for (from, to) in from.iter().zip(to.iter()) {
438 out.push((GraphNode::from_raw(*from), GraphNode::from_raw(*to)))
439 }
440 Ok(out)
441 }
442 }
443
444 pub fn node_type(&mut self, node: GraphNode) -> CudaResult<GraphNodeType> {
446 self.check_deps_are_valid("node_type", &[node])?;
447 unsafe {
448 let mut ty = MaybeUninit::uninit();
449 cuda::cuGraphNodeGetType(node.to_raw(), ty.as_mut_ptr()).to_result()?;
450 let raw = ty.assume_init();
451 Ok(GraphNodeType::from_raw(raw))
452 }
453 }
454
455 pub fn kernel_node_params(&mut self, node: GraphNode) -> CudaResult<KernelInvocation> {
461 self.check_deps_are_valid("kernel_node_params", &[node])?;
462 assert_eq!(
463 self.node_type(node)?,
464 GraphNodeType::KernelInvocation,
465 "Node given to `kernel_node_params` was not a kernel invocation node"
466 );
467 unsafe {
468 let mut params = MaybeUninit::uninit();
469 cuda::cuGraphKernelNodeGetParams(node.to_raw(), params.as_mut_ptr());
470 Ok(KernelInvocation::from_raw(params.assume_init()))
471 }
472 }
473
474 pub unsafe fn from_raw(raw: cuda::CUgraph) -> Self {
483 Self {
484 raw,
485 node_cache: None,
486 }
487 }
488
489 pub fn into_raw(self) -> cuda::CUgraph {
492 let me = ManuallyDrop::new(self);
493 me.raw
494 }
495}
496
497impl Drop for Graph {
498 fn drop(&mut self) {
499 unsafe {
500 cuda::cuGraphDestroy(self.raw);
501 }
502 }
503}