Skip to main content

vm/
fns.rs

1use crate::{JITRunTime, context::BuildContext, rt::PendingFn};
2use anyhow::{Context, Result, anyhow};
3use compiler::{Symbol, infer_generic_args_from_types, substitute_stmt, substitute_type};
4use cranelift::{codegen::ir::FuncRef, prelude::*};
5use cranelift_module::{FuncId, Module};
6use dynamic::Type;
7
8#[derive(Debug)]
9pub enum FnVariant {
10    Native { ty: Type, fn_id: FuncId },                                                                                //没有变体 直接调用的原生函数
11    Inline { fn_ptr: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>, arg_tys: Vec<Type> }, //inline 函数 直接生成代码
12    Compiled(Vec<(Type, FuncId)>),
13}
14
15impl FnVariant {
16    pub fn is_compiled(&self) -> bool {
17        if let Self::Compiled(_) = self { true } else { false }
18    }
19}
20
21use crate::get_type;
22use cranelift_module::Linkage;
23use parser::{Expr, ExprKind, Span, Stmt, StmtKind};
24use smol_str::SmolStr;
25
26#[derive(Debug)]
27pub enum FnInfo {
28    //用来调用的函数信息
29    Call { fn_id: FuncId, arg_tys: Vec<Type>, caps: Vec<usize>, ret: Type },
30    Inline { fn_ptr: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>, arg_tys: Vec<Type> },
31}
32
33impl FnInfo {
34    pub fn get_id(&self) -> Result<FuncId> {
35        if let Self::Call { fn_id, arg_tys: _, caps: _, ret: _ } = self { Ok(*fn_id) } else { Err(anyhow!("Inline 函数没有 FuncId")) }
36    }
37
38    pub fn arg_tys(&self) -> Result<&[Type]> {
39        match self {
40            Self::Call { fn_id: _, arg_tys, caps: _, ret: _ } => Ok(arg_tys),
41            Self::Inline { fn_ptr: _, arg_tys } => Ok(arg_tys),
42        }
43    }
44
45    pub fn get_type(&self) -> Result<Type> {
46        match self {
47            Self::Call { fn_id: _, arg_tys: _, caps: _, ret } => Ok(ret.clone()),
48            Self::Inline { fn_ptr, arg_tys: _ } => fn_ptr(None, vec![]).map(|(_, t)| t),
49        }
50    }
51}
52
53impl JITRunTime {
54    fn coerce_returns(stmt: &Stmt, ret_ty: &Type) -> Stmt {
55        let kind = match &stmt.kind {
56            StmtKind::Return(Some(expr)) if ret_ty.is_void() => StmtKind::Return(None),
57            StmtKind::Return(Some(expr)) => StmtKind::Return(Some(Expr::new(ExprKind::Typed { value: Box::new(expr.clone()), ty: ret_ty.clone() }, expr.span))),
58            StmtKind::Block(stmts) => StmtKind::Block(stmts.iter().map(|stmt| Self::coerce_returns(stmt, ret_ty)).collect()),
59            StmtKind::If { cond, then_body, else_body } => {
60                StmtKind::If { cond: cond.clone(), then_body: Box::new(Self::coerce_returns(then_body, ret_ty)), else_body: else_body.as_ref().map(|body| Box::new(Self::coerce_returns(body, ret_ty))) }
61            }
62            StmtKind::While { cond, body } => StmtKind::While { cond: cond.clone(), body: Box::new(Self::coerce_returns(body, ret_ty)) },
63            StmtKind::Loop(body) => StmtKind::Loop(Box::new(Self::coerce_returns(body, ret_ty))),
64            StmtKind::For { pat, range, body } => StmtKind::For { pat: pat.clone(), range: range.clone(), body: Box::new(Self::coerce_returns(body, ret_ty)) },
65            _ => stmt.kind.clone(),
66        };
67        Stmt::new(kind, stmt.span)
68    }
69
70    pub fn add_inline(&mut self, name: &str, args: Vec<Type>, ret: Type, f: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>) -> Result<u32> {
71        let id = self.compiler.add_symbol(name, Symbol::native(args.clone(), ret));
72        self.fns.insert(id, FnVariant::Inline { fn_ptr: f.into(), arg_tys: args });
73        if let Some((def, method)) = name.split_once("::") {
74            let def_id = self.get_id(def)?;
75            if let Some((_, define)) = self.compiler.symbols.get_symbol_mut(def_id) {
76                if let Symbol::Struct(Type::Struct { params, fields }, _) = define {
77                    fields.push((method.into(), Type::Symbol { id, params: params.clone() }));
78                }
79            }
80        }
81        Ok(id)
82    }
83
84    pub fn get_fn_ref(&mut self, ctx: &mut BuildContext, fn_id: FuncId) -> FuncRef {
85        ctx.get_fn_ref(fn_id).unwrap_or_else(|| {
86            let fn_ref = self.module.declare_func_in_func(fn_id, &mut ctx.builder.func);
87            ctx.fn_refs.push((fn_id, fn_ref));
88            fn_ref
89        })
90    }
91
92    pub fn adjust_args(&mut self, ctx: &mut BuildContext, args: Vec<(Value, Type)>, arg_tys: &[Type]) -> Result<Vec<Value>> {
93        let mut results = Vec::<Value>::new();
94        for ((arg, ty), arg_ty) in args.into_iter().zip(arg_tys.iter()) {
95            if ty != *arg_ty {
96                results.push(self.convert(ctx, (arg, ty), arg_ty.clone())?);
97            } else {
98                results.push(arg);
99            }
100        }
101        Ok(results)
102    }
103
104    pub fn get_fn(&self, id: u32, want_tys: &[Type]) -> Result<FnInfo> {
105        if let Some(fn_info) = self.fns.get(&id) {
106            match fn_info {
107                FnVariant::Compiled(fns) => {
108                    for (ty, fn_id) in fns.iter() {
109                        if let Type::Fn { tys, ret } = ty.clone() {
110                            let mut real_types = Vec::new();
111                            for (ty1, ty2) in tys.iter().zip(want_tys.iter()) {
112                                if ty1 != ty2 {
113                                    if ty1.is_any() || ty2.is_any() {
114                                        real_types.push(ty1.clone());
115                                    }
116                                    //ty1 是目的类型
117                                    else {
118                                        break;
119                                    }
120                                } else {
121                                    real_types.push(ty1.clone());
122                                }
123                            }
124                            if real_types.len() == want_tys.len() {
125                                return Ok(FnInfo::Call { fn_id: *fn_id, arg_tys: real_types, caps: Vec::new(), ret: ret.as_ref().clone() });
126                            }
127                        }
128                    }
129                }
130                FnVariant::Inline { fn_ptr, arg_tys } => {
131                    return Ok(FnInfo::Inline { fn_ptr: fn_ptr.clone(), arg_tys: arg_tys.clone() });
132                }
133                FnVariant::Native { ty, fn_id } => {
134                    if let Type::Fn { tys, ret } = ty.clone() {
135                        return Ok(FnInfo::Call { fn_id: *fn_id, arg_tys: tys, caps: Vec::new(), ret: ret.as_ref().clone() });
136                    }
137                }
138            }
139        }
140        Err(anyhow!("未发现函数 {}", id))
141    }
142
143    pub fn get_sig(&mut self, arg_tys: &[Type], ret: Type) -> Result<Signature> {
144        if let Some(st) = self.sigs.iter().find_map(|s| if s.0 == arg_tys && ret == s.2 { Some(s.1.clone()) } else { None }) {
145            return Ok(st);
146        }
147        let mut sig = self.module.make_signature();
148        for arg in arg_tys.iter() {
149            sig.params.push(AbiParam::new(get_type(arg)?));
150        }
151        if !ret.is_void() {
152            sig.returns.push(AbiParam::new(get_type(&ret)?));
153        }
154        self.sigs.push((arg_tys.to_vec(), sig.clone(), ret.clone()));
155        Ok(sig)
156    }
157
158    fn declare_compiled_fn(&mut self, name_id: Option<&(SmolStr, u32)>, arg_tys: &[Type], ret_ty: Type) -> Result<FuncId> {
159        let sig = self.get_sig(arg_tys, ret_ty.clone())?;
160        log::info!("{:?} {:?}", name_id, sig);
161        if let Some((name, id)) = name_id {
162            let fn_id = self.module.declare_function(&name, Linkage::Local, &sig)?;
163            let variant = (Type::Fn { tys: arg_tys.to_vec(), ret: std::rc::Rc::new(ret_ty.clone()) }, fn_id);
164            if let Some(FnVariant::Compiled(fns)) = self.fns.get_mut(id) {
165                fns.push(variant);
166            } else {
167                self.fns.insert(*id, FnVariant::Compiled(vec![variant]));
168            }
169            Ok(fn_id)
170        } else {
171            Ok(self.module.declare_anonymous_function(&sig)?)
172        }
173    }
174
175    fn define_compiled_fn(&mut self, fn_id: FuncId, name_id: Option<&(SmolStr, u32)>, arg_tys: &[Type], ret_ty: Type, stmt: &Stmt) -> Result<()> {
176        let sig = self.get_sig(arg_tys, ret_ty.clone())?;
177        #[cfg(feature = "ir-disassembly")]
178        let fn_name = name_id.map(|(name, _)| name.clone());
179        let mut ctx = self.module.make_context();
180        ctx.func.signature = sig.clone();
181
182        let mut func_ctx = FunctionBuilderContext::new();
183        let builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx);
184
185        let mut build_ctx = BuildContext::new(builder, &arg_tys)?;
186        self.compile_depth += 1;
187        let stmt = Self::coerce_returns(stmt, &ret_ty);
188        let gen_result = self.gen_stmt(&mut build_ctx, &stmt, None, None);
189        self.compile_depth -= 1;
190        gen_result?;
191
192        build_ctx.builder.seal_all_blocks();
193        #[cfg(feature = "ir-disassembly")]
194        {
195            let ir = format!("{}", ctx.func.display());
196            if let Some(name) = fn_name {
197                self.ir_disassembly.insert(name, ir);
198            }
199        }
200        self.module.define_function(fn_id, &mut ctx).with_context(|| name_id.map(|(name, _)| format!("define function {}", name)).unwrap_or_else(|| "define anonymous function".to_string()))?;
201        log::info!("{:?}", ctx.func);
202        Ok(())
203    }
204
205    pub(crate) fn compile_fn(&mut self, name_id: Option<(SmolStr, u32)>, arg_tys: &[Type], ret_ty: Type, stmt: &Stmt) -> Result<FuncId> {
206        let drain_pending = self.compile_depth == 0;
207        let fn_id = self.declare_compiled_fn(name_id.as_ref(), arg_tys, ret_ty.clone())?;
208        self.define_compiled_fn(fn_id, name_id.as_ref(), arg_tys, ret_ty, stmt)?;
209        if drain_pending {
210            self.drain_pending_fns()?;
211        }
212        Ok(fn_id)
213    }
214
215    fn drain_pending_fns(&mut self) -> Result<()> {
216        while let Some(pending) = self.pending_fns.pop_front() {
217            let name_id = (pending.name, pending.symbol_id);
218            self.define_compiled_fn(pending.fn_id, Some(&name_id), &pending.arg_tys, pending.ret_ty, &pending.body)?;
219        }
220        Ok(())
221    }
222
223    pub fn gen_fn(&mut self, ctx: Option<&BuildContext>, id: u32, arg_tys: &[Type]) -> Result<FnInfo> {
224        self.gen_fn_with_params(ctx, id, arg_tys, &[])
225    }
226
227    pub fn gen_fn_with_params(&mut self, ctx: Option<&BuildContext>, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<FnInfo> {
228        let mut arg_tys: Vec<Type> = arg_tys.iter().map(|ty| self.compiler.symbols.get_type(ty).unwrap()).collect();
229        if generic_args.is_empty()
230            && let Ok(info) = self.get_fn(id, &arg_tys)
231        {
232            return Ok(info);
233        }
234        let (name, s) = self.compiler.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
235        if let Symbol::Fn { ty, args, generic_params, cap, body, is_pub: _ } = s.clone() {
236            if let Type::Fn { tys: decl_tys, ret: _ } = ty {
237                let inferred_generic_args = if generic_args.is_empty() { infer_generic_args_from_types(&generic_params, &decl_tys, &arg_tys) } else { generic_args.to_vec() };
238                let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
239                let decl_tys = if generic_params.is_empty() { decl_tys } else { decl_tys.iter().map(|ty| substitute_type(ty, &generic_params, generic_args)).collect() };
240                while arg_tys.len() < decl_tys.len() {
241                    arg_tys.push(self.compiler.symbols.get_type(&decl_tys[arg_tys.len()]).unwrap_or(Type::Any));
242                }
243                let ret_ty = self.compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
244                if let Some(FnVariant::Compiled(fns)) = self.fns.get(&id) {
245                    for (ty, fn_id) in fns {
246                        if let Type::Fn { tys, ret } = ty
247                            && tys == &arg_tys
248                            && ret.as_ref() == &ret_ty
249                        {
250                            return Ok(FnInfo::Call { fn_id: *fn_id, arg_tys: arg_tys.to_vec(), caps: Vec::new(), ret: ret_ty });
251                        }
252                    }
253                }
254                let mut compile_cap = cap.clone();
255                let body = if generic_params.is_empty() {
256                    body.as_ref().clone()
257                } else {
258                    let mut compile_tys = decl_tys.clone();
259                    let substituted = substitute_stmt(body.as_ref(), &generic_params, generic_args);
260                    let saved_state = self.compiler.take_local_state();
261                    let compiled_body = self.compiler.compile_fn(&args, &mut compile_tys, substituted, &mut compile_cap);
262                    self.compiler.restore_local_state(saved_state);
263                    Stmt::new(StmtKind::Block(compiled_body?), Span::default())
264                };
265                for v in compile_cap.vars.iter() {
266                    ctx.as_ref().map(|ctx| arg_tys.push(ctx.vars[*v].get_ty()));
267                }
268                let fn_id = if self.compile_depth > 0 {
269                    let fn_id = self.declare_compiled_fn(Some(&(name.clone(), id)), &arg_tys, ret_ty.clone())?;
270                    self.pending_fns.push_back(PendingFn { name: name.clone(), symbol_id: id, fn_id, arg_tys: arg_tys.clone(), ret_ty: ret_ty.clone(), body });
271                    fn_id
272                } else {
273                    let fn_id = self.compile_fn(Some((name.clone(), id)), &arg_tys, ret_ty.clone(), &body)?;
274                    self.drain_pending_fns()?;
275                    self.module.finalize_definitions()?;
276                    fn_id
277                };
278                return Ok(FnInfo::Call { fn_id, arg_tys: arg_tys.to_vec(), caps: compile_cap.vars.clone(), ret: ret_ty });
279            }
280            let ret_ty = self.compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
281            for v in cap.vars.iter() {
282                ctx.as_ref().map(|ctx| arg_tys.push(ctx.vars[*v].get_ty()));
283            }
284            let body = if generic_params.is_empty() { body.as_ref().clone() } else { substitute_stmt(body.as_ref(), &generic_params, generic_args) };
285            let fn_id = if self.compile_depth > 0 {
286                let fn_id = self.declare_compiled_fn(Some(&(name.clone(), id)), &arg_tys, ret_ty.clone())?;
287                self.pending_fns.push_back(PendingFn { name: name.clone(), symbol_id: id, fn_id, arg_tys: arg_tys.clone(), ret_ty: ret_ty.clone(), body });
288                fn_id
289            } else {
290                let fn_id = self.compile_fn(Some((name.clone(), id)), &arg_tys, ret_ty.clone(), &body)?;
291                self.drain_pending_fns()?;
292                self.module.finalize_definitions()?;
293                fn_id
294            };
295            return Ok(FnInfo::Call { fn_id, arg_tys: arg_tys.to_vec(), caps: cap.vars.clone(), ret: ret_ty });
296        }
297        Err(anyhow!("生成函数 {} 失败", id))
298    }
299}