kn_cuda_sys/wrapper/
graph.rs

1use std::ptr::null_mut;
2
3use crate::bindings::{
4    cudaGraphCreate, cudaGraphDestroy, cudaGraphExecDestroy, cudaGraphExec_t, cudaGraphInstantiate, cudaGraphLaunch,
5    cudaGraph_t,
6};
7use crate::wrapper::handle::CudaStream;
8use crate::wrapper::status::Status;
9
10#[derive(Debug)]
11pub struct CudaGraph {
12    inner: cudaGraph_t,
13}
14
15impl Drop for CudaGraph {
16    fn drop(&mut self) {
17        unsafe { cudaGraphDestroy(self.inner).unwrap_in_drop() }
18    }
19}
20
21impl CudaGraph {
22    pub fn new() -> Self {
23        unsafe {
24            let mut inner = null_mut();
25            cudaGraphCreate(&mut inner as *mut _, 0).unwrap();
26            CudaGraph { inner }
27        }
28    }
29
30    pub unsafe fn new_from_inner(inner: cudaGraph_t) -> CudaGraph {
31        CudaGraph { inner }
32    }
33
34    pub unsafe fn instantiate(&self) -> CudaGraphExec {
35        //TODO try printing error string for fun
36        let mut inner = null_mut();
37        cudaGraphInstantiate(&mut inner as *mut _, self.inner(), 0).unwrap();
38        CudaGraphExec { inner }
39    }
40
41    pub unsafe fn inner(&self) -> cudaGraph_t {
42        self.inner
43    }
44}
45
46#[derive(Debug)]
47pub struct CudaGraphExec {
48    inner: cudaGraphExec_t,
49}
50
51impl Drop for CudaGraphExec {
52    fn drop(&mut self) {
53        unsafe { cudaGraphExecDestroy(self.inner).unwrap_in_drop() }
54    }
55}
56
57impl CudaGraphExec {
58    pub unsafe fn inner(&self) -> cudaGraphExec_t {
59        self.inner
60    }
61
62    pub unsafe fn launch(&self, stream: &CudaStream) {
63        cudaGraphLaunch(self.inner(), stream.inner()).unwrap();
64    }
65}