1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use std::ptr::null_mut;

use crate::bindings::{
    cudaGraphCreate, cudaGraphDestroy, cudaGraphExecDestroy, cudaGraphExec_t, cudaGraphInstantiate, cudaGraphLaunch,
    cudaGraph_t,
};
use crate::wrapper::handle::CudaStream;
use crate::wrapper::status::Status;

#[derive(Debug)]
pub struct CudaGraph {
    inner: cudaGraph_t,
}

impl Drop for CudaGraph {
    fn drop(&mut self) {
        unsafe { cudaGraphDestroy(self.inner).unwrap_in_drop() }
    }
}

impl CudaGraph {
    pub fn new() -> Self {
        unsafe {
            let mut inner = null_mut();
            cudaGraphCreate(&mut inner as *mut _, 0).unwrap();
            CudaGraph { inner }
        }
    }

    pub unsafe fn new_from_inner(inner: cudaGraph_t) -> CudaGraph {
        CudaGraph { inner }
    }

    pub unsafe fn instantiate(&self) -> CudaGraphExec {
        //TODO try printing error string for fun
        let mut inner = null_mut();
        cudaGraphInstantiate(&mut inner as *mut _, self.inner(), 0).unwrap();
        CudaGraphExec { inner }
    }

    pub unsafe fn inner(&self) -> cudaGraph_t {
        self.inner
    }
}

#[derive(Debug)]
pub struct CudaGraphExec {
    inner: cudaGraphExec_t,
}

impl Drop for CudaGraphExec {
    fn drop(&mut self) {
        unsafe { cudaGraphExecDestroy(self.inner).unwrap_in_drop() }
    }
}

impl CudaGraphExec {
    pub unsafe fn inner(&self) -> cudaGraphExec_t {
        self.inner
    }

    pub unsafe fn launch(&self, stream: &CudaStream) {
        cudaGraphLaunch(self.inner(), stream.inner()).unwrap();
    }
}