Skip to main content

compiler/
infer.rs

1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, PatternKind, Span, Stmt, StmtKind};
5
6impl Compiler {
7    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
8        match left {
9            Some(left) if left == right => Ok(left),
10            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
11            Some(left) => Ok(left + right),
12            None => Ok(right),
13        }
14    }
15
16    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
17        self.infer_returns(stmt, true).map(|(ty, _)| ty)
18    }
19
20    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<Type>, bool)> {
21        match &stmt.kind {
22            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_expr(expr)?), true)),
23            StmtKind::Return(None) => Ok((Some(Type::Void), true)),
24            StmtKind::Block(stmts) => {
25                let mut ret = None;
26                for (idx, stmt) in stmts.iter().enumerate() {
27                    let (ty, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
28                    if let Some(ty) = ty {
29                        ret = Some(Self::merge_return_type(stmt.span, ret, ty)?);
30                    }
31                    if always_returns {
32                        return Ok((ret, true));
33                    }
34                }
35                Ok((ret, false))
36            }
37            StmtKind::If { cond, then_body, else_body } => {
38                let cond_ty = self.infer_expr(cond)?;
39                if cond_ty != Type::Bool {
40                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
41                }
42                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
43                let else_returns = if let Some(body) = else_body {
44                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
45                    if let Some(ty) = else_ty {
46                        ret = Some(Self::merge_return_type(body.span, ret, ty)?);
47                    }
48                    else_returns
49                } else {
50                    false
51                };
52                Ok((ret, then_returns && else_returns))
53            }
54            StmtKind::While { cond, body } => {
55                let cond_ty = self.infer_expr(cond)?;
56                if cond_ty != Type::Bool {
57                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
58                }
59                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
60            }
61            StmtKind::Loop(body) => self.infer_returns(body, false),
62            StmtKind::For { pat, range, body } => {
63                if let PatternKind::Var { idx, .. } = &pat.kind {
64                    let ty = self.infer_expr(range)?;
65                    self.set_ty(*idx, ty);
66                } else if let PatternKind::Tuple(pats) = &pat.kind {
67                    let ty = self.infer_expr(range)?;
68                    assert!(ty.is_any());
69                    for pat in pats {
70                        if let Some(idx) = pat.var() {
71                            self.set_ty(idx, Type::Any);
72                        }
73                    }
74                }
75                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
76            }
77            StmtKind::Let { .. } => {
78                self.infer_stmt(stmt)?;
79                Ok((None, false))
80            }
81            StmtKind::Expr(expr, close) => {
82                let ty = self.infer_expr(expr)?;
83                Ok(if *close || !tail { (None, false) } else { (Some(ty), true) })
84            }
85            _ => {
86                self.infer_stmt(stmt)?;
87                Ok((None, false))
88            }
89        }
90    }
91
92    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
93        match &expr.kind {
94            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
95            ExprKind::Value(v) if v.is_list() || v.is_map() => Ok(Type::Any),
96            ExprKind::Value(v) => Ok(v.get_type()),
97            ExprKind::Const(_) => Ok(Type::Any),
98            ExprKind::Var(idx) => {
99                let idx = self.top() + (*idx as usize);
100                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
101            }
102            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
103                Symbol::Const { ty, .. } => Ok(ty.clone()),
104                Symbol::Static { ty, .. } => Ok(ty.clone()),
105                Symbol::Struct(ty, _) => Ok(ty.clone()),
106                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
107                Symbol::Native(ty) => Ok(ty.clone()),
108                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
109            },
110            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
111            ExprKind::Unary { value, .. } => self.infer_expr(value.as_ref()),
112            ExprKind::Binary { left, op, right } => {
113                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
114                let ty = if op.is_logic() {
115                    let left_ty = self.infer_expr(left)?;
116                    if matches!(op, BinaryOp::And | BinaryOp::Or) && left_ty.is_any() { Type::Any } else { Type::Bool }
117                } else if op == &BinaryOp::Idx {
118                    let left_ty = self.infer_expr(left)?;
119                    if let Type::Array(elem_ty, _) = left_ty {
120                        (*elem_ty).clone()
121                    } else if let Type::Vec(elem_ty, _) = left_ty {
122                        (*elem_ty).clone()
123                    } else {
124                        let left_ty = self.symbols.get_type(&left_ty)?;
125                        let right_ty = if right.is_value() || right.is_const() {
126                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
127                            if right_value.is_str() {
128                                if left_ty.is_any() {
129                                    return Ok(Type::Any);
130                                }
131                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
132                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
133                                }
134                            } else if let Type::Struct { fields, .. } = &left_ty
135                                && let Some(idx) = right_value.as_int()
136                            {
137                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
138                            }
139                            right_value.get_type()
140                        } else {
141                            self.infer_expr(right)?
142                        };
143                        if right_ty.is_int() || right_ty.is_uint() {
144                            if left_ty.is_any() {
145                                return Ok(Type::Any);
146                            }
147                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
148                            let fn_ty = self.symbols.get_type(&s)?;
149                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
150                        }
151                        if left_ty.is_any() {
152                            return Ok(Type::Any);
153                        }
154                        Type::Any
155                    }
156                } else {
157                    let right_ty = self.infer_expr(right)?;
158                    if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
159                };
160                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
161                Ok(ty)
162            }
163            ExprKind::Call { obj, params } => {
164                if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
165                    let mut args = Vec::new();
166                    for p in params {
167                        args.push(self.infer_expr(p)?);
168                    }
169                    self.infer_fn_with_params(*id, &args, generic_args)
170                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
171                    let base_name = match ty {
172                        Type::Ident { name, .. } => name.clone(),
173                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
174                        _ => return Ok(Type::Any),
175                    };
176                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
177                    let mut args = vec![self.infer_expr(target)?];
178                    for p in params {
179                        args.push(self.infer_expr(p)?);
180                    }
181                    self.infer_fn(id, &args)
182                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
183                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
184                    for p in params {
185                        args.push(self.infer_expr(p)?);
186                    }
187                    self.infer_fn(*id, &args)
188                } else if obj.is_idx() {
189                    let (target, _, method) = obj.clone().binary().unwrap();
190                    let ty = self.infer_expr(&target)?;
191                    if let Some(method) = self.get_value(&method) {
192                        let method = method.as_str();
193                        let fn_ty = match self.get_field(&ty, method) {
194                            Ok((_, fn_ty)) => fn_ty,
195                            Err(_) => {
196                                let id = self.symbols.get_id(method)?;
197                                if self.symbols.get_symbol(id)?.1.is_fn() {
198                                    Type::Symbol { id, params: Vec::new() }
199                                } else {
200                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
201                                }
202                            }
203                        };
204                        if let Type::Symbol { id, .. } = fn_ty {
205                            let mut args = vec![ty];
206                            for p in params {
207                                args.push(self.infer_expr(p)?);
208                            }
209                            self.infer_fn(id, &args)
210                        } else {
211                            Ok(fn_ty)
212                        }
213                    } else {
214                        Ok(Type::Any)
215                    }
216                } else if let ExprKind::Var(idx) = &obj.kind {
217                    let idx = self.top() + (*idx as usize);
218                    if idx < self.tys.len()
219                        && let Type::Symbol { id, .. } = self.tys[idx]
220                    {
221                        let mut args = Vec::new();
222                        for p in params {
223                            args.push(self.infer_expr(p)?);
224                        }
225                        self.infer_fn(id, &args)
226                    } else {
227                        Ok(Type::Any)
228                    }
229                } else if obj.is_value() {
230                    Ok(Type::Void)
231                } else {
232                    Ok(Type::Any)
233                }
234            }
235            ExprKind::Typed { ty, .. } => Ok(ty.clone()),
236            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
237            ExprKind::Range { start, stop, .. } => {
238                let start_ty = self.infer_expr(start)?;
239                let stop_ty = self.infer_expr(stop)?;
240                Ok(if start_ty.is_any() {
241                    stop_ty
242                } else if stop_ty.is_any() {
243                    start_ty
244                } else {
245                    stop_ty
246                })
247            }
248            _ => Ok(Type::Any),
249        }
250    }
251
252    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
253        let mut fn_tys = Vec::new();
254        for (i, ty) in tys.iter().enumerate() {
255            if !ty.is_any() {
256                fn_tys.push(ty.clone());
257            } else if let Some(arg_ty) = arg_tys.get(i) {
258                fn_tys.push(self.symbols.get_type(arg_ty)?);
259            } else {
260                fn_tys.push(Type::Any);
261            }
262        }
263        Ok(fn_tys)
264    }
265
266    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
267        self.infer_fn_with_params(id, arg_tys, &[])
268    }
269
270    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
271        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
272        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
273            if let Type::Fn { tys, ret: _ } = ty {
274                let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
275                let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
276                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
277                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
278                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
279                let body = if generic_params.is_empty() {
280                    body
281                } else {
282                    let mut compile_tys = tys.clone();
283                    let mut compile_cap = cap.clone();
284                    let saved_state = self.take_local_state();
285                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
286                    self.restore_local_state(saved_state);
287                    Stmt::new(StmtKind::Block(compiled?), Span::default())
288                };
289                if let Some(fns) = self.fns.get_mut(&id) {
290                    for f in fns.iter() {
291                        if f.0 == generic_args && f.1 == fn_tys {
292                            return Ok(f.2.clone());
293                        }
294                    }
295                    fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
296                } else {
297                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
298                }
299                let top = self.tys.len();
300                self.tys.append(&mut fn_tys.clone());
301                for c in cap.vars.iter() {
302                    self.tys.push(self.tys[self.top() + *c].clone());
303                }
304                self.frames.push(top);
305                let ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
306                if let Some(top) = self.frames.pop() {
307                    self.tys.truncate(top);
308                }
309                let ret_ty = match ret_ty {
310                    Ok(ret_ty) => ret_ty,
311                    Err(err) => {
312                        log::error!("infer_fn {} failed: {:?}", name, err);
313                        let should_remove = self
314                            .fns
315                            .get_mut(&id)
316                            .map(|fns| {
317                                fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
318                                fns.is_empty()
319                            })
320                            .unwrap_or(false);
321                        if should_remove {
322                            self.fns.remove(&id);
323                        }
324                        return Err(err);
325                    }
326                };
327                self.fns.get_mut(&id).map(|f| {
328                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
329                });
330                Ok(ret_ty)
331            } else {
332                Ok(Type::Any)
333            }
334        } else if let Symbol::Native(f) = s {
335            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
336        } else if matches!(s, Symbol::Null) {
337            Ok(Type::Any)
338        } else {
339            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
340        }
341    }
342
343    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
344        match &stmt.kind {
345            StmtKind::Expr(expr, close) => {
346                if !close {
347                    self.infer_expr(expr)
348                } else {
349                    self.infer_expr(expr)?;
350                    Ok(Type::Void)
351                }
352            }
353            StmtKind::Return(expr) => {
354                if let Some(e) = expr {
355                    self.infer_expr(e)
356                } else {
357                    Ok(Type::Void)
358                }
359            }
360            StmtKind::Block(stmts) => {
361                for (idx, stmt) in stmts.iter().enumerate() {
362                    let ty = self.infer_stmt(stmt)?;
363                    if stmt.is_return() || idx == stmts.len() - 1 {
364                        return Ok(ty);
365                    }
366                }
367                Ok(Type::Void)
368            }
369            StmtKind::If { then_body, else_body, .. } => {
370                let then_ty = self.infer_stmt(then_body)?;
371                if let Some(e) = else_body {
372                    let else_ty = self.infer_stmt(e)?;
373                    if then_ty != else_ty {
374                        log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
375                        return Ok(if then_ty.is_any() { else_ty } else { then_ty });
376                    }
377                }
378                if else_body.is_none() {
379                    return Ok(Type::Void);
380                }
381                Ok(then_ty)
382            }
383            StmtKind::While { cond, body } => {
384                let cond_ty = self.infer_expr(cond)?;
385                if cond_ty != Type::Bool {
386                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
387                }
388                self.infer_stmt(body)
389            }
390            StmtKind::For { pat, range, body } => {
391                if let PatternKind::Var { idx, .. } = &pat.kind {
392                    let ty = self.infer_expr(range)?;
393                    self.set_ty(*idx, ty);
394                } else if let PatternKind::Tuple(pats) = &pat.kind {
395                    let ty = self.infer_expr(range)?;
396                    assert!(ty.is_any());
397                    for pat in pats {
398                        if let Some(idx) = pat.var() {
399                            self.set_ty(idx, Type::Any);
400                        }
401                    }
402                }
403                self.infer_stmt(body)
404            }
405            StmtKind::Let { pat, value } => {
406                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
407                if let PatternKind::Ident { ty, .. } = &pat.kind {
408                    let annotated_ty = self.symbols.get_type(ty)?;
409                    if annotated_ty.is_any() {
410                        self.add_ty(expr_ty);
411                    } else {
412                        self.add_ty(annotated_ty);
413                    }
414                } else if let PatternKind::Var { idx, .. } = &pat.kind {
415                    self.set_ty(*idx, expr_ty);
416                } else if matches!(pat.kind, PatternKind::Wildcard) {
417                    self.add_ty(expr_ty);
418                }
419                Ok(Type::Void)
420            }
421            _ => Ok(Type::Void),
422        }
423    }
424}