1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 generic_type_params: std::collections::BTreeSet<String>,
40 parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45 params: Vec<(String, InferredType)>,
46 return_type: InferredType,
47 type_param_names: Vec<String>,
49 required_params: usize,
51}
52
53impl TypeScope {
54 fn new() -> Self {
55 Self {
56 vars: BTreeMap::new(),
57 functions: BTreeMap::new(),
58 type_aliases: BTreeMap::new(),
59 enums: BTreeMap::new(),
60 interfaces: BTreeMap::new(),
61 structs: BTreeMap::new(),
62 generic_type_params: std::collections::BTreeSet::new(),
63 parent: None,
64 }
65 }
66
67 fn child(&self) -> Self {
68 Self {
69 vars: BTreeMap::new(),
70 functions: BTreeMap::new(),
71 type_aliases: BTreeMap::new(),
72 enums: BTreeMap::new(),
73 interfaces: BTreeMap::new(),
74 structs: BTreeMap::new(),
75 generic_type_params: std::collections::BTreeSet::new(),
76 parent: Some(Box::new(self.clone())),
77 }
78 }
79
80 fn get_var(&self, name: &str) -> Option<&InferredType> {
81 self.vars
82 .get(name)
83 .or_else(|| self.parent.as_ref()?.get_var(name))
84 }
85
86 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87 self.functions
88 .get(name)
89 .or_else(|| self.parent.as_ref()?.get_fn(name))
90 }
91
92 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93 self.type_aliases
94 .get(name)
95 .or_else(|| self.parent.as_ref()?.resolve_type(name))
96 }
97
98 fn is_generic_type_param(&self, name: &str) -> bool {
99 self.generic_type_params.contains(name)
100 || self
101 .parent
102 .as_ref()
103 .is_some_and(|p| p.is_generic_type_param(name))
104 }
105
106 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107 self.enums
108 .get(name)
109 .or_else(|| self.parent.as_ref()?.get_enum(name))
110 }
111
112 #[allow(dead_code)]
113 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114 self.interfaces
115 .get(name)
116 .or_else(|| self.parent.as_ref()?.get_interface(name))
117 }
118
119 fn define_var(&mut self, name: &str, ty: InferredType) {
120 self.vars.insert(name.to_string(), ty);
121 }
122
123 fn define_fn(&mut self, name: &str, sig: FnSignature) {
124 self.functions.insert(name.to_string(), sig);
125 }
126}
127
128fn builtin_return_type(name: &str) -> InferredType {
130 match name {
131 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133 Some(TypeExpr::Named("nil".into()))
134 }
135 "type_of"
136 | "to_string"
137 | "json_stringify"
138 | "read_file"
139 | "http_get"
140 | "http_post"
141 | "llm_call"
142 | "regex_replace"
143 | "path_join"
144 | "temp_dir"
145 | "date_format"
146 | "format"
147 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
148 "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
149 "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
150 | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
151 Some(TypeExpr::Named("float".into()))
152 }
153 "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
154 Some(TypeExpr::Named("bool".into()))
155 }
156 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list" => {
157 Some(TypeExpr::Named("list".into()))
158 }
159 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
160 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
161 Some(TypeExpr::Named("dict".into()))
162 }
163 "metadata_set"
164 | "metadata_save"
165 | "metadata_refresh_hashes"
166 | "invalidate_facts"
167 | "log_json"
168 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
169 "env" | "regex_match" => Some(TypeExpr::Union(vec![
170 TypeExpr::Named("string".into()),
171 TypeExpr::Named("nil".into()),
172 ])),
173 "json_parse" | "json_extract" => None, _ => None,
175 }
176}
177
178fn is_builtin(name: &str) -> bool {
180 matches!(
181 name,
182 "log"
183 | "print"
184 | "println"
185 | "type_of"
186 | "to_string"
187 | "to_int"
188 | "to_float"
189 | "json_stringify"
190 | "json_parse"
191 | "env"
192 | "timestamp"
193 | "sleep"
194 | "read_file"
195 | "write_file"
196 | "exit"
197 | "regex_match"
198 | "regex_replace"
199 | "http_get"
200 | "http_post"
201 | "llm_call"
202 | "agent_loop"
203 | "await"
204 | "cancel"
205 | "file_exists"
206 | "delete_file"
207 | "list_dir"
208 | "mkdir"
209 | "path_join"
210 | "copy_file"
211 | "append_file"
212 | "temp_dir"
213 | "stat"
214 | "exec"
215 | "shell"
216 | "date_now"
217 | "date_format"
218 | "date_parse"
219 | "format"
220 | "json_validate"
221 | "json_extract"
222 | "trim"
223 | "lowercase"
224 | "uppercase"
225 | "split"
226 | "starts_with"
227 | "ends_with"
228 | "contains"
229 | "replace"
230 | "join"
231 | "len"
232 | "substring"
233 | "dirname"
234 | "basename"
235 | "extname"
236 | "sin"
237 | "cos"
238 | "tan"
239 | "asin"
240 | "acos"
241 | "atan"
242 | "atan2"
243 | "log2"
244 | "log10"
245 | "ln"
246 | "exp"
247 | "pi"
248 | "e"
249 | "sign"
250 | "is_nan"
251 | "is_infinite"
252 | "set"
253 | "set_add"
254 | "set_remove"
255 | "set_contains"
256 | "set_union"
257 | "set_intersect"
258 | "set_difference"
259 | "to_list"
260 )
261}
262
263pub struct TypeChecker {
265 diagnostics: Vec<TypeDiagnostic>,
266 scope: TypeScope,
267}
268
269impl TypeChecker {
270 pub fn new() -> Self {
271 Self {
272 diagnostics: Vec::new(),
273 scope: TypeScope::new(),
274 }
275 }
276
277 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
279 Self::register_declarations_into(&mut self.scope, program);
281
282 for snode in program {
284 if let Node::Pipeline { body, .. } = &snode.node {
285 Self::register_declarations_into(&mut self.scope, body);
286 }
287 }
288
289 for snode in program {
291 match &snode.node {
292 Node::Pipeline { params, body, .. } => {
293 let mut child = self.scope.child();
294 for p in params {
295 child.define_var(p, None);
296 }
297 self.check_block(body, &mut child);
298 }
299 Node::FnDecl {
300 name,
301 type_params,
302 params,
303 return_type,
304 body,
305 ..
306 } => {
307 let required_params =
308 params.iter().filter(|p| p.default_value.is_none()).count();
309 let sig = FnSignature {
310 params: params
311 .iter()
312 .map(|p| (p.name.clone(), p.type_expr.clone()))
313 .collect(),
314 return_type: return_type.clone(),
315 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
316 required_params,
317 };
318 self.scope.define_fn(name, sig);
319 self.check_fn_body(type_params, params, return_type, body);
320 }
321 _ => {
322 let mut scope = self.scope.clone();
323 self.check_node(snode, &mut scope);
324 for (name, ty) in scope.vars {
326 self.scope.vars.entry(name).or_insert(ty);
327 }
328 }
329 }
330 }
331
332 self.diagnostics
333 }
334
335 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
337 for snode in nodes {
338 match &snode.node {
339 Node::TypeDecl { name, type_expr } => {
340 scope.type_aliases.insert(name.clone(), type_expr.clone());
341 }
342 Node::EnumDecl { name, variants } => {
343 let variant_names: Vec<String> =
344 variants.iter().map(|v| v.name.clone()).collect();
345 scope.enums.insert(name.clone(), variant_names);
346 }
347 Node::InterfaceDecl { name, methods } => {
348 scope.interfaces.insert(name.clone(), methods.clone());
349 }
350 Node::StructDecl { name, fields } => {
351 let field_types: Vec<(String, InferredType)> = fields
352 .iter()
353 .map(|f| (f.name.clone(), f.type_expr.clone()))
354 .collect();
355 scope.structs.insert(name.clone(), field_types);
356 }
357 _ => {}
358 }
359 }
360 }
361
362 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
363 for stmt in stmts {
364 self.check_node(stmt, scope);
365 }
366 }
367
368 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
370 match pattern {
371 BindingPattern::Identifier(name) => {
372 scope.define_var(name, None);
373 }
374 BindingPattern::Dict(fields) => {
375 for field in fields {
376 let name = field.alias.as_deref().unwrap_or(&field.key);
377 scope.define_var(name, None);
378 }
379 }
380 BindingPattern::List(elements) => {
381 for elem in elements {
382 scope.define_var(&elem.name, None);
383 }
384 }
385 }
386 }
387
388 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
389 let span = snode.span;
390 match &snode.node {
391 Node::LetBinding {
392 pattern,
393 type_ann,
394 value,
395 } => {
396 let inferred = self.infer_type(value, scope);
397 if let BindingPattern::Identifier(name) = pattern {
398 if let Some(expected) = type_ann {
399 if let Some(actual) = &inferred {
400 if !self.types_compatible(expected, actual, scope) {
401 let mut msg = format!(
402 "Type mismatch: '{}' declared as {}, but assigned {}",
403 name,
404 format_type(expected),
405 format_type(actual)
406 );
407 if let Some(detail) = shape_mismatch_detail(expected, actual) {
408 msg.push_str(&format!(" ({})", detail));
409 }
410 self.error_at(msg, span);
411 }
412 }
413 }
414 let ty = type_ann.clone().or(inferred);
415 scope.define_var(name, ty);
416 } else {
417 Self::define_pattern_vars(pattern, scope);
418 }
419 }
420
421 Node::VarBinding {
422 pattern,
423 type_ann,
424 value,
425 } => {
426 let inferred = self.infer_type(value, scope);
427 if let BindingPattern::Identifier(name) = pattern {
428 if let Some(expected) = type_ann {
429 if let Some(actual) = &inferred {
430 if !self.types_compatible(expected, actual, scope) {
431 let mut msg = format!(
432 "Type mismatch: '{}' declared as {}, but assigned {}",
433 name,
434 format_type(expected),
435 format_type(actual)
436 );
437 if let Some(detail) = shape_mismatch_detail(expected, actual) {
438 msg.push_str(&format!(" ({})", detail));
439 }
440 self.error_at(msg, span);
441 }
442 }
443 }
444 let ty = type_ann.clone().or(inferred);
445 scope.define_var(name, ty);
446 } else {
447 Self::define_pattern_vars(pattern, scope);
448 }
449 }
450
451 Node::FnDecl {
452 name,
453 type_params,
454 params,
455 return_type,
456 body,
457 ..
458 } => {
459 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
460 let sig = FnSignature {
461 params: params
462 .iter()
463 .map(|p| (p.name.clone(), p.type_expr.clone()))
464 .collect(),
465 return_type: return_type.clone(),
466 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
467 required_params,
468 };
469 scope.define_fn(name, sig.clone());
470 scope.define_var(name, None);
471 self.check_fn_body(type_params, params, return_type, body);
472 }
473
474 Node::FunctionCall { name, args } => {
475 self.check_call(name, args, scope, span);
476 }
477
478 Node::IfElse {
479 condition,
480 then_body,
481 else_body,
482 } => {
483 self.check_node(condition, scope);
484 let mut then_scope = scope.child();
485 if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
487 then_scope.define_var(&var_name, narrowed);
488 }
489 self.check_block(then_body, &mut then_scope);
490 if let Some(else_body) = else_body {
491 let mut else_scope = scope.child();
492 self.check_block(else_body, &mut else_scope);
493 }
494 }
495
496 Node::ForIn {
497 pattern,
498 iterable,
499 body,
500 } => {
501 self.check_node(iterable, scope);
502 let mut loop_scope = scope.child();
503 if let BindingPattern::Identifier(variable) = pattern {
504 let elem_type = match self.infer_type(iterable, scope) {
506 Some(TypeExpr::List(inner)) => Some(*inner),
507 Some(TypeExpr::Named(n)) if n == "string" => {
508 Some(TypeExpr::Named("string".into()))
509 }
510 _ => None,
511 };
512 loop_scope.define_var(variable, elem_type);
513 } else {
514 Self::define_pattern_vars(pattern, &mut loop_scope);
515 }
516 self.check_block(body, &mut loop_scope);
517 }
518
519 Node::WhileLoop { condition, body } => {
520 self.check_node(condition, scope);
521 let mut loop_scope = scope.child();
522 self.check_block(body, &mut loop_scope);
523 }
524
525 Node::TryCatch {
526 body,
527 error_var,
528 catch_body,
529 finally_body,
530 ..
531 } => {
532 let mut try_scope = scope.child();
533 self.check_block(body, &mut try_scope);
534 let mut catch_scope = scope.child();
535 if let Some(var) = error_var {
536 catch_scope.define_var(var, None);
537 }
538 self.check_block(catch_body, &mut catch_scope);
539 if let Some(fb) = finally_body {
540 let mut finally_scope = scope.child();
541 self.check_block(fb, &mut finally_scope);
542 }
543 }
544
545 Node::ReturnStmt {
546 value: Some(val), ..
547 } => {
548 self.check_node(val, scope);
549 }
550
551 Node::Assignment {
552 target, value, op, ..
553 } => {
554 self.check_node(value, scope);
555 if let Node::Identifier(name) = &target.node {
556 if let Some(Some(var_type)) = scope.get_var(name) {
557 let value_type = self.infer_type(value, scope);
558 let assigned = if let Some(op) = op {
559 let var_inferred = scope.get_var(name).cloned().flatten();
560 infer_binary_op_type(op, &var_inferred, &value_type)
561 } else {
562 value_type
563 };
564 if let Some(actual) = &assigned {
565 if !self.types_compatible(var_type, actual, scope) {
566 self.error_at(
567 format!(
568 "Type mismatch: cannot assign {} to '{}' (declared as {})",
569 format_type(actual),
570 name,
571 format_type(var_type)
572 ),
573 span,
574 );
575 }
576 }
577 }
578 }
579 }
580
581 Node::TypeDecl { name, type_expr } => {
582 scope.type_aliases.insert(name.clone(), type_expr.clone());
583 }
584
585 Node::EnumDecl { name, variants } => {
586 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
587 scope.enums.insert(name.clone(), variant_names);
588 }
589
590 Node::StructDecl { name, fields } => {
591 let field_types: Vec<(String, InferredType)> = fields
592 .iter()
593 .map(|f| (f.name.clone(), f.type_expr.clone()))
594 .collect();
595 scope.structs.insert(name.clone(), field_types);
596 }
597
598 Node::InterfaceDecl { name, methods } => {
599 scope.interfaces.insert(name.clone(), methods.clone());
600 }
601
602 Node::ImplBlock { methods, .. } => {
603 for method_sn in methods {
604 self.check_node(method_sn, scope);
605 }
606 }
607
608 Node::TryOperator { operand } => {
609 self.check_node(operand, scope);
610 }
611
612 Node::MatchExpr { value, arms } => {
613 self.check_node(value, scope);
614 let value_type = self.infer_type(value, scope);
615 for arm in arms {
616 self.check_node(&arm.pattern, scope);
617 if let Some(ref vt) = value_type {
619 let value_type_name = format_type(vt);
620 let mismatch = match &arm.pattern.node {
621 Node::StringLiteral(_) => {
622 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
623 }
624 Node::IntLiteral(_) => {
625 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
626 && !self.types_compatible(
627 vt,
628 &TypeExpr::Named("float".into()),
629 scope,
630 )
631 }
632 Node::FloatLiteral(_) => {
633 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
634 && !self.types_compatible(
635 vt,
636 &TypeExpr::Named("int".into()),
637 scope,
638 )
639 }
640 Node::BoolLiteral(_) => {
641 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
642 }
643 _ => false,
644 };
645 if mismatch {
646 let pattern_type = match &arm.pattern.node {
647 Node::StringLiteral(_) => "string",
648 Node::IntLiteral(_) => "int",
649 Node::FloatLiteral(_) => "float",
650 Node::BoolLiteral(_) => "bool",
651 _ => unreachable!(),
652 };
653 self.warning_at(
654 format!(
655 "Match pattern type mismatch: matching {} against {} literal",
656 value_type_name, pattern_type
657 ),
658 arm.pattern.span,
659 );
660 }
661 }
662 let mut arm_scope = scope.child();
663 self.check_block(&arm.body, &mut arm_scope);
664 }
665 self.check_match_exhaustiveness(value, arms, scope, span);
666 }
667
668 Node::BinaryOp { op, left, right } => {
670 self.check_node(left, scope);
671 self.check_node(right, scope);
672 let lt = self.infer_type(left, scope);
674 let rt = self.infer_type(right, scope);
675 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
676 match op.as_str() {
677 "-" | "*" | "/" | "%" => {
678 let numeric = ["int", "float"];
679 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
680 self.warning_at(
681 format!(
682 "Operator '{op}' may not be valid for types {} and {}",
683 l, r
684 ),
685 span,
686 );
687 }
688 }
689 "+" => {
690 let valid = ["int", "float", "string", "list", "dict"];
692 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
693 self.warning_at(
694 format!(
695 "Operator '+' may not be valid for types {} and {}",
696 l, r
697 ),
698 span,
699 );
700 }
701 }
702 _ => {}
703 }
704 }
705 }
706 Node::UnaryOp { operand, .. } => {
707 self.check_node(operand, scope);
708 }
709 Node::MethodCall { object, args, .. }
710 | Node::OptionalMethodCall { object, args, .. } => {
711 self.check_node(object, scope);
712 for arg in args {
713 self.check_node(arg, scope);
714 }
715 }
716 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
717 self.check_node(object, scope);
718 }
719 Node::SubscriptAccess { object, index } => {
720 self.check_node(object, scope);
721 self.check_node(index, scope);
722 }
723 Node::SliceAccess { object, start, end } => {
724 self.check_node(object, scope);
725 if let Some(s) = start {
726 self.check_node(s, scope);
727 }
728 if let Some(e) = end {
729 self.check_node(e, scope);
730 }
731 }
732
733 _ => {}
735 }
736 }
737
738 fn check_fn_body(
739 &mut self,
740 type_params: &[TypeParam],
741 params: &[TypedParam],
742 return_type: &Option<TypeExpr>,
743 body: &[SNode],
744 ) {
745 let mut fn_scope = self.scope.child();
746 for tp in type_params {
749 fn_scope.generic_type_params.insert(tp.name.clone());
750 }
751 for param in params {
752 fn_scope.define_var(¶m.name, param.type_expr.clone());
753 if let Some(default) = ¶m.default_value {
754 self.check_node(default, &mut fn_scope);
755 }
756 }
757 self.check_block(body, &mut fn_scope);
758
759 if let Some(ret_type) = return_type {
761 for stmt in body {
762 self.check_return_type(stmt, ret_type, &fn_scope);
763 }
764 }
765 }
766
767 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
768 let span = snode.span;
769 match &snode.node {
770 Node::ReturnStmt { value: Some(val) } => {
771 let inferred = self.infer_type(val, scope);
772 if let Some(actual) = &inferred {
773 if !self.types_compatible(expected, actual, scope) {
774 self.error_at(
775 format!(
776 "Return type mismatch: expected {}, got {}",
777 format_type(expected),
778 format_type(actual)
779 ),
780 span,
781 );
782 }
783 }
784 }
785 Node::IfElse {
786 then_body,
787 else_body,
788 ..
789 } => {
790 for stmt in then_body {
791 self.check_return_type(stmt, expected, scope);
792 }
793 if let Some(else_body) = else_body {
794 for stmt in else_body {
795 self.check_return_type(stmt, expected, scope);
796 }
797 }
798 }
799 _ => {}
800 }
801 }
802
803 fn extract_nil_narrowing(
807 condition: &SNode,
808 scope: &TypeScope,
809 ) -> Option<(String, InferredType)> {
810 if let Node::BinaryOp { op, left, right } = &condition.node {
811 if op == "!=" {
812 let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
814 (left, right)
815 } else if matches!(left.node, Node::NilLiteral) {
816 (right, left)
817 } else {
818 return None;
819 };
820 let _ = nil_node;
821 if let Node::Identifier(name) = &var_node.node {
822 if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
824 let narrowed: Vec<TypeExpr> = members
825 .iter()
826 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
827 .cloned()
828 .collect();
829 return if narrowed.len() == 1 {
830 Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
831 } else if narrowed.is_empty() {
832 None
833 } else {
834 Some((name.clone(), Some(TypeExpr::Union(narrowed))))
835 };
836 }
837 }
838 }
839 }
840 None
841 }
842
843 fn check_match_exhaustiveness(
844 &mut self,
845 value: &SNode,
846 arms: &[MatchArm],
847 scope: &TypeScope,
848 span: Span,
849 ) {
850 let enum_name = match &value.node {
852 Node::PropertyAccess { object, property } if property == "variant" => {
853 match self.infer_type(object, scope) {
855 Some(TypeExpr::Named(name)) => {
856 if scope.get_enum(&name).is_some() {
857 Some(name)
858 } else {
859 None
860 }
861 }
862 _ => None,
863 }
864 }
865 _ => {
866 match self.infer_type(value, scope) {
868 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
869 _ => None,
870 }
871 }
872 };
873
874 let Some(enum_name) = enum_name else {
875 return;
876 };
877 let Some(variants) = scope.get_enum(&enum_name) else {
878 return;
879 };
880
881 let mut covered: Vec<String> = Vec::new();
883 let mut has_wildcard = false;
884
885 for arm in arms {
886 match &arm.pattern.node {
887 Node::StringLiteral(s) => covered.push(s.clone()),
889 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
891 has_wildcard = true;
892 }
893 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
895 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
897 _ => {
898 has_wildcard = true;
900 }
901 }
902 }
903
904 if has_wildcard {
905 return;
906 }
907
908 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
909 if !missing.is_empty() {
910 let missing_str = missing
911 .iter()
912 .map(|s| format!("\"{}\"", s))
913 .collect::<Vec<_>>()
914 .join(", ");
915 self.warning_at(
916 format!(
917 "Non-exhaustive match on enum {}: missing variants {}",
918 enum_name, missing_str
919 ),
920 span,
921 );
922 }
923 }
924
925 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
926 if let Some(sig) = scope.get_fn(name).cloned() {
928 if !is_builtin(name)
929 && (args.len() < sig.required_params || args.len() > sig.params.len())
930 {
931 let expected = if sig.required_params == sig.params.len() {
932 format!("{}", sig.params.len())
933 } else {
934 format!("{}-{}", sig.required_params, sig.params.len())
935 };
936 self.warning_at(
937 format!(
938 "Function '{}' expects {} arguments, got {}",
939 name,
940 expected,
941 args.len()
942 ),
943 span,
944 );
945 }
946 let call_scope = if sig.type_param_names.is_empty() {
949 scope.clone()
950 } else {
951 let mut s = scope.child();
952 for tp_name in &sig.type_param_names {
953 s.generic_type_params.insert(tp_name.clone());
954 }
955 s
956 };
957 for (i, (arg, (param_name, param_type))) in
958 args.iter().zip(sig.params.iter()).enumerate()
959 {
960 if let Some(expected) = param_type {
961 let actual = self.infer_type(arg, scope);
962 if let Some(actual) = &actual {
963 if !self.types_compatible(expected, actual, &call_scope) {
964 self.error_at(
965 format!(
966 "Argument {} ('{}'): expected {}, got {}",
967 i + 1,
968 param_name,
969 format_type(expected),
970 format_type(actual)
971 ),
972 arg.span,
973 );
974 }
975 }
976 }
977 }
978 }
979 for arg in args {
981 self.check_node(arg, scope);
982 }
983 }
984
985 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
987 match &snode.node {
988 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
989 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
990 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
991 Some(TypeExpr::Named("string".into()))
992 }
993 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
994 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
995 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
996 Node::DictLiteral(entries) => {
997 let mut fields = Vec::new();
999 let mut all_string_keys = true;
1000 for entry in entries {
1001 if let Node::StringLiteral(key) = &entry.key.node {
1002 let val_type = self
1003 .infer_type(&entry.value, scope)
1004 .unwrap_or(TypeExpr::Named("nil".into()));
1005 fields.push(ShapeField {
1006 name: key.clone(),
1007 type_expr: val_type,
1008 optional: false,
1009 });
1010 } else {
1011 all_string_keys = false;
1012 break;
1013 }
1014 }
1015 if all_string_keys && !fields.is_empty() {
1016 Some(TypeExpr::Shape(fields))
1017 } else {
1018 Some(TypeExpr::Named("dict".into()))
1019 }
1020 }
1021 Node::Closure { params, body } => {
1022 let all_typed = params.iter().all(|p| p.type_expr.is_some());
1024 if all_typed && !params.is_empty() {
1025 let param_types: Vec<TypeExpr> =
1026 params.iter().filter_map(|p| p.type_expr.clone()).collect();
1027 let ret = body.last().and_then(|last| self.infer_type(last, scope));
1029 if let Some(ret_type) = ret {
1030 return Some(TypeExpr::FnType {
1031 params: param_types,
1032 return_type: Box::new(ret_type),
1033 });
1034 }
1035 }
1036 Some(TypeExpr::Named("closure".into()))
1037 }
1038
1039 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1040
1041 Node::FunctionCall { name, .. } => {
1042 if let Some(sig) = scope.get_fn(name) {
1044 return sig.return_type.clone();
1045 }
1046 builtin_return_type(name)
1048 }
1049
1050 Node::BinaryOp { op, left, right } => {
1051 let lt = self.infer_type(left, scope);
1052 let rt = self.infer_type(right, scope);
1053 infer_binary_op_type(op, <, &rt)
1054 }
1055
1056 Node::UnaryOp { op, operand } => {
1057 let t = self.infer_type(operand, scope);
1058 match op.as_str() {
1059 "!" => Some(TypeExpr::Named("bool".into())),
1060 "-" => t, _ => None,
1062 }
1063 }
1064
1065 Node::Ternary {
1066 true_expr,
1067 false_expr,
1068 ..
1069 } => {
1070 let tt = self.infer_type(true_expr, scope);
1071 let ft = self.infer_type(false_expr, scope);
1072 match (&tt, &ft) {
1073 (Some(a), Some(b)) if a == b => tt,
1074 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1075 (Some(_), None) => tt,
1076 (None, Some(_)) => ft,
1077 (None, None) => None,
1078 }
1079 }
1080
1081 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1082
1083 Node::PropertyAccess { object, property } => {
1084 if let Node::Identifier(name) = &object.node {
1086 if scope.get_enum(name).is_some() {
1087 return Some(TypeExpr::Named(name.clone()));
1088 }
1089 }
1090 if property == "variant" {
1092 let obj_type = self.infer_type(object, scope);
1093 if let Some(TypeExpr::Named(name)) = &obj_type {
1094 if scope.get_enum(name).is_some() {
1095 return Some(TypeExpr::Named("string".into()));
1096 }
1097 }
1098 }
1099 let obj_type = self.infer_type(object, scope);
1101 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1102 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1103 return Some(field.type_expr.clone());
1104 }
1105 }
1106 None
1107 }
1108
1109 Node::SubscriptAccess { object, index } => {
1110 let obj_type = self.infer_type(object, scope);
1111 match &obj_type {
1112 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1113 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1114 Some(TypeExpr::Shape(fields)) => {
1115 if let Node::StringLiteral(key) = &index.node {
1117 fields
1118 .iter()
1119 .find(|f| &f.name == key)
1120 .map(|f| f.type_expr.clone())
1121 } else {
1122 None
1123 }
1124 }
1125 Some(TypeExpr::Named(n)) if n == "list" => None,
1126 Some(TypeExpr::Named(n)) if n == "dict" => None,
1127 Some(TypeExpr::Named(n)) if n == "string" => {
1128 Some(TypeExpr::Named("string".into()))
1129 }
1130 _ => None,
1131 }
1132 }
1133 Node::SliceAccess { object, .. } => {
1134 let obj_type = self.infer_type(object, scope);
1136 match &obj_type {
1137 Some(TypeExpr::List(_)) => obj_type,
1138 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1139 Some(TypeExpr::Named(n)) if n == "string" => {
1140 Some(TypeExpr::Named("string".into()))
1141 }
1142 _ => None,
1143 }
1144 }
1145 Node::MethodCall { object, method, .. }
1146 | Node::OptionalMethodCall { object, method, .. } => {
1147 let obj_type = self.infer_type(object, scope);
1148 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1149 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1150 match method.as_str() {
1151 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1153 Some(TypeExpr::Named("bool".into()))
1154 }
1155 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1157 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1159 | "pad_left" | "pad_right" | "repeat" | "join" => {
1160 Some(TypeExpr::Named("string".into()))
1161 }
1162 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1163 "filter" => {
1165 if is_dict {
1166 Some(TypeExpr::Named("dict".into()))
1167 } else {
1168 Some(TypeExpr::Named("list".into()))
1169 }
1170 }
1171 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1173 "reduce" | "find" | "first" | "last" => None,
1174 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1176 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1177 "to_string" => Some(TypeExpr::Named("string".into())),
1179 "to_int" => Some(TypeExpr::Named("int".into())),
1180 "to_float" => Some(TypeExpr::Named("float".into())),
1181 _ => None,
1182 }
1183 }
1184
1185 Node::TryOperator { operand } => {
1187 match self.infer_type(operand, scope) {
1188 Some(TypeExpr::Named(name)) if name == "Result" => None, _ => None,
1190 }
1191 }
1192
1193 _ => None,
1194 }
1195 }
1196
1197 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1199 if let TypeExpr::Named(name) = expected {
1201 if scope.is_generic_type_param(name) {
1202 return true;
1203 }
1204 }
1205 if let TypeExpr::Named(name) = actual {
1206 if scope.is_generic_type_param(name) {
1207 return true;
1208 }
1209 }
1210 let expected = self.resolve_alias(expected, scope);
1211 let actual = self.resolve_alias(actual, scope);
1212
1213 match (&expected, &actual) {
1214 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1215 (TypeExpr::Union(members), actual_type) => members
1216 .iter()
1217 .any(|m| self.types_compatible(m, actual_type, scope)),
1218 (expected_type, TypeExpr::Union(members)) => members
1219 .iter()
1220 .all(|m| self.types_compatible(expected_type, m, scope)),
1221 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1222 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1223 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1224 if expected_field.optional {
1225 return true;
1226 }
1227 af.iter().any(|actual_field| {
1228 actual_field.name == expected_field.name
1229 && self.types_compatible(
1230 &expected_field.type_expr,
1231 &actual_field.type_expr,
1232 scope,
1233 )
1234 })
1235 }),
1236 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1238 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1239 keys_ok
1240 && af
1241 .iter()
1242 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1243 }
1244 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1246 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1247 self.types_compatible(expected_inner, actual_inner, scope)
1248 }
1249 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1250 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1251 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1252 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1253 }
1254 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1255 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1256 (
1258 TypeExpr::FnType {
1259 params: ep,
1260 return_type: er,
1261 },
1262 TypeExpr::FnType {
1263 params: ap,
1264 return_type: ar,
1265 },
1266 ) => {
1267 ep.len() == ap.len()
1268 && ep
1269 .iter()
1270 .zip(ap.iter())
1271 .all(|(e, a)| self.types_compatible(e, a, scope))
1272 && self.types_compatible(er, ar, scope)
1273 }
1274 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1276 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1277 _ => false,
1278 }
1279 }
1280
1281 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1282 if let TypeExpr::Named(name) = ty {
1283 if let Some(resolved) = scope.resolve_type(name) {
1284 return resolved.clone();
1285 }
1286 }
1287 ty.clone()
1288 }
1289
1290 fn error_at(&mut self, message: String, span: Span) {
1291 self.diagnostics.push(TypeDiagnostic {
1292 message,
1293 severity: DiagnosticSeverity::Error,
1294 span: Some(span),
1295 });
1296 }
1297
1298 fn warning_at(&mut self, message: String, span: Span) {
1299 self.diagnostics.push(TypeDiagnostic {
1300 message,
1301 severity: DiagnosticSeverity::Warning,
1302 span: Some(span),
1303 });
1304 }
1305}
1306
1307impl Default for TypeChecker {
1308 fn default() -> Self {
1309 Self::new()
1310 }
1311}
1312
1313fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1315 match op {
1316 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1317 "+" => match (left, right) {
1318 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1319 match (l.as_str(), r.as_str()) {
1320 ("int", "int") => Some(TypeExpr::Named("int".into())),
1321 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1322 ("string", _) => Some(TypeExpr::Named("string".into())),
1323 ("list", "list") => Some(TypeExpr::Named("list".into())),
1324 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1325 _ => Some(TypeExpr::Named("string".into())),
1326 }
1327 }
1328 _ => None,
1329 },
1330 "-" | "*" | "/" | "%" => match (left, right) {
1331 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1332 match (l.as_str(), r.as_str()) {
1333 ("int", "int") => Some(TypeExpr::Named("int".into())),
1334 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1335 _ => None,
1336 }
1337 }
1338 _ => None,
1339 },
1340 "??" => match (left, right) {
1341 (Some(TypeExpr::Union(members)), _) => {
1342 let non_nil: Vec<_> = members
1343 .iter()
1344 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1345 .cloned()
1346 .collect();
1347 if non_nil.len() == 1 {
1348 Some(non_nil[0].clone())
1349 } else if non_nil.is_empty() {
1350 right.clone()
1351 } else {
1352 Some(TypeExpr::Union(non_nil))
1353 }
1354 }
1355 _ => right.clone(),
1356 },
1357 "|>" => None,
1358 _ => None,
1359 }
1360}
1361
1362pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1367 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1368 let mut details = Vec::new();
1369 for field in ef {
1370 if field.optional {
1371 continue;
1372 }
1373 match af.iter().find(|f| f.name == field.name) {
1374 None => details.push(format!(
1375 "missing field '{}' ({})",
1376 field.name,
1377 format_type(&field.type_expr)
1378 )),
1379 Some(actual_field) => {
1380 let e_str = format_type(&field.type_expr);
1381 let a_str = format_type(&actual_field.type_expr);
1382 if e_str != a_str {
1383 details.push(format!(
1384 "field '{}' has type {}, expected {}",
1385 field.name, a_str, e_str
1386 ));
1387 }
1388 }
1389 }
1390 }
1391 if details.is_empty() {
1392 None
1393 } else {
1394 Some(details.join("; "))
1395 }
1396 } else {
1397 None
1398 }
1399}
1400
1401pub fn format_type(ty: &TypeExpr) -> String {
1402 match ty {
1403 TypeExpr::Named(n) => n.clone(),
1404 TypeExpr::Union(types) => types
1405 .iter()
1406 .map(format_type)
1407 .collect::<Vec<_>>()
1408 .join(" | "),
1409 TypeExpr::Shape(fields) => {
1410 let inner: Vec<String> = fields
1411 .iter()
1412 .map(|f| {
1413 let opt = if f.optional { "?" } else { "" };
1414 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1415 })
1416 .collect();
1417 format!("{{{}}}", inner.join(", "))
1418 }
1419 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1420 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1421 TypeExpr::FnType {
1422 params,
1423 return_type,
1424 } => {
1425 let params_str = params
1426 .iter()
1427 .map(format_type)
1428 .collect::<Vec<_>>()
1429 .join(", ");
1430 format!("fn({}) -> {}", params_str, format_type(return_type))
1431 }
1432 }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437 use super::*;
1438 use crate::Parser;
1439 use harn_lexer::Lexer;
1440
1441 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1442 let mut lexer = Lexer::new(source);
1443 let tokens = lexer.tokenize().unwrap();
1444 let mut parser = Parser::new(tokens);
1445 let program = parser.parse().unwrap();
1446 TypeChecker::new().check(&program)
1447 }
1448
1449 fn errors(source: &str) -> Vec<String> {
1450 check_source(source)
1451 .into_iter()
1452 .filter(|d| d.severity == DiagnosticSeverity::Error)
1453 .map(|d| d.message)
1454 .collect()
1455 }
1456
1457 #[test]
1458 fn test_no_errors_for_untyped_code() {
1459 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1460 assert!(errs.is_empty());
1461 }
1462
1463 #[test]
1464 fn test_correct_typed_let() {
1465 let errs = errors("pipeline t(task) { let x: int = 42 }");
1466 assert!(errs.is_empty());
1467 }
1468
1469 #[test]
1470 fn test_type_mismatch_let() {
1471 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1472 assert_eq!(errs.len(), 1);
1473 assert!(errs[0].contains("Type mismatch"));
1474 assert!(errs[0].contains("int"));
1475 assert!(errs[0].contains("string"));
1476 }
1477
1478 #[test]
1479 fn test_correct_typed_fn() {
1480 let errs = errors(
1481 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1482 );
1483 assert!(errs.is_empty());
1484 }
1485
1486 #[test]
1487 fn test_fn_arg_type_mismatch() {
1488 let errs = errors(
1489 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1490add("hello", 2) }"#,
1491 );
1492 assert_eq!(errs.len(), 1);
1493 assert!(errs[0].contains("Argument 1"));
1494 assert!(errs[0].contains("expected int"));
1495 }
1496
1497 #[test]
1498 fn test_return_type_mismatch() {
1499 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1500 assert_eq!(errs.len(), 1);
1501 assert!(errs[0].contains("Return type mismatch"));
1502 }
1503
1504 #[test]
1505 fn test_union_type_compatible() {
1506 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1507 assert!(errs.is_empty());
1508 }
1509
1510 #[test]
1511 fn test_union_type_mismatch() {
1512 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1513 assert_eq!(errs.len(), 1);
1514 assert!(errs[0].contains("Type mismatch"));
1515 }
1516
1517 #[test]
1518 fn test_type_inference_propagation() {
1519 let errs = errors(
1520 r#"pipeline t(task) {
1521 fn add(a: int, b: int) -> int { return a + b }
1522 let result: string = add(1, 2)
1523}"#,
1524 );
1525 assert_eq!(errs.len(), 1);
1526 assert!(errs[0].contains("Type mismatch"));
1527 assert!(errs[0].contains("string"));
1528 assert!(errs[0].contains("int"));
1529 }
1530
1531 #[test]
1532 fn test_builtin_return_type_inference() {
1533 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1534 assert_eq!(errs.len(), 1);
1535 assert!(errs[0].contains("string"));
1536 assert!(errs[0].contains("int"));
1537 }
1538
1539 #[test]
1540 fn test_binary_op_type_inference() {
1541 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1542 assert_eq!(errs.len(), 1);
1543 }
1544
1545 #[test]
1546 fn test_comparison_returns_bool() {
1547 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1548 assert!(errs.is_empty());
1549 }
1550
1551 #[test]
1552 fn test_int_float_promotion() {
1553 let errs = errors("pipeline t(task) { let x: float = 42 }");
1554 assert!(errs.is_empty());
1555 }
1556
1557 #[test]
1558 fn test_untyped_code_no_errors() {
1559 let errs = errors(
1560 r#"pipeline t(task) {
1561 fn process(data) {
1562 let result = data + " processed"
1563 return result
1564 }
1565 log(process("hello"))
1566}"#,
1567 );
1568 assert!(errs.is_empty());
1569 }
1570
1571 #[test]
1572 fn test_type_alias() {
1573 let errs = errors(
1574 r#"pipeline t(task) {
1575 type Name = string
1576 let x: Name = "hello"
1577}"#,
1578 );
1579 assert!(errs.is_empty());
1580 }
1581
1582 #[test]
1583 fn test_type_alias_mismatch() {
1584 let errs = errors(
1585 r#"pipeline t(task) {
1586 type Name = string
1587 let x: Name = 42
1588}"#,
1589 );
1590 assert_eq!(errs.len(), 1);
1591 }
1592
1593 #[test]
1594 fn test_assignment_type_check() {
1595 let errs = errors(
1596 r#"pipeline t(task) {
1597 var x: int = 0
1598 x = "hello"
1599}"#,
1600 );
1601 assert_eq!(errs.len(), 1);
1602 assert!(errs[0].contains("cannot assign string"));
1603 }
1604
1605 #[test]
1606 fn test_covariance_int_to_float_in_fn() {
1607 let errs = errors(
1608 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1609 );
1610 assert!(errs.is_empty());
1611 }
1612
1613 #[test]
1614 fn test_covariance_return_type() {
1615 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1616 assert!(errs.is_empty());
1617 }
1618
1619 #[test]
1620 fn test_no_contravariance_float_to_int() {
1621 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1622 assert_eq!(errs.len(), 1);
1623 }
1624
1625 fn warnings(source: &str) -> Vec<String> {
1628 check_source(source)
1629 .into_iter()
1630 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1631 .map(|d| d.message)
1632 .collect()
1633 }
1634
1635 #[test]
1636 fn test_exhaustive_match_no_warning() {
1637 let warns = warnings(
1638 r#"pipeline t(task) {
1639 enum Color { Red, Green, Blue }
1640 let c = Color.Red
1641 match c.variant {
1642 "Red" -> { log("r") }
1643 "Green" -> { log("g") }
1644 "Blue" -> { log("b") }
1645 }
1646}"#,
1647 );
1648 let exhaustive_warns: Vec<_> = warns
1649 .iter()
1650 .filter(|w| w.contains("Non-exhaustive"))
1651 .collect();
1652 assert!(exhaustive_warns.is_empty());
1653 }
1654
1655 #[test]
1656 fn test_non_exhaustive_match_warning() {
1657 let warns = warnings(
1658 r#"pipeline t(task) {
1659 enum Color { Red, Green, Blue }
1660 let c = Color.Red
1661 match c.variant {
1662 "Red" -> { log("r") }
1663 "Green" -> { log("g") }
1664 }
1665}"#,
1666 );
1667 let exhaustive_warns: Vec<_> = warns
1668 .iter()
1669 .filter(|w| w.contains("Non-exhaustive"))
1670 .collect();
1671 assert_eq!(exhaustive_warns.len(), 1);
1672 assert!(exhaustive_warns[0].contains("Blue"));
1673 }
1674
1675 #[test]
1676 fn test_non_exhaustive_multiple_missing() {
1677 let warns = warnings(
1678 r#"pipeline t(task) {
1679 enum Status { Active, Inactive, Pending }
1680 let s = Status.Active
1681 match s.variant {
1682 "Active" -> { log("a") }
1683 }
1684}"#,
1685 );
1686 let exhaustive_warns: Vec<_> = warns
1687 .iter()
1688 .filter(|w| w.contains("Non-exhaustive"))
1689 .collect();
1690 assert_eq!(exhaustive_warns.len(), 1);
1691 assert!(exhaustive_warns[0].contains("Inactive"));
1692 assert!(exhaustive_warns[0].contains("Pending"));
1693 }
1694
1695 #[test]
1696 fn test_enum_construct_type_inference() {
1697 let errs = errors(
1698 r#"pipeline t(task) {
1699 enum Color { Red, Green, Blue }
1700 let c: Color = Color.Red
1701}"#,
1702 );
1703 assert!(errs.is_empty());
1704 }
1705
1706 #[test]
1709 fn test_nil_coalescing_strips_nil() {
1710 let errs = errors(
1712 r#"pipeline t(task) {
1713 let x: string | nil = nil
1714 let y: string = x ?? "default"
1715}"#,
1716 );
1717 assert!(errs.is_empty());
1718 }
1719
1720 #[test]
1721 fn test_shape_mismatch_detail_missing_field() {
1722 let errs = errors(
1723 r#"pipeline t(task) {
1724 let x: {name: string, age: int} = {name: "hello"}
1725}"#,
1726 );
1727 assert_eq!(errs.len(), 1);
1728 assert!(
1729 errs[0].contains("missing field 'age'"),
1730 "expected detail about missing field, got: {}",
1731 errs[0]
1732 );
1733 }
1734
1735 #[test]
1736 fn test_shape_mismatch_detail_wrong_type() {
1737 let errs = errors(
1738 r#"pipeline t(task) {
1739 let x: {name: string, age: int} = {name: 42, age: 10}
1740}"#,
1741 );
1742 assert_eq!(errs.len(), 1);
1743 assert!(
1744 errs[0].contains("field 'name' has type int, expected string"),
1745 "expected detail about wrong type, got: {}",
1746 errs[0]
1747 );
1748 }
1749
1750 #[test]
1753 fn test_match_pattern_string_against_int() {
1754 let warns = warnings(
1755 r#"pipeline t(task) {
1756 let x: int = 42
1757 match x {
1758 "hello" -> { log("bad") }
1759 42 -> { log("ok") }
1760 }
1761}"#,
1762 );
1763 let pattern_warns: Vec<_> = warns
1764 .iter()
1765 .filter(|w| w.contains("Match pattern type mismatch"))
1766 .collect();
1767 assert_eq!(pattern_warns.len(), 1);
1768 assert!(pattern_warns[0].contains("matching int against string literal"));
1769 }
1770
1771 #[test]
1772 fn test_match_pattern_int_against_string() {
1773 let warns = warnings(
1774 r#"pipeline t(task) {
1775 let x: string = "hello"
1776 match x {
1777 42 -> { log("bad") }
1778 "hello" -> { log("ok") }
1779 }
1780}"#,
1781 );
1782 let pattern_warns: Vec<_> = warns
1783 .iter()
1784 .filter(|w| w.contains("Match pattern type mismatch"))
1785 .collect();
1786 assert_eq!(pattern_warns.len(), 1);
1787 assert!(pattern_warns[0].contains("matching string against int literal"));
1788 }
1789
1790 #[test]
1791 fn test_match_pattern_bool_against_int() {
1792 let warns = warnings(
1793 r#"pipeline t(task) {
1794 let x: int = 42
1795 match x {
1796 true -> { log("bad") }
1797 42 -> { log("ok") }
1798 }
1799}"#,
1800 );
1801 let pattern_warns: Vec<_> = warns
1802 .iter()
1803 .filter(|w| w.contains("Match pattern type mismatch"))
1804 .collect();
1805 assert_eq!(pattern_warns.len(), 1);
1806 assert!(pattern_warns[0].contains("matching int against bool literal"));
1807 }
1808
1809 #[test]
1810 fn test_match_pattern_float_against_string() {
1811 let warns = warnings(
1812 r#"pipeline t(task) {
1813 let x: string = "hello"
1814 match x {
1815 3.14 -> { log("bad") }
1816 "hello" -> { log("ok") }
1817 }
1818}"#,
1819 );
1820 let pattern_warns: Vec<_> = warns
1821 .iter()
1822 .filter(|w| w.contains("Match pattern type mismatch"))
1823 .collect();
1824 assert_eq!(pattern_warns.len(), 1);
1825 assert!(pattern_warns[0].contains("matching string against float literal"));
1826 }
1827
1828 #[test]
1829 fn test_match_pattern_int_against_float_ok() {
1830 let warns = warnings(
1832 r#"pipeline t(task) {
1833 let x: float = 3.14
1834 match x {
1835 42 -> { log("ok") }
1836 _ -> { log("default") }
1837 }
1838}"#,
1839 );
1840 let pattern_warns: Vec<_> = warns
1841 .iter()
1842 .filter(|w| w.contains("Match pattern type mismatch"))
1843 .collect();
1844 assert!(pattern_warns.is_empty());
1845 }
1846
1847 #[test]
1848 fn test_match_pattern_float_against_int_ok() {
1849 let warns = warnings(
1851 r#"pipeline t(task) {
1852 let x: int = 42
1853 match x {
1854 3.14 -> { log("close") }
1855 _ -> { log("default") }
1856 }
1857}"#,
1858 );
1859 let pattern_warns: Vec<_> = warns
1860 .iter()
1861 .filter(|w| w.contains("Match pattern type mismatch"))
1862 .collect();
1863 assert!(pattern_warns.is_empty());
1864 }
1865
1866 #[test]
1867 fn test_match_pattern_correct_types_no_warning() {
1868 let warns = warnings(
1869 r#"pipeline t(task) {
1870 let x: int = 42
1871 match x {
1872 1 -> { log("one") }
1873 2 -> { log("two") }
1874 _ -> { log("other") }
1875 }
1876}"#,
1877 );
1878 let pattern_warns: Vec<_> = warns
1879 .iter()
1880 .filter(|w| w.contains("Match pattern type mismatch"))
1881 .collect();
1882 assert!(pattern_warns.is_empty());
1883 }
1884
1885 #[test]
1886 fn test_match_pattern_wildcard_no_warning() {
1887 let warns = warnings(
1888 r#"pipeline t(task) {
1889 let x: int = 42
1890 match x {
1891 _ -> { log("catch all") }
1892 }
1893}"#,
1894 );
1895 let pattern_warns: Vec<_> = warns
1896 .iter()
1897 .filter(|w| w.contains("Match pattern type mismatch"))
1898 .collect();
1899 assert!(pattern_warns.is_empty());
1900 }
1901
1902 #[test]
1903 fn test_match_pattern_untyped_no_warning() {
1904 let warns = warnings(
1906 r#"pipeline t(task) {
1907 let x = some_unknown_fn()
1908 match x {
1909 "hello" -> { log("string") }
1910 42 -> { log("int") }
1911 }
1912}"#,
1913 );
1914 let pattern_warns: Vec<_> = warns
1915 .iter()
1916 .filter(|w| w.contains("Match pattern type mismatch"))
1917 .collect();
1918 assert!(pattern_warns.is_empty());
1919 }
1920}