1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
40 generic_type_params: std::collections::BTreeSet<String>,
42 parent: Option<Box<TypeScope>>,
43}
44
45#[derive(Debug, Clone)]
47struct ImplMethodSig {
48 name: String,
49 param_count: usize,
51}
52
53#[derive(Debug, Clone)]
54struct FnSignature {
55 params: Vec<(String, InferredType)>,
56 return_type: InferredType,
57 type_param_names: Vec<String>,
59 required_params: usize,
61 where_clauses: Vec<(String, String)>,
63}
64
65impl TypeScope {
66 fn new() -> Self {
67 Self {
68 vars: BTreeMap::new(),
69 functions: BTreeMap::new(),
70 type_aliases: BTreeMap::new(),
71 enums: BTreeMap::new(),
72 interfaces: BTreeMap::new(),
73 structs: BTreeMap::new(),
74 impl_methods: BTreeMap::new(),
75 generic_type_params: std::collections::BTreeSet::new(),
76 parent: None,
77 }
78 }
79
80 fn child(&self) -> Self {
81 Self {
82 vars: BTreeMap::new(),
83 functions: BTreeMap::new(),
84 type_aliases: BTreeMap::new(),
85 enums: BTreeMap::new(),
86 interfaces: BTreeMap::new(),
87 structs: BTreeMap::new(),
88 impl_methods: BTreeMap::new(),
89 generic_type_params: std::collections::BTreeSet::new(),
90 parent: Some(Box::new(self.clone())),
91 }
92 }
93
94 fn get_var(&self, name: &str) -> Option<&InferredType> {
95 self.vars
96 .get(name)
97 .or_else(|| self.parent.as_ref()?.get_var(name))
98 }
99
100 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
101 self.functions
102 .get(name)
103 .or_else(|| self.parent.as_ref()?.get_fn(name))
104 }
105
106 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
107 self.type_aliases
108 .get(name)
109 .or_else(|| self.parent.as_ref()?.resolve_type(name))
110 }
111
112 fn is_generic_type_param(&self, name: &str) -> bool {
113 self.generic_type_params.contains(name)
114 || self
115 .parent
116 .as_ref()
117 .is_some_and(|p| p.is_generic_type_param(name))
118 }
119
120 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
121 self.enums
122 .get(name)
123 .or_else(|| self.parent.as_ref()?.get_enum(name))
124 }
125
126 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
127 self.interfaces
128 .get(name)
129 .or_else(|| self.parent.as_ref()?.get_interface(name))
130 }
131
132 fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
133 self.structs
134 .get(name)
135 .or_else(|| self.parent.as_ref()?.get_struct(name))
136 }
137
138 fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
139 self.impl_methods
140 .get(name)
141 .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
142 }
143
144 fn define_var(&mut self, name: &str, ty: InferredType) {
145 self.vars.insert(name.to_string(), ty);
146 }
147
148 fn define_fn(&mut self, name: &str, sig: FnSignature) {
149 self.functions.insert(name.to_string(), sig);
150 }
151}
152
153fn builtin_return_type(name: &str) -> InferredType {
155 match name {
156 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
157 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
158 Some(TypeExpr::Named("nil".into()))
159 }
160 "type_of"
161 | "to_string"
162 | "json_stringify"
163 | "read_file"
164 | "http_get"
165 | "http_post"
166 | "llm_call"
167 | "regex_replace"
168 | "path_join"
169 | "temp_dir"
170 | "date_format"
171 | "format"
172 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
173 "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
174 "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
175 | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
176 Some(TypeExpr::Named("float".into()))
177 }
178 "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
179 Some(TypeExpr::Named("bool".into()))
180 }
181 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list"
182 | "regex_captures" => Some(TypeExpr::Named("list".into())),
183 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
184 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
185 Some(TypeExpr::Named("dict".into()))
186 }
187 "metadata_set"
188 | "metadata_save"
189 | "metadata_refresh_hashes"
190 | "invalidate_facts"
191 | "log_json"
192 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
193 "env" | "regex_match" => Some(TypeExpr::Union(vec![
194 TypeExpr::Named("string".into()),
195 TypeExpr::Named("nil".into()),
196 ])),
197 "json_parse" | "json_extract" => None, _ => None,
199 }
200}
201
202fn is_builtin(name: &str) -> bool {
204 matches!(
205 name,
206 "log"
207 | "print"
208 | "println"
209 | "type_of"
210 | "to_string"
211 | "to_int"
212 | "to_float"
213 | "json_stringify"
214 | "json_parse"
215 | "env"
216 | "timestamp"
217 | "sleep"
218 | "read_file"
219 | "write_file"
220 | "exit"
221 | "regex_match"
222 | "regex_replace"
223 | "regex_captures"
224 | "http_get"
225 | "http_post"
226 | "llm_call"
227 | "agent_loop"
228 | "await"
229 | "cancel"
230 | "file_exists"
231 | "delete_file"
232 | "list_dir"
233 | "mkdir"
234 | "path_join"
235 | "copy_file"
236 | "append_file"
237 | "temp_dir"
238 | "stat"
239 | "exec"
240 | "shell"
241 | "date_now"
242 | "date_format"
243 | "date_parse"
244 | "format"
245 | "json_validate"
246 | "json_extract"
247 | "trim"
248 | "lowercase"
249 | "uppercase"
250 | "split"
251 | "starts_with"
252 | "ends_with"
253 | "contains"
254 | "replace"
255 | "join"
256 | "len"
257 | "substring"
258 | "dirname"
259 | "basename"
260 | "extname"
261 | "sin"
262 | "cos"
263 | "tan"
264 | "asin"
265 | "acos"
266 | "atan"
267 | "atan2"
268 | "log2"
269 | "log10"
270 | "ln"
271 | "exp"
272 | "pi"
273 | "e"
274 | "sign"
275 | "is_nan"
276 | "is_infinite"
277 | "set"
278 | "set_add"
279 | "set_remove"
280 | "set_contains"
281 | "set_union"
282 | "set_intersect"
283 | "set_difference"
284 | "to_list"
285 )
286}
287
288pub struct TypeChecker {
290 diagnostics: Vec<TypeDiagnostic>,
291 scope: TypeScope,
292}
293
294impl TypeChecker {
295 pub fn new() -> Self {
296 Self {
297 diagnostics: Vec::new(),
298 scope: TypeScope::new(),
299 }
300 }
301
302 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
304 Self::register_declarations_into(&mut self.scope, program);
306
307 for snode in program {
309 if let Node::Pipeline { body, .. } = &snode.node {
310 Self::register_declarations_into(&mut self.scope, body);
311 }
312 }
313
314 for snode in program {
316 match &snode.node {
317 Node::Pipeline { params, body, .. } => {
318 let mut child = self.scope.child();
319 for p in params {
320 child.define_var(p, None);
321 }
322 self.check_block(body, &mut child);
323 }
324 Node::FnDecl {
325 name,
326 type_params,
327 params,
328 return_type,
329 where_clauses,
330 body,
331 ..
332 } => {
333 let required_params =
334 params.iter().filter(|p| p.default_value.is_none()).count();
335 let sig = FnSignature {
336 params: params
337 .iter()
338 .map(|p| (p.name.clone(), p.type_expr.clone()))
339 .collect(),
340 return_type: return_type.clone(),
341 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
342 required_params,
343 where_clauses: where_clauses
344 .iter()
345 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
346 .collect(),
347 };
348 self.scope.define_fn(name, sig);
349 self.check_fn_body(type_params, params, return_type, body);
350 }
351 _ => {
352 let mut scope = self.scope.clone();
353 self.check_node(snode, &mut scope);
354 for (name, ty) in scope.vars {
356 self.scope.vars.entry(name).or_insert(ty);
357 }
358 }
359 }
360 }
361
362 self.diagnostics
363 }
364
365 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
367 for snode in nodes {
368 match &snode.node {
369 Node::TypeDecl { name, type_expr } => {
370 scope.type_aliases.insert(name.clone(), type_expr.clone());
371 }
372 Node::EnumDecl { name, variants } => {
373 let variant_names: Vec<String> =
374 variants.iter().map(|v| v.name.clone()).collect();
375 scope.enums.insert(name.clone(), variant_names);
376 }
377 Node::InterfaceDecl { name, methods } => {
378 scope.interfaces.insert(name.clone(), methods.clone());
379 }
380 Node::StructDecl { name, fields } => {
381 let field_types: Vec<(String, InferredType)> = fields
382 .iter()
383 .map(|f| (f.name.clone(), f.type_expr.clone()))
384 .collect();
385 scope.structs.insert(name.clone(), field_types);
386 }
387 Node::ImplBlock {
388 type_name, methods, ..
389 } => {
390 let sigs: Vec<ImplMethodSig> = methods
391 .iter()
392 .filter_map(|m| {
393 if let Node::FnDecl { name, params, .. } = &m.node {
394 let param_count =
396 params.iter().filter(|p| p.name != "self").count();
397 Some(ImplMethodSig {
398 name: name.clone(),
399 param_count,
400 })
401 } else {
402 None
403 }
404 })
405 .collect();
406 scope.impl_methods.insert(type_name.clone(), sigs);
407 }
408 _ => {}
409 }
410 }
411 }
412
413 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
414 for stmt in stmts {
415 self.check_node(stmt, scope);
416 }
417 }
418
419 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
421 match pattern {
422 BindingPattern::Identifier(name) => {
423 scope.define_var(name, None);
424 }
425 BindingPattern::Dict(fields) => {
426 for field in fields {
427 let name = field.alias.as_deref().unwrap_or(&field.key);
428 scope.define_var(name, None);
429 }
430 }
431 BindingPattern::List(elements) => {
432 for elem in elements {
433 scope.define_var(&elem.name, None);
434 }
435 }
436 }
437 }
438
439 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
440 let span = snode.span;
441 match &snode.node {
442 Node::LetBinding {
443 pattern,
444 type_ann,
445 value,
446 } => {
447 let inferred = self.infer_type(value, scope);
448 if let BindingPattern::Identifier(name) = pattern {
449 if let Some(expected) = type_ann {
450 if let Some(actual) = &inferred {
451 if !self.types_compatible(expected, actual, scope) {
452 let mut msg = format!(
453 "Type mismatch: '{}' declared as {}, but assigned {}",
454 name,
455 format_type(expected),
456 format_type(actual)
457 );
458 if let Some(detail) = shape_mismatch_detail(expected, actual) {
459 msg.push_str(&format!(" ({})", detail));
460 }
461 self.error_at(msg, span);
462 }
463 }
464 }
465 let ty = type_ann.clone().or(inferred);
466 scope.define_var(name, ty);
467 } else {
468 Self::define_pattern_vars(pattern, scope);
469 }
470 }
471
472 Node::VarBinding {
473 pattern,
474 type_ann,
475 value,
476 } => {
477 let inferred = self.infer_type(value, scope);
478 if let BindingPattern::Identifier(name) = pattern {
479 if let Some(expected) = type_ann {
480 if let Some(actual) = &inferred {
481 if !self.types_compatible(expected, actual, scope) {
482 let mut msg = format!(
483 "Type mismatch: '{}' declared as {}, but assigned {}",
484 name,
485 format_type(expected),
486 format_type(actual)
487 );
488 if let Some(detail) = shape_mismatch_detail(expected, actual) {
489 msg.push_str(&format!(" ({})", detail));
490 }
491 self.error_at(msg, span);
492 }
493 }
494 }
495 let ty = type_ann.clone().or(inferred);
496 scope.define_var(name, ty);
497 } else {
498 Self::define_pattern_vars(pattern, scope);
499 }
500 }
501
502 Node::FnDecl {
503 name,
504 type_params,
505 params,
506 return_type,
507 where_clauses,
508 body,
509 ..
510 } => {
511 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
512 let sig = FnSignature {
513 params: params
514 .iter()
515 .map(|p| (p.name.clone(), p.type_expr.clone()))
516 .collect(),
517 return_type: return_type.clone(),
518 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
519 required_params,
520 where_clauses: where_clauses
521 .iter()
522 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
523 .collect(),
524 };
525 scope.define_fn(name, sig.clone());
526 scope.define_var(name, None);
527 self.check_fn_body(type_params, params, return_type, body);
528 }
529
530 Node::FunctionCall { name, args } => {
531 self.check_call(name, args, scope, span);
532 }
533
534 Node::IfElse {
535 condition,
536 then_body,
537 else_body,
538 } => {
539 self.check_node(condition, scope);
540 let mut then_scope = scope.child();
541 if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
543 then_scope.define_var(&var_name, narrowed);
544 }
545 self.check_block(then_body, &mut then_scope);
546 if let Some(else_body) = else_body {
547 let mut else_scope = scope.child();
548 self.check_block(else_body, &mut else_scope);
549 }
550 }
551
552 Node::ForIn {
553 pattern,
554 iterable,
555 body,
556 } => {
557 self.check_node(iterable, scope);
558 let mut loop_scope = scope.child();
559 if let BindingPattern::Identifier(variable) = pattern {
560 let elem_type = match self.infer_type(iterable, scope) {
562 Some(TypeExpr::List(inner)) => Some(*inner),
563 Some(TypeExpr::Named(n)) if n == "string" => {
564 Some(TypeExpr::Named("string".into()))
565 }
566 _ => None,
567 };
568 loop_scope.define_var(variable, elem_type);
569 } else {
570 Self::define_pattern_vars(pattern, &mut loop_scope);
571 }
572 self.check_block(body, &mut loop_scope);
573 }
574
575 Node::WhileLoop { condition, body } => {
576 self.check_node(condition, scope);
577 let mut loop_scope = scope.child();
578 self.check_block(body, &mut loop_scope);
579 }
580
581 Node::TryCatch {
582 body,
583 error_var,
584 catch_body,
585 finally_body,
586 ..
587 } => {
588 let mut try_scope = scope.child();
589 self.check_block(body, &mut try_scope);
590 let mut catch_scope = scope.child();
591 if let Some(var) = error_var {
592 catch_scope.define_var(var, None);
593 }
594 self.check_block(catch_body, &mut catch_scope);
595 if let Some(fb) = finally_body {
596 let mut finally_scope = scope.child();
597 self.check_block(fb, &mut finally_scope);
598 }
599 }
600
601 Node::TryExpr { body } => {
602 let mut try_scope = scope.child();
603 self.check_block(body, &mut try_scope);
604 }
605
606 Node::ReturnStmt {
607 value: Some(val), ..
608 } => {
609 self.check_node(val, scope);
610 }
611
612 Node::Assignment {
613 target, value, op, ..
614 } => {
615 self.check_node(value, scope);
616 if let Node::Identifier(name) = &target.node {
617 if let Some(Some(var_type)) = scope.get_var(name) {
618 let value_type = self.infer_type(value, scope);
619 let assigned = if let Some(op) = op {
620 let var_inferred = scope.get_var(name).cloned().flatten();
621 infer_binary_op_type(op, &var_inferred, &value_type)
622 } else {
623 value_type
624 };
625 if let Some(actual) = &assigned {
626 if !self.types_compatible(var_type, actual, scope) {
627 self.error_at(
628 format!(
629 "Type mismatch: cannot assign {} to '{}' (declared as {})",
630 format_type(actual),
631 name,
632 format_type(var_type)
633 ),
634 span,
635 );
636 }
637 }
638 }
639 }
640 }
641
642 Node::TypeDecl { name, type_expr } => {
643 scope.type_aliases.insert(name.clone(), type_expr.clone());
644 }
645
646 Node::EnumDecl { name, variants } => {
647 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
648 scope.enums.insert(name.clone(), variant_names);
649 }
650
651 Node::StructDecl { name, fields } => {
652 let field_types: Vec<(String, InferredType)> = fields
653 .iter()
654 .map(|f| (f.name.clone(), f.type_expr.clone()))
655 .collect();
656 scope.structs.insert(name.clone(), field_types);
657 }
658
659 Node::InterfaceDecl { name, methods } => {
660 scope.interfaces.insert(name.clone(), methods.clone());
661 }
662
663 Node::ImplBlock {
664 type_name, methods, ..
665 } => {
666 let sigs: Vec<ImplMethodSig> = methods
668 .iter()
669 .filter_map(|m| {
670 if let Node::FnDecl { name, params, .. } = &m.node {
671 let param_count = params.iter().filter(|p| p.name != "self").count();
672 Some(ImplMethodSig {
673 name: name.clone(),
674 param_count,
675 })
676 } else {
677 None
678 }
679 })
680 .collect();
681 scope.impl_methods.insert(type_name.clone(), sigs);
682 for method_sn in methods {
683 self.check_node(method_sn, scope);
684 }
685 }
686
687 Node::TryOperator { operand } => {
688 self.check_node(operand, scope);
689 }
690
691 Node::MatchExpr { value, arms } => {
692 self.check_node(value, scope);
693 let value_type = self.infer_type(value, scope);
694 for arm in arms {
695 self.check_node(&arm.pattern, scope);
696 if let Some(ref vt) = value_type {
698 let value_type_name = format_type(vt);
699 let mismatch = match &arm.pattern.node {
700 Node::StringLiteral(_) => {
701 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
702 }
703 Node::IntLiteral(_) => {
704 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
705 && !self.types_compatible(
706 vt,
707 &TypeExpr::Named("float".into()),
708 scope,
709 )
710 }
711 Node::FloatLiteral(_) => {
712 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
713 && !self.types_compatible(
714 vt,
715 &TypeExpr::Named("int".into()),
716 scope,
717 )
718 }
719 Node::BoolLiteral(_) => {
720 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
721 }
722 _ => false,
723 };
724 if mismatch {
725 let pattern_type = match &arm.pattern.node {
726 Node::StringLiteral(_) => "string",
727 Node::IntLiteral(_) => "int",
728 Node::FloatLiteral(_) => "float",
729 Node::BoolLiteral(_) => "bool",
730 _ => unreachable!(),
731 };
732 self.warning_at(
733 format!(
734 "Match pattern type mismatch: matching {} against {} literal",
735 value_type_name, pattern_type
736 ),
737 arm.pattern.span,
738 );
739 }
740 }
741 let mut arm_scope = scope.child();
742 self.check_block(&arm.body, &mut arm_scope);
743 }
744 self.check_match_exhaustiveness(value, arms, scope, span);
745 }
746
747 Node::BinaryOp { op, left, right } => {
749 self.check_node(left, scope);
750 self.check_node(right, scope);
751 let lt = self.infer_type(left, scope);
753 let rt = self.infer_type(right, scope);
754 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
755 match op.as_str() {
756 "-" | "*" | "/" | "%" => {
757 let numeric = ["int", "float"];
758 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
759 self.warning_at(
760 format!(
761 "Operator '{op}' may not be valid for types {} and {}",
762 l, r
763 ),
764 span,
765 );
766 }
767 }
768 "+" => {
769 let valid = ["int", "float", "string", "list", "dict"];
771 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
772 self.warning_at(
773 format!(
774 "Operator '+' may not be valid for types {} and {}",
775 l, r
776 ),
777 span,
778 );
779 }
780 }
781 _ => {}
782 }
783 }
784 }
785 Node::UnaryOp { operand, .. } => {
786 self.check_node(operand, scope);
787 }
788 Node::MethodCall { object, args, .. }
789 | Node::OptionalMethodCall { object, args, .. } => {
790 self.check_node(object, scope);
791 for arg in args {
792 self.check_node(arg, scope);
793 }
794 }
795 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
796 self.check_node(object, scope);
797 }
798 Node::SubscriptAccess { object, index } => {
799 self.check_node(object, scope);
800 self.check_node(index, scope);
801 }
802 Node::SliceAccess { object, start, end } => {
803 self.check_node(object, scope);
804 if let Some(s) = start {
805 self.check_node(s, scope);
806 }
807 if let Some(e) = end {
808 self.check_node(e, scope);
809 }
810 }
811
812 _ => {}
814 }
815 }
816
817 fn check_fn_body(
818 &mut self,
819 type_params: &[TypeParam],
820 params: &[TypedParam],
821 return_type: &Option<TypeExpr>,
822 body: &[SNode],
823 ) {
824 let mut fn_scope = self.scope.child();
825 for tp in type_params {
828 fn_scope.generic_type_params.insert(tp.name.clone());
829 }
830 for param in params {
831 fn_scope.define_var(¶m.name, param.type_expr.clone());
832 if let Some(default) = ¶m.default_value {
833 self.check_node(default, &mut fn_scope);
834 }
835 }
836 self.check_block(body, &mut fn_scope);
837
838 if let Some(ret_type) = return_type {
840 for stmt in body {
841 self.check_return_type(stmt, ret_type, &fn_scope);
842 }
843 }
844 }
845
846 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
847 let span = snode.span;
848 match &snode.node {
849 Node::ReturnStmt { value: Some(val) } => {
850 let inferred = self.infer_type(val, scope);
851 if let Some(actual) = &inferred {
852 if !self.types_compatible(expected, actual, scope) {
853 self.error_at(
854 format!(
855 "Return type mismatch: expected {}, got {}",
856 format_type(expected),
857 format_type(actual)
858 ),
859 span,
860 );
861 }
862 }
863 }
864 Node::IfElse {
865 then_body,
866 else_body,
867 ..
868 } => {
869 for stmt in then_body {
870 self.check_return_type(stmt, expected, scope);
871 }
872 if let Some(else_body) = else_body {
873 for stmt in else_body {
874 self.check_return_type(stmt, expected, scope);
875 }
876 }
877 }
878 _ => {}
879 }
880 }
881
882 fn satisfies_interface(
888 &self,
889 type_name: &str,
890 interface_name: &str,
891 scope: &TypeScope,
892 ) -> bool {
893 let interface_methods = match scope.get_interface(interface_name) {
894 Some(methods) => methods,
895 None => return false,
896 };
897 let impl_methods = match scope.get_impl_methods(type_name) {
898 Some(methods) => methods,
899 None => return interface_methods.is_empty(),
900 };
901 interface_methods.iter().all(|iface_method| {
902 let iface_param_count = iface_method
903 .params
904 .iter()
905 .filter(|p| p.name != "self")
906 .count();
907 impl_methods.iter().any(|impl_method| {
908 impl_method.name == iface_method.name
909 && impl_method.param_count == iface_param_count
910 })
911 })
912 }
913
914 fn extract_nil_narrowing(
915 condition: &SNode,
916 scope: &TypeScope,
917 ) -> Option<(String, InferredType)> {
918 if let Node::BinaryOp { op, left, right } = &condition.node {
919 if op == "!=" {
920 let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
922 (left, right)
923 } else if matches!(left.node, Node::NilLiteral) {
924 (right, left)
925 } else {
926 return None;
927 };
928 let _ = nil_node;
929 if let Node::Identifier(name) = &var_node.node {
930 if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
932 let narrowed: Vec<TypeExpr> = members
933 .iter()
934 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
935 .cloned()
936 .collect();
937 return if narrowed.len() == 1 {
938 Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
939 } else if narrowed.is_empty() {
940 None
941 } else {
942 Some((name.clone(), Some(TypeExpr::Union(narrowed))))
943 };
944 }
945 }
946 }
947 }
948 None
949 }
950
951 fn check_match_exhaustiveness(
952 &mut self,
953 value: &SNode,
954 arms: &[MatchArm],
955 scope: &TypeScope,
956 span: Span,
957 ) {
958 let enum_name = match &value.node {
960 Node::PropertyAccess { object, property } if property == "variant" => {
961 match self.infer_type(object, scope) {
963 Some(TypeExpr::Named(name)) => {
964 if scope.get_enum(&name).is_some() {
965 Some(name)
966 } else {
967 None
968 }
969 }
970 _ => None,
971 }
972 }
973 _ => {
974 match self.infer_type(value, scope) {
976 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
977 _ => None,
978 }
979 }
980 };
981
982 let Some(enum_name) = enum_name else {
983 return;
984 };
985 let Some(variants) = scope.get_enum(&enum_name) else {
986 return;
987 };
988
989 let mut covered: Vec<String> = Vec::new();
991 let mut has_wildcard = false;
992
993 for arm in arms {
994 match &arm.pattern.node {
995 Node::StringLiteral(s) => covered.push(s.clone()),
997 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
999 has_wildcard = true;
1000 }
1001 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1003 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1005 _ => {
1006 has_wildcard = true;
1008 }
1009 }
1010 }
1011
1012 if has_wildcard {
1013 return;
1014 }
1015
1016 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1017 if !missing.is_empty() {
1018 let missing_str = missing
1019 .iter()
1020 .map(|s| format!("\"{}\"", s))
1021 .collect::<Vec<_>>()
1022 .join(", ");
1023 self.warning_at(
1024 format!(
1025 "Non-exhaustive match on enum {}: missing variants {}",
1026 enum_name, missing_str
1027 ),
1028 span,
1029 );
1030 }
1031 }
1032
1033 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1034 let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1036 if let Some(sig) = scope.get_fn(name).cloned() {
1037 if !has_spread
1038 && !is_builtin(name)
1039 && (args.len() < sig.required_params || args.len() > sig.params.len())
1040 {
1041 let expected = if sig.required_params == sig.params.len() {
1042 format!("{}", sig.params.len())
1043 } else {
1044 format!("{}-{}", sig.required_params, sig.params.len())
1045 };
1046 self.warning_at(
1047 format!(
1048 "Function '{}' expects {} arguments, got {}",
1049 name,
1050 expected,
1051 args.len()
1052 ),
1053 span,
1054 );
1055 }
1056 let call_scope = if sig.type_param_names.is_empty() {
1059 scope.clone()
1060 } else {
1061 let mut s = scope.child();
1062 for tp_name in &sig.type_param_names {
1063 s.generic_type_params.insert(tp_name.clone());
1064 }
1065 s
1066 };
1067 for (i, (arg, (param_name, param_type))) in
1068 args.iter().zip(sig.params.iter()).enumerate()
1069 {
1070 if let Some(expected) = param_type {
1071 let actual = self.infer_type(arg, scope);
1072 if let Some(actual) = &actual {
1073 if !self.types_compatible(expected, actual, &call_scope) {
1074 self.error_at(
1075 format!(
1076 "Argument {} ('{}'): expected {}, got {}",
1077 i + 1,
1078 param_name,
1079 format_type(expected),
1080 format_type(actual)
1081 ),
1082 arg.span,
1083 );
1084 }
1085 }
1086 }
1087 }
1088 if !sig.where_clauses.is_empty() {
1090 let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1092 for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1093 if let Some(TypeExpr::Named(param_ty)) = param_type {
1094 if sig.type_param_names.contains(param_ty) {
1095 if let Some(TypeExpr::Named(concrete)) = self.infer_type(arg, scope) {
1096 type_bindings.entry(param_ty.clone()).or_insert(concrete);
1097 }
1098 }
1099 }
1100 }
1101 for (type_param, bound) in &sig.where_clauses {
1102 if let Some(concrete_type) = type_bindings.get(type_param) {
1103 if !self.satisfies_interface(concrete_type, bound, scope) {
1104 self.warning_at(
1105 format!(
1106 "Type '{}' does not satisfy interface '{}': \
1107 required by constraint `where {}: {}`",
1108 concrete_type, bound, type_param, bound
1109 ),
1110 span,
1111 );
1112 }
1113 }
1114 }
1115 }
1116 }
1117 for arg in args {
1119 self.check_node(arg, scope);
1120 }
1121 }
1122
1123 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1125 match &snode.node {
1126 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1127 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1128 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1129 Some(TypeExpr::Named("string".into()))
1130 }
1131 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1132 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1133 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1134 Node::DictLiteral(entries) => {
1135 let mut fields = Vec::new();
1137 let mut all_string_keys = true;
1138 for entry in entries {
1139 if let Node::StringLiteral(key) = &entry.key.node {
1140 let val_type = self
1141 .infer_type(&entry.value, scope)
1142 .unwrap_or(TypeExpr::Named("nil".into()));
1143 fields.push(ShapeField {
1144 name: key.clone(),
1145 type_expr: val_type,
1146 optional: false,
1147 });
1148 } else {
1149 all_string_keys = false;
1150 break;
1151 }
1152 }
1153 if all_string_keys && !fields.is_empty() {
1154 Some(TypeExpr::Shape(fields))
1155 } else {
1156 Some(TypeExpr::Named("dict".into()))
1157 }
1158 }
1159 Node::Closure { params, body } => {
1160 let all_typed = params.iter().all(|p| p.type_expr.is_some());
1162 if all_typed && !params.is_empty() {
1163 let param_types: Vec<TypeExpr> =
1164 params.iter().filter_map(|p| p.type_expr.clone()).collect();
1165 let ret = body.last().and_then(|last| self.infer_type(last, scope));
1167 if let Some(ret_type) = ret {
1168 return Some(TypeExpr::FnType {
1169 params: param_types,
1170 return_type: Box::new(ret_type),
1171 });
1172 }
1173 }
1174 Some(TypeExpr::Named("closure".into()))
1175 }
1176
1177 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1178
1179 Node::FunctionCall { name, .. } => {
1180 if scope.get_struct(name).is_some() {
1182 return Some(TypeExpr::Named(name.clone()));
1183 }
1184 if let Some(sig) = scope.get_fn(name) {
1186 return sig.return_type.clone();
1187 }
1188 builtin_return_type(name)
1190 }
1191
1192 Node::BinaryOp { op, left, right } => {
1193 let lt = self.infer_type(left, scope);
1194 let rt = self.infer_type(right, scope);
1195 infer_binary_op_type(op, <, &rt)
1196 }
1197
1198 Node::UnaryOp { op, operand } => {
1199 let t = self.infer_type(operand, scope);
1200 match op.as_str() {
1201 "!" => Some(TypeExpr::Named("bool".into())),
1202 "-" => t, _ => None,
1204 }
1205 }
1206
1207 Node::Ternary {
1208 true_expr,
1209 false_expr,
1210 ..
1211 } => {
1212 let tt = self.infer_type(true_expr, scope);
1213 let ft = self.infer_type(false_expr, scope);
1214 match (&tt, &ft) {
1215 (Some(a), Some(b)) if a == b => tt,
1216 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1217 (Some(_), None) => tt,
1218 (None, Some(_)) => ft,
1219 (None, None) => None,
1220 }
1221 }
1222
1223 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1224
1225 Node::PropertyAccess { object, property } => {
1226 if let Node::Identifier(name) = &object.node {
1228 if scope.get_enum(name).is_some() {
1229 return Some(TypeExpr::Named(name.clone()));
1230 }
1231 }
1232 if property == "variant" {
1234 let obj_type = self.infer_type(object, scope);
1235 if let Some(TypeExpr::Named(name)) = &obj_type {
1236 if scope.get_enum(name).is_some() {
1237 return Some(TypeExpr::Named("string".into()));
1238 }
1239 }
1240 }
1241 let obj_type = self.infer_type(object, scope);
1243 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1244 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1245 return Some(field.type_expr.clone());
1246 }
1247 }
1248 None
1249 }
1250
1251 Node::SubscriptAccess { object, index } => {
1252 let obj_type = self.infer_type(object, scope);
1253 match &obj_type {
1254 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1255 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1256 Some(TypeExpr::Shape(fields)) => {
1257 if let Node::StringLiteral(key) = &index.node {
1259 fields
1260 .iter()
1261 .find(|f| &f.name == key)
1262 .map(|f| f.type_expr.clone())
1263 } else {
1264 None
1265 }
1266 }
1267 Some(TypeExpr::Named(n)) if n == "list" => None,
1268 Some(TypeExpr::Named(n)) if n == "dict" => None,
1269 Some(TypeExpr::Named(n)) if n == "string" => {
1270 Some(TypeExpr::Named("string".into()))
1271 }
1272 _ => None,
1273 }
1274 }
1275 Node::SliceAccess { object, .. } => {
1276 let obj_type = self.infer_type(object, scope);
1278 match &obj_type {
1279 Some(TypeExpr::List(_)) => obj_type,
1280 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1281 Some(TypeExpr::Named(n)) if n == "string" => {
1282 Some(TypeExpr::Named("string".into()))
1283 }
1284 _ => None,
1285 }
1286 }
1287 Node::MethodCall { object, method, .. }
1288 | Node::OptionalMethodCall { object, method, .. } => {
1289 let obj_type = self.infer_type(object, scope);
1290 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1291 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1292 match method.as_str() {
1293 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1295 Some(TypeExpr::Named("bool".into()))
1296 }
1297 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1299 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1301 | "pad_left" | "pad_right" | "repeat" | "join" => {
1302 Some(TypeExpr::Named("string".into()))
1303 }
1304 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1305 "filter" => {
1307 if is_dict {
1308 Some(TypeExpr::Named("dict".into()))
1309 } else {
1310 Some(TypeExpr::Named("list".into()))
1311 }
1312 }
1313 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1315 "reduce" | "find" | "first" | "last" => None,
1316 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1318 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1319 "to_string" => Some(TypeExpr::Named("string".into())),
1321 "to_int" => Some(TypeExpr::Named("int".into())),
1322 "to_float" => Some(TypeExpr::Named("float".into())),
1323 _ => None,
1324 }
1325 }
1326
1327 Node::TryOperator { operand } => {
1329 match self.infer_type(operand, scope) {
1330 Some(TypeExpr::Named(name)) if name == "Result" => None, _ => None,
1332 }
1333 }
1334
1335 _ => None,
1336 }
1337 }
1338
1339 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1341 if let TypeExpr::Named(name) = expected {
1343 if scope.is_generic_type_param(name) {
1344 return true;
1345 }
1346 }
1347 if let TypeExpr::Named(name) = actual {
1348 if scope.is_generic_type_param(name) {
1349 return true;
1350 }
1351 }
1352 let expected = self.resolve_alias(expected, scope);
1353 let actual = self.resolve_alias(actual, scope);
1354
1355 if let TypeExpr::Named(iface_name) = &expected {
1358 if scope.get_interface(iface_name).is_some() {
1359 if let TypeExpr::Named(type_name) = &actual {
1360 return self.satisfies_interface(type_name, iface_name, scope);
1361 }
1362 return false;
1363 }
1364 }
1365
1366 match (&expected, &actual) {
1367 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1368 (TypeExpr::Union(members), actual_type) => members
1369 .iter()
1370 .any(|m| self.types_compatible(m, actual_type, scope)),
1371 (expected_type, TypeExpr::Union(members)) => members
1372 .iter()
1373 .all(|m| self.types_compatible(expected_type, m, scope)),
1374 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1375 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1376 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1377 if expected_field.optional {
1378 return true;
1379 }
1380 af.iter().any(|actual_field| {
1381 actual_field.name == expected_field.name
1382 && self.types_compatible(
1383 &expected_field.type_expr,
1384 &actual_field.type_expr,
1385 scope,
1386 )
1387 })
1388 }),
1389 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1391 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1392 keys_ok
1393 && af
1394 .iter()
1395 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1396 }
1397 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1399 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1400 self.types_compatible(expected_inner, actual_inner, scope)
1401 }
1402 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1403 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1404 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1405 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1406 }
1407 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1408 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1409 (
1411 TypeExpr::FnType {
1412 params: ep,
1413 return_type: er,
1414 },
1415 TypeExpr::FnType {
1416 params: ap,
1417 return_type: ar,
1418 },
1419 ) => {
1420 ep.len() == ap.len()
1421 && ep
1422 .iter()
1423 .zip(ap.iter())
1424 .all(|(e, a)| self.types_compatible(e, a, scope))
1425 && self.types_compatible(er, ar, scope)
1426 }
1427 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1429 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1430 _ => false,
1431 }
1432 }
1433
1434 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1435 if let TypeExpr::Named(name) = ty {
1436 if let Some(resolved) = scope.resolve_type(name) {
1437 return resolved.clone();
1438 }
1439 }
1440 ty.clone()
1441 }
1442
1443 fn error_at(&mut self, message: String, span: Span) {
1444 self.diagnostics.push(TypeDiagnostic {
1445 message,
1446 severity: DiagnosticSeverity::Error,
1447 span: Some(span),
1448 });
1449 }
1450
1451 fn warning_at(&mut self, message: String, span: Span) {
1452 self.diagnostics.push(TypeDiagnostic {
1453 message,
1454 severity: DiagnosticSeverity::Warning,
1455 span: Some(span),
1456 });
1457 }
1458}
1459
1460impl Default for TypeChecker {
1461 fn default() -> Self {
1462 Self::new()
1463 }
1464}
1465
1466fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1468 match op {
1469 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1470 "+" => match (left, right) {
1471 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1472 match (l.as_str(), r.as_str()) {
1473 ("int", "int") => Some(TypeExpr::Named("int".into())),
1474 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1475 ("string", _) => Some(TypeExpr::Named("string".into())),
1476 ("list", "list") => Some(TypeExpr::Named("list".into())),
1477 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1478 _ => Some(TypeExpr::Named("string".into())),
1479 }
1480 }
1481 _ => None,
1482 },
1483 "-" | "*" | "/" | "%" => match (left, right) {
1484 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1485 match (l.as_str(), r.as_str()) {
1486 ("int", "int") => Some(TypeExpr::Named("int".into())),
1487 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1488 _ => None,
1489 }
1490 }
1491 _ => None,
1492 },
1493 "??" => match (left, right) {
1494 (Some(TypeExpr::Union(members)), _) => {
1495 let non_nil: Vec<_> = members
1496 .iter()
1497 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1498 .cloned()
1499 .collect();
1500 if non_nil.len() == 1 {
1501 Some(non_nil[0].clone())
1502 } else if non_nil.is_empty() {
1503 right.clone()
1504 } else {
1505 Some(TypeExpr::Union(non_nil))
1506 }
1507 }
1508 _ => right.clone(),
1509 },
1510 "|>" => None,
1511 _ => None,
1512 }
1513}
1514
1515pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1520 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1521 let mut details = Vec::new();
1522 for field in ef {
1523 if field.optional {
1524 continue;
1525 }
1526 match af.iter().find(|f| f.name == field.name) {
1527 None => details.push(format!(
1528 "missing field '{}' ({})",
1529 field.name,
1530 format_type(&field.type_expr)
1531 )),
1532 Some(actual_field) => {
1533 let e_str = format_type(&field.type_expr);
1534 let a_str = format_type(&actual_field.type_expr);
1535 if e_str != a_str {
1536 details.push(format!(
1537 "field '{}' has type {}, expected {}",
1538 field.name, a_str, e_str
1539 ));
1540 }
1541 }
1542 }
1543 }
1544 if details.is_empty() {
1545 None
1546 } else {
1547 Some(details.join("; "))
1548 }
1549 } else {
1550 None
1551 }
1552}
1553
1554pub fn format_type(ty: &TypeExpr) -> String {
1555 match ty {
1556 TypeExpr::Named(n) => n.clone(),
1557 TypeExpr::Union(types) => types
1558 .iter()
1559 .map(format_type)
1560 .collect::<Vec<_>>()
1561 .join(" | "),
1562 TypeExpr::Shape(fields) => {
1563 let inner: Vec<String> = fields
1564 .iter()
1565 .map(|f| {
1566 let opt = if f.optional { "?" } else { "" };
1567 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1568 })
1569 .collect();
1570 format!("{{{}}}", inner.join(", "))
1571 }
1572 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1573 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1574 TypeExpr::FnType {
1575 params,
1576 return_type,
1577 } => {
1578 let params_str = params
1579 .iter()
1580 .map(format_type)
1581 .collect::<Vec<_>>()
1582 .join(", ");
1583 format!("fn({}) -> {}", params_str, format_type(return_type))
1584 }
1585 }
1586}
1587
1588#[cfg(test)]
1589mod tests {
1590 use super::*;
1591 use crate::Parser;
1592 use harn_lexer::Lexer;
1593
1594 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1595 let mut lexer = Lexer::new(source);
1596 let tokens = lexer.tokenize().unwrap();
1597 let mut parser = Parser::new(tokens);
1598 let program = parser.parse().unwrap();
1599 TypeChecker::new().check(&program)
1600 }
1601
1602 fn errors(source: &str) -> Vec<String> {
1603 check_source(source)
1604 .into_iter()
1605 .filter(|d| d.severity == DiagnosticSeverity::Error)
1606 .map(|d| d.message)
1607 .collect()
1608 }
1609
1610 #[test]
1611 fn test_no_errors_for_untyped_code() {
1612 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1613 assert!(errs.is_empty());
1614 }
1615
1616 #[test]
1617 fn test_correct_typed_let() {
1618 let errs = errors("pipeline t(task) { let x: int = 42 }");
1619 assert!(errs.is_empty());
1620 }
1621
1622 #[test]
1623 fn test_type_mismatch_let() {
1624 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1625 assert_eq!(errs.len(), 1);
1626 assert!(errs[0].contains("Type mismatch"));
1627 assert!(errs[0].contains("int"));
1628 assert!(errs[0].contains("string"));
1629 }
1630
1631 #[test]
1632 fn test_correct_typed_fn() {
1633 let errs = errors(
1634 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1635 );
1636 assert!(errs.is_empty());
1637 }
1638
1639 #[test]
1640 fn test_fn_arg_type_mismatch() {
1641 let errs = errors(
1642 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1643add("hello", 2) }"#,
1644 );
1645 assert_eq!(errs.len(), 1);
1646 assert!(errs[0].contains("Argument 1"));
1647 assert!(errs[0].contains("expected int"));
1648 }
1649
1650 #[test]
1651 fn test_return_type_mismatch() {
1652 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1653 assert_eq!(errs.len(), 1);
1654 assert!(errs[0].contains("Return type mismatch"));
1655 }
1656
1657 #[test]
1658 fn test_union_type_compatible() {
1659 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1660 assert!(errs.is_empty());
1661 }
1662
1663 #[test]
1664 fn test_union_type_mismatch() {
1665 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1666 assert_eq!(errs.len(), 1);
1667 assert!(errs[0].contains("Type mismatch"));
1668 }
1669
1670 #[test]
1671 fn test_type_inference_propagation() {
1672 let errs = errors(
1673 r#"pipeline t(task) {
1674 fn add(a: int, b: int) -> int { return a + b }
1675 let result: string = add(1, 2)
1676}"#,
1677 );
1678 assert_eq!(errs.len(), 1);
1679 assert!(errs[0].contains("Type mismatch"));
1680 assert!(errs[0].contains("string"));
1681 assert!(errs[0].contains("int"));
1682 }
1683
1684 #[test]
1685 fn test_builtin_return_type_inference() {
1686 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1687 assert_eq!(errs.len(), 1);
1688 assert!(errs[0].contains("string"));
1689 assert!(errs[0].contains("int"));
1690 }
1691
1692 #[test]
1693 fn test_binary_op_type_inference() {
1694 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1695 assert_eq!(errs.len(), 1);
1696 }
1697
1698 #[test]
1699 fn test_comparison_returns_bool() {
1700 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1701 assert!(errs.is_empty());
1702 }
1703
1704 #[test]
1705 fn test_int_float_promotion() {
1706 let errs = errors("pipeline t(task) { let x: float = 42 }");
1707 assert!(errs.is_empty());
1708 }
1709
1710 #[test]
1711 fn test_untyped_code_no_errors() {
1712 let errs = errors(
1713 r#"pipeline t(task) {
1714 fn process(data) {
1715 let result = data + " processed"
1716 return result
1717 }
1718 log(process("hello"))
1719}"#,
1720 );
1721 assert!(errs.is_empty());
1722 }
1723
1724 #[test]
1725 fn test_type_alias() {
1726 let errs = errors(
1727 r#"pipeline t(task) {
1728 type Name = string
1729 let x: Name = "hello"
1730}"#,
1731 );
1732 assert!(errs.is_empty());
1733 }
1734
1735 #[test]
1736 fn test_type_alias_mismatch() {
1737 let errs = errors(
1738 r#"pipeline t(task) {
1739 type Name = string
1740 let x: Name = 42
1741}"#,
1742 );
1743 assert_eq!(errs.len(), 1);
1744 }
1745
1746 #[test]
1747 fn test_assignment_type_check() {
1748 let errs = errors(
1749 r#"pipeline t(task) {
1750 var x: int = 0
1751 x = "hello"
1752}"#,
1753 );
1754 assert_eq!(errs.len(), 1);
1755 assert!(errs[0].contains("cannot assign string"));
1756 }
1757
1758 #[test]
1759 fn test_covariance_int_to_float_in_fn() {
1760 let errs = errors(
1761 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1762 );
1763 assert!(errs.is_empty());
1764 }
1765
1766 #[test]
1767 fn test_covariance_return_type() {
1768 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1769 assert!(errs.is_empty());
1770 }
1771
1772 #[test]
1773 fn test_no_contravariance_float_to_int() {
1774 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1775 assert_eq!(errs.len(), 1);
1776 }
1777
1778 fn warnings(source: &str) -> Vec<String> {
1781 check_source(source)
1782 .into_iter()
1783 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1784 .map(|d| d.message)
1785 .collect()
1786 }
1787
1788 #[test]
1789 fn test_exhaustive_match_no_warning() {
1790 let warns = warnings(
1791 r#"pipeline t(task) {
1792 enum Color { Red, Green, Blue }
1793 let c = Color.Red
1794 match c.variant {
1795 "Red" -> { log("r") }
1796 "Green" -> { log("g") }
1797 "Blue" -> { log("b") }
1798 }
1799}"#,
1800 );
1801 let exhaustive_warns: Vec<_> = warns
1802 .iter()
1803 .filter(|w| w.contains("Non-exhaustive"))
1804 .collect();
1805 assert!(exhaustive_warns.is_empty());
1806 }
1807
1808 #[test]
1809 fn test_non_exhaustive_match_warning() {
1810 let warns = warnings(
1811 r#"pipeline t(task) {
1812 enum Color { Red, Green, Blue }
1813 let c = Color.Red
1814 match c.variant {
1815 "Red" -> { log("r") }
1816 "Green" -> { log("g") }
1817 }
1818}"#,
1819 );
1820 let exhaustive_warns: Vec<_> = warns
1821 .iter()
1822 .filter(|w| w.contains("Non-exhaustive"))
1823 .collect();
1824 assert_eq!(exhaustive_warns.len(), 1);
1825 assert!(exhaustive_warns[0].contains("Blue"));
1826 }
1827
1828 #[test]
1829 fn test_non_exhaustive_multiple_missing() {
1830 let warns = warnings(
1831 r#"pipeline t(task) {
1832 enum Status { Active, Inactive, Pending }
1833 let s = Status.Active
1834 match s.variant {
1835 "Active" -> { log("a") }
1836 }
1837}"#,
1838 );
1839 let exhaustive_warns: Vec<_> = warns
1840 .iter()
1841 .filter(|w| w.contains("Non-exhaustive"))
1842 .collect();
1843 assert_eq!(exhaustive_warns.len(), 1);
1844 assert!(exhaustive_warns[0].contains("Inactive"));
1845 assert!(exhaustive_warns[0].contains("Pending"));
1846 }
1847
1848 #[test]
1849 fn test_enum_construct_type_inference() {
1850 let errs = errors(
1851 r#"pipeline t(task) {
1852 enum Color { Red, Green, Blue }
1853 let c: Color = Color.Red
1854}"#,
1855 );
1856 assert!(errs.is_empty());
1857 }
1858
1859 #[test]
1862 fn test_nil_coalescing_strips_nil() {
1863 let errs = errors(
1865 r#"pipeline t(task) {
1866 let x: string | nil = nil
1867 let y: string = x ?? "default"
1868}"#,
1869 );
1870 assert!(errs.is_empty());
1871 }
1872
1873 #[test]
1874 fn test_shape_mismatch_detail_missing_field() {
1875 let errs = errors(
1876 r#"pipeline t(task) {
1877 let x: {name: string, age: int} = {name: "hello"}
1878}"#,
1879 );
1880 assert_eq!(errs.len(), 1);
1881 assert!(
1882 errs[0].contains("missing field 'age'"),
1883 "expected detail about missing field, got: {}",
1884 errs[0]
1885 );
1886 }
1887
1888 #[test]
1889 fn test_shape_mismatch_detail_wrong_type() {
1890 let errs = errors(
1891 r#"pipeline t(task) {
1892 let x: {name: string, age: int} = {name: 42, age: 10}
1893}"#,
1894 );
1895 assert_eq!(errs.len(), 1);
1896 assert!(
1897 errs[0].contains("field 'name' has type int, expected string"),
1898 "expected detail about wrong type, got: {}",
1899 errs[0]
1900 );
1901 }
1902
1903 #[test]
1906 fn test_match_pattern_string_against_int() {
1907 let warns = warnings(
1908 r#"pipeline t(task) {
1909 let x: int = 42
1910 match x {
1911 "hello" -> { log("bad") }
1912 42 -> { log("ok") }
1913 }
1914}"#,
1915 );
1916 let pattern_warns: Vec<_> = warns
1917 .iter()
1918 .filter(|w| w.contains("Match pattern type mismatch"))
1919 .collect();
1920 assert_eq!(pattern_warns.len(), 1);
1921 assert!(pattern_warns[0].contains("matching int against string literal"));
1922 }
1923
1924 #[test]
1925 fn test_match_pattern_int_against_string() {
1926 let warns = warnings(
1927 r#"pipeline t(task) {
1928 let x: string = "hello"
1929 match x {
1930 42 -> { log("bad") }
1931 "hello" -> { log("ok") }
1932 }
1933}"#,
1934 );
1935 let pattern_warns: Vec<_> = warns
1936 .iter()
1937 .filter(|w| w.contains("Match pattern type mismatch"))
1938 .collect();
1939 assert_eq!(pattern_warns.len(), 1);
1940 assert!(pattern_warns[0].contains("matching string against int literal"));
1941 }
1942
1943 #[test]
1944 fn test_match_pattern_bool_against_int() {
1945 let warns = warnings(
1946 r#"pipeline t(task) {
1947 let x: int = 42
1948 match x {
1949 true -> { log("bad") }
1950 42 -> { log("ok") }
1951 }
1952}"#,
1953 );
1954 let pattern_warns: Vec<_> = warns
1955 .iter()
1956 .filter(|w| w.contains("Match pattern type mismatch"))
1957 .collect();
1958 assert_eq!(pattern_warns.len(), 1);
1959 assert!(pattern_warns[0].contains("matching int against bool literal"));
1960 }
1961
1962 #[test]
1963 fn test_match_pattern_float_against_string() {
1964 let warns = warnings(
1965 r#"pipeline t(task) {
1966 let x: string = "hello"
1967 match x {
1968 3.14 -> { log("bad") }
1969 "hello" -> { log("ok") }
1970 }
1971}"#,
1972 );
1973 let pattern_warns: Vec<_> = warns
1974 .iter()
1975 .filter(|w| w.contains("Match pattern type mismatch"))
1976 .collect();
1977 assert_eq!(pattern_warns.len(), 1);
1978 assert!(pattern_warns[0].contains("matching string against float literal"));
1979 }
1980
1981 #[test]
1982 fn test_match_pattern_int_against_float_ok() {
1983 let warns = warnings(
1985 r#"pipeline t(task) {
1986 let x: float = 3.14
1987 match x {
1988 42 -> { log("ok") }
1989 _ -> { log("default") }
1990 }
1991}"#,
1992 );
1993 let pattern_warns: Vec<_> = warns
1994 .iter()
1995 .filter(|w| w.contains("Match pattern type mismatch"))
1996 .collect();
1997 assert!(pattern_warns.is_empty());
1998 }
1999
2000 #[test]
2001 fn test_match_pattern_float_against_int_ok() {
2002 let warns = warnings(
2004 r#"pipeline t(task) {
2005 let x: int = 42
2006 match x {
2007 3.14 -> { log("close") }
2008 _ -> { log("default") }
2009 }
2010}"#,
2011 );
2012 let pattern_warns: Vec<_> = warns
2013 .iter()
2014 .filter(|w| w.contains("Match pattern type mismatch"))
2015 .collect();
2016 assert!(pattern_warns.is_empty());
2017 }
2018
2019 #[test]
2020 fn test_match_pattern_correct_types_no_warning() {
2021 let warns = warnings(
2022 r#"pipeline t(task) {
2023 let x: int = 42
2024 match x {
2025 1 -> { log("one") }
2026 2 -> { log("two") }
2027 _ -> { log("other") }
2028 }
2029}"#,
2030 );
2031 let pattern_warns: Vec<_> = warns
2032 .iter()
2033 .filter(|w| w.contains("Match pattern type mismatch"))
2034 .collect();
2035 assert!(pattern_warns.is_empty());
2036 }
2037
2038 #[test]
2039 fn test_match_pattern_wildcard_no_warning() {
2040 let warns = warnings(
2041 r#"pipeline t(task) {
2042 let x: int = 42
2043 match x {
2044 _ -> { log("catch all") }
2045 }
2046}"#,
2047 );
2048 let pattern_warns: Vec<_> = warns
2049 .iter()
2050 .filter(|w| w.contains("Match pattern type mismatch"))
2051 .collect();
2052 assert!(pattern_warns.is_empty());
2053 }
2054
2055 #[test]
2056 fn test_match_pattern_untyped_no_warning() {
2057 let warns = warnings(
2059 r#"pipeline t(task) {
2060 let x = some_unknown_fn()
2061 match x {
2062 "hello" -> { log("string") }
2063 42 -> { log("int") }
2064 }
2065}"#,
2066 );
2067 let pattern_warns: Vec<_> = warns
2068 .iter()
2069 .filter(|w| w.contains("Match pattern type mismatch"))
2070 .collect();
2071 assert!(pattern_warns.is_empty());
2072 }
2073}