1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12 pub help: Option<String>,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum DiagnosticSeverity {
17 Error,
18 Warning,
19}
20
21type InferredType = Option<TypeExpr>;
23
24#[derive(Debug, Clone)]
26struct TypeScope {
27 vars: BTreeMap<String, InferredType>,
29 functions: BTreeMap<String, FnSignature>,
31 type_aliases: BTreeMap<String, TypeExpr>,
33 enums: BTreeMap<String, Vec<String>>,
35 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
37 structs: BTreeMap<String, Vec<(String, InferredType)>>,
39 impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
41 generic_type_params: std::collections::BTreeSet<String>,
43 where_constraints: BTreeMap<String, String>,
46 parent: Option<Box<TypeScope>>,
47}
48
49#[derive(Debug, Clone)]
51struct ImplMethodSig {
52 name: String,
53 param_count: usize,
55 param_types: Vec<Option<TypeExpr>>,
57 return_type: Option<TypeExpr>,
59}
60
61#[derive(Debug, Clone)]
62struct FnSignature {
63 params: Vec<(String, InferredType)>,
64 return_type: InferredType,
65 type_param_names: Vec<String>,
67 required_params: usize,
69 where_clauses: Vec<(String, String)>,
71}
72
73impl TypeScope {
74 fn new() -> Self {
75 Self {
76 vars: BTreeMap::new(),
77 functions: BTreeMap::new(),
78 type_aliases: BTreeMap::new(),
79 enums: BTreeMap::new(),
80 interfaces: BTreeMap::new(),
81 structs: BTreeMap::new(),
82 impl_methods: BTreeMap::new(),
83 generic_type_params: std::collections::BTreeSet::new(),
84 where_constraints: BTreeMap::new(),
85 parent: None,
86 }
87 }
88
89 fn child(&self) -> Self {
90 Self {
91 vars: BTreeMap::new(),
92 functions: BTreeMap::new(),
93 type_aliases: BTreeMap::new(),
94 enums: BTreeMap::new(),
95 interfaces: BTreeMap::new(),
96 structs: BTreeMap::new(),
97 impl_methods: BTreeMap::new(),
98 generic_type_params: std::collections::BTreeSet::new(),
99 where_constraints: BTreeMap::new(),
100 parent: Some(Box::new(self.clone())),
101 }
102 }
103
104 fn get_var(&self, name: &str) -> Option<&InferredType> {
105 self.vars
106 .get(name)
107 .or_else(|| self.parent.as_ref()?.get_var(name))
108 }
109
110 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
111 self.functions
112 .get(name)
113 .or_else(|| self.parent.as_ref()?.get_fn(name))
114 }
115
116 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
117 self.type_aliases
118 .get(name)
119 .or_else(|| self.parent.as_ref()?.resolve_type(name))
120 }
121
122 fn is_generic_type_param(&self, name: &str) -> bool {
123 self.generic_type_params.contains(name)
124 || self
125 .parent
126 .as_ref()
127 .is_some_and(|p| p.is_generic_type_param(name))
128 }
129
130 fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
131 self.where_constraints
132 .get(type_param)
133 .map(|s| s.as_str())
134 .or_else(|| {
135 self.parent
136 .as_ref()
137 .and_then(|p| p.get_where_constraint(type_param))
138 })
139 }
140
141 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
142 self.enums
143 .get(name)
144 .or_else(|| self.parent.as_ref()?.get_enum(name))
145 }
146
147 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
148 self.interfaces
149 .get(name)
150 .or_else(|| self.parent.as_ref()?.get_interface(name))
151 }
152
153 fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
154 self.structs
155 .get(name)
156 .or_else(|| self.parent.as_ref()?.get_struct(name))
157 }
158
159 fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
160 self.impl_methods
161 .get(name)
162 .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
163 }
164
165 fn define_var(&mut self, name: &str, ty: InferredType) {
166 self.vars.insert(name.to_string(), ty);
167 }
168
169 fn define_fn(&mut self, name: &str, sig: FnSignature) {
170 self.functions.insert(name.to_string(), sig);
171 }
172}
173
174fn builtin_return_type(name: &str) -> InferredType {
176 match name {
177 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
178 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
179 Some(TypeExpr::Named("nil".into()))
180 }
181 "type_of"
182 | "to_string"
183 | "json_stringify"
184 | "read_file"
185 | "http_get"
186 | "http_post"
187 | "llm_call"
188 | "regex_replace"
189 | "path_join"
190 | "temp_dir"
191 | "date_format"
192 | "format"
193 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
194 "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
195 "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
196 | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
197 Some(TypeExpr::Named("float".into()))
198 }
199 "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
200 Some(TypeExpr::Named("bool".into()))
201 }
202 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list"
203 | "regex_captures" => Some(TypeExpr::Named("list".into())),
204 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
205 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
206 Some(TypeExpr::Named("dict".into()))
207 }
208 "metadata_set"
209 | "metadata_save"
210 | "metadata_refresh_hashes"
211 | "invalidate_facts"
212 | "log_json"
213 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
214 "env" | "regex_match" => Some(TypeExpr::Union(vec![
215 TypeExpr::Named("string".into()),
216 TypeExpr::Named("nil".into()),
217 ])),
218 "json_parse" | "json_extract" => None, _ => None,
220 }
221}
222
223fn is_builtin(name: &str) -> bool {
225 matches!(
226 name,
227 "log"
228 | "print"
229 | "println"
230 | "type_of"
231 | "to_string"
232 | "to_int"
233 | "to_float"
234 | "json_stringify"
235 | "json_parse"
236 | "env"
237 | "timestamp"
238 | "sleep"
239 | "read_file"
240 | "write_file"
241 | "exit"
242 | "regex_match"
243 | "regex_replace"
244 | "regex_captures"
245 | "http_get"
246 | "http_post"
247 | "llm_call"
248 | "agent_loop"
249 | "await"
250 | "cancel"
251 | "file_exists"
252 | "delete_file"
253 | "list_dir"
254 | "mkdir"
255 | "path_join"
256 | "copy_file"
257 | "append_file"
258 | "temp_dir"
259 | "stat"
260 | "exec"
261 | "shell"
262 | "date_now"
263 | "date_format"
264 | "date_parse"
265 | "format"
266 | "json_validate"
267 | "json_extract"
268 | "trim"
269 | "lowercase"
270 | "uppercase"
271 | "split"
272 | "starts_with"
273 | "ends_with"
274 | "contains"
275 | "replace"
276 | "join"
277 | "len"
278 | "substring"
279 | "dirname"
280 | "basename"
281 | "extname"
282 | "sin"
283 | "cos"
284 | "tan"
285 | "asin"
286 | "acos"
287 | "atan"
288 | "atan2"
289 | "log2"
290 | "log10"
291 | "ln"
292 | "exp"
293 | "pi"
294 | "e"
295 | "sign"
296 | "is_nan"
297 | "is_infinite"
298 | "set"
299 | "set_add"
300 | "set_remove"
301 | "set_contains"
302 | "set_union"
303 | "set_intersect"
304 | "set_difference"
305 | "to_list"
306 )
307}
308
309pub struct TypeChecker {
311 diagnostics: Vec<TypeDiagnostic>,
312 scope: TypeScope,
313}
314
315impl TypeChecker {
316 pub fn new() -> Self {
317 Self {
318 diagnostics: Vec::new(),
319 scope: TypeScope::new(),
320 }
321 }
322
323 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
325 Self::register_declarations_into(&mut self.scope, program);
327
328 for snode in program {
330 if let Node::Pipeline { body, .. } = &snode.node {
331 Self::register_declarations_into(&mut self.scope, body);
332 }
333 }
334
335 for snode in program {
337 match &snode.node {
338 Node::Pipeline { params, body, .. } => {
339 let mut child = self.scope.child();
340 for p in params {
341 child.define_var(p, None);
342 }
343 self.check_block(body, &mut child);
344 }
345 Node::FnDecl {
346 name,
347 type_params,
348 params,
349 return_type,
350 where_clauses,
351 body,
352 ..
353 } => {
354 let required_params =
355 params.iter().filter(|p| p.default_value.is_none()).count();
356 let sig = FnSignature {
357 params: params
358 .iter()
359 .map(|p| (p.name.clone(), p.type_expr.clone()))
360 .collect(),
361 return_type: return_type.clone(),
362 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
363 required_params,
364 where_clauses: where_clauses
365 .iter()
366 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
367 .collect(),
368 };
369 self.scope.define_fn(name, sig);
370 self.check_fn_body(type_params, params, return_type, body, where_clauses);
371 }
372 _ => {
373 let mut scope = self.scope.clone();
374 self.check_node(snode, &mut scope);
375 for (name, ty) in scope.vars {
377 self.scope.vars.entry(name).or_insert(ty);
378 }
379 }
380 }
381 }
382
383 self.diagnostics
384 }
385
386 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
388 for snode in nodes {
389 match &snode.node {
390 Node::TypeDecl { name, type_expr } => {
391 scope.type_aliases.insert(name.clone(), type_expr.clone());
392 }
393 Node::EnumDecl { name, variants } => {
394 let variant_names: Vec<String> =
395 variants.iter().map(|v| v.name.clone()).collect();
396 scope.enums.insert(name.clone(), variant_names);
397 }
398 Node::InterfaceDecl { name, methods } => {
399 scope.interfaces.insert(name.clone(), methods.clone());
400 }
401 Node::StructDecl { name, fields } => {
402 let field_types: Vec<(String, InferredType)> = fields
403 .iter()
404 .map(|f| (f.name.clone(), f.type_expr.clone()))
405 .collect();
406 scope.structs.insert(name.clone(), field_types);
407 }
408 Node::ImplBlock {
409 type_name, methods, ..
410 } => {
411 let sigs: Vec<ImplMethodSig> = methods
412 .iter()
413 .filter_map(|m| {
414 if let Node::FnDecl {
415 name,
416 params,
417 return_type,
418 ..
419 } = &m.node
420 {
421 let non_self: Vec<_> =
422 params.iter().filter(|p| p.name != "self").collect();
423 let param_count = non_self.len();
424 let param_types: Vec<Option<TypeExpr>> =
425 non_self.iter().map(|p| p.type_expr.clone()).collect();
426 Some(ImplMethodSig {
427 name: name.clone(),
428 param_count,
429 param_types,
430 return_type: return_type.clone(),
431 })
432 } else {
433 None
434 }
435 })
436 .collect();
437 scope.impl_methods.insert(type_name.clone(), sigs);
438 }
439 _ => {}
440 }
441 }
442 }
443
444 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
445 for stmt in stmts {
446 self.check_node(stmt, scope);
447 }
448 }
449
450 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
452 match pattern {
453 BindingPattern::Identifier(name) => {
454 scope.define_var(name, None);
455 }
456 BindingPattern::Dict(fields) => {
457 for field in fields {
458 let name = field.alias.as_deref().unwrap_or(&field.key);
459 scope.define_var(name, None);
460 }
461 }
462 BindingPattern::List(elements) => {
463 for elem in elements {
464 scope.define_var(&elem.name, None);
465 }
466 }
467 }
468 }
469
470 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
471 let span = snode.span;
472 match &snode.node {
473 Node::LetBinding {
474 pattern,
475 type_ann,
476 value,
477 } => {
478 let inferred = self.infer_type(value, scope);
479 if let BindingPattern::Identifier(name) = pattern {
480 if let Some(expected) = type_ann {
481 if let Some(actual) = &inferred {
482 if !self.types_compatible(expected, actual, scope) {
483 let mut msg = format!(
484 "Type mismatch: '{}' declared as {}, but assigned {}",
485 name,
486 format_type(expected),
487 format_type(actual)
488 );
489 if let Some(detail) = shape_mismatch_detail(expected, actual) {
490 msg.push_str(&format!(" ({})", detail));
491 }
492 self.error_at(msg, span);
493 }
494 }
495 }
496 let ty = type_ann.clone().or(inferred);
497 scope.define_var(name, ty);
498 } else {
499 Self::define_pattern_vars(pattern, scope);
500 }
501 }
502
503 Node::VarBinding {
504 pattern,
505 type_ann,
506 value,
507 } => {
508 let inferred = self.infer_type(value, scope);
509 if let BindingPattern::Identifier(name) = pattern {
510 if let Some(expected) = type_ann {
511 if let Some(actual) = &inferred {
512 if !self.types_compatible(expected, actual, scope) {
513 let mut msg = format!(
514 "Type mismatch: '{}' declared as {}, but assigned {}",
515 name,
516 format_type(expected),
517 format_type(actual)
518 );
519 if let Some(detail) = shape_mismatch_detail(expected, actual) {
520 msg.push_str(&format!(" ({})", detail));
521 }
522 self.error_at(msg, span);
523 }
524 }
525 }
526 let ty = type_ann.clone().or(inferred);
527 scope.define_var(name, ty);
528 } else {
529 Self::define_pattern_vars(pattern, scope);
530 }
531 }
532
533 Node::FnDecl {
534 name,
535 type_params,
536 params,
537 return_type,
538 where_clauses,
539 body,
540 ..
541 } => {
542 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
543 let sig = FnSignature {
544 params: params
545 .iter()
546 .map(|p| (p.name.clone(), p.type_expr.clone()))
547 .collect(),
548 return_type: return_type.clone(),
549 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
550 required_params,
551 where_clauses: where_clauses
552 .iter()
553 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
554 .collect(),
555 };
556 scope.define_fn(name, sig.clone());
557 scope.define_var(name, None);
558 self.check_fn_body(type_params, params, return_type, body, where_clauses);
559 }
560
561 Node::FunctionCall { name, args } => {
562 self.check_call(name, args, scope, span);
563 }
564
565 Node::IfElse {
566 condition,
567 then_body,
568 else_body,
569 } => {
570 self.check_node(condition, scope);
571 let mut then_scope = scope.child();
572 if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
574 then_scope.define_var(&var_name, narrowed);
575 }
576 self.check_block(then_body, &mut then_scope);
577 if let Some(else_body) = else_body {
578 let mut else_scope = scope.child();
579 self.check_block(else_body, &mut else_scope);
580 }
581 }
582
583 Node::ForIn {
584 pattern,
585 iterable,
586 body,
587 } => {
588 self.check_node(iterable, scope);
589 let mut loop_scope = scope.child();
590 if let BindingPattern::Identifier(variable) = pattern {
591 let elem_type = match self.infer_type(iterable, scope) {
593 Some(TypeExpr::List(inner)) => Some(*inner),
594 Some(TypeExpr::Named(n)) if n == "string" => {
595 Some(TypeExpr::Named("string".into()))
596 }
597 _ => None,
598 };
599 loop_scope.define_var(variable, elem_type);
600 } else {
601 Self::define_pattern_vars(pattern, &mut loop_scope);
602 }
603 self.check_block(body, &mut loop_scope);
604 }
605
606 Node::WhileLoop { condition, body } => {
607 self.check_node(condition, scope);
608 let mut loop_scope = scope.child();
609 self.check_block(body, &mut loop_scope);
610 }
611
612 Node::TryCatch {
613 body,
614 error_var,
615 catch_body,
616 finally_body,
617 ..
618 } => {
619 let mut try_scope = scope.child();
620 self.check_block(body, &mut try_scope);
621 let mut catch_scope = scope.child();
622 if let Some(var) = error_var {
623 catch_scope.define_var(var, None);
624 }
625 self.check_block(catch_body, &mut catch_scope);
626 if let Some(fb) = finally_body {
627 let mut finally_scope = scope.child();
628 self.check_block(fb, &mut finally_scope);
629 }
630 }
631
632 Node::TryExpr { body } => {
633 let mut try_scope = scope.child();
634 self.check_block(body, &mut try_scope);
635 }
636
637 Node::ReturnStmt {
638 value: Some(val), ..
639 } => {
640 self.check_node(val, scope);
641 }
642
643 Node::Assignment {
644 target, value, op, ..
645 } => {
646 self.check_node(value, scope);
647 if let Node::Identifier(name) = &target.node {
648 if let Some(Some(var_type)) = scope.get_var(name) {
649 let value_type = self.infer_type(value, scope);
650 let assigned = if let Some(op) = op {
651 let var_inferred = scope.get_var(name).cloned().flatten();
652 infer_binary_op_type(op, &var_inferred, &value_type)
653 } else {
654 value_type
655 };
656 if let Some(actual) = &assigned {
657 if !self.types_compatible(var_type, actual, scope) {
658 self.error_at(
659 format!(
660 "Type mismatch: cannot assign {} to '{}' (declared as {})",
661 format_type(actual),
662 name,
663 format_type(var_type)
664 ),
665 span,
666 );
667 }
668 }
669 }
670 }
671 }
672
673 Node::TypeDecl { name, type_expr } => {
674 scope.type_aliases.insert(name.clone(), type_expr.clone());
675 }
676
677 Node::EnumDecl { name, variants } => {
678 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
679 scope.enums.insert(name.clone(), variant_names);
680 }
681
682 Node::StructDecl { name, fields } => {
683 let field_types: Vec<(String, InferredType)> = fields
684 .iter()
685 .map(|f| (f.name.clone(), f.type_expr.clone()))
686 .collect();
687 scope.structs.insert(name.clone(), field_types);
688 }
689
690 Node::InterfaceDecl { name, methods } => {
691 scope.interfaces.insert(name.clone(), methods.clone());
692 }
693
694 Node::ImplBlock {
695 type_name, methods, ..
696 } => {
697 let sigs: Vec<ImplMethodSig> = methods
699 .iter()
700 .filter_map(|m| {
701 if let Node::FnDecl {
702 name,
703 params,
704 return_type,
705 ..
706 } = &m.node
707 {
708 let non_self: Vec<_> =
709 params.iter().filter(|p| p.name != "self").collect();
710 let param_count = non_self.len();
711 let param_types: Vec<Option<TypeExpr>> =
712 non_self.iter().map(|p| p.type_expr.clone()).collect();
713 Some(ImplMethodSig {
714 name: name.clone(),
715 param_count,
716 param_types,
717 return_type: return_type.clone(),
718 })
719 } else {
720 None
721 }
722 })
723 .collect();
724 scope.impl_methods.insert(type_name.clone(), sigs);
725 for method_sn in methods {
726 self.check_node(method_sn, scope);
727 }
728 }
729
730 Node::TryOperator { operand } => {
731 self.check_node(operand, scope);
732 }
733
734 Node::MatchExpr { value, arms } => {
735 self.check_node(value, scope);
736 let value_type = self.infer_type(value, scope);
737 for arm in arms {
738 self.check_node(&arm.pattern, scope);
739 if let Some(ref vt) = value_type {
741 let value_type_name = format_type(vt);
742 let mismatch = match &arm.pattern.node {
743 Node::StringLiteral(_) => {
744 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
745 }
746 Node::IntLiteral(_) => {
747 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
748 && !self.types_compatible(
749 vt,
750 &TypeExpr::Named("float".into()),
751 scope,
752 )
753 }
754 Node::FloatLiteral(_) => {
755 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
756 && !self.types_compatible(
757 vt,
758 &TypeExpr::Named("int".into()),
759 scope,
760 )
761 }
762 Node::BoolLiteral(_) => {
763 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
764 }
765 _ => false,
766 };
767 if mismatch {
768 let pattern_type = match &arm.pattern.node {
769 Node::StringLiteral(_) => "string",
770 Node::IntLiteral(_) => "int",
771 Node::FloatLiteral(_) => "float",
772 Node::BoolLiteral(_) => "bool",
773 _ => unreachable!(),
774 };
775 self.warning_at(
776 format!(
777 "Match pattern type mismatch: matching {} against {} literal",
778 value_type_name, pattern_type
779 ),
780 arm.pattern.span,
781 );
782 }
783 }
784 let mut arm_scope = scope.child();
785 self.check_block(&arm.body, &mut arm_scope);
786 }
787 self.check_match_exhaustiveness(value, arms, scope, span);
788 }
789
790 Node::BinaryOp { op, left, right } => {
792 self.check_node(left, scope);
793 self.check_node(right, scope);
794 let lt = self.infer_type(left, scope);
796 let rt = self.infer_type(right, scope);
797 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
798 match op.as_str() {
799 "-" | "*" | "/" | "%" => {
800 let numeric = ["int", "float"];
801 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
802 self.warning_at(
803 format!(
804 "Operator '{op}' may not be valid for types {} and {}",
805 l, r
806 ),
807 span,
808 );
809 }
810 }
811 "+" => {
812 let valid = ["int", "float", "string", "list", "dict"];
814 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
815 self.warning_at(
816 format!(
817 "Operator '+' may not be valid for types {} and {}",
818 l, r
819 ),
820 span,
821 );
822 }
823 }
824 _ => {}
825 }
826 }
827 }
828 Node::UnaryOp { operand, .. } => {
829 self.check_node(operand, scope);
830 }
831 Node::MethodCall {
832 object,
833 method,
834 args,
835 ..
836 }
837 | Node::OptionalMethodCall {
838 object,
839 method,
840 args,
841 ..
842 } => {
843 self.check_node(object, scope);
844 for arg in args {
845 self.check_node(arg, scope);
846 }
847 if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
851 if scope.is_generic_type_param(&type_name) {
852 if let Some(iface_name) = scope.get_where_constraint(&type_name) {
853 if let Some(iface_methods) = scope.get_interface(iface_name) {
854 let has_method = iface_methods.iter().any(|m| m.name == *method);
855 if !has_method {
856 self.warning_at(
857 format!(
858 "Method '{}' not found in interface '{}' (constraint on '{}')",
859 method, iface_name, type_name
860 ),
861 span,
862 );
863 }
864 }
865 }
866 }
867 }
868 }
869 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
870 self.check_node(object, scope);
871 }
872 Node::SubscriptAccess { object, index } => {
873 self.check_node(object, scope);
874 self.check_node(index, scope);
875 }
876 Node::SliceAccess { object, start, end } => {
877 self.check_node(object, scope);
878 if let Some(s) = start {
879 self.check_node(s, scope);
880 }
881 if let Some(e) = end {
882 self.check_node(e, scope);
883 }
884 }
885
886 _ => {}
888 }
889 }
890
891 fn check_fn_body(
892 &mut self,
893 type_params: &[TypeParam],
894 params: &[TypedParam],
895 return_type: &Option<TypeExpr>,
896 body: &[SNode],
897 where_clauses: &[WhereClause],
898 ) {
899 let mut fn_scope = self.scope.child();
900 for tp in type_params {
903 fn_scope.generic_type_params.insert(tp.name.clone());
904 }
905 for wc in where_clauses {
907 fn_scope
908 .where_constraints
909 .insert(wc.type_name.clone(), wc.bound.clone());
910 }
911 for param in params {
912 fn_scope.define_var(¶m.name, param.type_expr.clone());
913 if let Some(default) = ¶m.default_value {
914 self.check_node(default, &mut fn_scope);
915 }
916 }
917 self.check_block(body, &mut fn_scope);
918
919 if let Some(ret_type) = return_type {
921 for stmt in body {
922 self.check_return_type(stmt, ret_type, &fn_scope);
923 }
924 }
925 }
926
927 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
928 let span = snode.span;
929 match &snode.node {
930 Node::ReturnStmt { value: Some(val) } => {
931 let inferred = self.infer_type(val, scope);
932 if let Some(actual) = &inferred {
933 if !self.types_compatible(expected, actual, scope) {
934 self.error_at(
935 format!(
936 "Return type mismatch: expected {}, got {}",
937 format_type(expected),
938 format_type(actual)
939 ),
940 span,
941 );
942 }
943 }
944 }
945 Node::IfElse {
946 then_body,
947 else_body,
948 ..
949 } => {
950 for stmt in then_body {
951 self.check_return_type(stmt, expected, scope);
952 }
953 if let Some(else_body) = else_body {
954 for stmt in else_body {
955 self.check_return_type(stmt, expected, scope);
956 }
957 }
958 }
959 _ => {}
960 }
961 }
962
963 fn satisfies_interface(
969 &self,
970 type_name: &str,
971 interface_name: &str,
972 scope: &TypeScope,
973 ) -> bool {
974 let interface_methods = match scope.get_interface(interface_name) {
975 Some(methods) => methods,
976 None => return false,
977 };
978 let impl_methods = match scope.get_impl_methods(type_name) {
979 Some(methods) => methods,
980 None => return interface_methods.is_empty(),
981 };
982 interface_methods.iter().all(|iface_method| {
983 let iface_params: Vec<_> = iface_method
984 .params
985 .iter()
986 .filter(|p| p.name != "self")
987 .collect();
988 let iface_param_count = iface_params.len();
989 impl_methods.iter().any(|impl_method| {
990 if impl_method.name != iface_method.name
991 || impl_method.param_count != iface_param_count
992 {
993 return false;
994 }
995 for (i, iface_param) in iface_params.iter().enumerate() {
997 if let (Some(expected), Some(actual)) = (
998 &iface_param.type_expr,
999 impl_method.param_types.get(i).and_then(|t| t.as_ref()),
1000 ) {
1001 if !self.types_compatible(expected, actual, scope) {
1002 return false;
1003 }
1004 }
1005 }
1006 if let (Some(expected_ret), Some(actual_ret)) =
1008 (&iface_method.return_type, &impl_method.return_type)
1009 {
1010 if !self.types_compatible(expected_ret, actual_ret, scope) {
1011 return false;
1012 }
1013 }
1014 true
1015 })
1016 })
1017 }
1018
1019 fn extract_type_bindings(
1022 param_type: &TypeExpr,
1023 arg_type: &TypeExpr,
1024 type_params: &std::collections::BTreeSet<String>,
1025 bindings: &mut BTreeMap<String, String>,
1026 ) {
1027 match (param_type, arg_type) {
1028 (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
1030 if type_params.contains(param_name) =>
1031 {
1032 bindings
1033 .entry(param_name.clone())
1034 .or_insert(concrete.clone());
1035 }
1036 (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
1038 Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
1039 }
1040 (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
1042 Self::extract_type_bindings(pk, ak, type_params, bindings);
1043 Self::extract_type_bindings(pv, av, type_params, bindings);
1044 }
1045 _ => {}
1046 }
1047 }
1048
1049 fn extract_nil_narrowing(
1050 condition: &SNode,
1051 scope: &TypeScope,
1052 ) -> Option<(String, InferredType)> {
1053 if let Node::BinaryOp { op, left, right } = &condition.node {
1054 if op == "!=" {
1055 let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1057 (left, right)
1058 } else if matches!(left.node, Node::NilLiteral) {
1059 (right, left)
1060 } else {
1061 return None;
1062 };
1063 let _ = nil_node;
1064 if let Node::Identifier(name) = &var_node.node {
1065 if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1067 let narrowed: Vec<TypeExpr> = members
1068 .iter()
1069 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1070 .cloned()
1071 .collect();
1072 return if narrowed.len() == 1 {
1073 Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1074 } else if narrowed.is_empty() {
1075 None
1076 } else {
1077 Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1078 };
1079 }
1080 }
1081 }
1082 }
1083 None
1084 }
1085
1086 fn check_match_exhaustiveness(
1087 &mut self,
1088 value: &SNode,
1089 arms: &[MatchArm],
1090 scope: &TypeScope,
1091 span: Span,
1092 ) {
1093 let enum_name = match &value.node {
1095 Node::PropertyAccess { object, property } if property == "variant" => {
1096 match self.infer_type(object, scope) {
1098 Some(TypeExpr::Named(name)) => {
1099 if scope.get_enum(&name).is_some() {
1100 Some(name)
1101 } else {
1102 None
1103 }
1104 }
1105 _ => None,
1106 }
1107 }
1108 _ => {
1109 match self.infer_type(value, scope) {
1111 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1112 _ => None,
1113 }
1114 }
1115 };
1116
1117 let Some(enum_name) = enum_name else {
1118 return;
1119 };
1120 let Some(variants) = scope.get_enum(&enum_name) else {
1121 return;
1122 };
1123
1124 let mut covered: Vec<String> = Vec::new();
1126 let mut has_wildcard = false;
1127
1128 for arm in arms {
1129 match &arm.pattern.node {
1130 Node::StringLiteral(s) => covered.push(s.clone()),
1132 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1134 has_wildcard = true;
1135 }
1136 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1138 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1140 _ => {
1141 has_wildcard = true;
1143 }
1144 }
1145 }
1146
1147 if has_wildcard {
1148 return;
1149 }
1150
1151 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1152 if !missing.is_empty() {
1153 let missing_str = missing
1154 .iter()
1155 .map(|s| format!("\"{}\"", s))
1156 .collect::<Vec<_>>()
1157 .join(", ");
1158 self.warning_at(
1159 format!(
1160 "Non-exhaustive match on enum {}: missing variants {}",
1161 enum_name, missing_str
1162 ),
1163 span,
1164 );
1165 }
1166 }
1167
1168 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1169 let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1171 if let Some(sig) = scope.get_fn(name).cloned() {
1172 if !has_spread
1173 && !is_builtin(name)
1174 && (args.len() < sig.required_params || args.len() > sig.params.len())
1175 {
1176 let expected = if sig.required_params == sig.params.len() {
1177 format!("{}", sig.params.len())
1178 } else {
1179 format!("{}-{}", sig.required_params, sig.params.len())
1180 };
1181 self.warning_at(
1182 format!(
1183 "Function '{}' expects {} arguments, got {}",
1184 name,
1185 expected,
1186 args.len()
1187 ),
1188 span,
1189 );
1190 }
1191 let call_scope = if sig.type_param_names.is_empty() {
1194 scope.clone()
1195 } else {
1196 let mut s = scope.child();
1197 for tp_name in &sig.type_param_names {
1198 s.generic_type_params.insert(tp_name.clone());
1199 }
1200 s
1201 };
1202 for (i, (arg, (param_name, param_type))) in
1203 args.iter().zip(sig.params.iter()).enumerate()
1204 {
1205 if let Some(expected) = param_type {
1206 let actual = self.infer_type(arg, scope);
1207 if let Some(actual) = &actual {
1208 if !self.types_compatible(expected, actual, &call_scope) {
1209 self.error_at(
1210 format!(
1211 "Argument {} ('{}'): expected {}, got {}",
1212 i + 1,
1213 param_name,
1214 format_type(expected),
1215 format_type(actual)
1216 ),
1217 arg.span,
1218 );
1219 }
1220 }
1221 }
1222 }
1223 if !sig.where_clauses.is_empty() {
1225 let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1228 let type_param_set: std::collections::BTreeSet<String> =
1229 sig.type_param_names.iter().cloned().collect();
1230 for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1231 if let Some(param_ty) = param_type {
1232 if let Some(arg_ty) = self.infer_type(arg, scope) {
1233 Self::extract_type_bindings(
1234 param_ty,
1235 &arg_ty,
1236 &type_param_set,
1237 &mut type_bindings,
1238 );
1239 }
1240 }
1241 }
1242 for (type_param, bound) in &sig.where_clauses {
1243 if let Some(concrete_type) = type_bindings.get(type_param) {
1244 if !self.satisfies_interface(concrete_type, bound, scope) {
1245 self.warning_at(
1246 format!(
1247 "Type '{}' does not satisfy interface '{}': \
1248 required by constraint `where {}: {}`",
1249 concrete_type, bound, type_param, bound
1250 ),
1251 span,
1252 );
1253 }
1254 }
1255 }
1256 }
1257 }
1258 for arg in args {
1260 self.check_node(arg, scope);
1261 }
1262 }
1263
1264 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1266 match &snode.node {
1267 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1268 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1269 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1270 Some(TypeExpr::Named("string".into()))
1271 }
1272 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1273 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1274 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1275 Node::DictLiteral(entries) => {
1276 let mut fields = Vec::new();
1278 let mut all_string_keys = true;
1279 for entry in entries {
1280 if let Node::StringLiteral(key) = &entry.key.node {
1281 let val_type = self
1282 .infer_type(&entry.value, scope)
1283 .unwrap_or(TypeExpr::Named("nil".into()));
1284 fields.push(ShapeField {
1285 name: key.clone(),
1286 type_expr: val_type,
1287 optional: false,
1288 });
1289 } else {
1290 all_string_keys = false;
1291 break;
1292 }
1293 }
1294 if all_string_keys && !fields.is_empty() {
1295 Some(TypeExpr::Shape(fields))
1296 } else {
1297 Some(TypeExpr::Named("dict".into()))
1298 }
1299 }
1300 Node::Closure { params, body, .. } => {
1301 let all_typed = params.iter().all(|p| p.type_expr.is_some());
1303 if all_typed && !params.is_empty() {
1304 let param_types: Vec<TypeExpr> =
1305 params.iter().filter_map(|p| p.type_expr.clone()).collect();
1306 let ret = body.last().and_then(|last| self.infer_type(last, scope));
1308 if let Some(ret_type) = ret {
1309 return Some(TypeExpr::FnType {
1310 params: param_types,
1311 return_type: Box::new(ret_type),
1312 });
1313 }
1314 }
1315 Some(TypeExpr::Named("closure".into()))
1316 }
1317
1318 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1319
1320 Node::FunctionCall { name, .. } => {
1321 if scope.get_struct(name).is_some() {
1323 return Some(TypeExpr::Named(name.clone()));
1324 }
1325 if let Some(sig) = scope.get_fn(name) {
1327 return sig.return_type.clone();
1328 }
1329 builtin_return_type(name)
1331 }
1332
1333 Node::BinaryOp { op, left, right } => {
1334 let lt = self.infer_type(left, scope);
1335 let rt = self.infer_type(right, scope);
1336 infer_binary_op_type(op, <, &rt)
1337 }
1338
1339 Node::UnaryOp { op, operand } => {
1340 let t = self.infer_type(operand, scope);
1341 match op.as_str() {
1342 "!" => Some(TypeExpr::Named("bool".into())),
1343 "-" => t, _ => None,
1345 }
1346 }
1347
1348 Node::Ternary {
1349 true_expr,
1350 false_expr,
1351 ..
1352 } => {
1353 let tt = self.infer_type(true_expr, scope);
1354 let ft = self.infer_type(false_expr, scope);
1355 match (&tt, &ft) {
1356 (Some(a), Some(b)) if a == b => tt,
1357 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1358 (Some(_), None) => tt,
1359 (None, Some(_)) => ft,
1360 (None, None) => None,
1361 }
1362 }
1363
1364 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1365
1366 Node::PropertyAccess { object, property } => {
1367 if let Node::Identifier(name) = &object.node {
1369 if scope.get_enum(name).is_some() {
1370 return Some(TypeExpr::Named(name.clone()));
1371 }
1372 }
1373 if property == "variant" {
1375 let obj_type = self.infer_type(object, scope);
1376 if let Some(TypeExpr::Named(name)) = &obj_type {
1377 if scope.get_enum(name).is_some() {
1378 return Some(TypeExpr::Named("string".into()));
1379 }
1380 }
1381 }
1382 let obj_type = self.infer_type(object, scope);
1384 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1385 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1386 return Some(field.type_expr.clone());
1387 }
1388 }
1389 None
1390 }
1391
1392 Node::SubscriptAccess { object, index } => {
1393 let obj_type = self.infer_type(object, scope);
1394 match &obj_type {
1395 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1396 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1397 Some(TypeExpr::Shape(fields)) => {
1398 if let Node::StringLiteral(key) = &index.node {
1400 fields
1401 .iter()
1402 .find(|f| &f.name == key)
1403 .map(|f| f.type_expr.clone())
1404 } else {
1405 None
1406 }
1407 }
1408 Some(TypeExpr::Named(n)) if n == "list" => None,
1409 Some(TypeExpr::Named(n)) if n == "dict" => None,
1410 Some(TypeExpr::Named(n)) if n == "string" => {
1411 Some(TypeExpr::Named("string".into()))
1412 }
1413 _ => None,
1414 }
1415 }
1416 Node::SliceAccess { object, .. } => {
1417 let obj_type = self.infer_type(object, scope);
1419 match &obj_type {
1420 Some(TypeExpr::List(_)) => obj_type,
1421 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1422 Some(TypeExpr::Named(n)) if n == "string" => {
1423 Some(TypeExpr::Named("string".into()))
1424 }
1425 _ => None,
1426 }
1427 }
1428 Node::MethodCall { object, method, .. }
1429 | Node::OptionalMethodCall { object, method, .. } => {
1430 let obj_type = self.infer_type(object, scope);
1431 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1432 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1433 match method.as_str() {
1434 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1436 Some(TypeExpr::Named("bool".into()))
1437 }
1438 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1440 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1442 | "pad_left" | "pad_right" | "repeat" | "join" => {
1443 Some(TypeExpr::Named("string".into()))
1444 }
1445 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1446 "filter" => {
1448 if is_dict {
1449 Some(TypeExpr::Named("dict".into()))
1450 } else {
1451 Some(TypeExpr::Named("list".into()))
1452 }
1453 }
1454 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1456 "reduce" | "find" | "first" | "last" => None,
1457 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1459 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1460 "to_string" => Some(TypeExpr::Named("string".into())),
1462 "to_int" => Some(TypeExpr::Named("int".into())),
1463 "to_float" => Some(TypeExpr::Named("float".into())),
1464 _ => None,
1465 }
1466 }
1467
1468 Node::TryOperator { operand } => {
1470 match self.infer_type(operand, scope) {
1471 Some(TypeExpr::Named(name)) if name == "Result" => None, _ => None,
1473 }
1474 }
1475
1476 _ => None,
1477 }
1478 }
1479
1480 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1482 if let TypeExpr::Named(name) = expected {
1484 if scope.is_generic_type_param(name) {
1485 return true;
1486 }
1487 }
1488 if let TypeExpr::Named(name) = actual {
1489 if scope.is_generic_type_param(name) {
1490 return true;
1491 }
1492 }
1493 let expected = self.resolve_alias(expected, scope);
1494 let actual = self.resolve_alias(actual, scope);
1495
1496 if let TypeExpr::Named(iface_name) = &expected {
1499 if scope.get_interface(iface_name).is_some() {
1500 if let TypeExpr::Named(type_name) = &actual {
1501 return self.satisfies_interface(type_name, iface_name, scope);
1502 }
1503 return false;
1504 }
1505 }
1506
1507 match (&expected, &actual) {
1508 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1509 (TypeExpr::Union(members), actual_type) => members
1510 .iter()
1511 .any(|m| self.types_compatible(m, actual_type, scope)),
1512 (expected_type, TypeExpr::Union(members)) => members
1513 .iter()
1514 .all(|m| self.types_compatible(expected_type, m, scope)),
1515 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1516 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1517 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1518 if expected_field.optional {
1519 return true;
1520 }
1521 af.iter().any(|actual_field| {
1522 actual_field.name == expected_field.name
1523 && self.types_compatible(
1524 &expected_field.type_expr,
1525 &actual_field.type_expr,
1526 scope,
1527 )
1528 })
1529 }),
1530 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1532 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1533 keys_ok
1534 && af
1535 .iter()
1536 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1537 }
1538 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1540 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1541 self.types_compatible(expected_inner, actual_inner, scope)
1542 }
1543 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1544 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1545 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1546 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1547 }
1548 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1549 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1550 (
1552 TypeExpr::FnType {
1553 params: ep,
1554 return_type: er,
1555 },
1556 TypeExpr::FnType {
1557 params: ap,
1558 return_type: ar,
1559 },
1560 ) => {
1561 ep.len() == ap.len()
1562 && ep
1563 .iter()
1564 .zip(ap.iter())
1565 .all(|(e, a)| self.types_compatible(e, a, scope))
1566 && self.types_compatible(er, ar, scope)
1567 }
1568 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1570 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1571 _ => false,
1572 }
1573 }
1574
1575 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1576 if let TypeExpr::Named(name) = ty {
1577 if let Some(resolved) = scope.resolve_type(name) {
1578 return resolved.clone();
1579 }
1580 }
1581 ty.clone()
1582 }
1583
1584 fn error_at(&mut self, message: String, span: Span) {
1585 self.diagnostics.push(TypeDiagnostic {
1586 message,
1587 severity: DiagnosticSeverity::Error,
1588 span: Some(span),
1589 help: None,
1590 });
1591 }
1592
1593 #[allow(dead_code)]
1594 fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
1595 self.diagnostics.push(TypeDiagnostic {
1596 message,
1597 severity: DiagnosticSeverity::Error,
1598 span: Some(span),
1599 help: Some(help),
1600 });
1601 }
1602
1603 fn warning_at(&mut self, message: String, span: Span) {
1604 self.diagnostics.push(TypeDiagnostic {
1605 message,
1606 severity: DiagnosticSeverity::Warning,
1607 span: Some(span),
1608 help: None,
1609 });
1610 }
1611
1612 #[allow(dead_code)]
1613 fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
1614 self.diagnostics.push(TypeDiagnostic {
1615 message,
1616 severity: DiagnosticSeverity::Warning,
1617 span: Some(span),
1618 help: Some(help),
1619 });
1620 }
1621}
1622
1623impl Default for TypeChecker {
1624 fn default() -> Self {
1625 Self::new()
1626 }
1627}
1628
1629fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1631 match op {
1632 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
1633 Some(TypeExpr::Named("bool".into()))
1634 }
1635 "+" => match (left, right) {
1636 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1637 match (l.as_str(), r.as_str()) {
1638 ("int", "int") => Some(TypeExpr::Named("int".into())),
1639 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1640 ("string", _) => Some(TypeExpr::Named("string".into())),
1641 ("list", "list") => Some(TypeExpr::Named("list".into())),
1642 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1643 _ => Some(TypeExpr::Named("string".into())),
1644 }
1645 }
1646 _ => None,
1647 },
1648 "-" | "*" | "/" | "%" => match (left, right) {
1649 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1650 match (l.as_str(), r.as_str()) {
1651 ("int", "int") => Some(TypeExpr::Named("int".into())),
1652 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1653 _ => None,
1654 }
1655 }
1656 _ => None,
1657 },
1658 "??" => match (left, right) {
1659 (Some(TypeExpr::Union(members)), _) => {
1660 let non_nil: Vec<_> = members
1661 .iter()
1662 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1663 .cloned()
1664 .collect();
1665 if non_nil.len() == 1 {
1666 Some(non_nil[0].clone())
1667 } else if non_nil.is_empty() {
1668 right.clone()
1669 } else {
1670 Some(TypeExpr::Union(non_nil))
1671 }
1672 }
1673 _ => right.clone(),
1674 },
1675 "|>" => None,
1676 _ => None,
1677 }
1678}
1679
1680pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1685 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1686 let mut details = Vec::new();
1687 for field in ef {
1688 if field.optional {
1689 continue;
1690 }
1691 match af.iter().find(|f| f.name == field.name) {
1692 None => details.push(format!(
1693 "missing field '{}' ({})",
1694 field.name,
1695 format_type(&field.type_expr)
1696 )),
1697 Some(actual_field) => {
1698 let e_str = format_type(&field.type_expr);
1699 let a_str = format_type(&actual_field.type_expr);
1700 if e_str != a_str {
1701 details.push(format!(
1702 "field '{}' has type {}, expected {}",
1703 field.name, a_str, e_str
1704 ));
1705 }
1706 }
1707 }
1708 }
1709 if details.is_empty() {
1710 None
1711 } else {
1712 Some(details.join("; "))
1713 }
1714 } else {
1715 None
1716 }
1717}
1718
1719pub fn format_type(ty: &TypeExpr) -> String {
1720 match ty {
1721 TypeExpr::Named(n) => n.clone(),
1722 TypeExpr::Union(types) => types
1723 .iter()
1724 .map(format_type)
1725 .collect::<Vec<_>>()
1726 .join(" | "),
1727 TypeExpr::Shape(fields) => {
1728 let inner: Vec<String> = fields
1729 .iter()
1730 .map(|f| {
1731 let opt = if f.optional { "?" } else { "" };
1732 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1733 })
1734 .collect();
1735 format!("{{{}}}", inner.join(", "))
1736 }
1737 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1738 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1739 TypeExpr::FnType {
1740 params,
1741 return_type,
1742 } => {
1743 let params_str = params
1744 .iter()
1745 .map(format_type)
1746 .collect::<Vec<_>>()
1747 .join(", ");
1748 format!("fn({}) -> {}", params_str, format_type(return_type))
1749 }
1750 }
1751}
1752
1753#[cfg(test)]
1754mod tests {
1755 use super::*;
1756 use crate::Parser;
1757 use harn_lexer::Lexer;
1758
1759 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1760 let mut lexer = Lexer::new(source);
1761 let tokens = lexer.tokenize().unwrap();
1762 let mut parser = Parser::new(tokens);
1763 let program = parser.parse().unwrap();
1764 TypeChecker::new().check(&program)
1765 }
1766
1767 fn errors(source: &str) -> Vec<String> {
1768 check_source(source)
1769 .into_iter()
1770 .filter(|d| d.severity == DiagnosticSeverity::Error)
1771 .map(|d| d.message)
1772 .collect()
1773 }
1774
1775 #[test]
1776 fn test_no_errors_for_untyped_code() {
1777 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1778 assert!(errs.is_empty());
1779 }
1780
1781 #[test]
1782 fn test_correct_typed_let() {
1783 let errs = errors("pipeline t(task) { let x: int = 42 }");
1784 assert!(errs.is_empty());
1785 }
1786
1787 #[test]
1788 fn test_type_mismatch_let() {
1789 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1790 assert_eq!(errs.len(), 1);
1791 assert!(errs[0].contains("Type mismatch"));
1792 assert!(errs[0].contains("int"));
1793 assert!(errs[0].contains("string"));
1794 }
1795
1796 #[test]
1797 fn test_correct_typed_fn() {
1798 let errs = errors(
1799 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1800 );
1801 assert!(errs.is_empty());
1802 }
1803
1804 #[test]
1805 fn test_fn_arg_type_mismatch() {
1806 let errs = errors(
1807 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1808add("hello", 2) }"#,
1809 );
1810 assert_eq!(errs.len(), 1);
1811 assert!(errs[0].contains("Argument 1"));
1812 assert!(errs[0].contains("expected int"));
1813 }
1814
1815 #[test]
1816 fn test_return_type_mismatch() {
1817 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1818 assert_eq!(errs.len(), 1);
1819 assert!(errs[0].contains("Return type mismatch"));
1820 }
1821
1822 #[test]
1823 fn test_union_type_compatible() {
1824 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1825 assert!(errs.is_empty());
1826 }
1827
1828 #[test]
1829 fn test_union_type_mismatch() {
1830 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1831 assert_eq!(errs.len(), 1);
1832 assert!(errs[0].contains("Type mismatch"));
1833 }
1834
1835 #[test]
1836 fn test_type_inference_propagation() {
1837 let errs = errors(
1838 r#"pipeline t(task) {
1839 fn add(a: int, b: int) -> int { return a + b }
1840 let result: string = add(1, 2)
1841}"#,
1842 );
1843 assert_eq!(errs.len(), 1);
1844 assert!(errs[0].contains("Type mismatch"));
1845 assert!(errs[0].contains("string"));
1846 assert!(errs[0].contains("int"));
1847 }
1848
1849 #[test]
1850 fn test_builtin_return_type_inference() {
1851 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1852 assert_eq!(errs.len(), 1);
1853 assert!(errs[0].contains("string"));
1854 assert!(errs[0].contains("int"));
1855 }
1856
1857 #[test]
1858 fn test_binary_op_type_inference() {
1859 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1860 assert_eq!(errs.len(), 1);
1861 }
1862
1863 #[test]
1864 fn test_comparison_returns_bool() {
1865 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1866 assert!(errs.is_empty());
1867 }
1868
1869 #[test]
1870 fn test_int_float_promotion() {
1871 let errs = errors("pipeline t(task) { let x: float = 42 }");
1872 assert!(errs.is_empty());
1873 }
1874
1875 #[test]
1876 fn test_untyped_code_no_errors() {
1877 let errs = errors(
1878 r#"pipeline t(task) {
1879 fn process(data) {
1880 let result = data + " processed"
1881 return result
1882 }
1883 log(process("hello"))
1884}"#,
1885 );
1886 assert!(errs.is_empty());
1887 }
1888
1889 #[test]
1890 fn test_type_alias() {
1891 let errs = errors(
1892 r#"pipeline t(task) {
1893 type Name = string
1894 let x: Name = "hello"
1895}"#,
1896 );
1897 assert!(errs.is_empty());
1898 }
1899
1900 #[test]
1901 fn test_type_alias_mismatch() {
1902 let errs = errors(
1903 r#"pipeline t(task) {
1904 type Name = string
1905 let x: Name = 42
1906}"#,
1907 );
1908 assert_eq!(errs.len(), 1);
1909 }
1910
1911 #[test]
1912 fn test_assignment_type_check() {
1913 let errs = errors(
1914 r#"pipeline t(task) {
1915 var x: int = 0
1916 x = "hello"
1917}"#,
1918 );
1919 assert_eq!(errs.len(), 1);
1920 assert!(errs[0].contains("cannot assign string"));
1921 }
1922
1923 #[test]
1924 fn test_covariance_int_to_float_in_fn() {
1925 let errs = errors(
1926 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1927 );
1928 assert!(errs.is_empty());
1929 }
1930
1931 #[test]
1932 fn test_covariance_return_type() {
1933 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1934 assert!(errs.is_empty());
1935 }
1936
1937 #[test]
1938 fn test_no_contravariance_float_to_int() {
1939 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1940 assert_eq!(errs.len(), 1);
1941 }
1942
1943 fn warnings(source: &str) -> Vec<String> {
1946 check_source(source)
1947 .into_iter()
1948 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1949 .map(|d| d.message)
1950 .collect()
1951 }
1952
1953 #[test]
1954 fn test_exhaustive_match_no_warning() {
1955 let warns = warnings(
1956 r#"pipeline t(task) {
1957 enum Color { Red, Green, Blue }
1958 let c = Color.Red
1959 match c.variant {
1960 "Red" -> { log("r") }
1961 "Green" -> { log("g") }
1962 "Blue" -> { log("b") }
1963 }
1964}"#,
1965 );
1966 let exhaustive_warns: Vec<_> = warns
1967 .iter()
1968 .filter(|w| w.contains("Non-exhaustive"))
1969 .collect();
1970 assert!(exhaustive_warns.is_empty());
1971 }
1972
1973 #[test]
1974 fn test_non_exhaustive_match_warning() {
1975 let warns = warnings(
1976 r#"pipeline t(task) {
1977 enum Color { Red, Green, Blue }
1978 let c = Color.Red
1979 match c.variant {
1980 "Red" -> { log("r") }
1981 "Green" -> { log("g") }
1982 }
1983}"#,
1984 );
1985 let exhaustive_warns: Vec<_> = warns
1986 .iter()
1987 .filter(|w| w.contains("Non-exhaustive"))
1988 .collect();
1989 assert_eq!(exhaustive_warns.len(), 1);
1990 assert!(exhaustive_warns[0].contains("Blue"));
1991 }
1992
1993 #[test]
1994 fn test_non_exhaustive_multiple_missing() {
1995 let warns = warnings(
1996 r#"pipeline t(task) {
1997 enum Status { Active, Inactive, Pending }
1998 let s = Status.Active
1999 match s.variant {
2000 "Active" -> { log("a") }
2001 }
2002}"#,
2003 );
2004 let exhaustive_warns: Vec<_> = warns
2005 .iter()
2006 .filter(|w| w.contains("Non-exhaustive"))
2007 .collect();
2008 assert_eq!(exhaustive_warns.len(), 1);
2009 assert!(exhaustive_warns[0].contains("Inactive"));
2010 assert!(exhaustive_warns[0].contains("Pending"));
2011 }
2012
2013 #[test]
2014 fn test_enum_construct_type_inference() {
2015 let errs = errors(
2016 r#"pipeline t(task) {
2017 enum Color { Red, Green, Blue }
2018 let c: Color = Color.Red
2019}"#,
2020 );
2021 assert!(errs.is_empty());
2022 }
2023
2024 #[test]
2027 fn test_nil_coalescing_strips_nil() {
2028 let errs = errors(
2030 r#"pipeline t(task) {
2031 let x: string | nil = nil
2032 let y: string = x ?? "default"
2033}"#,
2034 );
2035 assert!(errs.is_empty());
2036 }
2037
2038 #[test]
2039 fn test_shape_mismatch_detail_missing_field() {
2040 let errs = errors(
2041 r#"pipeline t(task) {
2042 let x: {name: string, age: int} = {name: "hello"}
2043}"#,
2044 );
2045 assert_eq!(errs.len(), 1);
2046 assert!(
2047 errs[0].contains("missing field 'age'"),
2048 "expected detail about missing field, got: {}",
2049 errs[0]
2050 );
2051 }
2052
2053 #[test]
2054 fn test_shape_mismatch_detail_wrong_type() {
2055 let errs = errors(
2056 r#"pipeline t(task) {
2057 let x: {name: string, age: int} = {name: 42, age: 10}
2058}"#,
2059 );
2060 assert_eq!(errs.len(), 1);
2061 assert!(
2062 errs[0].contains("field 'name' has type int, expected string"),
2063 "expected detail about wrong type, got: {}",
2064 errs[0]
2065 );
2066 }
2067
2068 #[test]
2071 fn test_match_pattern_string_against_int() {
2072 let warns = warnings(
2073 r#"pipeline t(task) {
2074 let x: int = 42
2075 match x {
2076 "hello" -> { log("bad") }
2077 42 -> { log("ok") }
2078 }
2079}"#,
2080 );
2081 let pattern_warns: Vec<_> = warns
2082 .iter()
2083 .filter(|w| w.contains("Match pattern type mismatch"))
2084 .collect();
2085 assert_eq!(pattern_warns.len(), 1);
2086 assert!(pattern_warns[0].contains("matching int against string literal"));
2087 }
2088
2089 #[test]
2090 fn test_match_pattern_int_against_string() {
2091 let warns = warnings(
2092 r#"pipeline t(task) {
2093 let x: string = "hello"
2094 match x {
2095 42 -> { log("bad") }
2096 "hello" -> { log("ok") }
2097 }
2098}"#,
2099 );
2100 let pattern_warns: Vec<_> = warns
2101 .iter()
2102 .filter(|w| w.contains("Match pattern type mismatch"))
2103 .collect();
2104 assert_eq!(pattern_warns.len(), 1);
2105 assert!(pattern_warns[0].contains("matching string against int literal"));
2106 }
2107
2108 #[test]
2109 fn test_match_pattern_bool_against_int() {
2110 let warns = warnings(
2111 r#"pipeline t(task) {
2112 let x: int = 42
2113 match x {
2114 true -> { log("bad") }
2115 42 -> { log("ok") }
2116 }
2117}"#,
2118 );
2119 let pattern_warns: Vec<_> = warns
2120 .iter()
2121 .filter(|w| w.contains("Match pattern type mismatch"))
2122 .collect();
2123 assert_eq!(pattern_warns.len(), 1);
2124 assert!(pattern_warns[0].contains("matching int against bool literal"));
2125 }
2126
2127 #[test]
2128 fn test_match_pattern_float_against_string() {
2129 let warns = warnings(
2130 r#"pipeline t(task) {
2131 let x: string = "hello"
2132 match x {
2133 3.14 -> { log("bad") }
2134 "hello" -> { log("ok") }
2135 }
2136}"#,
2137 );
2138 let pattern_warns: Vec<_> = warns
2139 .iter()
2140 .filter(|w| w.contains("Match pattern type mismatch"))
2141 .collect();
2142 assert_eq!(pattern_warns.len(), 1);
2143 assert!(pattern_warns[0].contains("matching string against float literal"));
2144 }
2145
2146 #[test]
2147 fn test_match_pattern_int_against_float_ok() {
2148 let warns = warnings(
2150 r#"pipeline t(task) {
2151 let x: float = 3.14
2152 match x {
2153 42 -> { log("ok") }
2154 _ -> { log("default") }
2155 }
2156}"#,
2157 );
2158 let pattern_warns: Vec<_> = warns
2159 .iter()
2160 .filter(|w| w.contains("Match pattern type mismatch"))
2161 .collect();
2162 assert!(pattern_warns.is_empty());
2163 }
2164
2165 #[test]
2166 fn test_match_pattern_float_against_int_ok() {
2167 let warns = warnings(
2169 r#"pipeline t(task) {
2170 let x: int = 42
2171 match x {
2172 3.14 -> { log("close") }
2173 _ -> { log("default") }
2174 }
2175}"#,
2176 );
2177 let pattern_warns: Vec<_> = warns
2178 .iter()
2179 .filter(|w| w.contains("Match pattern type mismatch"))
2180 .collect();
2181 assert!(pattern_warns.is_empty());
2182 }
2183
2184 #[test]
2185 fn test_match_pattern_correct_types_no_warning() {
2186 let warns = warnings(
2187 r#"pipeline t(task) {
2188 let x: int = 42
2189 match x {
2190 1 -> { log("one") }
2191 2 -> { log("two") }
2192 _ -> { log("other") }
2193 }
2194}"#,
2195 );
2196 let pattern_warns: Vec<_> = warns
2197 .iter()
2198 .filter(|w| w.contains("Match pattern type mismatch"))
2199 .collect();
2200 assert!(pattern_warns.is_empty());
2201 }
2202
2203 #[test]
2204 fn test_match_pattern_wildcard_no_warning() {
2205 let warns = warnings(
2206 r#"pipeline t(task) {
2207 let x: int = 42
2208 match x {
2209 _ -> { log("catch all") }
2210 }
2211}"#,
2212 );
2213 let pattern_warns: Vec<_> = warns
2214 .iter()
2215 .filter(|w| w.contains("Match pattern type mismatch"))
2216 .collect();
2217 assert!(pattern_warns.is_empty());
2218 }
2219
2220 #[test]
2221 fn test_match_pattern_untyped_no_warning() {
2222 let warns = warnings(
2224 r#"pipeline t(task) {
2225 let x = some_unknown_fn()
2226 match x {
2227 "hello" -> { log("string") }
2228 42 -> { log("int") }
2229 }
2230}"#,
2231 );
2232 let pattern_warns: Vec<_> = warns
2233 .iter()
2234 .filter(|w| w.contains("Match pattern type mismatch"))
2235 .collect();
2236 assert!(pattern_warns.is_empty());
2237 }
2238}