1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
//! CUDA Graph management.

use std::{
    ffi::c_void,
    mem::{ManuallyDrop, MaybeUninit},
    os::raw::{c_char, c_uint},
    path::Path,
    ptr,
};

use crate::{
    error::{CudaResult, ToResult},
    function::{BlockSize, GridSize},
    sys as cuda,
};

/// Creates a kernel invocation using the same syntax as [`launch`] to be used to insert kernel launches inside graphs.
/// This returns a Result of a kernel invocation object you can then pass to a graph.
#[macro_export]
macro_rules! kernel_invocation {
    ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
        {
            let name = std::ffi::CString::new(stringify!($function)).unwrap();
            let function = $module.get_function(&name);
            match function {
                Ok(f) => kernel_invocation!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
                Err(e) => Err(e),
            }
        }
    };
    ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
        {
            fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
            if false {
                $(
                    assert_impl_devicecopy($arg);
                )*
            };

            let boxed = vec![$(&$arg as *const _ as *mut ::std::ffi::c_void),*].into_boxed_slice();

            Ok($crate::graph::KernelInvocation::_new_internal(
                $crate::function::BlockSize::from($block),
                $crate::function::GridSize::from($grid),
                $shared,
                $function.to_raw(),
                vec![].into_boxed_slice(),
            ))
        }
    };
}

/// A prepared kernel invocation to be added to a graph.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KernelInvocation {
    pub block_dim: BlockSize,
    pub grid_dim: GridSize,
    pub shared_mem_bytes: u32,
    func: cuda::CUfunction,
    params: Box<*mut c_void>,
    params_len: Option<usize>,
}

impl KernelInvocation {
    #[doc(hidden)]
    pub fn _new_internal(
        block_dim: BlockSize,
        grid_dim: GridSize,
        shared_mem_bytes: u32,
        func: cuda::CUfunction,
        params: Box<*mut c_void>,
        params_len: usize,
    ) -> Self {
        Self {
            block_dim,
            grid_dim,
            shared_mem_bytes,
            func,
            params,
            params_len: Some(params_len),
        }
    }

    pub fn to_raw(self) -> cuda::CUDA_KERNEL_NODE_PARAMS {
        cuda::CUDA_KERNEL_NODE_PARAMS {
            func: self.func,
            gridDimX: self.grid_dim.x,
            gridDimY: self.grid_dim.y,
            gridDimZ: self.grid_dim.z,
            blockDimX: self.block_dim.x,
            blockDimY: self.block_dim.y,
            blockDimZ: self.block_dim.z,
            kernelParams: Box::into_raw(self.params),
            sharedMemBytes: self.shared_mem_bytes,
            extra: ptr::null_mut(),
        }
    }

    /// Makes a new invocation from its raw counterpart.
    ///
    /// # Safety
    ///
    /// The function pointer must be a valid CUfunction pointer and
    /// params' "ownership" must be able to be transferred to the invocation
    /// (it will be turned into a Box).
    pub unsafe fn from_raw(raw: cuda::CUDA_KERNEL_NODE_PARAMS) -> Self {
        Self {
            func: raw.func,
            grid_dim: GridSize::xyz(raw.gridDimX, raw.gridDimY, raw.gridDimZ),
            block_dim: BlockSize::xyz(raw.blockDimX, raw.gridDimY, raw.gridDimZ),
            params: Box::from_raw(raw.kernelParams),
            shared_mem_bytes: raw.sharedMemBytes,
            params_len: None,
        }
    }
}

/// An opaque handle to a node in a graph. There are no methods on [`GraphNode`], they
/// are just handles for identifying nodes to be used on [`Graph`] functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct GraphNode {
    raw: cuda::CUgraphNode,
}

unsafe impl Send for GraphNode {}
unsafe impl Sync for GraphNode {}

impl GraphNode {
    /// Creates a new node from a raw handle. This is safe because node checks
    /// happen on the graph when functions are called.
    pub fn from_raw(raw: cuda::CUgraphNode) -> Self {
        Self { raw }
    }

    /// Converts this node into a raw handle.
    pub fn to_raw(self) -> cuda::CUgraphNode {
        self.raw
    }
}

/// The different types that a node can be.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum GraphNodeType {
    /// Invokes a GPU kernel.
    KernelInvocation,
    /// Copies memory from one location to another (CPU to GPU/GPU to CPU/GPU to GPU).
    Memcpy,
    /// Sets some memory to some value.
    Memset,
    /// Executes a function on the host (CPU).
    HostExecute,
    /// Executes a child graph.
    ChildGraph,
    /// Does nothing.
    Empty,
    /// Waits for an event.
    WaitEvent,
    /// Record an event.
    EventRecord,
    /// Performs a signal operation on external semaphore objects.
    SemaphoreSignal,
    /// Performs a wait operation on external semaphore objects.
    SemaphoreWait,
    /// Allocates some memory.
    MemoryAllocation,
    /// Frees some memory.
    MemoryFree,
}

impl GraphNodeType {
    /// Converts a raw type to a [`GraphNodeType`].
    pub fn from_raw(raw: cuda::CUgraphNodeType) -> Self {
        match raw {
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL => GraphNodeType::KernelInvocation,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY => GraphNodeType::Memcpy,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET => GraphNodeType::Memset,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST => GraphNodeType::HostExecute,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH => GraphNodeType::ChildGraph,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY => GraphNodeType::Empty,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT => GraphNodeType::WaitEvent,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD => GraphNodeType::EventRecord,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL => {
                GraphNodeType::SemaphoreSignal
            }
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT => {
                GraphNodeType::SemaphoreWait
            }
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC => GraphNodeType::MemoryAllocation,
            cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE => GraphNodeType::MemoryFree,
        }
    }

    /// Converts this type to its raw counterpart.
    pub fn to_raw(self) -> cuda::CUgraphNodeType {
        match self {
            Self::KernelInvocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL,
            Self::Memcpy => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY,
            Self::Memset => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET,
            Self::HostExecute => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST,
            Self::ChildGraph => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH,
            Self::Empty => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY,
            Self::WaitEvent => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT,
            Self::EventRecord => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD,
            Self::SemaphoreSignal => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL,
            Self::SemaphoreWait => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT,
            Self::MemoryAllocation => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC,
            Self::MemoryFree => cuda::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE,
        }
    }
}

/// A graph object used for building a hierarchy of kernels to launch at once.
/// Graphs are used to control jobs that have multiple kernels that need to be launched back to back.
/// They reduce the overhead of launching kernels and cpu/gpu transfer by launching everything at once.
///
/// CUDA Graphs are inherently extremely unsafe, it is very easy to cause UB by passing a dropped node,
/// an invalid node, a node from another graph, etc. To mostly solve this we query the nodes inside the
/// graph every time a node is used to check if it is valid. This sounds expensive, but in practice graphs
/// are not large enough where checking makes a big difference. Additionally, internally we cache the nodes
/// that are known to be up-to-date
///
/// These safety measures should account for most safety pitfalls, if you encounter a way to bypass them
/// please file an issue and we will try to fix it ASAP.
///
/// However, it is inherently impossible for us to validate graph usage, just like launching kernels.
/// Therefore, launching graphs is unsafe and always will be, the user must validate that:
/// - All kernel launches are safe (same invariants as launching a normal kernel)
/// - Memory structures used inside the graph must not be dropped before the graph is executed (this will likely
/// throw an error if you try doing it).
///
/// These problems can easily be avoided by launching the graph as soon as or right after it is instantiated,
/// instead of holding onto it long-term, which can cause problems if data is dropped before the graph is executed.
///
/// Graphs are **not** threadsafe, therefore it is not possible to modify them from multiple threads at the
/// same time. This is statically prevented by taking mutable references for all functions. You can however
/// send graphs between threads.
#[derive(Debug)]
pub struct Graph {
    raw: cuda::CUgraph,
    // a cache of nodes, this cache is None when the node cache is out of date,
    // it will get refreshed when get_nodes is called.
    node_cache: Option<Vec<GraphNode>>,
}

// SAFETY: the cuda driver API docs say that any operations on the same graph object are not
// thredsafe and must be serialized, but passing graphs to and from threads should be fine.
// The fact that methods on Graph take `&mut self` statically prevents this from happening (thanks rustc <3)
unsafe impl Send for Graph {}
unsafe impl Sync for Graph {}

bitflags::bitflags! {
    /// Flags for creating a graph. This is currently empty but reserved for
    /// any flags which may be added in the future.
    #[derive(Default)]
    pub struct GraphCreationFlags: u32 {
        /// No flags, currently the only option available.
        const NONE = 0b00000000;
    }
}

impl Graph {
    fn check_deps_are_valid(&mut self, func_name: &str, nodes: &[GraphNode]) -> CudaResult<()> {
        // per the docs, nodes must be valid AND not duplicate.
        for (idx, node) in nodes.iter().enumerate() {
            if let Some(pos) = nodes
                .iter()
                .enumerate()
                .position(|(cur_idx, x)| x == node && cur_idx != idx)
            {
                panic!("Duplicate dependency found in call to `{}`, the first instance is at index {}, the second instance is at index {}", func_name, idx, pos);
            }

            assert!(
                self.is_valid_node(*node)?,
                "Invalid (dropped or from another graph) node was given to `{}`",
                func_name
            );
        }
        Ok(())
    }

    /// Check if a node is valid in this graph.
    pub fn is_valid_node(&mut self, node: GraphNode) -> CudaResult<bool> {
        let nodes = self.nodes()?;
        Ok(nodes.contains(&node))
    }

    /// Get the number of nodes in this graph.
    pub fn num_nodes(&mut self) -> CudaResult<usize> {
        unsafe {
            let mut len = MaybeUninit::uninit();
            cuda::cuGraphGetNodes(self.raw, ptr::null_mut(), len.as_mut_ptr()).to_result()?;
            Ok(len.assume_init())
        }
    }

    /// Get all of the nodes in this graph.
    pub fn nodes(&mut self) -> CudaResult<&[GraphNode]> {
        if self.node_cache.is_none() {
            unsafe {
                let mut len = self.num_nodes()?;
                let mut vec = Vec::with_capacity(len);
                cuda::cuGraphGetNodes(
                    self.raw,
                    vec.as_mut_ptr() as *mut cuda::CUgraphNode,
                    &mut len as *mut usize,
                )
                .to_result()?;
                vec.set_len(len);
                self.node_cache = Some(vec);
            }
        }
        Ok(self.node_cache.as_ref().unwrap())
    }

    /// Creates a new graph from some flags.
    pub fn new(flags: GraphCreationFlags) -> CudaResult<Self> {
        let mut raw = MaybeUninit::uninit();

        unsafe {
            cuda::cuGraphCreate(raw.as_mut_ptr(), flags.bits).to_result()?;

            Ok(Self {
                raw: raw.assume_init(),
                node_cache: Some(vec![]),
            })
        }
    }

    /// Dumps a dotfile to a path which contains a visual representation of the graph for debugging.
    /// This dotfile can be turned into an image with graphviz.
    #[cfg(any(windows, unix))]
    pub fn dump_debug_dotfile<P: AsRef<Path>>(&mut self, path: P) -> CudaResult<()> {
        // not currently present in cuda-driver-sys for some reason
        extern "C" {
            fn cuGraphDebugDotPrint(
                hGraph: cuda::CUgraph,
                path: *const c_char,
                flags: c_uint,
            ) -> cuda::CUresult;
        }

        let path = path.as_ref();
        let mut buf = Vec::new();
        #[cfg(unix)]
        {
            use std::os::unix::ffi::OsStrExt;
            buf.extend(path.as_os_str().as_bytes());
            buf.push(0);
        }

        #[cfg(windows)]
        {
            use std::os::windows::ffi::OsStrExt;
            buf.extend(
                path.as_os_str()
                    .encode_wide()
                    .chain(Some(0))
                    .map(|b| {
                        let b = b.to_ne_bytes();
                        b.get(0).copied().into_iter().chain(b.get(1).copied())
                    })
                    .flatten(),
            );
        }

        unsafe { cuGraphDebugDotPrint(self.raw, "./out.dot\0".as_ptr().cast(), 1 << 0).to_result() }
    }

    /// Adds a kernel invocation node to this graph, [`KernelInvocation`] can be created using
    /// [`kernel_invocation`] which uses the same syntax as [`launch`](crate::launch). This will
    /// place the node after its dependencies (which will execute before it).
    pub fn add_kernel_node(
        &mut self,
        invocation: KernelInvocation,
        dependencies: impl AsRef<[GraphNode]>,
    ) -> CudaResult<GraphNode> {
        let deps = dependencies.as_ref();
        self.check_deps_are_valid("add_kernel_node", deps)?;
        // invalidate cache because it will change.
        self.node_cache = None;
        unsafe {
            let deps_ptr = deps.as_ptr().cast();
            let mut node = MaybeUninit::<GraphNode>::uninit();
            let params = invocation.to_raw();
            cuda::cuGraphAddKernelNode(
                node.as_mut_ptr().cast(),
                self.raw,
                deps_ptr,
                deps.len(),
                &params as *const _,
            )
            .to_result()?;
            Ok(node.assume_init())
        }
    }

    /// The number of edges (dependency edges) inside this graph.
    pub fn num_edges(&mut self) -> CudaResult<usize> {
        unsafe {
            let mut size = MaybeUninit::uninit();
            cuda::cuGraphGetEdges(
                self.raw,
                ptr::null_mut(),
                ptr::null_mut(),
                size.as_mut_ptr(),
            )
            .to_result()?;
            Ok(size.assume_init())
        }
    }

    /// Returns a list of the dependency edges of this graph.
    ///
    /// # Returns
    ///
    /// Returns a vector of the edge from one node to another. There may be multiples
    /// of the same node in the vector, since a node can have multiple edges coming out of it.
    /// `(A, B)` means that `B` has a dependency on `A`, that is, `A` will execute before `B`.
    pub fn edges(&mut self) -> CudaResult<Vec<(GraphNode, GraphNode)>> {
        unsafe {
            let num_edges = self.num_edges()?;
            let mut from = vec![ptr::null_mut(); num_edges].into_boxed_slice();
            let mut to = vec![ptr::null_mut(); num_edges].into_boxed_slice();

            cuda::cuGraphGetEdges(
                self.raw,
                from.as_mut_ptr(),
                to.as_mut_ptr(),
                &num_edges as *const _ as *mut usize,
            )
            .to_result()?;

            let mut out = Vec::with_capacity(num_edges);
            for (from, to) in from.iter().zip(to.iter()) {
                out.push((GraphNode::from_raw(*from), GraphNode::from_raw(*to)))
            }
            Ok(out)
        }
    }

    /// Retrieves the type of a node.
    pub fn node_type(&mut self, node: GraphNode) -> CudaResult<GraphNodeType> {
        self.check_deps_are_valid("node_type", &[node])?;
        unsafe {
            let mut ty = MaybeUninit::uninit();
            cuda::cuGraphNodeGetType(node.to_raw(), ty.as_mut_ptr()).to_result()?;
            let raw = ty.assume_init();
            Ok(GraphNodeType::from_raw(raw))
        }
    }

    /// Retrieves the invocation parameters for a kernel invocation node.
    ///
    /// # Panics
    ///
    /// Panics if the node is invalid or if the node is not a kernel invocation node.
    pub fn kernel_node_params(&mut self, node: GraphNode) -> CudaResult<KernelInvocation> {
        self.check_deps_are_valid("kernel_node_params", &[node])?;
        assert_eq!(
            self.node_type(node)?,
            GraphNodeType::KernelInvocation,
            "Node given to `kernel_node_params` was not a kernel invocation node"
        );
        unsafe {
            let mut params = MaybeUninit::uninit();
            cuda::cuGraphKernelNodeGetParams(node.to_raw(), params.as_mut_ptr());
            Ok(KernelInvocation::from_raw(params.assume_init()))
        }
    }

    /// Creates a new [`Graph`] from a raw handle.
    ///
    /// # Safety
    ///
    /// This assumes a couple of things:
    /// - This handle is exclusive, nothing else can use it in any way, including trying to drop it.
    /// - It must be a valid handle. This invariant must be upheld, the library is allowed to rely on
    /// the fact that the handle is valid in terms of safety, therefore failure to uphold this invariant is UB.
    pub unsafe fn from_raw(raw: cuda::CUgraph) -> Self {
        Self {
            raw,
            node_cache: None,
        }
    }

    /// Consumes this [`Graph`], turning it into a raw handle. The handle will not be dropped,
    /// it is up to the caller to ensure the graph is destroyed.
    pub fn into_raw(self) -> cuda::CUgraph {
        let me = ManuallyDrop::new(self);
        me.raw
    }
}

impl Drop for Graph {
    fn drop(&mut self) {
        unsafe {
            cuda::cuGraphDestroy(self.raw);
        }
    }
}