pipeline_script/llvm/
module.rs

1use crate::llvm::executor::JITExecutor;
2use crate::llvm::function::Function;
3use crate::llvm::global::Global;
4use crate::llvm::types::LLVMType;
5use crate::llvm::value::LLVMValue;
6use llvm_sys::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
7use llvm_sys::core::{
8    LLVMAddFunction, LLVMAddGlobal, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, LLVMDumpModule,
9    LLVMGetBufferSize, LLVMGetBufferStart, LLVMGetNamedFunction, LLVMModuleCreateWithName,
10    LLVMSetInitializer, LLVMSetTarget, LLVMTypeOf,
11};
12use llvm_sys::execution_engine::{LLVMCreateJITCompilerForModule, LLVMExecutionEngineRef};
13use llvm_sys::orc2::{
14    LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule, LLVMOrcThreadSafeModuleRef,
15};
16use llvm_sys::prelude::{LLVMModuleRef, LLVMValueRef};
17use llvm_sys::target_machine::{
18    LLVMCodeGenFileType, LLVMCodeGenOptLevel, LLVMCodeModel, LLVMCreateTargetMachine,
19    LLVMDisposeTargetMachine, LLVMGetTargetFromTriple, LLVMRelocMode, LLVMTarget,
20    LLVMTargetMachineEmitToMemoryBuffer,
21};
22use llvm_sys::LLVMMemoryBuffer;
23use std::collections::HashMap;
24use std::ffi::CString;
25use std::ptr;
26use std::rc::Rc;
27
28#[derive(Debug)]
29pub struct LLVMModule {
30    module_ref: Rc<LLVMModuleRef>,
31    function_map: HashMap<String, Function>,
32    struct_map: HashMap<String, (HashMap<String, usize>, LLVMType)>,
33    global_map: HashMap<String, (LLVMType, LLVMValue)>,
34}
35
36impl LLVMModule {
37    pub(crate) fn new(name: impl AsRef<str>) -> Self {
38        let name = name.as_ref();
39        let name = CString::new(name).unwrap();
40        let module_ref = unsafe { LLVMModuleCreateWithName(name.as_ptr()) };
41        LLVMModule {
42            module_ref: Rc::new(module_ref),
43            function_map: HashMap::new(),
44            global_map: HashMap::new(),
45            struct_map: HashMap::new(),
46        }
47    }
48    pub fn create_thread_safe_module(&self) -> LLVMOrcThreadSafeModuleRef {
49        unsafe {
50            let ctx = LLVMOrcCreateNewThreadSafeContext();
51            LLVMOrcCreateNewThreadSafeModule(*self.module_ref, ctx)
52        }
53    }
54    pub fn from_raw(module_ref: LLVMModuleRef) -> Self {
55        LLVMModule {
56            module_ref: Rc::new(module_ref),
57            function_map: HashMap::new(),
58            global_map: HashMap::new(),
59            struct_map: HashMap::new(),
60        }
61    }
62    pub fn to_assembly(&self, target_triple: &str) -> Result<String, String> {
63        fn compile_to_assembly(
64            module: *mut llvm_sys::LLVMModule,
65            target_triple: &str,
66        ) -> Result<String, String> {
67            unsafe {
68                // 2. 创建目标机器
69                let mut target: *mut LLVMTarget = std::ptr::null_mut();
70                let mut error_msg = std::ptr::null_mut();
71                let target_triple_cstr = std::ffi::CString::new(target_triple).unwrap();
72
73                if LLVMGetTargetFromTriple(target_triple_cstr.as_ptr(), &mut target, &mut error_msg)
74                    != 0
75                {
76                    return Err(std::ffi::CStr::from_ptr(error_msg)
77                        .to_string_lossy()
78                        .into_owned());
79                }
80                if target.is_null() {
81                    return Err("Failed to get target".to_string());
82                }
83                let cpu = CString::new("generic").unwrap();
84                let features = CString::new("").unwrap();
85                let opt_level = LLVMCodeGenOptLevel::LLVMCodeGenLevelDefault;
86
87                let target_machine = LLVMCreateTargetMachine(
88                    target,
89                    target_triple_cstr.as_ptr(),
90                    cpu.as_ptr(),
91                    features.as_ptr(),
92                    opt_level,
93                    LLVMRelocMode::LLVMRelocDefault,
94                    LLVMCodeModel::LLVMCodeModelDefault,
95                );
96
97                if target_machine.is_null() {
98                    return Err("Failed to create target machine".to_string());
99                }
100                // 3. 配置模块目标
101                LLVMSetTarget(module, target_triple_cstr.as_ptr());
102                // 4. 创建内存缓冲输出
103                let mut output_buffer: *mut LLVMMemoryBuffer = std::ptr::null_mut();
104                let output_type = LLVMCodeGenFileType::LLVMAssemblyFile;
105
106                let result = LLVMTargetMachineEmitToMemoryBuffer(
107                    target_machine,
108                    module,
109                    output_type,
110                    &mut error_msg,
111                    &mut output_buffer,
112                );
113
114                if result != 0 {
115                    return Err(std::ffi::CStr::from_ptr(error_msg)
116                        .to_string_lossy()
117                        .into_owned());
118                }
119
120                // 5. 提取汇编代码
121                let data_ptr = LLVMGetBufferStart(output_buffer);
122                let data_size = LLVMGetBufferSize(output_buffer);
123                let assembly = std::slice::from_raw_parts(data_ptr as *const u8, data_size);
124                let assembly_str = std::str::from_utf8(assembly).unwrap().to_owned();
125
126                // 清理资源
127                LLVMDisposeMemoryBuffer(output_buffer);
128                LLVMDisposeTargetMachine(target_machine);
129
130                Ok(assembly_str)
131            }
132        }
133        compile_to_assembly(*self.module_ref, target_triple)
134    }
135    pub fn register_struct(
136        &mut self,
137        name: impl AsRef<str>,
138        field_map: HashMap<String, usize>,
139        t: LLVMType,
140    ) {
141        self.struct_map.insert(name.as_ref().into(), (field_map, t));
142    }
143    pub fn register_extern_function(
144        &mut self,
145        name: impl AsRef<str>,
146        function_type: LLVMType,
147    ) -> Function {
148        let module = self.module_ref.as_ref();
149        let name0 = CString::new(name.as_ref()).unwrap();
150        let f =
151            unsafe { LLVMAddFunction(*module, name0.as_ptr(), function_type.as_llvm_type_ref()) };
152        let f = Function::new(function_type, f, vec![]);
153        self.function_map.insert(name.as_ref().into(), f.clone());
154        f
155    }
156    pub fn register_function(
157        &mut self,
158        name: impl AsRef<str>,
159        function_type: LLVMType,
160        param_names: Vec<String>,
161    ) -> Function {
162        let module = self.module_ref.as_ref();
163        let name0 = CString::new(name.as_ref()).unwrap();
164        let f =
165            unsafe { LLVMAddFunction(*module, name0.as_ptr(), function_type.as_llvm_type_ref()) };
166        let f = Function::new(function_type.clone(), f, param_names.clone());
167        self.function_map.insert(name.as_ref().into(), f.clone());
168        f
169    }
170    pub fn verify_with_debug_info(&self) {
171        unsafe {
172            // 即时打印错误到控制台
173            LLVMVerifyModule(
174                *self.module_ref,
175                LLVMVerifierFailureAction::LLVMPrintMessageAction,
176                std::ptr::null_mut(),
177            );
178
179            // 同时获取错误信息副本
180            let mut error_copy = std::ptr::null_mut();
181            let status = LLVMVerifyModule(
182                *self.module_ref,
183                LLVMVerifierFailureAction::LLVMReturnStatusAction,
184                &mut error_copy,
185            );
186
187            if status != 0 {
188                // let error_str = CStr::from_ptr(error_copy).to_string_lossy();
189                // let debug_str =error_str.to_string();
190                // println!("Additional error details: {}", debug_str);
191                LLVMDisposeMessage(error_copy);
192            }
193        }
194    }
195    pub fn add_global(&mut self, name: impl AsRef<str>, v: LLVMValue) {
196        let module = self.module_ref.as_ref();
197        let name0 = name.as_ref();
198        let name1 = CString::new(name0).unwrap();
199        let v = v.as_llvm_value_ref();
200        let ty = unsafe { LLVMTypeOf(v) };
201        let global_val = unsafe {
202            let global_val = LLVMAddGlobal(*module, ty, name1.as_ptr());
203            LLVMSetInitializer(global_val, v);
204            global_val
205        };
206        self.global_map
207            .insert(name.as_ref().into(), (ty.into(), global_val.into()));
208    }
209
210    pub fn create_executor(&self) -> Result<JITExecutor, String> {
211        let mut engine: LLVMExecutionEngineRef = ptr::null_mut();
212        let mut error: *mut i8 = ptr::null_mut();
213        let module = self.module_ref.as_ref();
214        unsafe {
215            if LLVMCreateJITCompilerForModule(&mut engine, *module, 0, &mut error) != 0 {
216                let error_msg = CString::from_raw(error).into_string().unwrap();
217                let e = format!("Error creating JIT: {}", error_msg);
218                return Err(e);
219            }
220        }
221
222        Ok(JITExecutor::new(
223            self.module_ref.clone(),
224            engine,
225            self.function_map.clone(),
226        ))
227    }
228    pub fn get_function(&self, name: impl AsRef<str>) -> Option<&Function> {
229        self.function_map.get(name.as_ref())
230    }
231    pub fn get_function_ref(&self, name: impl AsRef<str>) -> LLVMValueRef {
232        let name = CString::new(name.as_ref()).unwrap();
233        let f = unsafe { LLVMGetNamedFunction(*self.module_ref.as_ref(), name.as_ptr()) };
234        f
235    }
236    pub fn get_struct(&self, name: impl AsRef<str>) -> Option<&(HashMap<String, usize>, LLVMType)> {
237        self.struct_map.get(name.as_ref())
238    }
239    pub fn get_struct_by_type(
240        &self,
241        t: &LLVMType,
242    ) -> Option<(&String, &HashMap<String, usize>, &LLVMType)> {
243        for (name, (map, ty)) in self.struct_map.iter() {
244            // if &Global::pointer_type(ty.clone()) == t{
245            //     return Some((name,map,ty));
246            // }
247            if ty == t {
248                return Some((name, map, ty));
249            }
250        }
251        None
252    }
253    pub fn get_struct_by_pointer_type(
254        &self,
255        t: &LLVMType,
256    ) -> Option<(&String, &HashMap<String, usize>, &LLVMType)> {
257        dbg!(&self.struct_map);
258        for (name, (map, ty)) in self.struct_map.iter() {
259            dbg!(name, map, ty, Global::pointer_type(ty.clone()));
260            if &Global::pointer_type(ty.clone()) == t {
261                return Some((name, map, ty));
262            }
263        }
264        None
265    }
266    pub fn dump(&self) {
267        let module = self.module_ref.as_ref();
268        if module.is_null() {
269            return;
270        }
271        unsafe { LLVMDumpModule(*module) }
272    }
273}
274
275impl Drop for LLVMModule {
276    fn drop(&mut self) {
277        if Rc::strong_count(&self.module_ref) == 1 {
278            // 只有一个引用,安全释放 LLVMModuleRef
279            // unsafe {
280            // 这里需要调用相应的 LLVM API 函数释放 LLVMModuleRef
281            // 例如:LLVMDisposeModule
282            // LLVMDisposeModule(self.module);
283            // 释放 Arc 的引用
284            // 注意:Arc 会自动处理引用计数和内存释放
285            let module_ptr = Rc::get_mut(&mut self.module_ref).unwrap();
286            if !module_ptr.is_null() {
287                // 释放模块
288                // LLVMDisposeModule(*module_ptr);
289            }
290            // }
291        }
292    }
293}