1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6impl Compiler {
7 fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
8 match left {
9 Some(left) if left == right => Ok(left),
10 Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
11 Some(left) => Ok(left + right),
12 None => Ok(right),
13 }
14 }
15
16 fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
17 self.infer_returns(stmt, true).map(|(ty, _)| ty)
18 }
19
20 fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<Type>, bool)> {
21 match &stmt.kind {
22 StmtKind::Return(Some(expr)) => Ok((Some(self.infer_expr(expr)?), true)),
23 StmtKind::Return(None) => Ok((Some(Type::Void), true)),
24 StmtKind::Block(stmts) => {
25 let mut ret = None;
26 for (idx, stmt) in stmts.iter().enumerate() {
27 let (ty, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
28 if let Some(ty) = ty {
29 ret = Some(Self::merge_return_type(stmt.span, ret, ty)?);
30 }
31 if always_returns {
32 return Ok((ret, true));
33 }
34 }
35 Ok((ret, false))
36 }
37 StmtKind::If { cond, then_body, else_body } => {
38 let cond_ty = self.infer_expr(cond)?;
39 if cond_ty != Type::Bool {
40 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
41 }
42 let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
43 let else_returns = if let Some(body) = else_body {
44 let (else_ty, else_returns) = self.infer_returns(body, tail)?;
45 if let Some(ty) = else_ty {
46 ret = Some(Self::merge_return_type(body.span, ret, ty)?);
47 }
48 else_returns
49 } else {
50 false
51 };
52 Ok((ret, then_returns && else_returns))
53 }
54 StmtKind::While { cond, body } => {
55 let cond_ty = self.infer_expr(cond)?;
56 if cond_ty != Type::Bool {
57 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
58 }
59 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
60 }
61 StmtKind::Loop(body) => self.infer_returns(body, false),
62 StmtKind::For { pat, range, body } => {
63 if let PatternKind::Var { idx, .. } = &pat.kind {
64 let ty = self.infer_expr(range)?;
65 self.set_ty(*idx, ty);
66 } else if let PatternKind::Tuple(pats) = &pat.kind {
67 let ty = self.infer_expr(range)?;
68 assert!(ty.is_any());
69 for pat in pats {
70 if let Some(idx) = pat.var() {
71 self.set_ty(idx, Type::Any);
72 }
73 }
74 }
75 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
76 }
77 StmtKind::Let { .. } => {
78 self.infer_stmt(stmt)?;
79 Ok((None, false))
80 }
81 StmtKind::Expr(expr, close) => {
82 let ty = self.infer_expr(expr)?;
83 Ok(if *close || !tail { (None, false) } else { (Some(ty), true) })
84 }
85 _ => {
86 self.infer_stmt(stmt)?;
87 Ok((None, false))
88 }
89 }
90 }
91
92 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
93 match &expr.kind {
94 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
95 ExprKind::Value(v) if v.is_list() || v.is_map() => Ok(Type::Any),
96 ExprKind::Value(v) => Ok(v.get_type()),
97 ExprKind::Const(_) => Ok(Type::Any),
98 ExprKind::Var(idx) => {
99 let idx = self.top() + (*idx as usize);
100 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
101 }
102 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
103 Symbol::Const { ty, .. } => Ok(ty.clone()),
104 Symbol::Static { ty, .. } => Ok(ty.clone()),
105 Symbol::Struct(ty, _) => Ok(ty.clone()),
106 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
107 Symbol::Native(ty) => Ok(ty.clone()),
108 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
109 },
110 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
111 ExprKind::Unary { op, value } => match op {
112 UnaryOp::Not => {
113 self.infer_expr(value.as_ref())?;
114 Ok(Type::Bool)
115 }
116 UnaryOp::Neg => self.infer_expr(value.as_ref()),
117 UnaryOp::Unknow => Ok(Type::Any),
118 },
119 ExprKind::Binary { left, op, right } => {
120 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
121 let ty = if op.is_logic() {
122 let left_ty = self.infer_expr(left)?;
123 if matches!(op, BinaryOp::And | BinaryOp::Or) && left_ty.is_any() { Type::Any } else { Type::Bool }
124 } else if op == &BinaryOp::Idx {
125 let left_ty = self.infer_expr(left)?;
126 if let Type::Array(elem_ty, _) = left_ty {
127 (*elem_ty).clone()
128 } else if let Type::Vec(elem_ty, _) = left_ty {
129 (*elem_ty).clone()
130 } else {
131 let left_ty = self.symbols.get_type(&left_ty)?;
132 let right_ty = if right.is_value() || right.is_const() {
133 let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
134 if right_value.is_str() {
135 if left_ty.is_any() {
136 return Ok(Type::Any);
137 }
138 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
139 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
140 }
141 } else if let Type::Struct { fields, .. } = &left_ty
142 && let Some(idx) = right_value.as_int()
143 {
144 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
145 }
146 right_value.get_type()
147 } else {
148 self.infer_expr(right)?
149 };
150 if right_ty.is_int() || right_ty.is_uint() {
151 if left_ty.is_any() {
152 return Ok(Type::Any);
153 }
154 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
155 let fn_ty = self.symbols.get_type(&s)?;
156 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
157 }
158 if left_ty.is_any() {
159 return Ok(Type::Any);
160 }
161 Type::Any
162 }
163 } else {
164 let right_ty = self.infer_expr(right)?;
165 if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
166 };
167 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
168 Ok(ty)
169 }
170 ExprKind::Call { obj, params } => {
171 if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
172 let mut args = Vec::new();
173 for p in params {
174 args.push(self.infer_expr(p)?);
175 }
176 self.infer_fn_with_params(*id, &args, generic_args)
177 } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
178 let base_name = match ty {
179 Type::Ident { name, .. } => name.clone(),
180 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
181 _ => return Ok(Type::Any),
182 };
183 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
184 let mut args = vec![self.infer_expr(target)?];
185 for p in params {
186 args.push(self.infer_expr(p)?);
187 }
188 self.infer_fn(id, &args)
189 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
190 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
191 for p in params {
192 args.push(self.infer_expr(p)?);
193 }
194 self.infer_fn(*id, &args)
195 } else if obj.is_idx() {
196 let (target, _, method) = obj.clone().binary().unwrap();
197 let ty = self.infer_expr(&target)?;
198 if let Some(method) = self.get_value(&method) {
199 let method = method.as_str();
200 let fn_ty = match self.get_field(&ty, method) {
201 Ok((_, fn_ty)) => fn_ty,
202 Err(_) => {
203 let id = self.symbols.get_id(method)?;
204 if self.symbols.get_symbol(id)?.1.is_fn() {
205 Type::Symbol { id, params: Vec::new() }
206 } else {
207 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
208 }
209 }
210 };
211 if let Type::Symbol { id, .. } = fn_ty {
212 let mut args = vec![ty];
213 for p in params {
214 args.push(self.infer_expr(p)?);
215 }
216 self.infer_fn(id, &args)
217 } else {
218 Ok(fn_ty)
219 }
220 } else {
221 Ok(Type::Any)
222 }
223 } else if let ExprKind::Var(idx) = &obj.kind {
224 let idx = self.top() + (*idx as usize);
225 if idx < self.tys.len()
226 && let Type::Symbol { id, .. } = self.tys[idx]
227 {
228 let mut args = Vec::new();
229 for p in params {
230 args.push(self.infer_expr(p)?);
231 }
232 self.infer_fn(id, &args)
233 } else {
234 Ok(Type::Any)
235 }
236 } else if obj.is_value() {
237 Ok(Type::Void)
238 } else {
239 Ok(Type::Any)
240 }
241 }
242 ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
243 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
244 ExprKind::Range { start, stop, .. } => {
245 let start_ty = self.infer_expr(start)?;
246 let stop_ty = self.infer_expr(stop)?;
247 Ok(if start_ty.is_any() {
248 stop_ty
249 } else if stop_ty.is_any() {
250 start_ty
251 } else {
252 stop_ty
253 })
254 }
255 _ => Ok(Type::Any),
256 }
257 }
258
259 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
260 let mut fn_tys = Vec::new();
261 for (i, ty) in tys.iter().enumerate() {
262 if !ty.is_any() {
263 fn_tys.push(ty.clone());
264 } else if let Some(arg_ty) = arg_tys.get(i) {
265 fn_tys.push(self.symbols.get_type(arg_ty)?);
266 } else {
267 fn_tys.push(Type::Any);
268 }
269 }
270 Ok(fn_tys)
271 }
272
273 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
274 self.infer_fn_with_params(id, arg_tys, &[])
275 }
276
277 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
278 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
279 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
280 if let Type::Fn { tys, ret: _ } = ty {
281 let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
282 let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
283 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
284 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
285 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
286 let body = if generic_params.is_empty() {
287 body
288 } else {
289 let mut compile_tys = tys.clone();
290 let mut compile_cap = cap.clone();
291 let saved_state = self.take_local_state();
292 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
293 self.restore_local_state(saved_state);
294 Stmt::new(StmtKind::Block(compiled?), Span::default())
295 };
296 if let Some(fns) = self.fns.get_mut(&id) {
297 for f in fns.iter() {
298 if f.0 == generic_args && f.1 == fn_tys {
299 return self.symbols.get_type(&f.2);
300 }
301 }
302 fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
303 } else {
304 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
305 }
306 let top = self.tys.len();
307 self.tys.append(&mut fn_tys.clone());
308 for c in cap.vars.iter() {
309 self.tys.push(self.tys[self.top() + *c].clone());
310 }
311 self.frames.push(top);
312 let ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
313 if let Some(top) = self.frames.pop() {
314 self.tys.truncate(top);
315 }
316 let ret_ty = match ret_ty {
317 Ok(ret_ty) => self.symbols.get_type(&ret_ty).unwrap_or(ret_ty),
318 Err(err) => {
319 log::error!("infer_fn {} failed: {:?}", name, err);
320 let should_remove = self
321 .fns
322 .get_mut(&id)
323 .map(|fns| {
324 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
325 fns.is_empty()
326 })
327 .unwrap_or(false);
328 if should_remove {
329 self.fns.remove(&id);
330 }
331 return Err(err);
332 }
333 };
334 self.fns.get_mut(&id).map(|f| {
335 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
336 });
337 if generic_args.is_empty()
338 && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
339 && ret.is_any()
340 {
341 *ret = std::rc::Rc::new(ret_ty.clone());
342 }
343 Ok(ret_ty)
344 } else {
345 Ok(Type::Any)
346 }
347 } else if let Symbol::Native(f) = s {
348 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
349 } else if matches!(s, Symbol::Null) {
350 Ok(Type::Any)
351 } else {
352 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
353 }
354 }
355
356 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
357 match &stmt.kind {
358 StmtKind::Expr(expr, close) => {
359 if !close {
360 self.infer_expr(expr)
361 } else {
362 self.infer_expr(expr)?;
363 Ok(Type::Void)
364 }
365 }
366 StmtKind::Return(expr) => {
367 if let Some(e) = expr {
368 self.infer_expr(e)
369 } else {
370 Ok(Type::Void)
371 }
372 }
373 StmtKind::Block(stmts) => {
374 for (idx, stmt) in stmts.iter().enumerate() {
375 let ty = self.infer_stmt(stmt)?;
376 if stmt.is_return() || idx == stmts.len() - 1 {
377 return Ok(ty);
378 }
379 }
380 Ok(Type::Void)
381 }
382 StmtKind::If { then_body, else_body, .. } => {
383 let then_ty = self.infer_stmt(then_body)?;
384 if let Some(e) = else_body {
385 let else_ty = self.infer_stmt(e)?;
386 if then_ty != else_ty {
387 log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
388 return Ok(if then_ty.is_any() { else_ty } else { then_ty });
389 }
390 }
391 if else_body.is_none() {
392 return Ok(Type::Void);
393 }
394 Ok(then_ty)
395 }
396 StmtKind::While { cond, body } => {
397 let cond_ty = self.infer_expr(cond)?;
398 if cond_ty != Type::Bool {
399 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
400 }
401 self.infer_stmt(body)
402 }
403 StmtKind::For { pat, range, body } => {
404 if let PatternKind::Var { idx, .. } = &pat.kind {
405 let ty = self.infer_expr(range)?;
406 self.set_ty(*idx, ty);
407 } else if let PatternKind::Tuple(pats) = &pat.kind {
408 let ty = self.infer_expr(range)?;
409 assert!(ty.is_any());
410 for pat in pats {
411 if let Some(idx) = pat.var() {
412 self.set_ty(idx, Type::Any);
413 }
414 }
415 }
416 self.infer_stmt(body)
417 }
418 StmtKind::Let { pat, value } => {
419 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
420 if let PatternKind::Ident { ty, .. } = &pat.kind {
421 let annotated_ty = self.symbols.get_type(ty)?;
422 if annotated_ty.is_any() {
423 self.add_ty(expr_ty);
424 } else {
425 self.add_ty(annotated_ty);
426 }
427 } else if let PatternKind::Var { idx, .. } = &pat.kind {
428 self.set_ty(*idx, expr_ty);
429 } else if matches!(pat.kind, PatternKind::Wildcard) {
430 self.add_ty(expr_ty);
431 }
432 Ok(Type::Void)
433 }
434 _ => Ok(Type::Void),
435 }
436 }
437}