1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43 params: Vec<(String, InferredType)>,
44 return_type: InferredType,
45}
46
47impl TypeScope {
48 fn new() -> Self {
49 Self {
50 vars: BTreeMap::new(),
51 functions: BTreeMap::new(),
52 type_aliases: BTreeMap::new(),
53 enums: BTreeMap::new(),
54 interfaces: BTreeMap::new(),
55 structs: BTreeMap::new(),
56 parent: None,
57 }
58 }
59
60 fn child(&self) -> Self {
61 Self {
62 vars: BTreeMap::new(),
63 functions: BTreeMap::new(),
64 type_aliases: BTreeMap::new(),
65 enums: BTreeMap::new(),
66 interfaces: BTreeMap::new(),
67 structs: BTreeMap::new(),
68 parent: Some(Box::new(self.clone())),
69 }
70 }
71
72 fn get_var(&self, name: &str) -> Option<&InferredType> {
73 self.vars
74 .get(name)
75 .or_else(|| self.parent.as_ref()?.get_var(name))
76 }
77
78 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79 self.functions
80 .get(name)
81 .or_else(|| self.parent.as_ref()?.get_fn(name))
82 }
83
84 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85 self.type_aliases
86 .get(name)
87 .or_else(|| self.parent.as_ref()?.resolve_type(name))
88 }
89
90 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91 self.enums
92 .get(name)
93 .or_else(|| self.parent.as_ref()?.get_enum(name))
94 }
95
96 #[allow(dead_code)]
97 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98 self.interfaces
99 .get(name)
100 .or_else(|| self.parent.as_ref()?.get_interface(name))
101 }
102
103 fn define_var(&mut self, name: &str, ty: InferredType) {
104 self.vars.insert(name.to_string(), ty);
105 }
106
107 fn define_fn(&mut self, name: &str, sig: FnSignature) {
108 self.functions.insert(name.to_string(), sig);
109 }
110}
111
112fn builtin_return_type(name: &str) -> InferredType {
114 match name {
115 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117 Some(TypeExpr::Named("nil".into()))
118 }
119 "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120 | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121 | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122 "to_int" => Some(TypeExpr::Named("int".into())),
123 "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124 "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125 "list_dir" => Some(TypeExpr::Named("list".into())),
126 "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127 "env" | "regex_match" => Some(TypeExpr::Union(vec![
128 TypeExpr::Named("string".into()),
129 TypeExpr::Named("nil".into()),
130 ])),
131 "json_parse" | "json_extract" => None, _ => None,
133 }
134}
135
136fn is_builtin(name: &str) -> bool {
138 matches!(
139 name,
140 "log"
141 | "print"
142 | "println"
143 | "type_of"
144 | "to_string"
145 | "to_int"
146 | "to_float"
147 | "json_stringify"
148 | "json_parse"
149 | "env"
150 | "timestamp"
151 | "sleep"
152 | "read_file"
153 | "write_file"
154 | "exit"
155 | "regex_match"
156 | "regex_replace"
157 | "http_get"
158 | "http_post"
159 | "llm_call"
160 | "agent_loop"
161 | "await"
162 | "cancel"
163 | "file_exists"
164 | "delete_file"
165 | "list_dir"
166 | "mkdir"
167 | "path_join"
168 | "copy_file"
169 | "append_file"
170 | "temp_dir"
171 | "stat"
172 | "exec"
173 | "shell"
174 | "date_now"
175 | "date_format"
176 | "date_parse"
177 | "format"
178 | "json_validate"
179 | "json_extract"
180 | "trim"
181 | "lowercase"
182 | "uppercase"
183 | "split"
184 | "starts_with"
185 | "ends_with"
186 | "contains"
187 | "replace"
188 | "join"
189 | "len"
190 | "substring"
191 | "dirname"
192 | "basename"
193 | "extname"
194 )
195}
196
197pub struct TypeChecker {
199 diagnostics: Vec<TypeDiagnostic>,
200 scope: TypeScope,
201}
202
203impl TypeChecker {
204 pub fn new() -> Self {
205 Self {
206 diagnostics: Vec::new(),
207 scope: TypeScope::new(),
208 }
209 }
210
211 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213 Self::register_declarations_into(&mut self.scope, program);
215
216 for snode in program {
218 if let Node::Pipeline { body, .. } = &snode.node {
219 Self::register_declarations_into(&mut self.scope, body);
220 }
221 }
222
223 for snode in program {
225 match &snode.node {
226 Node::Pipeline { params, body, .. } => {
227 let mut child = self.scope.child();
228 for p in params {
229 child.define_var(p, None);
230 }
231 self.check_block(body, &mut child);
232 }
233 Node::FnDecl {
234 name,
235 params,
236 return_type,
237 body,
238 ..
239 } => {
240 let sig = FnSignature {
241 params: params
242 .iter()
243 .map(|p| (p.name.clone(), p.type_expr.clone()))
244 .collect(),
245 return_type: return_type.clone(),
246 };
247 self.scope.define_fn(name, sig);
248 self.check_fn_body(params, return_type, body);
249 }
250 _ => {
251 let mut scope = self.scope.clone();
252 self.check_node(snode, &mut scope);
253 for (name, ty) in scope.vars {
255 self.scope.vars.entry(name).or_insert(ty);
256 }
257 }
258 }
259 }
260
261 self.diagnostics
262 }
263
264 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266 for snode in nodes {
267 match &snode.node {
268 Node::TypeDecl { name, type_expr } => {
269 scope.type_aliases.insert(name.clone(), type_expr.clone());
270 }
271 Node::EnumDecl { name, variants } => {
272 let variant_names: Vec<String> =
273 variants.iter().map(|v| v.name.clone()).collect();
274 scope.enums.insert(name.clone(), variant_names);
275 }
276 Node::InterfaceDecl { name, methods } => {
277 scope.interfaces.insert(name.clone(), methods.clone());
278 }
279 Node::StructDecl { name, fields } => {
280 let field_types: Vec<(String, InferredType)> = fields
281 .iter()
282 .map(|f| (f.name.clone(), f.type_expr.clone()))
283 .collect();
284 scope.structs.insert(name.clone(), field_types);
285 }
286 _ => {}
287 }
288 }
289 }
290
291 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292 for stmt in stmts {
293 self.check_node(stmt, scope);
294 }
295 }
296
297 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
299 match pattern {
300 BindingPattern::Identifier(name) => {
301 scope.define_var(name, None);
302 }
303 BindingPattern::Dict(fields) => {
304 for field in fields {
305 let name = field.alias.as_deref().unwrap_or(&field.key);
306 scope.define_var(name, None);
307 }
308 }
309 BindingPattern::List(elements) => {
310 for elem in elements {
311 scope.define_var(&elem.name, None);
312 }
313 }
314 }
315 }
316
317 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
318 let span = snode.span;
319 match &snode.node {
320 Node::LetBinding {
321 pattern,
322 type_ann,
323 value,
324 } => {
325 let inferred = self.infer_type(value, scope);
326 if let BindingPattern::Identifier(name) = pattern {
327 if let Some(expected) = type_ann {
328 if let Some(actual) = &inferred {
329 if !self.types_compatible(expected, actual, scope) {
330 self.error_at(
331 format!(
332 "Type mismatch: '{}' declared as {}, but assigned {}",
333 name,
334 format_type(expected),
335 format_type(actual)
336 ),
337 span,
338 );
339 }
340 }
341 }
342 let ty = type_ann.clone().or(inferred);
343 scope.define_var(name, ty);
344 } else {
345 Self::define_pattern_vars(pattern, scope);
346 }
347 }
348
349 Node::VarBinding {
350 pattern,
351 type_ann,
352 value,
353 } => {
354 let inferred = self.infer_type(value, scope);
355 if let BindingPattern::Identifier(name) = pattern {
356 if let Some(expected) = type_ann {
357 if let Some(actual) = &inferred {
358 if !self.types_compatible(expected, actual, scope) {
359 self.error_at(
360 format!(
361 "Type mismatch: '{}' declared as {}, but assigned {}",
362 name,
363 format_type(expected),
364 format_type(actual)
365 ),
366 span,
367 );
368 }
369 }
370 }
371 let ty = type_ann.clone().or(inferred);
372 scope.define_var(name, ty);
373 } else {
374 Self::define_pattern_vars(pattern, scope);
375 }
376 }
377
378 Node::FnDecl {
379 name,
380 params,
381 return_type,
382 body,
383 ..
384 } => {
385 let sig = FnSignature {
386 params: params
387 .iter()
388 .map(|p| (p.name.clone(), p.type_expr.clone()))
389 .collect(),
390 return_type: return_type.clone(),
391 };
392 scope.define_fn(name, sig.clone());
393 scope.define_var(name, None);
394 self.check_fn_body(params, return_type, body);
395 }
396
397 Node::FunctionCall { name, args } => {
398 self.check_call(name, args, scope, span);
399 }
400
401 Node::IfElse {
402 condition,
403 then_body,
404 else_body,
405 } => {
406 self.check_node(condition, scope);
407 let mut then_scope = scope.child();
408 self.check_block(then_body, &mut then_scope);
409 if let Some(else_body) = else_body {
410 let mut else_scope = scope.child();
411 self.check_block(else_body, &mut else_scope);
412 }
413 }
414
415 Node::ForIn {
416 pattern,
417 iterable,
418 body,
419 } => {
420 self.check_node(iterable, scope);
421 let mut loop_scope = scope.child();
422 if let BindingPattern::Identifier(variable) = pattern {
423 let elem_type = match self.infer_type(iterable, scope) {
425 Some(TypeExpr::List(inner)) => Some(*inner),
426 Some(TypeExpr::Named(n)) if n == "string" => {
427 Some(TypeExpr::Named("string".into()))
428 }
429 _ => None,
430 };
431 loop_scope.define_var(variable, elem_type);
432 } else {
433 Self::define_pattern_vars(pattern, &mut loop_scope);
434 }
435 self.check_block(body, &mut loop_scope);
436 }
437
438 Node::WhileLoop { condition, body } => {
439 self.check_node(condition, scope);
440 let mut loop_scope = scope.child();
441 self.check_block(body, &mut loop_scope);
442 }
443
444 Node::TryCatch {
445 body,
446 error_var,
447 catch_body,
448 ..
449 } => {
450 let mut try_scope = scope.child();
451 self.check_block(body, &mut try_scope);
452 let mut catch_scope = scope.child();
453 if let Some(var) = error_var {
454 catch_scope.define_var(var, None);
455 }
456 self.check_block(catch_body, &mut catch_scope);
457 }
458
459 Node::ReturnStmt {
460 value: Some(val), ..
461 } => {
462 self.check_node(val, scope);
463 }
464
465 Node::Assignment {
466 target, value, op, ..
467 } => {
468 self.check_node(value, scope);
469 if let Node::Identifier(name) = &target.node {
470 if let Some(Some(var_type)) = scope.get_var(name) {
471 let value_type = self.infer_type(value, scope);
472 let assigned = if let Some(op) = op {
473 let var_inferred = scope.get_var(name).cloned().flatten();
474 infer_binary_op_type(op, &var_inferred, &value_type)
475 } else {
476 value_type
477 };
478 if let Some(actual) = &assigned {
479 if !self.types_compatible(var_type, actual, scope) {
480 self.error_at(
481 format!(
482 "Type mismatch: cannot assign {} to '{}' (declared as {})",
483 format_type(actual),
484 name,
485 format_type(var_type)
486 ),
487 span,
488 );
489 }
490 }
491 }
492 }
493 }
494
495 Node::TypeDecl { name, type_expr } => {
496 scope.type_aliases.insert(name.clone(), type_expr.clone());
497 }
498
499 Node::EnumDecl { name, variants } => {
500 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
501 scope.enums.insert(name.clone(), variant_names);
502 }
503
504 Node::StructDecl { name, fields } => {
505 let field_types: Vec<(String, InferredType)> = fields
506 .iter()
507 .map(|f| (f.name.clone(), f.type_expr.clone()))
508 .collect();
509 scope.structs.insert(name.clone(), field_types);
510 }
511
512 Node::InterfaceDecl { name, methods } => {
513 scope.interfaces.insert(name.clone(), methods.clone());
514 }
515
516 Node::MatchExpr { value, arms } => {
517 self.check_node(value, scope);
518 for arm in arms {
519 self.check_node(&arm.pattern, scope);
520 let mut arm_scope = scope.child();
521 self.check_block(&arm.body, &mut arm_scope);
522 }
523 self.check_match_exhaustiveness(value, arms, scope, span);
524 }
525
526 Node::BinaryOp { op, left, right } => {
528 self.check_node(left, scope);
529 self.check_node(right, scope);
530 let lt = self.infer_type(left, scope);
532 let rt = self.infer_type(right, scope);
533 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
534 match op.as_str() {
535 "-" | "*" | "/" | "%" => {
536 let numeric = ["int", "float"];
537 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
538 self.warning_at(
539 format!(
540 "Operator '{op}' may not be valid for types {} and {}",
541 l, r
542 ),
543 span,
544 );
545 }
546 }
547 "+" => {
548 let valid = ["int", "float", "string", "list", "dict"];
550 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
551 self.warning_at(
552 format!(
553 "Operator '+' may not be valid for types {} and {}",
554 l, r
555 ),
556 span,
557 );
558 }
559 }
560 _ => {}
561 }
562 }
563 }
564 Node::UnaryOp { operand, .. } => {
565 self.check_node(operand, scope);
566 }
567 Node::MethodCall { object, args, .. }
568 | Node::OptionalMethodCall { object, args, .. } => {
569 self.check_node(object, scope);
570 for arg in args {
571 self.check_node(arg, scope);
572 }
573 }
574 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
575 self.check_node(object, scope);
576 }
577 Node::SubscriptAccess { object, index } => {
578 self.check_node(object, scope);
579 self.check_node(index, scope);
580 }
581 Node::SliceAccess { object, start, end } => {
582 self.check_node(object, scope);
583 if let Some(s) = start {
584 self.check_node(s, scope);
585 }
586 if let Some(e) = end {
587 self.check_node(e, scope);
588 }
589 }
590
591 _ => {}
593 }
594 }
595
596 fn check_fn_body(
597 &mut self,
598 params: &[TypedParam],
599 return_type: &Option<TypeExpr>,
600 body: &[SNode],
601 ) {
602 let mut fn_scope = self.scope.child();
603 for param in params {
604 fn_scope.define_var(¶m.name, param.type_expr.clone());
605 }
606 self.check_block(body, &mut fn_scope);
607
608 if let Some(ret_type) = return_type {
610 for stmt in body {
611 self.check_return_type(stmt, ret_type, &fn_scope);
612 }
613 }
614 }
615
616 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
617 let span = snode.span;
618 match &snode.node {
619 Node::ReturnStmt { value: Some(val) } => {
620 let inferred = self.infer_type(val, scope);
621 if let Some(actual) = &inferred {
622 if !self.types_compatible(expected, actual, scope) {
623 self.error_at(
624 format!(
625 "Return type mismatch: expected {}, got {}",
626 format_type(expected),
627 format_type(actual)
628 ),
629 span,
630 );
631 }
632 }
633 }
634 Node::IfElse {
635 then_body,
636 else_body,
637 ..
638 } => {
639 for stmt in then_body {
640 self.check_return_type(stmt, expected, scope);
641 }
642 if let Some(else_body) = else_body {
643 for stmt in else_body {
644 self.check_return_type(stmt, expected, scope);
645 }
646 }
647 }
648 _ => {}
649 }
650 }
651
652 fn check_match_exhaustiveness(
654 &mut self,
655 value: &SNode,
656 arms: &[MatchArm],
657 scope: &TypeScope,
658 span: Span,
659 ) {
660 let enum_name = match &value.node {
662 Node::PropertyAccess { object, property } if property == "variant" => {
663 match self.infer_type(object, scope) {
665 Some(TypeExpr::Named(name)) => {
666 if scope.get_enum(&name).is_some() {
667 Some(name)
668 } else {
669 None
670 }
671 }
672 _ => None,
673 }
674 }
675 _ => {
676 match self.infer_type(value, scope) {
678 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
679 _ => None,
680 }
681 }
682 };
683
684 let Some(enum_name) = enum_name else {
685 return;
686 };
687 let Some(variants) = scope.get_enum(&enum_name) else {
688 return;
689 };
690
691 let mut covered: Vec<String> = Vec::new();
693 let mut has_wildcard = false;
694
695 for arm in arms {
696 match &arm.pattern.node {
697 Node::StringLiteral(s) => covered.push(s.clone()),
699 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
701 has_wildcard = true;
702 }
703 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
705 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
707 _ => {
708 has_wildcard = true;
710 }
711 }
712 }
713
714 if has_wildcard {
715 return;
716 }
717
718 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
719 if !missing.is_empty() {
720 let missing_str = missing
721 .iter()
722 .map(|s| format!("\"{}\"", s))
723 .collect::<Vec<_>>()
724 .join(", ");
725 self.warning_at(
726 format!(
727 "Non-exhaustive match on enum {}: missing variants {}",
728 enum_name, missing_str
729 ),
730 span,
731 );
732 }
733 }
734
735 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
736 if let Some(sig) = scope.get_fn(name).cloned() {
738 if args.len() != sig.params.len() && !is_builtin(name) {
739 self.warning_at(
740 format!(
741 "Function '{}' expects {} arguments, got {}",
742 name,
743 sig.params.len(),
744 args.len()
745 ),
746 span,
747 );
748 }
749 for (i, (arg, (param_name, param_type))) in
750 args.iter().zip(sig.params.iter()).enumerate()
751 {
752 if let Some(expected) = param_type {
753 let actual = self.infer_type(arg, scope);
754 if let Some(actual) = &actual {
755 if !self.types_compatible(expected, actual, scope) {
756 self.error_at(
757 format!(
758 "Argument {} ('{}'): expected {}, got {}",
759 i + 1,
760 param_name,
761 format_type(expected),
762 format_type(actual)
763 ),
764 arg.span,
765 );
766 }
767 }
768 }
769 }
770 }
771 for arg in args {
773 self.check_node(arg, scope);
774 }
775 }
776
777 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
779 match &snode.node {
780 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
781 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
782 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
783 Some(TypeExpr::Named("string".into()))
784 }
785 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
786 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
787 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
788 Node::DictLiteral(entries) => {
789 let mut fields = Vec::new();
791 let mut all_string_keys = true;
792 for entry in entries {
793 if let Node::StringLiteral(key) = &entry.key.node {
794 let val_type = self
795 .infer_type(&entry.value, scope)
796 .unwrap_or(TypeExpr::Named("nil".into()));
797 fields.push(ShapeField {
798 name: key.clone(),
799 type_expr: val_type,
800 optional: false,
801 });
802 } else {
803 all_string_keys = false;
804 break;
805 }
806 }
807 if all_string_keys && !fields.is_empty() {
808 Some(TypeExpr::Shape(fields))
809 } else {
810 Some(TypeExpr::Named("dict".into()))
811 }
812 }
813 Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
814
815 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
816
817 Node::FunctionCall { name, .. } => {
818 if let Some(sig) = scope.get_fn(name) {
820 return sig.return_type.clone();
821 }
822 builtin_return_type(name)
824 }
825
826 Node::BinaryOp { op, left, right } => {
827 let lt = self.infer_type(left, scope);
828 let rt = self.infer_type(right, scope);
829 infer_binary_op_type(op, <, &rt)
830 }
831
832 Node::UnaryOp { op, operand } => {
833 let t = self.infer_type(operand, scope);
834 match op.as_str() {
835 "!" => Some(TypeExpr::Named("bool".into())),
836 "-" => t, _ => None,
838 }
839 }
840
841 Node::Ternary {
842 true_expr,
843 false_expr,
844 ..
845 } => {
846 let tt = self.infer_type(true_expr, scope);
847 let ft = self.infer_type(false_expr, scope);
848 match (&tt, &ft) {
849 (Some(a), Some(b)) if a == b => tt,
850 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
851 (Some(_), None) => tt,
852 (None, Some(_)) => ft,
853 (None, None) => None,
854 }
855 }
856
857 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
858
859 Node::PropertyAccess { object, property } => {
860 if let Node::Identifier(name) = &object.node {
862 if scope.get_enum(name).is_some() {
863 return Some(TypeExpr::Named(name.clone()));
864 }
865 }
866 if property == "variant" {
868 let obj_type = self.infer_type(object, scope);
869 if let Some(TypeExpr::Named(name)) = &obj_type {
870 if scope.get_enum(name).is_some() {
871 return Some(TypeExpr::Named("string".into()));
872 }
873 }
874 }
875 let obj_type = self.infer_type(object, scope);
877 if let Some(TypeExpr::Shape(fields)) = &obj_type {
878 if let Some(field) = fields.iter().find(|f| f.name == *property) {
879 return Some(field.type_expr.clone());
880 }
881 }
882 None
883 }
884
885 Node::SubscriptAccess { object, index } => {
886 let obj_type = self.infer_type(object, scope);
887 match &obj_type {
888 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
889 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
890 Some(TypeExpr::Shape(fields)) => {
891 if let Node::StringLiteral(key) = &index.node {
893 fields
894 .iter()
895 .find(|f| &f.name == key)
896 .map(|f| f.type_expr.clone())
897 } else {
898 None
899 }
900 }
901 Some(TypeExpr::Named(n)) if n == "list" => None,
902 Some(TypeExpr::Named(n)) if n == "dict" => None,
903 Some(TypeExpr::Named(n)) if n == "string" => {
904 Some(TypeExpr::Named("string".into()))
905 }
906 _ => None,
907 }
908 }
909 Node::SliceAccess { object, .. } => {
910 let obj_type = self.infer_type(object, scope);
912 match &obj_type {
913 Some(TypeExpr::List(_)) => obj_type,
914 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
915 Some(TypeExpr::Named(n)) if n == "string" => {
916 Some(TypeExpr::Named("string".into()))
917 }
918 _ => None,
919 }
920 }
921 Node::MethodCall { object, method, .. }
922 | Node::OptionalMethodCall { object, method, .. } => {
923 let obj_type = self.infer_type(object, scope);
924 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
925 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
926 match method.as_str() {
927 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
929 Some(TypeExpr::Named("bool".into()))
930 }
931 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
933 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
935 | "pad_left" | "pad_right" | "repeat" | "join" => {
936 Some(TypeExpr::Named("string".into()))
937 }
938 "split" | "chars" => Some(TypeExpr::Named("list".into())),
939 "filter" => {
941 if is_dict {
942 Some(TypeExpr::Named("dict".into()))
943 } else {
944 Some(TypeExpr::Named("list".into()))
945 }
946 }
947 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
949 "reduce" | "find" | "first" | "last" => None,
950 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
952 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
953 "to_string" => Some(TypeExpr::Named("string".into())),
955 "to_int" => Some(TypeExpr::Named("int".into())),
956 "to_float" => Some(TypeExpr::Named("float".into())),
957 _ => None,
958 }
959 }
960
961 _ => None,
962 }
963 }
964
965 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
967 let expected = self.resolve_alias(expected, scope);
968 let actual = self.resolve_alias(actual, scope);
969
970 match (&expected, &actual) {
971 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
972 (TypeExpr::Union(members), actual_type) => members
973 .iter()
974 .any(|m| self.types_compatible(m, actual_type, scope)),
975 (expected_type, TypeExpr::Union(members)) => members
976 .iter()
977 .all(|m| self.types_compatible(expected_type, m, scope)),
978 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
979 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
980 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
981 if expected_field.optional {
982 return true;
983 }
984 af.iter().any(|actual_field| {
985 actual_field.name == expected_field.name
986 && self.types_compatible(
987 &expected_field.type_expr,
988 &actual_field.type_expr,
989 scope,
990 )
991 })
992 }),
993 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
995 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
996 keys_ok
997 && af
998 .iter()
999 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1000 }
1001 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1003 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1004 self.types_compatible(expected_inner, actual_inner, scope)
1005 }
1006 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1007 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1008 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1009 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1010 }
1011 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1012 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1013 _ => false,
1014 }
1015 }
1016
1017 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1018 if let TypeExpr::Named(name) = ty {
1019 if let Some(resolved) = scope.resolve_type(name) {
1020 return resolved.clone();
1021 }
1022 }
1023 ty.clone()
1024 }
1025
1026 fn error_at(&mut self, message: String, span: Span) {
1027 self.diagnostics.push(TypeDiagnostic {
1028 message,
1029 severity: DiagnosticSeverity::Error,
1030 span: Some(span),
1031 });
1032 }
1033
1034 fn warning_at(&mut self, message: String, span: Span) {
1035 self.diagnostics.push(TypeDiagnostic {
1036 message,
1037 severity: DiagnosticSeverity::Warning,
1038 span: Some(span),
1039 });
1040 }
1041}
1042
1043impl Default for TypeChecker {
1044 fn default() -> Self {
1045 Self::new()
1046 }
1047}
1048
1049fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1051 match op {
1052 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1053 "+" => match (left, right) {
1054 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1055 match (l.as_str(), r.as_str()) {
1056 ("int", "int") => Some(TypeExpr::Named("int".into())),
1057 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1058 ("string", _) => Some(TypeExpr::Named("string".into())),
1059 ("list", "list") => Some(TypeExpr::Named("list".into())),
1060 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1061 _ => Some(TypeExpr::Named("string".into())),
1062 }
1063 }
1064 _ => None,
1065 },
1066 "-" | "*" | "/" | "%" => match (left, right) {
1067 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1068 match (l.as_str(), r.as_str()) {
1069 ("int", "int") => Some(TypeExpr::Named("int".into())),
1070 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1071 _ => None,
1072 }
1073 }
1074 _ => None,
1075 },
1076 "??" => match (left, right) {
1077 (Some(TypeExpr::Union(members)), _) => {
1078 let non_nil: Vec<_> = members
1079 .iter()
1080 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1081 .cloned()
1082 .collect();
1083 if non_nil.len() == 1 {
1084 Some(non_nil[0].clone())
1085 } else if non_nil.is_empty() {
1086 right.clone()
1087 } else {
1088 Some(TypeExpr::Union(non_nil))
1089 }
1090 }
1091 _ => right.clone(),
1092 },
1093 "|>" => None,
1094 _ => None,
1095 }
1096}
1097
1098pub fn format_type(ty: &TypeExpr) -> String {
1100 match ty {
1101 TypeExpr::Named(n) => n.clone(),
1102 TypeExpr::Union(types) => types
1103 .iter()
1104 .map(format_type)
1105 .collect::<Vec<_>>()
1106 .join(" | "),
1107 TypeExpr::Shape(fields) => {
1108 let inner: Vec<String> = fields
1109 .iter()
1110 .map(|f| {
1111 let opt = if f.optional { "?" } else { "" };
1112 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1113 })
1114 .collect();
1115 format!("{{{}}}", inner.join(", "))
1116 }
1117 TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1118 TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1119 }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124 use super::*;
1125 use crate::Parser;
1126 use harn_lexer::Lexer;
1127
1128 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1129 let mut lexer = Lexer::new(source);
1130 let tokens = lexer.tokenize().unwrap();
1131 let mut parser = Parser::new(tokens);
1132 let program = parser.parse().unwrap();
1133 TypeChecker::new().check(&program)
1134 }
1135
1136 fn errors(source: &str) -> Vec<String> {
1137 check_source(source)
1138 .into_iter()
1139 .filter(|d| d.severity == DiagnosticSeverity::Error)
1140 .map(|d| d.message)
1141 .collect()
1142 }
1143
1144 #[test]
1145 fn test_no_errors_for_untyped_code() {
1146 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1147 assert!(errs.is_empty());
1148 }
1149
1150 #[test]
1151 fn test_correct_typed_let() {
1152 let errs = errors("pipeline t(task) { let x: int = 42 }");
1153 assert!(errs.is_empty());
1154 }
1155
1156 #[test]
1157 fn test_type_mismatch_let() {
1158 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1159 assert_eq!(errs.len(), 1);
1160 assert!(errs[0].contains("Type mismatch"));
1161 assert!(errs[0].contains("int"));
1162 assert!(errs[0].contains("string"));
1163 }
1164
1165 #[test]
1166 fn test_correct_typed_fn() {
1167 let errs = errors(
1168 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1169 );
1170 assert!(errs.is_empty());
1171 }
1172
1173 #[test]
1174 fn test_fn_arg_type_mismatch() {
1175 let errs = errors(
1176 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1177add("hello", 2) }"#,
1178 );
1179 assert_eq!(errs.len(), 1);
1180 assert!(errs[0].contains("Argument 1"));
1181 assert!(errs[0].contains("expected int"));
1182 }
1183
1184 #[test]
1185 fn test_return_type_mismatch() {
1186 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1187 assert_eq!(errs.len(), 1);
1188 assert!(errs[0].contains("Return type mismatch"));
1189 }
1190
1191 #[test]
1192 fn test_union_type_compatible() {
1193 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1194 assert!(errs.is_empty());
1195 }
1196
1197 #[test]
1198 fn test_union_type_mismatch() {
1199 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1200 assert_eq!(errs.len(), 1);
1201 assert!(errs[0].contains("Type mismatch"));
1202 }
1203
1204 #[test]
1205 fn test_type_inference_propagation() {
1206 let errs = errors(
1207 r#"pipeline t(task) {
1208 fn add(a: int, b: int) -> int { return a + b }
1209 let result: string = add(1, 2)
1210}"#,
1211 );
1212 assert_eq!(errs.len(), 1);
1213 assert!(errs[0].contains("Type mismatch"));
1214 assert!(errs[0].contains("string"));
1215 assert!(errs[0].contains("int"));
1216 }
1217
1218 #[test]
1219 fn test_builtin_return_type_inference() {
1220 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1221 assert_eq!(errs.len(), 1);
1222 assert!(errs[0].contains("string"));
1223 assert!(errs[0].contains("int"));
1224 }
1225
1226 #[test]
1227 fn test_binary_op_type_inference() {
1228 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1229 assert_eq!(errs.len(), 1);
1230 }
1231
1232 #[test]
1233 fn test_comparison_returns_bool() {
1234 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1235 assert!(errs.is_empty());
1236 }
1237
1238 #[test]
1239 fn test_int_float_promotion() {
1240 let errs = errors("pipeline t(task) { let x: float = 42 }");
1241 assert!(errs.is_empty());
1242 }
1243
1244 #[test]
1245 fn test_untyped_code_no_errors() {
1246 let errs = errors(
1247 r#"pipeline t(task) {
1248 fn process(data) {
1249 let result = data + " processed"
1250 return result
1251 }
1252 log(process("hello"))
1253}"#,
1254 );
1255 assert!(errs.is_empty());
1256 }
1257
1258 #[test]
1259 fn test_type_alias() {
1260 let errs = errors(
1261 r#"pipeline t(task) {
1262 type Name = string
1263 let x: Name = "hello"
1264}"#,
1265 );
1266 assert!(errs.is_empty());
1267 }
1268
1269 #[test]
1270 fn test_type_alias_mismatch() {
1271 let errs = errors(
1272 r#"pipeline t(task) {
1273 type Name = string
1274 let x: Name = 42
1275}"#,
1276 );
1277 assert_eq!(errs.len(), 1);
1278 }
1279
1280 #[test]
1281 fn test_assignment_type_check() {
1282 let errs = errors(
1283 r#"pipeline t(task) {
1284 var x: int = 0
1285 x = "hello"
1286}"#,
1287 );
1288 assert_eq!(errs.len(), 1);
1289 assert!(errs[0].contains("cannot assign string"));
1290 }
1291
1292 #[test]
1293 fn test_covariance_int_to_float_in_fn() {
1294 let errs = errors(
1295 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1296 );
1297 assert!(errs.is_empty());
1298 }
1299
1300 #[test]
1301 fn test_covariance_return_type() {
1302 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1303 assert!(errs.is_empty());
1304 }
1305
1306 #[test]
1307 fn test_no_contravariance_float_to_int() {
1308 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1309 assert_eq!(errs.len(), 1);
1310 }
1311
1312 fn warnings(source: &str) -> Vec<String> {
1315 check_source(source)
1316 .into_iter()
1317 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1318 .map(|d| d.message)
1319 .collect()
1320 }
1321
1322 #[test]
1323 fn test_exhaustive_match_no_warning() {
1324 let warns = warnings(
1325 r#"pipeline t(task) {
1326 enum Color { Red, Green, Blue }
1327 let c = Color.Red
1328 match c.variant {
1329 "Red" -> { log("r") }
1330 "Green" -> { log("g") }
1331 "Blue" -> { log("b") }
1332 }
1333}"#,
1334 );
1335 let exhaustive_warns: Vec<_> = warns
1336 .iter()
1337 .filter(|w| w.contains("Non-exhaustive"))
1338 .collect();
1339 assert!(exhaustive_warns.is_empty());
1340 }
1341
1342 #[test]
1343 fn test_non_exhaustive_match_warning() {
1344 let warns = warnings(
1345 r#"pipeline t(task) {
1346 enum Color { Red, Green, Blue }
1347 let c = Color.Red
1348 match c.variant {
1349 "Red" -> { log("r") }
1350 "Green" -> { log("g") }
1351 }
1352}"#,
1353 );
1354 let exhaustive_warns: Vec<_> = warns
1355 .iter()
1356 .filter(|w| w.contains("Non-exhaustive"))
1357 .collect();
1358 assert_eq!(exhaustive_warns.len(), 1);
1359 assert!(exhaustive_warns[0].contains("Blue"));
1360 }
1361
1362 #[test]
1363 fn test_non_exhaustive_multiple_missing() {
1364 let warns = warnings(
1365 r#"pipeline t(task) {
1366 enum Status { Active, Inactive, Pending }
1367 let s = Status.Active
1368 match s.variant {
1369 "Active" -> { log("a") }
1370 }
1371}"#,
1372 );
1373 let exhaustive_warns: Vec<_> = warns
1374 .iter()
1375 .filter(|w| w.contains("Non-exhaustive"))
1376 .collect();
1377 assert_eq!(exhaustive_warns.len(), 1);
1378 assert!(exhaustive_warns[0].contains("Inactive"));
1379 assert!(exhaustive_warns[0].contains("Pending"));
1380 }
1381
1382 #[test]
1383 fn test_enum_construct_type_inference() {
1384 let errs = errors(
1385 r#"pipeline t(task) {
1386 enum Color { Red, Green, Blue }
1387 let c: Color = Color.Red
1388}"#,
1389 );
1390 assert!(errs.is_empty());
1391 }
1392
1393 #[test]
1396 fn test_nil_coalescing_strips_nil() {
1397 let errs = errors(
1399 r#"pipeline t(task) {
1400 let x: string | nil = nil
1401 let y: string = x ?? "default"
1402}"#,
1403 );
1404 assert!(errs.is_empty());
1405 }
1406}