cust/
graph.rs

1//! CUDA Graph management.
2
3use std::{
4    ffi::c_void,
5    mem::{ManuallyDrop, MaybeUninit},
6    os::raw::{c_char, c_uint},
7    path::Path,
8    ptr,
9};
10
11use crate::{
12    error::{CudaResult, ToResult},
13    function::{BlockSize, GridSize},
14    sys as cuda,
15};
16
17/// Creates a kernel invocation using the same syntax as [`launch`] to be used to insert kernel launches inside graphs.
18/// This returns a Result of a kernel invocation object you can then pass to a graph.
19#[macro_export]
20macro_rules! kernel_invocation {
21    ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
22        {
23            let name = std::ffi::CString::new(stringify!($function)).unwrap();
24            let function = $module.get_function(&name);
25            match function {
26                Ok(f) => kernel_invocation!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
27                Err(e) => Err(e),
28            }
29        }
30    };
31    ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
32        {
33            fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
34            if false {
35                $(
36                    assert_impl_devicecopy($arg);
37                )*
38            };
39
40            let boxed = vec![$(&$arg as *const _ as *mut ::std::ffi::c_void),*].into_boxed_slice();
41
42            Ok($crate::graph::KernelInvocation::_new_internal(
43                $crate::function::BlockSize::from($block),
44                $crate::function::GridSize::from($grid),
45                $shared,
46                $function.to_raw(),
47                vec![].into_boxed_slice(),
48            ))
49        }
50    };
51}
52
53/// A prepared kernel invocation to be added to a graph.
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct KernelInvocation {
56    pub block_dim: BlockSize,
57    pub grid_dim: GridSize,
58    pub shared_mem_bytes: u32,
59    func: cuda::CUfunction,
60    params: Box<*mut c_void>,
61    params_len: Option<usize>,
62}
63
64impl KernelInvocation {
65    #[doc(hidden)]
66    pub fn _new_internal(
67        block_dim: BlockSize,
68        grid_dim: GridSize,
69        shared_mem_bytes: u32,
70        func: cuda::CUfunction,
71        params: Box<*mut c_void>,
72        params_len: usize,
73    ) -> Self {
74        Self {
75            block_dim,
76            grid_dim,
77            shared_mem_bytes,
78            func,
79            params,
80            params_len: Some(params_len),
81        }
82    }
83
84    pub fn to_raw(self) -> cuda::CUDA_KERNEL_NODE_PARAMS {
85        cuda::CUDA_KERNEL_NODE_PARAMS {
86            func: self.func,
87            gridDimX: self.grid_dim.x,
88            gridDimY: self.grid_dim.y,
89            gridDimZ: self.grid_dim.z,
90            blockDimX: self.block_dim.x,
91            blockDimY: self.block_dim.y,
92            blockDimZ: self.block_dim.z,
93            kernelParams: Box::into_raw(self.params),
94            sharedMemBytes: self.shared_mem_bytes,
95            extra: ptr::null_mut(),
96        }
97    }
98
99    /// Makes a new invocation from its raw counterpart.
100    ///
101    /// # Safety
102    ///
103    /// The function pointer must be a valid CUfunction pointer and
104    /// params' "ownership" must be able to be transferred to the invocation
105    /// (it will be turned into a Box).
106    pub unsafe fn from_raw(raw: cuda::CUDA_KERNEL_NODE_PARAMS) -> Self {
107        Self {
108            func: raw.func,
109            grid_dim: GridSize::xyz(raw.gridDimX, raw.gridDimY, raw.gridDimZ),
110            block_dim: BlockSize::xyz(raw.blockDimX, raw.gridDimY, raw.gridDimZ),
111            params: Box::from_raw(raw.kernelParams),
112            shared_mem_bytes: raw.sharedMemBytes,
113            params_len: None,
114        }
115    }
116}
117
118/// An opaque handle to a node in a graph. There are no methods on [`GraphNode`], they
119/// are just handles for identifying nodes to be used on [`Graph`] functions.
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
121#[repr(transparent)]
122pub struct GraphNode {
123    raw: cuda::CUgraphNode,
124}
125
126unsafe impl Send for GraphNode {}
127unsafe impl Sync for GraphNode {}
128
129impl GraphNode {
130    /// Creates a new node from a raw handle. This is safe because node checks
131    /// happen on the graph when functions are called.
132    pub fn from_raw(raw: cuda::CUgraphNode) -> Self {
133        Self { raw }
134    }
135
136    /// Converts this node into a raw handle.
137    pub fn to_raw(self) -> cuda::CUgraphNode {
138        self.raw
139    }
140}
141
142/// The different types that a node can be.
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144#[repr(u8)]
145pub enum GraphNodeType {
146    /// Invokes a GPU kernel.
147    KernelInvocation,
148    /// Copies memory from one location to another (CPU to GPU/GPU to CPU/GPU to GPU).
149    Memcpy,
150    /// Sets some memory to some value.
151    Memset,
152    /// Executes a function on the host (CPU).
153    HostExecute,
154    /// Executes a child graph.
155    ChildGraph,
156    /// Does nothing.
157    Empty,
158    /// Waits for an event.
159    WaitEvent,
160    /// Record an event.
161    EventRecord,
162    /// Performs a signal operation on external semaphore objects.
163    SemaphoreSignal,
164    /// Performs a wait operation on external semaphore objects.
165    SemaphoreWait,
166    /// Allocates some memory.
167    MemoryAllocation,
168    /// Frees some memory.
169    MemoryFree,
170}
171
172impl GraphNodeType {
173    /// Converts a raw type to a [`GraphNodeType`].
174    pub fn from_raw(raw: cuda::CUgraphNodeType) -> Self {
175        match raw {
176            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL => GraphNodeType::KernelInvocation,
177            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY => GraphNodeType::Memcpy,
178            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET => GraphNodeType::Memset,
179            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST => GraphNodeType::HostExecute,
180            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH => GraphNodeType::ChildGraph,
181            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY => GraphNodeType::Empty,
182            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT => GraphNodeType::WaitEvent,
183            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD => GraphNodeType::EventRecord,
184            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL => {
185                GraphNodeType::SemaphoreSignal
186            }
187            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT => {
188                GraphNodeType::SemaphoreWait
189            }
190            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC => GraphNodeType::MemoryAllocation,
191            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE => GraphNodeType::MemoryFree,
192        }
193    }
194
195    /// Converts this type to its raw counterpart.
196    pub fn to_raw(self) -> cuda::CUgraphNodeType {
197        match self {
198            Self::KernelInvocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL,
199            Self::Memcpy => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY,
200            Self::Memset => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET,
201            Self::HostExecute => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST,
202            Self::ChildGraph => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH,
203            Self::Empty => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY,
204            Self::WaitEvent => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT,
205            Self::EventRecord => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD,
206            Self::SemaphoreSignal => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL,
207            Self::SemaphoreWait => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT,
208            Self::MemoryAllocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC,
209            Self::MemoryFree => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE,
210        }
211    }
212}
213
214/// A graph object used for building a hierarchy of kernels to launch at once.
215/// Graphs are used to control jobs that have multiple kernels that need to be launched back to back.
216/// They reduce the overhead of launching kernels and cpu/gpu transfer by launching everything at once.
217///
218/// CUDA Graphs are inherently extremely unsafe, it is very easy to cause UB by passing a dropped node,
219/// an invalid node, a node from another graph, etc. To mostly solve this we query the nodes inside the
220/// graph every time a node is used to check if it is valid. This sounds expensive, but in practice graphs
221/// are not large enough where checking makes a big difference. Additionally, internally we cache the nodes
222/// that are known to be up-to-date
223///
224/// These safety measures should account for most safety pitfalls, if you encounter a way to bypass them
225/// please file an issue and we will try to fix it ASAP.
226///
227/// However, it is inherently impossible for us to validate graph usage, just like launching kernels.
228/// Therefore, launching graphs is unsafe and always will be, the user must validate that:
229/// - All kernel launches are safe (same invariants as launching a normal kernel)
230/// - Memory structures used inside the graph must not be dropped before the graph is executed (this will likely
231/// throw an error if you try doing it).
232///
233/// These problems can easily be avoided by launching the graph as soon as or right after it is instantiated,
234/// instead of holding onto it long-term, which can cause problems if data is dropped before the graph is executed.
235///
236/// Graphs are **not** threadsafe, therefore it is not possible to modify them from multiple threads at the
237/// same time. This is statically prevented by taking mutable references for all functions. You can however
238/// send graphs between threads.
239#[derive(Debug)]
240pub struct Graph {
241    raw: cuda::CUgraph,
242    // a cache of nodes, this cache is None when the node cache is out of date,
243    // it will get refreshed when get_nodes is called.
244    node_cache: Option<Vec<GraphNode>>,
245}
246
247// SAFETY: the cuda driver API docs say that any operations on the same graph object are not
248// thredsafe and must be serialized, but passing graphs to and from threads should be fine.
249// The fact that methods on Graph take `&mut self` statically prevents this from happening (thanks rustc <3)
250unsafe impl Send for Graph {}
251unsafe impl Sync for Graph {}
252
253bitflags::bitflags! {
254    /// Flags for creating a graph. This is currently empty but reserved for
255    /// any flags which may be added in the future.
256    #[derive(Default)]
257    pub struct GraphCreationFlags: u32 {
258        /// No flags, currently the only option available.
259        const NONE = 0b00000000;
260    }
261}
262
263impl Graph {
264    fn check_deps_are_valid(&mut self, func_name: &str, nodes: &[GraphNode]) -> CudaResult<()> {
265        // per the docs, nodes must be valid AND not duplicate.
266        for (idx, node) in nodes.iter().enumerate() {
267            if let Some(pos) = nodes
268                .iter()
269                .enumerate()
270                .position(|(cur_idx, x)| x == node && cur_idx != idx)
271            {
272                panic!("Duplicate dependency found in call to `{}`, the first instance is at index {}, the second instance is at index {}", func_name, idx, pos);
273            }
274
275            assert!(
276                self.is_valid_node(*node)?,
277                "Invalid (dropped or from another graph) node was given to `{}`",
278                func_name
279            );
280        }
281        Ok(())
282    }
283
284    /// Check if a node is valid in this graph.
285    pub fn is_valid_node(&mut self, node: GraphNode) -> CudaResult<bool> {
286        let nodes = self.nodes()?;
287        Ok(nodes.contains(&node))
288    }
289
290    /// Get the number of nodes in this graph.
291    pub fn num_nodes(&mut self) -> CudaResult<usize> {
292        unsafe {
293            let mut len = MaybeUninit::uninit();
294            cuda::cuGraphGetNodes(self.raw, ptr::null_mut(), len.as_mut_ptr()).to_result()?;
295            Ok(len.assume_init())
296        }
297    }
298
299    /// Get all of the nodes in this graph.
300    pub fn nodes(&mut self) -> CudaResult<&[GraphNode]> {
301        if self.node_cache.is_none() {
302            unsafe {
303                let mut len = self.num_nodes()?;
304                let mut vec = Vec::with_capacity(len);
305                cuda::cuGraphGetNodes(
306                    self.raw,
307                    vec.as_mut_ptr() as *mut cuda::CUgraphNode,
308                    &mut len as *mut usize,
309                )
310                .to_result()?;
311                vec.set_len(len);
312                self.node_cache = Some(vec);
313            }
314        }
315        Ok(self.node_cache.as_ref().unwrap())
316    }
317
318    /// Creates a new graph from some flags.
319    pub fn new(flags: GraphCreationFlags) -> CudaResult<Self> {
320        let mut raw = MaybeUninit::uninit();
321
322        unsafe {
323            cuda::cuGraphCreate(raw.as_mut_ptr(), flags.bits).to_result()?;
324
325            Ok(Self {
326                raw: raw.assume_init(),
327                node_cache: Some(vec![]),
328            })
329        }
330    }
331
332    /// Dumps a dotfile to a path which contains a visual representation of the graph for debugging.
333    /// This dotfile can be turned into an image with graphviz.
334    #[cfg(any(windows, unix))]
335    pub fn dump_debug_dotfile<P: AsRef<Path>>(&mut self, path: P) -> CudaResult<()> {
336        // not currently present in cuda-driver-sys for some reason
337        extern "C" {
338            fn cuGraphDebugDotPrint(
339                hGraph: cuda::CUgraph,
340                path: *const c_char,
341                flags: c_uint,
342            ) -> cuda::CUresult;
343        }
344
345        let path = path.as_ref();
346        let mut buf = Vec::new();
347        #[cfg(unix)]
348        {
349            use std::os::unix::ffi::OsStrExt;
350            buf.extend(path.as_os_str().as_bytes());
351            buf.push(0);
352        }
353
354        #[cfg(windows)]
355        {
356            use std::os::windows::ffi::OsStrExt;
357            buf.extend(
358                path.as_os_str()
359                    .encode_wide()
360                    .chain(Some(0))
361                    .map(|b| {
362                        let b = b.to_ne_bytes();
363                        b.get(0).copied().into_iter().chain(b.get(1).copied())
364                    })
365                    .flatten(),
366            );
367        }
368
369        unsafe { cuGraphDebugDotPrint(self.raw, "./out.dot\0".as_ptr().cast(), 1 << 0).to_result() }
370    }
371
372    /// Adds a kernel invocation node to this graph, [`KernelInvocation`] can be created using
373    /// [`kernel_invocation`] which uses the same syntax as [`launch`](crate::launch). This will
374    /// place the node after its dependencies (which will execute before it).
375    pub fn add_kernel_node(
376        &mut self,
377        invocation: KernelInvocation,
378        dependencies: impl AsRef<[GraphNode]>,
379    ) -> CudaResult<GraphNode> {
380        let deps = dependencies.as_ref();
381        self.check_deps_are_valid("add_kernel_node", deps)?;
382        // invalidate cache because it will change.
383        self.node_cache = None;
384        unsafe {
385            let deps_ptr = deps.as_ptr().cast();
386            let mut node = MaybeUninit::<GraphNode>::uninit();
387            let params = invocation.to_raw();
388            cuda::cuGraphAddKernelNode(
389                node.as_mut_ptr().cast(),
390                self.raw,
391                deps_ptr,
392                deps.len(),
393                &params as *const _,
394            )
395            .to_result()?;
396            Ok(node.assume_init())
397        }
398    }
399
400    /// The number of edges (dependency edges) inside this graph.
401    pub fn num_edges(&mut self) -> CudaResult<usize> {
402        unsafe {
403            let mut size = MaybeUninit::uninit();
404            cuda::cuGraphGetEdges(
405                self.raw,
406                ptr::null_mut(),
407                ptr::null_mut(),
408                size.as_mut_ptr(),
409            )
410            .to_result()?;
411            Ok(size.assume_init())
412        }
413    }
414
415    /// Returns a list of the dependency edges of this graph.
416    ///
417    /// # Returns
418    ///
419    /// Returns a vector of the edge from one node to another. There may be multiples
420    /// of the same node in the vector, since a node can have multiple edges coming out of it.
421    /// `(A, B)` means that `B` has a dependency on `A`, that is, `A` will execute before `B`.
422    pub fn edges(&mut self) -> CudaResult<Vec<(GraphNode, GraphNode)>> {
423        unsafe {
424            let num_edges = self.num_edges()?;
425            let mut from = vec![ptr::null_mut(); num_edges].into_boxed_slice();
426            let mut to = vec![ptr::null_mut(); num_edges].into_boxed_slice();
427
428            cuda::cuGraphGetEdges(
429                self.raw,
430                from.as_mut_ptr(),
431                to.as_mut_ptr(),
432                &num_edges as *const _ as *mut usize,
433            )
434            .to_result()?;
435
436            let mut out = Vec::with_capacity(num_edges);
437            for (from, to) in from.iter().zip(to.iter()) {
438                out.push((GraphNode::from_raw(*from), GraphNode::from_raw(*to)))
439            }
440            Ok(out)
441        }
442    }
443
444    /// Retrieves the type of a node.
445    pub fn node_type(&mut self, node: GraphNode) -> CudaResult<GraphNodeType> {
446        self.check_deps_are_valid("node_type", &[node])?;
447        unsafe {
448            let mut ty = MaybeUninit::uninit();
449            cuda::cuGraphNodeGetType(node.to_raw(), ty.as_mut_ptr()).to_result()?;
450            let raw = ty.assume_init();
451            Ok(GraphNodeType::from_raw(raw))
452        }
453    }
454
455    /// Retrieves the invocation parameters for a kernel invocation node.
456    ///
457    /// # Panics
458    ///
459    /// Panics if the node is invalid or if the node is not a kernel invocation node.
460    pub fn kernel_node_params(&mut self, node: GraphNode) -> CudaResult<KernelInvocation> {
461        self.check_deps_are_valid("kernel_node_params", &[node])?;
462        assert_eq!(
463            self.node_type(node)?,
464            GraphNodeType::KernelInvocation,
465            "Node given to `kernel_node_params` was not a kernel invocation node"
466        );
467        unsafe {
468            let mut params = MaybeUninit::uninit();
469            cuda::cuGraphKernelNodeGetParams(node.to_raw(), params.as_mut_ptr());
470            Ok(KernelInvocation::from_raw(params.assume_init()))
471        }
472    }
473
474    /// Creates a new [`Graph`] from a raw handle.
475    ///
476    /// # Safety
477    ///
478    /// This assumes a couple of things:
479    /// - This handle is exclusive, nothing else can use it in any way, including trying to drop it.
480    /// - It must be a valid handle. This invariant must be upheld, the library is allowed to rely on
481    /// the fact that the handle is valid in terms of safety, therefore failure to uphold this invariant is UB.
482    pub unsafe fn from_raw(raw: cuda::CUgraph) -> Self {
483        Self {
484            raw,
485            node_cache: None,
486        }
487    }
488
489    /// Consumes this [`Graph`], turning it into a raw handle. The handle will not be dropped,
490    /// it is up to the caller to ensure the graph is destroyed.
491    pub fn into_raw(self) -> cuda::CUgraph {
492        let me = ManuallyDrop::new(self);
493        me.raw
494    }
495}
496
497impl Drop for Graph {
498    fn drop(&mut self) {
499        unsafe {
500            cuda::cuGraphDestroy(self.raw);
501        }
502    }
503}