use crate::builder::unnamed;
use crate::llvm::*;
use crate::lto::ThinBuffer;
use find_cuda_helper::find_cuda_root;
use nvvm::*;
use rustc_codegen_ssa::traits::ThinBufferMethods;
use rustc_session::{config::DebugInfo, Session};
use std::ffi::OsStr;
use std::fmt::Display;
use std::marker::PhantomData;
use std::path::Path;
use std::{fs, ptr};
use tracing::debug;
const LIBINTRINSICS: &[u8] = include_bytes!("../libintrinsics.bc");
pub enum CodegenErr {
Nvvm(NvvmError),
Io(std::io::Error),
}
impl From<std::io::Error> for CodegenErr {
fn from(v: std::io::Error) -> Self {
Self::Io(v)
}
}
impl From<NvvmError> for CodegenErr {
fn from(v: NvvmError) -> Self {
Self::Nvvm(v)
}
}
impl Display for CodegenErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Nvvm(err) => std::fmt::Display::fmt(&err, f),
Self::Io(err) => std::fmt::Display::fmt(&err, f),
}
}
}
pub fn codegen_bitcode_modules(
opts: &[NvvmOption],
sess: &Session,
modules: Vec<Vec<u8>>,
llcx: &Context,
) -> Result<Vec<u8>, CodegenErr> {
debug!("Codegenning bitcode to PTX");
let (major, minor) = nvvm::ir_version();
if minor < 6 || major < 1 {
sess.fatal("rustc_codegen_nvvm requires at least libnvvm 1.6 (CUDA 11.2)");
}
let prog = NvvmProgram::new()?;
let module = merge_llvm_modules(modules, llcx);
unsafe {
internalize_pass(module, llcx);
dce_pass(module);
if sess.opts.debuginfo != DebugInfo::None {
cleanup_dicompileunit(module);
}
let (dbg_major, dbg_minor) = nvvm::dbg_version();
let ty_i32 = LLVMInt32TypeInContext(llcx);
let major = LLVMConstInt(ty_i32, major as u64, False);
let minor = LLVMConstInt(ty_i32, minor as u64, False);
let dbg_major = LLVMConstInt(ty_i32, dbg_major as u64, False);
let dbg_minor = LLVMConstInt(ty_i32, dbg_minor as u64, False);
let vals = vec![major, minor, dbg_major, dbg_minor];
let node = LLVMMDNodeInContext(llcx, vals.as_ptr(), vals.len() as u32);
LLVMAddNamedMetadataOperand(module, "nvvmir.version\0".as_ptr().cast(), node);
}
let buf = ThinBuffer::new(module);
prog.add_module(buf.data(), "merged".to_string())?;
let libdevice = if let Some(bc) = find_libdevice() {
bc
} else {
sess.fatal("Could not find the libdevice library (libdevice.10.bc) in the CUDA directory")
};
prog.add_lazy_module(&libdevice, "libdevice".to_string())?;
prog.add_lazy_module(LIBINTRINSICS, "libintrinsics".to_string())?;
let verification_res = prog.verify();
if verification_res.is_err() {
let log = prog.compiler_log().unwrap().unwrap_or_default();
let footer = "If you plan to submit a bug report please re-run the codegen with `RUSTFLAGS=\"--emit=llvm-ir\" and include the .ll file corresponding to the .o file mentioned in the log";
panic!(
"Malformed NVVM IR program rejected by libnvvm, dumping verifier log:\n\n{}\n\n{}",
log, footer
);
}
let res = match prog.compile(opts) {
Ok(b) => b,
Err(error) => {
panic!(
"libnvvm returned an error that was not previously caught by the verifier: {:?}",
error
);
}
};
Ok(res)
}
pub fn find_libdevice() -> Option<Vec<u8>> {
if let Some(base_path) = find_cuda_root() {
let libdevice_file = fs::read_dir(Path::new(&base_path).join("nvvm").join("libdevice"))
.ok()?
.filter_map(Result::ok)
.find(|f| f.path().extension() == Some(OsStr::new("bc")))?
.path();
fs::read(libdevice_file).ok()
} else {
None
}
}
unsafe fn cleanup_dicompileunit(module: &Module) {
let mut cu1 = ptr::null_mut();
let mut cu2 = ptr::null_mut();
LLVMRustThinLTOGetDICompileUnit(module, &mut cu1, &mut cu2);
LLVMRustThinLTOPatchDICompileUnit(module, cu1);
}
fn merge_llvm_modules(modules: Vec<Vec<u8>>, llcx: &Context) -> &Module {
let module = unsafe { crate::create_module(llcx, "merged_modules") };
for merged_module in modules {
unsafe {
let tmp = LLVMRustParseBitcodeForLTO(
llcx,
merged_module.as_ptr(),
merged_module.len(),
unnamed(),
)
.expect("Failed to parse module bitcode");
LLVMLinkModules2(module, tmp);
}
}
module
}
struct FunctionIter<'a, 'll> {
module: PhantomData<&'a &'ll Module>,
next: Option<&'ll Value>,
}
struct GlobalIter<'a, 'll> {
module: PhantomData<&'a &'ll Module>,
next: Option<&'ll Value>,
}
impl<'a, 'll> FunctionIter<'a, 'll> {
pub fn new(module: &'a &'ll Module) -> Self {
FunctionIter {
module: PhantomData::default(),
next: unsafe { LLVMGetFirstFunction(*module) },
}
}
}
impl<'a, 'll> Iterator for FunctionIter<'a, 'll> {
type Item = &'ll Value;
fn next(&mut self) -> Option<&'ll Value> {
let next = self.next;
self.next = match next {
Some(next) => unsafe { LLVMGetNextFunction(&*next) },
None => None,
};
next
}
}
impl<'a, 'll> GlobalIter<'a, 'll> {
pub fn new(module: &'a &'ll Module) -> Self {
GlobalIter {
module: PhantomData::default(),
next: unsafe { LLVMGetFirstGlobal(*module) },
}
}
}
impl<'a, 'll> Iterator for GlobalIter<'a, 'll> {
type Item = &'ll Value;
fn next(&mut self) -> Option<&'ll Value> {
let next = self.next;
self.next = match next {
Some(next) => unsafe { LLVMGetNextGlobal(&*next) },
None => None,
};
next
}
}
unsafe fn internalize_pass(module: &Module, cx: &Context) {
let num_operands =
LLVMGetNamedMetadataNumOperands(module, "nvvm.annotations\0".as_ptr().cast()) as usize;
let mut operands = Vec::with_capacity(num_operands);
LLVMGetNamedMetadataOperands(
module,
"nvvm.annotations\0".as_ptr().cast(),
operands.as_mut_ptr(),
);
operands.set_len(num_operands);
let mut kernels = Vec::with_capacity(num_operands);
let kernel_str = LLVMMDStringInContext(cx, "kernel".as_ptr().cast(), 6);
for mdnode in operands {
let num_operands = LLVMGetMDNodeNumOperands(mdnode) as usize;
let mut operands = Vec::with_capacity(num_operands);
LLVMGetMDNodeOperands(mdnode, operands.as_mut_ptr());
operands.set_len(num_operands);
if operands.get(1) == Some(&kernel_str) {
kernels.push(operands[0]);
}
}
let num_operands =
LLVMGetNamedMetadataNumOperands(module, "cg_nvvm_used\0".as_ptr().cast()) as usize;
let mut operands = Vec::with_capacity(num_operands);
LLVMGetNamedMetadataOperands(
module,
"cg_nvvm_used\0".as_ptr().cast(),
operands.as_mut_ptr(),
);
operands.set_len(num_operands);
let mut used_funcs = Vec::with_capacity(num_operands);
for mdnode in operands {
let num_operands = LLVMGetMDNodeNumOperands(mdnode) as usize;
let mut operands = Vec::with_capacity(num_operands);
LLVMGetMDNodeOperands(mdnode, operands.as_mut_ptr());
operands.set_len(num_operands);
used_funcs.push(operands[0]);
}
let iter = FunctionIter::new(&module);
for func in iter {
let is_kernel = kernels.contains(&func);
let is_decl = LLVMIsDeclaration(func) == True;
let is_used = used_funcs.contains(&func);
if !is_decl && !is_kernel {
LLVMRustSetLinkage(func, Linkage::InternalLinkage);
LLVMRustSetVisibility(func, Visibility::Default);
}
if is_used {
LLVMRustSetLinkage(func, Linkage::ExternalLinkage);
LLVMRustSetVisibility(func, Visibility::Default);
}
}
let iter = GlobalIter::new(&module);
for func in iter {
let is_decl = LLVMIsDeclaration(func) == True;
if !is_decl {
LLVMRustSetLinkage(func, Linkage::InternalLinkage);
LLVMRustSetVisibility(func, Visibility::Default);
}
}
}
unsafe fn dce_pass(module: &Module) {
let pass_manager = LLVMCreatePassManager();
LLVMAddGlobalDCEPass(pass_manager);
LLVMRunPassManager(pass_manager, module);
LLVMDisposePassManager(pass_manager);
}