pipeline_script/llvm/orc/
lljit.rs

1use llvm_sys::error::LLVMGetErrorMessage;
2use llvm_sys::orc2::lljit::{
3    LLVMOrcLLJITAddLLVMIRModule, LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup,
4    LLVMOrcLLJITMangleAndIntern, LLVMOrcLLJITRef,
5};
6use llvm_sys::orc2::{
7    LLVMJITEvaluatedSymbol, LLVMJITSymbolFlags, LLVMOrcAbsoluteSymbols, LLVMOrcCSymbolMapPair,
8    LLVMOrcCSymbolMapPairs, LLVMOrcExecutorAddress, LLVMOrcJITDylibDefine, LLVMOrcJITDylibRef,
9    LLVMOrcSymbolStringPoolEntryRef, LLVMOrcThreadSafeModuleRef,
10};
11use std::ffi::CString;
12
13pub struct LLJIT {
14    reference: LLVMOrcLLJITRef,
15    main_dylib: LLVMOrcJITDylibRef,
16}
17
18impl LLJIT {
19    pub(crate) fn new(refer: LLVMOrcLLJITRef) -> Self {
20        unsafe {
21            Self {
22                reference: refer,
23                main_dylib: LLVMOrcLLJITGetMainJITDylib(refer),
24            }
25        }
26    }
27    #[allow(unused)]
28    pub fn get_main_jitdylib(&self) -> LLVMOrcJITDylibRef {
29        self.main_dylib
30    }
31    pub(crate) unsafe fn add_main_dylib_llvm_ir_module(
32        &self,
33        safe_module: LLVMOrcThreadSafeModuleRef,
34    ) -> Result<(), String> {
35        let error =
36            unsafe { LLVMOrcLLJITAddLLVMIRModule(self.reference, self.main_dylib, safe_module) };
37        unsafe {
38            if !error.is_null() {
39                let message = LLVMGetErrorMessage(error);
40                let s = CString::from_raw(message);
41                return Err(s.to_str().unwrap().to_string());
42            }
43        }
44        Ok(())
45    }
46    pub fn define_main_dylib_symbol(&self, symbol_name: &str, sym_addr: u64) {
47        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
48        let pair = LLVMOrcCSymbolMapPair {
49            Name: self.mangle_and_intern(symbol_name.to_str().unwrap()),
50            Sym: LLVMJITEvaluatedSymbol {
51                Address: sym_addr,
52                Flags: LLVMJITSymbolFlags {
53                    GenericFlags: 0,
54                    TargetFlags: 0,
55                },
56            },
57        };
58        let f = unsafe { LLVMOrcAbsoluteSymbols([pair].as_mut_ptr() as LLVMOrcCSymbolMapPairs, 1) };
59        unsafe { LLVMOrcJITDylibDefine(self.main_dylib, f) };
60    }
61    pub fn mangle_and_intern(&self, symbol_name: &str) -> LLVMOrcSymbolStringPoolEntryRef {
62        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
63        unsafe { LLVMOrcLLJITMangleAndIntern(self.reference, symbol_name.as_ptr()) }
64    }
65    pub fn lookup(&self, symbol_name: &str) -> Result<u64, String> {
66        let mut address: LLVMOrcExecutorAddress = 0;
67        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
68        let error =
69            unsafe { LLVMOrcLLJITLookup(self.reference, &mut address, symbol_name.as_ptr()) };
70        unsafe {
71            if !error.is_null() {
72                let message = LLVMGetErrorMessage(error);
73                let s = CString::from_raw(message);
74                return Err(s.to_str().unwrap().to_string());
75            }
76        }
77        Ok(address as u64)
78    }
79}