1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
use super::{
    cuCtxCreate_v2, cuCtxDestroy, cuDeviceGet, cuDeviceGetCount, cuInit, cuLaunchKernel,
    cuMemFree_v2, cuMemcpyDtoH_v2, cuMemcpyHtoD_v2, cuModuleGetFunction, cuModuleLoad,
    cuModuleLoadData, cuModuleUnload, cuStreamCreate, cuStreamSynchronize,
    error::{CudaErrorKind, CudaResult},
    ffi::cuMemAlloc_v2,
    CUcontext, CUdevice, CUfunction, CUmodule, CUstream,
};
use crate::CUdeviceptr;
use std::{
    ffi::{c_void, CString},
    ptr::null_mut,
};

pub fn cinit(flags: u32) -> CudaResult<()> {
    unsafe { cuInit(flags).into() }
}

#[derive(Debug)]
pub struct CudaIntDevice(pub CUdevice);

pub fn device_count() -> CudaResult<i32> {
    let mut count = 0;
    unsafe { cuDeviceGetCount(&mut count as *mut i32) }.to_result()?;
    Ok(count)
}

// TODO: cuda set device
pub fn device(ordinal: i32) -> CudaResult<CudaIntDevice> {
    if ordinal >= device_count()? {
        return Err(CudaErrorKind::InvalidDeviceIdx);
    }

    let mut device = CudaIntDevice(0);
    unsafe { cuDeviceGet(&mut device.0 as *mut i32, ordinal) };
    Ok(device)
}

#[derive(Debug)]
pub struct Context(pub CUcontext);

impl Drop for Context {
    fn drop(&mut self) {
        unsafe {
            cuCtxDestroy(self.0);
        }
    }
}

pub fn create_context(device: &CudaIntDevice) -> CudaResult<Context> {
    let mut context = Context(null_mut());
    unsafe {
        // TODO: Flags
        cuCtxCreate_v2(&mut context.0 as *mut CUcontext, 0, device.0).to_result()?;
    }
    Ok(context)
}

pub fn cumalloc<T>(len: usize) -> CudaResult<CUdeviceptr> {
    let bytes = len * core::mem::size_of::<T>();

    if bytes == 0 {
        return Err(CudaErrorKind::InvalidAllocSize);
    }

    let mut ptr: CUdeviceptr = 0;
    unsafe { cuMemAlloc_v2(&mut ptr, bytes).to_result()? };
    Ok(ptr)
}

/// Free CUDA GPU memory
/// # Safety
/// FFI, `ptr` must be a valid pointer.
pub unsafe fn cufree(ptr: CUdeviceptr) -> CudaResult<()> {
    cuMemFree_v2(ptr).into()
}

pub fn cu_write<T>(dst: CUdeviceptr, src_host: &[T]) -> CudaResult<()> {
    let bytes_to_copy = src_host.len() * std::mem::size_of::<T>();
    unsafe { cuMemcpyHtoD_v2(dst, src_host.as_ptr() as *const c_void, bytes_to_copy) }.into()
}

pub fn cu_read<T>(dst_host: &mut [T], src: CUdeviceptr) -> CudaResult<()> {
    let bytes_to_copy = dst_host.len() * std::mem::size_of::<T>();
    unsafe { cuMemcpyDtoH_v2(dst_host.as_mut_ptr() as *mut c_void, src, bytes_to_copy) }.into()
}

#[derive(Debug)]
pub struct Module(pub CUmodule);

impl Module {
    pub fn function(&self, fn_name: &str) -> CudaResult<FnHandle> {
        module_get_fn(self, fn_name)
    }
}

impl Drop for Module {
    fn drop(&mut self) {
        unsafe {
            cuModuleUnload(self.0).to_result().unwrap();
        }
    }
}

pub fn load_module(fname: &str) -> CudaResult<Module> {
    let fname = CString::new(fname).unwrap();

    let mut module = Module(null_mut());
    unsafe { cuModuleLoad(&mut module.0, fname.as_ptr()) }.to_result()?;
    Ok(module)
}

pub fn load_module_data(src: CString) -> CudaResult<Module> {
    let mut module = Module(null_mut());
    unsafe { cuModuleLoadData(&mut module.0, src.as_ptr() as *const c_void) }.to_result()?;
    Ok(module)
}

#[derive(Debug, Clone, Copy)]
pub struct FnHandle(pub CUfunction);

pub fn module_get_fn(module: &Module, fn_name: &str) -> CudaResult<FnHandle> {
    let fn_name = CString::new(fn_name).unwrap();

    let mut handle = FnHandle(null_mut());
    unsafe { cuModuleGetFunction(&mut handle.0, module.0, fn_name.as_ptr()) }.to_result()?;
    Ok(handle)
}

#[derive(Debug)]
pub struct Stream(pub CUstream);

impl Stream {
    pub fn sync(&self) -> CudaResult<()> {
        unsafe { cuStreamSynchronize(self.0) }.to_result()
    }
}

pub fn create_stream() -> CudaResult<Stream> {
    let mut ph_stream = Stream(null_mut());
    unsafe { cuStreamCreate(&mut ph_stream.0, 0) }.to_result()?;
    Ok(ph_stream)
}

pub fn culaunch_kernel(
    f: &FnHandle,
    grid: [u32; 3],
    blocks: [u32; 3],
    stream: &Stream,
    params: &[*mut c_void],
) -> CudaResult<()> {
    unsafe {
        cuLaunchKernel(
            f.0,
            grid[0],
            grid[1],
            grid[2],
            blocks[0],
            blocks[1],
            blocks[2],
            0,
            stream.0,
            params.as_ptr() as *mut _,
            std::ptr::null_mut(),
        )
    }
    .to_result()?;

    // TODO: sync here or elsewhere?
    //stream.sync()?;

    //    unsafe {cuCtxSynchronize().to_result()?};

    Ok(())
}