pipeline_script/compiler/
mod.rs

1mod expr;
2mod stmt;
3mod r#type;
4
5use crate::context::Context;
6use crate::llvm::context::LLVMContext;
7use crate::llvm::global::Global;
8use crate::llvm::module::LLVMModule;
9use std::collections::HashMap;
10use std::rc::Rc;
11use std::sync::RwLock;
12
13use crate::ast::module::Module;
14use crate::context::key::ContextKey;
15use crate::context::value::ContextValue;
16use crate::llvm::types::LLVMType;
17use crate::llvm::value::fucntion::FunctionValue;
18use crate::llvm::value::LLVMValue;
19
20pub struct Compiler {
21    module: Module,
22    ctx: LLVMContext,
23    builtin_symbol: HashMap<String, LLVMValue>,
24    llvm_module: Rc<RwLock<LLVMModule>>,
25}
26
27impl Compiler {
28    pub fn new(module: Module) -> Self {
29        let llvm_ctx = LLVMContext::new();
30        let llvm_module = llvm_ctx.create_module(module.get_name());
31        Self {
32            ctx: llvm_ctx,
33            module,
34            llvm_module: Rc::new(RwLock::new(llvm_module)),
35            builtin_symbol: HashMap::new(),
36        }
37    }
38    pub fn register_builtin_symbol(&mut self, name: &str, ty: LLVMValue) {
39        self.builtin_symbol.insert(name.to_string(), ty);
40    }
41    pub fn register_llvm_function(&mut self, name: &str, ty: LLVMType, arg_names: Vec<String>) {
42        self.llvm_module
43            .write()
44            .unwrap()
45            .register_function(name, ty, arg_names);
46    }
47    pub fn compile(&mut self, ctx: &Context) -> Rc<RwLock<LLVMModule>> {
48        let ctx = Context::with_type_table(ctx, HashMap::new());
49        // 顶层作用域
50        let ctx = Context::with_scope(&ctx);
51        let ctx = Context::with_value(
52            &ctx,
53            ContextKey::LLVMModule,
54            ContextValue::LLVMModule(self.llvm_module.clone()),
55        );
56
57        // 编译内建符号进入scope
58        for (name, ty) in self.builtin_symbol.iter() {
59            ctx.set_symbol(name.clone(), ty.clone());
60        }
61
62        // 编译结构体
63        for (name, item) in self.module.get_structs() {
64            if !item.generics.is_empty() {
65                continue;
66            }
67            let fields = item.get_fields();
68            let mut field_index: HashMap<String, usize> = HashMap::new();
69            for (i, field) in fields.iter().enumerate() {
70                field_index.insert(field.name.clone(), i);
71            }
72            let t = item.get_type();
73            let mut t = self.get_type(&ctx, &t);
74            if t.is_function() {
75                t = Global::pointer_type(t);
76            }
77            self.llvm_module
78                .write()
79                .unwrap()
80                .register_struct(name, field_index, t);
81        }
82        // 编译枚举
83        for (name, item) in self.module.get_type_aliases().iter() {
84            let t = self.compile_type(item.get_type());
85            self.llvm_module
86                .write()
87                .unwrap()
88                .register_struct(name, HashMap::new(), t);
89        }
90        let builder = Global::create_builder();
91        let ctx = Context::with_builder(&ctx, builder);
92        let ctx = Context::with_default_expr(&ctx);
93        // 编译函数声明
94        for (name, item) in self.module.get_functions().iter() {
95            if item.is_template {
96                continue;
97            }
98            let args = item.args();
99            let mut arg_types = vec![];
100            let mut is_var_arg = false;
101            let mut is_array_vararg = false;
102            for arg in args {
103                if arg.is_var_arg() {
104                    is_var_arg = true;
105                    continue;
106                }
107                if arg.is_array_vararg() {
108                    is_array_vararg = true;
109                }
110                let ty = arg.r#type().unwrap();
111                let t = self.get_type(&ctx, &ty);
112                // 函数值作为参数时,统一转换成闭包类型
113                if t.is_function() {
114                    arg_types.push((
115                        arg.name().to_string(),
116                        Global::struct_type(
117                            "Closure".into(),
118                            vec![
119                                ("ptr".to_string(), Global::pointer_type(t)),
120                                ("env".to_string(), Global::pointer_type(Global::unit_type())),
121                            ],
122                        ),
123                    ));
124                    continue;
125                }
126                arg_types.push((arg.name().to_string(), t));
127            }
128            let return_type0 = item.return_type();
129            let return_type = self.get_type(&ctx, return_type0);
130            let t = if is_var_arg || item.is_vararg() {
131                Global::function_type_with_var_arg(return_type.clone(), arg_types)
132            } else {
133                Global::function_type(return_type.clone(), arg_types)
134            };
135            let f = if item.is_extern {
136                self.llvm_module
137                    .write()
138                    .unwrap()
139                    .register_extern_function(name, t)
140            } else {
141                let args = item.args();
142                let param_names: Vec<String> =
143                    args.iter().map(|arg| arg.name()).collect::<Vec<_>>();
144                self.llvm_module
145                    .write()
146                    .unwrap()
147                    .register_function(name, t, param_names)
148            };
149            let mut function_value = FunctionValue::new(
150                f.get_function_ref(),
151                name.clone(),
152                Box::new(return_type.get_undef()),
153                args.iter()
154                    .map(|arg| {
155                        let name = arg.name().clone();
156                        let undef = self.compile_type(&arg.r#type().unwrap()).get_undef();
157
158                        if arg.is_env() {
159                            if let Some(default_expr) = arg.get_default() {
160                                ctx.set_default_expr(name.clone(), Box::new(default_expr.clone()));
161                            }
162                            (name, undef)
163                        } else if let Some(default_expr) = arg.get_default() {
164                            let r = self.compile_expr(default_expr, &ctx);
165                            (
166                                name,
167                                if undef.is_inject() {
168                                    LLVMValue::Inject(Box::new(r))
169                                } else {
170                                    r
171                                },
172                            )
173                        } else {
174                            (name, undef)
175                        }
176                    })
177                    .collect(),
178            );
179            if item.is_vararg() || is_var_arg {
180                function_value.set_vararg();
181            }
182            if is_array_vararg {
183                function_value.set_array_vararg();
184            }
185            ctx.set_symbol(name.clone(), function_value.into());
186        }
187        // 先编译全局块中的static声明
188        let block = self.module.get_global_block();
189        for stmt in block.iter() {
190            if let crate::ast::stmt::Stmt::StaticDecl(_) = stmt.get_stmt() {
191                self.compile_stmt(stmt, &ctx); // 使用顶层作用域
192            }
193        }
194        // 编译函数实现
195        for (_, item) in self.module.get_functions().iter() {
196            if item.is_extern || item.is_template {
197                continue;
198            }
199            let ctx = self.prepare_function(&ctx, item);
200            for stmt in item.body() {
201                self.compile_stmt(stmt, &ctx);
202            }
203            let flag = ctx.get_flag("return").unwrap();
204            if !flag {
205                let builder = ctx.get_builder();
206                builder.build_return_void();
207            }
208        }
209
210        // 编译主函数
211        let main = self.llvm_module.write().unwrap().register_function(
212            "$Module.main",
213            Global::function_type(Global::unit_type(), vec![]),
214            vec![],
215        );
216        let block = self.module.get_global_block();
217        let entry = main.append_basic_block("entry");
218        let builder = ctx.get_builder();
219        builder.position_at_end(entry);
220        let function_value = FunctionValue::new(
221            main.get_function_ref(),
222            "$Module.main".into(),
223            Box::new(Global::unit_type().get_undef()),
224            vec![],
225        );
226        let ctx = Context::with_function(&ctx, function_value);
227        let ctx = Context::with_scope(&ctx);
228        let ctx = Context::with_flag(&ctx, "return", false);
229
230        // 然后编译其他语句
231        for stmt in block.iter() {
232            if !matches!(stmt.get_stmt(), crate::ast::stmt::Stmt::StaticDecl(_)) {
233                self.compile_stmt(stmt, &ctx);
234            }
235        }
236
237        let flag = ctx.get_flag("return").unwrap();
238        if !flag {
239            builder.build_return_void();
240        }
241        self.llvm_module.clone()
242    }
243    pub fn prepare_function(
244        &self,
245        ctx: &Context,
246        function: &crate::ast::function::Function,
247    ) -> Context {
248        let function_value = ctx
249            .get_symbol(function.name())
250            .unwrap()
251            .as_function()
252            .unwrap();
253        let entry = function_value.append_basic_block("entry");
254        let builder = ctx.get_builder();
255        builder.position_at_end(entry);
256        let ctx = Context::with_function(ctx, function_value.clone());
257        let ctx = Context::with_type(&ctx, "current_function".into(), function.get_type());
258        let ctx = Context::with_scope(&ctx);
259        let ctx = Context::with_flag(&ctx, "return", false);
260        // 注册形参进入作用域
261        for arg in function.args() {
262            let arg_name = arg.name();
263            let mut arg_value = function_value.get_param(arg_name.clone()).unwrap();
264            if let LLVMValue::Function(f) = &mut arg_value {
265                f.set_closure()
266            }
267            ctx.set_symbol(arg_name, arg_value);
268        }
269        ctx
270    }
271}