1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, PatternKind, Span, Stmt, StmtKind};
5
6impl Compiler {
7 fn merge_return_type(left: Option<Type>, right: Type) -> Type {
8 match left {
9 Some(left) if left == right => left,
10 Some(left) => left + right,
11 None => right,
12 }
13 }
14
15 fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
16 match &stmt.kind {
17 StmtKind::Return(Some(expr)) => Ok(Some(self.infer_expr(expr)?)),
18 StmtKind::Return(None) => Ok(Some(Type::Void)),
19 StmtKind::Block(stmts) => {
20 let mut ret = None;
21 for stmt in stmts {
22 if let Some(ty) = self.infer_return_type(stmt)? {
23 ret = Some(Self::merge_return_type(ret, ty));
24 }
25 }
26 Ok(ret)
27 }
28 StmtKind::If { cond, then_body, else_body } => {
29 let cond_ty = self.infer_expr(cond)?;
30 if cond_ty != Type::Bool {
31 return Err(Self::semantic_error(cond.span, "条件表达式必须是布尔类型"));
32 }
33 let mut ret = self.infer_return_type(then_body)?;
34 if let Some(body) = else_body {
35 if let Some(ty) = self.infer_return_type(body)? {
36 ret = Some(Self::merge_return_type(ret, ty));
37 }
38 }
39 Ok(ret)
40 }
41 StmtKind::While { cond, body } => {
42 let cond_ty = self.infer_expr(cond)?;
43 if cond_ty != Type::Bool {
44 return Err(Self::semantic_error(cond.span, "条件表达式必须是布尔类型"));
45 }
46 self.infer_return_type(body)
47 }
48 StmtKind::Loop(body) => self.infer_return_type(body),
49 StmtKind::For { pat, range, body } => {
50 if let PatternKind::Var { idx, .. } = &pat.kind {
51 let ty = self.infer_expr(range)?;
52 self.set_ty(*idx, ty);
53 } else if let PatternKind::Tuple(pats) = &pat.kind {
54 let ty = self.infer_expr(range)?;
55 assert!(ty.is_any());
56 for pat in pats {
57 if let Some(idx) = pat.var() {
58 self.set_ty(idx, Type::Any);
59 }
60 }
61 }
62 self.infer_return_type(body)
63 }
64 StmtKind::Let { .. } => {
65 self.infer_stmt(stmt)?;
66 Ok(None)
67 }
68 StmtKind::Expr(expr, close) => {
69 let ty = self.infer_expr(expr)?;
70 Ok(if *close { None } else { Some(ty) })
71 }
72 _ => {
73 self.infer_stmt(stmt)?;
74 Ok(None)
75 }
76 }
77 }
78
79 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
80 match &expr.kind {
81 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
82 ExprKind::Value(v) => Ok(v.get_type()),
83 ExprKind::Var(idx) => {
84 let idx = self.top() + (*idx as usize);
85 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
86 }
87 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
88 Symbol::Const { ty, .. } => Ok(ty.clone()),
89 Symbol::Static { ty, .. } => Ok(ty.clone()),
90 Symbol::Struct(ty, _) => Ok(ty.clone()),
91 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
92 Symbol::Native(ty) => Ok(ty.clone()),
93 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
94 },
95 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
96 ExprKind::Unary { value, .. } => self.infer_expr(value.as_ref()),
97 ExprKind::Binary { left, op, right } => {
98 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
99 let ty = if op.is_logic() {
100 let left_ty = self.infer_expr(left)?;
101 if matches!(op, BinaryOp::And | BinaryOp::Or) && left_ty.is_any() { Type::Any } else { Type::Bool }
102 } else if op == &BinaryOp::Idx {
103 let left_ty = self.infer_expr(left)?;
104 if let Type::Array(elem_ty, _) = left_ty {
105 (*elem_ty).clone()
106 } else if let Type::Vec(elem_ty, _) = left_ty {
107 (*elem_ty).clone()
108 } else {
109 let left_ty = self.symbols.get_type(&left_ty)?;
110 let right_ty = if right.is_value() || right.is_const() {
111 let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
112 if right_value.is_str() {
113 if left_ty.is_any() {
114 return Ok(Type::Any);
115 }
116 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
117 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
118 }
119 } else if let Type::Struct { fields, .. } = &left_ty
120 && let Some(idx) = right_value.as_int()
121 {
122 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
123 }
124 right_value.get_type()
125 } else {
126 self.infer_expr(right)?
127 };
128 if right_ty.is_int() || right_ty.is_uint() {
129 if left_ty.is_any() {
130 return Ok(Type::Any);
131 }
132 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
133 let fn_ty = self.symbols.get_type(&s)?;
134 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
135 }
136 if left_ty.is_any() {
137 return Ok(Type::Any);
138 }
139 Type::Any
140 }
141 } else {
142 let right_ty = self.infer_expr(right)?;
143 if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
144 };
145 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
146 Ok(ty)
147 }
148 ExprKind::Call { obj, params } => {
149 if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
150 let mut args = Vec::new();
151 for p in params {
152 args.push(self.infer_expr(p)?);
153 }
154 self.infer_fn_with_params(*id, &args, generic_args)
155 } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
156 let base_name = match ty {
157 Type::Ident { name, .. } => name.clone(),
158 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
159 _ => return Ok(Type::Any),
160 };
161 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
162 let mut args = vec![self.infer_expr(target)?];
163 for p in params {
164 args.push(self.infer_expr(p)?);
165 }
166 self.infer_fn(id, &args)
167 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
168 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
169 for p in params {
170 args.push(self.infer_expr(p)?);
171 }
172 self.infer_fn(*id, &args)
173 } else if obj.is_idx() {
174 let (target, _, method) = obj.clone().binary().unwrap();
175 let ty = self.infer_expr(&target)?;
176 if let Some(method) = self.get_value(&method) {
177 let method = method.as_str();
178 let fn_ty = match self.get_field(&ty, method) {
179 Ok((_, fn_ty)) => fn_ty,
180 Err(_) => {
181 let id = self.symbols.get_id(method)?;
182 if self.symbols.get_symbol(id)?.1.is_fn() {
183 Type::Symbol { id, params: Vec::new() }
184 } else {
185 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
186 }
187 }
188 };
189 if let Type::Symbol { id, .. } = fn_ty {
190 let mut args = vec![ty];
191 for p in params {
192 args.push(self.infer_expr(p)?);
193 }
194 self.infer_fn(id, &args)
195 } else {
196 Ok(fn_ty)
197 }
198 } else {
199 Ok(Type::Any)
200 }
201 } else if let ExprKind::Var(idx) = &obj.kind {
202 let idx = self.top() + (*idx as usize);
203 if idx < self.tys.len()
204 && let Type::Symbol { id, .. } = self.tys[idx]
205 {
206 let mut args = Vec::new();
207 for p in params {
208 args.push(self.infer_expr(p)?);
209 }
210 self.infer_fn(id, &args)
211 } else {
212 Ok(Type::Any)
213 }
214 } else if obj.is_value() {
215 Ok(Type::Void)
216 } else {
217 Ok(Type::Any)
218 }
219 }
220 ExprKind::Typed { ty, .. } => Ok(ty.clone()),
221 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
222 ExprKind::Range { start, stop, .. } => {
223 let start_ty = self.infer_expr(start)?;
224 let stop_ty = self.infer_expr(stop)?;
225 Ok(if start_ty.is_any() {
226 stop_ty
227 } else if stop_ty.is_any() {
228 start_ty
229 } else {
230 stop_ty
231 })
232 }
233 _ => Ok(Type::Any),
234 }
235 }
236
237 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
238 let mut fn_tys = Vec::new();
239 for (i, ty) in tys.iter().enumerate() {
240 if !ty.is_any() {
241 fn_tys.push(ty.clone());
242 } else if let Some(arg_ty) = arg_tys.get(i) {
243 fn_tys.push(self.symbols.get_type(arg_ty)?);
244 } else {
245 fn_tys.push(Type::Any);
246 }
247 }
248 Ok(fn_tys)
249 }
250
251 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
252 self.infer_fn_with_params(id, arg_tys, &[])
253 }
254
255 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
256 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
257 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
258 if let Type::Fn { tys, ret: _ } = ty {
259 let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
260 let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
261 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
262 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
263 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
264 let body = if generic_params.is_empty() {
265 body
266 } else {
267 let mut compile_tys = tys.clone();
268 let mut compile_cap = cap.clone();
269 let saved_state = self.take_local_state();
270 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
271 self.restore_local_state(saved_state);
272 Stmt::new(StmtKind::Block(compiled?), Span::default())
273 };
274 if let Some(fns) = self.fns.get_mut(&id) {
275 for f in fns.iter() {
276 if f.0 == generic_args && f.1 == fn_tys {
277 return Ok(f.2.clone());
278 }
279 }
280 fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
281 } else {
282 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
283 }
284 let top = self.tys.len();
285 self.tys.append(&mut fn_tys.clone());
286 for c in cap.vars.iter() {
287 self.tys.push(self.tys[self.top() + *c].clone());
288 }
289 self.frames.push(top);
290 let ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
291 if let Some(top) = self.frames.pop() {
292 self.tys.truncate(top);
293 }
294 let ret_ty = match ret_ty {
295 Ok(ret_ty) => ret_ty,
296 Err(err) => {
297 log::error!("infer_fn {} failed: {:?}", name, err);
298 let should_remove = self
299 .fns
300 .get_mut(&id)
301 .map(|fns| {
302 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
303 fns.is_empty()
304 })
305 .unwrap_or(false);
306 if should_remove {
307 self.fns.remove(&id);
308 }
309 return Err(err);
310 }
311 };
312 self.fns.get_mut(&id).map(|f| {
313 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
314 });
315 Ok(ret_ty)
316 } else {
317 Ok(Type::Any)
318 }
319 } else if let Symbol::Native(f) = s {
320 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
321 } else if matches!(s, Symbol::Null) {
322 Ok(Type::Any)
323 } else {
324 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
325 }
326 }
327
328 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
329 match &stmt.kind {
330 StmtKind::Expr(expr, close) => {
331 if !close {
332 self.infer_expr(expr)
333 } else {
334 self.infer_expr(expr)?;
335 Ok(Type::Void)
336 }
337 }
338 StmtKind::Return(expr) => {
339 if let Some(e) = expr {
340 self.infer_expr(e)
341 } else {
342 Ok(Type::Void)
343 }
344 }
345 StmtKind::Block(stmts) => {
346 for (idx, stmt) in stmts.iter().enumerate() {
347 let ty = self.infer_stmt(stmt)?;
348 if stmt.is_return() || idx == stmts.len() - 1 {
349 return Ok(ty);
350 }
351 }
352 Ok(Type::Void)
353 }
354 StmtKind::If { then_body, else_body, .. } => {
355 let then_ty = self.infer_stmt(then_body)?;
356 if let Some(e) = else_body {
357 let else_ty = self.infer_stmt(e)?;
358 if then_ty != else_ty {
359 log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
360 return Ok(if then_ty.is_any() { else_ty } else { then_ty });
361 }
362 }
363 if else_body.is_none() {
364 return Ok(Type::Void);
365 }
366 Ok(then_ty)
367 }
368 StmtKind::While { cond, body } => {
369 let cond_ty = self.infer_expr(cond)?;
370 if cond_ty != Type::Bool {
371 return Err(Self::semantic_error(cond.span, "条件表达式必须是布尔类型"));
372 }
373 self.infer_stmt(body)
374 }
375 StmtKind::For { pat, range, body } => {
376 if let PatternKind::Var { idx, .. } = &pat.kind {
377 let ty = self.infer_expr(range)?;
378 self.set_ty(*idx, ty);
379 } else if let PatternKind::Tuple(pats) = &pat.kind {
380 let ty = self.infer_expr(range)?;
381 assert!(ty.is_any());
382 for pat in pats {
383 if let Some(idx) = pat.var() {
384 self.set_ty(idx, Type::Any);
385 }
386 }
387 }
388 self.infer_stmt(body)
389 }
390 StmtKind::Let { pat, value } => {
391 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
392 if let PatternKind::Ident { ty, .. } = &pat.kind {
393 let annotated_ty = self.symbols.get_type(ty)?;
394 if annotated_ty.is_any() {
395 self.add_ty(expr_ty);
396 } else {
397 self.add_ty(annotated_ty);
398 }
399 } else if let PatternKind::Var { idx, .. } = &pat.kind {
400 self.set_ty(*idx, expr_ty);
401 } else if matches!(pat.kind, PatternKind::Wildcard) {
402 self.add_ty(expr_ty);
403 }
404 Ok(Type::Void)
405 }
406 _ => Ok(Type::Void),
407 }
408 }
409}