1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 generic_type_params: std::collections::BTreeSet<String>,
40 parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45 params: Vec<(String, InferredType)>,
46 return_type: InferredType,
47 type_param_names: Vec<String>,
49 required_params: usize,
51}
52
53impl TypeScope {
54 fn new() -> Self {
55 Self {
56 vars: BTreeMap::new(),
57 functions: BTreeMap::new(),
58 type_aliases: BTreeMap::new(),
59 enums: BTreeMap::new(),
60 interfaces: BTreeMap::new(),
61 structs: BTreeMap::new(),
62 generic_type_params: std::collections::BTreeSet::new(),
63 parent: None,
64 }
65 }
66
67 fn child(&self) -> Self {
68 Self {
69 vars: BTreeMap::new(),
70 functions: BTreeMap::new(),
71 type_aliases: BTreeMap::new(),
72 enums: BTreeMap::new(),
73 interfaces: BTreeMap::new(),
74 structs: BTreeMap::new(),
75 generic_type_params: std::collections::BTreeSet::new(),
76 parent: Some(Box::new(self.clone())),
77 }
78 }
79
80 fn get_var(&self, name: &str) -> Option<&InferredType> {
81 self.vars
82 .get(name)
83 .or_else(|| self.parent.as_ref()?.get_var(name))
84 }
85
86 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87 self.functions
88 .get(name)
89 .or_else(|| self.parent.as_ref()?.get_fn(name))
90 }
91
92 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93 self.type_aliases
94 .get(name)
95 .or_else(|| self.parent.as_ref()?.resolve_type(name))
96 }
97
98 fn is_generic_type_param(&self, name: &str) -> bool {
99 self.generic_type_params.contains(name)
100 || self
101 .parent
102 .as_ref()
103 .is_some_and(|p| p.is_generic_type_param(name))
104 }
105
106 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107 self.enums
108 .get(name)
109 .or_else(|| self.parent.as_ref()?.get_enum(name))
110 }
111
112 #[allow(dead_code)]
113 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114 self.interfaces
115 .get(name)
116 .or_else(|| self.parent.as_ref()?.get_interface(name))
117 }
118
119 fn define_var(&mut self, name: &str, ty: InferredType) {
120 self.vars.insert(name.to_string(), ty);
121 }
122
123 fn define_fn(&mut self, name: &str, sig: FnSignature) {
124 self.functions.insert(name.to_string(), sig);
125 }
126}
127
128fn builtin_return_type(name: &str) -> InferredType {
130 match name {
131 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133 Some(TypeExpr::Named("nil".into()))
134 }
135 "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
136 | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
137 | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
138 "to_int" => Some(TypeExpr::Named("int".into())),
139 "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
140 "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
141 "list_dir" => Some(TypeExpr::Named("list".into())),
142 "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
143 "env" | "regex_match" => Some(TypeExpr::Union(vec![
144 TypeExpr::Named("string".into()),
145 TypeExpr::Named("nil".into()),
146 ])),
147 "json_parse" | "json_extract" => None, _ => None,
149 }
150}
151
152fn is_builtin(name: &str) -> bool {
154 matches!(
155 name,
156 "log"
157 | "print"
158 | "println"
159 | "type_of"
160 | "to_string"
161 | "to_int"
162 | "to_float"
163 | "json_stringify"
164 | "json_parse"
165 | "env"
166 | "timestamp"
167 | "sleep"
168 | "read_file"
169 | "write_file"
170 | "exit"
171 | "regex_match"
172 | "regex_replace"
173 | "http_get"
174 | "http_post"
175 | "llm_call"
176 | "agent_loop"
177 | "await"
178 | "cancel"
179 | "file_exists"
180 | "delete_file"
181 | "list_dir"
182 | "mkdir"
183 | "path_join"
184 | "copy_file"
185 | "append_file"
186 | "temp_dir"
187 | "stat"
188 | "exec"
189 | "shell"
190 | "date_now"
191 | "date_format"
192 | "date_parse"
193 | "format"
194 | "json_validate"
195 | "json_extract"
196 | "trim"
197 | "lowercase"
198 | "uppercase"
199 | "split"
200 | "starts_with"
201 | "ends_with"
202 | "contains"
203 | "replace"
204 | "join"
205 | "len"
206 | "substring"
207 | "dirname"
208 | "basename"
209 | "extname"
210 )
211}
212
213pub struct TypeChecker {
215 diagnostics: Vec<TypeDiagnostic>,
216 scope: TypeScope,
217}
218
219impl TypeChecker {
220 pub fn new() -> Self {
221 Self {
222 diagnostics: Vec::new(),
223 scope: TypeScope::new(),
224 }
225 }
226
227 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
229 Self::register_declarations_into(&mut self.scope, program);
231
232 for snode in program {
234 if let Node::Pipeline { body, .. } = &snode.node {
235 Self::register_declarations_into(&mut self.scope, body);
236 }
237 }
238
239 for snode in program {
241 match &snode.node {
242 Node::Pipeline { params, body, .. } => {
243 let mut child = self.scope.child();
244 for p in params {
245 child.define_var(p, None);
246 }
247 self.check_block(body, &mut child);
248 }
249 Node::FnDecl {
250 name,
251 type_params,
252 params,
253 return_type,
254 body,
255 ..
256 } => {
257 let required_params = params
258 .iter()
259 .filter(|p| p.default_value.is_none())
260 .count();
261 let sig = FnSignature {
262 params: params
263 .iter()
264 .map(|p| (p.name.clone(), p.type_expr.clone()))
265 .collect(),
266 return_type: return_type.clone(),
267 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
268 required_params,
269 };
270 self.scope.define_fn(name, sig);
271 self.check_fn_body(type_params, params, return_type, body);
272 }
273 _ => {
274 let mut scope = self.scope.clone();
275 self.check_node(snode, &mut scope);
276 for (name, ty) in scope.vars {
278 self.scope.vars.entry(name).or_insert(ty);
279 }
280 }
281 }
282 }
283
284 self.diagnostics
285 }
286
287 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
289 for snode in nodes {
290 match &snode.node {
291 Node::TypeDecl { name, type_expr } => {
292 scope.type_aliases.insert(name.clone(), type_expr.clone());
293 }
294 Node::EnumDecl { name, variants } => {
295 let variant_names: Vec<String> =
296 variants.iter().map(|v| v.name.clone()).collect();
297 scope.enums.insert(name.clone(), variant_names);
298 }
299 Node::InterfaceDecl { name, methods } => {
300 scope.interfaces.insert(name.clone(), methods.clone());
301 }
302 Node::StructDecl { name, fields } => {
303 let field_types: Vec<(String, InferredType)> = fields
304 .iter()
305 .map(|f| (f.name.clone(), f.type_expr.clone()))
306 .collect();
307 scope.structs.insert(name.clone(), field_types);
308 }
309 _ => {}
310 }
311 }
312 }
313
314 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
315 for stmt in stmts {
316 self.check_node(stmt, scope);
317 }
318 }
319
320 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
322 match pattern {
323 BindingPattern::Identifier(name) => {
324 scope.define_var(name, None);
325 }
326 BindingPattern::Dict(fields) => {
327 for field in fields {
328 let name = field.alias.as_deref().unwrap_or(&field.key);
329 scope.define_var(name, None);
330 }
331 }
332 BindingPattern::List(elements) => {
333 for elem in elements {
334 scope.define_var(&elem.name, None);
335 }
336 }
337 }
338 }
339
340 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
341 let span = snode.span;
342 match &snode.node {
343 Node::LetBinding {
344 pattern,
345 type_ann,
346 value,
347 } => {
348 let inferred = self.infer_type(value, scope);
349 if let BindingPattern::Identifier(name) = pattern {
350 if let Some(expected) = type_ann {
351 if let Some(actual) = &inferred {
352 if !self.types_compatible(expected, actual, scope) {
353 let mut msg = format!(
354 "Type mismatch: '{}' declared as {}, but assigned {}",
355 name,
356 format_type(expected),
357 format_type(actual)
358 );
359 if let Some(detail) = shape_mismatch_detail(expected, actual) {
360 msg.push_str(&format!(" ({})", detail));
361 }
362 self.error_at(msg, span);
363 }
364 }
365 }
366 let ty = type_ann.clone().or(inferred);
367 scope.define_var(name, ty);
368 } else {
369 Self::define_pattern_vars(pattern, scope);
370 }
371 }
372
373 Node::VarBinding {
374 pattern,
375 type_ann,
376 value,
377 } => {
378 let inferred = self.infer_type(value, scope);
379 if let BindingPattern::Identifier(name) = pattern {
380 if let Some(expected) = type_ann {
381 if let Some(actual) = &inferred {
382 if !self.types_compatible(expected, actual, scope) {
383 let mut msg = format!(
384 "Type mismatch: '{}' declared as {}, but assigned {}",
385 name,
386 format_type(expected),
387 format_type(actual)
388 );
389 if let Some(detail) = shape_mismatch_detail(expected, actual) {
390 msg.push_str(&format!(" ({})", detail));
391 }
392 self.error_at(msg, span);
393 }
394 }
395 }
396 let ty = type_ann.clone().or(inferred);
397 scope.define_var(name, ty);
398 } else {
399 Self::define_pattern_vars(pattern, scope);
400 }
401 }
402
403 Node::FnDecl {
404 name,
405 type_params,
406 params,
407 return_type,
408 body,
409 ..
410 } => {
411 let required_params = params
412 .iter()
413 .filter(|p| p.default_value.is_none())
414 .count();
415 let sig = FnSignature {
416 params: params
417 .iter()
418 .map(|p| (p.name.clone(), p.type_expr.clone()))
419 .collect(),
420 return_type: return_type.clone(),
421 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
422 required_params,
423 };
424 scope.define_fn(name, sig.clone());
425 scope.define_var(name, None);
426 self.check_fn_body(type_params, params, return_type, body);
427 }
428
429 Node::FunctionCall { name, args } => {
430 self.check_call(name, args, scope, span);
431 }
432
433 Node::IfElse {
434 condition,
435 then_body,
436 else_body,
437 } => {
438 self.check_node(condition, scope);
439 let mut then_scope = scope.child();
440 self.check_block(then_body, &mut then_scope);
441 if let Some(else_body) = else_body {
442 let mut else_scope = scope.child();
443 self.check_block(else_body, &mut else_scope);
444 }
445 }
446
447 Node::ForIn {
448 pattern,
449 iterable,
450 body,
451 } => {
452 self.check_node(iterable, scope);
453 let mut loop_scope = scope.child();
454 if let BindingPattern::Identifier(variable) = pattern {
455 let elem_type = match self.infer_type(iterable, scope) {
457 Some(TypeExpr::List(inner)) => Some(*inner),
458 Some(TypeExpr::Named(n)) if n == "string" => {
459 Some(TypeExpr::Named("string".into()))
460 }
461 _ => None,
462 };
463 loop_scope.define_var(variable, elem_type);
464 } else {
465 Self::define_pattern_vars(pattern, &mut loop_scope);
466 }
467 self.check_block(body, &mut loop_scope);
468 }
469
470 Node::WhileLoop { condition, body } => {
471 self.check_node(condition, scope);
472 let mut loop_scope = scope.child();
473 self.check_block(body, &mut loop_scope);
474 }
475
476 Node::TryCatch {
477 body,
478 error_var,
479 catch_body,
480 finally_body,
481 ..
482 } => {
483 let mut try_scope = scope.child();
484 self.check_block(body, &mut try_scope);
485 let mut catch_scope = scope.child();
486 if let Some(var) = error_var {
487 catch_scope.define_var(var, None);
488 }
489 self.check_block(catch_body, &mut catch_scope);
490 if let Some(fb) = finally_body {
491 let mut finally_scope = scope.child();
492 self.check_block(fb, &mut finally_scope);
493 }
494 }
495
496 Node::ReturnStmt {
497 value: Some(val), ..
498 } => {
499 self.check_node(val, scope);
500 }
501
502 Node::Assignment {
503 target, value, op, ..
504 } => {
505 self.check_node(value, scope);
506 if let Node::Identifier(name) = &target.node {
507 if let Some(Some(var_type)) = scope.get_var(name) {
508 let value_type = self.infer_type(value, scope);
509 let assigned = if let Some(op) = op {
510 let var_inferred = scope.get_var(name).cloned().flatten();
511 infer_binary_op_type(op, &var_inferred, &value_type)
512 } else {
513 value_type
514 };
515 if let Some(actual) = &assigned {
516 if !self.types_compatible(var_type, actual, scope) {
517 self.error_at(
518 format!(
519 "Type mismatch: cannot assign {} to '{}' (declared as {})",
520 format_type(actual),
521 name,
522 format_type(var_type)
523 ),
524 span,
525 );
526 }
527 }
528 }
529 }
530 }
531
532 Node::TypeDecl { name, type_expr } => {
533 scope.type_aliases.insert(name.clone(), type_expr.clone());
534 }
535
536 Node::EnumDecl { name, variants } => {
537 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
538 scope.enums.insert(name.clone(), variant_names);
539 }
540
541 Node::StructDecl { name, fields } => {
542 let field_types: Vec<(String, InferredType)> = fields
543 .iter()
544 .map(|f| (f.name.clone(), f.type_expr.clone()))
545 .collect();
546 scope.structs.insert(name.clone(), field_types);
547 }
548
549 Node::InterfaceDecl { name, methods } => {
550 scope.interfaces.insert(name.clone(), methods.clone());
551 }
552
553 Node::MatchExpr { value, arms } => {
554 self.check_node(value, scope);
555 let value_type = self.infer_type(value, scope);
556 for arm in arms {
557 self.check_node(&arm.pattern, scope);
558 if let Some(ref vt) = value_type {
560 let value_type_name = format_type(vt);
561 let mismatch = match &arm.pattern.node {
562 Node::StringLiteral(_) => {
563 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
564 }
565 Node::IntLiteral(_) => {
566 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
567 && !self.types_compatible(
568 vt,
569 &TypeExpr::Named("float".into()),
570 scope,
571 )
572 }
573 Node::FloatLiteral(_) => {
574 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
575 && !self.types_compatible(
576 vt,
577 &TypeExpr::Named("int".into()),
578 scope,
579 )
580 }
581 Node::BoolLiteral(_) => {
582 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
583 }
584 _ => false,
585 };
586 if mismatch {
587 let pattern_type = match &arm.pattern.node {
588 Node::StringLiteral(_) => "string",
589 Node::IntLiteral(_) => "int",
590 Node::FloatLiteral(_) => "float",
591 Node::BoolLiteral(_) => "bool",
592 _ => unreachable!(),
593 };
594 self.warning_at(
595 format!(
596 "Match pattern type mismatch: matching {} against {} literal",
597 value_type_name, pattern_type
598 ),
599 arm.pattern.span,
600 );
601 }
602 }
603 let mut arm_scope = scope.child();
604 self.check_block(&arm.body, &mut arm_scope);
605 }
606 self.check_match_exhaustiveness(value, arms, scope, span);
607 }
608
609 Node::BinaryOp { op, left, right } => {
611 self.check_node(left, scope);
612 self.check_node(right, scope);
613 let lt = self.infer_type(left, scope);
615 let rt = self.infer_type(right, scope);
616 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
617 match op.as_str() {
618 "-" | "*" | "/" | "%" => {
619 let numeric = ["int", "float"];
620 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
621 self.warning_at(
622 format!(
623 "Operator '{op}' may not be valid for types {} and {}",
624 l, r
625 ),
626 span,
627 );
628 }
629 }
630 "+" => {
631 let valid = ["int", "float", "string", "list", "dict"];
633 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
634 self.warning_at(
635 format!(
636 "Operator '+' may not be valid for types {} and {}",
637 l, r
638 ),
639 span,
640 );
641 }
642 }
643 _ => {}
644 }
645 }
646 }
647 Node::UnaryOp { operand, .. } => {
648 self.check_node(operand, scope);
649 }
650 Node::MethodCall { object, args, .. }
651 | Node::OptionalMethodCall { object, args, .. } => {
652 self.check_node(object, scope);
653 for arg in args {
654 self.check_node(arg, scope);
655 }
656 }
657 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
658 self.check_node(object, scope);
659 }
660 Node::SubscriptAccess { object, index } => {
661 self.check_node(object, scope);
662 self.check_node(index, scope);
663 }
664 Node::SliceAccess { object, start, end } => {
665 self.check_node(object, scope);
666 if let Some(s) = start {
667 self.check_node(s, scope);
668 }
669 if let Some(e) = end {
670 self.check_node(e, scope);
671 }
672 }
673
674 _ => {}
676 }
677 }
678
679 fn check_fn_body(
680 &mut self,
681 type_params: &[TypeParam],
682 params: &[TypedParam],
683 return_type: &Option<TypeExpr>,
684 body: &[SNode],
685 ) {
686 let mut fn_scope = self.scope.child();
687 for tp in type_params {
690 fn_scope.generic_type_params.insert(tp.name.clone());
691 }
692 for param in params {
693 fn_scope.define_var(¶m.name, param.type_expr.clone());
694 if let Some(default) = ¶m.default_value {
695 self.check_node(default, &mut fn_scope);
696 }
697 }
698 self.check_block(body, &mut fn_scope);
699
700 if let Some(ret_type) = return_type {
702 for stmt in body {
703 self.check_return_type(stmt, ret_type, &fn_scope);
704 }
705 }
706 }
707
708 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
709 let span = snode.span;
710 match &snode.node {
711 Node::ReturnStmt { value: Some(val) } => {
712 let inferred = self.infer_type(val, scope);
713 if let Some(actual) = &inferred {
714 if !self.types_compatible(expected, actual, scope) {
715 self.error_at(
716 format!(
717 "Return type mismatch: expected {}, got {}",
718 format_type(expected),
719 format_type(actual)
720 ),
721 span,
722 );
723 }
724 }
725 }
726 Node::IfElse {
727 then_body,
728 else_body,
729 ..
730 } => {
731 for stmt in then_body {
732 self.check_return_type(stmt, expected, scope);
733 }
734 if let Some(else_body) = else_body {
735 for stmt in else_body {
736 self.check_return_type(stmt, expected, scope);
737 }
738 }
739 }
740 _ => {}
741 }
742 }
743
744 fn check_match_exhaustiveness(
746 &mut self,
747 value: &SNode,
748 arms: &[MatchArm],
749 scope: &TypeScope,
750 span: Span,
751 ) {
752 let enum_name = match &value.node {
754 Node::PropertyAccess { object, property } if property == "variant" => {
755 match self.infer_type(object, scope) {
757 Some(TypeExpr::Named(name)) => {
758 if scope.get_enum(&name).is_some() {
759 Some(name)
760 } else {
761 None
762 }
763 }
764 _ => None,
765 }
766 }
767 _ => {
768 match self.infer_type(value, scope) {
770 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
771 _ => None,
772 }
773 }
774 };
775
776 let Some(enum_name) = enum_name else {
777 return;
778 };
779 let Some(variants) = scope.get_enum(&enum_name) else {
780 return;
781 };
782
783 let mut covered: Vec<String> = Vec::new();
785 let mut has_wildcard = false;
786
787 for arm in arms {
788 match &arm.pattern.node {
789 Node::StringLiteral(s) => covered.push(s.clone()),
791 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
793 has_wildcard = true;
794 }
795 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
797 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
799 _ => {
800 has_wildcard = true;
802 }
803 }
804 }
805
806 if has_wildcard {
807 return;
808 }
809
810 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
811 if !missing.is_empty() {
812 let missing_str = missing
813 .iter()
814 .map(|s| format!("\"{}\"", s))
815 .collect::<Vec<_>>()
816 .join(", ");
817 self.warning_at(
818 format!(
819 "Non-exhaustive match on enum {}: missing variants {}",
820 enum_name, missing_str
821 ),
822 span,
823 );
824 }
825 }
826
827 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
828 if let Some(sig) = scope.get_fn(name).cloned() {
830 if !is_builtin(name)
831 && (args.len() < sig.required_params || args.len() > sig.params.len())
832 {
833 let expected = if sig.required_params == sig.params.len() {
834 format!("{}", sig.params.len())
835 } else {
836 format!("{}-{}", sig.required_params, sig.params.len())
837 };
838 self.warning_at(
839 format!(
840 "Function '{}' expects {} arguments, got {}",
841 name,
842 expected,
843 args.len()
844 ),
845 span,
846 );
847 }
848 let call_scope = if sig.type_param_names.is_empty() {
851 scope.clone()
852 } else {
853 let mut s = scope.child();
854 for tp_name in &sig.type_param_names {
855 s.generic_type_params.insert(tp_name.clone());
856 }
857 s
858 };
859 for (i, (arg, (param_name, param_type))) in
860 args.iter().zip(sig.params.iter()).enumerate()
861 {
862 if let Some(expected) = param_type {
863 let actual = self.infer_type(arg, scope);
864 if let Some(actual) = &actual {
865 if !self.types_compatible(expected, actual, &call_scope) {
866 self.error_at(
867 format!(
868 "Argument {} ('{}'): expected {}, got {}",
869 i + 1,
870 param_name,
871 format_type(expected),
872 format_type(actual)
873 ),
874 arg.span,
875 );
876 }
877 }
878 }
879 }
880 }
881 for arg in args {
883 self.check_node(arg, scope);
884 }
885 }
886
887 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
889 match &snode.node {
890 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
891 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
892 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
893 Some(TypeExpr::Named("string".into()))
894 }
895 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
896 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
897 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
898 Node::DictLiteral(entries) => {
899 let mut fields = Vec::new();
901 let mut all_string_keys = true;
902 for entry in entries {
903 if let Node::StringLiteral(key) = &entry.key.node {
904 let val_type = self
905 .infer_type(&entry.value, scope)
906 .unwrap_or(TypeExpr::Named("nil".into()));
907 fields.push(ShapeField {
908 name: key.clone(),
909 type_expr: val_type,
910 optional: false,
911 });
912 } else {
913 all_string_keys = false;
914 break;
915 }
916 }
917 if all_string_keys && !fields.is_empty() {
918 Some(TypeExpr::Shape(fields))
919 } else {
920 Some(TypeExpr::Named("dict".into()))
921 }
922 }
923 Node::Closure { params, body } => {
924 let all_typed = params.iter().all(|p| p.type_expr.is_some());
926 if all_typed && !params.is_empty() {
927 let param_types: Vec<TypeExpr> =
928 params.iter().filter_map(|p| p.type_expr.clone()).collect();
929 let ret = body.last().and_then(|last| self.infer_type(last, scope));
931 if let Some(ret_type) = ret {
932 return Some(TypeExpr::FnType {
933 params: param_types,
934 return_type: Box::new(ret_type),
935 });
936 }
937 }
938 Some(TypeExpr::Named("closure".into()))
939 }
940
941 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
942
943 Node::FunctionCall { name, .. } => {
944 if let Some(sig) = scope.get_fn(name) {
946 return sig.return_type.clone();
947 }
948 builtin_return_type(name)
950 }
951
952 Node::BinaryOp { op, left, right } => {
953 let lt = self.infer_type(left, scope);
954 let rt = self.infer_type(right, scope);
955 infer_binary_op_type(op, <, &rt)
956 }
957
958 Node::UnaryOp { op, operand } => {
959 let t = self.infer_type(operand, scope);
960 match op.as_str() {
961 "!" => Some(TypeExpr::Named("bool".into())),
962 "-" => t, _ => None,
964 }
965 }
966
967 Node::Ternary {
968 true_expr,
969 false_expr,
970 ..
971 } => {
972 let tt = self.infer_type(true_expr, scope);
973 let ft = self.infer_type(false_expr, scope);
974 match (&tt, &ft) {
975 (Some(a), Some(b)) if a == b => tt,
976 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
977 (Some(_), None) => tt,
978 (None, Some(_)) => ft,
979 (None, None) => None,
980 }
981 }
982
983 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
984
985 Node::PropertyAccess { object, property } => {
986 if let Node::Identifier(name) = &object.node {
988 if scope.get_enum(name).is_some() {
989 return Some(TypeExpr::Named(name.clone()));
990 }
991 }
992 if property == "variant" {
994 let obj_type = self.infer_type(object, scope);
995 if let Some(TypeExpr::Named(name)) = &obj_type {
996 if scope.get_enum(name).is_some() {
997 return Some(TypeExpr::Named("string".into()));
998 }
999 }
1000 }
1001 let obj_type = self.infer_type(object, scope);
1003 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1004 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1005 return Some(field.type_expr.clone());
1006 }
1007 }
1008 None
1009 }
1010
1011 Node::SubscriptAccess { object, index } => {
1012 let obj_type = self.infer_type(object, scope);
1013 match &obj_type {
1014 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1015 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1016 Some(TypeExpr::Shape(fields)) => {
1017 if let Node::StringLiteral(key) = &index.node {
1019 fields
1020 .iter()
1021 .find(|f| &f.name == key)
1022 .map(|f| f.type_expr.clone())
1023 } else {
1024 None
1025 }
1026 }
1027 Some(TypeExpr::Named(n)) if n == "list" => None,
1028 Some(TypeExpr::Named(n)) if n == "dict" => None,
1029 Some(TypeExpr::Named(n)) if n == "string" => {
1030 Some(TypeExpr::Named("string".into()))
1031 }
1032 _ => None,
1033 }
1034 }
1035 Node::SliceAccess { object, .. } => {
1036 let obj_type = self.infer_type(object, scope);
1038 match &obj_type {
1039 Some(TypeExpr::List(_)) => obj_type,
1040 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1041 Some(TypeExpr::Named(n)) if n == "string" => {
1042 Some(TypeExpr::Named("string".into()))
1043 }
1044 _ => None,
1045 }
1046 }
1047 Node::MethodCall { object, method, .. }
1048 | Node::OptionalMethodCall { object, method, .. } => {
1049 let obj_type = self.infer_type(object, scope);
1050 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1051 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1052 match method.as_str() {
1053 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1055 Some(TypeExpr::Named("bool".into()))
1056 }
1057 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1059 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1061 | "pad_left" | "pad_right" | "repeat" | "join" => {
1062 Some(TypeExpr::Named("string".into()))
1063 }
1064 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1065 "filter" => {
1067 if is_dict {
1068 Some(TypeExpr::Named("dict".into()))
1069 } else {
1070 Some(TypeExpr::Named("list".into()))
1071 }
1072 }
1073 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1075 "reduce" | "find" | "first" | "last" => None,
1076 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1078 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1079 "to_string" => Some(TypeExpr::Named("string".into())),
1081 "to_int" => Some(TypeExpr::Named("int".into())),
1082 "to_float" => Some(TypeExpr::Named("float".into())),
1083 _ => None,
1084 }
1085 }
1086
1087 _ => None,
1088 }
1089 }
1090
1091 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1093 if let TypeExpr::Named(name) = expected {
1095 if scope.is_generic_type_param(name) {
1096 return true;
1097 }
1098 }
1099 if let TypeExpr::Named(name) = actual {
1100 if scope.is_generic_type_param(name) {
1101 return true;
1102 }
1103 }
1104 let expected = self.resolve_alias(expected, scope);
1105 let actual = self.resolve_alias(actual, scope);
1106
1107 match (&expected, &actual) {
1108 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1109 (TypeExpr::Union(members), actual_type) => members
1110 .iter()
1111 .any(|m| self.types_compatible(m, actual_type, scope)),
1112 (expected_type, TypeExpr::Union(members)) => members
1113 .iter()
1114 .all(|m| self.types_compatible(expected_type, m, scope)),
1115 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1116 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1117 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1118 if expected_field.optional {
1119 return true;
1120 }
1121 af.iter().any(|actual_field| {
1122 actual_field.name == expected_field.name
1123 && self.types_compatible(
1124 &expected_field.type_expr,
1125 &actual_field.type_expr,
1126 scope,
1127 )
1128 })
1129 }),
1130 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1132 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1133 keys_ok
1134 && af
1135 .iter()
1136 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1137 }
1138 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1140 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1141 self.types_compatible(expected_inner, actual_inner, scope)
1142 }
1143 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1144 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1145 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1146 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1147 }
1148 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1149 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1150 (
1152 TypeExpr::FnType {
1153 params: ep,
1154 return_type: er,
1155 },
1156 TypeExpr::FnType {
1157 params: ap,
1158 return_type: ar,
1159 },
1160 ) => {
1161 ep.len() == ap.len()
1162 && ep
1163 .iter()
1164 .zip(ap.iter())
1165 .all(|(e, a)| self.types_compatible(e, a, scope))
1166 && self.types_compatible(er, ar, scope)
1167 }
1168 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1170 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1171 _ => false,
1172 }
1173 }
1174
1175 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1176 if let TypeExpr::Named(name) = ty {
1177 if let Some(resolved) = scope.resolve_type(name) {
1178 return resolved.clone();
1179 }
1180 }
1181 ty.clone()
1182 }
1183
1184 fn error_at(&mut self, message: String, span: Span) {
1185 self.diagnostics.push(TypeDiagnostic {
1186 message,
1187 severity: DiagnosticSeverity::Error,
1188 span: Some(span),
1189 });
1190 }
1191
1192 fn warning_at(&mut self, message: String, span: Span) {
1193 self.diagnostics.push(TypeDiagnostic {
1194 message,
1195 severity: DiagnosticSeverity::Warning,
1196 span: Some(span),
1197 });
1198 }
1199}
1200
1201impl Default for TypeChecker {
1202 fn default() -> Self {
1203 Self::new()
1204 }
1205}
1206
1207fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1209 match op {
1210 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1211 "+" => match (left, right) {
1212 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1213 match (l.as_str(), r.as_str()) {
1214 ("int", "int") => Some(TypeExpr::Named("int".into())),
1215 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1216 ("string", _) => Some(TypeExpr::Named("string".into())),
1217 ("list", "list") => Some(TypeExpr::Named("list".into())),
1218 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1219 _ => Some(TypeExpr::Named("string".into())),
1220 }
1221 }
1222 _ => None,
1223 },
1224 "-" | "*" | "/" | "%" => match (left, right) {
1225 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1226 match (l.as_str(), r.as_str()) {
1227 ("int", "int") => Some(TypeExpr::Named("int".into())),
1228 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1229 _ => None,
1230 }
1231 }
1232 _ => None,
1233 },
1234 "??" => match (left, right) {
1235 (Some(TypeExpr::Union(members)), _) => {
1236 let non_nil: Vec<_> = members
1237 .iter()
1238 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1239 .cloned()
1240 .collect();
1241 if non_nil.len() == 1 {
1242 Some(non_nil[0].clone())
1243 } else if non_nil.is_empty() {
1244 right.clone()
1245 } else {
1246 Some(TypeExpr::Union(non_nil))
1247 }
1248 }
1249 _ => right.clone(),
1250 },
1251 "|>" => None,
1252 _ => None,
1253 }
1254}
1255
1256pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1261 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1262 let mut details = Vec::new();
1263 for field in ef {
1264 if field.optional {
1265 continue;
1266 }
1267 match af.iter().find(|f| f.name == field.name) {
1268 None => details.push(format!(
1269 "missing field '{}' ({})",
1270 field.name,
1271 format_type(&field.type_expr)
1272 )),
1273 Some(actual_field) => {
1274 let e_str = format_type(&field.type_expr);
1275 let a_str = format_type(&actual_field.type_expr);
1276 if e_str != a_str {
1277 details.push(format!(
1278 "field '{}' has type {}, expected {}",
1279 field.name, a_str, e_str
1280 ));
1281 }
1282 }
1283 }
1284 }
1285 if details.is_empty() {
1286 None
1287 } else {
1288 Some(details.join("; "))
1289 }
1290 } else {
1291 None
1292 }
1293}
1294
1295pub fn format_type(ty: &TypeExpr) -> String {
1296 match ty {
1297 TypeExpr::Named(n) => n.clone(),
1298 TypeExpr::Union(types) => types
1299 .iter()
1300 .map(format_type)
1301 .collect::<Vec<_>>()
1302 .join(" | "),
1303 TypeExpr::Shape(fields) => {
1304 let inner: Vec<String> = fields
1305 .iter()
1306 .map(|f| {
1307 let opt = if f.optional { "?" } else { "" };
1308 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1309 })
1310 .collect();
1311 format!("{{{}}}", inner.join(", "))
1312 }
1313 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1314 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1315 TypeExpr::FnType {
1316 params,
1317 return_type,
1318 } => {
1319 let params_str = params
1320 .iter()
1321 .map(format_type)
1322 .collect::<Vec<_>>()
1323 .join(", ");
1324 format!("fn({}) -> {}", params_str, format_type(return_type))
1325 }
1326 }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331 use super::*;
1332 use crate::Parser;
1333 use harn_lexer::Lexer;
1334
1335 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1336 let mut lexer = Lexer::new(source);
1337 let tokens = lexer.tokenize().unwrap();
1338 let mut parser = Parser::new(tokens);
1339 let program = parser.parse().unwrap();
1340 TypeChecker::new().check(&program)
1341 }
1342
1343 fn errors(source: &str) -> Vec<String> {
1344 check_source(source)
1345 .into_iter()
1346 .filter(|d| d.severity == DiagnosticSeverity::Error)
1347 .map(|d| d.message)
1348 .collect()
1349 }
1350
1351 #[test]
1352 fn test_no_errors_for_untyped_code() {
1353 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1354 assert!(errs.is_empty());
1355 }
1356
1357 #[test]
1358 fn test_correct_typed_let() {
1359 let errs = errors("pipeline t(task) { let x: int = 42 }");
1360 assert!(errs.is_empty());
1361 }
1362
1363 #[test]
1364 fn test_type_mismatch_let() {
1365 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1366 assert_eq!(errs.len(), 1);
1367 assert!(errs[0].contains("Type mismatch"));
1368 assert!(errs[0].contains("int"));
1369 assert!(errs[0].contains("string"));
1370 }
1371
1372 #[test]
1373 fn test_correct_typed_fn() {
1374 let errs = errors(
1375 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1376 );
1377 assert!(errs.is_empty());
1378 }
1379
1380 #[test]
1381 fn test_fn_arg_type_mismatch() {
1382 let errs = errors(
1383 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1384add("hello", 2) }"#,
1385 );
1386 assert_eq!(errs.len(), 1);
1387 assert!(errs[0].contains("Argument 1"));
1388 assert!(errs[0].contains("expected int"));
1389 }
1390
1391 #[test]
1392 fn test_return_type_mismatch() {
1393 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1394 assert_eq!(errs.len(), 1);
1395 assert!(errs[0].contains("Return type mismatch"));
1396 }
1397
1398 #[test]
1399 fn test_union_type_compatible() {
1400 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1401 assert!(errs.is_empty());
1402 }
1403
1404 #[test]
1405 fn test_union_type_mismatch() {
1406 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1407 assert_eq!(errs.len(), 1);
1408 assert!(errs[0].contains("Type mismatch"));
1409 }
1410
1411 #[test]
1412 fn test_type_inference_propagation() {
1413 let errs = errors(
1414 r#"pipeline t(task) {
1415 fn add(a: int, b: int) -> int { return a + b }
1416 let result: string = add(1, 2)
1417}"#,
1418 );
1419 assert_eq!(errs.len(), 1);
1420 assert!(errs[0].contains("Type mismatch"));
1421 assert!(errs[0].contains("string"));
1422 assert!(errs[0].contains("int"));
1423 }
1424
1425 #[test]
1426 fn test_builtin_return_type_inference() {
1427 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1428 assert_eq!(errs.len(), 1);
1429 assert!(errs[0].contains("string"));
1430 assert!(errs[0].contains("int"));
1431 }
1432
1433 #[test]
1434 fn test_binary_op_type_inference() {
1435 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1436 assert_eq!(errs.len(), 1);
1437 }
1438
1439 #[test]
1440 fn test_comparison_returns_bool() {
1441 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1442 assert!(errs.is_empty());
1443 }
1444
1445 #[test]
1446 fn test_int_float_promotion() {
1447 let errs = errors("pipeline t(task) { let x: float = 42 }");
1448 assert!(errs.is_empty());
1449 }
1450
1451 #[test]
1452 fn test_untyped_code_no_errors() {
1453 let errs = errors(
1454 r#"pipeline t(task) {
1455 fn process(data) {
1456 let result = data + " processed"
1457 return result
1458 }
1459 log(process("hello"))
1460}"#,
1461 );
1462 assert!(errs.is_empty());
1463 }
1464
1465 #[test]
1466 fn test_type_alias() {
1467 let errs = errors(
1468 r#"pipeline t(task) {
1469 type Name = string
1470 let x: Name = "hello"
1471}"#,
1472 );
1473 assert!(errs.is_empty());
1474 }
1475
1476 #[test]
1477 fn test_type_alias_mismatch() {
1478 let errs = errors(
1479 r#"pipeline t(task) {
1480 type Name = string
1481 let x: Name = 42
1482}"#,
1483 );
1484 assert_eq!(errs.len(), 1);
1485 }
1486
1487 #[test]
1488 fn test_assignment_type_check() {
1489 let errs = errors(
1490 r#"pipeline t(task) {
1491 var x: int = 0
1492 x = "hello"
1493}"#,
1494 );
1495 assert_eq!(errs.len(), 1);
1496 assert!(errs[0].contains("cannot assign string"));
1497 }
1498
1499 #[test]
1500 fn test_covariance_int_to_float_in_fn() {
1501 let errs = errors(
1502 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1503 );
1504 assert!(errs.is_empty());
1505 }
1506
1507 #[test]
1508 fn test_covariance_return_type() {
1509 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1510 assert!(errs.is_empty());
1511 }
1512
1513 #[test]
1514 fn test_no_contravariance_float_to_int() {
1515 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1516 assert_eq!(errs.len(), 1);
1517 }
1518
1519 fn warnings(source: &str) -> Vec<String> {
1522 check_source(source)
1523 .into_iter()
1524 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1525 .map(|d| d.message)
1526 .collect()
1527 }
1528
1529 #[test]
1530 fn test_exhaustive_match_no_warning() {
1531 let warns = warnings(
1532 r#"pipeline t(task) {
1533 enum Color { Red, Green, Blue }
1534 let c = Color.Red
1535 match c.variant {
1536 "Red" -> { log("r") }
1537 "Green" -> { log("g") }
1538 "Blue" -> { log("b") }
1539 }
1540}"#,
1541 );
1542 let exhaustive_warns: Vec<_> = warns
1543 .iter()
1544 .filter(|w| w.contains("Non-exhaustive"))
1545 .collect();
1546 assert!(exhaustive_warns.is_empty());
1547 }
1548
1549 #[test]
1550 fn test_non_exhaustive_match_warning() {
1551 let warns = warnings(
1552 r#"pipeline t(task) {
1553 enum Color { Red, Green, Blue }
1554 let c = Color.Red
1555 match c.variant {
1556 "Red" -> { log("r") }
1557 "Green" -> { log("g") }
1558 }
1559}"#,
1560 );
1561 let exhaustive_warns: Vec<_> = warns
1562 .iter()
1563 .filter(|w| w.contains("Non-exhaustive"))
1564 .collect();
1565 assert_eq!(exhaustive_warns.len(), 1);
1566 assert!(exhaustive_warns[0].contains("Blue"));
1567 }
1568
1569 #[test]
1570 fn test_non_exhaustive_multiple_missing() {
1571 let warns = warnings(
1572 r#"pipeline t(task) {
1573 enum Status { Active, Inactive, Pending }
1574 let s = Status.Active
1575 match s.variant {
1576 "Active" -> { log("a") }
1577 }
1578}"#,
1579 );
1580 let exhaustive_warns: Vec<_> = warns
1581 .iter()
1582 .filter(|w| w.contains("Non-exhaustive"))
1583 .collect();
1584 assert_eq!(exhaustive_warns.len(), 1);
1585 assert!(exhaustive_warns[0].contains("Inactive"));
1586 assert!(exhaustive_warns[0].contains("Pending"));
1587 }
1588
1589 #[test]
1590 fn test_enum_construct_type_inference() {
1591 let errs = errors(
1592 r#"pipeline t(task) {
1593 enum Color { Red, Green, Blue }
1594 let c: Color = Color.Red
1595}"#,
1596 );
1597 assert!(errs.is_empty());
1598 }
1599
1600 #[test]
1603 fn test_nil_coalescing_strips_nil() {
1604 let errs = errors(
1606 r#"pipeline t(task) {
1607 let x: string | nil = nil
1608 let y: string = x ?? "default"
1609}"#,
1610 );
1611 assert!(errs.is_empty());
1612 }
1613
1614 #[test]
1615 fn test_shape_mismatch_detail_missing_field() {
1616 let errs = errors(
1617 r#"pipeline t(task) {
1618 let x: {name: string, age: int} = {name: "hello"}
1619}"#,
1620 );
1621 assert_eq!(errs.len(), 1);
1622 assert!(
1623 errs[0].contains("missing field 'age'"),
1624 "expected detail about missing field, got: {}",
1625 errs[0]
1626 );
1627 }
1628
1629 #[test]
1630 fn test_shape_mismatch_detail_wrong_type() {
1631 let errs = errors(
1632 r#"pipeline t(task) {
1633 let x: {name: string, age: int} = {name: 42, age: 10}
1634}"#,
1635 );
1636 assert_eq!(errs.len(), 1);
1637 assert!(
1638 errs[0].contains("field 'name' has type int, expected string"),
1639 "expected detail about wrong type, got: {}",
1640 errs[0]
1641 );
1642 }
1643
1644 #[test]
1647 fn test_match_pattern_string_against_int() {
1648 let warns = warnings(
1649 r#"pipeline t(task) {
1650 let x: int = 42
1651 match x {
1652 "hello" -> { log("bad") }
1653 42 -> { log("ok") }
1654 }
1655}"#,
1656 );
1657 let pattern_warns: Vec<_> = warns
1658 .iter()
1659 .filter(|w| w.contains("Match pattern type mismatch"))
1660 .collect();
1661 assert_eq!(pattern_warns.len(), 1);
1662 assert!(pattern_warns[0].contains("matching int against string literal"));
1663 }
1664
1665 #[test]
1666 fn test_match_pattern_int_against_string() {
1667 let warns = warnings(
1668 r#"pipeline t(task) {
1669 let x: string = "hello"
1670 match x {
1671 42 -> { log("bad") }
1672 "hello" -> { log("ok") }
1673 }
1674}"#,
1675 );
1676 let pattern_warns: Vec<_> = warns
1677 .iter()
1678 .filter(|w| w.contains("Match pattern type mismatch"))
1679 .collect();
1680 assert_eq!(pattern_warns.len(), 1);
1681 assert!(pattern_warns[0].contains("matching string against int literal"));
1682 }
1683
1684 #[test]
1685 fn test_match_pattern_bool_against_int() {
1686 let warns = warnings(
1687 r#"pipeline t(task) {
1688 let x: int = 42
1689 match x {
1690 true -> { log("bad") }
1691 42 -> { log("ok") }
1692 }
1693}"#,
1694 );
1695 let pattern_warns: Vec<_> = warns
1696 .iter()
1697 .filter(|w| w.contains("Match pattern type mismatch"))
1698 .collect();
1699 assert_eq!(pattern_warns.len(), 1);
1700 assert!(pattern_warns[0].contains("matching int against bool literal"));
1701 }
1702
1703 #[test]
1704 fn test_match_pattern_float_against_string() {
1705 let warns = warnings(
1706 r#"pipeline t(task) {
1707 let x: string = "hello"
1708 match x {
1709 3.14 -> { log("bad") }
1710 "hello" -> { log("ok") }
1711 }
1712}"#,
1713 );
1714 let pattern_warns: Vec<_> = warns
1715 .iter()
1716 .filter(|w| w.contains("Match pattern type mismatch"))
1717 .collect();
1718 assert_eq!(pattern_warns.len(), 1);
1719 assert!(pattern_warns[0].contains("matching string against float literal"));
1720 }
1721
1722 #[test]
1723 fn test_match_pattern_int_against_float_ok() {
1724 let warns = warnings(
1726 r#"pipeline t(task) {
1727 let x: float = 3.14
1728 match x {
1729 42 -> { log("ok") }
1730 _ -> { log("default") }
1731 }
1732}"#,
1733 );
1734 let pattern_warns: Vec<_> = warns
1735 .iter()
1736 .filter(|w| w.contains("Match pattern type mismatch"))
1737 .collect();
1738 assert!(pattern_warns.is_empty());
1739 }
1740
1741 #[test]
1742 fn test_match_pattern_float_against_int_ok() {
1743 let warns = warnings(
1745 r#"pipeline t(task) {
1746 let x: int = 42
1747 match x {
1748 3.14 -> { log("close") }
1749 _ -> { log("default") }
1750 }
1751}"#,
1752 );
1753 let pattern_warns: Vec<_> = warns
1754 .iter()
1755 .filter(|w| w.contains("Match pattern type mismatch"))
1756 .collect();
1757 assert!(pattern_warns.is_empty());
1758 }
1759
1760 #[test]
1761 fn test_match_pattern_correct_types_no_warning() {
1762 let warns = warnings(
1763 r#"pipeline t(task) {
1764 let x: int = 42
1765 match x {
1766 1 -> { log("one") }
1767 2 -> { log("two") }
1768 _ -> { log("other") }
1769 }
1770}"#,
1771 );
1772 let pattern_warns: Vec<_> = warns
1773 .iter()
1774 .filter(|w| w.contains("Match pattern type mismatch"))
1775 .collect();
1776 assert!(pattern_warns.is_empty());
1777 }
1778
1779 #[test]
1780 fn test_match_pattern_wildcard_no_warning() {
1781 let warns = warnings(
1782 r#"pipeline t(task) {
1783 let x: int = 42
1784 match x {
1785 _ -> { log("catch all") }
1786 }
1787}"#,
1788 );
1789 let pattern_warns: Vec<_> = warns
1790 .iter()
1791 .filter(|w| w.contains("Match pattern type mismatch"))
1792 .collect();
1793 assert!(pattern_warns.is_empty());
1794 }
1795
1796 #[test]
1797 fn test_match_pattern_untyped_no_warning() {
1798 let warns = warnings(
1800 r#"pipeline t(task) {
1801 let x = some_unknown_fn()
1802 match x {
1803 "hello" -> { log("string") }
1804 42 -> { log("int") }
1805 }
1806}"#,
1807 );
1808 let pattern_warns: Vec<_> = warns
1809 .iter()
1810 .filter(|w| w.contains("Match pattern type mismatch"))
1811 .collect();
1812 assert!(pattern_warns.is_empty());
1813 }
1814}