use std::{
ffi::{CStr, CString},
fmt::Display,
mem::MaybeUninit,
ptr::null_mut,
str::FromStr,
};
#[allow(warnings, clippy::warnings)]
pub mod sys;
pub fn ir_version() -> (i32, i32) {
unsafe {
let mut major_ir = MaybeUninit::uninit();
let mut minor_ir = MaybeUninit::uninit();
let mut major_dbg = MaybeUninit::uninit();
let mut minor_dbg = MaybeUninit::uninit();
sys::nvvmIRVersion(
major_ir.as_mut_ptr(),
minor_ir.as_mut_ptr(),
major_dbg.as_mut_ptr(),
minor_dbg.as_mut_ptr(),
);
(major_ir.assume_init(), minor_ir.assume_init())
}
}
pub fn dbg_version() -> (i32, i32) {
unsafe {
let mut major_ir = MaybeUninit::uninit();
let mut minor_ir = MaybeUninit::uninit();
let mut major_dbg = MaybeUninit::uninit();
let mut minor_dbg = MaybeUninit::uninit();
sys::nvvmIRVersion(
major_ir.as_mut_ptr(),
minor_ir.as_mut_ptr(),
major_dbg.as_mut_ptr(),
minor_dbg.as_mut_ptr(),
);
(major_dbg.assume_init(), minor_dbg.assume_init())
}
}
pub fn nvvm_version() -> (i32, i32) {
unsafe {
let mut major = MaybeUninit::uninit();
let mut minor = MaybeUninit::uninit();
sys::nvvmVersion(major.as_mut_ptr(), minor.as_mut_ptr());
(major.assume_init(), minor.assume_init())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NvvmError {
OutOfMemory,
ProgramCreationFailure,
IrVersionMismatch,
InvalidInput,
InvalidIr,
InvalidOption,
NoModuleInProgram,
CompilationError,
}
impl Display for NvvmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unsafe {
let ptr = sys::nvvmGetErrorString(self.to_raw());
f.write_str(&CStr::from_ptr(ptr).to_string_lossy())
}
}
}
impl NvvmError {
fn to_raw(self) -> sys::nvvmResult {
match self {
NvvmError::CompilationError => sys::nvvmResult_NVVM_ERROR_COMPILATION,
NvvmError::OutOfMemory => sys::nvvmResult_NVVM_ERROR_OUT_OF_MEMORY,
NvvmError::ProgramCreationFailure => {
sys::nvvmResult_NVVM_ERROR_PROGRAM_CREATION_FAILURE
}
NvvmError::IrVersionMismatch => sys::nvvmResult_NVVM_ERROR_IR_VERSION_MISMATCH,
NvvmError::InvalidOption => sys::nvvmResult_NVVM_ERROR_INVALID_OPTION,
NvvmError::InvalidInput => sys::nvvmResult_NVVM_ERROR_INVALID_INPUT,
NvvmError::InvalidIr => sys::nvvmResult_NVVM_ERROR_INVALID_IR,
NvvmError::NoModuleInProgram => sys::nvvmResult_NVVM_ERROR_NO_MODULE_IN_PROGRAM,
}
}
fn from_raw(result: sys::nvvmResult) -> Self {
use NvvmError::*;
match result {
sys::nvvmResult_NVVM_ERROR_COMPILATION => CompilationError,
sys::nvvmResult_NVVM_ERROR_OUT_OF_MEMORY => OutOfMemory,
sys::nvvmResult_NVVM_ERROR_PROGRAM_CREATION_FAILURE => ProgramCreationFailure,
sys::nvvmResult_NVVM_ERROR_IR_VERSION_MISMATCH => IrVersionMismatch,
sys::nvvmResult_NVVM_ERROR_INVALID_OPTION => InvalidOption,
sys::nvvmResult_NVVM_ERROR_INVALID_INPUT => InvalidInput,
sys::nvvmResult_NVVM_ERROR_INVALID_IR => InvalidIr,
sys::nvvmResult_NVVM_ERROR_NO_MODULE_IN_PROGRAM => NoModuleInProgram,
sys::nvvmResult_NVVM_SUCCESS => panic!(),
_ => unreachable!(),
}
}
}
trait ToNvvmResult {
fn to_result(self) -> Result<(), NvvmError>;
}
impl ToNvvmResult for sys::nvvmResult {
fn to_result(self) -> Result<(), NvvmError> {
let err = match self {
sys::nvvmResult_NVVM_SUCCESS => return Ok(()),
_ => NvvmError::from_raw(self),
};
Err(err)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NvvmOption {
GenDebugInfo,
GenLineInfo,
NoOpts,
Arch(NvvmArch),
Ftz,
FastSqrt,
FastDiv,
NoFmaContraction,
}
impl Display for NvvmOption {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let res = match self {
Self::GenDebugInfo => "-g",
Self::GenLineInfo => "-generate-line-info",
Self::NoOpts => "-opt=0",
Self::Arch(arch) => return f.write_str(&format!("-arch={}", arch)),
Self::Ftz => "-ftz=1",
Self::FastSqrt => "-prec-sqrt=0",
Self::FastDiv => "-prec-div=0",
Self::NoFmaContraction => "-fma=0",
};
f.write_str(res)
}
}
impl FromStr for NvvmOption {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
Ok(match s {
"-g" => Self::GenDebugInfo,
"-generate-line-info" => Self::GenLineInfo,
_ if s.starts_with("-opt=") => {
let slice = &s[5..];
if slice == "0" {
Self::NoOpts
} else if slice == "3" {
return Err("-opt=3 is default");
} else {
return Err("unknown optimization level");
}
}
_ if s.starts_with("-ftz=") => {
let slice = &s[5..];
if slice == "1" {
Self::Ftz
} else if slice == "0" {
return Err("-ftz=0 is default");
} else {
return Err("unknown ftz option");
}
}
_ if s.starts_with("-prec-sqrt=") => {
let slice = &s[11..];
if slice == "0" {
Self::FastSqrt
} else if slice == "1" {
return Err("-prec-sqrt=1 is default");
} else {
return Err("unknown prec-sqrt option");
}
}
_ if s.starts_with("-prec-div=") => {
let slice = &s[10..];
if slice == "0" {
Self::FastDiv
} else if slice == "1" {
return Err("-prec-div=1 is default");
} else {
return Err("unknown prec-div option");
}
}
_ if s.starts_with("-fma=") => {
let slice = &s[5..];
if slice == "0" {
Self::NoFmaContraction
} else if slice == "1" {
return Err("-fma=1 is default");
} else {
return Err("unknown fma option");
}
}
_ if s.starts_with("-arch=") => {
let slice = &s[6..];
let arch_num = &slice[8..];
let arch = match arch_num {
"35" => NvvmArch::Compute35,
"37" => NvvmArch::Compute37,
"50" => NvvmArch::Compute50,
"52" => NvvmArch::Compute52,
"53" => NvvmArch::Compute53,
"60" => NvvmArch::Compute60,
"61" => NvvmArch::Compute61,
"62" => NvvmArch::Compute62,
"70" => NvvmArch::Compute70,
"72" => NvvmArch::Compute72,
"75" => NvvmArch::Compute75,
"80" => NvvmArch::Compute80,
_ => return Err("unknown arch"),
};
Self::Arch(arch)
}
_ => return Err("umknown option"),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NvvmArch {
Compute35,
Compute37,
Compute50,
Compute52,
Compute53,
Compute60,
Compute61,
Compute62,
Compute70,
Compute72,
Compute75,
Compute80,
}
impl Display for NvvmArch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut raw = format!("{:?}", self).to_ascii_lowercase();
raw.insert(7, '_');
f.write_str(&raw)
}
}
impl Default for NvvmArch {
fn default() -> Self {
Self::Compute52
}
}
pub struct NvvmProgram {
raw: sys::nvvmProgram,
}
impl Drop for NvvmProgram {
fn drop(&mut self) {
unsafe {
sys::nvvmDestroyProgram(&mut self.raw as *mut _)
.to_result()
.expect("failed to destroy nvvm program");
}
}
}
impl NvvmProgram {
pub fn new() -> Result<Self, NvvmError> {
unsafe {
let mut raw = MaybeUninit::uninit();
sys::nvvmCreateProgram(raw.as_mut_ptr()).to_result()?;
Ok(Self {
raw: raw.assume_init(),
})
}
}
pub fn compile(&self, options: &[NvvmOption]) -> Result<Vec<u8>, NvvmError> {
unsafe {
let options = options
.iter()
.map(|x| format!("{}\0", x))
.collect::<Vec<_>>();
let mut options_ptr = options
.iter()
.map(|x| x.as_ptr().cast())
.collect::<Vec<_>>();
sys::nvvmCompileProgram(self.raw, options.len() as i32, options_ptr.as_mut_ptr())
.to_result()?;
let mut size = 0;
sys::nvvmGetCompiledResultSize(self.raw, &mut size as *mut usize as *mut _)
.to_result()?;
let mut buf: Vec<u8> = Vec::with_capacity(size);
sys::nvvmGetCompiledResult(self.raw, buf.as_mut_ptr().cast()).to_result()?;
buf.set_len(size);
buf.pop();
Ok(buf)
}
}
pub fn add_module(&self, bitcode: &[u8], name: String) -> Result<(), NvvmError> {
unsafe {
let cstring = CString::new(name).expect("module name with nul");
sys::nvvmAddModuleToProgram(
self.raw,
bitcode.as_ptr().cast(),
bitcode.len() as u64,
cstring.as_ptr(),
)
.to_result()
}
}
pub fn add_lazy_module(&self, bitcode: &[u8], name: String) -> Result<(), NvvmError> {
unsafe {
let cstring = CString::new(name).expect("module name with nul");
sys::nvvmLazyAddModuleToProgram(
self.raw,
bitcode.as_ptr().cast(),
bitcode.len() as u64,
cstring.as_ptr(),
)
.to_result()
}
}
pub fn compiler_log(&self) -> Result<Option<String>, NvvmError> {
unsafe {
let mut size = MaybeUninit::uninit();
sys::nvvmGetProgramLogSize(self.raw, size.as_mut_ptr()).to_result()?;
let size = size.assume_init() as usize;
let mut buf: Vec<u8> = Vec::with_capacity(size);
sys::nvvmGetProgramLog(self.raw, buf.as_mut_ptr().cast()).to_result()?;
buf.set_len(size);
buf.pop();
let string = String::from_utf8(buf).expect("nvvm compiler log was not utf8");
Ok(Some(string).filter(|s| !s.is_empty()))
}
}
pub fn verify(&self) -> Result<(), NvvmError> {
unsafe { sys::nvvmVerifyProgram(self.raw, 0, null_mut()).to_result() }
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
#[test]
fn options_parse_correctly() {
use crate::NvvmArch::*;
use crate::NvvmOption::{self, *};
let opts = vec![
"-g",
"-generate-line-info",
"-opt=0",
"-arch=compute_35",
"-arch=compute_37",
"-arch=compute_50",
"-arch=compute_52",
"-arch=compute_53",
"-arch=compute_60",
"-arch=compute_61",
"-arch=compute_62",
"-arch=compute_70",
"-arch=compute_72",
"-arch=compute_75",
"-arch=compute_80",
"-ftz=1",
"-prec-sqrt=0",
"-prec-div=0",
"-fma=0",
];
let expected = vec![
GenDebugInfo,
GenLineInfo,
NoOpts,
Arch(Compute35),
Arch(Compute37),
Arch(Compute50),
Arch(Compute52),
Arch(Compute53),
Arch(Compute60),
Arch(Compute61),
Arch(Compute62),
Arch(Compute70),
Arch(Compute72),
Arch(Compute75),
Arch(Compute80),
Ftz,
FastSqrt,
FastDiv,
NoFmaContraction,
];
let found = opts
.into_iter()
.map(|x| NvvmOption::from_str(x).unwrap())
.collect::<Vec<_>>();
assert_eq!(found, expected);
}
}