1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
mod error;
mod ffi;

use std::{
    ffi::CString,
    ptr::{null, null_mut}, os::raw::c_char,
};

pub use ffi::*;

use self::error::NvrtcResult;

pub struct NvrtcProgram(pub nvrtcProgram);

impl NvrtcProgram {
    pub fn compile(&self, options: Option<Vec<CString>>) -> NvrtcResult<()> {
        compile_program(self, options)
    }

    pub fn ptx(&self) -> NvrtcResult<CString> {
        get_ptx(self)
    }
}

pub fn create_program(src: &str, name: &str) -> NvrtcResult<NvrtcProgram> {
    let src = CString::new(src).unwrap();
    let name = CString::new(name).unwrap();

    let mut prog = NvrtcProgram(null_mut());
    unsafe {
        nvrtcCreateProgram(
            &mut prog.0,
            src.as_ptr(),
            name.as_ptr(),
            0,
            null_mut(),
            null_mut(),
        )
    }
    .to_result()?;
    Ok(prog)
}
pub fn compile_program(prog: &NvrtcProgram, options: Option<Vec<CString>>) -> NvrtcResult<()> {
    /*
    let (num_options, options) = match options {
        Some(options) => (options.len(), options.as_ptr()),
        None => (0, null()),
    };
    */
    match options {
        Some(options) => {
            let options = options
                .iter()
                .map(|option| option.as_ptr())
                .collect::<Vec<*const c_char>>();
            unsafe { nvrtcCompileProgram(prog.0, options.len() as i32, options.as_ptr()) }
                .to_result()
        }
        None => unsafe { nvrtcCompileProgram(prog.0, 0, null()) }.to_result(),
    }
    //unsafe { nvrtcCompileProgram(prog.0, num_options as i32, options as *const *const i8) }.to_result()
}

pub fn get_ptx(prog: &NvrtcProgram) -> NvrtcResult<CString> {
    unsafe {
        let mut ptx_size = 0;
        nvrtcGetPTXSize(prog.0, &mut ptx_size).to_result()?;
        let mut src: Vec<u8> = vec![0; ptx_size as usize];
        nvrtcGetPTX(prog.0, src.as_mut_ptr() as *mut c_char).to_result()?;
        Ok(CString::from_vec_with_nul_unchecked(src))
    }
}