pipeline_script/llvm/
module.rs

1use crate::llvm::executor::JITExecutor;
2use crate::llvm::function::Function;
3use crate::llvm::global::Global;
4use crate::llvm::types::LLVMType;
5use crate::llvm::value::reference::ReferenceValue;
6use crate::llvm::value::LLVMValue;
7use llvm_sys::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
8use llvm_sys::core::{
9    LLVMAddFunction, LLVMAddGlobal, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, LLVMDumpModule,
10    LLVMGetBufferSize, LLVMGetBufferStart, LLVMGetNamedFunction, LLVMModuleCreateWithName,
11    LLVMSetInitializer, LLVMSetTarget,
12};
13use llvm_sys::execution_engine::{LLVMCreateJITCompilerForModule, LLVMExecutionEngineRef};
14use llvm_sys::orc2::{
15    LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule, LLVMOrcThreadSafeModuleRef,
16};
17use llvm_sys::prelude::{LLVMModuleRef, LLVMValueRef};
18use llvm_sys::target_machine::{
19    LLVMCodeGenFileType, LLVMCodeGenOptLevel, LLVMCodeModel, LLVMCreateTargetMachine,
20    LLVMDisposeTargetMachine, LLVMGetTargetFromTriple, LLVMRelocMode, LLVMTarget,
21    LLVMTargetMachineEmitToMemoryBuffer,
22};
23use llvm_sys::LLVMMemoryBuffer;
24use std::collections::HashMap;
25use std::ffi::CString;
26use std::ptr;
27use std::rc::Rc;
28
29#[derive(Debug)]
30pub struct LLVMModule {
31    module_ref: Rc<LLVMModuleRef>,
32    function_map: HashMap<String, Function>,
33    struct_map: HashMap<String, (HashMap<String, usize>, LLVMType)>,
34    global_map: HashMap<String, (LLVMType, LLVMValue)>,
35}
36
37impl LLVMModule {
38    pub(crate) fn new(name: impl AsRef<str>) -> Self {
39        let name = name.as_ref();
40        let name = CString::new(name).unwrap();
41        let module_ref = unsafe { LLVMModuleCreateWithName(name.as_ptr()) };
42        LLVMModule {
43            module_ref: Rc::new(module_ref),
44            function_map: HashMap::new(),
45            global_map: HashMap::new(),
46            struct_map: HashMap::new(),
47        }
48    }
49    pub fn create_thread_safe_module(&self) -> LLVMOrcThreadSafeModuleRef {
50        unsafe {
51            let ctx = LLVMOrcCreateNewThreadSafeContext();
52            LLVMOrcCreateNewThreadSafeModule(*self.module_ref, ctx)
53        }
54    }
55    pub fn from_raw(module_ref: LLVMModuleRef) -> Self {
56        LLVMModule {
57            module_ref: Rc::new(module_ref),
58            function_map: HashMap::new(),
59            global_map: HashMap::new(),
60            struct_map: HashMap::new(),
61        }
62    }
63    pub fn to_assembly(&self, target_triple: &str) -> Result<String, String> {
64        fn compile_to_assembly(
65            module: *mut llvm_sys::LLVMModule,
66            target_triple: &str,
67        ) -> Result<String, String> {
68            unsafe {
69                // 2. 创建目标机器
70                let mut target: *mut LLVMTarget = std::ptr::null_mut();
71                let mut error_msg = std::ptr::null_mut();
72                let target_triple_cstr = std::ffi::CString::new(target_triple).unwrap();
73
74                if LLVMGetTargetFromTriple(target_triple_cstr.as_ptr(), &mut target, &mut error_msg)
75                    != 0
76                {
77                    return Err(std::ffi::CStr::from_ptr(error_msg)
78                        .to_string_lossy()
79                        .into_owned());
80                }
81                if target.is_null() {
82                    return Err("Failed to get target".to_string());
83                }
84                let cpu = CString::new("generic").unwrap();
85                let features = CString::new("").unwrap();
86                let opt_level = LLVMCodeGenOptLevel::LLVMCodeGenLevelDefault;
87
88                let target_machine = LLVMCreateTargetMachine(
89                    target,
90                    target_triple_cstr.as_ptr(),
91                    cpu.as_ptr(),
92                    features.as_ptr(),
93                    opt_level,
94                    LLVMRelocMode::LLVMRelocDefault,
95                    LLVMCodeModel::LLVMCodeModelDefault,
96                );
97
98                if target_machine.is_null() {
99                    return Err("Failed to create target machine".to_string());
100                }
101                // 3. 配置模块目标
102                LLVMSetTarget(module, target_triple_cstr.as_ptr());
103                // 4. 创建内存缓冲输出
104                let mut output_buffer: *mut LLVMMemoryBuffer = std::ptr::null_mut();
105                let output_type = LLVMCodeGenFileType::LLVMAssemblyFile;
106
107                let result = LLVMTargetMachineEmitToMemoryBuffer(
108                    target_machine,
109                    module,
110                    output_type,
111                    &mut error_msg,
112                    &mut output_buffer,
113                );
114
115                if result != 0 {
116                    return Err(std::ffi::CStr::from_ptr(error_msg)
117                        .to_string_lossy()
118                        .into_owned());
119                }
120
121                // 5. 提取汇编代码
122                let data_ptr = LLVMGetBufferStart(output_buffer);
123                let data_size = LLVMGetBufferSize(output_buffer);
124                let assembly = std::slice::from_raw_parts(data_ptr as *const u8, data_size);
125                let assembly_str = std::str::from_utf8(assembly).unwrap().to_owned();
126
127                // 清理资源
128                LLVMDisposeMemoryBuffer(output_buffer);
129                LLVMDisposeTargetMachine(target_machine);
130
131                Ok(assembly_str)
132            }
133        }
134        compile_to_assembly(*self.module_ref, target_triple)
135    }
136    pub fn register_struct(
137        &mut self,
138        name: impl AsRef<str>,
139        field_map: HashMap<String, usize>,
140        t: LLVMType,
141    ) {
142        self.struct_map.insert(name.as_ref().into(), (field_map, t));
143    }
144    pub fn register_extern_function(
145        &mut self,
146        name: impl AsRef<str>,
147        function_type: LLVMType,
148    ) -> Function {
149        let module = self.module_ref.as_ref();
150        let name0 = CString::new(name.as_ref()).unwrap();
151        let f =
152            unsafe { LLVMAddFunction(*module, name0.as_ptr(), function_type.as_llvm_type_ref()) };
153        let f = Function::new(function_type, f, vec![]);
154        self.function_map.insert(name.as_ref().into(), f.clone());
155        f
156    }
157    pub fn register_function(
158        &mut self,
159        name: impl AsRef<str>,
160        function_type: LLVMType,
161        param_names: Vec<String>,
162    ) -> Function {
163        let module = self.module_ref.as_ref();
164        let name0 = CString::new(name.as_ref()).unwrap();
165        let f =
166            unsafe { LLVMAddFunction(*module, name0.as_ptr(), function_type.as_llvm_type_ref()) };
167        let f = Function::new(function_type.clone(), f, param_names.clone());
168        self.function_map.insert(name.as_ref().into(), f.clone());
169        f
170    }
171    pub fn verify_with_debug_info(&self) {
172        unsafe {
173            // 即时打印错误到控制台
174            LLVMVerifyModule(
175                *self.module_ref,
176                LLVMVerifierFailureAction::LLVMPrintMessageAction,
177                std::ptr::null_mut(),
178            );
179
180            // 同时获取错误信息副本
181            let mut error_copy = std::ptr::null_mut();
182            let status = LLVMVerifyModule(
183                *self.module_ref,
184                LLVMVerifierFailureAction::LLVMReturnStatusAction,
185                &mut error_copy,
186            );
187
188            if status != 0 {
189                // let error_str = CStr::from_ptr(error_copy).to_string_lossy();
190                // let debug_str =error_str.to_string();
191                // println!("Additional error details: {}", debug_str);
192                LLVMDisposeMessage(error_copy);
193            }
194        }
195    }
196    pub fn add_global(
197        &mut self,
198        name: impl AsRef<str>,
199        v0: &LLVMValue,
200        ty: &LLVMType,
201    ) -> LLVMValue {
202        let module = self.module_ref.as_ref();
203        let name0 = name.as_ref();
204        let name1 = CString::new(name0).unwrap();
205        let v = v0.as_llvm_value_ref();
206        let global_val = unsafe {
207            let global_val = LLVMAddGlobal(*module, ty.as_llvm_type_ref(), name1.as_ptr());
208            LLVMSetInitializer(global_val, v);
209            global_val
210        };
211        self.global_map
212            .insert(name.as_ref().into(), (ty.clone(), global_val.into()));
213
214        LLVMValue::Reference(ReferenceValue::new(global_val, v0.clone()))
215    }
216
217    pub fn create_executor(&self) -> Result<JITExecutor, String> {
218        let mut engine: LLVMExecutionEngineRef = ptr::null_mut();
219        let mut error: *mut i8 = ptr::null_mut();
220        let module = self.module_ref.as_ref();
221        unsafe {
222            if LLVMCreateJITCompilerForModule(&mut engine, *module, 0, &mut error) != 0 {
223                let error_msg = CString::from_raw(error).into_string().unwrap();
224                let e = format!("Error creating JIT: {}", error_msg);
225                return Err(e);
226            }
227        }
228
229        Ok(JITExecutor::new(
230            self.module_ref.clone(),
231            engine,
232            self.function_map.clone(),
233        ))
234    }
235    pub fn get_function(&self, name: impl AsRef<str>) -> Option<&Function> {
236        self.function_map.get(name.as_ref())
237    }
238    pub fn get_function_ref(&self, name: impl AsRef<str>) -> LLVMValueRef {
239        let name = CString::new(name.as_ref()).unwrap();
240        let f = unsafe { LLVMGetNamedFunction(*self.module_ref.as_ref(), name.as_ptr()) };
241        f
242    }
243    pub fn get_struct(&self, name: impl AsRef<str>) -> Option<&(HashMap<String, usize>, LLVMType)> {
244        self.struct_map.get(name.as_ref())
245    }
246    pub fn get_struct_by_type(
247        &self,
248        t: &LLVMType,
249    ) -> Option<(&String, &HashMap<String, usize>, &LLVMType)> {
250        for (name, (map, ty)) in self.struct_map.iter() {
251            // if &Global::pointer_type(ty.clone()) == t{
252            //     return Some((name,map,ty));
253            // }
254            if ty == t {
255                return Some((name, map, ty));
256            }
257        }
258        None
259    }
260    pub fn get_struct_by_pointer_type(
261        &self,
262        t: &LLVMType,
263    ) -> Option<(&String, &HashMap<String, usize>, &LLVMType)> {
264        for (name, (map, ty)) in self.struct_map.iter() {
265            dbg!(name, map, ty, Global::pointer_type(ty.clone()));
266            if &Global::pointer_type(ty.clone()) == t {
267                return Some((name, map, ty));
268            }
269        }
270        None
271    }
272    pub fn get_module_ref(&self) -> LLVMModuleRef {
273        *self.module_ref
274    }
275    pub fn dump(&self) {
276        let module = self.module_ref.as_ref();
277        if module.is_null() {
278            return;
279        }
280        unsafe { LLVMDumpModule(*module) }
281    }
282    pub fn has_global(&self, name: impl AsRef<str>) -> bool {
283        self.global_map.contains_key(name.as_ref())
284    }
285}
286
287impl Drop for LLVMModule {
288    fn drop(&mut self) {
289        if Rc::strong_count(&self.module_ref) == 1 {
290            // 只有一个引用,安全释放 LLVMModuleRef
291            // unsafe {
292            // 这里需要调用相应的 LLVM API 函数释放 LLVMModuleRef
293            // 例如:LLVMDisposeModule
294            // LLVMDisposeModule(self.module);
295            // 释放 Arc 的引用
296            // 注意:Arc 会自动处理引用计数和内存释放
297            let module_ptr = Rc::get_mut(&mut self.module_ref).unwrap();
298            if !module_ptr.is_null() {
299                // 释放模块
300                // LLVMDisposeModule(*module_ptr);
301            }
302            // }
303        }
304    }
305}