use bitflags::bitflags;
use std::{mem::MaybeUninit, ptr};
use llvm_sys::orc2::{
LLVMJITEvaluatedSymbol, LLVMJITSymbolFlags, LLVMJITSymbolGenericFlags, LLVMOrcAbsoluteSymbols,
LLVMOrcCSymbolMapPair, LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule,
LLVMOrcDisposeThreadSafeContext, LLVMOrcDisposeThreadSafeModule, lljit,
};
use crate::llvm_sys::{
core::{LLVMModule, handle_err},
cstr_to_string, to_c_str,
};
bitflags! {
#[derive(PartialEq, Eq, Clone, Debug, Hash, Copy)]
pub struct JITSymbolGenericFlags: u8 {
const JITSymbolGenericFlagsNone = 0;
const JITSymbolGenericFlagsExported = 1;
const JITSymbolGenericFlagsWeak = 2;
const JITSymbolGenericFlagsCallable = 4;
const JITSymbolGenericFlagsMaterializationSideEffectsOnly = 8;
}
}
impl From<LLVMJITSymbolGenericFlags> for JITSymbolGenericFlags {
fn from(value: LLVMJITSymbolGenericFlags) -> Self {
let mut flags = JITSymbolGenericFlags::empty();
if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsExported as u8) != 0
{
flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsExported;
}
if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsWeak as u8) != 0 {
flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsWeak;
}
if (value as u8) & (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsCallable as u8) != 0
{
flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsCallable;
}
if (value as u8)
& (LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsMaterializationSideEffectsOnly
as u8)
!= 0
{
flags |= JITSymbolGenericFlags::JITSymbolGenericFlagsMaterializationSideEffectsOnly;
}
flags
}
}
impl From<JITSymbolGenericFlags> for u8 {
fn from(value: JITSymbolGenericFlags) -> Self {
let mut flags = LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsNone as u8;
if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsExported) {
flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsExported as u8;
}
if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsWeak) {
flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsWeak as u8;
}
if value.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsCallable) {
flags |= LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsCallable as u8;
}
if value
.contains(JITSymbolGenericFlags::JITSymbolGenericFlagsMaterializationSideEffectsOnly)
{
flags |=
LLVMJITSymbolGenericFlags::LLVMJITSymbolGenericFlagsMaterializationSideEffectsOnly
as u8;
}
flags
}
}
pub struct LLVMLLJIT(lljit::LLVMOrcLLJITRef);
impl LLVMLLJIT {
pub fn new_with_default_builder() -> Result<Self, String> {
unsafe {
let mut jit = MaybeUninit::uninit();
let err = lljit::LLVMOrcCreateLLJIT(jit.as_mut_ptr(), ptr::null_mut());
handle_err(err)?;
Ok(LLVMLLJIT(jit.assume_init()))
}
}
pub fn add_module(&self, module: LLVMModule) -> Result<(), String> {
unsafe {
let tsctx = LLVMOrcCreateNewThreadSafeContext();
let tsm = LLVMOrcCreateNewThreadSafeModule(module.inner_ref(), tsctx);
let main_jd = lljit::LLVMOrcLLJITGetMainJITDylib(self.0);
let err = lljit::LLVMOrcLLJITAddLLVMIRModule(self.0, main_jd, tsm);
LLVMOrcDisposeThreadSafeContext(tsctx);
std::mem::forget(module);
handle_err(err).inspect_err(|_| {
LLVMOrcDisposeThreadSafeModule(tsm);
})
}
}
pub fn lookup_symbol(&self, name: &str) -> Result<u64, String> {
unsafe {
let mut addr = MaybeUninit::uninit();
let err = lljit::LLVMOrcLLJITLookup(self.0, addr.as_mut_ptr(), to_c_str(name).as_ptr());
handle_err(err)?;
Ok(addr.assume_init())
}
}
pub fn get_triple_string(&self) -> String {
unsafe {
let triple_ptr = lljit::LLVMOrcLLJITGetTripleString(self.0);
cstr_to_string(triple_ptr).unwrap()
}
}
pub fn add_symbol_mapping(
&self,
name: &str,
addr: u64,
flags: JITSymbolGenericFlags,
) -> Result<(), String> {
let symbol_pool_ref =
unsafe { lljit::LLVMOrcLLJITMangleAndIntern(self.0, to_c_str(name).as_ptr()) };
let jit_evaluated_symbol = LLVMJITEvaluatedSymbol {
Address: addr,
Flags: LLVMJITSymbolFlags {
GenericFlags: flags.into(),
TargetFlags: 0,
},
};
let mut symbol_pair = LLVMOrcCSymbolMapPair {
Name: symbol_pool_ref,
Sym: jit_evaluated_symbol,
};
let materialization_unit = unsafe { LLVMOrcAbsoluteSymbols(&mut symbol_pair as *mut _, 1) };
let main_dylib = unsafe { lljit::LLVMOrcLLJITGetMainJITDylib(self.0) };
let res =
unsafe { llvm_sys::orc2::LLVMOrcJITDylibDefine(main_dylib, materialization_unit) };
handle_err(res)
}
}
impl Drop for LLVMLLJIT {
fn drop(&mut self) {
unsafe {
let err = lljit::LLVMOrcDisposeLLJIT(self.0);
if let Err(err) = handle_err(err) {
panic!("Error disposing LLJIT: {}", err);
}
}
}
}