ptx_compiler/
lib.rs

1//! Raw and High level bindings to the CUDA NVPTX compiler used to compile PTX to
2//! cubin files.
3
4use std::mem::MaybeUninit;
5
6#[allow(warnings)]
7pub mod sys;
8
9trait ToResult {
10    fn to_result(self) -> Result<(), NvptxError>;
11}
12
13impl ToResult for sys::nvPTXCompileResult {
14    fn to_result(self) -> Result<(), NvptxError> {
15        match self {
16            sys::nvPTXCompileResult_NVPTXCOMPILE_SUCCESS => Ok(()),
17            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_INVALID_INPUT => {
18                Err(NvptxError::InvalidInput)
19            }
20            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_COMPILATION_FAILURE => {
21                Err(NvptxError::CompilationFailure)
22            }
23            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_INTERNAL => Err(NvptxError::Internal),
24            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_OUT_OF_MEMORY => {
25                Err(NvptxError::OutOfMemory)
26            }
27            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION => {
28                Err(NvptxError::UnsupportedPtxVersion)
29            }
30            // these two are statically prevented so they should never happen
31            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE => {
32                unreachable!("nvptx yielded an incomplete invocation error")
33            }
34            sys::nvPTXCompileResult_NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE => {
35                unreachable!("nvptx yielded an invalid handle err")
36            }
37            _ => unreachable!(),
38        }
39    }
40}
41
42pub type NvptxResult<T> = Result<T, NvptxError>;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum NvptxError {
46    InvalidInput,
47    CompilationFailure,
48    Internal,
49    OutOfMemory,
50    UnsupportedPtxVersion,
51}
52
53#[repr(transparent)]
54#[derive(Debug)]
55pub struct NvptxCompiler {
56    raw: sys::nvPTXCompilerHandle,
57}
58
59impl NvptxCompiler {
60    /// Create a new compiler from a ptx string.
61    pub fn new(ptx: impl AsRef<str>) -> NvptxResult<Self> {
62        let ptx = ptx.as_ref();
63        let mut raw = MaybeUninit::uninit();
64
65        unsafe {
66            sys::nvPTXCompilerCreate(raw.as_mut_ptr(), ptx.len() as u64, ptx.as_ptr().cast())
67                .to_result()?;
68            let raw = raw.assume_init();
69            Ok(Self { raw })
70        }
71    }
72}
73
74impl Drop for NvptxCompiler {
75    fn drop(&mut self) {
76        unsafe {
77            sys::nvPTXCompilerDestroy(&mut self.raw as *mut _)
78                .to_result()
79                .expect("failed to destroy nvptx compiler");
80        }
81    }
82}
83
84#[derive(Debug)]
85pub struct CompilerFailure {
86    pub error: NvptxError,
87    handle: sys::nvPTXCompilerHandle,
88}
89
90impl Drop for CompilerFailure {
91    fn drop(&mut self) {
92        unsafe {
93            sys::nvPTXCompilerDestroy(&mut self.handle as *mut _)
94                .to_result()
95                .expect("failed to destroy nvptx compiler failure");
96        }
97    }
98}
99
100impl CompilerFailure {
101    pub fn error_log(&self) -> NvptxResult<String> {
102        let mut size = MaybeUninit::uninit();
103        unsafe {
104            sys::nvPTXCompilerGetErrorLogSize(self.handle, size.as_mut_ptr()).to_result()?;
105            let size = size.assume_init() as usize;
106            let mut vec = Vec::with_capacity(size);
107            sys::nvPTXCompilerGetErrorLog(self.handle, vec.as_mut_ptr() as *mut i8).to_result()?;
108            vec.set_len(size);
109            Ok(String::from_utf8_lossy(&vec).to_string())
110        }
111    }
112}
113
114/// The result of a compiled program
115#[derive(Debug)]
116pub struct CompiledProgram {
117    pub cubin: Vec<u8>,
118    handle: sys::nvPTXCompilerHandle,
119}
120
121impl Drop for CompiledProgram {
122    fn drop(&mut self) {
123        unsafe {
124            sys::nvPTXCompilerDestroy(&mut self.handle as *mut _)
125                .to_result()
126                .expect("failed to destroy nvptx compiled program");
127        }
128    }
129}
130
131impl CompiledProgram {
132    pub fn info_log(&self) -> NvptxResult<String> {
133        let mut size = MaybeUninit::uninit();
134        unsafe {
135            sys::nvPTXCompilerGetInfoLogSize(self.handle, size.as_mut_ptr()).to_result()?;
136            let size = size.assume_init() as usize;
137            let mut vec = Vec::with_capacity(size);
138            sys::nvPTXCompilerGetInfoLog(self.handle, vec.as_mut_ptr() as *mut i8).to_result()?;
139            vec.set_len(size);
140            Ok(String::from_utf8_lossy(&vec).to_string())
141        }
142    }
143}