1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43 params: Vec<(String, InferredType)>,
44 return_type: InferredType,
45}
46
47impl TypeScope {
48 fn new() -> Self {
49 Self {
50 vars: BTreeMap::new(),
51 functions: BTreeMap::new(),
52 type_aliases: BTreeMap::new(),
53 enums: BTreeMap::new(),
54 interfaces: BTreeMap::new(),
55 structs: BTreeMap::new(),
56 parent: None,
57 }
58 }
59
60 fn child(&self) -> Self {
61 Self {
62 vars: BTreeMap::new(),
63 functions: BTreeMap::new(),
64 type_aliases: BTreeMap::new(),
65 enums: BTreeMap::new(),
66 interfaces: BTreeMap::new(),
67 structs: BTreeMap::new(),
68 parent: Some(Box::new(self.clone())),
69 }
70 }
71
72 fn get_var(&self, name: &str) -> Option<&InferredType> {
73 self.vars
74 .get(name)
75 .or_else(|| self.parent.as_ref()?.get_var(name))
76 }
77
78 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79 self.functions
80 .get(name)
81 .or_else(|| self.parent.as_ref()?.get_fn(name))
82 }
83
84 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85 self.type_aliases
86 .get(name)
87 .or_else(|| self.parent.as_ref()?.resolve_type(name))
88 }
89
90 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91 self.enums
92 .get(name)
93 .or_else(|| self.parent.as_ref()?.get_enum(name))
94 }
95
96 #[allow(dead_code)]
97 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98 self.interfaces
99 .get(name)
100 .or_else(|| self.parent.as_ref()?.get_interface(name))
101 }
102
103 fn define_var(&mut self, name: &str, ty: InferredType) {
104 self.vars.insert(name.to_string(), ty);
105 }
106
107 fn define_fn(&mut self, name: &str, sig: FnSignature) {
108 self.functions.insert(name.to_string(), sig);
109 }
110}
111
112fn builtin_return_type(name: &str) -> InferredType {
114 match name {
115 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117 Some(TypeExpr::Named("nil".into()))
118 }
119 "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120 | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121 | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122 "to_int" => Some(TypeExpr::Named("int".into())),
123 "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124 "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125 "list_dir" => Some(TypeExpr::Named("list".into())),
126 "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127 "env" | "regex_match" => Some(TypeExpr::Union(vec![
128 TypeExpr::Named("string".into()),
129 TypeExpr::Named("nil".into()),
130 ])),
131 "json_parse" | "json_extract" => None, _ => None,
133 }
134}
135
136fn is_builtin(name: &str) -> bool {
138 matches!(
139 name,
140 "log"
141 | "print"
142 | "println"
143 | "type_of"
144 | "to_string"
145 | "to_int"
146 | "to_float"
147 | "json_stringify"
148 | "json_parse"
149 | "env"
150 | "timestamp"
151 | "sleep"
152 | "read_file"
153 | "write_file"
154 | "exit"
155 | "regex_match"
156 | "regex_replace"
157 | "http_get"
158 | "http_post"
159 | "llm_call"
160 | "agent_loop"
161 | "await"
162 | "cancel"
163 | "file_exists"
164 | "delete_file"
165 | "list_dir"
166 | "mkdir"
167 | "path_join"
168 | "copy_file"
169 | "append_file"
170 | "temp_dir"
171 | "stat"
172 | "exec"
173 | "shell"
174 | "date_now"
175 | "date_format"
176 | "date_parse"
177 | "format"
178 | "json_validate"
179 | "json_extract"
180 | "trim"
181 | "lowercase"
182 | "uppercase"
183 | "split"
184 | "starts_with"
185 | "ends_with"
186 | "contains"
187 | "replace"
188 | "join"
189 | "len"
190 | "substring"
191 | "dirname"
192 | "basename"
193 | "extname"
194 )
195}
196
197pub struct TypeChecker {
199 diagnostics: Vec<TypeDiagnostic>,
200 scope: TypeScope,
201}
202
203impl TypeChecker {
204 pub fn new() -> Self {
205 Self {
206 diagnostics: Vec::new(),
207 scope: TypeScope::new(),
208 }
209 }
210
211 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213 Self::register_declarations_into(&mut self.scope, program);
215
216 for snode in program {
218 if let Node::Pipeline { body, .. } = &snode.node {
219 Self::register_declarations_into(&mut self.scope, body);
220 }
221 }
222
223 for snode in program {
225 match &snode.node {
226 Node::Pipeline { params, body, .. } => {
227 let mut child = self.scope.child();
228 for p in params {
229 child.define_var(p, None);
230 }
231 self.check_block(body, &mut child);
232 }
233 Node::FnDecl {
234 name,
235 params,
236 return_type,
237 body,
238 ..
239 } => {
240 let sig = FnSignature {
241 params: params
242 .iter()
243 .map(|p| (p.name.clone(), p.type_expr.clone()))
244 .collect(),
245 return_type: return_type.clone(),
246 };
247 self.scope.define_fn(name, sig);
248 self.check_fn_body(params, return_type, body);
249 }
250 _ => {
251 let mut scope = self.scope.clone();
252 self.check_node(snode, &mut scope);
253 for (name, ty) in scope.vars {
255 self.scope.vars.entry(name).or_insert(ty);
256 }
257 }
258 }
259 }
260
261 self.diagnostics
262 }
263
264 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266 for snode in nodes {
267 match &snode.node {
268 Node::TypeDecl { name, type_expr } => {
269 scope.type_aliases.insert(name.clone(), type_expr.clone());
270 }
271 Node::EnumDecl { name, variants } => {
272 let variant_names: Vec<String> =
273 variants.iter().map(|v| v.name.clone()).collect();
274 scope.enums.insert(name.clone(), variant_names);
275 }
276 Node::InterfaceDecl { name, methods } => {
277 scope.interfaces.insert(name.clone(), methods.clone());
278 }
279 Node::StructDecl { name, fields } => {
280 let field_types: Vec<(String, InferredType)> = fields
281 .iter()
282 .map(|f| (f.name.clone(), f.type_expr.clone()))
283 .collect();
284 scope.structs.insert(name.clone(), field_types);
285 }
286 _ => {}
287 }
288 }
289 }
290
291 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292 for stmt in stmts {
293 self.check_node(stmt, scope);
294 }
295 }
296
297 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
298 let span = snode.span;
299 match &snode.node {
300 Node::LetBinding {
301 name,
302 type_ann,
303 value,
304 } => {
305 let inferred = self.infer_type(value, scope);
306 if let Some(expected) = type_ann {
307 if let Some(actual) = &inferred {
308 if !self.types_compatible(expected, actual, scope) {
309 self.error_at(
310 format!(
311 "Type mismatch: '{}' declared as {}, but assigned {}",
312 name,
313 format_type(expected),
314 format_type(actual)
315 ),
316 span,
317 );
318 }
319 }
320 }
321 let ty = type_ann.clone().or(inferred);
322 scope.define_var(name, ty);
323 }
324
325 Node::VarBinding {
326 name,
327 type_ann,
328 value,
329 } => {
330 let inferred = self.infer_type(value, scope);
331 if let Some(expected) = type_ann {
332 if let Some(actual) = &inferred {
333 if !self.types_compatible(expected, actual, scope) {
334 self.error_at(
335 format!(
336 "Type mismatch: '{}' declared as {}, but assigned {}",
337 name,
338 format_type(expected),
339 format_type(actual)
340 ),
341 span,
342 );
343 }
344 }
345 }
346 let ty = type_ann.clone().or(inferred);
347 scope.define_var(name, ty);
348 }
349
350 Node::FnDecl {
351 name,
352 params,
353 return_type,
354 body,
355 ..
356 } => {
357 let sig = FnSignature {
358 params: params
359 .iter()
360 .map(|p| (p.name.clone(), p.type_expr.clone()))
361 .collect(),
362 return_type: return_type.clone(),
363 };
364 scope.define_fn(name, sig.clone());
365 scope.define_var(name, None);
366 self.check_fn_body(params, return_type, body);
367 }
368
369 Node::FunctionCall { name, args } => {
370 self.check_call(name, args, scope, span);
371 }
372
373 Node::IfElse {
374 condition,
375 then_body,
376 else_body,
377 } => {
378 self.check_node(condition, scope);
379 let mut then_scope = scope.child();
380 self.check_block(then_body, &mut then_scope);
381 if let Some(else_body) = else_body {
382 let mut else_scope = scope.child();
383 self.check_block(else_body, &mut else_scope);
384 }
385 }
386
387 Node::ForIn {
388 variable,
389 iterable,
390 body,
391 } => {
392 self.check_node(iterable, scope);
393 let mut loop_scope = scope.child();
394 let elem_type = match self.infer_type(iterable, scope) {
396 Some(TypeExpr::List(inner)) => Some(*inner),
397 Some(TypeExpr::Named(n)) if n == "string" => {
398 Some(TypeExpr::Named("string".into()))
399 }
400 _ => None,
401 };
402 loop_scope.define_var(variable, elem_type);
403 self.check_block(body, &mut loop_scope);
404 }
405
406 Node::WhileLoop { condition, body } => {
407 self.check_node(condition, scope);
408 let mut loop_scope = scope.child();
409 self.check_block(body, &mut loop_scope);
410 }
411
412 Node::TryCatch {
413 body,
414 error_var,
415 catch_body,
416 ..
417 } => {
418 let mut try_scope = scope.child();
419 self.check_block(body, &mut try_scope);
420 let mut catch_scope = scope.child();
421 if let Some(var) = error_var {
422 catch_scope.define_var(var, None);
423 }
424 self.check_block(catch_body, &mut catch_scope);
425 }
426
427 Node::ReturnStmt {
428 value: Some(val), ..
429 } => {
430 self.check_node(val, scope);
431 }
432
433 Node::Assignment {
434 target, value, op, ..
435 } => {
436 self.check_node(value, scope);
437 if let Node::Identifier(name) = &target.node {
438 if let Some(Some(var_type)) = scope.get_var(name) {
439 let value_type = self.infer_type(value, scope);
440 let assigned = if let Some(op) = op {
441 let var_inferred = scope.get_var(name).cloned().flatten();
442 infer_binary_op_type(op, &var_inferred, &value_type)
443 } else {
444 value_type
445 };
446 if let Some(actual) = &assigned {
447 if !self.types_compatible(var_type, actual, scope) {
448 self.error_at(
449 format!(
450 "Type mismatch: cannot assign {} to '{}' (declared as {})",
451 format_type(actual),
452 name,
453 format_type(var_type)
454 ),
455 span,
456 );
457 }
458 }
459 }
460 }
461 }
462
463 Node::TypeDecl { name, type_expr } => {
464 scope.type_aliases.insert(name.clone(), type_expr.clone());
465 }
466
467 Node::EnumDecl { name, variants } => {
468 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
469 scope.enums.insert(name.clone(), variant_names);
470 }
471
472 Node::StructDecl { name, fields } => {
473 let field_types: Vec<(String, InferredType)> = fields
474 .iter()
475 .map(|f| (f.name.clone(), f.type_expr.clone()))
476 .collect();
477 scope.structs.insert(name.clone(), field_types);
478 }
479
480 Node::InterfaceDecl { name, methods } => {
481 scope.interfaces.insert(name.clone(), methods.clone());
482 }
483
484 Node::MatchExpr { value, arms } => {
485 self.check_node(value, scope);
486 for arm in arms {
487 self.check_node(&arm.pattern, scope);
488 let mut arm_scope = scope.child();
489 self.check_block(&arm.body, &mut arm_scope);
490 }
491 self.check_match_exhaustiveness(value, arms, scope, span);
492 }
493
494 Node::BinaryOp { op, left, right } => {
496 self.check_node(left, scope);
497 self.check_node(right, scope);
498 let lt = self.infer_type(left, scope);
500 let rt = self.infer_type(right, scope);
501 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
502 match op.as_str() {
503 "-" | "*" | "/" | "%" => {
504 let numeric = ["int", "float"];
505 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
506 self.warning_at(
507 format!(
508 "Operator '{op}' may not be valid for types {} and {}",
509 l, r
510 ),
511 span,
512 );
513 }
514 }
515 "+" => {
516 let valid = ["int", "float", "string", "list", "dict"];
518 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
519 self.warning_at(
520 format!(
521 "Operator '+' may not be valid for types {} and {}",
522 l, r
523 ),
524 span,
525 );
526 }
527 }
528 _ => {}
529 }
530 }
531 }
532 Node::UnaryOp { operand, .. } => {
533 self.check_node(operand, scope);
534 }
535 Node::MethodCall { object, args, .. }
536 | Node::OptionalMethodCall { object, args, .. } => {
537 self.check_node(object, scope);
538 for arg in args {
539 self.check_node(arg, scope);
540 }
541 }
542 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
543 self.check_node(object, scope);
544 }
545 Node::SubscriptAccess { object, index } => {
546 self.check_node(object, scope);
547 self.check_node(index, scope);
548 }
549 Node::SliceAccess { object, start, end } => {
550 self.check_node(object, scope);
551 if let Some(s) = start {
552 self.check_node(s, scope);
553 }
554 if let Some(e) = end {
555 self.check_node(e, scope);
556 }
557 }
558
559 _ => {}
561 }
562 }
563
564 fn check_fn_body(
565 &mut self,
566 params: &[TypedParam],
567 return_type: &Option<TypeExpr>,
568 body: &[SNode],
569 ) {
570 let mut fn_scope = self.scope.child();
571 for param in params {
572 fn_scope.define_var(¶m.name, param.type_expr.clone());
573 }
574 self.check_block(body, &mut fn_scope);
575
576 if let Some(ret_type) = return_type {
578 for stmt in body {
579 self.check_return_type(stmt, ret_type, &fn_scope);
580 }
581 }
582 }
583
584 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
585 let span = snode.span;
586 match &snode.node {
587 Node::ReturnStmt { value: Some(val) } => {
588 let inferred = self.infer_type(val, scope);
589 if let Some(actual) = &inferred {
590 if !self.types_compatible(expected, actual, scope) {
591 self.error_at(
592 format!(
593 "Return type mismatch: expected {}, got {}",
594 format_type(expected),
595 format_type(actual)
596 ),
597 span,
598 );
599 }
600 }
601 }
602 Node::IfElse {
603 then_body,
604 else_body,
605 ..
606 } => {
607 for stmt in then_body {
608 self.check_return_type(stmt, expected, scope);
609 }
610 if let Some(else_body) = else_body {
611 for stmt in else_body {
612 self.check_return_type(stmt, expected, scope);
613 }
614 }
615 }
616 _ => {}
617 }
618 }
619
620 fn check_match_exhaustiveness(
622 &mut self,
623 value: &SNode,
624 arms: &[MatchArm],
625 scope: &TypeScope,
626 span: Span,
627 ) {
628 let enum_name = match &value.node {
630 Node::PropertyAccess { object, property } if property == "variant" => {
631 match self.infer_type(object, scope) {
633 Some(TypeExpr::Named(name)) => {
634 if scope.get_enum(&name).is_some() {
635 Some(name)
636 } else {
637 None
638 }
639 }
640 _ => None,
641 }
642 }
643 _ => {
644 match self.infer_type(value, scope) {
646 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
647 _ => None,
648 }
649 }
650 };
651
652 let Some(enum_name) = enum_name else {
653 return;
654 };
655 let Some(variants) = scope.get_enum(&enum_name) else {
656 return;
657 };
658
659 let mut covered: Vec<String> = Vec::new();
661 let mut has_wildcard = false;
662
663 for arm in arms {
664 match &arm.pattern.node {
665 Node::StringLiteral(s) => covered.push(s.clone()),
667 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
669 has_wildcard = true;
670 }
671 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
673 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
675 _ => {
676 has_wildcard = true;
678 }
679 }
680 }
681
682 if has_wildcard {
683 return;
684 }
685
686 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
687 if !missing.is_empty() {
688 let missing_str = missing
689 .iter()
690 .map(|s| format!("\"{}\"", s))
691 .collect::<Vec<_>>()
692 .join(", ");
693 self.warning_at(
694 format!(
695 "Non-exhaustive match on enum {}: missing variants {}",
696 enum_name, missing_str
697 ),
698 span,
699 );
700 }
701 }
702
703 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
704 if let Some(sig) = scope.get_fn(name).cloned() {
706 if args.len() != sig.params.len() && !is_builtin(name) {
707 self.warning_at(
708 format!(
709 "Function '{}' expects {} arguments, got {}",
710 name,
711 sig.params.len(),
712 args.len()
713 ),
714 span,
715 );
716 }
717 for (i, (arg, (param_name, param_type))) in
718 args.iter().zip(sig.params.iter()).enumerate()
719 {
720 if let Some(expected) = param_type {
721 let actual = self.infer_type(arg, scope);
722 if let Some(actual) = &actual {
723 if !self.types_compatible(expected, actual, scope) {
724 self.error_at(
725 format!(
726 "Argument {} ('{}'): expected {}, got {}",
727 i + 1,
728 param_name,
729 format_type(expected),
730 format_type(actual)
731 ),
732 arg.span,
733 );
734 }
735 }
736 }
737 }
738 }
739 for arg in args {
741 self.check_node(arg, scope);
742 }
743 }
744
745 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
747 match &snode.node {
748 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
749 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
750 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
751 Some(TypeExpr::Named("string".into()))
752 }
753 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
754 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
755 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
756 Node::DictLiteral(entries) => {
757 let mut fields = Vec::new();
759 let mut all_string_keys = true;
760 for entry in entries {
761 if let Node::StringLiteral(key) = &entry.key.node {
762 let val_type = self
763 .infer_type(&entry.value, scope)
764 .unwrap_or(TypeExpr::Named("nil".into()));
765 fields.push(ShapeField {
766 name: key.clone(),
767 type_expr: val_type,
768 optional: false,
769 });
770 } else {
771 all_string_keys = false;
772 break;
773 }
774 }
775 if all_string_keys && !fields.is_empty() {
776 Some(TypeExpr::Shape(fields))
777 } else {
778 Some(TypeExpr::Named("dict".into()))
779 }
780 }
781 Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
782
783 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
784
785 Node::FunctionCall { name, .. } => {
786 if let Some(sig) = scope.get_fn(name) {
788 return sig.return_type.clone();
789 }
790 builtin_return_type(name)
792 }
793
794 Node::BinaryOp { op, left, right } => {
795 let lt = self.infer_type(left, scope);
796 let rt = self.infer_type(right, scope);
797 infer_binary_op_type(op, <, &rt)
798 }
799
800 Node::UnaryOp { op, operand } => {
801 let t = self.infer_type(operand, scope);
802 match op.as_str() {
803 "!" => Some(TypeExpr::Named("bool".into())),
804 "-" => t, _ => None,
806 }
807 }
808
809 Node::Ternary {
810 true_expr,
811 false_expr,
812 ..
813 } => {
814 let tt = self.infer_type(true_expr, scope);
815 let ft = self.infer_type(false_expr, scope);
816 match (&tt, &ft) {
817 (Some(a), Some(b)) if a == b => tt,
818 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
819 (Some(_), None) => tt,
820 (None, Some(_)) => ft,
821 (None, None) => None,
822 }
823 }
824
825 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
826
827 Node::PropertyAccess { object, property } => {
828 if let Node::Identifier(name) = &object.node {
830 if scope.get_enum(name).is_some() {
831 return Some(TypeExpr::Named(name.clone()));
832 }
833 }
834 if property == "variant" {
836 let obj_type = self.infer_type(object, scope);
837 if let Some(TypeExpr::Named(name)) = &obj_type {
838 if scope.get_enum(name).is_some() {
839 return Some(TypeExpr::Named("string".into()));
840 }
841 }
842 }
843 let obj_type = self.infer_type(object, scope);
845 if let Some(TypeExpr::Shape(fields)) = &obj_type {
846 if let Some(field) = fields.iter().find(|f| f.name == *property) {
847 return Some(field.type_expr.clone());
848 }
849 }
850 None
851 }
852
853 Node::SubscriptAccess { object, index } => {
854 let obj_type = self.infer_type(object, scope);
855 match &obj_type {
856 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
857 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
858 Some(TypeExpr::Shape(fields)) => {
859 if let Node::StringLiteral(key) = &index.node {
861 fields
862 .iter()
863 .find(|f| &f.name == key)
864 .map(|f| f.type_expr.clone())
865 } else {
866 None
867 }
868 }
869 Some(TypeExpr::Named(n)) if n == "list" => None,
870 Some(TypeExpr::Named(n)) if n == "dict" => None,
871 Some(TypeExpr::Named(n)) if n == "string" => {
872 Some(TypeExpr::Named("string".into()))
873 }
874 _ => None,
875 }
876 }
877 Node::SliceAccess { object, .. } => {
878 let obj_type = self.infer_type(object, scope);
880 match &obj_type {
881 Some(TypeExpr::List(_)) => obj_type,
882 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
883 Some(TypeExpr::Named(n)) if n == "string" => {
884 Some(TypeExpr::Named("string".into()))
885 }
886 _ => None,
887 }
888 }
889 Node::MethodCall { object, method, .. }
890 | Node::OptionalMethodCall { object, method, .. } => {
891 let obj_type = self.infer_type(object, scope);
892 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
893 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
894 match method.as_str() {
895 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
897 Some(TypeExpr::Named("bool".into()))
898 }
899 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
901 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
903 | "pad_left" | "pad_right" | "repeat" | "join" => {
904 Some(TypeExpr::Named("string".into()))
905 }
906 "split" | "chars" => Some(TypeExpr::Named("list".into())),
907 "filter" => {
909 if is_dict {
910 Some(TypeExpr::Named("dict".into()))
911 } else {
912 Some(TypeExpr::Named("list".into()))
913 }
914 }
915 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
917 "reduce" | "find" | "first" | "last" => None,
918 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
920 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
921 "to_string" => Some(TypeExpr::Named("string".into())),
923 "to_int" => Some(TypeExpr::Named("int".into())),
924 "to_float" => Some(TypeExpr::Named("float".into())),
925 _ => None,
926 }
927 }
928
929 _ => None,
930 }
931 }
932
933 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
935 let expected = self.resolve_alias(expected, scope);
936 let actual = self.resolve_alias(actual, scope);
937
938 match (&expected, &actual) {
939 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
940 (TypeExpr::Union(members), actual_type) => members
941 .iter()
942 .any(|m| self.types_compatible(m, actual_type, scope)),
943 (expected_type, TypeExpr::Union(members)) => members
944 .iter()
945 .all(|m| self.types_compatible(expected_type, m, scope)),
946 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
947 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
948 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
949 if expected_field.optional {
950 return true;
951 }
952 af.iter().any(|actual_field| {
953 actual_field.name == expected_field.name
954 && self.types_compatible(
955 &expected_field.type_expr,
956 &actual_field.type_expr,
957 scope,
958 )
959 })
960 }),
961 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
963 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
964 keys_ok
965 && af
966 .iter()
967 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
968 }
969 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
971 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
972 self.types_compatible(expected_inner, actual_inner, scope)
973 }
974 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
975 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
976 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
977 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
978 }
979 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
980 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
981 _ => false,
982 }
983 }
984
985 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
986 if let TypeExpr::Named(name) = ty {
987 if let Some(resolved) = scope.resolve_type(name) {
988 return resolved.clone();
989 }
990 }
991 ty.clone()
992 }
993
994 fn error_at(&mut self, message: String, span: Span) {
995 self.diagnostics.push(TypeDiagnostic {
996 message,
997 severity: DiagnosticSeverity::Error,
998 span: Some(span),
999 });
1000 }
1001
1002 fn warning_at(&mut self, message: String, span: Span) {
1003 self.diagnostics.push(TypeDiagnostic {
1004 message,
1005 severity: DiagnosticSeverity::Warning,
1006 span: Some(span),
1007 });
1008 }
1009}
1010
1011impl Default for TypeChecker {
1012 fn default() -> Self {
1013 Self::new()
1014 }
1015}
1016
1017fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1019 match op {
1020 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1021 "+" => match (left, right) {
1022 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1023 match (l.as_str(), r.as_str()) {
1024 ("int", "int") => Some(TypeExpr::Named("int".into())),
1025 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1026 ("string", _) => Some(TypeExpr::Named("string".into())),
1027 ("list", "list") => Some(TypeExpr::Named("list".into())),
1028 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1029 _ => Some(TypeExpr::Named("string".into())),
1030 }
1031 }
1032 _ => None,
1033 },
1034 "-" | "*" | "/" | "%" => match (left, right) {
1035 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1036 match (l.as_str(), r.as_str()) {
1037 ("int", "int") => Some(TypeExpr::Named("int".into())),
1038 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1039 _ => None,
1040 }
1041 }
1042 _ => None,
1043 },
1044 "??" => match (left, right) {
1045 (Some(TypeExpr::Union(members)), _) => {
1046 let non_nil: Vec<_> = members
1047 .iter()
1048 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1049 .cloned()
1050 .collect();
1051 if non_nil.len() == 1 {
1052 Some(non_nil[0].clone())
1053 } else if non_nil.is_empty() {
1054 right.clone()
1055 } else {
1056 Some(TypeExpr::Union(non_nil))
1057 }
1058 }
1059 _ => right.clone(),
1060 },
1061 "|>" => None,
1062 _ => None,
1063 }
1064}
1065
1066pub fn format_type(ty: &TypeExpr) -> String {
1068 match ty {
1069 TypeExpr::Named(n) => n.clone(),
1070 TypeExpr::Union(types) => types
1071 .iter()
1072 .map(format_type)
1073 .collect::<Vec<_>>()
1074 .join(" | "),
1075 TypeExpr::Shape(fields) => {
1076 let inner: Vec<String> = fields
1077 .iter()
1078 .map(|f| {
1079 let opt = if f.optional { "?" } else { "" };
1080 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1081 })
1082 .collect();
1083 format!("{{{}}}", inner.join(", "))
1084 }
1085 TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1086 TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1087 }
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092 use super::*;
1093 use crate::Parser;
1094 use harn_lexer::Lexer;
1095
1096 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1097 let mut lexer = Lexer::new(source);
1098 let tokens = lexer.tokenize().unwrap();
1099 let mut parser = Parser::new(tokens);
1100 let program = parser.parse().unwrap();
1101 TypeChecker::new().check(&program)
1102 }
1103
1104 fn errors(source: &str) -> Vec<String> {
1105 check_source(source)
1106 .into_iter()
1107 .filter(|d| d.severity == DiagnosticSeverity::Error)
1108 .map(|d| d.message)
1109 .collect()
1110 }
1111
1112 #[test]
1113 fn test_no_errors_for_untyped_code() {
1114 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1115 assert!(errs.is_empty());
1116 }
1117
1118 #[test]
1119 fn test_correct_typed_let() {
1120 let errs = errors("pipeline t(task) { let x: int = 42 }");
1121 assert!(errs.is_empty());
1122 }
1123
1124 #[test]
1125 fn test_type_mismatch_let() {
1126 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1127 assert_eq!(errs.len(), 1);
1128 assert!(errs[0].contains("Type mismatch"));
1129 assert!(errs[0].contains("int"));
1130 assert!(errs[0].contains("string"));
1131 }
1132
1133 #[test]
1134 fn test_correct_typed_fn() {
1135 let errs = errors(
1136 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1137 );
1138 assert!(errs.is_empty());
1139 }
1140
1141 #[test]
1142 fn test_fn_arg_type_mismatch() {
1143 let errs = errors(
1144 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1145add("hello", 2) }"#,
1146 );
1147 assert_eq!(errs.len(), 1);
1148 assert!(errs[0].contains("Argument 1"));
1149 assert!(errs[0].contains("expected int"));
1150 }
1151
1152 #[test]
1153 fn test_return_type_mismatch() {
1154 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1155 assert_eq!(errs.len(), 1);
1156 assert!(errs[0].contains("Return type mismatch"));
1157 }
1158
1159 #[test]
1160 fn test_union_type_compatible() {
1161 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1162 assert!(errs.is_empty());
1163 }
1164
1165 #[test]
1166 fn test_union_type_mismatch() {
1167 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1168 assert_eq!(errs.len(), 1);
1169 assert!(errs[0].contains("Type mismatch"));
1170 }
1171
1172 #[test]
1173 fn test_type_inference_propagation() {
1174 let errs = errors(
1175 r#"pipeline t(task) {
1176 fn add(a: int, b: int) -> int { return a + b }
1177 let result: string = add(1, 2)
1178}"#,
1179 );
1180 assert_eq!(errs.len(), 1);
1181 assert!(errs[0].contains("Type mismatch"));
1182 assert!(errs[0].contains("string"));
1183 assert!(errs[0].contains("int"));
1184 }
1185
1186 #[test]
1187 fn test_builtin_return_type_inference() {
1188 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1189 assert_eq!(errs.len(), 1);
1190 assert!(errs[0].contains("string"));
1191 assert!(errs[0].contains("int"));
1192 }
1193
1194 #[test]
1195 fn test_binary_op_type_inference() {
1196 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1197 assert_eq!(errs.len(), 1);
1198 }
1199
1200 #[test]
1201 fn test_comparison_returns_bool() {
1202 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1203 assert!(errs.is_empty());
1204 }
1205
1206 #[test]
1207 fn test_int_float_promotion() {
1208 let errs = errors("pipeline t(task) { let x: float = 42 }");
1209 assert!(errs.is_empty());
1210 }
1211
1212 #[test]
1213 fn test_untyped_code_no_errors() {
1214 let errs = errors(
1215 r#"pipeline t(task) {
1216 fn process(data) {
1217 let result = data + " processed"
1218 return result
1219 }
1220 log(process("hello"))
1221}"#,
1222 );
1223 assert!(errs.is_empty());
1224 }
1225
1226 #[test]
1227 fn test_type_alias() {
1228 let errs = errors(
1229 r#"pipeline t(task) {
1230 type Name = string
1231 let x: Name = "hello"
1232}"#,
1233 );
1234 assert!(errs.is_empty());
1235 }
1236
1237 #[test]
1238 fn test_type_alias_mismatch() {
1239 let errs = errors(
1240 r#"pipeline t(task) {
1241 type Name = string
1242 let x: Name = 42
1243}"#,
1244 );
1245 assert_eq!(errs.len(), 1);
1246 }
1247
1248 #[test]
1249 fn test_assignment_type_check() {
1250 let errs = errors(
1251 r#"pipeline t(task) {
1252 var x: int = 0
1253 x = "hello"
1254}"#,
1255 );
1256 assert_eq!(errs.len(), 1);
1257 assert!(errs[0].contains("cannot assign string"));
1258 }
1259
1260 #[test]
1261 fn test_covariance_int_to_float_in_fn() {
1262 let errs = errors(
1263 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1264 );
1265 assert!(errs.is_empty());
1266 }
1267
1268 #[test]
1269 fn test_covariance_return_type() {
1270 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1271 assert!(errs.is_empty());
1272 }
1273
1274 #[test]
1275 fn test_no_contravariance_float_to_int() {
1276 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1277 assert_eq!(errs.len(), 1);
1278 }
1279
1280 fn warnings(source: &str) -> Vec<String> {
1283 check_source(source)
1284 .into_iter()
1285 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1286 .map(|d| d.message)
1287 .collect()
1288 }
1289
1290 #[test]
1291 fn test_exhaustive_match_no_warning() {
1292 let warns = warnings(
1293 r#"pipeline t(task) {
1294 enum Color { Red, Green, Blue }
1295 let c = Color.Red
1296 match c.variant {
1297 "Red" -> { log("r") }
1298 "Green" -> { log("g") }
1299 "Blue" -> { log("b") }
1300 }
1301}"#,
1302 );
1303 let exhaustive_warns: Vec<_> = warns
1304 .iter()
1305 .filter(|w| w.contains("Non-exhaustive"))
1306 .collect();
1307 assert!(exhaustive_warns.is_empty());
1308 }
1309
1310 #[test]
1311 fn test_non_exhaustive_match_warning() {
1312 let warns = warnings(
1313 r#"pipeline t(task) {
1314 enum Color { Red, Green, Blue }
1315 let c = Color.Red
1316 match c.variant {
1317 "Red" -> { log("r") }
1318 "Green" -> { log("g") }
1319 }
1320}"#,
1321 );
1322 let exhaustive_warns: Vec<_> = warns
1323 .iter()
1324 .filter(|w| w.contains("Non-exhaustive"))
1325 .collect();
1326 assert_eq!(exhaustive_warns.len(), 1);
1327 assert!(exhaustive_warns[0].contains("Blue"));
1328 }
1329
1330 #[test]
1331 fn test_non_exhaustive_multiple_missing() {
1332 let warns = warnings(
1333 r#"pipeline t(task) {
1334 enum Status { Active, Inactive, Pending }
1335 let s = Status.Active
1336 match s.variant {
1337 "Active" -> { log("a") }
1338 }
1339}"#,
1340 );
1341 let exhaustive_warns: Vec<_> = warns
1342 .iter()
1343 .filter(|w| w.contains("Non-exhaustive"))
1344 .collect();
1345 assert_eq!(exhaustive_warns.len(), 1);
1346 assert!(exhaustive_warns[0].contains("Inactive"));
1347 assert!(exhaustive_warns[0].contains("Pending"));
1348 }
1349
1350 #[test]
1351 fn test_enum_construct_type_inference() {
1352 let errs = errors(
1353 r#"pipeline t(task) {
1354 enum Color { Red, Green, Blue }
1355 let c: Color = Color.Red
1356}"#,
1357 );
1358 assert!(errs.is_empty());
1359 }
1360
1361 #[test]
1364 fn test_nil_coalescing_strips_nil() {
1365 let errs = errors(
1367 r#"pipeline t(task) {
1368 let x: string | nil = nil
1369 let y: string = x ?? "default"
1370}"#,
1371 );
1372 assert!(errs.is_empty());
1373 }
1374}