pipeline_script/llvm/orc/
lljit.rs

1use llvm_sys::error::LLVMGetErrorMessage;
2use llvm_sys::orc2::lljit::{
3    LLVMOrcLLJITAddLLVMIRModule, LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup,
4    LLVMOrcLLJITMangleAndIntern, LLVMOrcLLJITRef,
5};
6use llvm_sys::orc2::{
7    LLVMJITEvaluatedSymbol, LLVMJITSymbolFlags, LLVMOrcAbsoluteSymbols, LLVMOrcCSymbolMapPair,
8    LLVMOrcCSymbolMapPairs, LLVMOrcExecutorAddress, LLVMOrcJITDylibDefine, LLVMOrcJITDylibRef,
9    LLVMOrcSymbolStringPoolEntryRef, LLVMOrcThreadSafeModuleRef,
10};
11use std::ffi::CString;
12
13pub struct LLJIT {
14    reference: LLVMOrcLLJITRef,
15    main_dylib: LLVMOrcJITDylibRef,
16}
17
18impl LLJIT {
19    pub fn new(refer: LLVMOrcLLJITRef) -> Self {
20        unsafe {
21            Self {
22                reference: refer,
23                main_dylib: LLVMOrcLLJITGetMainJITDylib(refer),
24            }
25        }
26    }
27    #[allow(unused)]
28    pub fn get_main_jitdylib(&self) -> LLVMOrcJITDylibRef {
29        self.main_dylib
30    }
31    pub fn add_main_dylib_llvm_ir_module(
32        &self,
33        safe_module: LLVMOrcThreadSafeModuleRef,
34    ) -> Result<(), String> {
35        let error =
36            unsafe { LLVMOrcLLJITAddLLVMIRModule(self.reference, self.main_dylib, safe_module) };
37        unsafe {
38            if !error.is_null() {
39                let message = LLVMGetErrorMessage(error);
40                let s = CString::from_raw(message);
41                return Err(s.to_str().unwrap().to_string());
42            }
43        }
44        Ok(())
45    }
46    pub fn define_main_dylib_symbol(&self, symbol_name: &str, sym_addr: u64) {
47        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
48        let sym_addr = sym_addr as u64;
49        let pair = LLVMOrcCSymbolMapPair {
50            Name: self.mangle_and_intern(symbol_name.to_str().unwrap()),
51            Sym: LLVMJITEvaluatedSymbol {
52                Address: sym_addr,
53                Flags: LLVMJITSymbolFlags {
54                    GenericFlags: 0,
55                    TargetFlags: 0,
56                },
57            },
58        };
59        let f = unsafe { LLVMOrcAbsoluteSymbols([pair].as_mut_ptr() as LLVMOrcCSymbolMapPairs, 1) };
60        unsafe { LLVMOrcJITDylibDefine(self.main_dylib, f) };
61    }
62    pub fn mangle_and_intern(&self, symbol_name: &str) -> LLVMOrcSymbolStringPoolEntryRef {
63        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
64        unsafe { LLVMOrcLLJITMangleAndIntern(self.reference, symbol_name.as_ptr()) }
65    }
66    pub fn lookup(&self, symbol_name: &str) -> Result<u64, String> {
67        let mut address: LLVMOrcExecutorAddress = 0;
68        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
69        let error =
70            unsafe { LLVMOrcLLJITLookup(self.reference, &mut address, symbol_name.as_ptr()) };
71        unsafe {
72            if !error.is_null() {
73                let message = LLVMGetErrorMessage(error);
74                let s = CString::from_raw(message);
75                return Err(s.to_str().unwrap().to_string());
76            }
77        }
78        Ok(address as u64)
79    }
80}