1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 generic_type_params: std::collections::BTreeSet<String>,
40 parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45 params: Vec<(String, InferredType)>,
46 return_type: InferredType,
47 type_param_names: Vec<String>,
49 required_params: usize,
51}
52
53impl TypeScope {
54 fn new() -> Self {
55 Self {
56 vars: BTreeMap::new(),
57 functions: BTreeMap::new(),
58 type_aliases: BTreeMap::new(),
59 enums: BTreeMap::new(),
60 interfaces: BTreeMap::new(),
61 structs: BTreeMap::new(),
62 generic_type_params: std::collections::BTreeSet::new(),
63 parent: None,
64 }
65 }
66
67 fn child(&self) -> Self {
68 Self {
69 vars: BTreeMap::new(),
70 functions: BTreeMap::new(),
71 type_aliases: BTreeMap::new(),
72 enums: BTreeMap::new(),
73 interfaces: BTreeMap::new(),
74 structs: BTreeMap::new(),
75 generic_type_params: std::collections::BTreeSet::new(),
76 parent: Some(Box::new(self.clone())),
77 }
78 }
79
80 fn get_var(&self, name: &str) -> Option<&InferredType> {
81 self.vars
82 .get(name)
83 .or_else(|| self.parent.as_ref()?.get_var(name))
84 }
85
86 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87 self.functions
88 .get(name)
89 .or_else(|| self.parent.as_ref()?.get_fn(name))
90 }
91
92 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93 self.type_aliases
94 .get(name)
95 .or_else(|| self.parent.as_ref()?.resolve_type(name))
96 }
97
98 fn is_generic_type_param(&self, name: &str) -> bool {
99 self.generic_type_params.contains(name)
100 || self
101 .parent
102 .as_ref()
103 .is_some_and(|p| p.is_generic_type_param(name))
104 }
105
106 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107 self.enums
108 .get(name)
109 .or_else(|| self.parent.as_ref()?.get_enum(name))
110 }
111
112 #[allow(dead_code)]
113 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114 self.interfaces
115 .get(name)
116 .or_else(|| self.parent.as_ref()?.get_interface(name))
117 }
118
119 fn define_var(&mut self, name: &str, ty: InferredType) {
120 self.vars.insert(name.to_string(), ty);
121 }
122
123 fn define_fn(&mut self, name: &str, sig: FnSignature) {
124 self.functions.insert(name.to_string(), sig);
125 }
126}
127
128fn builtin_return_type(name: &str) -> InferredType {
130 match name {
131 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133 Some(TypeExpr::Named("nil".into()))
134 }
135 "type_of"
136 | "to_string"
137 | "json_stringify"
138 | "read_file"
139 | "http_get"
140 | "http_post"
141 | "llm_call"
142 | "regex_replace"
143 | "path_join"
144 | "temp_dir"
145 | "date_format"
146 | "format"
147 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
148 "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
149 "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
150 | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
151 Some(TypeExpr::Named("float".into()))
152 }
153 "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
154 Some(TypeExpr::Named("bool".into()))
155 }
156 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list" => {
157 Some(TypeExpr::Named("list".into()))
158 }
159 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
160 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
161 Some(TypeExpr::Named("dict".into()))
162 }
163 "metadata_set"
164 | "metadata_save"
165 | "metadata_refresh_hashes"
166 | "invalidate_facts"
167 | "log_json"
168 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
169 "env" | "regex_match" => Some(TypeExpr::Union(vec![
170 TypeExpr::Named("string".into()),
171 TypeExpr::Named("nil".into()),
172 ])),
173 "json_parse" | "json_extract" => None, _ => None,
175 }
176}
177
178fn is_builtin(name: &str) -> bool {
180 matches!(
181 name,
182 "log"
183 | "print"
184 | "println"
185 | "type_of"
186 | "to_string"
187 | "to_int"
188 | "to_float"
189 | "json_stringify"
190 | "json_parse"
191 | "env"
192 | "timestamp"
193 | "sleep"
194 | "read_file"
195 | "write_file"
196 | "exit"
197 | "regex_match"
198 | "regex_replace"
199 | "http_get"
200 | "http_post"
201 | "llm_call"
202 | "agent_loop"
203 | "await"
204 | "cancel"
205 | "file_exists"
206 | "delete_file"
207 | "list_dir"
208 | "mkdir"
209 | "path_join"
210 | "copy_file"
211 | "append_file"
212 | "temp_dir"
213 | "stat"
214 | "exec"
215 | "shell"
216 | "date_now"
217 | "date_format"
218 | "date_parse"
219 | "format"
220 | "json_validate"
221 | "json_extract"
222 | "trim"
223 | "lowercase"
224 | "uppercase"
225 | "split"
226 | "starts_with"
227 | "ends_with"
228 | "contains"
229 | "replace"
230 | "join"
231 | "len"
232 | "substring"
233 | "dirname"
234 | "basename"
235 | "extname"
236 | "sin"
237 | "cos"
238 | "tan"
239 | "asin"
240 | "acos"
241 | "atan"
242 | "atan2"
243 | "log2"
244 | "log10"
245 | "ln"
246 | "exp"
247 | "pi"
248 | "e"
249 | "sign"
250 | "is_nan"
251 | "is_infinite"
252 | "set"
253 | "set_add"
254 | "set_remove"
255 | "set_contains"
256 | "set_union"
257 | "set_intersect"
258 | "set_difference"
259 | "to_list"
260 )
261}
262
263pub struct TypeChecker {
265 diagnostics: Vec<TypeDiagnostic>,
266 scope: TypeScope,
267}
268
269impl TypeChecker {
270 pub fn new() -> Self {
271 Self {
272 diagnostics: Vec::new(),
273 scope: TypeScope::new(),
274 }
275 }
276
277 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
279 Self::register_declarations_into(&mut self.scope, program);
281
282 for snode in program {
284 if let Node::Pipeline { body, .. } = &snode.node {
285 Self::register_declarations_into(&mut self.scope, body);
286 }
287 }
288
289 for snode in program {
291 match &snode.node {
292 Node::Pipeline { params, body, .. } => {
293 let mut child = self.scope.child();
294 for p in params {
295 child.define_var(p, None);
296 }
297 self.check_block(body, &mut child);
298 }
299 Node::FnDecl {
300 name,
301 type_params,
302 params,
303 return_type,
304 body,
305 ..
306 } => {
307 let required_params =
308 params.iter().filter(|p| p.default_value.is_none()).count();
309 let sig = FnSignature {
310 params: params
311 .iter()
312 .map(|p| (p.name.clone(), p.type_expr.clone()))
313 .collect(),
314 return_type: return_type.clone(),
315 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
316 required_params,
317 };
318 self.scope.define_fn(name, sig);
319 self.check_fn_body(type_params, params, return_type, body);
320 }
321 _ => {
322 let mut scope = self.scope.clone();
323 self.check_node(snode, &mut scope);
324 for (name, ty) in scope.vars {
326 self.scope.vars.entry(name).or_insert(ty);
327 }
328 }
329 }
330 }
331
332 self.diagnostics
333 }
334
335 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
337 for snode in nodes {
338 match &snode.node {
339 Node::TypeDecl { name, type_expr } => {
340 scope.type_aliases.insert(name.clone(), type_expr.clone());
341 }
342 Node::EnumDecl { name, variants } => {
343 let variant_names: Vec<String> =
344 variants.iter().map(|v| v.name.clone()).collect();
345 scope.enums.insert(name.clone(), variant_names);
346 }
347 Node::InterfaceDecl { name, methods } => {
348 scope.interfaces.insert(name.clone(), methods.clone());
349 }
350 Node::StructDecl { name, fields } => {
351 let field_types: Vec<(String, InferredType)> = fields
352 .iter()
353 .map(|f| (f.name.clone(), f.type_expr.clone()))
354 .collect();
355 scope.structs.insert(name.clone(), field_types);
356 }
357 _ => {}
358 }
359 }
360 }
361
362 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
363 for stmt in stmts {
364 self.check_node(stmt, scope);
365 }
366 }
367
368 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
370 match pattern {
371 BindingPattern::Identifier(name) => {
372 scope.define_var(name, None);
373 }
374 BindingPattern::Dict(fields) => {
375 for field in fields {
376 let name = field.alias.as_deref().unwrap_or(&field.key);
377 scope.define_var(name, None);
378 }
379 }
380 BindingPattern::List(elements) => {
381 for elem in elements {
382 scope.define_var(&elem.name, None);
383 }
384 }
385 }
386 }
387
388 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
389 let span = snode.span;
390 match &snode.node {
391 Node::LetBinding {
392 pattern,
393 type_ann,
394 value,
395 } => {
396 let inferred = self.infer_type(value, scope);
397 if let BindingPattern::Identifier(name) = pattern {
398 if let Some(expected) = type_ann {
399 if let Some(actual) = &inferred {
400 if !self.types_compatible(expected, actual, scope) {
401 let mut msg = format!(
402 "Type mismatch: '{}' declared as {}, but assigned {}",
403 name,
404 format_type(expected),
405 format_type(actual)
406 );
407 if let Some(detail) = shape_mismatch_detail(expected, actual) {
408 msg.push_str(&format!(" ({})", detail));
409 }
410 self.error_at(msg, span);
411 }
412 }
413 }
414 let ty = type_ann.clone().or(inferred);
415 scope.define_var(name, ty);
416 } else {
417 Self::define_pattern_vars(pattern, scope);
418 }
419 }
420
421 Node::VarBinding {
422 pattern,
423 type_ann,
424 value,
425 } => {
426 let inferred = self.infer_type(value, scope);
427 if let BindingPattern::Identifier(name) = pattern {
428 if let Some(expected) = type_ann {
429 if let Some(actual) = &inferred {
430 if !self.types_compatible(expected, actual, scope) {
431 let mut msg = format!(
432 "Type mismatch: '{}' declared as {}, but assigned {}",
433 name,
434 format_type(expected),
435 format_type(actual)
436 );
437 if let Some(detail) = shape_mismatch_detail(expected, actual) {
438 msg.push_str(&format!(" ({})", detail));
439 }
440 self.error_at(msg, span);
441 }
442 }
443 }
444 let ty = type_ann.clone().or(inferred);
445 scope.define_var(name, ty);
446 } else {
447 Self::define_pattern_vars(pattern, scope);
448 }
449 }
450
451 Node::FnDecl {
452 name,
453 type_params,
454 params,
455 return_type,
456 body,
457 ..
458 } => {
459 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
460 let sig = FnSignature {
461 params: params
462 .iter()
463 .map(|p| (p.name.clone(), p.type_expr.clone()))
464 .collect(),
465 return_type: return_type.clone(),
466 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
467 required_params,
468 };
469 scope.define_fn(name, sig.clone());
470 scope.define_var(name, None);
471 self.check_fn_body(type_params, params, return_type, body);
472 }
473
474 Node::FunctionCall { name, args } => {
475 self.check_call(name, args, scope, span);
476 }
477
478 Node::IfElse {
479 condition,
480 then_body,
481 else_body,
482 } => {
483 self.check_node(condition, scope);
484 let mut then_scope = scope.child();
485 self.check_block(then_body, &mut then_scope);
486 if let Some(else_body) = else_body {
487 let mut else_scope = scope.child();
488 self.check_block(else_body, &mut else_scope);
489 }
490 }
491
492 Node::ForIn {
493 pattern,
494 iterable,
495 body,
496 } => {
497 self.check_node(iterable, scope);
498 let mut loop_scope = scope.child();
499 if let BindingPattern::Identifier(variable) = pattern {
500 let elem_type = match self.infer_type(iterable, scope) {
502 Some(TypeExpr::List(inner)) => Some(*inner),
503 Some(TypeExpr::Named(n)) if n == "string" => {
504 Some(TypeExpr::Named("string".into()))
505 }
506 _ => None,
507 };
508 loop_scope.define_var(variable, elem_type);
509 } else {
510 Self::define_pattern_vars(pattern, &mut loop_scope);
511 }
512 self.check_block(body, &mut loop_scope);
513 }
514
515 Node::WhileLoop { condition, body } => {
516 self.check_node(condition, scope);
517 let mut loop_scope = scope.child();
518 self.check_block(body, &mut loop_scope);
519 }
520
521 Node::TryCatch {
522 body,
523 error_var,
524 catch_body,
525 finally_body,
526 ..
527 } => {
528 let mut try_scope = scope.child();
529 self.check_block(body, &mut try_scope);
530 let mut catch_scope = scope.child();
531 if let Some(var) = error_var {
532 catch_scope.define_var(var, None);
533 }
534 self.check_block(catch_body, &mut catch_scope);
535 if let Some(fb) = finally_body {
536 let mut finally_scope = scope.child();
537 self.check_block(fb, &mut finally_scope);
538 }
539 }
540
541 Node::ReturnStmt {
542 value: Some(val), ..
543 } => {
544 self.check_node(val, scope);
545 }
546
547 Node::Assignment {
548 target, value, op, ..
549 } => {
550 self.check_node(value, scope);
551 if let Node::Identifier(name) = &target.node {
552 if let Some(Some(var_type)) = scope.get_var(name) {
553 let value_type = self.infer_type(value, scope);
554 let assigned = if let Some(op) = op {
555 let var_inferred = scope.get_var(name).cloned().flatten();
556 infer_binary_op_type(op, &var_inferred, &value_type)
557 } else {
558 value_type
559 };
560 if let Some(actual) = &assigned {
561 if !self.types_compatible(var_type, actual, scope) {
562 self.error_at(
563 format!(
564 "Type mismatch: cannot assign {} to '{}' (declared as {})",
565 format_type(actual),
566 name,
567 format_type(var_type)
568 ),
569 span,
570 );
571 }
572 }
573 }
574 }
575 }
576
577 Node::TypeDecl { name, type_expr } => {
578 scope.type_aliases.insert(name.clone(), type_expr.clone());
579 }
580
581 Node::EnumDecl { name, variants } => {
582 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
583 scope.enums.insert(name.clone(), variant_names);
584 }
585
586 Node::StructDecl { name, fields } => {
587 let field_types: Vec<(String, InferredType)> = fields
588 .iter()
589 .map(|f| (f.name.clone(), f.type_expr.clone()))
590 .collect();
591 scope.structs.insert(name.clone(), field_types);
592 }
593
594 Node::InterfaceDecl { name, methods } => {
595 scope.interfaces.insert(name.clone(), methods.clone());
596 }
597
598 Node::MatchExpr { value, arms } => {
599 self.check_node(value, scope);
600 let value_type = self.infer_type(value, scope);
601 for arm in arms {
602 self.check_node(&arm.pattern, scope);
603 if let Some(ref vt) = value_type {
605 let value_type_name = format_type(vt);
606 let mismatch = match &arm.pattern.node {
607 Node::StringLiteral(_) => {
608 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
609 }
610 Node::IntLiteral(_) => {
611 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
612 && !self.types_compatible(
613 vt,
614 &TypeExpr::Named("float".into()),
615 scope,
616 )
617 }
618 Node::FloatLiteral(_) => {
619 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
620 && !self.types_compatible(
621 vt,
622 &TypeExpr::Named("int".into()),
623 scope,
624 )
625 }
626 Node::BoolLiteral(_) => {
627 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
628 }
629 _ => false,
630 };
631 if mismatch {
632 let pattern_type = match &arm.pattern.node {
633 Node::StringLiteral(_) => "string",
634 Node::IntLiteral(_) => "int",
635 Node::FloatLiteral(_) => "float",
636 Node::BoolLiteral(_) => "bool",
637 _ => unreachable!(),
638 };
639 self.warning_at(
640 format!(
641 "Match pattern type mismatch: matching {} against {} literal",
642 value_type_name, pattern_type
643 ),
644 arm.pattern.span,
645 );
646 }
647 }
648 let mut arm_scope = scope.child();
649 self.check_block(&arm.body, &mut arm_scope);
650 }
651 self.check_match_exhaustiveness(value, arms, scope, span);
652 }
653
654 Node::BinaryOp { op, left, right } => {
656 self.check_node(left, scope);
657 self.check_node(right, scope);
658 let lt = self.infer_type(left, scope);
660 let rt = self.infer_type(right, scope);
661 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
662 match op.as_str() {
663 "-" | "*" | "/" | "%" => {
664 let numeric = ["int", "float"];
665 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
666 self.warning_at(
667 format!(
668 "Operator '{op}' may not be valid for types {} and {}",
669 l, r
670 ),
671 span,
672 );
673 }
674 }
675 "+" => {
676 let valid = ["int", "float", "string", "list", "dict"];
678 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
679 self.warning_at(
680 format!(
681 "Operator '+' may not be valid for types {} and {}",
682 l, r
683 ),
684 span,
685 );
686 }
687 }
688 _ => {}
689 }
690 }
691 }
692 Node::UnaryOp { operand, .. } => {
693 self.check_node(operand, scope);
694 }
695 Node::MethodCall { object, args, .. }
696 | Node::OptionalMethodCall { object, args, .. } => {
697 self.check_node(object, scope);
698 for arg in args {
699 self.check_node(arg, scope);
700 }
701 }
702 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
703 self.check_node(object, scope);
704 }
705 Node::SubscriptAccess { object, index } => {
706 self.check_node(object, scope);
707 self.check_node(index, scope);
708 }
709 Node::SliceAccess { object, start, end } => {
710 self.check_node(object, scope);
711 if let Some(s) = start {
712 self.check_node(s, scope);
713 }
714 if let Some(e) = end {
715 self.check_node(e, scope);
716 }
717 }
718
719 _ => {}
721 }
722 }
723
724 fn check_fn_body(
725 &mut self,
726 type_params: &[TypeParam],
727 params: &[TypedParam],
728 return_type: &Option<TypeExpr>,
729 body: &[SNode],
730 ) {
731 let mut fn_scope = self.scope.child();
732 for tp in type_params {
735 fn_scope.generic_type_params.insert(tp.name.clone());
736 }
737 for param in params {
738 fn_scope.define_var(¶m.name, param.type_expr.clone());
739 if let Some(default) = ¶m.default_value {
740 self.check_node(default, &mut fn_scope);
741 }
742 }
743 self.check_block(body, &mut fn_scope);
744
745 if let Some(ret_type) = return_type {
747 for stmt in body {
748 self.check_return_type(stmt, ret_type, &fn_scope);
749 }
750 }
751 }
752
753 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
754 let span = snode.span;
755 match &snode.node {
756 Node::ReturnStmt { value: Some(val) } => {
757 let inferred = self.infer_type(val, scope);
758 if let Some(actual) = &inferred {
759 if !self.types_compatible(expected, actual, scope) {
760 self.error_at(
761 format!(
762 "Return type mismatch: expected {}, got {}",
763 format_type(expected),
764 format_type(actual)
765 ),
766 span,
767 );
768 }
769 }
770 }
771 Node::IfElse {
772 then_body,
773 else_body,
774 ..
775 } => {
776 for stmt in then_body {
777 self.check_return_type(stmt, expected, scope);
778 }
779 if let Some(else_body) = else_body {
780 for stmt in else_body {
781 self.check_return_type(stmt, expected, scope);
782 }
783 }
784 }
785 _ => {}
786 }
787 }
788
789 fn check_match_exhaustiveness(
791 &mut self,
792 value: &SNode,
793 arms: &[MatchArm],
794 scope: &TypeScope,
795 span: Span,
796 ) {
797 let enum_name = match &value.node {
799 Node::PropertyAccess { object, property } if property == "variant" => {
800 match self.infer_type(object, scope) {
802 Some(TypeExpr::Named(name)) => {
803 if scope.get_enum(&name).is_some() {
804 Some(name)
805 } else {
806 None
807 }
808 }
809 _ => None,
810 }
811 }
812 _ => {
813 match self.infer_type(value, scope) {
815 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
816 _ => None,
817 }
818 }
819 };
820
821 let Some(enum_name) = enum_name else {
822 return;
823 };
824 let Some(variants) = scope.get_enum(&enum_name) else {
825 return;
826 };
827
828 let mut covered: Vec<String> = Vec::new();
830 let mut has_wildcard = false;
831
832 for arm in arms {
833 match &arm.pattern.node {
834 Node::StringLiteral(s) => covered.push(s.clone()),
836 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
838 has_wildcard = true;
839 }
840 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
842 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
844 _ => {
845 has_wildcard = true;
847 }
848 }
849 }
850
851 if has_wildcard {
852 return;
853 }
854
855 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
856 if !missing.is_empty() {
857 let missing_str = missing
858 .iter()
859 .map(|s| format!("\"{}\"", s))
860 .collect::<Vec<_>>()
861 .join(", ");
862 self.warning_at(
863 format!(
864 "Non-exhaustive match on enum {}: missing variants {}",
865 enum_name, missing_str
866 ),
867 span,
868 );
869 }
870 }
871
872 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
873 if let Some(sig) = scope.get_fn(name).cloned() {
875 if !is_builtin(name)
876 && (args.len() < sig.required_params || args.len() > sig.params.len())
877 {
878 let expected = if sig.required_params == sig.params.len() {
879 format!("{}", sig.params.len())
880 } else {
881 format!("{}-{}", sig.required_params, sig.params.len())
882 };
883 self.warning_at(
884 format!(
885 "Function '{}' expects {} arguments, got {}",
886 name,
887 expected,
888 args.len()
889 ),
890 span,
891 );
892 }
893 let call_scope = if sig.type_param_names.is_empty() {
896 scope.clone()
897 } else {
898 let mut s = scope.child();
899 for tp_name in &sig.type_param_names {
900 s.generic_type_params.insert(tp_name.clone());
901 }
902 s
903 };
904 for (i, (arg, (param_name, param_type))) in
905 args.iter().zip(sig.params.iter()).enumerate()
906 {
907 if let Some(expected) = param_type {
908 let actual = self.infer_type(arg, scope);
909 if let Some(actual) = &actual {
910 if !self.types_compatible(expected, actual, &call_scope) {
911 self.error_at(
912 format!(
913 "Argument {} ('{}'): expected {}, got {}",
914 i + 1,
915 param_name,
916 format_type(expected),
917 format_type(actual)
918 ),
919 arg.span,
920 );
921 }
922 }
923 }
924 }
925 }
926 for arg in args {
928 self.check_node(arg, scope);
929 }
930 }
931
932 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
934 match &snode.node {
935 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
936 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
937 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
938 Some(TypeExpr::Named("string".into()))
939 }
940 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
941 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
942 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
943 Node::DictLiteral(entries) => {
944 let mut fields = Vec::new();
946 let mut all_string_keys = true;
947 for entry in entries {
948 if let Node::StringLiteral(key) = &entry.key.node {
949 let val_type = self
950 .infer_type(&entry.value, scope)
951 .unwrap_or(TypeExpr::Named("nil".into()));
952 fields.push(ShapeField {
953 name: key.clone(),
954 type_expr: val_type,
955 optional: false,
956 });
957 } else {
958 all_string_keys = false;
959 break;
960 }
961 }
962 if all_string_keys && !fields.is_empty() {
963 Some(TypeExpr::Shape(fields))
964 } else {
965 Some(TypeExpr::Named("dict".into()))
966 }
967 }
968 Node::Closure { params, body } => {
969 let all_typed = params.iter().all(|p| p.type_expr.is_some());
971 if all_typed && !params.is_empty() {
972 let param_types: Vec<TypeExpr> =
973 params.iter().filter_map(|p| p.type_expr.clone()).collect();
974 let ret = body.last().and_then(|last| self.infer_type(last, scope));
976 if let Some(ret_type) = ret {
977 return Some(TypeExpr::FnType {
978 params: param_types,
979 return_type: Box::new(ret_type),
980 });
981 }
982 }
983 Some(TypeExpr::Named("closure".into()))
984 }
985
986 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
987
988 Node::FunctionCall { name, .. } => {
989 if let Some(sig) = scope.get_fn(name) {
991 return sig.return_type.clone();
992 }
993 builtin_return_type(name)
995 }
996
997 Node::BinaryOp { op, left, right } => {
998 let lt = self.infer_type(left, scope);
999 let rt = self.infer_type(right, scope);
1000 infer_binary_op_type(op, <, &rt)
1001 }
1002
1003 Node::UnaryOp { op, operand } => {
1004 let t = self.infer_type(operand, scope);
1005 match op.as_str() {
1006 "!" => Some(TypeExpr::Named("bool".into())),
1007 "-" => t, _ => None,
1009 }
1010 }
1011
1012 Node::Ternary {
1013 true_expr,
1014 false_expr,
1015 ..
1016 } => {
1017 let tt = self.infer_type(true_expr, scope);
1018 let ft = self.infer_type(false_expr, scope);
1019 match (&tt, &ft) {
1020 (Some(a), Some(b)) if a == b => tt,
1021 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1022 (Some(_), None) => tt,
1023 (None, Some(_)) => ft,
1024 (None, None) => None,
1025 }
1026 }
1027
1028 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1029
1030 Node::PropertyAccess { object, property } => {
1031 if let Node::Identifier(name) = &object.node {
1033 if scope.get_enum(name).is_some() {
1034 return Some(TypeExpr::Named(name.clone()));
1035 }
1036 }
1037 if property == "variant" {
1039 let obj_type = self.infer_type(object, scope);
1040 if let Some(TypeExpr::Named(name)) = &obj_type {
1041 if scope.get_enum(name).is_some() {
1042 return Some(TypeExpr::Named("string".into()));
1043 }
1044 }
1045 }
1046 let obj_type = self.infer_type(object, scope);
1048 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1049 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1050 return Some(field.type_expr.clone());
1051 }
1052 }
1053 None
1054 }
1055
1056 Node::SubscriptAccess { object, index } => {
1057 let obj_type = self.infer_type(object, scope);
1058 match &obj_type {
1059 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1060 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1061 Some(TypeExpr::Shape(fields)) => {
1062 if let Node::StringLiteral(key) = &index.node {
1064 fields
1065 .iter()
1066 .find(|f| &f.name == key)
1067 .map(|f| f.type_expr.clone())
1068 } else {
1069 None
1070 }
1071 }
1072 Some(TypeExpr::Named(n)) if n == "list" => None,
1073 Some(TypeExpr::Named(n)) if n == "dict" => None,
1074 Some(TypeExpr::Named(n)) if n == "string" => {
1075 Some(TypeExpr::Named("string".into()))
1076 }
1077 _ => None,
1078 }
1079 }
1080 Node::SliceAccess { object, .. } => {
1081 let obj_type = self.infer_type(object, scope);
1083 match &obj_type {
1084 Some(TypeExpr::List(_)) => obj_type,
1085 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1086 Some(TypeExpr::Named(n)) if n == "string" => {
1087 Some(TypeExpr::Named("string".into()))
1088 }
1089 _ => None,
1090 }
1091 }
1092 Node::MethodCall { object, method, .. }
1093 | Node::OptionalMethodCall { object, method, .. } => {
1094 let obj_type = self.infer_type(object, scope);
1095 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1096 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1097 match method.as_str() {
1098 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1100 Some(TypeExpr::Named("bool".into()))
1101 }
1102 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1104 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1106 | "pad_left" | "pad_right" | "repeat" | "join" => {
1107 Some(TypeExpr::Named("string".into()))
1108 }
1109 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1110 "filter" => {
1112 if is_dict {
1113 Some(TypeExpr::Named("dict".into()))
1114 } else {
1115 Some(TypeExpr::Named("list".into()))
1116 }
1117 }
1118 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1120 "reduce" | "find" | "first" | "last" => None,
1121 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1123 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1124 "to_string" => Some(TypeExpr::Named("string".into())),
1126 "to_int" => Some(TypeExpr::Named("int".into())),
1127 "to_float" => Some(TypeExpr::Named("float".into())),
1128 _ => None,
1129 }
1130 }
1131
1132 _ => None,
1133 }
1134 }
1135
1136 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1138 if let TypeExpr::Named(name) = expected {
1140 if scope.is_generic_type_param(name) {
1141 return true;
1142 }
1143 }
1144 if let TypeExpr::Named(name) = actual {
1145 if scope.is_generic_type_param(name) {
1146 return true;
1147 }
1148 }
1149 let expected = self.resolve_alias(expected, scope);
1150 let actual = self.resolve_alias(actual, scope);
1151
1152 match (&expected, &actual) {
1153 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1154 (TypeExpr::Union(members), actual_type) => members
1155 .iter()
1156 .any(|m| self.types_compatible(m, actual_type, scope)),
1157 (expected_type, TypeExpr::Union(members)) => members
1158 .iter()
1159 .all(|m| self.types_compatible(expected_type, m, scope)),
1160 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1161 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1162 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1163 if expected_field.optional {
1164 return true;
1165 }
1166 af.iter().any(|actual_field| {
1167 actual_field.name == expected_field.name
1168 && self.types_compatible(
1169 &expected_field.type_expr,
1170 &actual_field.type_expr,
1171 scope,
1172 )
1173 })
1174 }),
1175 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1177 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1178 keys_ok
1179 && af
1180 .iter()
1181 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1182 }
1183 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1185 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1186 self.types_compatible(expected_inner, actual_inner, scope)
1187 }
1188 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1189 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1190 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1191 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1192 }
1193 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1194 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1195 (
1197 TypeExpr::FnType {
1198 params: ep,
1199 return_type: er,
1200 },
1201 TypeExpr::FnType {
1202 params: ap,
1203 return_type: ar,
1204 },
1205 ) => {
1206 ep.len() == ap.len()
1207 && ep
1208 .iter()
1209 .zip(ap.iter())
1210 .all(|(e, a)| self.types_compatible(e, a, scope))
1211 && self.types_compatible(er, ar, scope)
1212 }
1213 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1215 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1216 _ => false,
1217 }
1218 }
1219
1220 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1221 if let TypeExpr::Named(name) = ty {
1222 if let Some(resolved) = scope.resolve_type(name) {
1223 return resolved.clone();
1224 }
1225 }
1226 ty.clone()
1227 }
1228
1229 fn error_at(&mut self, message: String, span: Span) {
1230 self.diagnostics.push(TypeDiagnostic {
1231 message,
1232 severity: DiagnosticSeverity::Error,
1233 span: Some(span),
1234 });
1235 }
1236
1237 fn warning_at(&mut self, message: String, span: Span) {
1238 self.diagnostics.push(TypeDiagnostic {
1239 message,
1240 severity: DiagnosticSeverity::Warning,
1241 span: Some(span),
1242 });
1243 }
1244}
1245
1246impl Default for TypeChecker {
1247 fn default() -> Self {
1248 Self::new()
1249 }
1250}
1251
1252fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1254 match op {
1255 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1256 "+" => match (left, right) {
1257 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1258 match (l.as_str(), r.as_str()) {
1259 ("int", "int") => Some(TypeExpr::Named("int".into())),
1260 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1261 ("string", _) => Some(TypeExpr::Named("string".into())),
1262 ("list", "list") => Some(TypeExpr::Named("list".into())),
1263 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1264 _ => Some(TypeExpr::Named("string".into())),
1265 }
1266 }
1267 _ => None,
1268 },
1269 "-" | "*" | "/" | "%" => match (left, right) {
1270 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1271 match (l.as_str(), r.as_str()) {
1272 ("int", "int") => Some(TypeExpr::Named("int".into())),
1273 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1274 _ => None,
1275 }
1276 }
1277 _ => None,
1278 },
1279 "??" => match (left, right) {
1280 (Some(TypeExpr::Union(members)), _) => {
1281 let non_nil: Vec<_> = members
1282 .iter()
1283 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1284 .cloned()
1285 .collect();
1286 if non_nil.len() == 1 {
1287 Some(non_nil[0].clone())
1288 } else if non_nil.is_empty() {
1289 right.clone()
1290 } else {
1291 Some(TypeExpr::Union(non_nil))
1292 }
1293 }
1294 _ => right.clone(),
1295 },
1296 "|>" => None,
1297 _ => None,
1298 }
1299}
1300
1301pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1306 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1307 let mut details = Vec::new();
1308 for field in ef {
1309 if field.optional {
1310 continue;
1311 }
1312 match af.iter().find(|f| f.name == field.name) {
1313 None => details.push(format!(
1314 "missing field '{}' ({})",
1315 field.name,
1316 format_type(&field.type_expr)
1317 )),
1318 Some(actual_field) => {
1319 let e_str = format_type(&field.type_expr);
1320 let a_str = format_type(&actual_field.type_expr);
1321 if e_str != a_str {
1322 details.push(format!(
1323 "field '{}' has type {}, expected {}",
1324 field.name, a_str, e_str
1325 ));
1326 }
1327 }
1328 }
1329 }
1330 if details.is_empty() {
1331 None
1332 } else {
1333 Some(details.join("; "))
1334 }
1335 } else {
1336 None
1337 }
1338}
1339
1340pub fn format_type(ty: &TypeExpr) -> String {
1341 match ty {
1342 TypeExpr::Named(n) => n.clone(),
1343 TypeExpr::Union(types) => types
1344 .iter()
1345 .map(format_type)
1346 .collect::<Vec<_>>()
1347 .join(" | "),
1348 TypeExpr::Shape(fields) => {
1349 let inner: Vec<String> = fields
1350 .iter()
1351 .map(|f| {
1352 let opt = if f.optional { "?" } else { "" };
1353 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1354 })
1355 .collect();
1356 format!("{{{}}}", inner.join(", "))
1357 }
1358 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1359 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1360 TypeExpr::FnType {
1361 params,
1362 return_type,
1363 } => {
1364 let params_str = params
1365 .iter()
1366 .map(format_type)
1367 .collect::<Vec<_>>()
1368 .join(", ");
1369 format!("fn({}) -> {}", params_str, format_type(return_type))
1370 }
1371 }
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376 use super::*;
1377 use crate::Parser;
1378 use harn_lexer::Lexer;
1379
1380 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1381 let mut lexer = Lexer::new(source);
1382 let tokens = lexer.tokenize().unwrap();
1383 let mut parser = Parser::new(tokens);
1384 let program = parser.parse().unwrap();
1385 TypeChecker::new().check(&program)
1386 }
1387
1388 fn errors(source: &str) -> Vec<String> {
1389 check_source(source)
1390 .into_iter()
1391 .filter(|d| d.severity == DiagnosticSeverity::Error)
1392 .map(|d| d.message)
1393 .collect()
1394 }
1395
1396 #[test]
1397 fn test_no_errors_for_untyped_code() {
1398 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1399 assert!(errs.is_empty());
1400 }
1401
1402 #[test]
1403 fn test_correct_typed_let() {
1404 let errs = errors("pipeline t(task) { let x: int = 42 }");
1405 assert!(errs.is_empty());
1406 }
1407
1408 #[test]
1409 fn test_type_mismatch_let() {
1410 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1411 assert_eq!(errs.len(), 1);
1412 assert!(errs[0].contains("Type mismatch"));
1413 assert!(errs[0].contains("int"));
1414 assert!(errs[0].contains("string"));
1415 }
1416
1417 #[test]
1418 fn test_correct_typed_fn() {
1419 let errs = errors(
1420 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1421 );
1422 assert!(errs.is_empty());
1423 }
1424
1425 #[test]
1426 fn test_fn_arg_type_mismatch() {
1427 let errs = errors(
1428 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1429add("hello", 2) }"#,
1430 );
1431 assert_eq!(errs.len(), 1);
1432 assert!(errs[0].contains("Argument 1"));
1433 assert!(errs[0].contains("expected int"));
1434 }
1435
1436 #[test]
1437 fn test_return_type_mismatch() {
1438 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1439 assert_eq!(errs.len(), 1);
1440 assert!(errs[0].contains("Return type mismatch"));
1441 }
1442
1443 #[test]
1444 fn test_union_type_compatible() {
1445 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1446 assert!(errs.is_empty());
1447 }
1448
1449 #[test]
1450 fn test_union_type_mismatch() {
1451 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1452 assert_eq!(errs.len(), 1);
1453 assert!(errs[0].contains("Type mismatch"));
1454 }
1455
1456 #[test]
1457 fn test_type_inference_propagation() {
1458 let errs = errors(
1459 r#"pipeline t(task) {
1460 fn add(a: int, b: int) -> int { return a + b }
1461 let result: string = add(1, 2)
1462}"#,
1463 );
1464 assert_eq!(errs.len(), 1);
1465 assert!(errs[0].contains("Type mismatch"));
1466 assert!(errs[0].contains("string"));
1467 assert!(errs[0].contains("int"));
1468 }
1469
1470 #[test]
1471 fn test_builtin_return_type_inference() {
1472 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1473 assert_eq!(errs.len(), 1);
1474 assert!(errs[0].contains("string"));
1475 assert!(errs[0].contains("int"));
1476 }
1477
1478 #[test]
1479 fn test_binary_op_type_inference() {
1480 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1481 assert_eq!(errs.len(), 1);
1482 }
1483
1484 #[test]
1485 fn test_comparison_returns_bool() {
1486 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1487 assert!(errs.is_empty());
1488 }
1489
1490 #[test]
1491 fn test_int_float_promotion() {
1492 let errs = errors("pipeline t(task) { let x: float = 42 }");
1493 assert!(errs.is_empty());
1494 }
1495
1496 #[test]
1497 fn test_untyped_code_no_errors() {
1498 let errs = errors(
1499 r#"pipeline t(task) {
1500 fn process(data) {
1501 let result = data + " processed"
1502 return result
1503 }
1504 log(process("hello"))
1505}"#,
1506 );
1507 assert!(errs.is_empty());
1508 }
1509
1510 #[test]
1511 fn test_type_alias() {
1512 let errs = errors(
1513 r#"pipeline t(task) {
1514 type Name = string
1515 let x: Name = "hello"
1516}"#,
1517 );
1518 assert!(errs.is_empty());
1519 }
1520
1521 #[test]
1522 fn test_type_alias_mismatch() {
1523 let errs = errors(
1524 r#"pipeline t(task) {
1525 type Name = string
1526 let x: Name = 42
1527}"#,
1528 );
1529 assert_eq!(errs.len(), 1);
1530 }
1531
1532 #[test]
1533 fn test_assignment_type_check() {
1534 let errs = errors(
1535 r#"pipeline t(task) {
1536 var x: int = 0
1537 x = "hello"
1538}"#,
1539 );
1540 assert_eq!(errs.len(), 1);
1541 assert!(errs[0].contains("cannot assign string"));
1542 }
1543
1544 #[test]
1545 fn test_covariance_int_to_float_in_fn() {
1546 let errs = errors(
1547 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1548 );
1549 assert!(errs.is_empty());
1550 }
1551
1552 #[test]
1553 fn test_covariance_return_type() {
1554 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1555 assert!(errs.is_empty());
1556 }
1557
1558 #[test]
1559 fn test_no_contravariance_float_to_int() {
1560 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1561 assert_eq!(errs.len(), 1);
1562 }
1563
1564 fn warnings(source: &str) -> Vec<String> {
1567 check_source(source)
1568 .into_iter()
1569 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1570 .map(|d| d.message)
1571 .collect()
1572 }
1573
1574 #[test]
1575 fn test_exhaustive_match_no_warning() {
1576 let warns = warnings(
1577 r#"pipeline t(task) {
1578 enum Color { Red, Green, Blue }
1579 let c = Color.Red
1580 match c.variant {
1581 "Red" -> { log("r") }
1582 "Green" -> { log("g") }
1583 "Blue" -> { log("b") }
1584 }
1585}"#,
1586 );
1587 let exhaustive_warns: Vec<_> = warns
1588 .iter()
1589 .filter(|w| w.contains("Non-exhaustive"))
1590 .collect();
1591 assert!(exhaustive_warns.is_empty());
1592 }
1593
1594 #[test]
1595 fn test_non_exhaustive_match_warning() {
1596 let warns = warnings(
1597 r#"pipeline t(task) {
1598 enum Color { Red, Green, Blue }
1599 let c = Color.Red
1600 match c.variant {
1601 "Red" -> { log("r") }
1602 "Green" -> { log("g") }
1603 }
1604}"#,
1605 );
1606 let exhaustive_warns: Vec<_> = warns
1607 .iter()
1608 .filter(|w| w.contains("Non-exhaustive"))
1609 .collect();
1610 assert_eq!(exhaustive_warns.len(), 1);
1611 assert!(exhaustive_warns[0].contains("Blue"));
1612 }
1613
1614 #[test]
1615 fn test_non_exhaustive_multiple_missing() {
1616 let warns = warnings(
1617 r#"pipeline t(task) {
1618 enum Status { Active, Inactive, Pending }
1619 let s = Status.Active
1620 match s.variant {
1621 "Active" -> { log("a") }
1622 }
1623}"#,
1624 );
1625 let exhaustive_warns: Vec<_> = warns
1626 .iter()
1627 .filter(|w| w.contains("Non-exhaustive"))
1628 .collect();
1629 assert_eq!(exhaustive_warns.len(), 1);
1630 assert!(exhaustive_warns[0].contains("Inactive"));
1631 assert!(exhaustive_warns[0].contains("Pending"));
1632 }
1633
1634 #[test]
1635 fn test_enum_construct_type_inference() {
1636 let errs = errors(
1637 r#"pipeline t(task) {
1638 enum Color { Red, Green, Blue }
1639 let c: Color = Color.Red
1640}"#,
1641 );
1642 assert!(errs.is_empty());
1643 }
1644
1645 #[test]
1648 fn test_nil_coalescing_strips_nil() {
1649 let errs = errors(
1651 r#"pipeline t(task) {
1652 let x: string | nil = nil
1653 let y: string = x ?? "default"
1654}"#,
1655 );
1656 assert!(errs.is_empty());
1657 }
1658
1659 #[test]
1660 fn test_shape_mismatch_detail_missing_field() {
1661 let errs = errors(
1662 r#"pipeline t(task) {
1663 let x: {name: string, age: int} = {name: "hello"}
1664}"#,
1665 );
1666 assert_eq!(errs.len(), 1);
1667 assert!(
1668 errs[0].contains("missing field 'age'"),
1669 "expected detail about missing field, got: {}",
1670 errs[0]
1671 );
1672 }
1673
1674 #[test]
1675 fn test_shape_mismatch_detail_wrong_type() {
1676 let errs = errors(
1677 r#"pipeline t(task) {
1678 let x: {name: string, age: int} = {name: 42, age: 10}
1679}"#,
1680 );
1681 assert_eq!(errs.len(), 1);
1682 assert!(
1683 errs[0].contains("field 'name' has type int, expected string"),
1684 "expected detail about wrong type, got: {}",
1685 errs[0]
1686 );
1687 }
1688
1689 #[test]
1692 fn test_match_pattern_string_against_int() {
1693 let warns = warnings(
1694 r#"pipeline t(task) {
1695 let x: int = 42
1696 match x {
1697 "hello" -> { log("bad") }
1698 42 -> { log("ok") }
1699 }
1700}"#,
1701 );
1702 let pattern_warns: Vec<_> = warns
1703 .iter()
1704 .filter(|w| w.contains("Match pattern type mismatch"))
1705 .collect();
1706 assert_eq!(pattern_warns.len(), 1);
1707 assert!(pattern_warns[0].contains("matching int against string literal"));
1708 }
1709
1710 #[test]
1711 fn test_match_pattern_int_against_string() {
1712 let warns = warnings(
1713 r#"pipeline t(task) {
1714 let x: string = "hello"
1715 match x {
1716 42 -> { log("bad") }
1717 "hello" -> { log("ok") }
1718 }
1719}"#,
1720 );
1721 let pattern_warns: Vec<_> = warns
1722 .iter()
1723 .filter(|w| w.contains("Match pattern type mismatch"))
1724 .collect();
1725 assert_eq!(pattern_warns.len(), 1);
1726 assert!(pattern_warns[0].contains("matching string against int literal"));
1727 }
1728
1729 #[test]
1730 fn test_match_pattern_bool_against_int() {
1731 let warns = warnings(
1732 r#"pipeline t(task) {
1733 let x: int = 42
1734 match x {
1735 true -> { log("bad") }
1736 42 -> { log("ok") }
1737 }
1738}"#,
1739 );
1740 let pattern_warns: Vec<_> = warns
1741 .iter()
1742 .filter(|w| w.contains("Match pattern type mismatch"))
1743 .collect();
1744 assert_eq!(pattern_warns.len(), 1);
1745 assert!(pattern_warns[0].contains("matching int against bool literal"));
1746 }
1747
1748 #[test]
1749 fn test_match_pattern_float_against_string() {
1750 let warns = warnings(
1751 r#"pipeline t(task) {
1752 let x: string = "hello"
1753 match x {
1754 3.14 -> { log("bad") }
1755 "hello" -> { log("ok") }
1756 }
1757}"#,
1758 );
1759 let pattern_warns: Vec<_> = warns
1760 .iter()
1761 .filter(|w| w.contains("Match pattern type mismatch"))
1762 .collect();
1763 assert_eq!(pattern_warns.len(), 1);
1764 assert!(pattern_warns[0].contains("matching string against float literal"));
1765 }
1766
1767 #[test]
1768 fn test_match_pattern_int_against_float_ok() {
1769 let warns = warnings(
1771 r#"pipeline t(task) {
1772 let x: float = 3.14
1773 match x {
1774 42 -> { log("ok") }
1775 _ -> { log("default") }
1776 }
1777}"#,
1778 );
1779 let pattern_warns: Vec<_> = warns
1780 .iter()
1781 .filter(|w| w.contains("Match pattern type mismatch"))
1782 .collect();
1783 assert!(pattern_warns.is_empty());
1784 }
1785
1786 #[test]
1787 fn test_match_pattern_float_against_int_ok() {
1788 let warns = warnings(
1790 r#"pipeline t(task) {
1791 let x: int = 42
1792 match x {
1793 3.14 -> { log("close") }
1794 _ -> { log("default") }
1795 }
1796}"#,
1797 );
1798 let pattern_warns: Vec<_> = warns
1799 .iter()
1800 .filter(|w| w.contains("Match pattern type mismatch"))
1801 .collect();
1802 assert!(pattern_warns.is_empty());
1803 }
1804
1805 #[test]
1806 fn test_match_pattern_correct_types_no_warning() {
1807 let warns = warnings(
1808 r#"pipeline t(task) {
1809 let x: int = 42
1810 match x {
1811 1 -> { log("one") }
1812 2 -> { log("two") }
1813 _ -> { log("other") }
1814 }
1815}"#,
1816 );
1817 let pattern_warns: Vec<_> = warns
1818 .iter()
1819 .filter(|w| w.contains("Match pattern type mismatch"))
1820 .collect();
1821 assert!(pattern_warns.is_empty());
1822 }
1823
1824 #[test]
1825 fn test_match_pattern_wildcard_no_warning() {
1826 let warns = warnings(
1827 r#"pipeline t(task) {
1828 let x: int = 42
1829 match x {
1830 _ -> { log("catch all") }
1831 }
1832}"#,
1833 );
1834 let pattern_warns: Vec<_> = warns
1835 .iter()
1836 .filter(|w| w.contains("Match pattern type mismatch"))
1837 .collect();
1838 assert!(pattern_warns.is_empty());
1839 }
1840
1841 #[test]
1842 fn test_match_pattern_untyped_no_warning() {
1843 let warns = warnings(
1845 r#"pipeline t(task) {
1846 let x = some_unknown_fn()
1847 match x {
1848 "hello" -> { log("string") }
1849 42 -> { log("int") }
1850 }
1851}"#,
1852 );
1853 let pattern_warns: Vec<_> = warns
1854 .iter()
1855 .filter(|w| w.contains("Match pattern type mismatch"))
1856 .collect();
1857 assert!(pattern_warns.is_empty());
1858 }
1859}