1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 generic_type_params: std::collections::BTreeSet<String>,
40 parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45 params: Vec<(String, InferredType)>,
46 return_type: InferredType,
47 type_param_names: Vec<String>,
49 required_params: usize,
51}
52
53impl TypeScope {
54 fn new() -> Self {
55 Self {
56 vars: BTreeMap::new(),
57 functions: BTreeMap::new(),
58 type_aliases: BTreeMap::new(),
59 enums: BTreeMap::new(),
60 interfaces: BTreeMap::new(),
61 structs: BTreeMap::new(),
62 generic_type_params: std::collections::BTreeSet::new(),
63 parent: None,
64 }
65 }
66
67 fn child(&self) -> Self {
68 Self {
69 vars: BTreeMap::new(),
70 functions: BTreeMap::new(),
71 type_aliases: BTreeMap::new(),
72 enums: BTreeMap::new(),
73 interfaces: BTreeMap::new(),
74 structs: BTreeMap::new(),
75 generic_type_params: std::collections::BTreeSet::new(),
76 parent: Some(Box::new(self.clone())),
77 }
78 }
79
80 fn get_var(&self, name: &str) -> Option<&InferredType> {
81 self.vars
82 .get(name)
83 .or_else(|| self.parent.as_ref()?.get_var(name))
84 }
85
86 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87 self.functions
88 .get(name)
89 .or_else(|| self.parent.as_ref()?.get_fn(name))
90 }
91
92 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93 self.type_aliases
94 .get(name)
95 .or_else(|| self.parent.as_ref()?.resolve_type(name))
96 }
97
98 fn is_generic_type_param(&self, name: &str) -> bool {
99 self.generic_type_params.contains(name)
100 || self
101 .parent
102 .as_ref()
103 .is_some_and(|p| p.is_generic_type_param(name))
104 }
105
106 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107 self.enums
108 .get(name)
109 .or_else(|| self.parent.as_ref()?.get_enum(name))
110 }
111
112 #[allow(dead_code)]
113 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114 self.interfaces
115 .get(name)
116 .or_else(|| self.parent.as_ref()?.get_interface(name))
117 }
118
119 fn define_var(&mut self, name: &str, ty: InferredType) {
120 self.vars.insert(name.to_string(), ty);
121 }
122
123 fn define_fn(&mut self, name: &str, sig: FnSignature) {
124 self.functions.insert(name.to_string(), sig);
125 }
126}
127
128fn builtin_return_type(name: &str) -> InferredType {
130 match name {
131 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133 Some(TypeExpr::Named("nil".into()))
134 }
135 "type_of"
136 | "to_string"
137 | "json_stringify"
138 | "read_file"
139 | "http_get"
140 | "http_post"
141 | "llm_call"
142 | "regex_replace"
143 | "path_join"
144 | "temp_dir"
145 | "date_format"
146 | "format"
147 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
148 "to_int" | "timer_end" | "elapsed" => Some(TypeExpr::Named("int".into())),
149 "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
150 "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
151 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" => {
152 Some(TypeExpr::Named("list".into()))
153 }
154 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
155 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
156 Some(TypeExpr::Named("dict".into()))
157 }
158 "metadata_set"
159 | "metadata_save"
160 | "metadata_refresh_hashes"
161 | "invalidate_facts"
162 | "log_json"
163 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
164 "env" | "regex_match" => Some(TypeExpr::Union(vec![
165 TypeExpr::Named("string".into()),
166 TypeExpr::Named("nil".into()),
167 ])),
168 "json_parse" | "json_extract" => None, _ => None,
170 }
171}
172
173fn is_builtin(name: &str) -> bool {
175 matches!(
176 name,
177 "log"
178 | "print"
179 | "println"
180 | "type_of"
181 | "to_string"
182 | "to_int"
183 | "to_float"
184 | "json_stringify"
185 | "json_parse"
186 | "env"
187 | "timestamp"
188 | "sleep"
189 | "read_file"
190 | "write_file"
191 | "exit"
192 | "regex_match"
193 | "regex_replace"
194 | "http_get"
195 | "http_post"
196 | "llm_call"
197 | "agent_loop"
198 | "await"
199 | "cancel"
200 | "file_exists"
201 | "delete_file"
202 | "list_dir"
203 | "mkdir"
204 | "path_join"
205 | "copy_file"
206 | "append_file"
207 | "temp_dir"
208 | "stat"
209 | "exec"
210 | "shell"
211 | "date_now"
212 | "date_format"
213 | "date_parse"
214 | "format"
215 | "json_validate"
216 | "json_extract"
217 | "trim"
218 | "lowercase"
219 | "uppercase"
220 | "split"
221 | "starts_with"
222 | "ends_with"
223 | "contains"
224 | "replace"
225 | "join"
226 | "len"
227 | "substring"
228 | "dirname"
229 | "basename"
230 | "extname"
231 )
232}
233
234pub struct TypeChecker {
236 diagnostics: Vec<TypeDiagnostic>,
237 scope: TypeScope,
238}
239
240impl TypeChecker {
241 pub fn new() -> Self {
242 Self {
243 diagnostics: Vec::new(),
244 scope: TypeScope::new(),
245 }
246 }
247
248 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
250 Self::register_declarations_into(&mut self.scope, program);
252
253 for snode in program {
255 if let Node::Pipeline { body, .. } = &snode.node {
256 Self::register_declarations_into(&mut self.scope, body);
257 }
258 }
259
260 for snode in program {
262 match &snode.node {
263 Node::Pipeline { params, body, .. } => {
264 let mut child = self.scope.child();
265 for p in params {
266 child.define_var(p, None);
267 }
268 self.check_block(body, &mut child);
269 }
270 Node::FnDecl {
271 name,
272 type_params,
273 params,
274 return_type,
275 body,
276 ..
277 } => {
278 let required_params =
279 params.iter().filter(|p| p.default_value.is_none()).count();
280 let sig = FnSignature {
281 params: params
282 .iter()
283 .map(|p| (p.name.clone(), p.type_expr.clone()))
284 .collect(),
285 return_type: return_type.clone(),
286 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
287 required_params,
288 };
289 self.scope.define_fn(name, sig);
290 self.check_fn_body(type_params, params, return_type, body);
291 }
292 _ => {
293 let mut scope = self.scope.clone();
294 self.check_node(snode, &mut scope);
295 for (name, ty) in scope.vars {
297 self.scope.vars.entry(name).or_insert(ty);
298 }
299 }
300 }
301 }
302
303 self.diagnostics
304 }
305
306 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
308 for snode in nodes {
309 match &snode.node {
310 Node::TypeDecl { name, type_expr } => {
311 scope.type_aliases.insert(name.clone(), type_expr.clone());
312 }
313 Node::EnumDecl { name, variants } => {
314 let variant_names: Vec<String> =
315 variants.iter().map(|v| v.name.clone()).collect();
316 scope.enums.insert(name.clone(), variant_names);
317 }
318 Node::InterfaceDecl { name, methods } => {
319 scope.interfaces.insert(name.clone(), methods.clone());
320 }
321 Node::StructDecl { name, fields } => {
322 let field_types: Vec<(String, InferredType)> = fields
323 .iter()
324 .map(|f| (f.name.clone(), f.type_expr.clone()))
325 .collect();
326 scope.structs.insert(name.clone(), field_types);
327 }
328 _ => {}
329 }
330 }
331 }
332
333 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
334 for stmt in stmts {
335 self.check_node(stmt, scope);
336 }
337 }
338
339 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
341 match pattern {
342 BindingPattern::Identifier(name) => {
343 scope.define_var(name, None);
344 }
345 BindingPattern::Dict(fields) => {
346 for field in fields {
347 let name = field.alias.as_deref().unwrap_or(&field.key);
348 scope.define_var(name, None);
349 }
350 }
351 BindingPattern::List(elements) => {
352 for elem in elements {
353 scope.define_var(&elem.name, None);
354 }
355 }
356 }
357 }
358
359 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
360 let span = snode.span;
361 match &snode.node {
362 Node::LetBinding {
363 pattern,
364 type_ann,
365 value,
366 } => {
367 let inferred = self.infer_type(value, scope);
368 if let BindingPattern::Identifier(name) = pattern {
369 if let Some(expected) = type_ann {
370 if let Some(actual) = &inferred {
371 if !self.types_compatible(expected, actual, scope) {
372 let mut msg = format!(
373 "Type mismatch: '{}' declared as {}, but assigned {}",
374 name,
375 format_type(expected),
376 format_type(actual)
377 );
378 if let Some(detail) = shape_mismatch_detail(expected, actual) {
379 msg.push_str(&format!(" ({})", detail));
380 }
381 self.error_at(msg, span);
382 }
383 }
384 }
385 let ty = type_ann.clone().or(inferred);
386 scope.define_var(name, ty);
387 } else {
388 Self::define_pattern_vars(pattern, scope);
389 }
390 }
391
392 Node::VarBinding {
393 pattern,
394 type_ann,
395 value,
396 } => {
397 let inferred = self.infer_type(value, scope);
398 if let BindingPattern::Identifier(name) = pattern {
399 if let Some(expected) = type_ann {
400 if let Some(actual) = &inferred {
401 if !self.types_compatible(expected, actual, scope) {
402 let mut msg = format!(
403 "Type mismatch: '{}' declared as {}, but assigned {}",
404 name,
405 format_type(expected),
406 format_type(actual)
407 );
408 if let Some(detail) = shape_mismatch_detail(expected, actual) {
409 msg.push_str(&format!(" ({})", detail));
410 }
411 self.error_at(msg, span);
412 }
413 }
414 }
415 let ty = type_ann.clone().or(inferred);
416 scope.define_var(name, ty);
417 } else {
418 Self::define_pattern_vars(pattern, scope);
419 }
420 }
421
422 Node::FnDecl {
423 name,
424 type_params,
425 params,
426 return_type,
427 body,
428 ..
429 } => {
430 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
431 let sig = FnSignature {
432 params: params
433 .iter()
434 .map(|p| (p.name.clone(), p.type_expr.clone()))
435 .collect(),
436 return_type: return_type.clone(),
437 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
438 required_params,
439 };
440 scope.define_fn(name, sig.clone());
441 scope.define_var(name, None);
442 self.check_fn_body(type_params, params, return_type, body);
443 }
444
445 Node::FunctionCall { name, args } => {
446 self.check_call(name, args, scope, span);
447 }
448
449 Node::IfElse {
450 condition,
451 then_body,
452 else_body,
453 } => {
454 self.check_node(condition, scope);
455 let mut then_scope = scope.child();
456 self.check_block(then_body, &mut then_scope);
457 if let Some(else_body) = else_body {
458 let mut else_scope = scope.child();
459 self.check_block(else_body, &mut else_scope);
460 }
461 }
462
463 Node::ForIn {
464 pattern,
465 iterable,
466 body,
467 } => {
468 self.check_node(iterable, scope);
469 let mut loop_scope = scope.child();
470 if let BindingPattern::Identifier(variable) = pattern {
471 let elem_type = match self.infer_type(iterable, scope) {
473 Some(TypeExpr::List(inner)) => Some(*inner),
474 Some(TypeExpr::Named(n)) if n == "string" => {
475 Some(TypeExpr::Named("string".into()))
476 }
477 _ => None,
478 };
479 loop_scope.define_var(variable, elem_type);
480 } else {
481 Self::define_pattern_vars(pattern, &mut loop_scope);
482 }
483 self.check_block(body, &mut loop_scope);
484 }
485
486 Node::WhileLoop { condition, body } => {
487 self.check_node(condition, scope);
488 let mut loop_scope = scope.child();
489 self.check_block(body, &mut loop_scope);
490 }
491
492 Node::TryCatch {
493 body,
494 error_var,
495 catch_body,
496 finally_body,
497 ..
498 } => {
499 let mut try_scope = scope.child();
500 self.check_block(body, &mut try_scope);
501 let mut catch_scope = scope.child();
502 if let Some(var) = error_var {
503 catch_scope.define_var(var, None);
504 }
505 self.check_block(catch_body, &mut catch_scope);
506 if let Some(fb) = finally_body {
507 let mut finally_scope = scope.child();
508 self.check_block(fb, &mut finally_scope);
509 }
510 }
511
512 Node::ReturnStmt {
513 value: Some(val), ..
514 } => {
515 self.check_node(val, scope);
516 }
517
518 Node::Assignment {
519 target, value, op, ..
520 } => {
521 self.check_node(value, scope);
522 if let Node::Identifier(name) = &target.node {
523 if let Some(Some(var_type)) = scope.get_var(name) {
524 let value_type = self.infer_type(value, scope);
525 let assigned = if let Some(op) = op {
526 let var_inferred = scope.get_var(name).cloned().flatten();
527 infer_binary_op_type(op, &var_inferred, &value_type)
528 } else {
529 value_type
530 };
531 if let Some(actual) = &assigned {
532 if !self.types_compatible(var_type, actual, scope) {
533 self.error_at(
534 format!(
535 "Type mismatch: cannot assign {} to '{}' (declared as {})",
536 format_type(actual),
537 name,
538 format_type(var_type)
539 ),
540 span,
541 );
542 }
543 }
544 }
545 }
546 }
547
548 Node::TypeDecl { name, type_expr } => {
549 scope.type_aliases.insert(name.clone(), type_expr.clone());
550 }
551
552 Node::EnumDecl { name, variants } => {
553 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
554 scope.enums.insert(name.clone(), variant_names);
555 }
556
557 Node::StructDecl { name, fields } => {
558 let field_types: Vec<(String, InferredType)> = fields
559 .iter()
560 .map(|f| (f.name.clone(), f.type_expr.clone()))
561 .collect();
562 scope.structs.insert(name.clone(), field_types);
563 }
564
565 Node::InterfaceDecl { name, methods } => {
566 scope.interfaces.insert(name.clone(), methods.clone());
567 }
568
569 Node::MatchExpr { value, arms } => {
570 self.check_node(value, scope);
571 let value_type = self.infer_type(value, scope);
572 for arm in arms {
573 self.check_node(&arm.pattern, scope);
574 if let Some(ref vt) = value_type {
576 let value_type_name = format_type(vt);
577 let mismatch = match &arm.pattern.node {
578 Node::StringLiteral(_) => {
579 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
580 }
581 Node::IntLiteral(_) => {
582 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
583 && !self.types_compatible(
584 vt,
585 &TypeExpr::Named("float".into()),
586 scope,
587 )
588 }
589 Node::FloatLiteral(_) => {
590 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
591 && !self.types_compatible(
592 vt,
593 &TypeExpr::Named("int".into()),
594 scope,
595 )
596 }
597 Node::BoolLiteral(_) => {
598 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
599 }
600 _ => false,
601 };
602 if mismatch {
603 let pattern_type = match &arm.pattern.node {
604 Node::StringLiteral(_) => "string",
605 Node::IntLiteral(_) => "int",
606 Node::FloatLiteral(_) => "float",
607 Node::BoolLiteral(_) => "bool",
608 _ => unreachable!(),
609 };
610 self.warning_at(
611 format!(
612 "Match pattern type mismatch: matching {} against {} literal",
613 value_type_name, pattern_type
614 ),
615 arm.pattern.span,
616 );
617 }
618 }
619 let mut arm_scope = scope.child();
620 self.check_block(&arm.body, &mut arm_scope);
621 }
622 self.check_match_exhaustiveness(value, arms, scope, span);
623 }
624
625 Node::BinaryOp { op, left, right } => {
627 self.check_node(left, scope);
628 self.check_node(right, scope);
629 let lt = self.infer_type(left, scope);
631 let rt = self.infer_type(right, scope);
632 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
633 match op.as_str() {
634 "-" | "*" | "/" | "%" => {
635 let numeric = ["int", "float"];
636 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
637 self.warning_at(
638 format!(
639 "Operator '{op}' may not be valid for types {} and {}",
640 l, r
641 ),
642 span,
643 );
644 }
645 }
646 "+" => {
647 let valid = ["int", "float", "string", "list", "dict"];
649 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
650 self.warning_at(
651 format!(
652 "Operator '+' may not be valid for types {} and {}",
653 l, r
654 ),
655 span,
656 );
657 }
658 }
659 _ => {}
660 }
661 }
662 }
663 Node::UnaryOp { operand, .. } => {
664 self.check_node(operand, scope);
665 }
666 Node::MethodCall { object, args, .. }
667 | Node::OptionalMethodCall { object, args, .. } => {
668 self.check_node(object, scope);
669 for arg in args {
670 self.check_node(arg, scope);
671 }
672 }
673 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
674 self.check_node(object, scope);
675 }
676 Node::SubscriptAccess { object, index } => {
677 self.check_node(object, scope);
678 self.check_node(index, scope);
679 }
680 Node::SliceAccess { object, start, end } => {
681 self.check_node(object, scope);
682 if let Some(s) = start {
683 self.check_node(s, scope);
684 }
685 if let Some(e) = end {
686 self.check_node(e, scope);
687 }
688 }
689
690 _ => {}
692 }
693 }
694
695 fn check_fn_body(
696 &mut self,
697 type_params: &[TypeParam],
698 params: &[TypedParam],
699 return_type: &Option<TypeExpr>,
700 body: &[SNode],
701 ) {
702 let mut fn_scope = self.scope.child();
703 for tp in type_params {
706 fn_scope.generic_type_params.insert(tp.name.clone());
707 }
708 for param in params {
709 fn_scope.define_var(¶m.name, param.type_expr.clone());
710 if let Some(default) = ¶m.default_value {
711 self.check_node(default, &mut fn_scope);
712 }
713 }
714 self.check_block(body, &mut fn_scope);
715
716 if let Some(ret_type) = return_type {
718 for stmt in body {
719 self.check_return_type(stmt, ret_type, &fn_scope);
720 }
721 }
722 }
723
724 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
725 let span = snode.span;
726 match &snode.node {
727 Node::ReturnStmt { value: Some(val) } => {
728 let inferred = self.infer_type(val, scope);
729 if let Some(actual) = &inferred {
730 if !self.types_compatible(expected, actual, scope) {
731 self.error_at(
732 format!(
733 "Return type mismatch: expected {}, got {}",
734 format_type(expected),
735 format_type(actual)
736 ),
737 span,
738 );
739 }
740 }
741 }
742 Node::IfElse {
743 then_body,
744 else_body,
745 ..
746 } => {
747 for stmt in then_body {
748 self.check_return_type(stmt, expected, scope);
749 }
750 if let Some(else_body) = else_body {
751 for stmt in else_body {
752 self.check_return_type(stmt, expected, scope);
753 }
754 }
755 }
756 _ => {}
757 }
758 }
759
760 fn check_match_exhaustiveness(
762 &mut self,
763 value: &SNode,
764 arms: &[MatchArm],
765 scope: &TypeScope,
766 span: Span,
767 ) {
768 let enum_name = match &value.node {
770 Node::PropertyAccess { object, property } if property == "variant" => {
771 match self.infer_type(object, scope) {
773 Some(TypeExpr::Named(name)) => {
774 if scope.get_enum(&name).is_some() {
775 Some(name)
776 } else {
777 None
778 }
779 }
780 _ => None,
781 }
782 }
783 _ => {
784 match self.infer_type(value, scope) {
786 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
787 _ => None,
788 }
789 }
790 };
791
792 let Some(enum_name) = enum_name else {
793 return;
794 };
795 let Some(variants) = scope.get_enum(&enum_name) else {
796 return;
797 };
798
799 let mut covered: Vec<String> = Vec::new();
801 let mut has_wildcard = false;
802
803 for arm in arms {
804 match &arm.pattern.node {
805 Node::StringLiteral(s) => covered.push(s.clone()),
807 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
809 has_wildcard = true;
810 }
811 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
813 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
815 _ => {
816 has_wildcard = true;
818 }
819 }
820 }
821
822 if has_wildcard {
823 return;
824 }
825
826 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
827 if !missing.is_empty() {
828 let missing_str = missing
829 .iter()
830 .map(|s| format!("\"{}\"", s))
831 .collect::<Vec<_>>()
832 .join(", ");
833 self.warning_at(
834 format!(
835 "Non-exhaustive match on enum {}: missing variants {}",
836 enum_name, missing_str
837 ),
838 span,
839 );
840 }
841 }
842
843 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
844 if let Some(sig) = scope.get_fn(name).cloned() {
846 if !is_builtin(name)
847 && (args.len() < sig.required_params || args.len() > sig.params.len())
848 {
849 let expected = if sig.required_params == sig.params.len() {
850 format!("{}", sig.params.len())
851 } else {
852 format!("{}-{}", sig.required_params, sig.params.len())
853 };
854 self.warning_at(
855 format!(
856 "Function '{}' expects {} arguments, got {}",
857 name,
858 expected,
859 args.len()
860 ),
861 span,
862 );
863 }
864 let call_scope = if sig.type_param_names.is_empty() {
867 scope.clone()
868 } else {
869 let mut s = scope.child();
870 for tp_name in &sig.type_param_names {
871 s.generic_type_params.insert(tp_name.clone());
872 }
873 s
874 };
875 for (i, (arg, (param_name, param_type))) in
876 args.iter().zip(sig.params.iter()).enumerate()
877 {
878 if let Some(expected) = param_type {
879 let actual = self.infer_type(arg, scope);
880 if let Some(actual) = &actual {
881 if !self.types_compatible(expected, actual, &call_scope) {
882 self.error_at(
883 format!(
884 "Argument {} ('{}'): expected {}, got {}",
885 i + 1,
886 param_name,
887 format_type(expected),
888 format_type(actual)
889 ),
890 arg.span,
891 );
892 }
893 }
894 }
895 }
896 }
897 for arg in args {
899 self.check_node(arg, scope);
900 }
901 }
902
903 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
905 match &snode.node {
906 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
907 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
908 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
909 Some(TypeExpr::Named("string".into()))
910 }
911 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
912 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
913 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
914 Node::DictLiteral(entries) => {
915 let mut fields = Vec::new();
917 let mut all_string_keys = true;
918 for entry in entries {
919 if let Node::StringLiteral(key) = &entry.key.node {
920 let val_type = self
921 .infer_type(&entry.value, scope)
922 .unwrap_or(TypeExpr::Named("nil".into()));
923 fields.push(ShapeField {
924 name: key.clone(),
925 type_expr: val_type,
926 optional: false,
927 });
928 } else {
929 all_string_keys = false;
930 break;
931 }
932 }
933 if all_string_keys && !fields.is_empty() {
934 Some(TypeExpr::Shape(fields))
935 } else {
936 Some(TypeExpr::Named("dict".into()))
937 }
938 }
939 Node::Closure { params, body } => {
940 let all_typed = params.iter().all(|p| p.type_expr.is_some());
942 if all_typed && !params.is_empty() {
943 let param_types: Vec<TypeExpr> =
944 params.iter().filter_map(|p| p.type_expr.clone()).collect();
945 let ret = body.last().and_then(|last| self.infer_type(last, scope));
947 if let Some(ret_type) = ret {
948 return Some(TypeExpr::FnType {
949 params: param_types,
950 return_type: Box::new(ret_type),
951 });
952 }
953 }
954 Some(TypeExpr::Named("closure".into()))
955 }
956
957 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
958
959 Node::FunctionCall { name, .. } => {
960 if let Some(sig) = scope.get_fn(name) {
962 return sig.return_type.clone();
963 }
964 builtin_return_type(name)
966 }
967
968 Node::BinaryOp { op, left, right } => {
969 let lt = self.infer_type(left, scope);
970 let rt = self.infer_type(right, scope);
971 infer_binary_op_type(op, <, &rt)
972 }
973
974 Node::UnaryOp { op, operand } => {
975 let t = self.infer_type(operand, scope);
976 match op.as_str() {
977 "!" => Some(TypeExpr::Named("bool".into())),
978 "-" => t, _ => None,
980 }
981 }
982
983 Node::Ternary {
984 true_expr,
985 false_expr,
986 ..
987 } => {
988 let tt = self.infer_type(true_expr, scope);
989 let ft = self.infer_type(false_expr, scope);
990 match (&tt, &ft) {
991 (Some(a), Some(b)) if a == b => tt,
992 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
993 (Some(_), None) => tt,
994 (None, Some(_)) => ft,
995 (None, None) => None,
996 }
997 }
998
999 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1000
1001 Node::PropertyAccess { object, property } => {
1002 if let Node::Identifier(name) = &object.node {
1004 if scope.get_enum(name).is_some() {
1005 return Some(TypeExpr::Named(name.clone()));
1006 }
1007 }
1008 if property == "variant" {
1010 let obj_type = self.infer_type(object, scope);
1011 if let Some(TypeExpr::Named(name)) = &obj_type {
1012 if scope.get_enum(name).is_some() {
1013 return Some(TypeExpr::Named("string".into()));
1014 }
1015 }
1016 }
1017 let obj_type = self.infer_type(object, scope);
1019 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1020 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1021 return Some(field.type_expr.clone());
1022 }
1023 }
1024 None
1025 }
1026
1027 Node::SubscriptAccess { object, index } => {
1028 let obj_type = self.infer_type(object, scope);
1029 match &obj_type {
1030 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1031 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1032 Some(TypeExpr::Shape(fields)) => {
1033 if let Node::StringLiteral(key) = &index.node {
1035 fields
1036 .iter()
1037 .find(|f| &f.name == key)
1038 .map(|f| f.type_expr.clone())
1039 } else {
1040 None
1041 }
1042 }
1043 Some(TypeExpr::Named(n)) if n == "list" => None,
1044 Some(TypeExpr::Named(n)) if n == "dict" => None,
1045 Some(TypeExpr::Named(n)) if n == "string" => {
1046 Some(TypeExpr::Named("string".into()))
1047 }
1048 _ => None,
1049 }
1050 }
1051 Node::SliceAccess { object, .. } => {
1052 let obj_type = self.infer_type(object, scope);
1054 match &obj_type {
1055 Some(TypeExpr::List(_)) => obj_type,
1056 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1057 Some(TypeExpr::Named(n)) if n == "string" => {
1058 Some(TypeExpr::Named("string".into()))
1059 }
1060 _ => None,
1061 }
1062 }
1063 Node::MethodCall { object, method, .. }
1064 | Node::OptionalMethodCall { object, method, .. } => {
1065 let obj_type = self.infer_type(object, scope);
1066 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1067 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1068 match method.as_str() {
1069 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1071 Some(TypeExpr::Named("bool".into()))
1072 }
1073 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1075 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1077 | "pad_left" | "pad_right" | "repeat" | "join" => {
1078 Some(TypeExpr::Named("string".into()))
1079 }
1080 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1081 "filter" => {
1083 if is_dict {
1084 Some(TypeExpr::Named("dict".into()))
1085 } else {
1086 Some(TypeExpr::Named("list".into()))
1087 }
1088 }
1089 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1091 "reduce" | "find" | "first" | "last" => None,
1092 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1094 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1095 "to_string" => Some(TypeExpr::Named("string".into())),
1097 "to_int" => Some(TypeExpr::Named("int".into())),
1098 "to_float" => Some(TypeExpr::Named("float".into())),
1099 _ => None,
1100 }
1101 }
1102
1103 _ => None,
1104 }
1105 }
1106
1107 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1109 if let TypeExpr::Named(name) = expected {
1111 if scope.is_generic_type_param(name) {
1112 return true;
1113 }
1114 }
1115 if let TypeExpr::Named(name) = actual {
1116 if scope.is_generic_type_param(name) {
1117 return true;
1118 }
1119 }
1120 let expected = self.resolve_alias(expected, scope);
1121 let actual = self.resolve_alias(actual, scope);
1122
1123 match (&expected, &actual) {
1124 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1125 (TypeExpr::Union(members), actual_type) => members
1126 .iter()
1127 .any(|m| self.types_compatible(m, actual_type, scope)),
1128 (expected_type, TypeExpr::Union(members)) => members
1129 .iter()
1130 .all(|m| self.types_compatible(expected_type, m, scope)),
1131 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1132 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1133 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1134 if expected_field.optional {
1135 return true;
1136 }
1137 af.iter().any(|actual_field| {
1138 actual_field.name == expected_field.name
1139 && self.types_compatible(
1140 &expected_field.type_expr,
1141 &actual_field.type_expr,
1142 scope,
1143 )
1144 })
1145 }),
1146 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1148 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1149 keys_ok
1150 && af
1151 .iter()
1152 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1153 }
1154 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1156 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1157 self.types_compatible(expected_inner, actual_inner, scope)
1158 }
1159 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1160 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1161 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1162 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1163 }
1164 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1165 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1166 (
1168 TypeExpr::FnType {
1169 params: ep,
1170 return_type: er,
1171 },
1172 TypeExpr::FnType {
1173 params: ap,
1174 return_type: ar,
1175 },
1176 ) => {
1177 ep.len() == ap.len()
1178 && ep
1179 .iter()
1180 .zip(ap.iter())
1181 .all(|(e, a)| self.types_compatible(e, a, scope))
1182 && self.types_compatible(er, ar, scope)
1183 }
1184 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1186 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1187 _ => false,
1188 }
1189 }
1190
1191 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1192 if let TypeExpr::Named(name) = ty {
1193 if let Some(resolved) = scope.resolve_type(name) {
1194 return resolved.clone();
1195 }
1196 }
1197 ty.clone()
1198 }
1199
1200 fn error_at(&mut self, message: String, span: Span) {
1201 self.diagnostics.push(TypeDiagnostic {
1202 message,
1203 severity: DiagnosticSeverity::Error,
1204 span: Some(span),
1205 });
1206 }
1207
1208 fn warning_at(&mut self, message: String, span: Span) {
1209 self.diagnostics.push(TypeDiagnostic {
1210 message,
1211 severity: DiagnosticSeverity::Warning,
1212 span: Some(span),
1213 });
1214 }
1215}
1216
1217impl Default for TypeChecker {
1218 fn default() -> Self {
1219 Self::new()
1220 }
1221}
1222
1223fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1225 match op {
1226 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1227 "+" => match (left, right) {
1228 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1229 match (l.as_str(), r.as_str()) {
1230 ("int", "int") => Some(TypeExpr::Named("int".into())),
1231 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1232 ("string", _) => Some(TypeExpr::Named("string".into())),
1233 ("list", "list") => Some(TypeExpr::Named("list".into())),
1234 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1235 _ => Some(TypeExpr::Named("string".into())),
1236 }
1237 }
1238 _ => None,
1239 },
1240 "-" | "*" | "/" | "%" => match (left, right) {
1241 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1242 match (l.as_str(), r.as_str()) {
1243 ("int", "int") => Some(TypeExpr::Named("int".into())),
1244 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1245 _ => None,
1246 }
1247 }
1248 _ => None,
1249 },
1250 "??" => match (left, right) {
1251 (Some(TypeExpr::Union(members)), _) => {
1252 let non_nil: Vec<_> = members
1253 .iter()
1254 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1255 .cloned()
1256 .collect();
1257 if non_nil.len() == 1 {
1258 Some(non_nil[0].clone())
1259 } else if non_nil.is_empty() {
1260 right.clone()
1261 } else {
1262 Some(TypeExpr::Union(non_nil))
1263 }
1264 }
1265 _ => right.clone(),
1266 },
1267 "|>" => None,
1268 _ => None,
1269 }
1270}
1271
1272pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1277 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1278 let mut details = Vec::new();
1279 for field in ef {
1280 if field.optional {
1281 continue;
1282 }
1283 match af.iter().find(|f| f.name == field.name) {
1284 None => details.push(format!(
1285 "missing field '{}' ({})",
1286 field.name,
1287 format_type(&field.type_expr)
1288 )),
1289 Some(actual_field) => {
1290 let e_str = format_type(&field.type_expr);
1291 let a_str = format_type(&actual_field.type_expr);
1292 if e_str != a_str {
1293 details.push(format!(
1294 "field '{}' has type {}, expected {}",
1295 field.name, a_str, e_str
1296 ));
1297 }
1298 }
1299 }
1300 }
1301 if details.is_empty() {
1302 None
1303 } else {
1304 Some(details.join("; "))
1305 }
1306 } else {
1307 None
1308 }
1309}
1310
1311pub fn format_type(ty: &TypeExpr) -> String {
1312 match ty {
1313 TypeExpr::Named(n) => n.clone(),
1314 TypeExpr::Union(types) => types
1315 .iter()
1316 .map(format_type)
1317 .collect::<Vec<_>>()
1318 .join(" | "),
1319 TypeExpr::Shape(fields) => {
1320 let inner: Vec<String> = fields
1321 .iter()
1322 .map(|f| {
1323 let opt = if f.optional { "?" } else { "" };
1324 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1325 })
1326 .collect();
1327 format!("{{{}}}", inner.join(", "))
1328 }
1329 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1330 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1331 TypeExpr::FnType {
1332 params,
1333 return_type,
1334 } => {
1335 let params_str = params
1336 .iter()
1337 .map(format_type)
1338 .collect::<Vec<_>>()
1339 .join(", ");
1340 format!("fn({}) -> {}", params_str, format_type(return_type))
1341 }
1342 }
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347 use super::*;
1348 use crate::Parser;
1349 use harn_lexer::Lexer;
1350
1351 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1352 let mut lexer = Lexer::new(source);
1353 let tokens = lexer.tokenize().unwrap();
1354 let mut parser = Parser::new(tokens);
1355 let program = parser.parse().unwrap();
1356 TypeChecker::new().check(&program)
1357 }
1358
1359 fn errors(source: &str) -> Vec<String> {
1360 check_source(source)
1361 .into_iter()
1362 .filter(|d| d.severity == DiagnosticSeverity::Error)
1363 .map(|d| d.message)
1364 .collect()
1365 }
1366
1367 #[test]
1368 fn test_no_errors_for_untyped_code() {
1369 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1370 assert!(errs.is_empty());
1371 }
1372
1373 #[test]
1374 fn test_correct_typed_let() {
1375 let errs = errors("pipeline t(task) { let x: int = 42 }");
1376 assert!(errs.is_empty());
1377 }
1378
1379 #[test]
1380 fn test_type_mismatch_let() {
1381 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1382 assert_eq!(errs.len(), 1);
1383 assert!(errs[0].contains("Type mismatch"));
1384 assert!(errs[0].contains("int"));
1385 assert!(errs[0].contains("string"));
1386 }
1387
1388 #[test]
1389 fn test_correct_typed_fn() {
1390 let errs = errors(
1391 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1392 );
1393 assert!(errs.is_empty());
1394 }
1395
1396 #[test]
1397 fn test_fn_arg_type_mismatch() {
1398 let errs = errors(
1399 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1400add("hello", 2) }"#,
1401 );
1402 assert_eq!(errs.len(), 1);
1403 assert!(errs[0].contains("Argument 1"));
1404 assert!(errs[0].contains("expected int"));
1405 }
1406
1407 #[test]
1408 fn test_return_type_mismatch() {
1409 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1410 assert_eq!(errs.len(), 1);
1411 assert!(errs[0].contains("Return type mismatch"));
1412 }
1413
1414 #[test]
1415 fn test_union_type_compatible() {
1416 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1417 assert!(errs.is_empty());
1418 }
1419
1420 #[test]
1421 fn test_union_type_mismatch() {
1422 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1423 assert_eq!(errs.len(), 1);
1424 assert!(errs[0].contains("Type mismatch"));
1425 }
1426
1427 #[test]
1428 fn test_type_inference_propagation() {
1429 let errs = errors(
1430 r#"pipeline t(task) {
1431 fn add(a: int, b: int) -> int { return a + b }
1432 let result: string = add(1, 2)
1433}"#,
1434 );
1435 assert_eq!(errs.len(), 1);
1436 assert!(errs[0].contains("Type mismatch"));
1437 assert!(errs[0].contains("string"));
1438 assert!(errs[0].contains("int"));
1439 }
1440
1441 #[test]
1442 fn test_builtin_return_type_inference() {
1443 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1444 assert_eq!(errs.len(), 1);
1445 assert!(errs[0].contains("string"));
1446 assert!(errs[0].contains("int"));
1447 }
1448
1449 #[test]
1450 fn test_binary_op_type_inference() {
1451 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1452 assert_eq!(errs.len(), 1);
1453 }
1454
1455 #[test]
1456 fn test_comparison_returns_bool() {
1457 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1458 assert!(errs.is_empty());
1459 }
1460
1461 #[test]
1462 fn test_int_float_promotion() {
1463 let errs = errors("pipeline t(task) { let x: float = 42 }");
1464 assert!(errs.is_empty());
1465 }
1466
1467 #[test]
1468 fn test_untyped_code_no_errors() {
1469 let errs = errors(
1470 r#"pipeline t(task) {
1471 fn process(data) {
1472 let result = data + " processed"
1473 return result
1474 }
1475 log(process("hello"))
1476}"#,
1477 );
1478 assert!(errs.is_empty());
1479 }
1480
1481 #[test]
1482 fn test_type_alias() {
1483 let errs = errors(
1484 r#"pipeline t(task) {
1485 type Name = string
1486 let x: Name = "hello"
1487}"#,
1488 );
1489 assert!(errs.is_empty());
1490 }
1491
1492 #[test]
1493 fn test_type_alias_mismatch() {
1494 let errs = errors(
1495 r#"pipeline t(task) {
1496 type Name = string
1497 let x: Name = 42
1498}"#,
1499 );
1500 assert_eq!(errs.len(), 1);
1501 }
1502
1503 #[test]
1504 fn test_assignment_type_check() {
1505 let errs = errors(
1506 r#"pipeline t(task) {
1507 var x: int = 0
1508 x = "hello"
1509}"#,
1510 );
1511 assert_eq!(errs.len(), 1);
1512 assert!(errs[0].contains("cannot assign string"));
1513 }
1514
1515 #[test]
1516 fn test_covariance_int_to_float_in_fn() {
1517 let errs = errors(
1518 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1519 );
1520 assert!(errs.is_empty());
1521 }
1522
1523 #[test]
1524 fn test_covariance_return_type() {
1525 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1526 assert!(errs.is_empty());
1527 }
1528
1529 #[test]
1530 fn test_no_contravariance_float_to_int() {
1531 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1532 assert_eq!(errs.len(), 1);
1533 }
1534
1535 fn warnings(source: &str) -> Vec<String> {
1538 check_source(source)
1539 .into_iter()
1540 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1541 .map(|d| d.message)
1542 .collect()
1543 }
1544
1545 #[test]
1546 fn test_exhaustive_match_no_warning() {
1547 let warns = warnings(
1548 r#"pipeline t(task) {
1549 enum Color { Red, Green, Blue }
1550 let c = Color.Red
1551 match c.variant {
1552 "Red" -> { log("r") }
1553 "Green" -> { log("g") }
1554 "Blue" -> { log("b") }
1555 }
1556}"#,
1557 );
1558 let exhaustive_warns: Vec<_> = warns
1559 .iter()
1560 .filter(|w| w.contains("Non-exhaustive"))
1561 .collect();
1562 assert!(exhaustive_warns.is_empty());
1563 }
1564
1565 #[test]
1566 fn test_non_exhaustive_match_warning() {
1567 let warns = warnings(
1568 r#"pipeline t(task) {
1569 enum Color { Red, Green, Blue }
1570 let c = Color.Red
1571 match c.variant {
1572 "Red" -> { log("r") }
1573 "Green" -> { log("g") }
1574 }
1575}"#,
1576 );
1577 let exhaustive_warns: Vec<_> = warns
1578 .iter()
1579 .filter(|w| w.contains("Non-exhaustive"))
1580 .collect();
1581 assert_eq!(exhaustive_warns.len(), 1);
1582 assert!(exhaustive_warns[0].contains("Blue"));
1583 }
1584
1585 #[test]
1586 fn test_non_exhaustive_multiple_missing() {
1587 let warns = warnings(
1588 r#"pipeline t(task) {
1589 enum Status { Active, Inactive, Pending }
1590 let s = Status.Active
1591 match s.variant {
1592 "Active" -> { log("a") }
1593 }
1594}"#,
1595 );
1596 let exhaustive_warns: Vec<_> = warns
1597 .iter()
1598 .filter(|w| w.contains("Non-exhaustive"))
1599 .collect();
1600 assert_eq!(exhaustive_warns.len(), 1);
1601 assert!(exhaustive_warns[0].contains("Inactive"));
1602 assert!(exhaustive_warns[0].contains("Pending"));
1603 }
1604
1605 #[test]
1606 fn test_enum_construct_type_inference() {
1607 let errs = errors(
1608 r#"pipeline t(task) {
1609 enum Color { Red, Green, Blue }
1610 let c: Color = Color.Red
1611}"#,
1612 );
1613 assert!(errs.is_empty());
1614 }
1615
1616 #[test]
1619 fn test_nil_coalescing_strips_nil() {
1620 let errs = errors(
1622 r#"pipeline t(task) {
1623 let x: string | nil = nil
1624 let y: string = x ?? "default"
1625}"#,
1626 );
1627 assert!(errs.is_empty());
1628 }
1629
1630 #[test]
1631 fn test_shape_mismatch_detail_missing_field() {
1632 let errs = errors(
1633 r#"pipeline t(task) {
1634 let x: {name: string, age: int} = {name: "hello"}
1635}"#,
1636 );
1637 assert_eq!(errs.len(), 1);
1638 assert!(
1639 errs[0].contains("missing field 'age'"),
1640 "expected detail about missing field, got: {}",
1641 errs[0]
1642 );
1643 }
1644
1645 #[test]
1646 fn test_shape_mismatch_detail_wrong_type() {
1647 let errs = errors(
1648 r#"pipeline t(task) {
1649 let x: {name: string, age: int} = {name: 42, age: 10}
1650}"#,
1651 );
1652 assert_eq!(errs.len(), 1);
1653 assert!(
1654 errs[0].contains("field 'name' has type int, expected string"),
1655 "expected detail about wrong type, got: {}",
1656 errs[0]
1657 );
1658 }
1659
1660 #[test]
1663 fn test_match_pattern_string_against_int() {
1664 let warns = warnings(
1665 r#"pipeline t(task) {
1666 let x: int = 42
1667 match x {
1668 "hello" -> { log("bad") }
1669 42 -> { log("ok") }
1670 }
1671}"#,
1672 );
1673 let pattern_warns: Vec<_> = warns
1674 .iter()
1675 .filter(|w| w.contains("Match pattern type mismatch"))
1676 .collect();
1677 assert_eq!(pattern_warns.len(), 1);
1678 assert!(pattern_warns[0].contains("matching int against string literal"));
1679 }
1680
1681 #[test]
1682 fn test_match_pattern_int_against_string() {
1683 let warns = warnings(
1684 r#"pipeline t(task) {
1685 let x: string = "hello"
1686 match x {
1687 42 -> { log("bad") }
1688 "hello" -> { log("ok") }
1689 }
1690}"#,
1691 );
1692 let pattern_warns: Vec<_> = warns
1693 .iter()
1694 .filter(|w| w.contains("Match pattern type mismatch"))
1695 .collect();
1696 assert_eq!(pattern_warns.len(), 1);
1697 assert!(pattern_warns[0].contains("matching string against int literal"));
1698 }
1699
1700 #[test]
1701 fn test_match_pattern_bool_against_int() {
1702 let warns = warnings(
1703 r#"pipeline t(task) {
1704 let x: int = 42
1705 match x {
1706 true -> { log("bad") }
1707 42 -> { log("ok") }
1708 }
1709}"#,
1710 );
1711 let pattern_warns: Vec<_> = warns
1712 .iter()
1713 .filter(|w| w.contains("Match pattern type mismatch"))
1714 .collect();
1715 assert_eq!(pattern_warns.len(), 1);
1716 assert!(pattern_warns[0].contains("matching int against bool literal"));
1717 }
1718
1719 #[test]
1720 fn test_match_pattern_float_against_string() {
1721 let warns = warnings(
1722 r#"pipeline t(task) {
1723 let x: string = "hello"
1724 match x {
1725 3.14 -> { log("bad") }
1726 "hello" -> { log("ok") }
1727 }
1728}"#,
1729 );
1730 let pattern_warns: Vec<_> = warns
1731 .iter()
1732 .filter(|w| w.contains("Match pattern type mismatch"))
1733 .collect();
1734 assert_eq!(pattern_warns.len(), 1);
1735 assert!(pattern_warns[0].contains("matching string against float literal"));
1736 }
1737
1738 #[test]
1739 fn test_match_pattern_int_against_float_ok() {
1740 let warns = warnings(
1742 r#"pipeline t(task) {
1743 let x: float = 3.14
1744 match x {
1745 42 -> { log("ok") }
1746 _ -> { log("default") }
1747 }
1748}"#,
1749 );
1750 let pattern_warns: Vec<_> = warns
1751 .iter()
1752 .filter(|w| w.contains("Match pattern type mismatch"))
1753 .collect();
1754 assert!(pattern_warns.is_empty());
1755 }
1756
1757 #[test]
1758 fn test_match_pattern_float_against_int_ok() {
1759 let warns = warnings(
1761 r#"pipeline t(task) {
1762 let x: int = 42
1763 match x {
1764 3.14 -> { log("close") }
1765 _ -> { log("default") }
1766 }
1767}"#,
1768 );
1769 let pattern_warns: Vec<_> = warns
1770 .iter()
1771 .filter(|w| w.contains("Match pattern type mismatch"))
1772 .collect();
1773 assert!(pattern_warns.is_empty());
1774 }
1775
1776 #[test]
1777 fn test_match_pattern_correct_types_no_warning() {
1778 let warns = warnings(
1779 r#"pipeline t(task) {
1780 let x: int = 42
1781 match x {
1782 1 -> { log("one") }
1783 2 -> { log("two") }
1784 _ -> { log("other") }
1785 }
1786}"#,
1787 );
1788 let pattern_warns: Vec<_> = warns
1789 .iter()
1790 .filter(|w| w.contains("Match pattern type mismatch"))
1791 .collect();
1792 assert!(pattern_warns.is_empty());
1793 }
1794
1795 #[test]
1796 fn test_match_pattern_wildcard_no_warning() {
1797 let warns = warnings(
1798 r#"pipeline t(task) {
1799 let x: int = 42
1800 match x {
1801 _ -> { log("catch all") }
1802 }
1803}"#,
1804 );
1805 let pattern_warns: Vec<_> = warns
1806 .iter()
1807 .filter(|w| w.contains("Match pattern type mismatch"))
1808 .collect();
1809 assert!(pattern_warns.is_empty());
1810 }
1811
1812 #[test]
1813 fn test_match_pattern_untyped_no_warning() {
1814 let warns = warnings(
1816 r#"pipeline t(task) {
1817 let x = some_unknown_fn()
1818 match x {
1819 "hello" -> { log("string") }
1820 42 -> { log("int") }
1821 }
1822}"#,
1823 );
1824 let pattern_warns: Vec<_> = warns
1825 .iter()
1826 .filter(|w| w.contains("Match pattern type mismatch"))
1827 .collect();
1828 assert!(pattern_warns.is_empty());
1829 }
1830}