Skip to main content

compiler/
infer.rs

1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5use smol_str::SmolStr;
6
7#[derive(Clone)]
8struct ReturnInfo {
9    ty: Type,
10    shape: Option<Type>,
11}
12
13/// 类型推断递归链的硬上限。同一实例化由 `self.fns` 记忆化挡住,但互递归的泛型
14/// 函数每次可能产生新的 (generic_args, fn_tys) 实例化,记忆化命不中,会无限递归
15/// 直至栈溢出。超过此深度即把推断结果回退成 [`Type::Any`],把"挂起/崩溃"降级为
16/// 一个保守但安全的类型。正常代码的推断链远不及此。
17const MAX_INFER_DEPTH: usize = 64;
18
19impl Compiler {
20    fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
21        self.infer_stack.last().cloned()
22    }
23
24    fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
25        self.fns.get(&id).and_then(|fns| {
26            fns.iter().find_map(|item| {
27                if item.0 == generic_args
28                    && item.1 == fn_tys
29                    && let FnInferRet::Pending(seed) = &item.2
30                {
31                    seed.clone()
32                } else {
33                    None
34                }
35            })
36        })
37    }
38
39    fn update_pending_return_seed(&mut self, ty: &Type) {
40        if ty.is_any() {
41            return;
42        }
43        let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
44            return;
45        };
46        let Some(fns) = self.fns.get_mut(&id) else {
47            return;
48        };
49        if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
50            && let FnInferRet::Pending(seed) = &mut item.2
51        {
52            let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
53            *seed = Some(next);
54        }
55    }
56
57    /// 扫描函数体,查找第一个非递归路径上的返回值类型(仅处理字面量)。
58    fn try_find_base_return_ty(&self, body: &Stmt) -> Option<Type> {
59        match &body.kind {
60            StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty(s)),
61            StmtKind::If { then_body, else_body, .. } => self.try_find_base_return_ty(then_body).or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty(b))),
62            StmtKind::Return(Some(expr)) => Self::try_literal_type(expr),
63            StmtKind::Expr(expr, false) => Self::try_literal_type(expr),
64            _ => None,
65        }
66    }
67
68    /// 带作用域的 base case 返回类型查找
69    fn try_find_base_return_ty_with_scope(&mut self, body: &Stmt, fn_id: u32, fn_name: &str, args: &[SmolStr], fn_tys: &[Type]) -> Option<Type> {
70        let saved_state = self.take_local_state();
71        self.frames.push(0);
72        for (arg, ty) in args.iter().zip(fn_tys.iter()) {
73            self.add_name(arg.clone());
74            self.add_ty(ty.clone());
75        }
76        let result = self.try_find_base_return_ty_with_scope_inner(body, fn_id, fn_name);
77        self.restore_local_state(saved_state);
78        result
79    }
80
81    fn try_find_base_return_ty_with_scope_inner(&mut self, body: &Stmt, fn_id: u32, fn_name: &str) -> Option<Type> {
82        match &body.kind {
83            StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty_with_scope_inner(s, fn_id, fn_name)),
84            StmtKind::If { then_body, else_body, .. } => {
85                self.try_find_base_return_ty_with_scope_inner(then_body, fn_id, fn_name).or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty_with_scope_inner(b, fn_id, fn_name)))
86            }
87            StmtKind::Return(Some(expr)) => {
88                if Self::expr_calls_fn(expr, fn_id, fn_name) {
89                    None
90                } else {
91                    self.infer_return_expr(expr).ok().map(|info| info.ty)
92                }
93            }
94            StmtKind::Expr(expr, false) => {
95                if Self::expr_calls_fn(expr, fn_id, fn_name) {
96                    None
97                } else {
98                    self.infer_return_expr(expr).ok().map(|info| info.ty)
99                }
100            }
101            _ => None,
102        }
103    }
104
105    fn expr_calls_fn(expr: &Expr, fn_id: u32, fn_name: &str) -> bool {
106        match &expr.kind {
107            ExprKind::Call { obj, params } => {
108                if let ExprKind::Id(id, _) = &obj.kind {
109                    return *id == fn_id;
110                }
111                if let ExprKind::Ident(name) = &obj.kind {
112                    if name.as_str() == fn_name || fn_name.ends_with(&format!("::{}", name)) {
113                        return true;
114                    }
115                }
116                params.iter().any(|p| Self::expr_calls_fn(p, fn_id, fn_name))
117            }
118            ExprKind::Binary { left, op: _, right } => Self::expr_calls_fn(left, fn_id, fn_name) || Self::expr_calls_fn(right, fn_id, fn_name),
119            ExprKind::Unary { op: _, value } => Self::expr_calls_fn(value, fn_id, fn_name),
120            ExprKind::Typed { value, ty: _ } => Self::expr_calls_fn(value, fn_id, fn_name),
121            _ => false,
122        }
123    }
124
125    fn try_literal_type(expr: &Expr) -> Option<Type> {
126        match &expr.kind {
127            ExprKind::Value(v) => Some(v.get_type()),
128            ExprKind::Unary { op: UnaryOp::Neg, value } => Self::try_literal_type(value),
129            _ => None,
130        }
131    }
132
133    fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
134        match &pat.kind {
135            PatternKind::Ident { name, ty } => {
136                let annotated_ty = self.symbols.get_type(ty)?;
137                self.add_name(name.clone());
138                self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
139            }
140            PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
141            PatternKind::Tuple(pats) => {
142                if let Type::Tuple(tys) = expr_ty {
143                    for (pat, ty) in pats.iter().zip(tys) {
144                        self.add_pattern_bindings_for_infer(pat, ty)?;
145                    }
146                } else {
147                    for pat in pats {
148                        self.add_pattern_bindings_for_infer(pat, Type::Any)?;
149                    }
150                }
151            }
152            PatternKind::List { elems, .. } => {
153                for pat in elems {
154                    self.add_pattern_bindings_for_infer(pat, Type::Any)?;
155                }
156            }
157            PatternKind::Wildcard => {
158                self.add_name("".into());
159                self.add_ty(expr_ty);
160            }
161            PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
162        }
163        Ok(())
164    }
165
166    fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
167        if matches!(range.kind, ExprKind::Range { .. }) {
168            return self.infer_range_expr(range);
169        }
170        Ok(match self.infer_expr(range)? {
171            Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
172            _ => Type::Any,
173        })
174    }
175
176    fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
177        let ExprKind::Range { start, stop, .. } = &range.kind else {
178            return self.infer_expr(range);
179        };
180        let start_ty = self.infer_expr(start)?;
181        let stop_ty = self.infer_expr(stop)?;
182        Ok(Self::merge_range_bound_types(start_ty, stop_ty))
183    }
184
185    fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
186        if start_ty.is_any() {
187            stop_ty
188        } else if stop_ty.is_any() {
189            start_ty
190        // 无后缀整数字面量(默认 I32/I64)在 range 里向另一端的具体无符号类型靠拢,
191        // 这样 `0..n`(n: u32)仍是 u32 range,而不是被默认 I64 拖宽成 i64(会拆穿 GPU 后端)。
192        } else if matches!(start_ty, Type::I32 | Type::I64) && stop_ty.is_uint() {
193            stop_ty
194        } else if matches!(stop_ty, Type::I32 | Type::I64) && start_ty.is_uint() {
195            start_ty
196        } else {
197            start_ty + stop_ty
198        }
199    }
200
201    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
202        match left {
203            Some(left) if left == right => Ok(left),
204            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
205            Some(left) if left.is_any() || right.is_any() => Ok(Type::Any),
206            Some(left) => Ok(left + right),
207            None => Ok(right),
208        }
209    }
210
211    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
212        if !ty.is_any() {
213            return match ty {
214                Type::Struct { .. } => Some(ty.clone()),
215                Type::Map => Some(Type::Map),
216                Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
217                _ => None,
218            };
219        }
220        match &expr.kind {
221            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
222            ExprKind::Dict(_) => Some(Type::Map),
223            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
224            ExprKind::Const(idx) => self.consts.get_index(*idx).and_then(|(_, value)| Self::dynamic_return_shape(value.get_type())),
225            ExprKind::Typed { ty, .. } => Some(ty.clone()),
226            _ => None,
227        }
228    }
229
230    fn dynamic_return_shape(ty: Type) -> Option<Type> {
231        match ty {
232            Type::Map => Some(Type::Map),
233            Type::List(elem) => Some(Type::List(elem)),
234            Type::Array(elem, _) => Some(Type::List(elem)),
235            _ => None,
236        }
237    }
238
239    fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
240        match &expr.kind {
241            ExprKind::Var(idx) => Some(*idx),
242            ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
243            _ => None,
244        }
245    }
246
247    fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
248        match method {
249            "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
250                Some(ListElemState::Known(ty)) => ty,
251                Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
252                None => elem_ty.clone(),
253            })),
254            "push" => {
255                let pushed_ty = params
256                    .first()
257                    .map(|param| {
258                        if let Some(value) = self.get_value(param)
259                            && (value.is_str() || value.is_native())
260                        {
261                            Ok(value.get_type())
262                        } else {
263                            self.infer_expr(param)
264                        }
265                    })
266                    .transpose()?
267                    .unwrap_or(Type::Any);
268                if let Some(idx) = self.local_var_idx_for_expr(target) {
269                    let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
270                    let next_state = match state {
271                        ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
272                        ListElemState::Unknown => ListElemState::Known(pushed_ty),
273                        ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
274                        ListElemState::Known(prev) => {
275                            let merged = if prev == pushed_ty {
276                                prev
277                            } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
278                                prev + pushed_ty
279                            } else {
280                                Type::Any
281                            };
282                            if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
283                        }
284                        ListElemState::Mixed => ListElemState::Mixed,
285                    };
286                    let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
287                    self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
288                    self.set_list_elem_state(idx, Some(next_state));
289                }
290                Ok(Some(Type::Void))
291            }
292            "len" => Ok(Some(Type::I32)),
293            "is_list" | "is_null" => Ok(Some(Type::Bool)),
294            _ => Ok(None),
295        }
296    }
297
298    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
299        let ty = self.infer_expr(expr)?;
300        let shape = self.return_shape(expr, &ty);
301        let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
302        Ok(ReturnInfo { ty, shape })
303    }
304
305    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
306        let Some(left) = left else {
307            return Ok(right);
308        };
309        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
310            && left_shape != right_shape
311        {
312            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
313        }
314        if let Some(left_shape) = &left.shape
315            && left_shape.is_struct()
316            && right.ty.is_any()
317            && right.shape.is_none()
318        {
319            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
320        }
321        if let Some(right_shape) = &right.shape
322            && right_shape.is_struct()
323            && left.ty.is_any()
324            && left.shape.is_none()
325        {
326            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
327        }
328        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
329        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
330    }
331
332    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
333        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
334    }
335
336    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
337        self.infer_returns(stmt, true).map(|_| ())
338    }
339
340    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
341        match &stmt.kind {
342            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
343            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
344            StmtKind::Block(stmts) => {
345                let mut ret = None;
346                for (idx, stmt) in stmts.iter().enumerate() {
347                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
348                    if let Some(info) = info {
349                        self.update_pending_return_seed(&info.ty);
350                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
351                        if let Some(ret) = &ret {
352                            self.update_pending_return_seed(&ret.ty);
353                        }
354                    }
355                    if always_returns {
356                        return Ok((ret, true));
357                    }
358                }
359                Ok((ret, false))
360            }
361            StmtKind::If { cond, then_body, else_body } => {
362                let cond_ty = self.infer_expr(cond)?;
363                if cond_ty != Type::Bool {
364                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
365                }
366                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
367                if let Some(ret) = &ret {
368                    self.update_pending_return_seed(&ret.ty);
369                }
370                let else_returns = if let Some(body) = else_body {
371                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
372                    if let Some(info) = else_ty {
373                        self.update_pending_return_seed(&info.ty);
374                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
375                        if let Some(ret) = &ret {
376                            self.update_pending_return_seed(&ret.ty);
377                        }
378                    }
379                    else_returns
380                } else {
381                    false
382                };
383                Ok((ret, then_returns && else_returns))
384            }
385            StmtKind::While { cond, body } => {
386                let cond_ty = self.infer_expr(cond)?;
387                if cond_ty != Type::Bool {
388                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
389                }
390                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
391            }
392            StmtKind::Loop(body) => self.infer_returns(body, false),
393            StmtKind::For { pat, range, body } => {
394                let ty = self.for_pattern_ty(range)?;
395                self.add_pattern_bindings_for_infer(pat, ty)?;
396                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
397            }
398            StmtKind::Let { .. } => {
399                self.infer_stmt(stmt)?;
400                Ok((None, false))
401            }
402            StmtKind::Expr(expr, close) => {
403                let info = self.infer_return_expr(expr)?;
404                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
405            }
406            _ => {
407                self.infer_stmt(stmt)?;
408                Ok((None, false))
409            }
410        }
411    }
412
413    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
414        match &expr.kind {
415            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
416            ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
417            ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
418            ExprKind::Value(v) => Ok(v.get_type()),
419            ExprKind::Const(idx) => Ok(match self.consts.get_index(*idx) {
420                Some((_, value)) if value.is_str() => Type::Str,
421                Some((_, value)) if value.is_list() && value.len() == 0 => Type::list_any(),
422                _ => Type::Any,
423            }),
424            ExprKind::Var(idx) => {
425                let idx = self.top() + (*idx as usize);
426                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
427            }
428            ExprKind::Ident(ident) => {
429                for idx in (self.top()..self.names.len()).rev() {
430                    if self.names[idx].eq(ident) && idx < self.tys.len() {
431                        return self.symbols.get_type(&self.tys[idx]);
432                    }
433                }
434                let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
435                match self.symbols.get_symbol(id)?.1 {
436                    Symbol::Const { ty, .. } => Ok(ty.clone()),
437                    Symbol::Static { ty, .. } => Ok(ty.clone()),
438                    Symbol::Struct(ty, _) => Ok(ty.clone()),
439                    Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
440                    Symbol::Native(ty) => Ok(ty.clone()),
441                    s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
442                }
443            }
444            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
445                Symbol::Const { ty, .. } => Ok(ty.clone()),
446                Symbol::Static { ty, .. } => Ok(ty.clone()),
447                Symbol::Struct(ty, _) => Ok(ty.clone()),
448                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
449                Symbol::Native(ty) => Ok(ty.clone()),
450                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
451            },
452            ExprKind::Generic { obj, params } => {
453                let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
454                match self.infer_expr(obj)? {
455                    Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
456                    _ => Ok(Type::Any),
457                }
458            }
459            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
460            ExprKind::Unary { op, value } => match op {
461                UnaryOp::Not => {
462                    let ty = self.infer_expr(value.as_ref())?;
463                    if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
464                }
465                UnaryOp::Neg => self.infer_expr(value.as_ref()),
466                UnaryOp::Unknow => Ok(Type::Any),
467            },
468            ExprKind::Binary { left, op, right } => {
469                if op == &BinaryOp::Assign
470                    && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
471                {
472                    if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
473                        if left_items.len() != right_items.len() {
474                            return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
475                        }
476                        for item in right_items {
477                            let _ = self.infer_expr(item)?;
478                        }
479                    } else {
480                        let _ = self.infer_expr(right)?;
481                    }
482                    return Ok(Type::Void);
483                }
484                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
485                let ty = if op.is_logic() {
486                    Type::Bool
487                } else if op == &BinaryOp::Idx {
488                    let left_ty = self.infer_expr(left)?;
489                    if let Type::Array(elem_ty, _) = left_ty {
490                        (*elem_ty).clone()
491                    } else if let Type::Vec(elem_ty, _) = left_ty {
492                        (*elem_ty).clone()
493                    } else if let Type::List(elem_ty) = left_ty {
494                        (*elem_ty).clone()
495                    } else {
496                        let left_ty = self.symbols.get_type(&left_ty)?;
497                        let right_ty = if right.is_value() || right.is_const() {
498                            let right_value = if let ExprKind::Const(c) = &right.kind {
499                                match self.consts.get_index(*c) {
500                                    Some((_, v)) => v.clone(),
501                                    None => right.clone().value()?,
502                                }
503                            } else {
504                                right.clone().value()?
505                            };
506                            if right_value.is_str() {
507                                if left_ty.is_any() {
508                                    return Ok(Type::Any);
509                                }
510                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
511                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
512                                }
513                            } else if let Type::Struct { fields, .. } = &left_ty
514                                && let Some(idx) = right_value.as_int()
515                            {
516                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
517                            }
518                            right_value.get_type()
519                        } else {
520                            self.infer_expr(right)?
521                        };
522                        if right_ty.is_int() || right_ty.is_uint() {
523                            if left_ty.is_any() {
524                                return Ok(Type::Any);
525                            }
526                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
527                            let fn_ty = self.symbols.get_type(&s)?;
528                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
529                        }
530                        if left_ty.is_any() {
531                            return Ok(Type::Any);
532                        }
533                        Type::Any
534                    }
535                } else {
536                    let left_ty = self.infer_expr(left)?;
537                    let right_ty = self.infer_expr(right)?;
538                    if op == &BinaryOp::Assign {
539                        if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
540                    } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
541                        left_ty
542                    } else {
543                        left_ty + right_ty
544                    }
545                };
546                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
547                Ok(ty)
548            }
549            ExprKind::Call { obj, params } => {
550                if let ExprKind::Assoc { ty, name } = &obj.kind {
551                    let base_name = match ty {
552                        Type::Ident { name, .. } => name.clone(),
553                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
554                        _ => return Ok(Type::Any),
555                    };
556                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
557                    let generic_args = match ty {
558                        Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
559                        _ => Vec::new(),
560                    };
561                    let mut args = Vec::new();
562                    for p in params {
563                        args.push(self.infer_expr(p)?);
564                    }
565                    self.infer_fn_with_params(id, &args, &generic_args)
566                } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
567                    let mut args = Vec::new();
568                    for p in params {
569                        args.push(self.infer_expr(p)?);
570                    }
571                    self.infer_fn_with_params(*id, &args, generic_args)
572                } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
573                    let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
574                        return Ok(Type::Any);
575                    };
576                    let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
577                    let mut args = Vec::new();
578                    for p in params {
579                        args.push(self.infer_expr(p)?);
580                    }
581                    self.infer_fn_with_params(id, &args, &generic_args)
582                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
583                    let base_name = match ty {
584                        Type::Ident { name, .. } => name.clone(),
585                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
586                        _ => return Ok(Type::Any),
587                    };
588                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
589                    let mut args = vec![self.infer_expr(target)?];
590                    for p in params {
591                        args.push(self.infer_expr(p)?);
592                    }
593                    self.infer_fn(id, &args)
594                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
595                    let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
596                    if let Some(target) = obj_expr
597                        && let Some(method) = method
598                    {
599                        let target_ty = self.infer_expr(target)?;
600                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
601                            && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
602                        {
603                            return Ok(ret_ty);
604                        }
605                    }
606                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
607                    for p in params {
608                        args.push(self.infer_expr(p)?);
609                    }
610                    self.infer_fn(*id, &args)
611                } else if let ExprKind::Ident(name) = &obj.kind {
612                    for idx in (self.top()..self.names.len()).rev() {
613                        if self.names[idx].eq(name) && idx < self.tys.len() {
614                            return if let Type::Symbol { id, .. } = &self.tys[idx] {
615                                let id = *id;
616                                let mut args = Vec::new();
617                                for p in params {
618                                    args.push(self.infer_expr(p)?);
619                                }
620                                self.infer_fn(id, &args)
621                            } else {
622                                Ok(Type::Any)
623                            };
624                        }
625                    }
626                    let Ok(id) = self.symbols.get_id(name) else {
627                        return Ok(Type::Any);
628                    };
629                    if !self.symbols.get_symbol(id)?.1.is_fn() {
630                        return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
631                    }
632                    let mut args = Vec::new();
633                    for p in params {
634                        args.push(self.infer_expr(p)?);
635                    }
636                    self.infer_fn(id, &args)
637                } else if obj.is_idx() {
638                    let (target, _, method) = obj.clone().binary().unwrap();
639                    let ty = self.infer_expr(&target)?;
640                    if let Some(method) = self.get_value(&method) {
641                        let method = method.as_str();
642                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
643                            && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
644                        {
645                            return Ok(ret_ty);
646                        }
647                        let fn_ty = match self.get_field(&ty, method) {
648                            Ok((_, fn_ty)) => fn_ty,
649                            Err(_) => {
650                                let id = self.symbols.get_id(method)?;
651                                if self.symbols.get_symbol(id)?.1.is_fn() {
652                                    Type::Symbol { id, params: Vec::new() }
653                                } else {
654                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
655                                }
656                            }
657                        };
658                        if let Type::Symbol { id, .. } = fn_ty {
659                            let mut args = vec![ty];
660                            for p in params {
661                                args.push(self.infer_expr(p)?);
662                            }
663                            self.infer_fn(id, &args)
664                        } else {
665                            Ok(fn_ty)
666                        }
667                    } else {
668                        Ok(Type::Any)
669                    }
670                } else if let ExprKind::Var(idx) = &obj.kind {
671                    let idx = self.top() + (*idx as usize);
672                    if idx < self.tys.len()
673                        && let Type::Symbol { id, .. } = self.tys[idx]
674                    {
675                        let mut args = Vec::new();
676                        for p in params {
677                            args.push(self.infer_expr(p)?);
678                        }
679                        self.infer_fn(id, &args)
680                    } else {
681                        Ok(Type::Any)
682                    }
683                } else if obj.is_value() {
684                    Ok(Type::Void)
685                } else {
686                    Ok(Type::Any)
687                }
688            }
689            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
690            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
691            ExprKind::Repeat { value, len } => {
692                let value_ty = self.infer_expr(value)?;
693                let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
694                if let Type::ConstInt(len) = len {
695                    let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
696                    Ok(Type::Array(std::rc::Rc::new(value_ty), len))
697                } else {
698                    Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
699                }
700            }
701            ExprKind::List(items) => {
702                if items.is_empty() {
703                    return Ok(Type::list_any());
704                }
705                let mut elem_ty = Type::Any;
706                for item in items {
707                    let item_ty = self.infer_expr(item)?;
708                    elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
709                }
710                Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
711            }
712            ExprKind::Range { start, stop, .. } => {
713                let start_ty = self.infer_expr(start)?;
714                let stop_ty = self.infer_expr(stop)?;
715                Ok(Self::merge_range_bound_types(start_ty, stop_ty))
716            }
717            _ => Ok(Type::Any),
718        }
719    }
720
721    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
722        let mut fn_tys = Vec::new();
723        for (i, ty) in tys.iter().enumerate() {
724            if !ty.is_any() {
725                fn_tys.push(ty.clone());
726            } else if let Some(arg_ty) = arg_tys.get(i) {
727                fn_tys.push(self.symbols.get_type(arg_ty)?);
728            } else {
729                fn_tys.push(Type::Any);
730            }
731        }
732        Ok(fn_tys)
733    }
734
735    fn is_optimizable_local_ty(ty: &Type) -> bool {
736        ty.is_bool() || ty.is_native()
737    }
738
739    fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
740        matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
741    }
742
743    fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
744        let ty = self.tys.get(pos)?;
745        match ty {
746            Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
747                if let ListElemState::Known(elem_ty) = state
748                    && Self::is_optimizable_list_elem_ty(&elem_ty)
749                {
750                    Some(Type::List(std::rc::Rc::new(elem_ty)))
751                } else {
752                    None
753                }
754            }),
755            ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
756            _ => None,
757        }
758    }
759
760    fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
761        (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
762    }
763
764    fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
765        let items = self.local_type_hints.entry(id).or_default();
766        if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
767            item.2 = hints;
768        } else {
769            items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
770        }
771    }
772
773    pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
774        self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
775    }
776
777    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
778        self.infer_fn_with_params(id, arg_tys, &[])
779    }
780
781    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
782        // 病态(互)递归泛型推断会不断产生新实例化、绕过记忆化;到达深度上限即回退 Any,
783        // 避免推断阶段栈溢出崩溃。
784        if self.infer_stack.len() > MAX_INFER_DEPTH {
785            return Ok(Type::Any);
786        }
787        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
788        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
789            if let Type::Fn { tys, ret: _ } = ty {
790                let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
791                let generic_args = resolved_generic_args.as_slice();
792                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
793                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
794                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
795                let body = if generic_params.is_empty() {
796                    body
797                } else {
798                    let mut compile_tys = tys.clone();
799                    let mut compile_cap = cap.clone();
800                    let saved_state = self.take_local_state();
801                    if let Some((module, _)) = name.split_once("::") {
802                        self.symbols.push_module_scope(module.into());
803                    }
804                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
805                    if name.contains("::") {
806                        self.symbols.pop_module_scope();
807                    }
808                    self.restore_local_state(saved_state);
809                    Stmt::new(StmtKind::Block(compiled?), Span::default())
810                };
811                if let Some(fns) = self.fns.get_mut(&id) {
812                    for f in fns.iter() {
813                        if f.0 == generic_args && f.1 == fn_tys {
814                            return match &f.2 {
815                                FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
816                                FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or_else(|| {
817                                    // 递归自调用且种子为空:尝试从函数体 base case 查找返回类型
818                                    if self.infer_stack.iter().any(|(sid, sargs, _)| *sid == id && sargs == generic_args) {
819                                        if let Some(base_ty) = self.try_find_base_return_ty(&body) {
820                                            return self.symbols.get_type(&base_ty);
821                                        }
822                                    }
823                                    Ok(Type::Any)
824                                }),
825                            };
826                        }
827                    }
828                    fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
829                } else {
830                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
831                }
832                // 递归函数:预扫描 base case 返回类型作为种子
833                if self.pending_return_seed(id, generic_args, &fn_tys).is_none() {
834                    if let Some(base_ty) = self.try_find_base_return_ty_with_scope(&body, id, &name, &args, &fn_tys) {
835                        if let Some(fns) = self.fns.get_mut(&id) {
836                            if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
837                                && let FnInferRet::Pending(seed) = &mut item.2
838                                && seed.is_none()
839                            {
840                                *seed = Some(base_ty);
841                            }
842                        }
843                    }
844                }
845                let mut ret_ty = None;
846                let mut local_type_hints = Vec::new();
847                for _ in 0..4 {
848                    let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
849                    let saved_state = self.take_local_state();
850                    self.frames.push(0);
851                    for (arg, ty) in args.iter().zip(fn_tys.iter()) {
852                        self.add_name(arg.clone());
853                        self.add_ty(ty.clone());
854                    }
855                    for c in cap.vars.iter() {
856                        if let Some((name, ty)) = cap.names.get(*c) {
857                            self.add_name(name.clone());
858                            self.add_ty(ty.clone());
859                        } else {
860                            self.add_name("".into());
861                            self.add_ty(Type::Any);
862                        }
863                    }
864                    self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
865                    let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
866                    self.infer_stack.pop();
867                    let pass_local_type_hints = self.collect_local_type_hints();
868                    self.restore_local_state(saved_state);
869                    let pass_ret_ty = match pass_ret_ty {
870                        Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
871                        Err(err) => {
872                            log::error!("infer_fn {} failed: {:?}", name, err);
873                            let should_remove = self
874                                .fns
875                                .get_mut(&id)
876                                .map(|fns| {
877                                    fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
878                                    fns.is_empty()
879                                })
880                                .unwrap_or(false);
881                            if should_remove {
882                                self.fns.remove(&id);
883                            }
884                            return Err(err);
885                        }
886                    };
887                    if !pass_ret_ty.is_any() {
888                        self.update_pending_return_seed(&pass_ret_ty);
889                        ret_ty = Some(pass_ret_ty.clone());
890                    } else if ret_ty.is_none() {
891                        ret_ty = Some(pass_ret_ty);
892                    }
893                    local_type_hints = pass_local_type_hints;
894                    let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
895                    if before_seed == after_seed {
896                        break;
897                    }
898                }
899                let ret_ty = ret_ty.unwrap_or(Type::Any);
900                self.fns.get_mut(&id).map(|f| {
901                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
902                });
903                self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
904                if generic_args.is_empty()
905                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
906                    && ret.is_any()
907                {
908                    *ret = std::rc::Rc::new(ret_ty.clone());
909                }
910                Ok(ret_ty)
911            } else {
912                Ok(Type::Any)
913            }
914        } else if let Symbol::Native(f) = s {
915            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
916        } else if matches!(s, Symbol::Null) {
917            Ok(Type::Any)
918        } else {
919            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
920        }
921    }
922
923    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
924        match &stmt.kind {
925            StmtKind::Expr(expr, close) => {
926                if !close {
927                    self.infer_expr(expr)
928                } else {
929                    self.infer_expr(expr)?;
930                    Ok(Type::Void)
931                }
932            }
933            StmtKind::Return(expr) => {
934                if let Some(e) = expr {
935                    self.infer_expr(e)
936                } else {
937                    Ok(Type::Void)
938                }
939            }
940            StmtKind::Block(stmts) => {
941                for (idx, stmt) in stmts.iter().enumerate() {
942                    let ty = self.infer_stmt(stmt)?;
943                    if stmt.is_return() || idx == stmts.len() - 1 {
944                        return Ok(ty);
945                    }
946                }
947                Ok(Type::Void)
948            }
949            StmtKind::If { then_body, else_body, .. } => {
950                let then_ty = self.infer_stmt(then_body)?;
951                if let Some(e) = else_body {
952                    let else_ty = self.infer_stmt(e)?;
953                    if then_ty != else_ty {
954                        log::debug!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
955                        return Self::merge_return_type(stmt.span, Some(then_ty), else_ty);
956                    }
957                }
958                if else_body.is_none() {
959                    return Ok(Type::Void);
960                }
961                Ok(then_ty)
962            }
963            StmtKind::While { cond, body } => {
964                let cond_ty = self.infer_expr(cond)?;
965                if cond_ty != Type::Bool {
966                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
967                }
968                self.infer_stmt(body)
969            }
970            StmtKind::For { pat, range, body } => {
971                let ty = self.for_pattern_ty(range)?;
972                self.add_pattern_bindings_for_infer(pat, ty)?;
973                self.infer_stmt(body)
974            }
975            StmtKind::Let { pat, value } => {
976                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
977                self.add_pattern_bindings_for_infer(pat, expr_ty)?;
978                Ok(Type::Void)
979            }
980            _ => Ok(Type::Void),
981        }
982    }
983}