1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
40 generic_type_params: std::collections::BTreeSet<String>,
42 parent: Option<Box<TypeScope>>,
43}
44
45#[derive(Debug, Clone)]
47struct ImplMethodSig {
48 name: String,
49 param_count: usize,
51 param_types: Vec<Option<TypeExpr>>,
53 return_type: Option<TypeExpr>,
55}
56
57#[derive(Debug, Clone)]
58struct FnSignature {
59 params: Vec<(String, InferredType)>,
60 return_type: InferredType,
61 type_param_names: Vec<String>,
63 required_params: usize,
65 where_clauses: Vec<(String, String)>,
67}
68
69impl TypeScope {
70 fn new() -> Self {
71 Self {
72 vars: BTreeMap::new(),
73 functions: BTreeMap::new(),
74 type_aliases: BTreeMap::new(),
75 enums: BTreeMap::new(),
76 interfaces: BTreeMap::new(),
77 structs: BTreeMap::new(),
78 impl_methods: BTreeMap::new(),
79 generic_type_params: std::collections::BTreeSet::new(),
80 parent: None,
81 }
82 }
83
84 fn child(&self) -> Self {
85 Self {
86 vars: BTreeMap::new(),
87 functions: BTreeMap::new(),
88 type_aliases: BTreeMap::new(),
89 enums: BTreeMap::new(),
90 interfaces: BTreeMap::new(),
91 structs: BTreeMap::new(),
92 impl_methods: BTreeMap::new(),
93 generic_type_params: std::collections::BTreeSet::new(),
94 parent: Some(Box::new(self.clone())),
95 }
96 }
97
98 fn get_var(&self, name: &str) -> Option<&InferredType> {
99 self.vars
100 .get(name)
101 .or_else(|| self.parent.as_ref()?.get_var(name))
102 }
103
104 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
105 self.functions
106 .get(name)
107 .or_else(|| self.parent.as_ref()?.get_fn(name))
108 }
109
110 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
111 self.type_aliases
112 .get(name)
113 .or_else(|| self.parent.as_ref()?.resolve_type(name))
114 }
115
116 fn is_generic_type_param(&self, name: &str) -> bool {
117 self.generic_type_params.contains(name)
118 || self
119 .parent
120 .as_ref()
121 .is_some_and(|p| p.is_generic_type_param(name))
122 }
123
124 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
125 self.enums
126 .get(name)
127 .or_else(|| self.parent.as_ref()?.get_enum(name))
128 }
129
130 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
131 self.interfaces
132 .get(name)
133 .or_else(|| self.parent.as_ref()?.get_interface(name))
134 }
135
136 fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
137 self.structs
138 .get(name)
139 .or_else(|| self.parent.as_ref()?.get_struct(name))
140 }
141
142 fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
143 self.impl_methods
144 .get(name)
145 .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
146 }
147
148 fn define_var(&mut self, name: &str, ty: InferredType) {
149 self.vars.insert(name.to_string(), ty);
150 }
151
152 fn define_fn(&mut self, name: &str, sig: FnSignature) {
153 self.functions.insert(name.to_string(), sig);
154 }
155}
156
157fn builtin_return_type(name: &str) -> InferredType {
159 match name {
160 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
161 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
162 Some(TypeExpr::Named("nil".into()))
163 }
164 "type_of"
165 | "to_string"
166 | "json_stringify"
167 | "read_file"
168 | "http_get"
169 | "http_post"
170 | "llm_call"
171 | "regex_replace"
172 | "path_join"
173 | "temp_dir"
174 | "date_format"
175 | "format"
176 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
177 "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
178 "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
179 | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
180 Some(TypeExpr::Named("float".into()))
181 }
182 "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
183 Some(TypeExpr::Named("bool".into()))
184 }
185 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list"
186 | "regex_captures" => Some(TypeExpr::Named("list".into())),
187 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
188 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
189 Some(TypeExpr::Named("dict".into()))
190 }
191 "metadata_set"
192 | "metadata_save"
193 | "metadata_refresh_hashes"
194 | "invalidate_facts"
195 | "log_json"
196 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
197 "env" | "regex_match" => Some(TypeExpr::Union(vec![
198 TypeExpr::Named("string".into()),
199 TypeExpr::Named("nil".into()),
200 ])),
201 "json_parse" | "json_extract" => None, _ => None,
203 }
204}
205
206fn is_builtin(name: &str) -> bool {
208 matches!(
209 name,
210 "log"
211 | "print"
212 | "println"
213 | "type_of"
214 | "to_string"
215 | "to_int"
216 | "to_float"
217 | "json_stringify"
218 | "json_parse"
219 | "env"
220 | "timestamp"
221 | "sleep"
222 | "read_file"
223 | "write_file"
224 | "exit"
225 | "regex_match"
226 | "regex_replace"
227 | "regex_captures"
228 | "http_get"
229 | "http_post"
230 | "llm_call"
231 | "agent_loop"
232 | "await"
233 | "cancel"
234 | "file_exists"
235 | "delete_file"
236 | "list_dir"
237 | "mkdir"
238 | "path_join"
239 | "copy_file"
240 | "append_file"
241 | "temp_dir"
242 | "stat"
243 | "exec"
244 | "shell"
245 | "date_now"
246 | "date_format"
247 | "date_parse"
248 | "format"
249 | "json_validate"
250 | "json_extract"
251 | "trim"
252 | "lowercase"
253 | "uppercase"
254 | "split"
255 | "starts_with"
256 | "ends_with"
257 | "contains"
258 | "replace"
259 | "join"
260 | "len"
261 | "substring"
262 | "dirname"
263 | "basename"
264 | "extname"
265 | "sin"
266 | "cos"
267 | "tan"
268 | "asin"
269 | "acos"
270 | "atan"
271 | "atan2"
272 | "log2"
273 | "log10"
274 | "ln"
275 | "exp"
276 | "pi"
277 | "e"
278 | "sign"
279 | "is_nan"
280 | "is_infinite"
281 | "set"
282 | "set_add"
283 | "set_remove"
284 | "set_contains"
285 | "set_union"
286 | "set_intersect"
287 | "set_difference"
288 | "to_list"
289 )
290}
291
292pub struct TypeChecker {
294 diagnostics: Vec<TypeDiagnostic>,
295 scope: TypeScope,
296}
297
298impl TypeChecker {
299 pub fn new() -> Self {
300 Self {
301 diagnostics: Vec::new(),
302 scope: TypeScope::new(),
303 }
304 }
305
306 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
308 Self::register_declarations_into(&mut self.scope, program);
310
311 for snode in program {
313 if let Node::Pipeline { body, .. } = &snode.node {
314 Self::register_declarations_into(&mut self.scope, body);
315 }
316 }
317
318 for snode in program {
320 match &snode.node {
321 Node::Pipeline { params, body, .. } => {
322 let mut child = self.scope.child();
323 for p in params {
324 child.define_var(p, None);
325 }
326 self.check_block(body, &mut child);
327 }
328 Node::FnDecl {
329 name,
330 type_params,
331 params,
332 return_type,
333 where_clauses,
334 body,
335 ..
336 } => {
337 let required_params =
338 params.iter().filter(|p| p.default_value.is_none()).count();
339 let sig = FnSignature {
340 params: params
341 .iter()
342 .map(|p| (p.name.clone(), p.type_expr.clone()))
343 .collect(),
344 return_type: return_type.clone(),
345 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
346 required_params,
347 where_clauses: where_clauses
348 .iter()
349 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
350 .collect(),
351 };
352 self.scope.define_fn(name, sig);
353 self.check_fn_body(type_params, params, return_type, body);
354 }
355 _ => {
356 let mut scope = self.scope.clone();
357 self.check_node(snode, &mut scope);
358 for (name, ty) in scope.vars {
360 self.scope.vars.entry(name).or_insert(ty);
361 }
362 }
363 }
364 }
365
366 self.diagnostics
367 }
368
369 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
371 for snode in nodes {
372 match &snode.node {
373 Node::TypeDecl { name, type_expr } => {
374 scope.type_aliases.insert(name.clone(), type_expr.clone());
375 }
376 Node::EnumDecl { name, variants } => {
377 let variant_names: Vec<String> =
378 variants.iter().map(|v| v.name.clone()).collect();
379 scope.enums.insert(name.clone(), variant_names);
380 }
381 Node::InterfaceDecl { name, methods } => {
382 scope.interfaces.insert(name.clone(), methods.clone());
383 }
384 Node::StructDecl { name, fields } => {
385 let field_types: Vec<(String, InferredType)> = fields
386 .iter()
387 .map(|f| (f.name.clone(), f.type_expr.clone()))
388 .collect();
389 scope.structs.insert(name.clone(), field_types);
390 }
391 Node::ImplBlock {
392 type_name, methods, ..
393 } => {
394 let sigs: Vec<ImplMethodSig> = methods
395 .iter()
396 .filter_map(|m| {
397 if let Node::FnDecl {
398 name,
399 params,
400 return_type,
401 ..
402 } = &m.node
403 {
404 let non_self: Vec<_> =
405 params.iter().filter(|p| p.name != "self").collect();
406 let param_count = non_self.len();
407 let param_types: Vec<Option<TypeExpr>> =
408 non_self.iter().map(|p| p.type_expr.clone()).collect();
409 Some(ImplMethodSig {
410 name: name.clone(),
411 param_count,
412 param_types,
413 return_type: return_type.clone(),
414 })
415 } else {
416 None
417 }
418 })
419 .collect();
420 scope.impl_methods.insert(type_name.clone(), sigs);
421 }
422 _ => {}
423 }
424 }
425 }
426
427 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
428 for stmt in stmts {
429 self.check_node(stmt, scope);
430 }
431 }
432
433 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
435 match pattern {
436 BindingPattern::Identifier(name) => {
437 scope.define_var(name, None);
438 }
439 BindingPattern::Dict(fields) => {
440 for field in fields {
441 let name = field.alias.as_deref().unwrap_or(&field.key);
442 scope.define_var(name, None);
443 }
444 }
445 BindingPattern::List(elements) => {
446 for elem in elements {
447 scope.define_var(&elem.name, None);
448 }
449 }
450 }
451 }
452
453 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
454 let span = snode.span;
455 match &snode.node {
456 Node::LetBinding {
457 pattern,
458 type_ann,
459 value,
460 } => {
461 let inferred = self.infer_type(value, scope);
462 if let BindingPattern::Identifier(name) = pattern {
463 if let Some(expected) = type_ann {
464 if let Some(actual) = &inferred {
465 if !self.types_compatible(expected, actual, scope) {
466 let mut msg = format!(
467 "Type mismatch: '{}' declared as {}, but assigned {}",
468 name,
469 format_type(expected),
470 format_type(actual)
471 );
472 if let Some(detail) = shape_mismatch_detail(expected, actual) {
473 msg.push_str(&format!(" ({})", detail));
474 }
475 self.error_at(msg, span);
476 }
477 }
478 }
479 let ty = type_ann.clone().or(inferred);
480 scope.define_var(name, ty);
481 } else {
482 Self::define_pattern_vars(pattern, scope);
483 }
484 }
485
486 Node::VarBinding {
487 pattern,
488 type_ann,
489 value,
490 } => {
491 let inferred = self.infer_type(value, scope);
492 if let BindingPattern::Identifier(name) = pattern {
493 if let Some(expected) = type_ann {
494 if let Some(actual) = &inferred {
495 if !self.types_compatible(expected, actual, scope) {
496 let mut msg = format!(
497 "Type mismatch: '{}' declared as {}, but assigned {}",
498 name,
499 format_type(expected),
500 format_type(actual)
501 );
502 if let Some(detail) = shape_mismatch_detail(expected, actual) {
503 msg.push_str(&format!(" ({})", detail));
504 }
505 self.error_at(msg, span);
506 }
507 }
508 }
509 let ty = type_ann.clone().or(inferred);
510 scope.define_var(name, ty);
511 } else {
512 Self::define_pattern_vars(pattern, scope);
513 }
514 }
515
516 Node::FnDecl {
517 name,
518 type_params,
519 params,
520 return_type,
521 where_clauses,
522 body,
523 ..
524 } => {
525 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
526 let sig = FnSignature {
527 params: params
528 .iter()
529 .map(|p| (p.name.clone(), p.type_expr.clone()))
530 .collect(),
531 return_type: return_type.clone(),
532 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
533 required_params,
534 where_clauses: where_clauses
535 .iter()
536 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
537 .collect(),
538 };
539 scope.define_fn(name, sig.clone());
540 scope.define_var(name, None);
541 self.check_fn_body(type_params, params, return_type, body);
542 }
543
544 Node::FunctionCall { name, args } => {
545 self.check_call(name, args, scope, span);
546 }
547
548 Node::IfElse {
549 condition,
550 then_body,
551 else_body,
552 } => {
553 self.check_node(condition, scope);
554 let mut then_scope = scope.child();
555 if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
557 then_scope.define_var(&var_name, narrowed);
558 }
559 self.check_block(then_body, &mut then_scope);
560 if let Some(else_body) = else_body {
561 let mut else_scope = scope.child();
562 self.check_block(else_body, &mut else_scope);
563 }
564 }
565
566 Node::ForIn {
567 pattern,
568 iterable,
569 body,
570 } => {
571 self.check_node(iterable, scope);
572 let mut loop_scope = scope.child();
573 if let BindingPattern::Identifier(variable) = pattern {
574 let elem_type = match self.infer_type(iterable, scope) {
576 Some(TypeExpr::List(inner)) => Some(*inner),
577 Some(TypeExpr::Named(n)) if n == "string" => {
578 Some(TypeExpr::Named("string".into()))
579 }
580 _ => None,
581 };
582 loop_scope.define_var(variable, elem_type);
583 } else {
584 Self::define_pattern_vars(pattern, &mut loop_scope);
585 }
586 self.check_block(body, &mut loop_scope);
587 }
588
589 Node::WhileLoop { condition, body } => {
590 self.check_node(condition, scope);
591 let mut loop_scope = scope.child();
592 self.check_block(body, &mut loop_scope);
593 }
594
595 Node::TryCatch {
596 body,
597 error_var,
598 catch_body,
599 finally_body,
600 ..
601 } => {
602 let mut try_scope = scope.child();
603 self.check_block(body, &mut try_scope);
604 let mut catch_scope = scope.child();
605 if let Some(var) = error_var {
606 catch_scope.define_var(var, None);
607 }
608 self.check_block(catch_body, &mut catch_scope);
609 if let Some(fb) = finally_body {
610 let mut finally_scope = scope.child();
611 self.check_block(fb, &mut finally_scope);
612 }
613 }
614
615 Node::TryExpr { body } => {
616 let mut try_scope = scope.child();
617 self.check_block(body, &mut try_scope);
618 }
619
620 Node::ReturnStmt {
621 value: Some(val), ..
622 } => {
623 self.check_node(val, scope);
624 }
625
626 Node::Assignment {
627 target, value, op, ..
628 } => {
629 self.check_node(value, scope);
630 if let Node::Identifier(name) = &target.node {
631 if let Some(Some(var_type)) = scope.get_var(name) {
632 let value_type = self.infer_type(value, scope);
633 let assigned = if let Some(op) = op {
634 let var_inferred = scope.get_var(name).cloned().flatten();
635 infer_binary_op_type(op, &var_inferred, &value_type)
636 } else {
637 value_type
638 };
639 if let Some(actual) = &assigned {
640 if !self.types_compatible(var_type, actual, scope) {
641 self.error_at(
642 format!(
643 "Type mismatch: cannot assign {} to '{}' (declared as {})",
644 format_type(actual),
645 name,
646 format_type(var_type)
647 ),
648 span,
649 );
650 }
651 }
652 }
653 }
654 }
655
656 Node::TypeDecl { name, type_expr } => {
657 scope.type_aliases.insert(name.clone(), type_expr.clone());
658 }
659
660 Node::EnumDecl { name, variants } => {
661 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
662 scope.enums.insert(name.clone(), variant_names);
663 }
664
665 Node::StructDecl { name, fields } => {
666 let field_types: Vec<(String, InferredType)> = fields
667 .iter()
668 .map(|f| (f.name.clone(), f.type_expr.clone()))
669 .collect();
670 scope.structs.insert(name.clone(), field_types);
671 }
672
673 Node::InterfaceDecl { name, methods } => {
674 scope.interfaces.insert(name.clone(), methods.clone());
675 }
676
677 Node::ImplBlock {
678 type_name, methods, ..
679 } => {
680 let sigs: Vec<ImplMethodSig> = methods
682 .iter()
683 .filter_map(|m| {
684 if let Node::FnDecl {
685 name,
686 params,
687 return_type,
688 ..
689 } = &m.node
690 {
691 let non_self: Vec<_> =
692 params.iter().filter(|p| p.name != "self").collect();
693 let param_count = non_self.len();
694 let param_types: Vec<Option<TypeExpr>> =
695 non_self.iter().map(|p| p.type_expr.clone()).collect();
696 Some(ImplMethodSig {
697 name: name.clone(),
698 param_count,
699 param_types,
700 return_type: return_type.clone(),
701 })
702 } else {
703 None
704 }
705 })
706 .collect();
707 scope.impl_methods.insert(type_name.clone(), sigs);
708 for method_sn in methods {
709 self.check_node(method_sn, scope);
710 }
711 }
712
713 Node::TryOperator { operand } => {
714 self.check_node(operand, scope);
715 }
716
717 Node::MatchExpr { value, arms } => {
718 self.check_node(value, scope);
719 let value_type = self.infer_type(value, scope);
720 for arm in arms {
721 self.check_node(&arm.pattern, scope);
722 if let Some(ref vt) = value_type {
724 let value_type_name = format_type(vt);
725 let mismatch = match &arm.pattern.node {
726 Node::StringLiteral(_) => {
727 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
728 }
729 Node::IntLiteral(_) => {
730 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
731 && !self.types_compatible(
732 vt,
733 &TypeExpr::Named("float".into()),
734 scope,
735 )
736 }
737 Node::FloatLiteral(_) => {
738 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
739 && !self.types_compatible(
740 vt,
741 &TypeExpr::Named("int".into()),
742 scope,
743 )
744 }
745 Node::BoolLiteral(_) => {
746 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
747 }
748 _ => false,
749 };
750 if mismatch {
751 let pattern_type = match &arm.pattern.node {
752 Node::StringLiteral(_) => "string",
753 Node::IntLiteral(_) => "int",
754 Node::FloatLiteral(_) => "float",
755 Node::BoolLiteral(_) => "bool",
756 _ => unreachable!(),
757 };
758 self.warning_at(
759 format!(
760 "Match pattern type mismatch: matching {} against {} literal",
761 value_type_name, pattern_type
762 ),
763 arm.pattern.span,
764 );
765 }
766 }
767 let mut arm_scope = scope.child();
768 self.check_block(&arm.body, &mut arm_scope);
769 }
770 self.check_match_exhaustiveness(value, arms, scope, span);
771 }
772
773 Node::BinaryOp { op, left, right } => {
775 self.check_node(left, scope);
776 self.check_node(right, scope);
777 let lt = self.infer_type(left, scope);
779 let rt = self.infer_type(right, scope);
780 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
781 match op.as_str() {
782 "-" | "*" | "/" | "%" => {
783 let numeric = ["int", "float"];
784 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
785 self.warning_at(
786 format!(
787 "Operator '{op}' may not be valid for types {} and {}",
788 l, r
789 ),
790 span,
791 );
792 }
793 }
794 "+" => {
795 let valid = ["int", "float", "string", "list", "dict"];
797 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
798 self.warning_at(
799 format!(
800 "Operator '+' may not be valid for types {} and {}",
801 l, r
802 ),
803 span,
804 );
805 }
806 }
807 _ => {}
808 }
809 }
810 }
811 Node::UnaryOp { operand, .. } => {
812 self.check_node(operand, scope);
813 }
814 Node::MethodCall { object, args, .. }
815 | Node::OptionalMethodCall { object, args, .. } => {
816 self.check_node(object, scope);
817 for arg in args {
818 self.check_node(arg, scope);
819 }
820 }
821 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
822 self.check_node(object, scope);
823 }
824 Node::SubscriptAccess { object, index } => {
825 self.check_node(object, scope);
826 self.check_node(index, scope);
827 }
828 Node::SliceAccess { object, start, end } => {
829 self.check_node(object, scope);
830 if let Some(s) = start {
831 self.check_node(s, scope);
832 }
833 if let Some(e) = end {
834 self.check_node(e, scope);
835 }
836 }
837
838 _ => {}
840 }
841 }
842
843 fn check_fn_body(
844 &mut self,
845 type_params: &[TypeParam],
846 params: &[TypedParam],
847 return_type: &Option<TypeExpr>,
848 body: &[SNode],
849 ) {
850 let mut fn_scope = self.scope.child();
851 for tp in type_params {
854 fn_scope.generic_type_params.insert(tp.name.clone());
855 }
856 for param in params {
857 fn_scope.define_var(¶m.name, param.type_expr.clone());
858 if let Some(default) = ¶m.default_value {
859 self.check_node(default, &mut fn_scope);
860 }
861 }
862 self.check_block(body, &mut fn_scope);
863
864 if let Some(ret_type) = return_type {
866 for stmt in body {
867 self.check_return_type(stmt, ret_type, &fn_scope);
868 }
869 }
870 }
871
872 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
873 let span = snode.span;
874 match &snode.node {
875 Node::ReturnStmt { value: Some(val) } => {
876 let inferred = self.infer_type(val, scope);
877 if let Some(actual) = &inferred {
878 if !self.types_compatible(expected, actual, scope) {
879 self.error_at(
880 format!(
881 "Return type mismatch: expected {}, got {}",
882 format_type(expected),
883 format_type(actual)
884 ),
885 span,
886 );
887 }
888 }
889 }
890 Node::IfElse {
891 then_body,
892 else_body,
893 ..
894 } => {
895 for stmt in then_body {
896 self.check_return_type(stmt, expected, scope);
897 }
898 if let Some(else_body) = else_body {
899 for stmt in else_body {
900 self.check_return_type(stmt, expected, scope);
901 }
902 }
903 }
904 _ => {}
905 }
906 }
907
908 fn satisfies_interface(
914 &self,
915 type_name: &str,
916 interface_name: &str,
917 scope: &TypeScope,
918 ) -> bool {
919 let interface_methods = match scope.get_interface(interface_name) {
920 Some(methods) => methods,
921 None => return false,
922 };
923 let impl_methods = match scope.get_impl_methods(type_name) {
924 Some(methods) => methods,
925 None => return interface_methods.is_empty(),
926 };
927 interface_methods.iter().all(|iface_method| {
928 let iface_params: Vec<_> = iface_method
929 .params
930 .iter()
931 .filter(|p| p.name != "self")
932 .collect();
933 let iface_param_count = iface_params.len();
934 impl_methods.iter().any(|impl_method| {
935 if impl_method.name != iface_method.name
936 || impl_method.param_count != iface_param_count
937 {
938 return false;
939 }
940 for (i, iface_param) in iface_params.iter().enumerate() {
942 if let (Some(expected), Some(actual)) = (
943 &iface_param.type_expr,
944 impl_method.param_types.get(i).and_then(|t| t.as_ref()),
945 ) {
946 if !self.types_compatible(expected, actual, scope) {
947 return false;
948 }
949 }
950 }
951 if let (Some(expected_ret), Some(actual_ret)) =
953 (&iface_method.return_type, &impl_method.return_type)
954 {
955 if !self.types_compatible(expected_ret, actual_ret, scope) {
956 return false;
957 }
958 }
959 true
960 })
961 })
962 }
963
964 fn extract_type_bindings(
967 param_type: &TypeExpr,
968 arg_type: &TypeExpr,
969 type_params: &std::collections::BTreeSet<String>,
970 bindings: &mut BTreeMap<String, String>,
971 ) {
972 match (param_type, arg_type) {
973 (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
975 if type_params.contains(param_name) =>
976 {
977 bindings
978 .entry(param_name.clone())
979 .or_insert(concrete.clone());
980 }
981 (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
983 Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
984 }
985 (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
987 Self::extract_type_bindings(pk, ak, type_params, bindings);
988 Self::extract_type_bindings(pv, av, type_params, bindings);
989 }
990 _ => {}
991 }
992 }
993
994 fn extract_nil_narrowing(
995 condition: &SNode,
996 scope: &TypeScope,
997 ) -> Option<(String, InferredType)> {
998 if let Node::BinaryOp { op, left, right } = &condition.node {
999 if op == "!=" {
1000 let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1002 (left, right)
1003 } else if matches!(left.node, Node::NilLiteral) {
1004 (right, left)
1005 } else {
1006 return None;
1007 };
1008 let _ = nil_node;
1009 if let Node::Identifier(name) = &var_node.node {
1010 if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1012 let narrowed: Vec<TypeExpr> = members
1013 .iter()
1014 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1015 .cloned()
1016 .collect();
1017 return if narrowed.len() == 1 {
1018 Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1019 } else if narrowed.is_empty() {
1020 None
1021 } else {
1022 Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1023 };
1024 }
1025 }
1026 }
1027 }
1028 None
1029 }
1030
1031 fn check_match_exhaustiveness(
1032 &mut self,
1033 value: &SNode,
1034 arms: &[MatchArm],
1035 scope: &TypeScope,
1036 span: Span,
1037 ) {
1038 let enum_name = match &value.node {
1040 Node::PropertyAccess { object, property } if property == "variant" => {
1041 match self.infer_type(object, scope) {
1043 Some(TypeExpr::Named(name)) => {
1044 if scope.get_enum(&name).is_some() {
1045 Some(name)
1046 } else {
1047 None
1048 }
1049 }
1050 _ => None,
1051 }
1052 }
1053 _ => {
1054 match self.infer_type(value, scope) {
1056 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1057 _ => None,
1058 }
1059 }
1060 };
1061
1062 let Some(enum_name) = enum_name else {
1063 return;
1064 };
1065 let Some(variants) = scope.get_enum(&enum_name) else {
1066 return;
1067 };
1068
1069 let mut covered: Vec<String> = Vec::new();
1071 let mut has_wildcard = false;
1072
1073 for arm in arms {
1074 match &arm.pattern.node {
1075 Node::StringLiteral(s) => covered.push(s.clone()),
1077 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1079 has_wildcard = true;
1080 }
1081 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1083 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1085 _ => {
1086 has_wildcard = true;
1088 }
1089 }
1090 }
1091
1092 if has_wildcard {
1093 return;
1094 }
1095
1096 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1097 if !missing.is_empty() {
1098 let missing_str = missing
1099 .iter()
1100 .map(|s| format!("\"{}\"", s))
1101 .collect::<Vec<_>>()
1102 .join(", ");
1103 self.warning_at(
1104 format!(
1105 "Non-exhaustive match on enum {}: missing variants {}",
1106 enum_name, missing_str
1107 ),
1108 span,
1109 );
1110 }
1111 }
1112
1113 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1114 let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1116 if let Some(sig) = scope.get_fn(name).cloned() {
1117 if !has_spread
1118 && !is_builtin(name)
1119 && (args.len() < sig.required_params || args.len() > sig.params.len())
1120 {
1121 let expected = if sig.required_params == sig.params.len() {
1122 format!("{}", sig.params.len())
1123 } else {
1124 format!("{}-{}", sig.required_params, sig.params.len())
1125 };
1126 self.warning_at(
1127 format!(
1128 "Function '{}' expects {} arguments, got {}",
1129 name,
1130 expected,
1131 args.len()
1132 ),
1133 span,
1134 );
1135 }
1136 let call_scope = if sig.type_param_names.is_empty() {
1139 scope.clone()
1140 } else {
1141 let mut s = scope.child();
1142 for tp_name in &sig.type_param_names {
1143 s.generic_type_params.insert(tp_name.clone());
1144 }
1145 s
1146 };
1147 for (i, (arg, (param_name, param_type))) in
1148 args.iter().zip(sig.params.iter()).enumerate()
1149 {
1150 if let Some(expected) = param_type {
1151 let actual = self.infer_type(arg, scope);
1152 if let Some(actual) = &actual {
1153 if !self.types_compatible(expected, actual, &call_scope) {
1154 self.error_at(
1155 format!(
1156 "Argument {} ('{}'): expected {}, got {}",
1157 i + 1,
1158 param_name,
1159 format_type(expected),
1160 format_type(actual)
1161 ),
1162 arg.span,
1163 );
1164 }
1165 }
1166 }
1167 }
1168 if !sig.where_clauses.is_empty() {
1170 let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1173 let type_param_set: std::collections::BTreeSet<String> =
1174 sig.type_param_names.iter().cloned().collect();
1175 for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1176 if let Some(param_ty) = param_type {
1177 if let Some(arg_ty) = self.infer_type(arg, scope) {
1178 Self::extract_type_bindings(
1179 param_ty,
1180 &arg_ty,
1181 &type_param_set,
1182 &mut type_bindings,
1183 );
1184 }
1185 }
1186 }
1187 for (type_param, bound) in &sig.where_clauses {
1188 if let Some(concrete_type) = type_bindings.get(type_param) {
1189 if !self.satisfies_interface(concrete_type, bound, scope) {
1190 self.warning_at(
1191 format!(
1192 "Type '{}' does not satisfy interface '{}': \
1193 required by constraint `where {}: {}`",
1194 concrete_type, bound, type_param, bound
1195 ),
1196 span,
1197 );
1198 }
1199 }
1200 }
1201 }
1202 }
1203 for arg in args {
1205 self.check_node(arg, scope);
1206 }
1207 }
1208
1209 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1211 match &snode.node {
1212 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1213 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1214 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1215 Some(TypeExpr::Named("string".into()))
1216 }
1217 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1218 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1219 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1220 Node::DictLiteral(entries) => {
1221 let mut fields = Vec::new();
1223 let mut all_string_keys = true;
1224 for entry in entries {
1225 if let Node::StringLiteral(key) = &entry.key.node {
1226 let val_type = self
1227 .infer_type(&entry.value, scope)
1228 .unwrap_or(TypeExpr::Named("nil".into()));
1229 fields.push(ShapeField {
1230 name: key.clone(),
1231 type_expr: val_type,
1232 optional: false,
1233 });
1234 } else {
1235 all_string_keys = false;
1236 break;
1237 }
1238 }
1239 if all_string_keys && !fields.is_empty() {
1240 Some(TypeExpr::Shape(fields))
1241 } else {
1242 Some(TypeExpr::Named("dict".into()))
1243 }
1244 }
1245 Node::Closure { params, body } => {
1246 let all_typed = params.iter().all(|p| p.type_expr.is_some());
1248 if all_typed && !params.is_empty() {
1249 let param_types: Vec<TypeExpr> =
1250 params.iter().filter_map(|p| p.type_expr.clone()).collect();
1251 let ret = body.last().and_then(|last| self.infer_type(last, scope));
1253 if let Some(ret_type) = ret {
1254 return Some(TypeExpr::FnType {
1255 params: param_types,
1256 return_type: Box::new(ret_type),
1257 });
1258 }
1259 }
1260 Some(TypeExpr::Named("closure".into()))
1261 }
1262
1263 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1264
1265 Node::FunctionCall { name, .. } => {
1266 if scope.get_struct(name).is_some() {
1268 return Some(TypeExpr::Named(name.clone()));
1269 }
1270 if let Some(sig) = scope.get_fn(name) {
1272 return sig.return_type.clone();
1273 }
1274 builtin_return_type(name)
1276 }
1277
1278 Node::BinaryOp { op, left, right } => {
1279 let lt = self.infer_type(left, scope);
1280 let rt = self.infer_type(right, scope);
1281 infer_binary_op_type(op, <, &rt)
1282 }
1283
1284 Node::UnaryOp { op, operand } => {
1285 let t = self.infer_type(operand, scope);
1286 match op.as_str() {
1287 "!" => Some(TypeExpr::Named("bool".into())),
1288 "-" => t, _ => None,
1290 }
1291 }
1292
1293 Node::Ternary {
1294 true_expr,
1295 false_expr,
1296 ..
1297 } => {
1298 let tt = self.infer_type(true_expr, scope);
1299 let ft = self.infer_type(false_expr, scope);
1300 match (&tt, &ft) {
1301 (Some(a), Some(b)) if a == b => tt,
1302 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1303 (Some(_), None) => tt,
1304 (None, Some(_)) => ft,
1305 (None, None) => None,
1306 }
1307 }
1308
1309 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1310
1311 Node::PropertyAccess { object, property } => {
1312 if let Node::Identifier(name) = &object.node {
1314 if scope.get_enum(name).is_some() {
1315 return Some(TypeExpr::Named(name.clone()));
1316 }
1317 }
1318 if property == "variant" {
1320 let obj_type = self.infer_type(object, scope);
1321 if let Some(TypeExpr::Named(name)) = &obj_type {
1322 if scope.get_enum(name).is_some() {
1323 return Some(TypeExpr::Named("string".into()));
1324 }
1325 }
1326 }
1327 let obj_type = self.infer_type(object, scope);
1329 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1330 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1331 return Some(field.type_expr.clone());
1332 }
1333 }
1334 None
1335 }
1336
1337 Node::SubscriptAccess { object, index } => {
1338 let obj_type = self.infer_type(object, scope);
1339 match &obj_type {
1340 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1341 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1342 Some(TypeExpr::Shape(fields)) => {
1343 if let Node::StringLiteral(key) = &index.node {
1345 fields
1346 .iter()
1347 .find(|f| &f.name == key)
1348 .map(|f| f.type_expr.clone())
1349 } else {
1350 None
1351 }
1352 }
1353 Some(TypeExpr::Named(n)) if n == "list" => None,
1354 Some(TypeExpr::Named(n)) if n == "dict" => None,
1355 Some(TypeExpr::Named(n)) if n == "string" => {
1356 Some(TypeExpr::Named("string".into()))
1357 }
1358 _ => None,
1359 }
1360 }
1361 Node::SliceAccess { object, .. } => {
1362 let obj_type = self.infer_type(object, scope);
1364 match &obj_type {
1365 Some(TypeExpr::List(_)) => obj_type,
1366 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1367 Some(TypeExpr::Named(n)) if n == "string" => {
1368 Some(TypeExpr::Named("string".into()))
1369 }
1370 _ => None,
1371 }
1372 }
1373 Node::MethodCall { object, method, .. }
1374 | Node::OptionalMethodCall { object, method, .. } => {
1375 let obj_type = self.infer_type(object, scope);
1376 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1377 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1378 match method.as_str() {
1379 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1381 Some(TypeExpr::Named("bool".into()))
1382 }
1383 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1385 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1387 | "pad_left" | "pad_right" | "repeat" | "join" => {
1388 Some(TypeExpr::Named("string".into()))
1389 }
1390 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1391 "filter" => {
1393 if is_dict {
1394 Some(TypeExpr::Named("dict".into()))
1395 } else {
1396 Some(TypeExpr::Named("list".into()))
1397 }
1398 }
1399 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1401 "reduce" | "find" | "first" | "last" => None,
1402 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1404 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1405 "to_string" => Some(TypeExpr::Named("string".into())),
1407 "to_int" => Some(TypeExpr::Named("int".into())),
1408 "to_float" => Some(TypeExpr::Named("float".into())),
1409 _ => None,
1410 }
1411 }
1412
1413 Node::TryOperator { operand } => {
1415 match self.infer_type(operand, scope) {
1416 Some(TypeExpr::Named(name)) if name == "Result" => None, _ => None,
1418 }
1419 }
1420
1421 _ => None,
1422 }
1423 }
1424
1425 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1427 if let TypeExpr::Named(name) = expected {
1429 if scope.is_generic_type_param(name) {
1430 return true;
1431 }
1432 }
1433 if let TypeExpr::Named(name) = actual {
1434 if scope.is_generic_type_param(name) {
1435 return true;
1436 }
1437 }
1438 let expected = self.resolve_alias(expected, scope);
1439 let actual = self.resolve_alias(actual, scope);
1440
1441 if let TypeExpr::Named(iface_name) = &expected {
1444 if scope.get_interface(iface_name).is_some() {
1445 if let TypeExpr::Named(type_name) = &actual {
1446 return self.satisfies_interface(type_name, iface_name, scope);
1447 }
1448 return false;
1449 }
1450 }
1451
1452 match (&expected, &actual) {
1453 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1454 (TypeExpr::Union(members), actual_type) => members
1455 .iter()
1456 .any(|m| self.types_compatible(m, actual_type, scope)),
1457 (expected_type, TypeExpr::Union(members)) => members
1458 .iter()
1459 .all(|m| self.types_compatible(expected_type, m, scope)),
1460 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1461 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1462 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1463 if expected_field.optional {
1464 return true;
1465 }
1466 af.iter().any(|actual_field| {
1467 actual_field.name == expected_field.name
1468 && self.types_compatible(
1469 &expected_field.type_expr,
1470 &actual_field.type_expr,
1471 scope,
1472 )
1473 })
1474 }),
1475 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1477 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1478 keys_ok
1479 && af
1480 .iter()
1481 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1482 }
1483 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1485 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1486 self.types_compatible(expected_inner, actual_inner, scope)
1487 }
1488 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1489 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1490 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1491 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1492 }
1493 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1494 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1495 (
1497 TypeExpr::FnType {
1498 params: ep,
1499 return_type: er,
1500 },
1501 TypeExpr::FnType {
1502 params: ap,
1503 return_type: ar,
1504 },
1505 ) => {
1506 ep.len() == ap.len()
1507 && ep
1508 .iter()
1509 .zip(ap.iter())
1510 .all(|(e, a)| self.types_compatible(e, a, scope))
1511 && self.types_compatible(er, ar, scope)
1512 }
1513 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1515 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1516 _ => false,
1517 }
1518 }
1519
1520 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1521 if let TypeExpr::Named(name) = ty {
1522 if let Some(resolved) = scope.resolve_type(name) {
1523 return resolved.clone();
1524 }
1525 }
1526 ty.clone()
1527 }
1528
1529 fn error_at(&mut self, message: String, span: Span) {
1530 self.diagnostics.push(TypeDiagnostic {
1531 message,
1532 severity: DiagnosticSeverity::Error,
1533 span: Some(span),
1534 });
1535 }
1536
1537 fn warning_at(&mut self, message: String, span: Span) {
1538 self.diagnostics.push(TypeDiagnostic {
1539 message,
1540 severity: DiagnosticSeverity::Warning,
1541 span: Some(span),
1542 });
1543 }
1544}
1545
1546impl Default for TypeChecker {
1547 fn default() -> Self {
1548 Self::new()
1549 }
1550}
1551
1552fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1554 match op {
1555 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1556 "+" => match (left, right) {
1557 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1558 match (l.as_str(), r.as_str()) {
1559 ("int", "int") => Some(TypeExpr::Named("int".into())),
1560 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1561 ("string", _) => Some(TypeExpr::Named("string".into())),
1562 ("list", "list") => Some(TypeExpr::Named("list".into())),
1563 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1564 _ => Some(TypeExpr::Named("string".into())),
1565 }
1566 }
1567 _ => None,
1568 },
1569 "-" | "*" | "/" | "%" => match (left, right) {
1570 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1571 match (l.as_str(), r.as_str()) {
1572 ("int", "int") => Some(TypeExpr::Named("int".into())),
1573 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1574 _ => None,
1575 }
1576 }
1577 _ => None,
1578 },
1579 "??" => match (left, right) {
1580 (Some(TypeExpr::Union(members)), _) => {
1581 let non_nil: Vec<_> = members
1582 .iter()
1583 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1584 .cloned()
1585 .collect();
1586 if non_nil.len() == 1 {
1587 Some(non_nil[0].clone())
1588 } else if non_nil.is_empty() {
1589 right.clone()
1590 } else {
1591 Some(TypeExpr::Union(non_nil))
1592 }
1593 }
1594 _ => right.clone(),
1595 },
1596 "|>" => None,
1597 _ => None,
1598 }
1599}
1600
1601pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1606 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1607 let mut details = Vec::new();
1608 for field in ef {
1609 if field.optional {
1610 continue;
1611 }
1612 match af.iter().find(|f| f.name == field.name) {
1613 None => details.push(format!(
1614 "missing field '{}' ({})",
1615 field.name,
1616 format_type(&field.type_expr)
1617 )),
1618 Some(actual_field) => {
1619 let e_str = format_type(&field.type_expr);
1620 let a_str = format_type(&actual_field.type_expr);
1621 if e_str != a_str {
1622 details.push(format!(
1623 "field '{}' has type {}, expected {}",
1624 field.name, a_str, e_str
1625 ));
1626 }
1627 }
1628 }
1629 }
1630 if details.is_empty() {
1631 None
1632 } else {
1633 Some(details.join("; "))
1634 }
1635 } else {
1636 None
1637 }
1638}
1639
1640pub fn format_type(ty: &TypeExpr) -> String {
1641 match ty {
1642 TypeExpr::Named(n) => n.clone(),
1643 TypeExpr::Union(types) => types
1644 .iter()
1645 .map(format_type)
1646 .collect::<Vec<_>>()
1647 .join(" | "),
1648 TypeExpr::Shape(fields) => {
1649 let inner: Vec<String> = fields
1650 .iter()
1651 .map(|f| {
1652 let opt = if f.optional { "?" } else { "" };
1653 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1654 })
1655 .collect();
1656 format!("{{{}}}", inner.join(", "))
1657 }
1658 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1659 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1660 TypeExpr::FnType {
1661 params,
1662 return_type,
1663 } => {
1664 let params_str = params
1665 .iter()
1666 .map(format_type)
1667 .collect::<Vec<_>>()
1668 .join(", ");
1669 format!("fn({}) -> {}", params_str, format_type(return_type))
1670 }
1671 }
1672}
1673
1674#[cfg(test)]
1675mod tests {
1676 use super::*;
1677 use crate::Parser;
1678 use harn_lexer::Lexer;
1679
1680 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1681 let mut lexer = Lexer::new(source);
1682 let tokens = lexer.tokenize().unwrap();
1683 let mut parser = Parser::new(tokens);
1684 let program = parser.parse().unwrap();
1685 TypeChecker::new().check(&program)
1686 }
1687
1688 fn errors(source: &str) -> Vec<String> {
1689 check_source(source)
1690 .into_iter()
1691 .filter(|d| d.severity == DiagnosticSeverity::Error)
1692 .map(|d| d.message)
1693 .collect()
1694 }
1695
1696 #[test]
1697 fn test_no_errors_for_untyped_code() {
1698 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1699 assert!(errs.is_empty());
1700 }
1701
1702 #[test]
1703 fn test_correct_typed_let() {
1704 let errs = errors("pipeline t(task) { let x: int = 42 }");
1705 assert!(errs.is_empty());
1706 }
1707
1708 #[test]
1709 fn test_type_mismatch_let() {
1710 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1711 assert_eq!(errs.len(), 1);
1712 assert!(errs[0].contains("Type mismatch"));
1713 assert!(errs[0].contains("int"));
1714 assert!(errs[0].contains("string"));
1715 }
1716
1717 #[test]
1718 fn test_correct_typed_fn() {
1719 let errs = errors(
1720 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1721 );
1722 assert!(errs.is_empty());
1723 }
1724
1725 #[test]
1726 fn test_fn_arg_type_mismatch() {
1727 let errs = errors(
1728 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1729add("hello", 2) }"#,
1730 );
1731 assert_eq!(errs.len(), 1);
1732 assert!(errs[0].contains("Argument 1"));
1733 assert!(errs[0].contains("expected int"));
1734 }
1735
1736 #[test]
1737 fn test_return_type_mismatch() {
1738 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1739 assert_eq!(errs.len(), 1);
1740 assert!(errs[0].contains("Return type mismatch"));
1741 }
1742
1743 #[test]
1744 fn test_union_type_compatible() {
1745 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1746 assert!(errs.is_empty());
1747 }
1748
1749 #[test]
1750 fn test_union_type_mismatch() {
1751 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1752 assert_eq!(errs.len(), 1);
1753 assert!(errs[0].contains("Type mismatch"));
1754 }
1755
1756 #[test]
1757 fn test_type_inference_propagation() {
1758 let errs = errors(
1759 r#"pipeline t(task) {
1760 fn add(a: int, b: int) -> int { return a + b }
1761 let result: string = add(1, 2)
1762}"#,
1763 );
1764 assert_eq!(errs.len(), 1);
1765 assert!(errs[0].contains("Type mismatch"));
1766 assert!(errs[0].contains("string"));
1767 assert!(errs[0].contains("int"));
1768 }
1769
1770 #[test]
1771 fn test_builtin_return_type_inference() {
1772 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1773 assert_eq!(errs.len(), 1);
1774 assert!(errs[0].contains("string"));
1775 assert!(errs[0].contains("int"));
1776 }
1777
1778 #[test]
1779 fn test_binary_op_type_inference() {
1780 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1781 assert_eq!(errs.len(), 1);
1782 }
1783
1784 #[test]
1785 fn test_comparison_returns_bool() {
1786 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1787 assert!(errs.is_empty());
1788 }
1789
1790 #[test]
1791 fn test_int_float_promotion() {
1792 let errs = errors("pipeline t(task) { let x: float = 42 }");
1793 assert!(errs.is_empty());
1794 }
1795
1796 #[test]
1797 fn test_untyped_code_no_errors() {
1798 let errs = errors(
1799 r#"pipeline t(task) {
1800 fn process(data) {
1801 let result = data + " processed"
1802 return result
1803 }
1804 log(process("hello"))
1805}"#,
1806 );
1807 assert!(errs.is_empty());
1808 }
1809
1810 #[test]
1811 fn test_type_alias() {
1812 let errs = errors(
1813 r#"pipeline t(task) {
1814 type Name = string
1815 let x: Name = "hello"
1816}"#,
1817 );
1818 assert!(errs.is_empty());
1819 }
1820
1821 #[test]
1822 fn test_type_alias_mismatch() {
1823 let errs = errors(
1824 r#"pipeline t(task) {
1825 type Name = string
1826 let x: Name = 42
1827}"#,
1828 );
1829 assert_eq!(errs.len(), 1);
1830 }
1831
1832 #[test]
1833 fn test_assignment_type_check() {
1834 let errs = errors(
1835 r#"pipeline t(task) {
1836 var x: int = 0
1837 x = "hello"
1838}"#,
1839 );
1840 assert_eq!(errs.len(), 1);
1841 assert!(errs[0].contains("cannot assign string"));
1842 }
1843
1844 #[test]
1845 fn test_covariance_int_to_float_in_fn() {
1846 let errs = errors(
1847 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1848 );
1849 assert!(errs.is_empty());
1850 }
1851
1852 #[test]
1853 fn test_covariance_return_type() {
1854 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1855 assert!(errs.is_empty());
1856 }
1857
1858 #[test]
1859 fn test_no_contravariance_float_to_int() {
1860 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1861 assert_eq!(errs.len(), 1);
1862 }
1863
1864 fn warnings(source: &str) -> Vec<String> {
1867 check_source(source)
1868 .into_iter()
1869 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1870 .map(|d| d.message)
1871 .collect()
1872 }
1873
1874 #[test]
1875 fn test_exhaustive_match_no_warning() {
1876 let warns = warnings(
1877 r#"pipeline t(task) {
1878 enum Color { Red, Green, Blue }
1879 let c = Color.Red
1880 match c.variant {
1881 "Red" -> { log("r") }
1882 "Green" -> { log("g") }
1883 "Blue" -> { log("b") }
1884 }
1885}"#,
1886 );
1887 let exhaustive_warns: Vec<_> = warns
1888 .iter()
1889 .filter(|w| w.contains("Non-exhaustive"))
1890 .collect();
1891 assert!(exhaustive_warns.is_empty());
1892 }
1893
1894 #[test]
1895 fn test_non_exhaustive_match_warning() {
1896 let warns = warnings(
1897 r#"pipeline t(task) {
1898 enum Color { Red, Green, Blue }
1899 let c = Color.Red
1900 match c.variant {
1901 "Red" -> { log("r") }
1902 "Green" -> { log("g") }
1903 }
1904}"#,
1905 );
1906 let exhaustive_warns: Vec<_> = warns
1907 .iter()
1908 .filter(|w| w.contains("Non-exhaustive"))
1909 .collect();
1910 assert_eq!(exhaustive_warns.len(), 1);
1911 assert!(exhaustive_warns[0].contains("Blue"));
1912 }
1913
1914 #[test]
1915 fn test_non_exhaustive_multiple_missing() {
1916 let warns = warnings(
1917 r#"pipeline t(task) {
1918 enum Status { Active, Inactive, Pending }
1919 let s = Status.Active
1920 match s.variant {
1921 "Active" -> { log("a") }
1922 }
1923}"#,
1924 );
1925 let exhaustive_warns: Vec<_> = warns
1926 .iter()
1927 .filter(|w| w.contains("Non-exhaustive"))
1928 .collect();
1929 assert_eq!(exhaustive_warns.len(), 1);
1930 assert!(exhaustive_warns[0].contains("Inactive"));
1931 assert!(exhaustive_warns[0].contains("Pending"));
1932 }
1933
1934 #[test]
1935 fn test_enum_construct_type_inference() {
1936 let errs = errors(
1937 r#"pipeline t(task) {
1938 enum Color { Red, Green, Blue }
1939 let c: Color = Color.Red
1940}"#,
1941 );
1942 assert!(errs.is_empty());
1943 }
1944
1945 #[test]
1948 fn test_nil_coalescing_strips_nil() {
1949 let errs = errors(
1951 r#"pipeline t(task) {
1952 let x: string | nil = nil
1953 let y: string = x ?? "default"
1954}"#,
1955 );
1956 assert!(errs.is_empty());
1957 }
1958
1959 #[test]
1960 fn test_shape_mismatch_detail_missing_field() {
1961 let errs = errors(
1962 r#"pipeline t(task) {
1963 let x: {name: string, age: int} = {name: "hello"}
1964}"#,
1965 );
1966 assert_eq!(errs.len(), 1);
1967 assert!(
1968 errs[0].contains("missing field 'age'"),
1969 "expected detail about missing field, got: {}",
1970 errs[0]
1971 );
1972 }
1973
1974 #[test]
1975 fn test_shape_mismatch_detail_wrong_type() {
1976 let errs = errors(
1977 r#"pipeline t(task) {
1978 let x: {name: string, age: int} = {name: 42, age: 10}
1979}"#,
1980 );
1981 assert_eq!(errs.len(), 1);
1982 assert!(
1983 errs[0].contains("field 'name' has type int, expected string"),
1984 "expected detail about wrong type, got: {}",
1985 errs[0]
1986 );
1987 }
1988
1989 #[test]
1992 fn test_match_pattern_string_against_int() {
1993 let warns = warnings(
1994 r#"pipeline t(task) {
1995 let x: int = 42
1996 match x {
1997 "hello" -> { log("bad") }
1998 42 -> { log("ok") }
1999 }
2000}"#,
2001 );
2002 let pattern_warns: Vec<_> = warns
2003 .iter()
2004 .filter(|w| w.contains("Match pattern type mismatch"))
2005 .collect();
2006 assert_eq!(pattern_warns.len(), 1);
2007 assert!(pattern_warns[0].contains("matching int against string literal"));
2008 }
2009
2010 #[test]
2011 fn test_match_pattern_int_against_string() {
2012 let warns = warnings(
2013 r#"pipeline t(task) {
2014 let x: string = "hello"
2015 match x {
2016 42 -> { log("bad") }
2017 "hello" -> { log("ok") }
2018 }
2019}"#,
2020 );
2021 let pattern_warns: Vec<_> = warns
2022 .iter()
2023 .filter(|w| w.contains("Match pattern type mismatch"))
2024 .collect();
2025 assert_eq!(pattern_warns.len(), 1);
2026 assert!(pattern_warns[0].contains("matching string against int literal"));
2027 }
2028
2029 #[test]
2030 fn test_match_pattern_bool_against_int() {
2031 let warns = warnings(
2032 r#"pipeline t(task) {
2033 let x: int = 42
2034 match x {
2035 true -> { log("bad") }
2036 42 -> { log("ok") }
2037 }
2038}"#,
2039 );
2040 let pattern_warns: Vec<_> = warns
2041 .iter()
2042 .filter(|w| w.contains("Match pattern type mismatch"))
2043 .collect();
2044 assert_eq!(pattern_warns.len(), 1);
2045 assert!(pattern_warns[0].contains("matching int against bool literal"));
2046 }
2047
2048 #[test]
2049 fn test_match_pattern_float_against_string() {
2050 let warns = warnings(
2051 r#"pipeline t(task) {
2052 let x: string = "hello"
2053 match x {
2054 3.14 -> { log("bad") }
2055 "hello" -> { log("ok") }
2056 }
2057}"#,
2058 );
2059 let pattern_warns: Vec<_> = warns
2060 .iter()
2061 .filter(|w| w.contains("Match pattern type mismatch"))
2062 .collect();
2063 assert_eq!(pattern_warns.len(), 1);
2064 assert!(pattern_warns[0].contains("matching string against float literal"));
2065 }
2066
2067 #[test]
2068 fn test_match_pattern_int_against_float_ok() {
2069 let warns = warnings(
2071 r#"pipeline t(task) {
2072 let x: float = 3.14
2073 match x {
2074 42 -> { log("ok") }
2075 _ -> { log("default") }
2076 }
2077}"#,
2078 );
2079 let pattern_warns: Vec<_> = warns
2080 .iter()
2081 .filter(|w| w.contains("Match pattern type mismatch"))
2082 .collect();
2083 assert!(pattern_warns.is_empty());
2084 }
2085
2086 #[test]
2087 fn test_match_pattern_float_against_int_ok() {
2088 let warns = warnings(
2090 r#"pipeline t(task) {
2091 let x: int = 42
2092 match x {
2093 3.14 -> { log("close") }
2094 _ -> { log("default") }
2095 }
2096}"#,
2097 );
2098 let pattern_warns: Vec<_> = warns
2099 .iter()
2100 .filter(|w| w.contains("Match pattern type mismatch"))
2101 .collect();
2102 assert!(pattern_warns.is_empty());
2103 }
2104
2105 #[test]
2106 fn test_match_pattern_correct_types_no_warning() {
2107 let warns = warnings(
2108 r#"pipeline t(task) {
2109 let x: int = 42
2110 match x {
2111 1 -> { log("one") }
2112 2 -> { log("two") }
2113 _ -> { log("other") }
2114 }
2115}"#,
2116 );
2117 let pattern_warns: Vec<_> = warns
2118 .iter()
2119 .filter(|w| w.contains("Match pattern type mismatch"))
2120 .collect();
2121 assert!(pattern_warns.is_empty());
2122 }
2123
2124 #[test]
2125 fn test_match_pattern_wildcard_no_warning() {
2126 let warns = warnings(
2127 r#"pipeline t(task) {
2128 let x: int = 42
2129 match x {
2130 _ -> { log("catch all") }
2131 }
2132}"#,
2133 );
2134 let pattern_warns: Vec<_> = warns
2135 .iter()
2136 .filter(|w| w.contains("Match pattern type mismatch"))
2137 .collect();
2138 assert!(pattern_warns.is_empty());
2139 }
2140
2141 #[test]
2142 fn test_match_pattern_untyped_no_warning() {
2143 let warns = warnings(
2145 r#"pipeline t(task) {
2146 let x = some_unknown_fn()
2147 match x {
2148 "hello" -> { log("string") }
2149 42 -> { log("int") }
2150 }
2151}"#,
2152 );
2153 let pattern_warns: Vec<_> = warns
2154 .iter()
2155 .filter(|w| w.contains("Match pattern type mismatch"))
2156 .collect();
2157 assert!(pattern_warns.is_empty());
2158 }
2159}