1use super::{Compiler, FnInferRet, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8 ty: Type,
9 shape: Option<Type>,
10}
11
12impl Compiler {
13 fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
14 match &pat.kind {
15 PatternKind::Ident { name, ty } => {
16 let annotated_ty = self.symbols.get_type(ty)?;
17 self.add_name(name.clone());
18 self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
19 }
20 PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
21 PatternKind::Tuple(pats) => {
22 if let Type::Tuple(tys) = expr_ty {
23 for (pat, ty) in pats.iter().zip(tys) {
24 self.add_pattern_bindings_for_infer(pat, ty)?;
25 }
26 } else {
27 for pat in pats {
28 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
29 }
30 }
31 }
32 PatternKind::List { elems, .. } => {
33 for pat in elems {
34 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
35 }
36 }
37 PatternKind::Wildcard => {
38 self.add_name("".into());
39 self.add_ty(expr_ty);
40 }
41 PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
42 }
43 Ok(())
44 }
45
46 fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
47 if matches!(range.kind, ExprKind::Range { .. }) {
48 return self.infer_expr(range);
49 }
50 Ok(match self.infer_expr(range)? {
51 Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) => elem_ty.as_ref().clone(),
52 _ => Type::Any,
53 })
54 }
55
56 fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
57 match left {
58 Some(left) if left == right => Ok(left),
59 Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
60 Some(left) => Ok(left + right),
61 None => Ok(right),
62 }
63 }
64
65 fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
66 if !ty.is_any() {
67 return if ty.is_struct() { Some(ty.clone()) } else { None };
68 }
69 match &expr.kind {
70 ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::List),
71 ExprKind::Dict(_) => Some(Type::Map),
72 ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
73 ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
74 ExprKind::Typed { ty, .. } => Some(ty.clone()),
75 _ => None,
76 }
77 }
78
79 fn dynamic_return_shape(ty: Type) -> Option<Type> {
80 match ty {
81 Type::Map => Some(Type::Map),
82 Type::List | Type::Array(_, _) => Some(Type::List),
83 _ => None,
84 }
85 }
86
87 fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
88 let ty = self.infer_expr(expr)?;
89 let shape = self.return_shape(expr, &ty);
90 Ok(ReturnInfo { ty, shape })
91 }
92
93 fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
94 let Some(left) = left else {
95 return Ok(right);
96 };
97 if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
98 && left_shape != right_shape
99 {
100 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
101 }
102 if let Some(left_shape) = &left.shape
103 && left_shape.is_struct()
104 && right.ty.is_any()
105 && right.shape.is_none()
106 {
107 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
108 }
109 if let Some(right_shape) = &right.shape
110 && right_shape.is_struct()
111 && left.ty.is_any()
112 && left.shape.is_none()
113 {
114 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
115 }
116 let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
117 Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
118 }
119
120 fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
121 self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
122 }
123
124 pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
125 self.infer_returns(stmt, true).map(|_| ())
126 }
127
128 fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
129 match &stmt.kind {
130 StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
131 StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
132 StmtKind::Block(stmts) => {
133 let mut ret = None;
134 for (idx, stmt) in stmts.iter().enumerate() {
135 let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
136 if let Some(info) = info {
137 ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
138 }
139 if always_returns {
140 return Ok((ret, true));
141 }
142 }
143 Ok((ret, false))
144 }
145 StmtKind::If { cond, then_body, else_body } => {
146 let cond_ty = self.infer_expr(cond)?;
147 if cond_ty != Type::Bool {
148 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
149 }
150 let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
151 let else_returns = if let Some(body) = else_body {
152 let (else_ty, else_returns) = self.infer_returns(body, tail)?;
153 if let Some(info) = else_ty {
154 ret = Some(Self::merge_return_info(body.span, ret, info)?);
155 }
156 else_returns
157 } else {
158 false
159 };
160 Ok((ret, then_returns && else_returns))
161 }
162 StmtKind::While { cond, body } => {
163 let cond_ty = self.infer_expr(cond)?;
164 if cond_ty != Type::Bool {
165 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
166 }
167 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
168 }
169 StmtKind::Loop(body) => self.infer_returns(body, false),
170 StmtKind::For { pat, range, body } => {
171 let ty = self.for_pattern_ty(range)?;
172 self.add_pattern_bindings_for_infer(pat, ty)?;
173 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
174 }
175 StmtKind::Let { .. } => {
176 self.infer_stmt(stmt)?;
177 Ok((None, false))
178 }
179 StmtKind::Expr(expr, close) => {
180 let info = self.infer_return_expr(expr)?;
181 Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
182 }
183 _ => {
184 self.infer_stmt(stmt)?;
185 Ok((None, false))
186 }
187 }
188 }
189
190 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
191 match &expr.kind {
192 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
193 ExprKind::Value(v) if v.is_list() || v.is_map() => Ok(Type::Any),
194 ExprKind::Value(v) => Ok(v.get_type()),
195 ExprKind::Const(_) => Ok(Type::Any),
196 ExprKind::Var(idx) => {
197 let idx = self.top() + (*idx as usize);
198 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
199 }
200 ExprKind::Ident(ident) => {
201 for idx in (self.top()..self.names.len()).rev() {
202 if self.names[idx].eq(ident) && idx < self.tys.len() {
203 return self.symbols.get_type(&self.tys[idx]);
204 }
205 }
206 let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
207 match self.symbols.get_symbol(id)?.1 {
208 Symbol::Const { ty, .. } => Ok(ty.clone()),
209 Symbol::Static { ty, .. } => Ok(ty.clone()),
210 Symbol::Struct(ty, _) => Ok(ty.clone()),
211 Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
212 Symbol::Native(ty) => Ok(ty.clone()),
213 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
214 }
215 }
216 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
217 Symbol::Const { ty, .. } => Ok(ty.clone()),
218 Symbol::Static { ty, .. } => Ok(ty.clone()),
219 Symbol::Struct(ty, _) => Ok(ty.clone()),
220 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
221 Symbol::Native(ty) => Ok(ty.clone()),
222 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
223 },
224 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
225 ExprKind::Unary { op, value } => match op {
226 UnaryOp::Not => {
227 let ty = self.infer_expr(value.as_ref())?;
228 if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
229 }
230 UnaryOp::Neg => self.infer_expr(value.as_ref()),
231 UnaryOp::Unknow => Ok(Type::Any),
232 },
233 ExprKind::Binary { left, op, right } => {
234 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
235 let ty = if op.is_logic() {
236 Type::Bool
237 } else if op == &BinaryOp::Idx {
238 let left_ty = self.infer_expr(left)?;
239 if let Type::Array(elem_ty, _) = left_ty {
240 (*elem_ty).clone()
241 } else if let Type::Vec(elem_ty, _) = left_ty {
242 (*elem_ty).clone()
243 } else {
244 let left_ty = self.symbols.get_type(&left_ty)?;
245 let right_ty = if right.is_value() || right.is_const() {
246 let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
247 if right_value.is_str() {
248 if left_ty.is_any() {
249 return Ok(Type::Any);
250 }
251 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
252 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
253 }
254 } else if let Type::Struct { fields, .. } = &left_ty
255 && let Some(idx) = right_value.as_int()
256 {
257 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
258 }
259 right_value.get_type()
260 } else {
261 self.infer_expr(right)?
262 };
263 if right_ty.is_int() || right_ty.is_uint() {
264 if left_ty.is_any() {
265 return Ok(Type::Any);
266 }
267 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
268 let fn_ty = self.symbols.get_type(&s)?;
269 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
270 }
271 if left_ty.is_any() {
272 return Ok(Type::Any);
273 }
274 Type::Any
275 }
276 } else {
277 let right_ty = self.infer_expr(right)?;
278 if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
279 };
280 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
281 Ok(ty)
282 }
283 ExprKind::Call { obj, params } => {
284 if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
285 let mut args = Vec::new();
286 for p in params {
287 args.push(self.infer_expr(p)?);
288 }
289 self.infer_fn_with_params(*id, &args, generic_args)
290 } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
291 let base_name = match ty {
292 Type::Ident { name, .. } => name.clone(),
293 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
294 _ => return Ok(Type::Any),
295 };
296 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
297 let mut args = vec![self.infer_expr(target)?];
298 for p in params {
299 args.push(self.infer_expr(p)?);
300 }
301 self.infer_fn(id, &args)
302 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
303 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
304 for p in params {
305 args.push(self.infer_expr(p)?);
306 }
307 self.infer_fn(*id, &args)
308 } else if let ExprKind::Ident(name) = &obj.kind {
309 for idx in (self.top()..self.names.len()).rev() {
310 if self.names[idx].eq(name) && idx < self.tys.len() {
311 return if let Type::Symbol { id, .. } = &self.tys[idx] {
312 let id = *id;
313 let mut args = Vec::new();
314 for p in params {
315 args.push(self.infer_expr(p)?);
316 }
317 self.infer_fn(id, &args)
318 } else {
319 Ok(Type::Any)
320 };
321 }
322 }
323 let Ok(id) = self.symbols.get_id(name) else {
324 return Ok(Type::Any);
325 };
326 if !self.symbols.get_symbol(id)?.1.is_fn() {
327 return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
328 }
329 let mut args = Vec::new();
330 for p in params {
331 args.push(self.infer_expr(p)?);
332 }
333 self.infer_fn(id, &args)
334 } else if obj.is_idx() {
335 let (target, _, method) = obj.clone().binary().unwrap();
336 let ty = self.infer_expr(&target)?;
337 if let Some(method) = self.get_value(&method) {
338 let method = method.as_str();
339 let fn_ty = match self.get_field(&ty, method) {
340 Ok((_, fn_ty)) => fn_ty,
341 Err(_) => {
342 let id = self.symbols.get_id(method)?;
343 if self.symbols.get_symbol(id)?.1.is_fn() {
344 Type::Symbol { id, params: Vec::new() }
345 } else {
346 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
347 }
348 }
349 };
350 if let Type::Symbol { id, .. } = fn_ty {
351 let mut args = vec![ty];
352 for p in params {
353 args.push(self.infer_expr(p)?);
354 }
355 self.infer_fn(id, &args)
356 } else {
357 Ok(fn_ty)
358 }
359 } else {
360 Ok(Type::Any)
361 }
362 } else if let ExprKind::Var(idx) = &obj.kind {
363 let idx = self.top() + (*idx as usize);
364 if idx < self.tys.len()
365 && let Type::Symbol { id, .. } = self.tys[idx]
366 {
367 let mut args = Vec::new();
368 for p in params {
369 args.push(self.infer_expr(p)?);
370 }
371 self.infer_fn(id, &args)
372 } else {
373 Ok(Type::Any)
374 }
375 } else if obj.is_value() {
376 Ok(Type::Void)
377 } else {
378 Ok(Type::Any)
379 }
380 }
381 ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
382 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
383 ExprKind::Range { start, stop, .. } => {
384 let start_ty = self.infer_expr(start)?;
385 let stop_ty = self.infer_expr(stop)?;
386 Ok(if start_ty.is_any() {
387 stop_ty
388 } else if stop_ty.is_any() {
389 start_ty
390 } else {
391 stop_ty
392 })
393 }
394 _ => Ok(Type::Any),
395 }
396 }
397
398 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
399 let mut fn_tys = Vec::new();
400 for (i, ty) in tys.iter().enumerate() {
401 if !ty.is_any() {
402 fn_tys.push(ty.clone());
403 } else if let Some(arg_ty) = arg_tys.get(i) {
404 fn_tys.push(self.symbols.get_type(arg_ty)?);
405 } else {
406 fn_tys.push(Type::Any);
407 }
408 }
409 Ok(fn_tys)
410 }
411
412 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
413 self.infer_fn_with_params(id, arg_tys, &[])
414 }
415
416 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
417 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
418 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
419 if let Type::Fn { tys, ret: _ } = ty {
420 let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
421 let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
422 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
423 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
424 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
425 let body = if generic_params.is_empty() {
426 body
427 } else {
428 let mut compile_tys = tys.clone();
429 let mut compile_cap = cap.clone();
430 let saved_state = self.take_local_state();
431 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
432 self.restore_local_state(saved_state);
433 Stmt::new(StmtKind::Block(compiled?), Span::default())
434 };
435 if let Some(fns) = self.fns.get_mut(&id) {
436 for f in fns.iter() {
437 if f.0 == generic_args && f.1 == fn_tys {
438 return match &f.2 {
439 FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
440 FnInferRet::Pending => Ok(Type::Any),
441 };
442 }
443 }
444 fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending));
445 } else {
446 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending)]);
447 }
448 let saved_state = self.take_local_state();
449 self.frames.push(0);
450 for (arg, ty) in args.iter().zip(fn_tys.iter()) {
451 self.add_name(arg.clone());
452 self.add_ty(ty.clone());
453 }
454 for c in cap.vars.iter() {
455 if let Some((name, ty)) = cap.names.get(*c) {
456 self.add_name(name.clone());
457 self.add_ty(ty.clone());
458 } else {
459 self.add_name("".into());
460 self.add_ty(Type::Any);
461 }
462 }
463 let ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
464 self.restore_local_state(saved_state);
465 let ret_ty = match ret_ty {
466 Ok(ret_ty) => self.symbols.get_type(&ret_ty).unwrap_or(ret_ty),
467 Err(err) => {
468 log::error!("infer_fn {} failed: {:?}", name, err);
469 let should_remove = self
470 .fns
471 .get_mut(&id)
472 .map(|fns| {
473 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending));
474 fns.is_empty()
475 })
476 .unwrap_or(false);
477 if should_remove {
478 self.fns.remove(&id);
479 }
480 return Err(err);
481 }
482 };
483 self.fns.get_mut(&id).map(|f| {
484 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
485 });
486 if generic_args.is_empty()
487 && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
488 && ret.is_any()
489 {
490 *ret = std::rc::Rc::new(ret_ty.clone());
491 }
492 Ok(ret_ty)
493 } else {
494 Ok(Type::Any)
495 }
496 } else if let Symbol::Native(f) = s {
497 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
498 } else if matches!(s, Symbol::Null) {
499 Ok(Type::Any)
500 } else {
501 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
502 }
503 }
504
505 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
506 match &stmt.kind {
507 StmtKind::Expr(expr, close) => {
508 if !close {
509 self.infer_expr(expr)
510 } else {
511 self.infer_expr(expr)?;
512 Ok(Type::Void)
513 }
514 }
515 StmtKind::Return(expr) => {
516 if let Some(e) = expr {
517 self.infer_expr(e)
518 } else {
519 Ok(Type::Void)
520 }
521 }
522 StmtKind::Block(stmts) => {
523 for (idx, stmt) in stmts.iter().enumerate() {
524 let ty = self.infer_stmt(stmt)?;
525 if stmt.is_return() || idx == stmts.len() - 1 {
526 return Ok(ty);
527 }
528 }
529 Ok(Type::Void)
530 }
531 StmtKind::If { then_body, else_body, .. } => {
532 let then_ty = self.infer_stmt(then_body)?;
533 if let Some(e) = else_body {
534 let else_ty = self.infer_stmt(e)?;
535 if then_ty != else_ty {
536 log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
537 return Ok(if then_ty.is_any() { else_ty } else { then_ty });
538 }
539 }
540 if else_body.is_none() {
541 return Ok(Type::Void);
542 }
543 Ok(then_ty)
544 }
545 StmtKind::While { cond, body } => {
546 let cond_ty = self.infer_expr(cond)?;
547 if cond_ty != Type::Bool {
548 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
549 }
550 self.infer_stmt(body)
551 }
552 StmtKind::For { pat, range, body } => {
553 let ty = self.for_pattern_ty(range)?;
554 self.add_pattern_bindings_for_infer(pat, ty)?;
555 self.infer_stmt(body)
556 }
557 StmtKind::Let { pat, value } => {
558 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
559 self.add_pattern_bindings_for_infer(pat, expr_ty)?;
560 Ok(Type::Void)
561 }
562 _ => Ok(Type::Void),
563 }
564 }
565}