1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43 params: Vec<(String, InferredType)>,
44 return_type: InferredType,
45}
46
47impl TypeScope {
48 fn new() -> Self {
49 Self {
50 vars: BTreeMap::new(),
51 functions: BTreeMap::new(),
52 type_aliases: BTreeMap::new(),
53 enums: BTreeMap::new(),
54 interfaces: BTreeMap::new(),
55 structs: BTreeMap::new(),
56 parent: None,
57 }
58 }
59
60 fn child(&self) -> Self {
61 Self {
62 vars: BTreeMap::new(),
63 functions: BTreeMap::new(),
64 type_aliases: BTreeMap::new(),
65 enums: BTreeMap::new(),
66 interfaces: BTreeMap::new(),
67 structs: BTreeMap::new(),
68 parent: Some(Box::new(self.clone())),
69 }
70 }
71
72 fn get_var(&self, name: &str) -> Option<&InferredType> {
73 self.vars
74 .get(name)
75 .or_else(|| self.parent.as_ref()?.get_var(name))
76 }
77
78 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79 self.functions
80 .get(name)
81 .or_else(|| self.parent.as_ref()?.get_fn(name))
82 }
83
84 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85 self.type_aliases
86 .get(name)
87 .or_else(|| self.parent.as_ref()?.resolve_type(name))
88 }
89
90 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91 self.enums
92 .get(name)
93 .or_else(|| self.parent.as_ref()?.get_enum(name))
94 }
95
96 #[allow(dead_code)]
97 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98 self.interfaces
99 .get(name)
100 .or_else(|| self.parent.as_ref()?.get_interface(name))
101 }
102
103 fn define_var(&mut self, name: &str, ty: InferredType) {
104 self.vars.insert(name.to_string(), ty);
105 }
106
107 fn define_fn(&mut self, name: &str, sig: FnSignature) {
108 self.functions.insert(name.to_string(), sig);
109 }
110}
111
112fn builtin_return_type(name: &str) -> InferredType {
114 match name {
115 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117 Some(TypeExpr::Named("nil".into()))
118 }
119 "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120 | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121 | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122 "to_int" => Some(TypeExpr::Named("int".into())),
123 "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124 "file_exists" => Some(TypeExpr::Named("bool".into())),
125 "list_dir" => Some(TypeExpr::Named("list".into())),
126 "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127 "env" | "regex_match" => Some(TypeExpr::Union(vec![
128 TypeExpr::Named("string".into()),
129 TypeExpr::Named("nil".into()),
130 ])),
131 "json_parse" => None, _ => None,
133 }
134}
135
136fn is_builtin(name: &str) -> bool {
138 matches!(
139 name,
140 "log"
141 | "print"
142 | "println"
143 | "type_of"
144 | "to_string"
145 | "to_int"
146 | "to_float"
147 | "json_stringify"
148 | "json_parse"
149 | "env"
150 | "timestamp"
151 | "sleep"
152 | "read_file"
153 | "write_file"
154 | "exit"
155 | "regex_match"
156 | "regex_replace"
157 | "http_get"
158 | "http_post"
159 | "llm_call"
160 | "agent_loop"
161 | "await"
162 | "cancel"
163 | "file_exists"
164 | "delete_file"
165 | "list_dir"
166 | "mkdir"
167 | "path_join"
168 | "copy_file"
169 | "append_file"
170 | "temp_dir"
171 | "stat"
172 | "exec"
173 | "shell"
174 | "date_now"
175 | "date_format"
176 | "date_parse"
177 | "format"
178 )
179}
180
181pub struct TypeChecker {
183 diagnostics: Vec<TypeDiagnostic>,
184 scope: TypeScope,
185}
186
187impl TypeChecker {
188 pub fn new() -> Self {
189 Self {
190 diagnostics: Vec::new(),
191 scope: TypeScope::new(),
192 }
193 }
194
195 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
197 Self::register_declarations_into(&mut self.scope, program);
199
200 for snode in program {
202 if let Node::Pipeline { body, .. } = &snode.node {
203 Self::register_declarations_into(&mut self.scope, body);
204 }
205 }
206
207 for snode in program {
209 match &snode.node {
210 Node::Pipeline { params, body, .. } => {
211 let mut child = self.scope.child();
212 for p in params {
213 child.define_var(p, None);
214 }
215 self.check_block(body, &mut child);
216 }
217 Node::FnDecl {
218 name,
219 params,
220 return_type,
221 body,
222 ..
223 } => {
224 let sig = FnSignature {
225 params: params
226 .iter()
227 .map(|p| (p.name.clone(), p.type_expr.clone()))
228 .collect(),
229 return_type: return_type.clone(),
230 };
231 self.scope.define_fn(name, sig);
232 self.check_fn_body(params, return_type, body);
233 }
234 _ => {
235 let mut scope = self.scope.clone();
236 self.check_node(snode, &mut scope);
237 for (name, ty) in scope.vars {
239 self.scope.vars.entry(name).or_insert(ty);
240 }
241 }
242 }
243 }
244
245 self.diagnostics
246 }
247
248 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
250 for snode in nodes {
251 match &snode.node {
252 Node::TypeDecl { name, type_expr } => {
253 scope.type_aliases.insert(name.clone(), type_expr.clone());
254 }
255 Node::EnumDecl { name, variants } => {
256 let variant_names: Vec<String> =
257 variants.iter().map(|v| v.name.clone()).collect();
258 scope.enums.insert(name.clone(), variant_names);
259 }
260 Node::InterfaceDecl { name, methods } => {
261 scope.interfaces.insert(name.clone(), methods.clone());
262 }
263 Node::StructDecl { name, fields } => {
264 let field_types: Vec<(String, InferredType)> = fields
265 .iter()
266 .map(|f| (f.name.clone(), f.type_expr.clone()))
267 .collect();
268 scope.structs.insert(name.clone(), field_types);
269 }
270 _ => {}
271 }
272 }
273 }
274
275 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
276 for stmt in stmts {
277 self.check_node(stmt, scope);
278 }
279 }
280
281 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
282 let span = snode.span;
283 match &snode.node {
284 Node::LetBinding {
285 name,
286 type_ann,
287 value,
288 } => {
289 let inferred = self.infer_type(value, scope);
290 if let Some(expected) = type_ann {
291 if let Some(actual) = &inferred {
292 if !self.types_compatible(expected, actual, scope) {
293 self.error_at(
294 format!(
295 "Type mismatch: '{}' declared as {}, but assigned {}",
296 name,
297 format_type(expected),
298 format_type(actual)
299 ),
300 span,
301 );
302 }
303 }
304 }
305 let ty = type_ann.clone().or(inferred);
306 scope.define_var(name, ty);
307 }
308
309 Node::VarBinding {
310 name,
311 type_ann,
312 value,
313 } => {
314 let inferred = self.infer_type(value, scope);
315 if let Some(expected) = type_ann {
316 if let Some(actual) = &inferred {
317 if !self.types_compatible(expected, actual, scope) {
318 self.error_at(
319 format!(
320 "Type mismatch: '{}' declared as {}, but assigned {}",
321 name,
322 format_type(expected),
323 format_type(actual)
324 ),
325 span,
326 );
327 }
328 }
329 }
330 let ty = type_ann.clone().or(inferred);
331 scope.define_var(name, ty);
332 }
333
334 Node::FnDecl {
335 name,
336 params,
337 return_type,
338 body,
339 ..
340 } => {
341 let sig = FnSignature {
342 params: params
343 .iter()
344 .map(|p| (p.name.clone(), p.type_expr.clone()))
345 .collect(),
346 return_type: return_type.clone(),
347 };
348 scope.define_fn(name, sig.clone());
349 scope.define_var(name, None);
350 self.check_fn_body(params, return_type, body);
351 }
352
353 Node::FunctionCall { name, args } => {
354 self.check_call(name, args, scope, span);
355 }
356
357 Node::IfElse {
358 condition,
359 then_body,
360 else_body,
361 } => {
362 self.check_node(condition, scope);
363 let mut then_scope = scope.child();
364 self.check_block(then_body, &mut then_scope);
365 if let Some(else_body) = else_body {
366 let mut else_scope = scope.child();
367 self.check_block(else_body, &mut else_scope);
368 }
369 }
370
371 Node::ForIn {
372 variable,
373 iterable,
374 body,
375 } => {
376 self.check_node(iterable, scope);
377 let mut loop_scope = scope.child();
378 let elem_type = match self.infer_type(iterable, scope) {
380 Some(TypeExpr::List(inner)) => Some(*inner),
381 Some(TypeExpr::Named(n)) if n == "string" => {
382 Some(TypeExpr::Named("string".into()))
383 }
384 _ => None,
385 };
386 loop_scope.define_var(variable, elem_type);
387 self.check_block(body, &mut loop_scope);
388 }
389
390 Node::WhileLoop { condition, body } => {
391 self.check_node(condition, scope);
392 let mut loop_scope = scope.child();
393 self.check_block(body, &mut loop_scope);
394 }
395
396 Node::TryCatch {
397 body,
398 error_var,
399 catch_body,
400 ..
401 } => {
402 let mut try_scope = scope.child();
403 self.check_block(body, &mut try_scope);
404 let mut catch_scope = scope.child();
405 if let Some(var) = error_var {
406 catch_scope.define_var(var, None);
407 }
408 self.check_block(catch_body, &mut catch_scope);
409 }
410
411 Node::ReturnStmt {
412 value: Some(val), ..
413 } => {
414 self.check_node(val, scope);
415 }
416
417 Node::Assignment {
418 target, value, op, ..
419 } => {
420 self.check_node(value, scope);
421 if let Node::Identifier(name) = &target.node {
422 if let Some(Some(var_type)) = scope.get_var(name) {
423 let value_type = self.infer_type(value, scope);
424 let assigned = if let Some(op) = op {
425 let var_inferred = scope.get_var(name).cloned().flatten();
426 infer_binary_op_type(op, &var_inferred, &value_type)
427 } else {
428 value_type
429 };
430 if let Some(actual) = &assigned {
431 if !self.types_compatible(var_type, actual, scope) {
432 self.error_at(
433 format!(
434 "Type mismatch: cannot assign {} to '{}' (declared as {})",
435 format_type(actual),
436 name,
437 format_type(var_type)
438 ),
439 span,
440 );
441 }
442 }
443 }
444 }
445 }
446
447 Node::TypeDecl { name, type_expr } => {
448 scope.type_aliases.insert(name.clone(), type_expr.clone());
449 }
450
451 Node::EnumDecl { name, variants } => {
452 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
453 scope.enums.insert(name.clone(), variant_names);
454 }
455
456 Node::StructDecl { name, fields } => {
457 let field_types: Vec<(String, InferredType)> = fields
458 .iter()
459 .map(|f| (f.name.clone(), f.type_expr.clone()))
460 .collect();
461 scope.structs.insert(name.clone(), field_types);
462 }
463
464 Node::InterfaceDecl { name, methods } => {
465 scope.interfaces.insert(name.clone(), methods.clone());
466 }
467
468 Node::MatchExpr { value, arms } => {
469 self.check_node(value, scope);
470 for arm in arms {
471 self.check_node(&arm.pattern, scope);
472 let mut arm_scope = scope.child();
473 self.check_block(&arm.body, &mut arm_scope);
474 }
475 self.check_match_exhaustiveness(value, arms, scope, span);
476 }
477
478 Node::BinaryOp { op, left, right } => {
480 self.check_node(left, scope);
481 self.check_node(right, scope);
482 let lt = self.infer_type(left, scope);
484 let rt = self.infer_type(right, scope);
485 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
486 match op.as_str() {
487 "-" | "*" | "/" | "%" => {
488 let numeric = ["int", "float"];
489 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
490 self.warning_at(
491 format!(
492 "Operator '{op}' may not be valid for types {} and {}",
493 l, r
494 ),
495 span,
496 );
497 }
498 }
499 "+" => {
500 let valid = ["int", "float", "string", "list", "dict"];
502 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
503 self.warning_at(
504 format!(
505 "Operator '+' may not be valid for types {} and {}",
506 l, r
507 ),
508 span,
509 );
510 }
511 }
512 _ => {}
513 }
514 }
515 }
516 Node::UnaryOp { operand, .. } => {
517 self.check_node(operand, scope);
518 }
519 Node::MethodCall { object, args, .. } => {
520 self.check_node(object, scope);
521 for arg in args {
522 self.check_node(arg, scope);
523 }
524 }
525 Node::PropertyAccess { object, .. } => {
526 self.check_node(object, scope);
527 }
528 Node::SubscriptAccess { object, index } => {
529 self.check_node(object, scope);
530 self.check_node(index, scope);
531 }
532
533 _ => {}
535 }
536 }
537
538 fn check_fn_body(
539 &mut self,
540 params: &[TypedParam],
541 return_type: &Option<TypeExpr>,
542 body: &[SNode],
543 ) {
544 let mut fn_scope = self.scope.child();
545 for param in params {
546 fn_scope.define_var(¶m.name, param.type_expr.clone());
547 }
548 self.check_block(body, &mut fn_scope);
549
550 if let Some(ret_type) = return_type {
552 for stmt in body {
553 self.check_return_type(stmt, ret_type, &fn_scope);
554 }
555 }
556 }
557
558 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
559 let span = snode.span;
560 match &snode.node {
561 Node::ReturnStmt { value: Some(val) } => {
562 let inferred = self.infer_type(val, scope);
563 if let Some(actual) = &inferred {
564 if !self.types_compatible(expected, actual, scope) {
565 self.error_at(
566 format!(
567 "Return type mismatch: expected {}, got {}",
568 format_type(expected),
569 format_type(actual)
570 ),
571 span,
572 );
573 }
574 }
575 }
576 Node::IfElse {
577 then_body,
578 else_body,
579 ..
580 } => {
581 for stmt in then_body {
582 self.check_return_type(stmt, expected, scope);
583 }
584 if let Some(else_body) = else_body {
585 for stmt in else_body {
586 self.check_return_type(stmt, expected, scope);
587 }
588 }
589 }
590 _ => {}
591 }
592 }
593
594 fn check_match_exhaustiveness(
596 &mut self,
597 value: &SNode,
598 arms: &[MatchArm],
599 scope: &TypeScope,
600 span: Span,
601 ) {
602 let enum_name = match &value.node {
604 Node::PropertyAccess { object, property } if property == "variant" => {
605 match self.infer_type(object, scope) {
607 Some(TypeExpr::Named(name)) => {
608 if scope.get_enum(&name).is_some() {
609 Some(name)
610 } else {
611 None
612 }
613 }
614 _ => None,
615 }
616 }
617 _ => {
618 match self.infer_type(value, scope) {
620 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
621 _ => None,
622 }
623 }
624 };
625
626 let Some(enum_name) = enum_name else {
627 return;
628 };
629 let Some(variants) = scope.get_enum(&enum_name) else {
630 return;
631 };
632
633 let mut covered: Vec<String> = Vec::new();
635 let mut has_wildcard = false;
636
637 for arm in arms {
638 match &arm.pattern.node {
639 Node::StringLiteral(s) => covered.push(s.clone()),
641 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
643 has_wildcard = true;
644 }
645 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
647 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
649 _ => {
650 has_wildcard = true;
652 }
653 }
654 }
655
656 if has_wildcard {
657 return;
658 }
659
660 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
661 if !missing.is_empty() {
662 let missing_str = missing
663 .iter()
664 .map(|s| format!("\"{}\"", s))
665 .collect::<Vec<_>>()
666 .join(", ");
667 self.warning_at(
668 format!(
669 "Non-exhaustive match on enum {}: missing variants {}",
670 enum_name, missing_str
671 ),
672 span,
673 );
674 }
675 }
676
677 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
678 if let Some(sig) = scope.get_fn(name).cloned() {
680 if args.len() != sig.params.len() && !is_builtin(name) {
681 self.warning_at(
682 format!(
683 "Function '{}' expects {} arguments, got {}",
684 name,
685 sig.params.len(),
686 args.len()
687 ),
688 span,
689 );
690 }
691 for (i, (arg, (param_name, param_type))) in
692 args.iter().zip(sig.params.iter()).enumerate()
693 {
694 if let Some(expected) = param_type {
695 let actual = self.infer_type(arg, scope);
696 if let Some(actual) = &actual {
697 if !self.types_compatible(expected, actual, scope) {
698 self.error_at(
699 format!(
700 "Argument {} ('{}'): expected {}, got {}",
701 i + 1,
702 param_name,
703 format_type(expected),
704 format_type(actual)
705 ),
706 arg.span,
707 );
708 }
709 }
710 }
711 }
712 }
713 for arg in args {
715 self.check_node(arg, scope);
716 }
717 }
718
719 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
721 match &snode.node {
722 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
723 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
724 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
725 Some(TypeExpr::Named("string".into()))
726 }
727 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
728 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
729 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
730 Node::DictLiteral(_) => Some(TypeExpr::Named("dict".into())),
731 Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
732
733 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
734
735 Node::FunctionCall { name, .. } => {
736 if let Some(sig) = scope.get_fn(name) {
738 return sig.return_type.clone();
739 }
740 builtin_return_type(name)
742 }
743
744 Node::BinaryOp { op, left, right } => {
745 let lt = self.infer_type(left, scope);
746 let rt = self.infer_type(right, scope);
747 infer_binary_op_type(op, <, &rt)
748 }
749
750 Node::UnaryOp { op, operand } => {
751 let t = self.infer_type(operand, scope);
752 match op.as_str() {
753 "!" => Some(TypeExpr::Named("bool".into())),
754 "-" => t, _ => None,
756 }
757 }
758
759 Node::Ternary {
760 true_expr,
761 false_expr,
762 ..
763 } => {
764 let tt = self.infer_type(true_expr, scope);
765 let ft = self.infer_type(false_expr, scope);
766 match (&tt, &ft) {
767 (Some(a), Some(b)) if a == b => tt,
768 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
769 (Some(_), None) => tt,
770 (None, Some(_)) => ft,
771 (None, None) => None,
772 }
773 }
774
775 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
776
777 Node::PropertyAccess { object, property } => {
778 if let Node::Identifier(name) = &object.node {
780 if scope.get_enum(name).is_some() {
781 return Some(TypeExpr::Named(name.clone()));
782 }
783 }
784 if property == "variant" {
786 let obj_type = self.infer_type(object, scope);
787 if let Some(TypeExpr::Named(name)) = &obj_type {
788 if scope.get_enum(name).is_some() {
789 return Some(TypeExpr::Named("string".into()));
790 }
791 }
792 }
793 None
794 }
795
796 Node::SubscriptAccess { object, .. } => {
797 let obj_type = self.infer_type(object, scope);
798 match &obj_type {
799 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
800 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
801 Some(TypeExpr::Named(n)) if n == "list" => None,
802 Some(TypeExpr::Named(n)) if n == "dict" => None,
803 Some(TypeExpr::Named(n)) if n == "string" => {
804 Some(TypeExpr::Named("string".into()))
805 }
806 _ => None,
807 }
808 }
809 Node::MethodCall { object, method, .. } => {
810 let obj_type = self.infer_type(object, scope);
811 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
812 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
813 match method.as_str() {
814 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
816 Some(TypeExpr::Named("bool".into()))
817 }
818 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
820 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
822 | "pad_left" | "pad_right" | "repeat" | "join" => {
823 Some(TypeExpr::Named("string".into()))
824 }
825 "split" | "chars" => Some(TypeExpr::Named("list".into())),
826 "filter" => {
828 if is_dict {
829 Some(TypeExpr::Named("dict".into()))
830 } else {
831 Some(TypeExpr::Named("list".into()))
832 }
833 }
834 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
836 "reduce" | "find" | "first" | "last" => None,
837 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
839 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
840 "to_string" => Some(TypeExpr::Named("string".into())),
842 "to_int" => Some(TypeExpr::Named("int".into())),
843 "to_float" => Some(TypeExpr::Named("float".into())),
844 _ => None,
845 }
846 }
847
848 _ => None,
849 }
850 }
851
852 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
854 let expected = self.resolve_alias(expected, scope);
855 let actual = self.resolve_alias(actual, scope);
856
857 match (&expected, &actual) {
858 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
859 (TypeExpr::Union(members), actual_type) => members
860 .iter()
861 .any(|m| self.types_compatible(m, actual_type, scope)),
862 (expected_type, TypeExpr::Union(members)) => members
863 .iter()
864 .all(|m| self.types_compatible(expected_type, m, scope)),
865 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
866 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
867 if expected_field.optional {
868 return true;
869 }
870 af.iter().any(|actual_field| {
871 actual_field.name == expected_field.name
872 && self.types_compatible(
873 &expected_field.type_expr,
874 &actual_field.type_expr,
875 scope,
876 )
877 })
878 }),
879 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
880 self.types_compatible(expected_inner, actual_inner, scope)
881 }
882 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
883 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
884 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
885 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
886 }
887 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
888 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
889 _ => false,
890 }
891 }
892
893 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
894 if let TypeExpr::Named(name) = ty {
895 if let Some(resolved) = scope.resolve_type(name) {
896 return resolved.clone();
897 }
898 }
899 ty.clone()
900 }
901
902 fn error_at(&mut self, message: String, span: Span) {
903 self.diagnostics.push(TypeDiagnostic {
904 message,
905 severity: DiagnosticSeverity::Error,
906 span: Some(span),
907 });
908 }
909
910 fn warning_at(&mut self, message: String, span: Span) {
911 self.diagnostics.push(TypeDiagnostic {
912 message,
913 severity: DiagnosticSeverity::Warning,
914 span: Some(span),
915 });
916 }
917}
918
919impl Default for TypeChecker {
920 fn default() -> Self {
921 Self::new()
922 }
923}
924
925fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
927 match op {
928 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
929 "+" => match (left, right) {
930 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
931 match (l.as_str(), r.as_str()) {
932 ("int", "int") => Some(TypeExpr::Named("int".into())),
933 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
934 ("string", _) => Some(TypeExpr::Named("string".into())),
935 ("list", "list") => Some(TypeExpr::Named("list".into())),
936 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
937 _ => Some(TypeExpr::Named("string".into())),
938 }
939 }
940 _ => None,
941 },
942 "-" | "*" | "/" | "%" => match (left, right) {
943 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
944 match (l.as_str(), r.as_str()) {
945 ("int", "int") => Some(TypeExpr::Named("int".into())),
946 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
947 _ => None,
948 }
949 }
950 _ => None,
951 },
952 "??" => match (left, right) {
953 (Some(TypeExpr::Union(members)), _) => {
954 let non_nil: Vec<_> = members
955 .iter()
956 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
957 .cloned()
958 .collect();
959 if non_nil.len() == 1 {
960 Some(non_nil[0].clone())
961 } else if non_nil.is_empty() {
962 right.clone()
963 } else {
964 Some(TypeExpr::Union(non_nil))
965 }
966 }
967 _ => right.clone(),
968 },
969 "|>" => None,
970 _ => None,
971 }
972}
973
974pub fn format_type(ty: &TypeExpr) -> String {
976 match ty {
977 TypeExpr::Named(n) => n.clone(),
978 TypeExpr::Union(types) => types
979 .iter()
980 .map(format_type)
981 .collect::<Vec<_>>()
982 .join(" | "),
983 TypeExpr::Shape(fields) => {
984 let inner: Vec<String> = fields
985 .iter()
986 .map(|f| {
987 let opt = if f.optional { "?" } else { "" };
988 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
989 })
990 .collect();
991 format!("{{{}}}", inner.join(", "))
992 }
993 TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
994 TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use super::*;
1001 use crate::Parser;
1002 use harn_lexer::Lexer;
1003
1004 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1005 let mut lexer = Lexer::new(source);
1006 let tokens = lexer.tokenize().unwrap();
1007 let mut parser = Parser::new(tokens);
1008 let program = parser.parse().unwrap();
1009 TypeChecker::new().check(&program)
1010 }
1011
1012 fn errors(source: &str) -> Vec<String> {
1013 check_source(source)
1014 .into_iter()
1015 .filter(|d| d.severity == DiagnosticSeverity::Error)
1016 .map(|d| d.message)
1017 .collect()
1018 }
1019
1020 #[test]
1021 fn test_no_errors_for_untyped_code() {
1022 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1023 assert!(errs.is_empty());
1024 }
1025
1026 #[test]
1027 fn test_correct_typed_let() {
1028 let errs = errors("pipeline t(task) { let x: int = 42 }");
1029 assert!(errs.is_empty());
1030 }
1031
1032 #[test]
1033 fn test_type_mismatch_let() {
1034 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1035 assert_eq!(errs.len(), 1);
1036 assert!(errs[0].contains("Type mismatch"));
1037 assert!(errs[0].contains("int"));
1038 assert!(errs[0].contains("string"));
1039 }
1040
1041 #[test]
1042 fn test_correct_typed_fn() {
1043 let errs = errors(
1044 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1045 );
1046 assert!(errs.is_empty());
1047 }
1048
1049 #[test]
1050 fn test_fn_arg_type_mismatch() {
1051 let errs = errors(
1052 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1053add("hello", 2) }"#,
1054 );
1055 assert_eq!(errs.len(), 1);
1056 assert!(errs[0].contains("Argument 1"));
1057 assert!(errs[0].contains("expected int"));
1058 }
1059
1060 #[test]
1061 fn test_return_type_mismatch() {
1062 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1063 assert_eq!(errs.len(), 1);
1064 assert!(errs[0].contains("Return type mismatch"));
1065 }
1066
1067 #[test]
1068 fn test_union_type_compatible() {
1069 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1070 assert!(errs.is_empty());
1071 }
1072
1073 #[test]
1074 fn test_union_type_mismatch() {
1075 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1076 assert_eq!(errs.len(), 1);
1077 assert!(errs[0].contains("Type mismatch"));
1078 }
1079
1080 #[test]
1081 fn test_type_inference_propagation() {
1082 let errs = errors(
1083 r#"pipeline t(task) {
1084 fn add(a: int, b: int) -> int { return a + b }
1085 let result: string = add(1, 2)
1086}"#,
1087 );
1088 assert_eq!(errs.len(), 1);
1089 assert!(errs[0].contains("Type mismatch"));
1090 assert!(errs[0].contains("string"));
1091 assert!(errs[0].contains("int"));
1092 }
1093
1094 #[test]
1095 fn test_builtin_return_type_inference() {
1096 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1097 assert_eq!(errs.len(), 1);
1098 assert!(errs[0].contains("string"));
1099 assert!(errs[0].contains("int"));
1100 }
1101
1102 #[test]
1103 fn test_binary_op_type_inference() {
1104 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1105 assert_eq!(errs.len(), 1);
1106 }
1107
1108 #[test]
1109 fn test_comparison_returns_bool() {
1110 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1111 assert!(errs.is_empty());
1112 }
1113
1114 #[test]
1115 fn test_int_float_promotion() {
1116 let errs = errors("pipeline t(task) { let x: float = 42 }");
1117 assert!(errs.is_empty());
1118 }
1119
1120 #[test]
1121 fn test_untyped_code_no_errors() {
1122 let errs = errors(
1123 r#"pipeline t(task) {
1124 fn process(data) {
1125 let result = data + " processed"
1126 return result
1127 }
1128 log(process("hello"))
1129}"#,
1130 );
1131 assert!(errs.is_empty());
1132 }
1133
1134 #[test]
1135 fn test_type_alias() {
1136 let errs = errors(
1137 r#"pipeline t(task) {
1138 type Name = string
1139 let x: Name = "hello"
1140}"#,
1141 );
1142 assert!(errs.is_empty());
1143 }
1144
1145 #[test]
1146 fn test_type_alias_mismatch() {
1147 let errs = errors(
1148 r#"pipeline t(task) {
1149 type Name = string
1150 let x: Name = 42
1151}"#,
1152 );
1153 assert_eq!(errs.len(), 1);
1154 }
1155
1156 #[test]
1157 fn test_assignment_type_check() {
1158 let errs = errors(
1159 r#"pipeline t(task) {
1160 var x: int = 0
1161 x = "hello"
1162}"#,
1163 );
1164 assert_eq!(errs.len(), 1);
1165 assert!(errs[0].contains("cannot assign string"));
1166 }
1167
1168 #[test]
1169 fn test_covariance_int_to_float_in_fn() {
1170 let errs = errors(
1171 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1172 );
1173 assert!(errs.is_empty());
1174 }
1175
1176 #[test]
1177 fn test_covariance_return_type() {
1178 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1179 assert!(errs.is_empty());
1180 }
1181
1182 #[test]
1183 fn test_no_contravariance_float_to_int() {
1184 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1185 assert_eq!(errs.len(), 1);
1186 }
1187
1188 fn warnings(source: &str) -> Vec<String> {
1191 check_source(source)
1192 .into_iter()
1193 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1194 .map(|d| d.message)
1195 .collect()
1196 }
1197
1198 #[test]
1199 fn test_exhaustive_match_no_warning() {
1200 let warns = warnings(
1201 r#"pipeline t(task) {
1202 enum Color { Red, Green, Blue }
1203 let c = Color.Red
1204 match c.variant {
1205 "Red" -> { log("r") }
1206 "Green" -> { log("g") }
1207 "Blue" -> { log("b") }
1208 }
1209}"#,
1210 );
1211 let exhaustive_warns: Vec<_> = warns
1212 .iter()
1213 .filter(|w| w.contains("Non-exhaustive"))
1214 .collect();
1215 assert!(exhaustive_warns.is_empty());
1216 }
1217
1218 #[test]
1219 fn test_non_exhaustive_match_warning() {
1220 let warns = warnings(
1221 r#"pipeline t(task) {
1222 enum Color { Red, Green, Blue }
1223 let c = Color.Red
1224 match c.variant {
1225 "Red" -> { log("r") }
1226 "Green" -> { log("g") }
1227 }
1228}"#,
1229 );
1230 let exhaustive_warns: Vec<_> = warns
1231 .iter()
1232 .filter(|w| w.contains("Non-exhaustive"))
1233 .collect();
1234 assert_eq!(exhaustive_warns.len(), 1);
1235 assert!(exhaustive_warns[0].contains("Blue"));
1236 }
1237
1238 #[test]
1239 fn test_non_exhaustive_multiple_missing() {
1240 let warns = warnings(
1241 r#"pipeline t(task) {
1242 enum Status { Active, Inactive, Pending }
1243 let s = Status.Active
1244 match s.variant {
1245 "Active" -> { log("a") }
1246 }
1247}"#,
1248 );
1249 let exhaustive_warns: Vec<_> = warns
1250 .iter()
1251 .filter(|w| w.contains("Non-exhaustive"))
1252 .collect();
1253 assert_eq!(exhaustive_warns.len(), 1);
1254 assert!(exhaustive_warns[0].contains("Inactive"));
1255 assert!(exhaustive_warns[0].contains("Pending"));
1256 }
1257
1258 #[test]
1259 fn test_enum_construct_type_inference() {
1260 let errs = errors(
1261 r#"pipeline t(task) {
1262 enum Color { Red, Green, Blue }
1263 let c: Color = Color.Red
1264}"#,
1265 );
1266 assert!(errs.is_empty());
1267 }
1268
1269 #[test]
1272 fn test_nil_coalescing_strips_nil() {
1273 let errs = errors(
1275 r#"pipeline t(task) {
1276 let x: string | nil = nil
1277 let y: string = x ?? "default"
1278}"#,
1279 );
1280 assert!(errs.is_empty());
1281 }
1282}