Skip to main content

oxicuda_driver/
ffi_graph.rs

1//! CUDA Graph and stream-ordered memory-pool FFI types.
2//!
3//! Opaque graph handles (`CUgraph`, `CUgraphExec`, `CUgraphNode`), the node
4//! parameter descriptors consumed by `cuGraphAdd*Node` (`CUDA_KERNEL_NODE_PARAMS`,
5//! `CUDA_MEMCPY3D`, `CUDA_MEMSET_NODE_PARAMS`), and the `CUmemPoolAttribute`
6//! discriminant used by `cuMemPoolSetAttribute` / `cuMemPoolGetAttribute`.
7//!
8//! All structs are `#[repr(C)]` and mirror the layout of the corresponding
9//! types in `cuda.h`; trailing reserved fields are part of the published ABI.
10
11use std::ffi::{c_char, c_void};
12
13use super::{CUarray, CUdeviceptr, CUfunction};
14
15// =========================================================================
16// CUgraph / CUgraphExec / CUgraphNode — opaque graph handles
17// =========================================================================
18
19/// Opaque handle to a CUDA graph (`CUgraph`).
20///
21/// A graph is a mutable DAG of operations created by `cuGraphCreate`,
22/// populated by `cuGraphAdd*Node`, and finalised into an executable form
23/// via `cuGraphInstantiate`.
24#[repr(transparent)]
25#[derive(Clone, Copy, PartialEq, Eq, Hash)]
26pub struct CUgraph(pub *mut c_void);
27
28// SAFETY: CUDA graph handles are opaque driver-side identifiers; treating
29// the handle as Send+Sync mirrors the C-side pointer, which the driver may
30// inspect from any thread when properly synchronised.
31unsafe impl Send for CUgraph {}
32unsafe impl Sync for CUgraph {}
33
34impl CUgraph {
35    /// Returns `true` if the handle is null (uninitialised).
36    #[inline]
37    pub fn is_null(self) -> bool {
38        self.0.is_null()
39    }
40}
41
42impl Default for CUgraph {
43    #[inline]
44    fn default() -> Self {
45        Self(std::ptr::null_mut())
46    }
47}
48
49impl std::fmt::Debug for CUgraph {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(f, "CUgraph({:p})", self.0)
52    }
53}
54
55/// Opaque handle to an instantiated, executable CUDA graph (`CUgraphExec`).
56///
57/// Produced by `cuGraphInstantiate`; submitted to a stream by `cuGraphLaunch`
58/// and destroyed by `cuGraphExecDestroy`.
59#[repr(transparent)]
60#[derive(Clone, Copy, PartialEq, Eq, Hash)]
61pub struct CUgraphExec(pub *mut c_void);
62
63// SAFETY: see [`CUgraph`].
64unsafe impl Send for CUgraphExec {}
65unsafe impl Sync for CUgraphExec {}
66
67impl CUgraphExec {
68    /// Returns `true` if the handle is null (uninitialised).
69    #[inline]
70    pub fn is_null(self) -> bool {
71        self.0.is_null()
72    }
73}
74
75impl Default for CUgraphExec {
76    #[inline]
77    fn default() -> Self {
78        Self(std::ptr::null_mut())
79    }
80}
81
82impl std::fmt::Debug for CUgraphExec {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(f, "CUgraphExec({:p})", self.0)
85    }
86}
87
88/// Opaque handle to a single node within a [`CUgraph`] (`CUgraphNode`).
89///
90/// Returned by every `cuGraphAdd*Node` call and used as a dependency
91/// endpoint when wiring graph edges.
92#[repr(transparent)]
93#[derive(Clone, Copy, PartialEq, Eq, Hash)]
94pub struct CUgraphNode(pub *mut c_void);
95
96// SAFETY: see [`CUgraph`].
97unsafe impl Send for CUgraphNode {}
98unsafe impl Sync for CUgraphNode {}
99
100impl CUgraphNode {
101    /// Returns `true` if the handle is null (uninitialised).
102    #[inline]
103    pub fn is_null(self) -> bool {
104        self.0.is_null()
105    }
106}
107
108impl Default for CUgraphNode {
109    #[inline]
110    fn default() -> Self {
111        Self(std::ptr::null_mut())
112    }
113}
114
115impl std::fmt::Debug for CUgraphNode {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "CUgraphNode({:p})", self.0)
118    }
119}
120
121// =========================================================================
122// CUDA_KERNEL_NODE_PARAMS — kernel-launch node descriptor
123// =========================================================================
124
125/// Parameters for a kernel-launch graph node, consumed by
126/// `cuGraphAddKernelNode`.
127///
128/// Mirrors `CUDA_KERNEL_NODE_PARAMS` (the pre-12.0 layout, which remains
129/// ABI-stable and accepted by the driver). `kernel_params` points to an
130/// array of pointers to the individual kernel arguments; `extra` is the
131/// alternative `CU_LAUNCH_PARAM_*` packing mechanism and is normally null.
132#[repr(C)]
133#[derive(Debug, Clone, Copy)]
134pub struct CUDA_KERNEL_NODE_PARAMS {
135    /// Kernel function to launch.
136    pub func: CUfunction,
137    /// Grid dimension X (number of blocks).
138    pub grid_dim_x: u32,
139    /// Grid dimension Y (number of blocks).
140    pub grid_dim_y: u32,
141    /// Grid dimension Z (number of blocks).
142    pub grid_dim_z: u32,
143    /// Block dimension X (threads per block).
144    pub block_dim_x: u32,
145    /// Block dimension Y (threads per block).
146    pub block_dim_y: u32,
147    /// Block dimension Z (threads per block).
148    pub block_dim_z: u32,
149    /// Dynamic shared-memory size in bytes.
150    pub shared_mem_bytes: u32,
151    /// Array of pointers to kernel arguments; null when `extra` is used.
152    pub kernel_params: *mut *mut c_void,
153    /// Alternative argument-packing buffer (`CU_LAUNCH_PARAM_*`); usually null.
154    pub extra: *mut *mut c_void,
155}
156
157// SAFETY: the struct carries raw pointers to caller-owned argument buffers;
158// the driver treats them as opaque. Mirroring the C struct, it is logically
159// Send+Sync.
160unsafe impl Send for CUDA_KERNEL_NODE_PARAMS {}
161unsafe impl Sync for CUDA_KERNEL_NODE_PARAMS {}
162
163impl Default for CUDA_KERNEL_NODE_PARAMS {
164    fn default() -> Self {
165        Self {
166            func: CUfunction::default(),
167            grid_dim_x: 0,
168            grid_dim_y: 0,
169            grid_dim_z: 0,
170            block_dim_x: 0,
171            block_dim_y: 0,
172            block_dim_z: 0,
173            shared_mem_bytes: 0,
174            kernel_params: std::ptr::null_mut(),
175            extra: std::ptr::null_mut(),
176        }
177    }
178}
179
180// =========================================================================
181// CUDA_MEMCPY3D — descriptor for `cuGraphAddMemcpyNode` / `cuMemcpy3D`
182// =========================================================================
183
184/// Descriptor for a 3-D memory copy, consumed by `cuGraphAddMemcpyNode`
185/// (and `cuMemcpy3D`).
186///
187/// Mirrors `CUDA_MEMCPY3D` in `cuda.h`. The driver inspects only the fields
188/// appropriate for the chosen source / destination memory types; the rest
189/// **must** be zeroed. Use [`CUDA_MEMCPY3D::default`] to obtain a
190/// zero-initialised descriptor and set only the fields you need.
191#[repr(C)]
192#[derive(Debug, Clone, Copy)]
193pub struct CUDA_MEMCPY3D {
194    /// Source X offset in bytes.
195    pub src_x_in_bytes: usize,
196    /// Source Y offset in rows.
197    pub src_y: usize,
198    /// Source Z offset in slices.
199    pub src_z: usize,
200    /// Source LOD (level of detail).
201    pub src_lod: usize,
202    /// Source memory type; see [`super::CUmemorytype`].
203    pub src_memory_type: u32,
204    /// Source host pointer (valid when `src_memory_type == Host`).
205    pub src_host: *const c_void,
206    /// Source device pointer (valid when `src_memory_type == Device`).
207    pub src_device: CUdeviceptr,
208    /// Source CUDA array (valid when `src_memory_type == Array`).
209    pub src_array: CUarray,
210    /// Reserved; must be null.
211    pub reserved0: *mut c_void,
212    /// Source pitch in bytes (`0` selects a tightly-packed layout).
213    pub src_pitch: usize,
214    /// Source height in rows (`0` selects a tightly-packed layout).
215    pub src_height: usize,
216    /// Destination X offset in bytes.
217    pub dst_x_in_bytes: usize,
218    /// Destination Y offset in rows.
219    pub dst_y: usize,
220    /// Destination Z offset in slices.
221    pub dst_z: usize,
222    /// Destination LOD (level of detail).
223    pub dst_lod: usize,
224    /// Destination memory type; see [`super::CUmemorytype`].
225    pub dst_memory_type: u32,
226    /// Destination host pointer (valid when `dst_memory_type == Host`).
227    pub dst_host: *mut c_void,
228    /// Destination device pointer (valid when `dst_memory_type == Device`).
229    pub dst_device: CUdeviceptr,
230    /// Destination CUDA array (valid when `dst_memory_type == Array`).
231    pub dst_array: CUarray,
232    /// Reserved; must be null.
233    pub reserved1: *mut c_void,
234    /// Destination pitch in bytes (`0` selects a tightly-packed layout).
235    pub dst_pitch: usize,
236    /// Destination height in rows (`0` selects a tightly-packed layout).
237    pub dst_height: usize,
238    /// Width of the copied region in bytes.
239    pub width_in_bytes: usize,
240    /// Height of the copied region in rows.
241    pub height: usize,
242    /// Depth of the copied region in slices.
243    pub depth: usize,
244}
245
246// SAFETY: carries raw pointers / array handles to caller-owned memory; the
247// driver treats them as opaque. Mirroring the C struct, it is Send+Sync.
248unsafe impl Send for CUDA_MEMCPY3D {}
249unsafe impl Sync for CUDA_MEMCPY3D {}
250
251impl Default for CUDA_MEMCPY3D {
252    fn default() -> Self {
253        Self {
254            src_x_in_bytes: 0,
255            src_y: 0,
256            src_z: 0,
257            src_lod: 0,
258            src_memory_type: 0,
259            src_host: std::ptr::null(),
260            src_device: 0,
261            src_array: CUarray::default(),
262            reserved0: std::ptr::null_mut(),
263            src_pitch: 0,
264            src_height: 0,
265            dst_x_in_bytes: 0,
266            dst_y: 0,
267            dst_z: 0,
268            dst_lod: 0,
269            dst_memory_type: 0,
270            dst_host: std::ptr::null_mut(),
271            dst_device: 0,
272            dst_array: CUarray::default(),
273            reserved1: std::ptr::null_mut(),
274            dst_pitch: 0,
275            dst_height: 0,
276            width_in_bytes: 0,
277            height: 0,
278            depth: 0,
279        }
280    }
281}
282
283// =========================================================================
284// CUDA_MEMSET_NODE_PARAMS — descriptor for `cuGraphAddMemsetNode`
285// =========================================================================
286
287/// Parameters for a memset graph node, consumed by `cuGraphAddMemsetNode`.
288///
289/// Mirrors `CUDA_MEMSET_NODE_PARAMS` in `cuda.h`. For a 1-D (linear) memset
290/// set `height = 1` and `pitch = 0`; `element_size` is `1`, `2`, or `4`
291/// bytes and `width` is the number of elements per row.
292#[repr(C)]
293#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294pub struct CUDA_MEMSET_NODE_PARAMS {
295    /// Destination device pointer.
296    pub dst: CUdeviceptr,
297    /// Destination pitch in bytes (`0` for a tightly-packed 1-D memset).
298    pub pitch: usize,
299    /// Value to write, interpreted according to `element_size`.
300    pub value: u32,
301    /// Size of each element in bytes (`1`, `2`, or `4`).
302    pub element_size: u32,
303    /// Width of the region in elements.
304    pub width: usize,
305    /// Height of the region in rows (`1` for a 1-D memset).
306    pub height: usize,
307}
308
309impl Default for CUDA_MEMSET_NODE_PARAMS {
310    fn default() -> Self {
311        Self {
312            dst: 0,
313            pitch: 0,
314            value: 0,
315            element_size: 1,
316            width: 0,
317            height: 1,
318        }
319    }
320}
321
322// =========================================================================
323// CUDA_HOST_NODE_PARAMS — descriptor for `cuGraphAddHostNode`
324// =========================================================================
325
326/// Parameters for a host-callback graph node, consumed by
327/// `cuGraphAddHostNode`.
328///
329/// Mirrors `CUDA_HOST_NODE_PARAMS` in `cuda.h`.
330#[repr(C)]
331#[derive(Debug, Clone, Copy)]
332pub struct CUDA_HOST_NODE_PARAMS {
333    /// Host function to execute, of type `void (*)(void *userData)`.
334    pub fn_ptr: Option<unsafe extern "C" fn(user_data: *mut c_void)>,
335    /// Argument forwarded to `fn_ptr`.
336    pub user_data: *mut c_void,
337}
338
339// SAFETY: the struct holds a function pointer and an opaque user-data
340// pointer; both are caller-managed and the driver treats them as opaque.
341unsafe impl Send for CUDA_HOST_NODE_PARAMS {}
342unsafe impl Sync for CUDA_HOST_NODE_PARAMS {}
343
344impl Default for CUDA_HOST_NODE_PARAMS {
345    fn default() -> Self {
346        Self {
347            fn_ptr: None,
348            user_data: std::ptr::null_mut(),
349        }
350    }
351}
352
353// =========================================================================
354// CUmemPoolAttribute — `cuMemPoolSetAttribute` / `cuMemPoolGetAttribute`
355// =========================================================================
356
357/// Attribute discriminant for `cuMemPoolSetAttribute` /
358/// `cuMemPoolGetAttribute`.
359///
360/// Mirrors `CUmemPoolAttribute` in `cuda.h`. The numeric values match the
361/// CUDA header exactly so the enum can be passed straight to the driver.
362#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
363#[repr(u32)]
364#[non_exhaustive]
365pub enum CUmemPoolAttribute {
366    /// `(value type = int)` Allow reuse of memory still in use by an
367    /// operation scheduled with an event dependency.
368    ReuseFollowEventDependencies = 1,
369    /// `(value type = int)` Allow reuse of completed frees with no explicit
370    /// event dependency (opportunistic reuse).
371    ReuseAllowOpportunistic = 2,
372    /// `(value type = int)` Allow the driver to insert internal stream
373    /// dependencies to enable reuse.
374    ReuseAllowInternalDependencies = 3,
375    /// `(value type = cuuint64_t)` Amount of reserved memory (bytes) to hold
376    /// onto before trying to release memory back to the OS.
377    ReleaseThreshold = 4,
378    /// `(value type = cuuint64_t, read-only)` Amount of backing memory
379    /// currently allocated for the pool.
380    ReservedMemCurrent = 5,
381    /// `(value type = cuuint64_t, read/write)` High-water mark of backing
382    /// memory allocated for the pool since the last reset.
383    ReservedMemHigh = 6,
384    /// `(value type = cuuint64_t, read-only)` Amount of memory from the pool
385    /// currently in use by the application.
386    UsedMemCurrent = 7,
387    /// `(value type = cuuint64_t, read/write)` High-water mark of memory in
388    /// use from the pool since the last reset.
389    UsedMemHigh = 8,
390}
391
392// =========================================================================
393// CUgraphInstantiate_flags — flags for `cuGraphInstantiateWithFlags`
394// =========================================================================
395
396/// Instantiate a graph in auto-free-on-launch mode (a finished graph frees
397/// its memory-allocation nodes before the next launch).
398pub const CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH: u64 = 1;
399/// Upload the graph during instantiation.
400pub const CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD: u64 = 2;
401/// Instantiate the graph for launch from the device.
402pub const CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH: u64 = 4;
403/// Run the graph using per-node priorities from the stream it is captured on.
404pub const CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY: u64 = 8;
405
406// =========================================================================
407// CUgraphNodeType — node-type discriminant (informational)
408// =========================================================================
409
410/// Type of a graph node, as reported by `cuGraphNodeGetType`.
411///
412/// Mirrors `CUgraphNodeType` in `cuda.h`. Provided for completeness and
413/// node-type queries; `cuGraphAdd*Node` calls do not require it.
414#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
415#[repr(u32)]
416#[non_exhaustive]
417pub enum CUgraphNodeType {
418    /// GPU kernel-launch node.
419    Kernel = 0,
420    /// Memory-copy node.
421    Memcpy = 1,
422    /// Memory-set node.
423    Memset = 2,
424    /// Host (CPU) callback node.
425    Host = 3,
426    /// Node that executes an embedded child graph.
427    Graph = 4,
428    /// Empty (no-op) node used as a synchronisation barrier.
429    Empty = 5,
430}
431
432// =========================================================================
433// CU_LAUNCH_PARAM_* sentinels (for the `extra` kernel-arg packing buffer)
434// =========================================================================
435
436/// Terminator for the `extra` kernel-argument buffer.
437pub const CU_LAUNCH_PARAM_END: *mut c_void = std::ptr::null_mut();
438
439// `c_char` is referenced by FFI signatures that consume node names; keep the
440// import meaningful so the module stays warning-free.
441#[allow(dead_code)]
442type GraphNodeName = *const c_char;
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn graph_handles_default_to_null() {
450        assert!(CUgraph::default().is_null());
451        assert!(CUgraphExec::default().is_null());
452        assert!(CUgraphNode::default().is_null());
453    }
454
455    #[test]
456    fn kernel_node_params_default_is_zeroed() {
457        let p = CUDA_KERNEL_NODE_PARAMS::default();
458        assert!(p.func.is_null());
459        assert_eq!(p.grid_dim_x, 0);
460        assert_eq!(p.block_dim_x, 0);
461        assert_eq!(p.shared_mem_bytes, 0);
462        assert!(p.kernel_params.is_null());
463        assert!(p.extra.is_null());
464    }
465
466    #[test]
467    fn memset_node_params_default_is_linear() {
468        let p = CUDA_MEMSET_NODE_PARAMS::default();
469        assert_eq!(p.dst, 0);
470        assert_eq!(p.pitch, 0);
471        assert_eq!(p.element_size, 1);
472        assert_eq!(p.height, 1);
473    }
474
475    #[test]
476    fn memcpy3d_default_is_zeroed() {
477        let m = CUDA_MEMCPY3D::default();
478        assert_eq!(m.src_memory_type, 0);
479        assert_eq!(m.dst_memory_type, 0);
480        assert_eq!(m.width_in_bytes, 0);
481        assert_eq!(m.depth, 0);
482        assert!(m.src_host.is_null());
483        assert!(m.reserved0.is_null());
484        assert!(m.reserved1.is_null());
485    }
486
487    #[test]
488    fn host_node_params_default_is_empty() {
489        let p = CUDA_HOST_NODE_PARAMS::default();
490        assert!(p.fn_ptr.is_none());
491        assert!(p.user_data.is_null());
492    }
493
494    #[test]
495    fn mem_pool_attribute_discriminants_match_cuda() {
496        assert_eq!(CUmemPoolAttribute::ReuseFollowEventDependencies as u32, 1);
497        assert_eq!(CUmemPoolAttribute::ReleaseThreshold as u32, 4);
498        assert_eq!(CUmemPoolAttribute::UsedMemHigh as u32, 8);
499    }
500
501    #[test]
502    fn graph_node_type_discriminants_match_cuda() {
503        assert_eq!(CUgraphNodeType::Kernel as u32, 0);
504        assert_eq!(CUgraphNodeType::Empty as u32, 5);
505    }
506}