1use sipha::error::SemanticDiagnostic;
4use sipha::red::{SyntaxElement, SyntaxNode};
5use sipha::types::Span;
6use sipha::walk::{Visitor, WalkResult};
7use std::collections::HashMap;
8
9pub type TypeMapKey = (u32, u32);
16
17struct NullCheckNarrowing {
20 var_name: String,
21 then_ty: Type,
22 else_ty: Type,
23 closing_node: SyntaxNode,
24}
25
26fn node_first_span(node: &SyntaxNode) -> Span {
28 node.first_token()
29 .map_or_else(|| Span::new(0, 0), |t| t.text_range())
30}
31
32use leekscript_core::syntax::Kind;
33use leekscript_core::{CastType, Type};
34use sipha::types::IntoSyntaxKind;
35
36use super::error::{invalid_cast_at, type_mismatch_at, wrong_arity_at};
37use super::node_helpers::{
38 call_argument_count, call_argument_node, class_decl_info, for_in_iterable_expr,
39 for_in_loop_vars, function_decl_info, is_ternary_expr, member_expr_member_name,
40 node_index_in_parent, null_check_from_condition, primary_expr_new_constructor,
41 primary_expr_resolvable_name, var_decl_info, VarDeclKind,
42};
43use super::scope::{ResolvedSymbol, Scope, ScopeId, ScopeStore};
44use super::type_expr::{anon_fn_types, find_type_expr_child, parse_type_expr, TypeExprResult};
45
46pub struct TypeChecker<'a> {
48 pub store: &'a ScopeStore,
49 root: &'a SyntaxNode,
51 stack: Vec<ScopeId>,
52 next_scope_id: usize,
53 var_types: Vec<HashMap<String, Type>>,
55 type_stack: Vec<Type>,
57 last_primary_ident: Option<String>,
59 current_function_return_type: Option<Type>,
61 current_class: Option<String>,
63 current_super_class: Option<String>,
65 pub diagnostics: Vec<SemanticDiagnostic>,
66 pub type_map: HashMap<TypeMapKey, Type>,
68 null_check_narrowing: Vec<NullCheckNarrowing>,
70 inferred_return_types: HashMap<TypeMapKey, Type>,
72}
73
74impl<'a> TypeChecker<'a> {
75 #[must_use]
76 pub fn new(store: &'a ScopeStore, root: &'a SyntaxNode) -> Self {
77 Self {
78 store,
79 root,
80 stack: vec![ScopeId(0)],
81 next_scope_id: 1,
82 var_types: vec![HashMap::new()],
83 type_stack: Vec::new(),
84 last_primary_ident: None,
85 current_function_return_type: None,
86 current_class: None,
87 current_super_class: None,
88 diagnostics: Vec::new(),
89 type_map: HashMap::new(),
90 null_check_narrowing: Vec::new(),
91 inferred_return_types: HashMap::new(),
92 }
93 }
94
95 fn enclosing_function(&self, node: &SyntaxNode) -> Option<SyntaxNode> {
97 for anc in node.ancestors(self.root) {
98 if anc.kind_as::<Kind>() == Some(Kind::NodeAnonFn)
99 || anc.kind_as::<Kind>() == Some(Kind::NodeFunctionDecl)
100 {
101 return Some(anc);
102 }
103 }
104 None
105 }
106
107 fn inferred_return_type_for_function_name(&self, name: &str) -> Option<Type> {
109 for decl in self
110 .root
111 .find_all_nodes(Kind::NodeFunctionDecl.into_syntax_kind())
112 {
113 if function_decl_info(&decl).is_some_and(|info| info.name == name) {
114 let r = decl.text_range();
115 if let Some(ty) = self.inferred_return_types.get(&(r.start, r.end)) {
116 return Some(ty.clone());
117 }
118 }
119 }
120 None
121 }
122
123 fn current_scope(&self) -> ScopeId {
124 *self.stack.last().unwrap_or(&ScopeId(0))
125 }
126
127 fn find_in_scope_chain<T>(&self, mut f: impl FnMut(&Scope) -> Option<T>) -> Option<T> {
129 let mut id = Some(self.current_scope());
130 while let Some(scope_id) = id {
131 if let Some(scope) = self.store.get(scope_id) {
132 if let Some(t) = f(scope) {
133 return Some(t);
134 }
135 id = scope.parent;
136 } else {
137 break;
138 }
139 }
140 None
141 }
142
143 fn push_scope(&mut self) {
144 self.stack.push(ScopeId(self.next_scope_id));
145 self.next_scope_id += 1;
146 self.var_types.push(HashMap::new());
147 }
148
149 fn pop_scope(&mut self) {
150 if self.stack.len() > 1 {
151 self.stack.pop();
152 self.var_types.pop();
153 }
154 }
155
156 fn lookup_var_type(&self, name: &str) -> Type {
157 for map in self.var_types.iter().rev() {
158 if let Some(t) = map.get(name) {
159 return t.clone();
160 }
161 }
162 self.find_in_scope_chain(|scope| {
163 scope
164 .get_variable(name)
165 .and_then(|v| v.declared_type.clone())
166 .or_else(|| {
167 if scope.has_global(name) {
168 Some(scope.get_global_type(name).unwrap_or(Type::any()))
169 } else {
170 None
171 }
172 })
173 })
174 .unwrap_or(Type::any())
175 }
176
177 fn add_var_type(&mut self, name: String, ty: Type) {
178 if let Some(map) = self.var_types.last_mut() {
179 map.insert(name, ty);
180 }
181 }
182
183 fn function_accepts_arity(&self, name: &str, arity: usize) -> bool {
185 self.find_in_scope_chain(|scope| {
186 if scope.function_accepts_arity(name, arity) {
187 Some(())
188 } else {
189 None
190 }
191 })
192 .is_some()
193 }
194
195 fn get_function_type(&self, name: &str, arity: usize) -> Option<(Vec<Type>, Type)> {
197 self.find_in_scope_chain(|scope| scope.get_function_type(name, arity))
198 }
199
200 fn record_expression_type(&mut self, node: &SyntaxNode, ty: &Type) {
202 let span = node.text_range();
203 self.type_map.insert((span.start, span.end), ty.clone());
204 }
205
206 fn resolve_identifier_type(&self, name: &str) -> Type {
208 if name == "this" {
209 if let Some(ref c) = self.current_class {
210 return Type::instance(c.clone());
211 }
212 }
213 if name == "super" {
214 if let Some(ref s) = self.current_super_class {
215 return Type::instance(s.clone());
216 }
217 }
218 if let Some(sym) = self.store.resolve(self.current_scope(), name) {
219 match sym {
220 ResolvedSymbol::Class(class_name) => return Type::class(Some(class_name)),
221 ResolvedSymbol::Function(_, _) => {
222 if let Some(mut ty) = self
223 .store
224 .get_function_type_as_value(self.current_scope(), name)
225 {
226 if let Type::Function {
228 ref args,
229 return_type: ref ret,
230 } = ty
231 {
232 if **ret == Type::any() {
233 if let Some(inferred) =
234 self.inferred_return_type_for_function_name(name)
235 {
236 ty = Type::function(args.clone(), inferred);
237 }
238 }
239 }
240 return ty;
241 }
242 }
243 _ => {}
244 }
245 }
246 if self.store.root_has_class(name) {
248 return Type::class(Some(name.to_string()));
249 }
250 self.lookup_var_type(name)
251 }
252}
253
254impl Visitor for TypeChecker<'_> {
255 fn enter_node(&mut self, node: &SyntaxNode) -> WalkResult {
256 let kind = match node.kind_as::<Kind>() {
257 Some(k) => k,
258 None => return WalkResult::Continue(()),
259 };
260
261 match kind {
262 Kind::NodeBlock
263 | Kind::NodeWhileStmt
264 | Kind::NodeForStmt
265 | Kind::NodeForInStmt
266 | Kind::NodeDoWhileStmt => self.push_scope(),
267 Kind::NodeFunctionDecl | Kind::NodeClassDecl | Kind::NodeConstructorDecl => {
268 self.push_scope();
269 if kind == Kind::NodeFunctionDecl {
270 let mut saw_arrow = false;
273 for child in node.children() {
274 if let SyntaxElement::Token(t) = &child {
275 if t.text() == "->" {
276 saw_arrow = true;
277 continue;
278 }
279 }
280 if let SyntaxElement::Node(n) = &child {
281 if n.kind_as::<Kind>() == Some(Kind::NodeTypeExpr) {
282 if saw_arrow {
283 if let TypeExprResult::Ok(t) = parse_type_expr(n) {
284 self.current_function_return_type = Some(t);
285 }
286 break;
287 }
288 if let TypeExprResult::Ok(t) = parse_type_expr(n) {
290 let is_function_type = matches!(t, Type::Class(Some(ref n)) if n == "function")
291 || matches!(t, Type::Instance(ref n) if n == "function");
292 if !is_function_type {
293 self.current_function_return_type = Some(t);
294 }
295 break;
296 }
297 }
298 }
299 saw_arrow = false;
300 }
301 } else if kind == Kind::NodeClassDecl {
302 if let Some(info) = class_decl_info(node) {
303 self.current_class = Some(info.name.clone());
304 self.current_super_class = info.super_class;
305 }
306 }
307 }
309 _ => {}
310 }
311
312 let to_apply = self.null_check_narrowing.last().and_then(|narrow| {
314 let ancestors: Vec<SyntaxNode> = node.ancestors(self.root);
315 let parent = ancestors.first()?;
316 if narrow.closing_node.text_range() != parent.text_range() {
317 return None;
318 }
319 let idx = node_index_in_parent(node, parent);
320 let parent_kind = parent.kind_as::<Kind>();
321 let (then_ty, else_ty) = (narrow.then_ty.clone(), narrow.else_ty.clone());
322 let var_name = narrow.var_name.clone();
323 if parent_kind == Some(Kind::NodeIfStmt) {
324 if idx == Some(4) {
325 Some((var_name, then_ty))
326 } else if idx == Some(6) {
327 Some((var_name, else_ty))
328 } else {
329 None
330 }
331 } else if parent_kind == Some(Kind::NodeExpr) && is_ternary_expr(parent) {
332 if idx == Some(2) {
333 Some((var_name, then_ty))
334 } else if idx == Some(4) {
335 Some((var_name, else_ty))
336 } else {
337 None
338 }
339 } else {
340 None
341 }
342 });
343 if let Some((var_name, ty)) = to_apply {
344 self.push_scope();
345 self.add_var_type(var_name, ty);
346 }
347
348 WalkResult::Continue(())
349 }
350
351 fn leave_node(&mut self, node: &SyntaxNode) -> WalkResult {
352 let kind = match node.kind_as::<Kind>() {
353 Some(k) => k,
354 None => return WalkResult::Continue(()),
355 };
356
357 if let Some(for_in) = node.find_ancestor(self.root, Kind::NodeForInStmt.into_syntax_kind())
359 {
360 if let Some(iterable_node) = for_in_iterable_expr(&for_in) {
361 if iterable_node.text_range() == node.text_range() {
362 if let Some(iterable_ty) = self.type_stack.pop() {
363 let (key_ty, value_ty) = iterable_key_value_types(&iterable_ty);
364 for (i, (var_name, _)) in for_in_loop_vars(&for_in).into_iter().enumerate()
365 {
366 let ty = if i == 0 {
367 key_ty.clone()
368 } else {
369 value_ty.clone()
370 };
371 self.add_var_type(var_name, ty);
372 }
373 }
374 }
375 }
376 }
377
378 let parent = node.ancestors(self.root).into_iter().next();
380 let closing_range = self
381 .null_check_narrowing
382 .last()
383 .map(|narrow| narrow.closing_node.text_range());
384 let is_branch = parent.is_some_and(|ref p| {
385 closing_range == Some(p.text_range()) && {
386 let idx = node_index_in_parent(node, p);
387 let k = p.kind_as::<Kind>();
388 (k == Some(Kind::NodeIfStmt) && (idx == Some(4) || idx == Some(6)))
389 || (k == Some(Kind::NodeExpr)
390 && is_ternary_expr(p)
391 && (idx == Some(2) || idx == Some(4)))
392 }
393 });
394 if is_branch {
395 self.pop_scope();
396 }
397
398 if let Some(parent) = node.ancestors(self.root).into_iter().next() {
400 let idx = node_index_in_parent(node, &parent);
401 let is_if_condition =
402 parent.kind_as::<Kind>() == Some(Kind::NodeIfStmt) && idx == Some(2);
403 let is_ternary_condition = parent.kind_as::<Kind>() == Some(Kind::NodeExpr)
404 && is_ternary_expr(&parent)
405 && idx == Some(0);
406 if is_if_condition || is_ternary_condition {
407 if let Some((var_name, then_is_non_null)) =
408 null_check_from_condition(node, self.root)
409 {
410 let var_ty = self.lookup_var_type(&var_name);
411 let (then_ty, else_ty) = if then_is_non_null {
412 (Type::non_null(&var_ty), Type::null())
413 } else {
414 (Type::null(), Type::non_null(&var_ty))
415 };
416 self.null_check_narrowing.push(NullCheckNarrowing {
417 var_name,
418 then_ty,
419 else_ty,
420 closing_node: parent,
421 });
422 }
423 }
424 }
425
426 match kind {
427 Kind::NodeBlock
428 | Kind::NodeWhileStmt
429 | Kind::NodeForStmt
430 | Kind::NodeForInStmt
431 | Kind::NodeDoWhileStmt => self.pop_scope(),
432 Kind::NodeIfStmt => {
433 if self
434 .null_check_narrowing
435 .last()
436 .is_some_and(|narrow| narrow.closing_node.text_range() == node.text_range())
437 {
438 self.null_check_narrowing.pop();
439 }
440 }
441 Kind::NodeFunctionDecl | Kind::NodeClassDecl | Kind::NodeConstructorDecl => {
442 if kind == Kind::NodeFunctionDecl {
443 self.current_function_return_type = None;
444 } else if kind == Kind::NodeClassDecl {
445 self.current_class = None;
446 self.current_super_class = None;
447 }
448 self.pop_scope();
449 }
450 Kind::NodeAnonFn => {
451 let r = node.text_range();
453 let key = (r.start, r.end);
454 if !self.inferred_return_types.contains_key(&key) {
455 if let Some(ty) = self.type_stack.pop() {
456 self.inferred_return_types.insert(key, ty);
457 }
458 }
459 }
460 Kind::NodePrimaryExpr => {
461 if let Some((class_name, num_args)) = primary_expr_new_constructor(node) {
462 for _ in 0..num_args {
463 self.type_stack.pop();
464 }
465 let ty = Type::instance(class_name);
466 self.type_stack.push(ty.clone());
467 self.record_expression_type(node, &ty);
468 } else if let Some(name) = primary_expr_resolvable_name(node) {
469 self.last_primary_ident = Some(name.clone());
470 let ty = self.resolve_identifier_type(&name);
471 self.type_stack.push(ty.clone());
472 self.record_expression_type(node, &ty);
473 } else if let Some(anon_fn) = node
474 .child_nodes()
475 .find(|c| c.kind_as::<Kind>() == Some(Kind::NodeAnonFn))
476 {
477 let (param_types, _) = anon_fn_types(&anon_fn);
479 let r = anon_fn.text_range();
480 let return_type = self
481 .inferred_return_types
482 .get(&(r.start, r.end))
483 .cloned()
484 .unwrap_or(Type::any());
485 let ty = Type::function(param_types, return_type);
486 self.type_stack.push(ty.clone());
487 self.record_expression_type(node, &ty);
488 } else {
489 let ty = infer_primary_type(node);
491 self.type_stack.push(ty.clone());
492 self.record_expression_type(node, &ty);
493 }
494 }
495 Kind::NodeCallExpr => {
496 let actual_arity = call_argument_count(node);
497 let callee_name = self.last_primary_ident.take();
498 if let Some(ref name) = callee_name {
499 if !self.function_accepts_arity(name, actual_arity) {
500 if let Some(exp) =
501 self.find_in_scope_chain(|scope| scope.get_function_arity(name))
502 {
503 self.diagnostics.push(wrong_arity_at(
504 node_first_span(node),
505 exp,
506 actual_arity,
507 ));
508 }
509 }
510 }
511 let mut arg_types: Vec<Type> = (0..actual_arity)
512 .filter_map(|_| self.type_stack.pop())
513 .collect();
514 arg_types.reverse();
515 let callee_ty = self.type_stack.pop(); let result_type = if let Some(ref name) = callee_name {
517 if let Some((param_types, return_type)) =
518 self.get_function_type(name, actual_arity)
519 {
520 if param_types.len() == arg_types.len() {
521 for (i, (arg_ty, param_ty)) in
522 arg_types.iter().zip(param_types.iter()).enumerate()
523 {
524 if !param_ty.assignable_from(arg_ty) {
525 if let Some(arg_node) = call_argument_node(node, i) {
526 if let Some(tok) = arg_node.first_token() {
527 self.diagnostics.push(type_mismatch_at(
528 tok.text_range(),
529 ¶m_ty.to_string(),
530 &arg_ty.to_string(),
531 ));
532 }
533 }
534 }
535 }
536 }
537 return_type
538 } else if let Some(Type::Class(Some(ref class_name))) = callee_ty.as_ref() {
539 Type::instance(class_name.clone())
541 } else if let Some(Type::Function {
542 args: param_types,
543 return_type,
544 }) = callee_ty.as_ref()
545 {
546 if param_types.len() == arg_types.len() {
548 for (i, (arg_ty, param_ty)) in
549 arg_types.iter().zip(param_types.iter()).enumerate()
550 {
551 if !param_ty.assignable_from(arg_ty) {
552 if let Some(arg_node) = call_argument_node(node, i) {
553 if let Some(tok) = arg_node.first_token() {
554 self.diagnostics.push(type_mismatch_at(
555 tok.text_range(),
556 ¶m_ty.to_string(),
557 &arg_ty.to_string(),
558 ));
559 }
560 }
561 }
562 }
563 }
564 (**return_type).clone()
565 } else {
566 Type::any()
567 }
568 } else if let Some(Type::Function {
569 args: param_types,
570 return_type,
571 }) = callee_ty.as_ref()
572 {
573 if param_types.len() == arg_types.len() {
575 for (i, (arg_ty, param_ty)) in
576 arg_types.iter().zip(param_types.iter()).enumerate()
577 {
578 if !param_ty.assignable_from(arg_ty) {
579 if let Some(arg_node) = call_argument_node(node, i) {
580 if let Some(tok) = arg_node.first_token() {
581 self.diagnostics.push(type_mismatch_at(
582 tok.text_range(),
583 ¶m_ty.to_string(),
584 &arg_ty.to_string(),
585 ));
586 }
587 }
588 }
589 }
590 }
591 (**return_type).clone()
592 } else {
593 Type::any()
594 };
595 let result_type = narrow_return_by_first_arg(result_type, &arg_types);
596 self.type_stack.push(result_type.clone());
597 self.record_expression_type(node, &result_type);
598 }
599 Kind::NodeMemberExpr => {
600 self.last_primary_ident = None;
601 let base_ty = self.type_stack.pop().unwrap_or(Type::any());
602 let ty = if member_expr_member_name(node).as_deref() == Some("class") {
603 match &base_ty {
605 Type::Class(Some(c)) => Type::class(Some(c.clone())),
606 Type::Instance(c) => Type::class(Some(c.clone())),
607 _ => Type::class(None),
608 }
609 } else if let Type::Instance(class_name) = &base_ty {
610 member_expr_member_name(node)
612 .and_then(|name| self.store.get_class_member_type(class_name, &name))
613 .unwrap_or(Type::any())
614 } else if let Type::Class(Some(class_name)) = &base_ty {
615 member_expr_member_name(node)
617 .and_then(|name| self.store.get_class_static_member_type(class_name, &name))
618 .unwrap_or(Type::any())
619 } else {
620 Type::any()
621 };
622 self.type_stack.push(ty.clone());
623 self.record_expression_type(node, &ty);
624 }
625 Kind::NodeIndexExpr => {
626 self.last_primary_ident = None;
627 let _index_ty = self.type_stack.pop().unwrap_or(Type::any());
629 let receiver_ty = self.type_stack.pop().unwrap_or(Type::any());
630 let element_ty = match &receiver_ty {
631 Type::Array(elem) => *elem.clone(),
632 Type::Map(_, val) => *val.clone(),
633 _ => Type::any(),
634 };
635 self.type_stack.push(element_ty.clone());
636 self.record_expression_type(node, &element_ty);
637 }
638 Kind::NodeVarDecl => {
639 if let Some(info) = var_decl_info(node) {
640 let rhs_ty = self.type_stack.pop().unwrap_or(Type::any());
641 let declared = if info.kind == VarDeclKind::Var {
643 None
644 } else {
645 self.store
646 .get(self.current_scope())
647 .and_then(|s| s.get_variable(&info.name))
648 .and_then(|v| v.declared_type.clone())
649 };
650 let ty_to_store = if let Some(ref d) = declared {
651 if !d.assignable_from(&rhs_ty) {
652 self.diagnostics.push(type_mismatch_at(
653 node_first_span(node),
654 &d.to_string(),
655 &rhs_ty.to_string(),
656 ));
657 }
658 d.clone()
659 } else {
660 rhs_ty
661 };
662 self.add_var_type(info.name.clone(), ty_to_store.clone());
663 self.record_expression_type(node, &ty_to_store);
664 let r = info.name_span;
666 self.type_map.insert((r.start, r.end), ty_to_store);
667 }
668 }
669 Kind::NodeExpr => {
670 if is_ternary_expr(node) && self.type_stack.len() >= 3 {
672 let else_ty = self.type_stack.pop().unwrap();
673 let then_ty = self.type_stack.pop().unwrap();
674 let _cond_ty = self.type_stack.pop().unwrap();
675 let result_ty = Type::compound2(then_ty, else_ty);
676 self.type_stack.push(result_ty.clone());
677 self.record_expression_type(node, &result_ty);
678 if self
679 .null_check_narrowing
680 .last()
681 .is_some_and(|narrow| narrow.closing_node.text_range() == node.text_range())
682 {
683 self.null_check_narrowing.pop();
684 }
685 } else {
686 let is_assign = node
688 .children()
689 .any(|c| matches!(c, SyntaxElement::Token(t) if t.text() == "="));
690 if is_assign && self.type_stack.len() >= 2 {
691 let rhs = self.type_stack.pop().unwrap();
692 let lhs = self.type_stack.pop().unwrap();
693 if !lhs.assignable_from(&rhs) {
694 self.diagnostics.push(type_mismatch_at(
695 node_first_span(node),
696 &lhs.to_string(),
697 &rhs.to_string(),
698 ));
699 }
700 self.type_stack.push(lhs);
701 }
702 }
703 }
704 Kind::NodeBinaryExpr => {
705 let op = node
706 .children()
707 .find_map(|c| {
708 if let SyntaxElement::Token(t) = c {
709 if t.kind_as::<Kind>() == Some(Kind::TokOp) {
710 return Some(t.text().to_string());
711 }
712 if t.kind_as::<Kind>() == Some(Kind::KwInstanceof) {
713 return Some("instanceof".to_string());
714 }
715 if t.kind_as::<Kind>() == Some(Kind::KwIn) {
716 return Some("in".to_string());
717 }
718 }
719 None
720 })
721 .unwrap_or_default();
722 if self.type_stack.len() >= 2 {
723 let right = self.type_stack.pop().unwrap();
724 let left = self.type_stack.pop().unwrap();
725 let (result, err) = check_binary_op(&op, &left, &right);
726 if let Some((_, msg)) = err {
727 self.diagnostics.push(
728 SemanticDiagnostic::error(node_first_span(node), msg)
729 .with_code(super::error::AnalysisError::TypeMismatch.code()),
730 );
731 }
732 let result_ty = result.unwrap_or(Type::any());
733 self.type_stack.push(result_ty.clone());
734 self.record_expression_type(node, &result_ty);
735 }
736 }
737 Kind::NodeUnaryExpr => {
738 let op = node
739 .first_token()
740 .map(|t| t.text().to_string())
741 .unwrap_or_default();
742 if let Some(operand) = self.type_stack.pop() {
743 let (result, err) = check_unary_op(&op, &operand);
744 if let Some((_, msg)) = err {
745 self.diagnostics.push(
746 SemanticDiagnostic::error(node_first_span(node), msg)
747 .with_code(super::error::AnalysisError::TypeMismatch.code()),
748 );
749 }
750 let result_ty = result.unwrap_or(Type::any());
751 self.type_stack.push(result_ty.clone());
752 self.record_expression_type(node, &result_ty);
753 }
754 }
755 Kind::NodeReturnStmt => {
756 let expr_type = self.type_stack.pop().unwrap_or(Type::void());
757 if let Some(ref expected) = self.current_function_return_type {
758 if !expected.assignable_from(&expr_type) {
759 self.diagnostics.push(type_mismatch_at(
760 node_first_span(node),
761 &expected.to_string(),
762 &expr_type.to_string(),
763 ));
764 }
765 }
766 if let Some(func) = self.enclosing_function(node) {
768 let r = func.text_range();
769 self.inferred_return_types
770 .insert((r.start, r.end), expr_type);
771 }
772 }
773 Kind::NodeAsCast => {
774 if let Some(expr_ty) = self.type_stack.pop() {
775 let ty = if let Some(te) = find_type_expr_child(node) {
776 if let TypeExprResult::Ok(target_ty) = parse_type_expr(&te) {
777 let cast = Type::check_cast(&expr_ty, &target_ty);
778 if cast == CastType::Incompatible {
779 self.diagnostics.push(invalid_cast_at(
780 node_first_span(node),
781 &expr_ty.to_string(),
782 &target_ty.to_string(),
783 ));
784 }
785 target_ty
786 } else {
787 Type::any()
788 }
789 } else {
790 Type::any()
791 };
792 self.type_stack.push(ty.clone());
793 self.record_expression_type(node, &ty);
794 }
795 }
796 _ => {}
797 }
798
799 WalkResult::Continue(())
800 }
801}
802
803fn infer_primary_type(node: &SyntaxNode) -> Type {
805 let first = node.first_token();
806 let first = match first {
807 Some(t) => t,
808 None => return Type::any(),
809 };
810 match first.kind_as::<Kind>() {
811 Some(Kind::TokNumber) => {
812 let text = first.text();
813 if text.contains('.') || text.to_lowercase().contains('e') {
814 Type::real()
815 } else {
816 Type::int()
817 }
818 }
819 Some(Kind::TokString) => Type::string(),
820 Some(Kind::KwTrue | Kind::KwFalse) => Type::bool(),
821 Some(Kind::KwNull) => Type::null(),
822 _ => Type::any(),
823 }
824}
825
826fn iterable_key_value_types(iterable: &Type) -> (Type, Type) {
829 match iterable {
830 Type::Array(elem) => (Type::int(), *elem.clone()),
831 Type::Map(k, v) => (*k.clone(), *v.clone()),
832 Type::Set(elem) => (Type::int(), *elem.clone()),
833 Type::Interval(elem) => (Type::int(), *elem.clone()),
834 _ => (Type::any(), Type::any()),
835 }
836}
837
838fn narrow_return_by_first_arg(return_type: Type, arg_types: &[Type]) -> Type {
841 let Type::Compound(variants) = return_type else {
842 return return_type;
843 };
844 let Some(first_arg) = arg_types.first() else {
845 return Type::compound(variants);
846 };
847 if first_arg == &Type::Any {
848 return Type::compound(variants);
849 }
850 let matching: Vec<Type> = variants
851 .iter()
852 .filter(|v| v.assignable_from(first_arg))
853 .cloned()
854 .collect();
855 if matching.len() == 1 {
856 matching.into_iter().next().unwrap()
857 } else {
858 Type::compound(variants)
859 }
860}
861
862fn check_binary_op(op: &str, left: &Type, right: &Type) -> (Option<Type>, Option<(Span, String)>) {
864 let numeric_ops = ["+", "-", "*", "/", "\\", "%", "**"];
865 let compare_ops = ["<", "<=", ">", ">="];
866 let equality_ops = ["==", "!="];
867 let logical_ops = ["&&", "||", "and", "or", "xor"];
868
869 if let Type::Compound(ts) = left {
871 let mut result_types: Vec<Type> = Vec::new();
872 let mut first_err: Option<(Span, String)> = None;
873 for t in ts {
874 let (res, err) = check_binary_op(op, t, right);
875 if let Some(ty) = res {
876 if !result_types.iter().any(|r| r == &ty) {
877 result_types.push(ty);
878 }
879 }
880 if first_err.is_none() {
881 first_err = err;
882 }
883 }
884 if !result_types.is_empty() {
885 return (
886 Some(if result_types.len() == 1 {
887 result_types.into_iter().next().unwrap()
888 } else {
889 Type::compound(result_types)
890 }),
891 None,
892 );
893 }
894 return (Some(Type::any()), first_err);
895 }
896 if let Type::Compound(ts) = right {
898 let mut result_types: Vec<Type> = Vec::new();
899 let mut first_err: Option<(Span, String)> = None;
900 for t in ts {
901 let (res, err) = check_binary_op(op, left, t);
902 if let Some(ty) = res {
903 if !result_types.iter().any(|r| r == &ty) {
904 result_types.push(ty);
905 }
906 }
907 if first_err.is_none() {
908 first_err = err;
909 }
910 }
911 if !result_types.is_empty() {
912 return (
913 Some(if result_types.len() == 1 {
914 result_types.into_iter().next().unwrap()
915 } else {
916 Type::compound(result_types)
917 }),
918 None,
919 );
920 }
921 return (Some(Type::any()), first_err);
922 }
923
924 if op == "+"
926 && (left == &Type::String
927 || right == &Type::String
928 || left == &Type::Any
929 || right == &Type::Any)
930 {
931 return (Some(Type::string()), None);
932 }
933 if numeric_ops.contains(&op) {
934 if left != &Type::Any && right != &Type::Any && (!left.is_number() || !right.is_number()) {
935 return (
936 Some(Type::real()),
937 Some((
938 Span::new(0, 0),
939 format!("operator `{op}` requires number, got {left} and {right}"),
940 )),
941 );
942 }
943 let result = if left == &Type::Int && right == &Type::Int {
945 Type::int()
946 } else {
947 Type::real()
948 };
949 (Some(result), None)
950 } else if compare_ops.contains(&op) {
951 if left != &Type::Any && right != &Type::Any && (!left.is_number() || !right.is_number()) {
952 return (
953 Some(Type::bool()),
954 Some((
955 Span::new(0, 0),
956 format!("comparison requires number, got {left} and {right}"),
957 )),
958 );
959 }
960 (Some(Type::bool()), None)
961 } else if equality_ops.contains(&op) {
962 (Some(Type::bool()), None)
963 } else if op == "instanceof" {
964 (Some(Type::bool()), None)
966 } else if op == "in" {
967 (Some(Type::bool()), None)
969 } else if logical_ops.contains(&op) {
970 (Some(Type::bool()), None)
972 } else {
973 (Some(Type::any()), None)
974 }
975}
976
977fn check_unary_op(op: &str, operand: &Type) -> (Option<Type>, Option<(Span, String)>) {
979 match op {
980 "-" | "+" => {
981 if operand != &Type::Any && !operand.is_number() {
982 return (
983 Some(Type::real()),
984 Some((
985 Span::new(0, 0),
986 format!("unary `{op}` requires number, got {operand}"),
987 )),
988 );
989 }
990 let result_ty = match operand {
992 Type::Int => Type::int(),
993 _ if operand.is_number() => Type::real(),
994 _ => Type::real(),
995 };
996 (Some(result_ty), None)
997 }
998 "!" | "not" => {
999 (Some(Type::bool()), None)
1001 }
1002 _ => (Some(Type::any()), None),
1003 }
1004}