Skip to main content

compiler/
infer.rs

1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8    ty: Type,
9    shape: Option<Type>,
10}
11
12impl Compiler {
13    fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
14        self.infer_stack.last().cloned()
15    }
16
17    fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
18        self.fns.get(&id).and_then(|fns| {
19            fns.iter().find_map(|item| {
20                if item.0 == generic_args
21                    && item.1 == fn_tys
22                    && let FnInferRet::Pending(seed) = &item.2
23                {
24                    seed.clone()
25                } else {
26                    None
27                }
28            })
29        })
30    }
31
32    fn update_pending_return_seed(&mut self, ty: &Type) {
33        if ty.is_any() {
34            return;
35        }
36        let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
37            return;
38        };
39        let Some(fns) = self.fns.get_mut(&id) else {
40            return;
41        };
42        if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
43            && let FnInferRet::Pending(seed) = &mut item.2
44        {
45            let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
46            *seed = Some(next);
47        }
48    }
49
50    fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
51        match &pat.kind {
52            PatternKind::Ident { name, ty } => {
53                let annotated_ty = self.symbols.get_type(ty)?;
54                self.add_name(name.clone());
55                self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
56            }
57            PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
58            PatternKind::Tuple(pats) => {
59                if let Type::Tuple(tys) = expr_ty {
60                    for (pat, ty) in pats.iter().zip(tys) {
61                        self.add_pattern_bindings_for_infer(pat, ty)?;
62                    }
63                } else {
64                    for pat in pats {
65                        self.add_pattern_bindings_for_infer(pat, Type::Any)?;
66                    }
67                }
68            }
69            PatternKind::List { elems, .. } => {
70                for pat in elems {
71                    self.add_pattern_bindings_for_infer(pat, Type::Any)?;
72                }
73            }
74            PatternKind::Wildcard => {
75                self.add_name("".into());
76                self.add_ty(expr_ty);
77            }
78            PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
79        }
80        Ok(())
81    }
82
83    fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
84        if matches!(range.kind, ExprKind::Range { .. }) {
85            return self.infer_expr(range);
86        }
87        Ok(match self.infer_expr(range)? {
88            Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
89            _ => Type::Any,
90        })
91    }
92
93    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
94        match left {
95            Some(left) if left == right => Ok(left),
96            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
97            Some(left) => Ok(left + right),
98            None => Ok(right),
99        }
100    }
101
102    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
103        if !ty.is_any() {
104            return match ty {
105                Type::Struct { .. } => Some(ty.clone()),
106                Type::Map => Some(Type::Map),
107                Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
108                _ => None,
109            };
110        }
111        match &expr.kind {
112            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
113            ExprKind::Dict(_) => Some(Type::Map),
114            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
115            ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
116            ExprKind::Typed { ty, .. } => Some(ty.clone()),
117            _ => None,
118        }
119    }
120
121    fn dynamic_return_shape(ty: Type) -> Option<Type> {
122        match ty {
123            Type::Map => Some(Type::Map),
124            Type::List(elem) => Some(Type::List(elem)),
125            Type::Array(elem, _) => Some(Type::List(elem)),
126            _ => None,
127        }
128    }
129
130    fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
131        match &expr.kind {
132            ExprKind::Var(idx) => Some(*idx),
133            ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
134            _ => None,
135        }
136    }
137
138    fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
139        match method {
140            "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
141                Some(ListElemState::Known(ty)) => ty,
142                Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
143                None => elem_ty.clone(),
144            })),
145            "push" => {
146                let pushed_ty = params
147                    .first()
148                    .map(|param| {
149                        if let Some(value) = self.get_value(param)
150                            && (value.is_str() || value.is_native())
151                        {
152                            Ok(value.get_type())
153                        } else {
154                            self.infer_expr(param)
155                        }
156                    })
157                    .transpose()?
158                    .unwrap_or(Type::Any);
159                if let Some(idx) = self.local_var_idx_for_expr(target) {
160                    let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
161                    let next_state = match state {
162                        ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
163                        ListElemState::Unknown => ListElemState::Known(pushed_ty),
164                        ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
165                        ListElemState::Known(prev) => {
166                            let merged = if prev == pushed_ty {
167                                prev
168                            } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
169                                prev + pushed_ty
170                            } else {
171                                Type::Any
172                            };
173                            if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
174                        }
175                        ListElemState::Mixed => ListElemState::Mixed,
176                    };
177                    let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
178                    self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
179                    self.set_list_elem_state(idx, Some(next_state));
180                }
181                Ok(Some(Type::Void))
182            }
183            "len" => Ok(Some(Type::I32)),
184            "is_list" | "is_null" => Ok(Some(Type::Bool)),
185            _ => Ok(None),
186        }
187    }
188
189    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
190        let ty = self.infer_expr(expr)?;
191        let shape = self.return_shape(expr, &ty);
192        let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
193        Ok(ReturnInfo { ty, shape })
194    }
195
196    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
197        let Some(left) = left else {
198            return Ok(right);
199        };
200        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
201            && left_shape != right_shape
202        {
203            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
204        }
205        if let Some(left_shape) = &left.shape
206            && left_shape.is_struct()
207            && right.ty.is_any()
208            && right.shape.is_none()
209        {
210            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
211        }
212        if let Some(right_shape) = &right.shape
213            && right_shape.is_struct()
214            && left.ty.is_any()
215            && left.shape.is_none()
216        {
217            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
218        }
219        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
220        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
221    }
222
223    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
224        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
225    }
226
227    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
228        self.infer_returns(stmt, true).map(|_| ())
229    }
230
231    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
232        match &stmt.kind {
233            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
234            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
235            StmtKind::Block(stmts) => {
236                let mut ret = None;
237                for (idx, stmt) in stmts.iter().enumerate() {
238                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
239                    if let Some(info) = info {
240                        self.update_pending_return_seed(&info.ty);
241                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
242                        if let Some(ret) = &ret {
243                            self.update_pending_return_seed(&ret.ty);
244                        }
245                    }
246                    if always_returns {
247                        return Ok((ret, true));
248                    }
249                }
250                Ok((ret, false))
251            }
252            StmtKind::If { cond, then_body, else_body } => {
253                let cond_ty = self.infer_expr(cond)?;
254                if cond_ty != Type::Bool {
255                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
256                }
257                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
258                if let Some(ret) = &ret {
259                    self.update_pending_return_seed(&ret.ty);
260                }
261                let else_returns = if let Some(body) = else_body {
262                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
263                    if let Some(info) = else_ty {
264                        self.update_pending_return_seed(&info.ty);
265                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
266                        if let Some(ret) = &ret {
267                            self.update_pending_return_seed(&ret.ty);
268                        }
269                    }
270                    else_returns
271                } else {
272                    false
273                };
274                Ok((ret, then_returns && else_returns))
275            }
276            StmtKind::While { cond, body } => {
277                let cond_ty = self.infer_expr(cond)?;
278                if cond_ty != Type::Bool {
279                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
280                }
281                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
282            }
283            StmtKind::Loop(body) => self.infer_returns(body, false),
284            StmtKind::For { pat, range, body } => {
285                let ty = self.for_pattern_ty(range)?;
286                self.add_pattern_bindings_for_infer(pat, ty)?;
287                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
288            }
289            StmtKind::Let { .. } => {
290                self.infer_stmt(stmt)?;
291                Ok((None, false))
292            }
293            StmtKind::Expr(expr, close) => {
294                let info = self.infer_return_expr(expr)?;
295                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
296            }
297            _ => {
298                self.infer_stmt(stmt)?;
299                Ok((None, false))
300            }
301        }
302    }
303
304    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
305        match &expr.kind {
306            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
307            ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
308            ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
309            ExprKind::Value(v) => Ok(v.get_type()),
310            ExprKind::Const(idx) => Ok(if self.consts.get(*idx).is_some_and(|value| value.is_list() && value.len() == 0) { Type::list_any() } else { Type::Any }),
311            ExprKind::Var(idx) => {
312                let idx = self.top() + (*idx as usize);
313                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
314            }
315            ExprKind::Ident(ident) => {
316                for idx in (self.top()..self.names.len()).rev() {
317                    if self.names[idx].eq(ident) && idx < self.tys.len() {
318                        return self.symbols.get_type(&self.tys[idx]);
319                    }
320                }
321                let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
322                match self.symbols.get_symbol(id)?.1 {
323                    Symbol::Const { ty, .. } => Ok(ty.clone()),
324                    Symbol::Static { ty, .. } => Ok(ty.clone()),
325                    Symbol::Struct(ty, _) => Ok(ty.clone()),
326                    Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
327                    Symbol::Native(ty) => Ok(ty.clone()),
328                    s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
329                }
330            }
331            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
332                Symbol::Const { ty, .. } => Ok(ty.clone()),
333                Symbol::Static { ty, .. } => Ok(ty.clone()),
334                Symbol::Struct(ty, _) => Ok(ty.clone()),
335                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
336                Symbol::Native(ty) => Ok(ty.clone()),
337                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
338            },
339            ExprKind::Generic { obj, params } => {
340                let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
341                match self.infer_expr(obj)? {
342                    Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
343                    _ => Ok(Type::Any),
344                }
345            }
346            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
347            ExprKind::Unary { op, value } => match op {
348                UnaryOp::Not => {
349                    let ty = self.infer_expr(value.as_ref())?;
350                    if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
351                }
352                UnaryOp::Neg => self.infer_expr(value.as_ref()),
353                UnaryOp::Unknow => Ok(Type::Any),
354            },
355            ExprKind::Binary { left, op, right } => {
356                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
357                let ty = if op.is_logic() {
358                    Type::Bool
359                } else if op == &BinaryOp::Idx {
360                    let left_ty = self.infer_expr(left)?;
361                    if let Type::Array(elem_ty, _) = left_ty {
362                        (*elem_ty).clone()
363                    } else if let Type::Vec(elem_ty, _) = left_ty {
364                        (*elem_ty).clone()
365                    } else if let Type::List(elem_ty) = left_ty {
366                        (*elem_ty).clone()
367                    } else {
368                        let left_ty = self.symbols.get_type(&left_ty)?;
369                        let right_ty = if right.is_value() || right.is_const() {
370                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
371                            if right_value.is_str() {
372                                if left_ty.is_any() {
373                                    return Ok(Type::Any);
374                                }
375                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
376                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
377                                }
378                            } else if let Type::Struct { fields, .. } = &left_ty
379                                && let Some(idx) = right_value.as_int()
380                            {
381                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
382                            }
383                            right_value.get_type()
384                        } else {
385                            self.infer_expr(right)?
386                        };
387                        if right_ty.is_int() || right_ty.is_uint() {
388                            if left_ty.is_any() {
389                                return Ok(Type::Any);
390                            }
391                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
392                            let fn_ty = self.symbols.get_type(&s)?;
393                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
394                        }
395                        if left_ty.is_any() {
396                            return Ok(Type::Any);
397                        }
398                        Type::Any
399                    }
400                } else {
401                    let left_ty = self.infer_expr(left)?;
402                    let right_ty = self.infer_expr(right)?;
403                    if op == &BinaryOp::Assign {
404                        if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
405                    } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
406                        left_ty
407                    } else {
408                        left_ty + right_ty
409                    }
410                };
411                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
412                Ok(ty)
413            }
414            ExprKind::Call { obj, params } => {
415                if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
416                    let mut args = Vec::new();
417                    for p in params {
418                        args.push(self.infer_expr(p)?);
419                    }
420                    self.infer_fn_with_params(*id, &args, generic_args)
421                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
422                    let base_name = match ty {
423                        Type::Ident { name, .. } => name.clone(),
424                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
425                        _ => return Ok(Type::Any),
426                    };
427                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
428                    let mut args = vec![self.infer_expr(target)?];
429                    for p in params {
430                        args.push(self.infer_expr(p)?);
431                    }
432                    self.infer_fn(id, &args)
433                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
434                    let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
435                    if let Some(target) = obj_expr
436                        && let Some(method) = method
437                    {
438                        let target_ty = self.infer_expr(target)?;
439                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
440                            && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
441                        {
442                            return Ok(ret_ty);
443                        }
444                    }
445                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
446                    for p in params {
447                        args.push(self.infer_expr(p)?);
448                    }
449                    self.infer_fn(*id, &args)
450                } else if let ExprKind::Ident(name) = &obj.kind {
451                    for idx in (self.top()..self.names.len()).rev() {
452                        if self.names[idx].eq(name) && idx < self.tys.len() {
453                            return if let Type::Symbol { id, .. } = &self.tys[idx] {
454                                let id = *id;
455                                let mut args = Vec::new();
456                                for p in params {
457                                    args.push(self.infer_expr(p)?);
458                                }
459                                self.infer_fn(id, &args)
460                            } else {
461                                Ok(Type::Any)
462                            };
463                        }
464                    }
465                    let Ok(id) = self.symbols.get_id(name) else {
466                        return Ok(Type::Any);
467                    };
468                    if !self.symbols.get_symbol(id)?.1.is_fn() {
469                        return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
470                    }
471                    let mut args = Vec::new();
472                    for p in params {
473                        args.push(self.infer_expr(p)?);
474                    }
475                    self.infer_fn(id, &args)
476                } else if obj.is_idx() {
477                    let (target, _, method) = obj.clone().binary().unwrap();
478                    let ty = self.infer_expr(&target)?;
479                    if let Some(method) = self.get_value(&method) {
480                        let method = method.as_str();
481                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
482                            && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
483                        {
484                            return Ok(ret_ty);
485                        }
486                        let fn_ty = match self.get_field(&ty, method) {
487                            Ok((_, fn_ty)) => fn_ty,
488                            Err(_) => {
489                                let id = self.symbols.get_id(method)?;
490                                if self.symbols.get_symbol(id)?.1.is_fn() {
491                                    Type::Symbol { id, params: Vec::new() }
492                                } else {
493                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
494                                }
495                            }
496                        };
497                        if let Type::Symbol { id, .. } = fn_ty {
498                            let mut args = vec![ty];
499                            for p in params {
500                                args.push(self.infer_expr(p)?);
501                            }
502                            self.infer_fn(id, &args)
503                        } else {
504                            Ok(fn_ty)
505                        }
506                    } else {
507                        Ok(Type::Any)
508                    }
509                } else if let ExprKind::Var(idx) = &obj.kind {
510                    let idx = self.top() + (*idx as usize);
511                    if idx < self.tys.len()
512                        && let Type::Symbol { id, .. } = self.tys[idx]
513                    {
514                        let mut args = Vec::new();
515                        for p in params {
516                            args.push(self.infer_expr(p)?);
517                        }
518                        self.infer_fn(id, &args)
519                    } else {
520                        Ok(Type::Any)
521                    }
522                } else if obj.is_value() {
523                    Ok(Type::Void)
524                } else {
525                    Ok(Type::Any)
526                }
527            }
528            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
529            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
530            ExprKind::Repeat { value, len } => {
531                let value_ty = self.infer_expr(value)?;
532                let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
533                if let Type::ConstInt(len) = len {
534                    let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
535                    Ok(Type::Array(std::rc::Rc::new(value_ty), len))
536                } else {
537                    Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
538                }
539            }
540            ExprKind::List(items) => {
541                if items.is_empty() {
542                    return Ok(Type::list_any());
543                }
544                let mut elem_ty = Type::Any;
545                for item in items {
546                    let item_ty = self.infer_expr(item)?;
547                    elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
548                }
549                Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
550            }
551            ExprKind::Range { start, stop, .. } => {
552                let start_ty = self.infer_expr(start)?;
553                let stop_ty = self.infer_expr(stop)?;
554                Ok(if start_ty.is_any() {
555                    stop_ty
556                } else if stop_ty.is_any() {
557                    start_ty
558                } else {
559                    start_ty + stop_ty
560                })
561            }
562            _ => Ok(Type::Any),
563        }
564    }
565
566    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
567        let mut fn_tys = Vec::new();
568        for (i, ty) in tys.iter().enumerate() {
569            if !ty.is_any() {
570                fn_tys.push(ty.clone());
571            } else if let Some(arg_ty) = arg_tys.get(i) {
572                fn_tys.push(self.symbols.get_type(arg_ty)?);
573            } else {
574                fn_tys.push(Type::Any);
575            }
576        }
577        Ok(fn_tys)
578    }
579
580    fn is_optimizable_local_ty(ty: &Type) -> bool {
581        ty.is_bool() || ty.is_native()
582    }
583
584    fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
585        matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
586    }
587
588    fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
589        let ty = self.tys.get(pos)?;
590        match ty {
591            Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
592                if let ListElemState::Known(elem_ty) = state
593                    && Self::is_optimizable_list_elem_ty(&elem_ty)
594                {
595                    Some(Type::List(std::rc::Rc::new(elem_ty)))
596                } else {
597                    None
598                }
599            }),
600            ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
601            _ => None,
602        }
603    }
604
605    fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
606        (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
607    }
608
609    fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
610        let items = self.local_type_hints.entry(id).or_default();
611        if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
612            item.2 = hints;
613        } else {
614            items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
615        }
616    }
617
618    pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
619        self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
620    }
621
622    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
623        self.infer_fn_with_params(id, arg_tys, &[])
624    }
625
626    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
627        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
628        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
629            if let Type::Fn { tys, ret: _ } = ty {
630                let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
631                let generic_args = resolved_generic_args.as_slice();
632                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
633                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
634                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
635                let body = if generic_params.is_empty() {
636                    body
637                } else {
638                    let mut compile_tys = tys.clone();
639                    let mut compile_cap = cap.clone();
640                    let saved_state = self.take_local_state();
641                    if let Some((module, _)) = name.split_once("::") {
642                        self.symbols.push_module_scope(module.into());
643                    }
644                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
645                    if name.contains("::") {
646                        self.symbols.pop_module_scope();
647                    }
648                    self.restore_local_state(saved_state);
649                    Stmt::new(StmtKind::Block(compiled?), Span::default())
650                };
651                if let Some(fns) = self.fns.get_mut(&id) {
652                    for f in fns.iter() {
653                        if f.0 == generic_args && f.1 == fn_tys {
654                            return match &f.2 {
655                                FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
656                                FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or(Ok(Type::Any)),
657                            };
658                        }
659                    }
660                    fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
661                } else {
662                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
663                }
664                let mut ret_ty = None;
665                let mut local_type_hints = Vec::new();
666                for _ in 0..4 {
667                    let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
668                    let saved_state = self.take_local_state();
669                    self.frames.push(0);
670                    for (arg, ty) in args.iter().zip(fn_tys.iter()) {
671                        self.add_name(arg.clone());
672                        self.add_ty(ty.clone());
673                    }
674                    for c in cap.vars.iter() {
675                        if let Some((name, ty)) = cap.names.get(*c) {
676                            self.add_name(name.clone());
677                            self.add_ty(ty.clone());
678                        } else {
679                            self.add_name("".into());
680                            self.add_ty(Type::Any);
681                        }
682                    }
683                    self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
684                    let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
685                    self.infer_stack.pop();
686                    let pass_local_type_hints = self.collect_local_type_hints();
687                    self.restore_local_state(saved_state);
688                    let pass_ret_ty = match pass_ret_ty {
689                        Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
690                        Err(err) => {
691                            log::error!("infer_fn {} failed: {:?}", name, err);
692                            let should_remove = self
693                                .fns
694                                .get_mut(&id)
695                                .map(|fns| {
696                                    fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
697                                    fns.is_empty()
698                                })
699                                .unwrap_or(false);
700                            if should_remove {
701                                self.fns.remove(&id);
702                            }
703                            return Err(err);
704                        }
705                    };
706                    if !pass_ret_ty.is_any() {
707                        self.update_pending_return_seed(&pass_ret_ty);
708                        ret_ty = Some(pass_ret_ty.clone());
709                    } else if ret_ty.is_none() {
710                        ret_ty = Some(pass_ret_ty);
711                    }
712                    local_type_hints = pass_local_type_hints;
713                    let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
714                    if before_seed == after_seed {
715                        break;
716                    }
717                }
718                let ret_ty = ret_ty.unwrap_or(Type::Any);
719                self.fns.get_mut(&id).map(|f| {
720                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
721                });
722                self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
723                if generic_args.is_empty()
724                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
725                    && ret.is_any()
726                {
727                    *ret = std::rc::Rc::new(ret_ty.clone());
728                }
729                Ok(ret_ty)
730            } else {
731                Ok(Type::Any)
732            }
733        } else if let Symbol::Native(f) = s {
734            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
735        } else if matches!(s, Symbol::Null) {
736            Ok(Type::Any)
737        } else {
738            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
739        }
740    }
741
742    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
743        match &stmt.kind {
744            StmtKind::Expr(expr, close) => {
745                if !close {
746                    self.infer_expr(expr)
747                } else {
748                    self.infer_expr(expr)?;
749                    Ok(Type::Void)
750                }
751            }
752            StmtKind::Return(expr) => {
753                if let Some(e) = expr {
754                    self.infer_expr(e)
755                } else {
756                    Ok(Type::Void)
757                }
758            }
759            StmtKind::Block(stmts) => {
760                for (idx, stmt) in stmts.iter().enumerate() {
761                    let ty = self.infer_stmt(stmt)?;
762                    if stmt.is_return() || idx == stmts.len() - 1 {
763                        return Ok(ty);
764                    }
765                }
766                Ok(Type::Void)
767            }
768            StmtKind::If { then_body, else_body, .. } => {
769                let then_ty = self.infer_stmt(then_body)?;
770                if let Some(e) = else_body {
771                    let else_ty = self.infer_stmt(e)?;
772                    if then_ty != else_ty {
773                        log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
774                        return Ok(if then_ty.is_any() { else_ty } else { then_ty });
775                    }
776                }
777                if else_body.is_none() {
778                    return Ok(Type::Void);
779                }
780                Ok(then_ty)
781            }
782            StmtKind::While { cond, body } => {
783                let cond_ty = self.infer_expr(cond)?;
784                if cond_ty != Type::Bool {
785                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
786                }
787                self.infer_stmt(body)
788            }
789            StmtKind::For { pat, range, body } => {
790                let ty = self.for_pattern_ty(range)?;
791                self.add_pattern_bindings_for_infer(pat, ty)?;
792                self.infer_stmt(body)
793            }
794            StmtKind::Let { pat, value } => {
795                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
796                self.add_pattern_bindings_for_infer(pat, expr_ty)?;
797                Ok(Type::Void)
798            }
799            _ => Ok(Type::Void),
800        }
801    }
802}