1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43 params: Vec<(String, InferredType)>,
44 return_type: InferredType,
45}
46
47impl TypeScope {
48 fn new() -> Self {
49 Self {
50 vars: BTreeMap::new(),
51 functions: BTreeMap::new(),
52 type_aliases: BTreeMap::new(),
53 enums: BTreeMap::new(),
54 interfaces: BTreeMap::new(),
55 structs: BTreeMap::new(),
56 parent: None,
57 }
58 }
59
60 fn child(&self) -> Self {
61 Self {
62 vars: BTreeMap::new(),
63 functions: BTreeMap::new(),
64 type_aliases: BTreeMap::new(),
65 enums: BTreeMap::new(),
66 interfaces: BTreeMap::new(),
67 structs: BTreeMap::new(),
68 parent: Some(Box::new(self.clone())),
69 }
70 }
71
72 fn get_var(&self, name: &str) -> Option<&InferredType> {
73 self.vars
74 .get(name)
75 .or_else(|| self.parent.as_ref()?.get_var(name))
76 }
77
78 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79 self.functions
80 .get(name)
81 .or_else(|| self.parent.as_ref()?.get_fn(name))
82 }
83
84 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85 self.type_aliases
86 .get(name)
87 .or_else(|| self.parent.as_ref()?.resolve_type(name))
88 }
89
90 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91 self.enums
92 .get(name)
93 .or_else(|| self.parent.as_ref()?.get_enum(name))
94 }
95
96 #[allow(dead_code)]
97 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98 self.interfaces
99 .get(name)
100 .or_else(|| self.parent.as_ref()?.get_interface(name))
101 }
102
103 fn define_var(&mut self, name: &str, ty: InferredType) {
104 self.vars.insert(name.to_string(), ty);
105 }
106
107 fn define_fn(&mut self, name: &str, sig: FnSignature) {
108 self.functions.insert(name.to_string(), sig);
109 }
110}
111
112fn builtin_return_type(name: &str) -> InferredType {
114 match name {
115 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117 Some(TypeExpr::Named("nil".into()))
118 }
119 "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120 | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121 | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122 "to_int" => Some(TypeExpr::Named("int".into())),
123 "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124 "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125 "list_dir" => Some(TypeExpr::Named("list".into())),
126 "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127 "env" | "regex_match" => Some(TypeExpr::Union(vec![
128 TypeExpr::Named("string".into()),
129 TypeExpr::Named("nil".into()),
130 ])),
131 "json_parse" | "json_extract" => None, _ => None,
133 }
134}
135
136fn is_builtin(name: &str) -> bool {
138 matches!(
139 name,
140 "log"
141 | "print"
142 | "println"
143 | "type_of"
144 | "to_string"
145 | "to_int"
146 | "to_float"
147 | "json_stringify"
148 | "json_parse"
149 | "env"
150 | "timestamp"
151 | "sleep"
152 | "read_file"
153 | "write_file"
154 | "exit"
155 | "regex_match"
156 | "regex_replace"
157 | "http_get"
158 | "http_post"
159 | "llm_call"
160 | "agent_loop"
161 | "await"
162 | "cancel"
163 | "file_exists"
164 | "delete_file"
165 | "list_dir"
166 | "mkdir"
167 | "path_join"
168 | "copy_file"
169 | "append_file"
170 | "temp_dir"
171 | "stat"
172 | "exec"
173 | "shell"
174 | "date_now"
175 | "date_format"
176 | "date_parse"
177 | "format"
178 | "json_validate"
179 | "json_extract"
180 | "trim"
181 | "lowercase"
182 | "uppercase"
183 | "split"
184 | "starts_with"
185 | "ends_with"
186 | "contains"
187 | "replace"
188 | "join"
189 | "len"
190 | "substring"
191 | "dirname"
192 | "basename"
193 | "extname"
194 )
195}
196
197pub struct TypeChecker {
199 diagnostics: Vec<TypeDiagnostic>,
200 scope: TypeScope,
201}
202
203impl TypeChecker {
204 pub fn new() -> Self {
205 Self {
206 diagnostics: Vec::new(),
207 scope: TypeScope::new(),
208 }
209 }
210
211 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213 Self::register_declarations_into(&mut self.scope, program);
215
216 for snode in program {
218 if let Node::Pipeline { body, .. } = &snode.node {
219 Self::register_declarations_into(&mut self.scope, body);
220 }
221 }
222
223 for snode in program {
225 match &snode.node {
226 Node::Pipeline { params, body, .. } => {
227 let mut child = self.scope.child();
228 for p in params {
229 child.define_var(p, None);
230 }
231 self.check_block(body, &mut child);
232 }
233 Node::FnDecl {
234 name,
235 params,
236 return_type,
237 body,
238 ..
239 } => {
240 let sig = FnSignature {
241 params: params
242 .iter()
243 .map(|p| (p.name.clone(), p.type_expr.clone()))
244 .collect(),
245 return_type: return_type.clone(),
246 };
247 self.scope.define_fn(name, sig);
248 self.check_fn_body(params, return_type, body);
249 }
250 _ => {
251 let mut scope = self.scope.clone();
252 self.check_node(snode, &mut scope);
253 for (name, ty) in scope.vars {
255 self.scope.vars.entry(name).or_insert(ty);
256 }
257 }
258 }
259 }
260
261 self.diagnostics
262 }
263
264 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266 for snode in nodes {
267 match &snode.node {
268 Node::TypeDecl { name, type_expr } => {
269 scope.type_aliases.insert(name.clone(), type_expr.clone());
270 }
271 Node::EnumDecl { name, variants } => {
272 let variant_names: Vec<String> =
273 variants.iter().map(|v| v.name.clone()).collect();
274 scope.enums.insert(name.clone(), variant_names);
275 }
276 Node::InterfaceDecl { name, methods } => {
277 scope.interfaces.insert(name.clone(), methods.clone());
278 }
279 Node::StructDecl { name, fields } => {
280 let field_types: Vec<(String, InferredType)> = fields
281 .iter()
282 .map(|f| (f.name.clone(), f.type_expr.clone()))
283 .collect();
284 scope.structs.insert(name.clone(), field_types);
285 }
286 _ => {}
287 }
288 }
289 }
290
291 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292 for stmt in stmts {
293 self.check_node(stmt, scope);
294 }
295 }
296
297 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
299 match pattern {
300 BindingPattern::Identifier(name) => {
301 scope.define_var(name, None);
302 }
303 BindingPattern::Dict(fields) => {
304 for field in fields {
305 let name = field.alias.as_deref().unwrap_or(&field.key);
306 scope.define_var(name, None);
307 }
308 }
309 BindingPattern::List(elements) => {
310 for elem in elements {
311 scope.define_var(&elem.name, None);
312 }
313 }
314 }
315 }
316
317 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
318 let span = snode.span;
319 match &snode.node {
320 Node::LetBinding {
321 pattern,
322 type_ann,
323 value,
324 } => {
325 let inferred = self.infer_type(value, scope);
326 if let BindingPattern::Identifier(name) = pattern {
327 if let Some(expected) = type_ann {
328 if let Some(actual) = &inferred {
329 if !self.types_compatible(expected, actual, scope) {
330 let mut msg = format!(
331 "Type mismatch: '{}' declared as {}, but assigned {}",
332 name,
333 format_type(expected),
334 format_type(actual)
335 );
336 if let Some(detail) = shape_mismatch_detail(expected, actual) {
337 msg.push_str(&format!(" ({})", detail));
338 }
339 self.error_at(msg, span);
340 }
341 }
342 }
343 let ty = type_ann.clone().or(inferred);
344 scope.define_var(name, ty);
345 } else {
346 Self::define_pattern_vars(pattern, scope);
347 }
348 }
349
350 Node::VarBinding {
351 pattern,
352 type_ann,
353 value,
354 } => {
355 let inferred = self.infer_type(value, scope);
356 if let BindingPattern::Identifier(name) = pattern {
357 if let Some(expected) = type_ann {
358 if let Some(actual) = &inferred {
359 if !self.types_compatible(expected, actual, scope) {
360 let mut msg = format!(
361 "Type mismatch: '{}' declared as {}, but assigned {}",
362 name,
363 format_type(expected),
364 format_type(actual)
365 );
366 if let Some(detail) = shape_mismatch_detail(expected, actual) {
367 msg.push_str(&format!(" ({})", detail));
368 }
369 self.error_at(msg, span);
370 }
371 }
372 }
373 let ty = type_ann.clone().or(inferred);
374 scope.define_var(name, ty);
375 } else {
376 Self::define_pattern_vars(pattern, scope);
377 }
378 }
379
380 Node::FnDecl {
381 name,
382 params,
383 return_type,
384 body,
385 ..
386 } => {
387 let sig = FnSignature {
388 params: params
389 .iter()
390 .map(|p| (p.name.clone(), p.type_expr.clone()))
391 .collect(),
392 return_type: return_type.clone(),
393 };
394 scope.define_fn(name, sig.clone());
395 scope.define_var(name, None);
396 self.check_fn_body(params, return_type, body);
397 }
398
399 Node::FunctionCall { name, args } => {
400 self.check_call(name, args, scope, span);
401 }
402
403 Node::IfElse {
404 condition,
405 then_body,
406 else_body,
407 } => {
408 self.check_node(condition, scope);
409 let mut then_scope = scope.child();
410 self.check_block(then_body, &mut then_scope);
411 if let Some(else_body) = else_body {
412 let mut else_scope = scope.child();
413 self.check_block(else_body, &mut else_scope);
414 }
415 }
416
417 Node::ForIn {
418 pattern,
419 iterable,
420 body,
421 } => {
422 self.check_node(iterable, scope);
423 let mut loop_scope = scope.child();
424 if let BindingPattern::Identifier(variable) = pattern {
425 let elem_type = match self.infer_type(iterable, scope) {
427 Some(TypeExpr::List(inner)) => Some(*inner),
428 Some(TypeExpr::Named(n)) if n == "string" => {
429 Some(TypeExpr::Named("string".into()))
430 }
431 _ => None,
432 };
433 loop_scope.define_var(variable, elem_type);
434 } else {
435 Self::define_pattern_vars(pattern, &mut loop_scope);
436 }
437 self.check_block(body, &mut loop_scope);
438 }
439
440 Node::WhileLoop { condition, body } => {
441 self.check_node(condition, scope);
442 let mut loop_scope = scope.child();
443 self.check_block(body, &mut loop_scope);
444 }
445
446 Node::TryCatch {
447 body,
448 error_var,
449 catch_body,
450 ..
451 } => {
452 let mut try_scope = scope.child();
453 self.check_block(body, &mut try_scope);
454 let mut catch_scope = scope.child();
455 if let Some(var) = error_var {
456 catch_scope.define_var(var, None);
457 }
458 self.check_block(catch_body, &mut catch_scope);
459 }
460
461 Node::ReturnStmt {
462 value: Some(val), ..
463 } => {
464 self.check_node(val, scope);
465 }
466
467 Node::Assignment {
468 target, value, op, ..
469 } => {
470 self.check_node(value, scope);
471 if let Node::Identifier(name) = &target.node {
472 if let Some(Some(var_type)) = scope.get_var(name) {
473 let value_type = self.infer_type(value, scope);
474 let assigned = if let Some(op) = op {
475 let var_inferred = scope.get_var(name).cloned().flatten();
476 infer_binary_op_type(op, &var_inferred, &value_type)
477 } else {
478 value_type
479 };
480 if let Some(actual) = &assigned {
481 if !self.types_compatible(var_type, actual, scope) {
482 self.error_at(
483 format!(
484 "Type mismatch: cannot assign {} to '{}' (declared as {})",
485 format_type(actual),
486 name,
487 format_type(var_type)
488 ),
489 span,
490 );
491 }
492 }
493 }
494 }
495 }
496
497 Node::TypeDecl { name, type_expr } => {
498 scope.type_aliases.insert(name.clone(), type_expr.clone());
499 }
500
501 Node::EnumDecl { name, variants } => {
502 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
503 scope.enums.insert(name.clone(), variant_names);
504 }
505
506 Node::StructDecl { name, fields } => {
507 let field_types: Vec<(String, InferredType)> = fields
508 .iter()
509 .map(|f| (f.name.clone(), f.type_expr.clone()))
510 .collect();
511 scope.structs.insert(name.clone(), field_types);
512 }
513
514 Node::InterfaceDecl { name, methods } => {
515 scope.interfaces.insert(name.clone(), methods.clone());
516 }
517
518 Node::MatchExpr { value, arms } => {
519 self.check_node(value, scope);
520 for arm in arms {
521 self.check_node(&arm.pattern, scope);
522 let mut arm_scope = scope.child();
523 self.check_block(&arm.body, &mut arm_scope);
524 }
525 self.check_match_exhaustiveness(value, arms, scope, span);
526 }
527
528 Node::BinaryOp { op, left, right } => {
530 self.check_node(left, scope);
531 self.check_node(right, scope);
532 let lt = self.infer_type(left, scope);
534 let rt = self.infer_type(right, scope);
535 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
536 match op.as_str() {
537 "-" | "*" | "/" | "%" => {
538 let numeric = ["int", "float"];
539 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
540 self.warning_at(
541 format!(
542 "Operator '{op}' may not be valid for types {} and {}",
543 l, r
544 ),
545 span,
546 );
547 }
548 }
549 "+" => {
550 let valid = ["int", "float", "string", "list", "dict"];
552 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
553 self.warning_at(
554 format!(
555 "Operator '+' may not be valid for types {} and {}",
556 l, r
557 ),
558 span,
559 );
560 }
561 }
562 _ => {}
563 }
564 }
565 }
566 Node::UnaryOp { operand, .. } => {
567 self.check_node(operand, scope);
568 }
569 Node::MethodCall { object, args, .. }
570 | Node::OptionalMethodCall { object, args, .. } => {
571 self.check_node(object, scope);
572 for arg in args {
573 self.check_node(arg, scope);
574 }
575 }
576 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
577 self.check_node(object, scope);
578 }
579 Node::SubscriptAccess { object, index } => {
580 self.check_node(object, scope);
581 self.check_node(index, scope);
582 }
583 Node::SliceAccess { object, start, end } => {
584 self.check_node(object, scope);
585 if let Some(s) = start {
586 self.check_node(s, scope);
587 }
588 if let Some(e) = end {
589 self.check_node(e, scope);
590 }
591 }
592
593 _ => {}
595 }
596 }
597
598 fn check_fn_body(
599 &mut self,
600 params: &[TypedParam],
601 return_type: &Option<TypeExpr>,
602 body: &[SNode],
603 ) {
604 let mut fn_scope = self.scope.child();
605 for param in params {
606 fn_scope.define_var(¶m.name, param.type_expr.clone());
607 }
608 self.check_block(body, &mut fn_scope);
609
610 if let Some(ret_type) = return_type {
612 for stmt in body {
613 self.check_return_type(stmt, ret_type, &fn_scope);
614 }
615 }
616 }
617
618 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
619 let span = snode.span;
620 match &snode.node {
621 Node::ReturnStmt { value: Some(val) } => {
622 let inferred = self.infer_type(val, scope);
623 if let Some(actual) = &inferred {
624 if !self.types_compatible(expected, actual, scope) {
625 self.error_at(
626 format!(
627 "Return type mismatch: expected {}, got {}",
628 format_type(expected),
629 format_type(actual)
630 ),
631 span,
632 );
633 }
634 }
635 }
636 Node::IfElse {
637 then_body,
638 else_body,
639 ..
640 } => {
641 for stmt in then_body {
642 self.check_return_type(stmt, expected, scope);
643 }
644 if let Some(else_body) = else_body {
645 for stmt in else_body {
646 self.check_return_type(stmt, expected, scope);
647 }
648 }
649 }
650 _ => {}
651 }
652 }
653
654 fn check_match_exhaustiveness(
656 &mut self,
657 value: &SNode,
658 arms: &[MatchArm],
659 scope: &TypeScope,
660 span: Span,
661 ) {
662 let enum_name = match &value.node {
664 Node::PropertyAccess { object, property } if property == "variant" => {
665 match self.infer_type(object, scope) {
667 Some(TypeExpr::Named(name)) => {
668 if scope.get_enum(&name).is_some() {
669 Some(name)
670 } else {
671 None
672 }
673 }
674 _ => None,
675 }
676 }
677 _ => {
678 match self.infer_type(value, scope) {
680 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
681 _ => None,
682 }
683 }
684 };
685
686 let Some(enum_name) = enum_name else {
687 return;
688 };
689 let Some(variants) = scope.get_enum(&enum_name) else {
690 return;
691 };
692
693 let mut covered: Vec<String> = Vec::new();
695 let mut has_wildcard = false;
696
697 for arm in arms {
698 match &arm.pattern.node {
699 Node::StringLiteral(s) => covered.push(s.clone()),
701 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
703 has_wildcard = true;
704 }
705 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
707 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
709 _ => {
710 has_wildcard = true;
712 }
713 }
714 }
715
716 if has_wildcard {
717 return;
718 }
719
720 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
721 if !missing.is_empty() {
722 let missing_str = missing
723 .iter()
724 .map(|s| format!("\"{}\"", s))
725 .collect::<Vec<_>>()
726 .join(", ");
727 self.warning_at(
728 format!(
729 "Non-exhaustive match on enum {}: missing variants {}",
730 enum_name, missing_str
731 ),
732 span,
733 );
734 }
735 }
736
737 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
738 if let Some(sig) = scope.get_fn(name).cloned() {
740 if args.len() != sig.params.len() && !is_builtin(name) {
741 self.warning_at(
742 format!(
743 "Function '{}' expects {} arguments, got {}",
744 name,
745 sig.params.len(),
746 args.len()
747 ),
748 span,
749 );
750 }
751 for (i, (arg, (param_name, param_type))) in
752 args.iter().zip(sig.params.iter()).enumerate()
753 {
754 if let Some(expected) = param_type {
755 let actual = self.infer_type(arg, scope);
756 if let Some(actual) = &actual {
757 if !self.types_compatible(expected, actual, scope) {
758 self.error_at(
759 format!(
760 "Argument {} ('{}'): expected {}, got {}",
761 i + 1,
762 param_name,
763 format_type(expected),
764 format_type(actual)
765 ),
766 arg.span,
767 );
768 }
769 }
770 }
771 }
772 }
773 for arg in args {
775 self.check_node(arg, scope);
776 }
777 }
778
779 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
781 match &snode.node {
782 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
783 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
784 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
785 Some(TypeExpr::Named("string".into()))
786 }
787 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
788 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
789 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
790 Node::DictLiteral(entries) => {
791 let mut fields = Vec::new();
793 let mut all_string_keys = true;
794 for entry in entries {
795 if let Node::StringLiteral(key) = &entry.key.node {
796 let val_type = self
797 .infer_type(&entry.value, scope)
798 .unwrap_or(TypeExpr::Named("nil".into()));
799 fields.push(ShapeField {
800 name: key.clone(),
801 type_expr: val_type,
802 optional: false,
803 });
804 } else {
805 all_string_keys = false;
806 break;
807 }
808 }
809 if all_string_keys && !fields.is_empty() {
810 Some(TypeExpr::Shape(fields))
811 } else {
812 Some(TypeExpr::Named("dict".into()))
813 }
814 }
815 Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
816
817 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
818
819 Node::FunctionCall { name, .. } => {
820 if let Some(sig) = scope.get_fn(name) {
822 return sig.return_type.clone();
823 }
824 builtin_return_type(name)
826 }
827
828 Node::BinaryOp { op, left, right } => {
829 let lt = self.infer_type(left, scope);
830 let rt = self.infer_type(right, scope);
831 infer_binary_op_type(op, <, &rt)
832 }
833
834 Node::UnaryOp { op, operand } => {
835 let t = self.infer_type(operand, scope);
836 match op.as_str() {
837 "!" => Some(TypeExpr::Named("bool".into())),
838 "-" => t, _ => None,
840 }
841 }
842
843 Node::Ternary {
844 true_expr,
845 false_expr,
846 ..
847 } => {
848 let tt = self.infer_type(true_expr, scope);
849 let ft = self.infer_type(false_expr, scope);
850 match (&tt, &ft) {
851 (Some(a), Some(b)) if a == b => tt,
852 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
853 (Some(_), None) => tt,
854 (None, Some(_)) => ft,
855 (None, None) => None,
856 }
857 }
858
859 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
860
861 Node::PropertyAccess { object, property } => {
862 if let Node::Identifier(name) = &object.node {
864 if scope.get_enum(name).is_some() {
865 return Some(TypeExpr::Named(name.clone()));
866 }
867 }
868 if property == "variant" {
870 let obj_type = self.infer_type(object, scope);
871 if let Some(TypeExpr::Named(name)) = &obj_type {
872 if scope.get_enum(name).is_some() {
873 return Some(TypeExpr::Named("string".into()));
874 }
875 }
876 }
877 let obj_type = self.infer_type(object, scope);
879 if let Some(TypeExpr::Shape(fields)) = &obj_type {
880 if let Some(field) = fields.iter().find(|f| f.name == *property) {
881 return Some(field.type_expr.clone());
882 }
883 }
884 None
885 }
886
887 Node::SubscriptAccess { object, index } => {
888 let obj_type = self.infer_type(object, scope);
889 match &obj_type {
890 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
891 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
892 Some(TypeExpr::Shape(fields)) => {
893 if let Node::StringLiteral(key) = &index.node {
895 fields
896 .iter()
897 .find(|f| &f.name == key)
898 .map(|f| f.type_expr.clone())
899 } else {
900 None
901 }
902 }
903 Some(TypeExpr::Named(n)) if n == "list" => None,
904 Some(TypeExpr::Named(n)) if n == "dict" => None,
905 Some(TypeExpr::Named(n)) if n == "string" => {
906 Some(TypeExpr::Named("string".into()))
907 }
908 _ => None,
909 }
910 }
911 Node::SliceAccess { object, .. } => {
912 let obj_type = self.infer_type(object, scope);
914 match &obj_type {
915 Some(TypeExpr::List(_)) => obj_type,
916 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
917 Some(TypeExpr::Named(n)) if n == "string" => {
918 Some(TypeExpr::Named("string".into()))
919 }
920 _ => None,
921 }
922 }
923 Node::MethodCall { object, method, .. }
924 | Node::OptionalMethodCall { object, method, .. } => {
925 let obj_type = self.infer_type(object, scope);
926 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
927 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
928 match method.as_str() {
929 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
931 Some(TypeExpr::Named("bool".into()))
932 }
933 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
935 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
937 | "pad_left" | "pad_right" | "repeat" | "join" => {
938 Some(TypeExpr::Named("string".into()))
939 }
940 "split" | "chars" => Some(TypeExpr::Named("list".into())),
941 "filter" => {
943 if is_dict {
944 Some(TypeExpr::Named("dict".into()))
945 } else {
946 Some(TypeExpr::Named("list".into()))
947 }
948 }
949 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
951 "reduce" | "find" | "first" | "last" => None,
952 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
954 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
955 "to_string" => Some(TypeExpr::Named("string".into())),
957 "to_int" => Some(TypeExpr::Named("int".into())),
958 "to_float" => Some(TypeExpr::Named("float".into())),
959 _ => None,
960 }
961 }
962
963 _ => None,
964 }
965 }
966
967 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
969 let expected = self.resolve_alias(expected, scope);
970 let actual = self.resolve_alias(actual, scope);
971
972 match (&expected, &actual) {
973 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
974 (TypeExpr::Union(members), actual_type) => members
975 .iter()
976 .any(|m| self.types_compatible(m, actual_type, scope)),
977 (expected_type, TypeExpr::Union(members)) => members
978 .iter()
979 .all(|m| self.types_compatible(expected_type, m, scope)),
980 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
981 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
982 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
983 if expected_field.optional {
984 return true;
985 }
986 af.iter().any(|actual_field| {
987 actual_field.name == expected_field.name
988 && self.types_compatible(
989 &expected_field.type_expr,
990 &actual_field.type_expr,
991 scope,
992 )
993 })
994 }),
995 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
997 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
998 keys_ok
999 && af
1000 .iter()
1001 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1002 }
1003 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1005 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1006 self.types_compatible(expected_inner, actual_inner, scope)
1007 }
1008 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1009 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1010 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1011 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1012 }
1013 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1014 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1015 _ => false,
1016 }
1017 }
1018
1019 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1020 if let TypeExpr::Named(name) = ty {
1021 if let Some(resolved) = scope.resolve_type(name) {
1022 return resolved.clone();
1023 }
1024 }
1025 ty.clone()
1026 }
1027
1028 fn error_at(&mut self, message: String, span: Span) {
1029 self.diagnostics.push(TypeDiagnostic {
1030 message,
1031 severity: DiagnosticSeverity::Error,
1032 span: Some(span),
1033 });
1034 }
1035
1036 fn warning_at(&mut self, message: String, span: Span) {
1037 self.diagnostics.push(TypeDiagnostic {
1038 message,
1039 severity: DiagnosticSeverity::Warning,
1040 span: Some(span),
1041 });
1042 }
1043}
1044
1045impl Default for TypeChecker {
1046 fn default() -> Self {
1047 Self::new()
1048 }
1049}
1050
1051fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1053 match op {
1054 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1055 "+" => match (left, right) {
1056 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1057 match (l.as_str(), r.as_str()) {
1058 ("int", "int") => Some(TypeExpr::Named("int".into())),
1059 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1060 ("string", _) => Some(TypeExpr::Named("string".into())),
1061 ("list", "list") => Some(TypeExpr::Named("list".into())),
1062 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1063 _ => Some(TypeExpr::Named("string".into())),
1064 }
1065 }
1066 _ => None,
1067 },
1068 "-" | "*" | "/" | "%" => match (left, right) {
1069 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1070 match (l.as_str(), r.as_str()) {
1071 ("int", "int") => Some(TypeExpr::Named("int".into())),
1072 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1073 _ => None,
1074 }
1075 }
1076 _ => None,
1077 },
1078 "??" => match (left, right) {
1079 (Some(TypeExpr::Union(members)), _) => {
1080 let non_nil: Vec<_> = members
1081 .iter()
1082 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1083 .cloned()
1084 .collect();
1085 if non_nil.len() == 1 {
1086 Some(non_nil[0].clone())
1087 } else if non_nil.is_empty() {
1088 right.clone()
1089 } else {
1090 Some(TypeExpr::Union(non_nil))
1091 }
1092 }
1093 _ => right.clone(),
1094 },
1095 "|>" => None,
1096 _ => None,
1097 }
1098}
1099
1100pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1105 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1106 let mut details = Vec::new();
1107 for field in ef {
1108 if field.optional {
1109 continue;
1110 }
1111 match af.iter().find(|f| f.name == field.name) {
1112 None => details.push(format!(
1113 "missing field '{}' ({})",
1114 field.name,
1115 format_type(&field.type_expr)
1116 )),
1117 Some(actual_field) => {
1118 let e_str = format_type(&field.type_expr);
1119 let a_str = format_type(&actual_field.type_expr);
1120 if e_str != a_str {
1121 details.push(format!(
1122 "field '{}' has type {}, expected {}",
1123 field.name, a_str, e_str
1124 ));
1125 }
1126 }
1127 }
1128 }
1129 if details.is_empty() {
1130 None
1131 } else {
1132 Some(details.join("; "))
1133 }
1134 } else {
1135 None
1136 }
1137}
1138
1139pub fn format_type(ty: &TypeExpr) -> String {
1140 match ty {
1141 TypeExpr::Named(n) => n.clone(),
1142 TypeExpr::Union(types) => types
1143 .iter()
1144 .map(format_type)
1145 .collect::<Vec<_>>()
1146 .join(" | "),
1147 TypeExpr::Shape(fields) => {
1148 let inner: Vec<String> = fields
1149 .iter()
1150 .map(|f| {
1151 let opt = if f.optional { "?" } else { "" };
1152 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1153 })
1154 .collect();
1155 format!("{{{}}}", inner.join(", "))
1156 }
1157 TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1158 TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165 use crate::Parser;
1166 use harn_lexer::Lexer;
1167
1168 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1169 let mut lexer = Lexer::new(source);
1170 let tokens = lexer.tokenize().unwrap();
1171 let mut parser = Parser::new(tokens);
1172 let program = parser.parse().unwrap();
1173 TypeChecker::new().check(&program)
1174 }
1175
1176 fn errors(source: &str) -> Vec<String> {
1177 check_source(source)
1178 .into_iter()
1179 .filter(|d| d.severity == DiagnosticSeverity::Error)
1180 .map(|d| d.message)
1181 .collect()
1182 }
1183
1184 #[test]
1185 fn test_no_errors_for_untyped_code() {
1186 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1187 assert!(errs.is_empty());
1188 }
1189
1190 #[test]
1191 fn test_correct_typed_let() {
1192 let errs = errors("pipeline t(task) { let x: int = 42 }");
1193 assert!(errs.is_empty());
1194 }
1195
1196 #[test]
1197 fn test_type_mismatch_let() {
1198 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1199 assert_eq!(errs.len(), 1);
1200 assert!(errs[0].contains("Type mismatch"));
1201 assert!(errs[0].contains("int"));
1202 assert!(errs[0].contains("string"));
1203 }
1204
1205 #[test]
1206 fn test_correct_typed_fn() {
1207 let errs = errors(
1208 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1209 );
1210 assert!(errs.is_empty());
1211 }
1212
1213 #[test]
1214 fn test_fn_arg_type_mismatch() {
1215 let errs = errors(
1216 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1217add("hello", 2) }"#,
1218 );
1219 assert_eq!(errs.len(), 1);
1220 assert!(errs[0].contains("Argument 1"));
1221 assert!(errs[0].contains("expected int"));
1222 }
1223
1224 #[test]
1225 fn test_return_type_mismatch() {
1226 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1227 assert_eq!(errs.len(), 1);
1228 assert!(errs[0].contains("Return type mismatch"));
1229 }
1230
1231 #[test]
1232 fn test_union_type_compatible() {
1233 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1234 assert!(errs.is_empty());
1235 }
1236
1237 #[test]
1238 fn test_union_type_mismatch() {
1239 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1240 assert_eq!(errs.len(), 1);
1241 assert!(errs[0].contains("Type mismatch"));
1242 }
1243
1244 #[test]
1245 fn test_type_inference_propagation() {
1246 let errs = errors(
1247 r#"pipeline t(task) {
1248 fn add(a: int, b: int) -> int { return a + b }
1249 let result: string = add(1, 2)
1250}"#,
1251 );
1252 assert_eq!(errs.len(), 1);
1253 assert!(errs[0].contains("Type mismatch"));
1254 assert!(errs[0].contains("string"));
1255 assert!(errs[0].contains("int"));
1256 }
1257
1258 #[test]
1259 fn test_builtin_return_type_inference() {
1260 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1261 assert_eq!(errs.len(), 1);
1262 assert!(errs[0].contains("string"));
1263 assert!(errs[0].contains("int"));
1264 }
1265
1266 #[test]
1267 fn test_binary_op_type_inference() {
1268 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1269 assert_eq!(errs.len(), 1);
1270 }
1271
1272 #[test]
1273 fn test_comparison_returns_bool() {
1274 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1275 assert!(errs.is_empty());
1276 }
1277
1278 #[test]
1279 fn test_int_float_promotion() {
1280 let errs = errors("pipeline t(task) { let x: float = 42 }");
1281 assert!(errs.is_empty());
1282 }
1283
1284 #[test]
1285 fn test_untyped_code_no_errors() {
1286 let errs = errors(
1287 r#"pipeline t(task) {
1288 fn process(data) {
1289 let result = data + " processed"
1290 return result
1291 }
1292 log(process("hello"))
1293}"#,
1294 );
1295 assert!(errs.is_empty());
1296 }
1297
1298 #[test]
1299 fn test_type_alias() {
1300 let errs = errors(
1301 r#"pipeline t(task) {
1302 type Name = string
1303 let x: Name = "hello"
1304}"#,
1305 );
1306 assert!(errs.is_empty());
1307 }
1308
1309 #[test]
1310 fn test_type_alias_mismatch() {
1311 let errs = errors(
1312 r#"pipeline t(task) {
1313 type Name = string
1314 let x: Name = 42
1315}"#,
1316 );
1317 assert_eq!(errs.len(), 1);
1318 }
1319
1320 #[test]
1321 fn test_assignment_type_check() {
1322 let errs = errors(
1323 r#"pipeline t(task) {
1324 var x: int = 0
1325 x = "hello"
1326}"#,
1327 );
1328 assert_eq!(errs.len(), 1);
1329 assert!(errs[0].contains("cannot assign string"));
1330 }
1331
1332 #[test]
1333 fn test_covariance_int_to_float_in_fn() {
1334 let errs = errors(
1335 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1336 );
1337 assert!(errs.is_empty());
1338 }
1339
1340 #[test]
1341 fn test_covariance_return_type() {
1342 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1343 assert!(errs.is_empty());
1344 }
1345
1346 #[test]
1347 fn test_no_contravariance_float_to_int() {
1348 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1349 assert_eq!(errs.len(), 1);
1350 }
1351
1352 fn warnings(source: &str) -> Vec<String> {
1355 check_source(source)
1356 .into_iter()
1357 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1358 .map(|d| d.message)
1359 .collect()
1360 }
1361
1362 #[test]
1363 fn test_exhaustive_match_no_warning() {
1364 let warns = warnings(
1365 r#"pipeline t(task) {
1366 enum Color { Red, Green, Blue }
1367 let c = Color.Red
1368 match c.variant {
1369 "Red" -> { log("r") }
1370 "Green" -> { log("g") }
1371 "Blue" -> { log("b") }
1372 }
1373}"#,
1374 );
1375 let exhaustive_warns: Vec<_> = warns
1376 .iter()
1377 .filter(|w| w.contains("Non-exhaustive"))
1378 .collect();
1379 assert!(exhaustive_warns.is_empty());
1380 }
1381
1382 #[test]
1383 fn test_non_exhaustive_match_warning() {
1384 let warns = warnings(
1385 r#"pipeline t(task) {
1386 enum Color { Red, Green, Blue }
1387 let c = Color.Red
1388 match c.variant {
1389 "Red" -> { log("r") }
1390 "Green" -> { log("g") }
1391 }
1392}"#,
1393 );
1394 let exhaustive_warns: Vec<_> = warns
1395 .iter()
1396 .filter(|w| w.contains("Non-exhaustive"))
1397 .collect();
1398 assert_eq!(exhaustive_warns.len(), 1);
1399 assert!(exhaustive_warns[0].contains("Blue"));
1400 }
1401
1402 #[test]
1403 fn test_non_exhaustive_multiple_missing() {
1404 let warns = warnings(
1405 r#"pipeline t(task) {
1406 enum Status { Active, Inactive, Pending }
1407 let s = Status.Active
1408 match s.variant {
1409 "Active" -> { log("a") }
1410 }
1411}"#,
1412 );
1413 let exhaustive_warns: Vec<_> = warns
1414 .iter()
1415 .filter(|w| w.contains("Non-exhaustive"))
1416 .collect();
1417 assert_eq!(exhaustive_warns.len(), 1);
1418 assert!(exhaustive_warns[0].contains("Inactive"));
1419 assert!(exhaustive_warns[0].contains("Pending"));
1420 }
1421
1422 #[test]
1423 fn test_enum_construct_type_inference() {
1424 let errs = errors(
1425 r#"pipeline t(task) {
1426 enum Color { Red, Green, Blue }
1427 let c: Color = Color.Red
1428}"#,
1429 );
1430 assert!(errs.is_empty());
1431 }
1432
1433 #[test]
1436 fn test_nil_coalescing_strips_nil() {
1437 let errs = errors(
1439 r#"pipeline t(task) {
1440 let x: string | nil = nil
1441 let y: string = x ?? "default"
1442}"#,
1443 );
1444 assert!(errs.is_empty());
1445 }
1446
1447 #[test]
1448 fn test_shape_mismatch_detail_missing_field() {
1449 let errs = errors(
1450 r#"pipeline t(task) {
1451 let x: {name: string, age: int} = {name: "hello"}
1452}"#,
1453 );
1454 assert_eq!(errs.len(), 1);
1455 assert!(
1456 errs[0].contains("missing field 'age'"),
1457 "expected detail about missing field, got: {}",
1458 errs[0]
1459 );
1460 }
1461
1462 #[test]
1463 fn test_shape_mismatch_detail_wrong_type() {
1464 let errs = errors(
1465 r#"pipeline t(task) {
1466 let x: {name: string, age: int} = {name: 42, age: 10}
1467}"#,
1468 );
1469 assert_eq!(errs.len(), 1);
1470 assert!(
1471 errs[0].contains("field 'name' has type int, expected string"),
1472 "expected detail about wrong type, got: {}",
1473 errs[0]
1474 );
1475 }
1476}