1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9 pub message: String,
10 pub severity: DiagnosticSeverity,
11 pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16 Error,
17 Warning,
18}
19
20type InferredType = Option<TypeExpr>;
22
23#[derive(Debug, Clone)]
25struct TypeScope {
26 vars: BTreeMap<String, InferredType>,
28 functions: BTreeMap<String, FnSignature>,
30 type_aliases: BTreeMap<String, TypeExpr>,
32 enums: BTreeMap<String, Vec<String>>,
34 interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36 structs: BTreeMap<String, Vec<(String, InferredType)>>,
38 impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
40 generic_type_params: std::collections::BTreeSet<String>,
42 where_constraints: BTreeMap<String, String>,
45 parent: Option<Box<TypeScope>>,
46}
47
48#[derive(Debug, Clone)]
50struct ImplMethodSig {
51 name: String,
52 param_count: usize,
54 param_types: Vec<Option<TypeExpr>>,
56 return_type: Option<TypeExpr>,
58}
59
60#[derive(Debug, Clone)]
61struct FnSignature {
62 params: Vec<(String, InferredType)>,
63 return_type: InferredType,
64 type_param_names: Vec<String>,
66 required_params: usize,
68 where_clauses: Vec<(String, String)>,
70}
71
72impl TypeScope {
73 fn new() -> Self {
74 Self {
75 vars: BTreeMap::new(),
76 functions: BTreeMap::new(),
77 type_aliases: BTreeMap::new(),
78 enums: BTreeMap::new(),
79 interfaces: BTreeMap::new(),
80 structs: BTreeMap::new(),
81 impl_methods: BTreeMap::new(),
82 generic_type_params: std::collections::BTreeSet::new(),
83 where_constraints: BTreeMap::new(),
84 parent: None,
85 }
86 }
87
88 fn child(&self) -> Self {
89 Self {
90 vars: BTreeMap::new(),
91 functions: BTreeMap::new(),
92 type_aliases: BTreeMap::new(),
93 enums: BTreeMap::new(),
94 interfaces: BTreeMap::new(),
95 structs: BTreeMap::new(),
96 impl_methods: BTreeMap::new(),
97 generic_type_params: std::collections::BTreeSet::new(),
98 where_constraints: BTreeMap::new(),
99 parent: Some(Box::new(self.clone())),
100 }
101 }
102
103 fn get_var(&self, name: &str) -> Option<&InferredType> {
104 self.vars
105 .get(name)
106 .or_else(|| self.parent.as_ref()?.get_var(name))
107 }
108
109 fn get_fn(&self, name: &str) -> Option<&FnSignature> {
110 self.functions
111 .get(name)
112 .or_else(|| self.parent.as_ref()?.get_fn(name))
113 }
114
115 fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
116 self.type_aliases
117 .get(name)
118 .or_else(|| self.parent.as_ref()?.resolve_type(name))
119 }
120
121 fn is_generic_type_param(&self, name: &str) -> bool {
122 self.generic_type_params.contains(name)
123 || self
124 .parent
125 .as_ref()
126 .is_some_and(|p| p.is_generic_type_param(name))
127 }
128
129 fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
130 self.where_constraints
131 .get(type_param)
132 .map(|s| s.as_str())
133 .or_else(|| {
134 self.parent
135 .as_ref()
136 .and_then(|p| p.get_where_constraint(type_param))
137 })
138 }
139
140 fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
141 self.enums
142 .get(name)
143 .or_else(|| self.parent.as_ref()?.get_enum(name))
144 }
145
146 fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
147 self.interfaces
148 .get(name)
149 .or_else(|| self.parent.as_ref()?.get_interface(name))
150 }
151
152 fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
153 self.structs
154 .get(name)
155 .or_else(|| self.parent.as_ref()?.get_struct(name))
156 }
157
158 fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
159 self.impl_methods
160 .get(name)
161 .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
162 }
163
164 fn define_var(&mut self, name: &str, ty: InferredType) {
165 self.vars.insert(name.to_string(), ty);
166 }
167
168 fn define_fn(&mut self, name: &str, sig: FnSignature) {
169 self.functions.insert(name.to_string(), sig);
170 }
171}
172
173fn builtin_return_type(name: &str) -> InferredType {
175 match name {
176 "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
177 | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
178 Some(TypeExpr::Named("nil".into()))
179 }
180 "type_of"
181 | "to_string"
182 | "json_stringify"
183 | "read_file"
184 | "http_get"
185 | "http_post"
186 | "llm_call"
187 | "regex_replace"
188 | "path_join"
189 | "temp_dir"
190 | "date_format"
191 | "format"
192 | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
193 "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
194 "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
195 | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
196 Some(TypeExpr::Named("float".into()))
197 }
198 "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
199 Some(TypeExpr::Named("bool".into()))
200 }
201 "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list"
202 | "regex_captures" => Some(TypeExpr::Named("list".into())),
203 "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
204 | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
205 Some(TypeExpr::Named("dict".into()))
206 }
207 "metadata_set"
208 | "metadata_save"
209 | "metadata_refresh_hashes"
210 | "invalidate_facts"
211 | "log_json"
212 | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
213 "env" | "regex_match" => Some(TypeExpr::Union(vec![
214 TypeExpr::Named("string".into()),
215 TypeExpr::Named("nil".into()),
216 ])),
217 "json_parse" | "json_extract" => None, _ => None,
219 }
220}
221
222fn is_builtin(name: &str) -> bool {
224 matches!(
225 name,
226 "log"
227 | "print"
228 | "println"
229 | "type_of"
230 | "to_string"
231 | "to_int"
232 | "to_float"
233 | "json_stringify"
234 | "json_parse"
235 | "env"
236 | "timestamp"
237 | "sleep"
238 | "read_file"
239 | "write_file"
240 | "exit"
241 | "regex_match"
242 | "regex_replace"
243 | "regex_captures"
244 | "http_get"
245 | "http_post"
246 | "llm_call"
247 | "agent_loop"
248 | "await"
249 | "cancel"
250 | "file_exists"
251 | "delete_file"
252 | "list_dir"
253 | "mkdir"
254 | "path_join"
255 | "copy_file"
256 | "append_file"
257 | "temp_dir"
258 | "stat"
259 | "exec"
260 | "shell"
261 | "date_now"
262 | "date_format"
263 | "date_parse"
264 | "format"
265 | "json_validate"
266 | "json_extract"
267 | "trim"
268 | "lowercase"
269 | "uppercase"
270 | "split"
271 | "starts_with"
272 | "ends_with"
273 | "contains"
274 | "replace"
275 | "join"
276 | "len"
277 | "substring"
278 | "dirname"
279 | "basename"
280 | "extname"
281 | "sin"
282 | "cos"
283 | "tan"
284 | "asin"
285 | "acos"
286 | "atan"
287 | "atan2"
288 | "log2"
289 | "log10"
290 | "ln"
291 | "exp"
292 | "pi"
293 | "e"
294 | "sign"
295 | "is_nan"
296 | "is_infinite"
297 | "set"
298 | "set_add"
299 | "set_remove"
300 | "set_contains"
301 | "set_union"
302 | "set_intersect"
303 | "set_difference"
304 | "to_list"
305 )
306}
307
308pub struct TypeChecker {
310 diagnostics: Vec<TypeDiagnostic>,
311 scope: TypeScope,
312}
313
314impl TypeChecker {
315 pub fn new() -> Self {
316 Self {
317 diagnostics: Vec::new(),
318 scope: TypeScope::new(),
319 }
320 }
321
322 pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
324 Self::register_declarations_into(&mut self.scope, program);
326
327 for snode in program {
329 if let Node::Pipeline { body, .. } = &snode.node {
330 Self::register_declarations_into(&mut self.scope, body);
331 }
332 }
333
334 for snode in program {
336 match &snode.node {
337 Node::Pipeline { params, body, .. } => {
338 let mut child = self.scope.child();
339 for p in params {
340 child.define_var(p, None);
341 }
342 self.check_block(body, &mut child);
343 }
344 Node::FnDecl {
345 name,
346 type_params,
347 params,
348 return_type,
349 where_clauses,
350 body,
351 ..
352 } => {
353 let required_params =
354 params.iter().filter(|p| p.default_value.is_none()).count();
355 let sig = FnSignature {
356 params: params
357 .iter()
358 .map(|p| (p.name.clone(), p.type_expr.clone()))
359 .collect(),
360 return_type: return_type.clone(),
361 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
362 required_params,
363 where_clauses: where_clauses
364 .iter()
365 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
366 .collect(),
367 };
368 self.scope.define_fn(name, sig);
369 self.check_fn_body(type_params, params, return_type, body, where_clauses);
370 }
371 _ => {
372 let mut scope = self.scope.clone();
373 self.check_node(snode, &mut scope);
374 for (name, ty) in scope.vars {
376 self.scope.vars.entry(name).or_insert(ty);
377 }
378 }
379 }
380 }
381
382 self.diagnostics
383 }
384
385 fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
387 for snode in nodes {
388 match &snode.node {
389 Node::TypeDecl { name, type_expr } => {
390 scope.type_aliases.insert(name.clone(), type_expr.clone());
391 }
392 Node::EnumDecl { name, variants } => {
393 let variant_names: Vec<String> =
394 variants.iter().map(|v| v.name.clone()).collect();
395 scope.enums.insert(name.clone(), variant_names);
396 }
397 Node::InterfaceDecl { name, methods } => {
398 scope.interfaces.insert(name.clone(), methods.clone());
399 }
400 Node::StructDecl { name, fields } => {
401 let field_types: Vec<(String, InferredType)> = fields
402 .iter()
403 .map(|f| (f.name.clone(), f.type_expr.clone()))
404 .collect();
405 scope.structs.insert(name.clone(), field_types);
406 }
407 Node::ImplBlock {
408 type_name, methods, ..
409 } => {
410 let sigs: Vec<ImplMethodSig> = methods
411 .iter()
412 .filter_map(|m| {
413 if let Node::FnDecl {
414 name,
415 params,
416 return_type,
417 ..
418 } = &m.node
419 {
420 let non_self: Vec<_> =
421 params.iter().filter(|p| p.name != "self").collect();
422 let param_count = non_self.len();
423 let param_types: Vec<Option<TypeExpr>> =
424 non_self.iter().map(|p| p.type_expr.clone()).collect();
425 Some(ImplMethodSig {
426 name: name.clone(),
427 param_count,
428 param_types,
429 return_type: return_type.clone(),
430 })
431 } else {
432 None
433 }
434 })
435 .collect();
436 scope.impl_methods.insert(type_name.clone(), sigs);
437 }
438 _ => {}
439 }
440 }
441 }
442
443 fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
444 for stmt in stmts {
445 self.check_node(stmt, scope);
446 }
447 }
448
449 fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
451 match pattern {
452 BindingPattern::Identifier(name) => {
453 scope.define_var(name, None);
454 }
455 BindingPattern::Dict(fields) => {
456 for field in fields {
457 let name = field.alias.as_deref().unwrap_or(&field.key);
458 scope.define_var(name, None);
459 }
460 }
461 BindingPattern::List(elements) => {
462 for elem in elements {
463 scope.define_var(&elem.name, None);
464 }
465 }
466 }
467 }
468
469 fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
470 let span = snode.span;
471 match &snode.node {
472 Node::LetBinding {
473 pattern,
474 type_ann,
475 value,
476 } => {
477 let inferred = self.infer_type(value, scope);
478 if let BindingPattern::Identifier(name) = pattern {
479 if let Some(expected) = type_ann {
480 if let Some(actual) = &inferred {
481 if !self.types_compatible(expected, actual, scope) {
482 let mut msg = format!(
483 "Type mismatch: '{}' declared as {}, but assigned {}",
484 name,
485 format_type(expected),
486 format_type(actual)
487 );
488 if let Some(detail) = shape_mismatch_detail(expected, actual) {
489 msg.push_str(&format!(" ({})", detail));
490 }
491 self.error_at(msg, span);
492 }
493 }
494 }
495 let ty = type_ann.clone().or(inferred);
496 scope.define_var(name, ty);
497 } else {
498 Self::define_pattern_vars(pattern, scope);
499 }
500 }
501
502 Node::VarBinding {
503 pattern,
504 type_ann,
505 value,
506 } => {
507 let inferred = self.infer_type(value, scope);
508 if let BindingPattern::Identifier(name) = pattern {
509 if let Some(expected) = type_ann {
510 if let Some(actual) = &inferred {
511 if !self.types_compatible(expected, actual, scope) {
512 let mut msg = format!(
513 "Type mismatch: '{}' declared as {}, but assigned {}",
514 name,
515 format_type(expected),
516 format_type(actual)
517 );
518 if let Some(detail) = shape_mismatch_detail(expected, actual) {
519 msg.push_str(&format!(" ({})", detail));
520 }
521 self.error_at(msg, span);
522 }
523 }
524 }
525 let ty = type_ann.clone().or(inferred);
526 scope.define_var(name, ty);
527 } else {
528 Self::define_pattern_vars(pattern, scope);
529 }
530 }
531
532 Node::FnDecl {
533 name,
534 type_params,
535 params,
536 return_type,
537 where_clauses,
538 body,
539 ..
540 } => {
541 let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
542 let sig = FnSignature {
543 params: params
544 .iter()
545 .map(|p| (p.name.clone(), p.type_expr.clone()))
546 .collect(),
547 return_type: return_type.clone(),
548 type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
549 required_params,
550 where_clauses: where_clauses
551 .iter()
552 .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
553 .collect(),
554 };
555 scope.define_fn(name, sig.clone());
556 scope.define_var(name, None);
557 self.check_fn_body(type_params, params, return_type, body, where_clauses);
558 }
559
560 Node::FunctionCall { name, args } => {
561 self.check_call(name, args, scope, span);
562 }
563
564 Node::IfElse {
565 condition,
566 then_body,
567 else_body,
568 } => {
569 self.check_node(condition, scope);
570 let mut then_scope = scope.child();
571 if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
573 then_scope.define_var(&var_name, narrowed);
574 }
575 self.check_block(then_body, &mut then_scope);
576 if let Some(else_body) = else_body {
577 let mut else_scope = scope.child();
578 self.check_block(else_body, &mut else_scope);
579 }
580 }
581
582 Node::ForIn {
583 pattern,
584 iterable,
585 body,
586 } => {
587 self.check_node(iterable, scope);
588 let mut loop_scope = scope.child();
589 if let BindingPattern::Identifier(variable) = pattern {
590 let elem_type = match self.infer_type(iterable, scope) {
592 Some(TypeExpr::List(inner)) => Some(*inner),
593 Some(TypeExpr::Named(n)) if n == "string" => {
594 Some(TypeExpr::Named("string".into()))
595 }
596 _ => None,
597 };
598 loop_scope.define_var(variable, elem_type);
599 } else {
600 Self::define_pattern_vars(pattern, &mut loop_scope);
601 }
602 self.check_block(body, &mut loop_scope);
603 }
604
605 Node::WhileLoop { condition, body } => {
606 self.check_node(condition, scope);
607 let mut loop_scope = scope.child();
608 self.check_block(body, &mut loop_scope);
609 }
610
611 Node::TryCatch {
612 body,
613 error_var,
614 catch_body,
615 finally_body,
616 ..
617 } => {
618 let mut try_scope = scope.child();
619 self.check_block(body, &mut try_scope);
620 let mut catch_scope = scope.child();
621 if let Some(var) = error_var {
622 catch_scope.define_var(var, None);
623 }
624 self.check_block(catch_body, &mut catch_scope);
625 if let Some(fb) = finally_body {
626 let mut finally_scope = scope.child();
627 self.check_block(fb, &mut finally_scope);
628 }
629 }
630
631 Node::TryExpr { body } => {
632 let mut try_scope = scope.child();
633 self.check_block(body, &mut try_scope);
634 }
635
636 Node::ReturnStmt {
637 value: Some(val), ..
638 } => {
639 self.check_node(val, scope);
640 }
641
642 Node::Assignment {
643 target, value, op, ..
644 } => {
645 self.check_node(value, scope);
646 if let Node::Identifier(name) = &target.node {
647 if let Some(Some(var_type)) = scope.get_var(name) {
648 let value_type = self.infer_type(value, scope);
649 let assigned = if let Some(op) = op {
650 let var_inferred = scope.get_var(name).cloned().flatten();
651 infer_binary_op_type(op, &var_inferred, &value_type)
652 } else {
653 value_type
654 };
655 if let Some(actual) = &assigned {
656 if !self.types_compatible(var_type, actual, scope) {
657 self.error_at(
658 format!(
659 "Type mismatch: cannot assign {} to '{}' (declared as {})",
660 format_type(actual),
661 name,
662 format_type(var_type)
663 ),
664 span,
665 );
666 }
667 }
668 }
669 }
670 }
671
672 Node::TypeDecl { name, type_expr } => {
673 scope.type_aliases.insert(name.clone(), type_expr.clone());
674 }
675
676 Node::EnumDecl { name, variants } => {
677 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
678 scope.enums.insert(name.clone(), variant_names);
679 }
680
681 Node::StructDecl { name, fields } => {
682 let field_types: Vec<(String, InferredType)> = fields
683 .iter()
684 .map(|f| (f.name.clone(), f.type_expr.clone()))
685 .collect();
686 scope.structs.insert(name.clone(), field_types);
687 }
688
689 Node::InterfaceDecl { name, methods } => {
690 scope.interfaces.insert(name.clone(), methods.clone());
691 }
692
693 Node::ImplBlock {
694 type_name, methods, ..
695 } => {
696 let sigs: Vec<ImplMethodSig> = methods
698 .iter()
699 .filter_map(|m| {
700 if let Node::FnDecl {
701 name,
702 params,
703 return_type,
704 ..
705 } = &m.node
706 {
707 let non_self: Vec<_> =
708 params.iter().filter(|p| p.name != "self").collect();
709 let param_count = non_self.len();
710 let param_types: Vec<Option<TypeExpr>> =
711 non_self.iter().map(|p| p.type_expr.clone()).collect();
712 Some(ImplMethodSig {
713 name: name.clone(),
714 param_count,
715 param_types,
716 return_type: return_type.clone(),
717 })
718 } else {
719 None
720 }
721 })
722 .collect();
723 scope.impl_methods.insert(type_name.clone(), sigs);
724 for method_sn in methods {
725 self.check_node(method_sn, scope);
726 }
727 }
728
729 Node::TryOperator { operand } => {
730 self.check_node(operand, scope);
731 }
732
733 Node::MatchExpr { value, arms } => {
734 self.check_node(value, scope);
735 let value_type = self.infer_type(value, scope);
736 for arm in arms {
737 self.check_node(&arm.pattern, scope);
738 if let Some(ref vt) = value_type {
740 let value_type_name = format_type(vt);
741 let mismatch = match &arm.pattern.node {
742 Node::StringLiteral(_) => {
743 !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
744 }
745 Node::IntLiteral(_) => {
746 !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
747 && !self.types_compatible(
748 vt,
749 &TypeExpr::Named("float".into()),
750 scope,
751 )
752 }
753 Node::FloatLiteral(_) => {
754 !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
755 && !self.types_compatible(
756 vt,
757 &TypeExpr::Named("int".into()),
758 scope,
759 )
760 }
761 Node::BoolLiteral(_) => {
762 !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
763 }
764 _ => false,
765 };
766 if mismatch {
767 let pattern_type = match &arm.pattern.node {
768 Node::StringLiteral(_) => "string",
769 Node::IntLiteral(_) => "int",
770 Node::FloatLiteral(_) => "float",
771 Node::BoolLiteral(_) => "bool",
772 _ => unreachable!(),
773 };
774 self.warning_at(
775 format!(
776 "Match pattern type mismatch: matching {} against {} literal",
777 value_type_name, pattern_type
778 ),
779 arm.pattern.span,
780 );
781 }
782 }
783 let mut arm_scope = scope.child();
784 self.check_block(&arm.body, &mut arm_scope);
785 }
786 self.check_match_exhaustiveness(value, arms, scope, span);
787 }
788
789 Node::BinaryOp { op, left, right } => {
791 self.check_node(left, scope);
792 self.check_node(right, scope);
793 let lt = self.infer_type(left, scope);
795 let rt = self.infer_type(right, scope);
796 if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
797 match op.as_str() {
798 "-" | "*" | "/" | "%" => {
799 let numeric = ["int", "float"];
800 if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
801 self.warning_at(
802 format!(
803 "Operator '{op}' may not be valid for types {} and {}",
804 l, r
805 ),
806 span,
807 );
808 }
809 }
810 "+" => {
811 let valid = ["int", "float", "string", "list", "dict"];
813 if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
814 self.warning_at(
815 format!(
816 "Operator '+' may not be valid for types {} and {}",
817 l, r
818 ),
819 span,
820 );
821 }
822 }
823 _ => {}
824 }
825 }
826 }
827 Node::UnaryOp { operand, .. } => {
828 self.check_node(operand, scope);
829 }
830 Node::MethodCall {
831 object,
832 method,
833 args,
834 ..
835 }
836 | Node::OptionalMethodCall {
837 object,
838 method,
839 args,
840 ..
841 } => {
842 self.check_node(object, scope);
843 for arg in args {
844 self.check_node(arg, scope);
845 }
846 if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
850 if scope.is_generic_type_param(&type_name) {
851 if let Some(iface_name) = scope.get_where_constraint(&type_name) {
852 if let Some(iface_methods) = scope.get_interface(iface_name) {
853 let has_method = iface_methods.iter().any(|m| m.name == *method);
854 if !has_method {
855 self.warning_at(
856 format!(
857 "Method '{}' not found in interface '{}' (constraint on '{}')",
858 method, iface_name, type_name
859 ),
860 span,
861 );
862 }
863 }
864 }
865 }
866 }
867 }
868 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
869 self.check_node(object, scope);
870 }
871 Node::SubscriptAccess { object, index } => {
872 self.check_node(object, scope);
873 self.check_node(index, scope);
874 }
875 Node::SliceAccess { object, start, end } => {
876 self.check_node(object, scope);
877 if let Some(s) = start {
878 self.check_node(s, scope);
879 }
880 if let Some(e) = end {
881 self.check_node(e, scope);
882 }
883 }
884
885 _ => {}
887 }
888 }
889
890 fn check_fn_body(
891 &mut self,
892 type_params: &[TypeParam],
893 params: &[TypedParam],
894 return_type: &Option<TypeExpr>,
895 body: &[SNode],
896 where_clauses: &[WhereClause],
897 ) {
898 let mut fn_scope = self.scope.child();
899 for tp in type_params {
902 fn_scope.generic_type_params.insert(tp.name.clone());
903 }
904 for wc in where_clauses {
906 fn_scope
907 .where_constraints
908 .insert(wc.type_name.clone(), wc.bound.clone());
909 }
910 for param in params {
911 fn_scope.define_var(¶m.name, param.type_expr.clone());
912 if let Some(default) = ¶m.default_value {
913 self.check_node(default, &mut fn_scope);
914 }
915 }
916 self.check_block(body, &mut fn_scope);
917
918 if let Some(ret_type) = return_type {
920 for stmt in body {
921 self.check_return_type(stmt, ret_type, &fn_scope);
922 }
923 }
924 }
925
926 fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
927 let span = snode.span;
928 match &snode.node {
929 Node::ReturnStmt { value: Some(val) } => {
930 let inferred = self.infer_type(val, scope);
931 if let Some(actual) = &inferred {
932 if !self.types_compatible(expected, actual, scope) {
933 self.error_at(
934 format!(
935 "Return type mismatch: expected {}, got {}",
936 format_type(expected),
937 format_type(actual)
938 ),
939 span,
940 );
941 }
942 }
943 }
944 Node::IfElse {
945 then_body,
946 else_body,
947 ..
948 } => {
949 for stmt in then_body {
950 self.check_return_type(stmt, expected, scope);
951 }
952 if let Some(else_body) = else_body {
953 for stmt in else_body {
954 self.check_return_type(stmt, expected, scope);
955 }
956 }
957 }
958 _ => {}
959 }
960 }
961
962 fn satisfies_interface(
968 &self,
969 type_name: &str,
970 interface_name: &str,
971 scope: &TypeScope,
972 ) -> bool {
973 let interface_methods = match scope.get_interface(interface_name) {
974 Some(methods) => methods,
975 None => return false,
976 };
977 let impl_methods = match scope.get_impl_methods(type_name) {
978 Some(methods) => methods,
979 None => return interface_methods.is_empty(),
980 };
981 interface_methods.iter().all(|iface_method| {
982 let iface_params: Vec<_> = iface_method
983 .params
984 .iter()
985 .filter(|p| p.name != "self")
986 .collect();
987 let iface_param_count = iface_params.len();
988 impl_methods.iter().any(|impl_method| {
989 if impl_method.name != iface_method.name
990 || impl_method.param_count != iface_param_count
991 {
992 return false;
993 }
994 for (i, iface_param) in iface_params.iter().enumerate() {
996 if let (Some(expected), Some(actual)) = (
997 &iface_param.type_expr,
998 impl_method.param_types.get(i).and_then(|t| t.as_ref()),
999 ) {
1000 if !self.types_compatible(expected, actual, scope) {
1001 return false;
1002 }
1003 }
1004 }
1005 if let (Some(expected_ret), Some(actual_ret)) =
1007 (&iface_method.return_type, &impl_method.return_type)
1008 {
1009 if !self.types_compatible(expected_ret, actual_ret, scope) {
1010 return false;
1011 }
1012 }
1013 true
1014 })
1015 })
1016 }
1017
1018 fn extract_type_bindings(
1021 param_type: &TypeExpr,
1022 arg_type: &TypeExpr,
1023 type_params: &std::collections::BTreeSet<String>,
1024 bindings: &mut BTreeMap<String, String>,
1025 ) {
1026 match (param_type, arg_type) {
1027 (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
1029 if type_params.contains(param_name) =>
1030 {
1031 bindings
1032 .entry(param_name.clone())
1033 .or_insert(concrete.clone());
1034 }
1035 (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
1037 Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
1038 }
1039 (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
1041 Self::extract_type_bindings(pk, ak, type_params, bindings);
1042 Self::extract_type_bindings(pv, av, type_params, bindings);
1043 }
1044 _ => {}
1045 }
1046 }
1047
1048 fn extract_nil_narrowing(
1049 condition: &SNode,
1050 scope: &TypeScope,
1051 ) -> Option<(String, InferredType)> {
1052 if let Node::BinaryOp { op, left, right } = &condition.node {
1053 if op == "!=" {
1054 let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1056 (left, right)
1057 } else if matches!(left.node, Node::NilLiteral) {
1058 (right, left)
1059 } else {
1060 return None;
1061 };
1062 let _ = nil_node;
1063 if let Node::Identifier(name) = &var_node.node {
1064 if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1066 let narrowed: Vec<TypeExpr> = members
1067 .iter()
1068 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1069 .cloned()
1070 .collect();
1071 return if narrowed.len() == 1 {
1072 Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1073 } else if narrowed.is_empty() {
1074 None
1075 } else {
1076 Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1077 };
1078 }
1079 }
1080 }
1081 }
1082 None
1083 }
1084
1085 fn check_match_exhaustiveness(
1086 &mut self,
1087 value: &SNode,
1088 arms: &[MatchArm],
1089 scope: &TypeScope,
1090 span: Span,
1091 ) {
1092 let enum_name = match &value.node {
1094 Node::PropertyAccess { object, property } if property == "variant" => {
1095 match self.infer_type(object, scope) {
1097 Some(TypeExpr::Named(name)) => {
1098 if scope.get_enum(&name).is_some() {
1099 Some(name)
1100 } else {
1101 None
1102 }
1103 }
1104 _ => None,
1105 }
1106 }
1107 _ => {
1108 match self.infer_type(value, scope) {
1110 Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1111 _ => None,
1112 }
1113 }
1114 };
1115
1116 let Some(enum_name) = enum_name else {
1117 return;
1118 };
1119 let Some(variants) = scope.get_enum(&enum_name) else {
1120 return;
1121 };
1122
1123 let mut covered: Vec<String> = Vec::new();
1125 let mut has_wildcard = false;
1126
1127 for arm in arms {
1128 match &arm.pattern.node {
1129 Node::StringLiteral(s) => covered.push(s.clone()),
1131 Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1133 has_wildcard = true;
1134 }
1135 Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1137 Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1139 _ => {
1140 has_wildcard = true;
1142 }
1143 }
1144 }
1145
1146 if has_wildcard {
1147 return;
1148 }
1149
1150 let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1151 if !missing.is_empty() {
1152 let missing_str = missing
1153 .iter()
1154 .map(|s| format!("\"{}\"", s))
1155 .collect::<Vec<_>>()
1156 .join(", ");
1157 self.warning_at(
1158 format!(
1159 "Non-exhaustive match on enum {}: missing variants {}",
1160 enum_name, missing_str
1161 ),
1162 span,
1163 );
1164 }
1165 }
1166
1167 fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1168 let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1170 if let Some(sig) = scope.get_fn(name).cloned() {
1171 if !has_spread
1172 && !is_builtin(name)
1173 && (args.len() < sig.required_params || args.len() > sig.params.len())
1174 {
1175 let expected = if sig.required_params == sig.params.len() {
1176 format!("{}", sig.params.len())
1177 } else {
1178 format!("{}-{}", sig.required_params, sig.params.len())
1179 };
1180 self.warning_at(
1181 format!(
1182 "Function '{}' expects {} arguments, got {}",
1183 name,
1184 expected,
1185 args.len()
1186 ),
1187 span,
1188 );
1189 }
1190 let call_scope = if sig.type_param_names.is_empty() {
1193 scope.clone()
1194 } else {
1195 let mut s = scope.child();
1196 for tp_name in &sig.type_param_names {
1197 s.generic_type_params.insert(tp_name.clone());
1198 }
1199 s
1200 };
1201 for (i, (arg, (param_name, param_type))) in
1202 args.iter().zip(sig.params.iter()).enumerate()
1203 {
1204 if let Some(expected) = param_type {
1205 let actual = self.infer_type(arg, scope);
1206 if let Some(actual) = &actual {
1207 if !self.types_compatible(expected, actual, &call_scope) {
1208 self.error_at(
1209 format!(
1210 "Argument {} ('{}'): expected {}, got {}",
1211 i + 1,
1212 param_name,
1213 format_type(expected),
1214 format_type(actual)
1215 ),
1216 arg.span,
1217 );
1218 }
1219 }
1220 }
1221 }
1222 if !sig.where_clauses.is_empty() {
1224 let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1227 let type_param_set: std::collections::BTreeSet<String> =
1228 sig.type_param_names.iter().cloned().collect();
1229 for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1230 if let Some(param_ty) = param_type {
1231 if let Some(arg_ty) = self.infer_type(arg, scope) {
1232 Self::extract_type_bindings(
1233 param_ty,
1234 &arg_ty,
1235 &type_param_set,
1236 &mut type_bindings,
1237 );
1238 }
1239 }
1240 }
1241 for (type_param, bound) in &sig.where_clauses {
1242 if let Some(concrete_type) = type_bindings.get(type_param) {
1243 if !self.satisfies_interface(concrete_type, bound, scope) {
1244 self.warning_at(
1245 format!(
1246 "Type '{}' does not satisfy interface '{}': \
1247 required by constraint `where {}: {}`",
1248 concrete_type, bound, type_param, bound
1249 ),
1250 span,
1251 );
1252 }
1253 }
1254 }
1255 }
1256 }
1257 for arg in args {
1259 self.check_node(arg, scope);
1260 }
1261 }
1262
1263 fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1265 match &snode.node {
1266 Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1267 Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1268 Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1269 Some(TypeExpr::Named("string".into()))
1270 }
1271 Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1272 Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1273 Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1274 Node::DictLiteral(entries) => {
1275 let mut fields = Vec::new();
1277 let mut all_string_keys = true;
1278 for entry in entries {
1279 if let Node::StringLiteral(key) = &entry.key.node {
1280 let val_type = self
1281 .infer_type(&entry.value, scope)
1282 .unwrap_or(TypeExpr::Named("nil".into()));
1283 fields.push(ShapeField {
1284 name: key.clone(),
1285 type_expr: val_type,
1286 optional: false,
1287 });
1288 } else {
1289 all_string_keys = false;
1290 break;
1291 }
1292 }
1293 if all_string_keys && !fields.is_empty() {
1294 Some(TypeExpr::Shape(fields))
1295 } else {
1296 Some(TypeExpr::Named("dict".into()))
1297 }
1298 }
1299 Node::Closure { params, body } => {
1300 let all_typed = params.iter().all(|p| p.type_expr.is_some());
1302 if all_typed && !params.is_empty() {
1303 let param_types: Vec<TypeExpr> =
1304 params.iter().filter_map(|p| p.type_expr.clone()).collect();
1305 let ret = body.last().and_then(|last| self.infer_type(last, scope));
1307 if let Some(ret_type) = ret {
1308 return Some(TypeExpr::FnType {
1309 params: param_types,
1310 return_type: Box::new(ret_type),
1311 });
1312 }
1313 }
1314 Some(TypeExpr::Named("closure".into()))
1315 }
1316
1317 Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1318
1319 Node::FunctionCall { name, .. } => {
1320 if scope.get_struct(name).is_some() {
1322 return Some(TypeExpr::Named(name.clone()));
1323 }
1324 if let Some(sig) = scope.get_fn(name) {
1326 return sig.return_type.clone();
1327 }
1328 builtin_return_type(name)
1330 }
1331
1332 Node::BinaryOp { op, left, right } => {
1333 let lt = self.infer_type(left, scope);
1334 let rt = self.infer_type(right, scope);
1335 infer_binary_op_type(op, <, &rt)
1336 }
1337
1338 Node::UnaryOp { op, operand } => {
1339 let t = self.infer_type(operand, scope);
1340 match op.as_str() {
1341 "!" => Some(TypeExpr::Named("bool".into())),
1342 "-" => t, _ => None,
1344 }
1345 }
1346
1347 Node::Ternary {
1348 true_expr,
1349 false_expr,
1350 ..
1351 } => {
1352 let tt = self.infer_type(true_expr, scope);
1353 let ft = self.infer_type(false_expr, scope);
1354 match (&tt, &ft) {
1355 (Some(a), Some(b)) if a == b => tt,
1356 (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1357 (Some(_), None) => tt,
1358 (None, Some(_)) => ft,
1359 (None, None) => None,
1360 }
1361 }
1362
1363 Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1364
1365 Node::PropertyAccess { object, property } => {
1366 if let Node::Identifier(name) = &object.node {
1368 if scope.get_enum(name).is_some() {
1369 return Some(TypeExpr::Named(name.clone()));
1370 }
1371 }
1372 if property == "variant" {
1374 let obj_type = self.infer_type(object, scope);
1375 if let Some(TypeExpr::Named(name)) = &obj_type {
1376 if scope.get_enum(name).is_some() {
1377 return Some(TypeExpr::Named("string".into()));
1378 }
1379 }
1380 }
1381 let obj_type = self.infer_type(object, scope);
1383 if let Some(TypeExpr::Shape(fields)) = &obj_type {
1384 if let Some(field) = fields.iter().find(|f| f.name == *property) {
1385 return Some(field.type_expr.clone());
1386 }
1387 }
1388 None
1389 }
1390
1391 Node::SubscriptAccess { object, index } => {
1392 let obj_type = self.infer_type(object, scope);
1393 match &obj_type {
1394 Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1395 Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1396 Some(TypeExpr::Shape(fields)) => {
1397 if let Node::StringLiteral(key) = &index.node {
1399 fields
1400 .iter()
1401 .find(|f| &f.name == key)
1402 .map(|f| f.type_expr.clone())
1403 } else {
1404 None
1405 }
1406 }
1407 Some(TypeExpr::Named(n)) if n == "list" => None,
1408 Some(TypeExpr::Named(n)) if n == "dict" => None,
1409 Some(TypeExpr::Named(n)) if n == "string" => {
1410 Some(TypeExpr::Named("string".into()))
1411 }
1412 _ => None,
1413 }
1414 }
1415 Node::SliceAccess { object, .. } => {
1416 let obj_type = self.infer_type(object, scope);
1418 match &obj_type {
1419 Some(TypeExpr::List(_)) => obj_type,
1420 Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1421 Some(TypeExpr::Named(n)) if n == "string" => {
1422 Some(TypeExpr::Named("string".into()))
1423 }
1424 _ => None,
1425 }
1426 }
1427 Node::MethodCall { object, method, .. }
1428 | Node::OptionalMethodCall { object, method, .. } => {
1429 let obj_type = self.infer_type(object, scope);
1430 let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1431 || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1432 match method.as_str() {
1433 "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1435 Some(TypeExpr::Named("bool".into()))
1436 }
1437 "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1439 "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1441 | "pad_left" | "pad_right" | "repeat" | "join" => {
1442 Some(TypeExpr::Named("string".into()))
1443 }
1444 "split" | "chars" => Some(TypeExpr::Named("list".into())),
1445 "filter" => {
1447 if is_dict {
1448 Some(TypeExpr::Named("dict".into()))
1449 } else {
1450 Some(TypeExpr::Named("list".into()))
1451 }
1452 }
1453 "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1455 "reduce" | "find" | "first" | "last" => None,
1456 "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1458 "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1459 "to_string" => Some(TypeExpr::Named("string".into())),
1461 "to_int" => Some(TypeExpr::Named("int".into())),
1462 "to_float" => Some(TypeExpr::Named("float".into())),
1463 _ => None,
1464 }
1465 }
1466
1467 Node::TryOperator { operand } => {
1469 match self.infer_type(operand, scope) {
1470 Some(TypeExpr::Named(name)) if name == "Result" => None, _ => None,
1472 }
1473 }
1474
1475 _ => None,
1476 }
1477 }
1478
1479 fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1481 if let TypeExpr::Named(name) = expected {
1483 if scope.is_generic_type_param(name) {
1484 return true;
1485 }
1486 }
1487 if let TypeExpr::Named(name) = actual {
1488 if scope.is_generic_type_param(name) {
1489 return true;
1490 }
1491 }
1492 let expected = self.resolve_alias(expected, scope);
1493 let actual = self.resolve_alias(actual, scope);
1494
1495 if let TypeExpr::Named(iface_name) = &expected {
1498 if scope.get_interface(iface_name).is_some() {
1499 if let TypeExpr::Named(type_name) = &actual {
1500 return self.satisfies_interface(type_name, iface_name, scope);
1501 }
1502 return false;
1503 }
1504 }
1505
1506 match (&expected, &actual) {
1507 (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1508 (TypeExpr::Union(members), actual_type) => members
1509 .iter()
1510 .any(|m| self.types_compatible(m, actual_type, scope)),
1511 (expected_type, TypeExpr::Union(members)) => members
1512 .iter()
1513 .all(|m| self.types_compatible(expected_type, m, scope)),
1514 (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1515 (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1516 (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1517 if expected_field.optional {
1518 return true;
1519 }
1520 af.iter().any(|actual_field| {
1521 actual_field.name == expected_field.name
1522 && self.types_compatible(
1523 &expected_field.type_expr,
1524 &actual_field.type_expr,
1525 scope,
1526 )
1527 })
1528 }),
1529 (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1531 let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1532 keys_ok
1533 && af
1534 .iter()
1535 .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1536 }
1537 (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1539 (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1540 self.types_compatible(expected_inner, actual_inner, scope)
1541 }
1542 (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1543 (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1544 (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1545 self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1546 }
1547 (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1548 (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1549 (
1551 TypeExpr::FnType {
1552 params: ep,
1553 return_type: er,
1554 },
1555 TypeExpr::FnType {
1556 params: ap,
1557 return_type: ar,
1558 },
1559 ) => {
1560 ep.len() == ap.len()
1561 && ep
1562 .iter()
1563 .zip(ap.iter())
1564 .all(|(e, a)| self.types_compatible(e, a, scope))
1565 && self.types_compatible(er, ar, scope)
1566 }
1567 (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1569 (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1570 _ => false,
1571 }
1572 }
1573
1574 fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1575 if let TypeExpr::Named(name) = ty {
1576 if let Some(resolved) = scope.resolve_type(name) {
1577 return resolved.clone();
1578 }
1579 }
1580 ty.clone()
1581 }
1582
1583 fn error_at(&mut self, message: String, span: Span) {
1584 self.diagnostics.push(TypeDiagnostic {
1585 message,
1586 severity: DiagnosticSeverity::Error,
1587 span: Some(span),
1588 });
1589 }
1590
1591 fn warning_at(&mut self, message: String, span: Span) {
1592 self.diagnostics.push(TypeDiagnostic {
1593 message,
1594 severity: DiagnosticSeverity::Warning,
1595 span: Some(span),
1596 });
1597 }
1598}
1599
1600impl Default for TypeChecker {
1601 fn default() -> Self {
1602 Self::new()
1603 }
1604}
1605
1606fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1608 match op {
1609 "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1610 "+" => match (left, right) {
1611 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1612 match (l.as_str(), r.as_str()) {
1613 ("int", "int") => Some(TypeExpr::Named("int".into())),
1614 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1615 ("string", _) => Some(TypeExpr::Named("string".into())),
1616 ("list", "list") => Some(TypeExpr::Named("list".into())),
1617 ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1618 _ => Some(TypeExpr::Named("string".into())),
1619 }
1620 }
1621 _ => None,
1622 },
1623 "-" | "*" | "/" | "%" => match (left, right) {
1624 (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1625 match (l.as_str(), r.as_str()) {
1626 ("int", "int") => Some(TypeExpr::Named("int".into())),
1627 ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1628 _ => None,
1629 }
1630 }
1631 _ => None,
1632 },
1633 "??" => match (left, right) {
1634 (Some(TypeExpr::Union(members)), _) => {
1635 let non_nil: Vec<_> = members
1636 .iter()
1637 .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1638 .cloned()
1639 .collect();
1640 if non_nil.len() == 1 {
1641 Some(non_nil[0].clone())
1642 } else if non_nil.is_empty() {
1643 right.clone()
1644 } else {
1645 Some(TypeExpr::Union(non_nil))
1646 }
1647 }
1648 _ => right.clone(),
1649 },
1650 "|>" => None,
1651 _ => None,
1652 }
1653}
1654
1655pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1660 if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1661 let mut details = Vec::new();
1662 for field in ef {
1663 if field.optional {
1664 continue;
1665 }
1666 match af.iter().find(|f| f.name == field.name) {
1667 None => details.push(format!(
1668 "missing field '{}' ({})",
1669 field.name,
1670 format_type(&field.type_expr)
1671 )),
1672 Some(actual_field) => {
1673 let e_str = format_type(&field.type_expr);
1674 let a_str = format_type(&actual_field.type_expr);
1675 if e_str != a_str {
1676 details.push(format!(
1677 "field '{}' has type {}, expected {}",
1678 field.name, a_str, e_str
1679 ));
1680 }
1681 }
1682 }
1683 }
1684 if details.is_empty() {
1685 None
1686 } else {
1687 Some(details.join("; "))
1688 }
1689 } else {
1690 None
1691 }
1692}
1693
1694pub fn format_type(ty: &TypeExpr) -> String {
1695 match ty {
1696 TypeExpr::Named(n) => n.clone(),
1697 TypeExpr::Union(types) => types
1698 .iter()
1699 .map(format_type)
1700 .collect::<Vec<_>>()
1701 .join(" | "),
1702 TypeExpr::Shape(fields) => {
1703 let inner: Vec<String> = fields
1704 .iter()
1705 .map(|f| {
1706 let opt = if f.optional { "?" } else { "" };
1707 format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1708 })
1709 .collect();
1710 format!("{{{}}}", inner.join(", "))
1711 }
1712 TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1713 TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1714 TypeExpr::FnType {
1715 params,
1716 return_type,
1717 } => {
1718 let params_str = params
1719 .iter()
1720 .map(format_type)
1721 .collect::<Vec<_>>()
1722 .join(", ");
1723 format!("fn({}) -> {}", params_str, format_type(return_type))
1724 }
1725 }
1726}
1727
1728#[cfg(test)]
1729mod tests {
1730 use super::*;
1731 use crate::Parser;
1732 use harn_lexer::Lexer;
1733
1734 fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1735 let mut lexer = Lexer::new(source);
1736 let tokens = lexer.tokenize().unwrap();
1737 let mut parser = Parser::new(tokens);
1738 let program = parser.parse().unwrap();
1739 TypeChecker::new().check(&program)
1740 }
1741
1742 fn errors(source: &str) -> Vec<String> {
1743 check_source(source)
1744 .into_iter()
1745 .filter(|d| d.severity == DiagnosticSeverity::Error)
1746 .map(|d| d.message)
1747 .collect()
1748 }
1749
1750 #[test]
1751 fn test_no_errors_for_untyped_code() {
1752 let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1753 assert!(errs.is_empty());
1754 }
1755
1756 #[test]
1757 fn test_correct_typed_let() {
1758 let errs = errors("pipeline t(task) { let x: int = 42 }");
1759 assert!(errs.is_empty());
1760 }
1761
1762 #[test]
1763 fn test_type_mismatch_let() {
1764 let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1765 assert_eq!(errs.len(), 1);
1766 assert!(errs[0].contains("Type mismatch"));
1767 assert!(errs[0].contains("int"));
1768 assert!(errs[0].contains("string"));
1769 }
1770
1771 #[test]
1772 fn test_correct_typed_fn() {
1773 let errs = errors(
1774 "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1775 );
1776 assert!(errs.is_empty());
1777 }
1778
1779 #[test]
1780 fn test_fn_arg_type_mismatch() {
1781 let errs = errors(
1782 r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1783add("hello", 2) }"#,
1784 );
1785 assert_eq!(errs.len(), 1);
1786 assert!(errs[0].contains("Argument 1"));
1787 assert!(errs[0].contains("expected int"));
1788 }
1789
1790 #[test]
1791 fn test_return_type_mismatch() {
1792 let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1793 assert_eq!(errs.len(), 1);
1794 assert!(errs[0].contains("Return type mismatch"));
1795 }
1796
1797 #[test]
1798 fn test_union_type_compatible() {
1799 let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1800 assert!(errs.is_empty());
1801 }
1802
1803 #[test]
1804 fn test_union_type_mismatch() {
1805 let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1806 assert_eq!(errs.len(), 1);
1807 assert!(errs[0].contains("Type mismatch"));
1808 }
1809
1810 #[test]
1811 fn test_type_inference_propagation() {
1812 let errs = errors(
1813 r#"pipeline t(task) {
1814 fn add(a: int, b: int) -> int { return a + b }
1815 let result: string = add(1, 2)
1816}"#,
1817 );
1818 assert_eq!(errs.len(), 1);
1819 assert!(errs[0].contains("Type mismatch"));
1820 assert!(errs[0].contains("string"));
1821 assert!(errs[0].contains("int"));
1822 }
1823
1824 #[test]
1825 fn test_builtin_return_type_inference() {
1826 let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1827 assert_eq!(errs.len(), 1);
1828 assert!(errs[0].contains("string"));
1829 assert!(errs[0].contains("int"));
1830 }
1831
1832 #[test]
1833 fn test_binary_op_type_inference() {
1834 let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1835 assert_eq!(errs.len(), 1);
1836 }
1837
1838 #[test]
1839 fn test_comparison_returns_bool() {
1840 let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1841 assert!(errs.is_empty());
1842 }
1843
1844 #[test]
1845 fn test_int_float_promotion() {
1846 let errs = errors("pipeline t(task) { let x: float = 42 }");
1847 assert!(errs.is_empty());
1848 }
1849
1850 #[test]
1851 fn test_untyped_code_no_errors() {
1852 let errs = errors(
1853 r#"pipeline t(task) {
1854 fn process(data) {
1855 let result = data + " processed"
1856 return result
1857 }
1858 log(process("hello"))
1859}"#,
1860 );
1861 assert!(errs.is_empty());
1862 }
1863
1864 #[test]
1865 fn test_type_alias() {
1866 let errs = errors(
1867 r#"pipeline t(task) {
1868 type Name = string
1869 let x: Name = "hello"
1870}"#,
1871 );
1872 assert!(errs.is_empty());
1873 }
1874
1875 #[test]
1876 fn test_type_alias_mismatch() {
1877 let errs = errors(
1878 r#"pipeline t(task) {
1879 type Name = string
1880 let x: Name = 42
1881}"#,
1882 );
1883 assert_eq!(errs.len(), 1);
1884 }
1885
1886 #[test]
1887 fn test_assignment_type_check() {
1888 let errs = errors(
1889 r#"pipeline t(task) {
1890 var x: int = 0
1891 x = "hello"
1892}"#,
1893 );
1894 assert_eq!(errs.len(), 1);
1895 assert!(errs[0].contains("cannot assign string"));
1896 }
1897
1898 #[test]
1899 fn test_covariance_int_to_float_in_fn() {
1900 let errs = errors(
1901 "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1902 );
1903 assert!(errs.is_empty());
1904 }
1905
1906 #[test]
1907 fn test_covariance_return_type() {
1908 let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1909 assert!(errs.is_empty());
1910 }
1911
1912 #[test]
1913 fn test_no_contravariance_float_to_int() {
1914 let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1915 assert_eq!(errs.len(), 1);
1916 }
1917
1918 fn warnings(source: &str) -> Vec<String> {
1921 check_source(source)
1922 .into_iter()
1923 .filter(|d| d.severity == DiagnosticSeverity::Warning)
1924 .map(|d| d.message)
1925 .collect()
1926 }
1927
1928 #[test]
1929 fn test_exhaustive_match_no_warning() {
1930 let warns = warnings(
1931 r#"pipeline t(task) {
1932 enum Color { Red, Green, Blue }
1933 let c = Color.Red
1934 match c.variant {
1935 "Red" -> { log("r") }
1936 "Green" -> { log("g") }
1937 "Blue" -> { log("b") }
1938 }
1939}"#,
1940 );
1941 let exhaustive_warns: Vec<_> = warns
1942 .iter()
1943 .filter(|w| w.contains("Non-exhaustive"))
1944 .collect();
1945 assert!(exhaustive_warns.is_empty());
1946 }
1947
1948 #[test]
1949 fn test_non_exhaustive_match_warning() {
1950 let warns = warnings(
1951 r#"pipeline t(task) {
1952 enum Color { Red, Green, Blue }
1953 let c = Color.Red
1954 match c.variant {
1955 "Red" -> { log("r") }
1956 "Green" -> { log("g") }
1957 }
1958}"#,
1959 );
1960 let exhaustive_warns: Vec<_> = warns
1961 .iter()
1962 .filter(|w| w.contains("Non-exhaustive"))
1963 .collect();
1964 assert_eq!(exhaustive_warns.len(), 1);
1965 assert!(exhaustive_warns[0].contains("Blue"));
1966 }
1967
1968 #[test]
1969 fn test_non_exhaustive_multiple_missing() {
1970 let warns = warnings(
1971 r#"pipeline t(task) {
1972 enum Status { Active, Inactive, Pending }
1973 let s = Status.Active
1974 match s.variant {
1975 "Active" -> { log("a") }
1976 }
1977}"#,
1978 );
1979 let exhaustive_warns: Vec<_> = warns
1980 .iter()
1981 .filter(|w| w.contains("Non-exhaustive"))
1982 .collect();
1983 assert_eq!(exhaustive_warns.len(), 1);
1984 assert!(exhaustive_warns[0].contains("Inactive"));
1985 assert!(exhaustive_warns[0].contains("Pending"));
1986 }
1987
1988 #[test]
1989 fn test_enum_construct_type_inference() {
1990 let errs = errors(
1991 r#"pipeline t(task) {
1992 enum Color { Red, Green, Blue }
1993 let c: Color = Color.Red
1994}"#,
1995 );
1996 assert!(errs.is_empty());
1997 }
1998
1999 #[test]
2002 fn test_nil_coalescing_strips_nil() {
2003 let errs = errors(
2005 r#"pipeline t(task) {
2006 let x: string | nil = nil
2007 let y: string = x ?? "default"
2008}"#,
2009 );
2010 assert!(errs.is_empty());
2011 }
2012
2013 #[test]
2014 fn test_shape_mismatch_detail_missing_field() {
2015 let errs = errors(
2016 r#"pipeline t(task) {
2017 let x: {name: string, age: int} = {name: "hello"}
2018}"#,
2019 );
2020 assert_eq!(errs.len(), 1);
2021 assert!(
2022 errs[0].contains("missing field 'age'"),
2023 "expected detail about missing field, got: {}",
2024 errs[0]
2025 );
2026 }
2027
2028 #[test]
2029 fn test_shape_mismatch_detail_wrong_type() {
2030 let errs = errors(
2031 r#"pipeline t(task) {
2032 let x: {name: string, age: int} = {name: 42, age: 10}
2033}"#,
2034 );
2035 assert_eq!(errs.len(), 1);
2036 assert!(
2037 errs[0].contains("field 'name' has type int, expected string"),
2038 "expected detail about wrong type, got: {}",
2039 errs[0]
2040 );
2041 }
2042
2043 #[test]
2046 fn test_match_pattern_string_against_int() {
2047 let warns = warnings(
2048 r#"pipeline t(task) {
2049 let x: int = 42
2050 match x {
2051 "hello" -> { log("bad") }
2052 42 -> { log("ok") }
2053 }
2054}"#,
2055 );
2056 let pattern_warns: Vec<_> = warns
2057 .iter()
2058 .filter(|w| w.contains("Match pattern type mismatch"))
2059 .collect();
2060 assert_eq!(pattern_warns.len(), 1);
2061 assert!(pattern_warns[0].contains("matching int against string literal"));
2062 }
2063
2064 #[test]
2065 fn test_match_pattern_int_against_string() {
2066 let warns = warnings(
2067 r#"pipeline t(task) {
2068 let x: string = "hello"
2069 match x {
2070 42 -> { log("bad") }
2071 "hello" -> { log("ok") }
2072 }
2073}"#,
2074 );
2075 let pattern_warns: Vec<_> = warns
2076 .iter()
2077 .filter(|w| w.contains("Match pattern type mismatch"))
2078 .collect();
2079 assert_eq!(pattern_warns.len(), 1);
2080 assert!(pattern_warns[0].contains("matching string against int literal"));
2081 }
2082
2083 #[test]
2084 fn test_match_pattern_bool_against_int() {
2085 let warns = warnings(
2086 r#"pipeline t(task) {
2087 let x: int = 42
2088 match x {
2089 true -> { log("bad") }
2090 42 -> { log("ok") }
2091 }
2092}"#,
2093 );
2094 let pattern_warns: Vec<_> = warns
2095 .iter()
2096 .filter(|w| w.contains("Match pattern type mismatch"))
2097 .collect();
2098 assert_eq!(pattern_warns.len(), 1);
2099 assert!(pattern_warns[0].contains("matching int against bool literal"));
2100 }
2101
2102 #[test]
2103 fn test_match_pattern_float_against_string() {
2104 let warns = warnings(
2105 r#"pipeline t(task) {
2106 let x: string = "hello"
2107 match x {
2108 3.14 -> { log("bad") }
2109 "hello" -> { log("ok") }
2110 }
2111}"#,
2112 );
2113 let pattern_warns: Vec<_> = warns
2114 .iter()
2115 .filter(|w| w.contains("Match pattern type mismatch"))
2116 .collect();
2117 assert_eq!(pattern_warns.len(), 1);
2118 assert!(pattern_warns[0].contains("matching string against float literal"));
2119 }
2120
2121 #[test]
2122 fn test_match_pattern_int_against_float_ok() {
2123 let warns = warnings(
2125 r#"pipeline t(task) {
2126 let x: float = 3.14
2127 match x {
2128 42 -> { log("ok") }
2129 _ -> { log("default") }
2130 }
2131}"#,
2132 );
2133 let pattern_warns: Vec<_> = warns
2134 .iter()
2135 .filter(|w| w.contains("Match pattern type mismatch"))
2136 .collect();
2137 assert!(pattern_warns.is_empty());
2138 }
2139
2140 #[test]
2141 fn test_match_pattern_float_against_int_ok() {
2142 let warns = warnings(
2144 r#"pipeline t(task) {
2145 let x: int = 42
2146 match x {
2147 3.14 -> { log("close") }
2148 _ -> { log("default") }
2149 }
2150}"#,
2151 );
2152 let pattern_warns: Vec<_> = warns
2153 .iter()
2154 .filter(|w| w.contains("Match pattern type mismatch"))
2155 .collect();
2156 assert!(pattern_warns.is_empty());
2157 }
2158
2159 #[test]
2160 fn test_match_pattern_correct_types_no_warning() {
2161 let warns = warnings(
2162 r#"pipeline t(task) {
2163 let x: int = 42
2164 match x {
2165 1 -> { log("one") }
2166 2 -> { log("two") }
2167 _ -> { log("other") }
2168 }
2169}"#,
2170 );
2171 let pattern_warns: Vec<_> = warns
2172 .iter()
2173 .filter(|w| w.contains("Match pattern type mismatch"))
2174 .collect();
2175 assert!(pattern_warns.is_empty());
2176 }
2177
2178 #[test]
2179 fn test_match_pattern_wildcard_no_warning() {
2180 let warns = warnings(
2181 r#"pipeline t(task) {
2182 let x: int = 42
2183 match x {
2184 _ -> { log("catch all") }
2185 }
2186}"#,
2187 );
2188 let pattern_warns: Vec<_> = warns
2189 .iter()
2190 .filter(|w| w.contains("Match pattern type mismatch"))
2191 .collect();
2192 assert!(pattern_warns.is_empty());
2193 }
2194
2195 #[test]
2196 fn test_match_pattern_untyped_no_warning() {
2197 let warns = warnings(
2199 r#"pipeline t(task) {
2200 let x = some_unknown_fn()
2201 match x {
2202 "hello" -> { log("string") }
2203 42 -> { log("int") }
2204 }
2205}"#,
2206 );
2207 let pattern_warns: Vec<_> = warns
2208 .iter()
2209 .filter(|w| w.contains("Match pattern type mismatch"))
2210 .collect();
2211 assert!(pattern_warns.is_empty());
2212 }
2213}