pipeline_script/llvm/orc/
lljit.rs

1use llvm_sys::error::LLVMGetErrorMessage;
2use llvm_sys::orc2::lljit::{
3    LLVMOrcLLJITAddLLVMIRModule, LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup,
4    LLVMOrcLLJITMangleAndIntern, LLVMOrcLLJITRef,
5};
6use llvm_sys::orc2::{LLVMJITEvaluatedSymbol, LLVMJITSymbolFlags, LLVMOrcAbsoluteSymbols, LLVMOrcCSymbolMapPair, LLVMOrcCSymbolMapPairs, LLVMOrcExecutorAddress, LLVMOrcJITDylibDefine, LLVMOrcJITDylibRef, LLVMOrcSymbolStringPoolEntryRef, LLVMOrcThreadSafeModuleRef};
7use std::ffi::CString;
8
9pub struct LLJIT {
10    reference: LLVMOrcLLJITRef,
11    main_dylib: LLVMOrcJITDylibRef,
12}
13
14impl LLJIT {
15    pub fn new(refer: LLVMOrcLLJITRef) -> Self {
16        unsafe {
17            Self {
18                reference: refer,
19                main_dylib: LLVMOrcLLJITGetMainJITDylib(refer),
20            }
21        }
22    }
23    #[allow(unused)]
24    pub fn get_main_jitdylib(&self) -> LLVMOrcJITDylibRef {
25        self.main_dylib
26    }
27    pub fn add_main_dylib_llvm_ir_module(
28        &self,
29        safe_module: LLVMOrcThreadSafeModuleRef,
30    ) -> Result<(), String> {
31        let error =
32            unsafe { LLVMOrcLLJITAddLLVMIRModule(self.reference, self.main_dylib, safe_module) };
33        unsafe {
34            if !error.is_null() {
35                let message = LLVMGetErrorMessage(error);
36                let s = CString::from_raw(message);
37                return Err(s.to_str().unwrap().to_string());
38            }
39        }
40        Ok(())
41    }
42    pub fn define_main_dylib_symbol(&self, symbol_name: &str, sym_addr: u64) {
43        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
44        let sym_addr = sym_addr as u64;
45        let pair = LLVMOrcCSymbolMapPair {
46            Name: self.mangle_and_intern(symbol_name.to_str().unwrap()),
47            Sym: LLVMJITEvaluatedSymbol {
48                Address: sym_addr,
49                Flags: LLVMJITSymbolFlags {
50                    GenericFlags: 0,
51                    TargetFlags: 0,
52                },
53            },
54        };
55        let f = unsafe { LLVMOrcAbsoluteSymbols([pair].as_mut_ptr() as LLVMOrcCSymbolMapPairs, 1) };
56        unsafe { LLVMOrcJITDylibDefine(self.main_dylib, f) };
57    }
58    pub fn mangle_and_intern(&self, symbol_name: &str) -> LLVMOrcSymbolStringPoolEntryRef {
59        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
60        unsafe { LLVMOrcLLJITMangleAndIntern(self.reference, symbol_name.as_ptr()) }
61    }
62    pub fn lookup(&self, symbol_name: &str) -> Result<u64, String> {
63        let mut address: LLVMOrcExecutorAddress = 0;
64        let symbol_name = CString::new(symbol_name).unwrap(); // IR 中声明的符号名
65        let error =
66            unsafe { LLVMOrcLLJITLookup(self.reference, &mut address, symbol_name.as_ptr()) };
67        unsafe {
68            if !error.is_null() {
69                let message = LLVMGetErrorMessage(error);
70                let s = CString::from_raw(message);
71                return Err(s.to_str().unwrap().to_string());
72            }
73        }
74        Ok(address as u64)
75    }
76}