1use std::mem::MaybeUninit;
5
6#[allow(warnings)]
7pub mod sys;
8
9trait ToResult {
10 fn to_result(self) -> Result<(), NvptxError>;
11}
12
13impl ToResult for sys::nvPTXCompileResult {
14 fn to_result(self) -> Result<(), NvptxError> {
15 match self {
16 sys::nvPTXCompileResult_NVPTXCOMPILE_SUCCESS => Ok(()),
17 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_INVALID_INPUT => {
18 Err(NvptxError::InvalidInput)
19 }
20 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_COMPILATION_FAILURE => {
21 Err(NvptxError::CompilationFailure)
22 }
23 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_INTERNAL => Err(NvptxError::Internal),
24 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_OUT_OF_MEMORY => {
25 Err(NvptxError::OutOfMemory)
26 }
27 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION => {
28 Err(NvptxError::UnsupportedPtxVersion)
29 }
30 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE => {
32 unreachable!("nvptx yielded an incomplete invocation error")
33 }
34 sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE => {
35 unreachable!("nvptx yielded an invalid handle err")
36 }
37 _ => unreachable!(),
38 }
39 }
40}
41
42pub type NvptxResult<T> = Result<T, NvptxError>;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum NvptxError {
46 InvalidInput,
47 CompilationFailure,
48 Internal,
49 OutOfMemory,
50 UnsupportedPtxVersion,
51}
52
53#[repr(transparent)]
54#[derive(Debug)]
55pub struct NvptxCompiler {
56 raw: sys::nvPTXCompilerHandle,
57}
58
59impl NvptxCompiler {
60 pub fn new(ptx: impl AsRef<str>) -> NvptxResult<Self> {
62 let ptx = ptx.as_ref();
63 let mut raw = MaybeUninit::uninit();
64
65 unsafe {
66 sys::nvPTXCompilerCreate(raw.as_mut_ptr(), ptx.len() as u64, ptx.as_ptr().cast())
67 .to_result()?;
68 let raw = raw.assume_init();
69 Ok(Self { raw })
70 }
71 }
72}
73
74impl Drop for NvptxCompiler {
75 fn drop(&mut self) {
76 unsafe {
77 sys::nvPTXCompilerDestroy(&mut self.raw as *mut _)
78 .to_result()
79 .expect("failed to destroy nvptx compiler");
80 }
81 }
82}
83
84#[derive(Debug)]
85pub struct CompilerFailure {
86 pub error: NvptxError,
87 handle: sys::nvPTXCompilerHandle,
88}
89
90impl Drop for CompilerFailure {
91 fn drop(&mut self) {
92 unsafe {
93 sys::nvPTXCompilerDestroy(&mut self.handle as *mut _)
94 .to_result()
95 .expect("failed to destroy nvptx compiler failure");
96 }
97 }
98}
99
100impl CompilerFailure {
101 pub fn error_log(&self) -> NvptxResult<String> {
102 let mut size = MaybeUninit::uninit();
103 unsafe {
104 sys::nvPTXCompilerGetErrorLogSize(self.handle, size.as_mut_ptr()).to_result()?;
105 let size = size.assume_init() as usize;
106 let mut vec = Vec::with_capacity(size);
107 sys::nvPTXCompilerGetErrorLog(self.handle, vec.as_mut_ptr() as *mut i8).to_result()?;
108 vec.set_len(size);
109 Ok(String::from_utf8_lossy(&vec).to_string())
110 }
111 }
112}
113
114#[derive(Debug)]
116pub struct CompiledProgram {
117 pub cubin: Vec<u8>,
118 handle: sys::nvPTXCompilerHandle,
119}
120
121impl Drop for CompiledProgram {
122 fn drop(&mut self) {
123 unsafe {
124 sys::nvPTXCompilerDestroy(&mut self.handle as *mut _)
125 .to_result()
126 .expect("failed to destroy nvptx compiled program");
127 }
128 }
129}
130
131impl CompiledProgram {
132 pub fn info_log(&self) -> NvptxResult<String> {
133 let mut size = MaybeUninit::uninit();
134 unsafe {
135 sys::nvPTXCompilerGetInfoLogSize(self.handle, size.as_mut_ptr()).to_result()?;
136 let size = size.assume_init() as usize;
137 let mut vec = Vec::with_capacity(size);
138 sys::nvPTXCompilerGetInfoLog(self.handle, vec.as_mut_ptr() as *mut i8).to_result()?;
139 vec.set_len(size);
140 Ok(String::from_utf8_lossy(&vec).to_string())
141 }
142 }
143}