Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Generic type parameter names in scope (treated as compatible with any type).
39    generic_type_params: std::collections::BTreeSet<String>,
40    parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45    params: Vec<(String, InferredType)>,
46    return_type: InferredType,
47    /// Generic type parameter names declared on the function.
48    type_param_names: Vec<String>,
49    /// Number of required parameters (those without defaults).
50    required_params: usize,
51}
52
53impl TypeScope {
54    fn new() -> Self {
55        Self {
56            vars: BTreeMap::new(),
57            functions: BTreeMap::new(),
58            type_aliases: BTreeMap::new(),
59            enums: BTreeMap::new(),
60            interfaces: BTreeMap::new(),
61            structs: BTreeMap::new(),
62            generic_type_params: std::collections::BTreeSet::new(),
63            parent: None,
64        }
65    }
66
67    fn child(&self) -> Self {
68        Self {
69            vars: BTreeMap::new(),
70            functions: BTreeMap::new(),
71            type_aliases: BTreeMap::new(),
72            enums: BTreeMap::new(),
73            interfaces: BTreeMap::new(),
74            structs: BTreeMap::new(),
75            generic_type_params: std::collections::BTreeSet::new(),
76            parent: Some(Box::new(self.clone())),
77        }
78    }
79
80    fn get_var(&self, name: &str) -> Option<&InferredType> {
81        self.vars
82            .get(name)
83            .or_else(|| self.parent.as_ref()?.get_var(name))
84    }
85
86    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87        self.functions
88            .get(name)
89            .or_else(|| self.parent.as_ref()?.get_fn(name))
90    }
91
92    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93        self.type_aliases
94            .get(name)
95            .or_else(|| self.parent.as_ref()?.resolve_type(name))
96    }
97
98    fn is_generic_type_param(&self, name: &str) -> bool {
99        self.generic_type_params.contains(name)
100            || self
101                .parent
102                .as_ref()
103                .is_some_and(|p| p.is_generic_type_param(name))
104    }
105
106    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107        self.enums
108            .get(name)
109            .or_else(|| self.parent.as_ref()?.get_enum(name))
110    }
111
112    #[allow(dead_code)]
113    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114        self.interfaces
115            .get(name)
116            .or_else(|| self.parent.as_ref()?.get_interface(name))
117    }
118
119    fn define_var(&mut self, name: &str, ty: InferredType) {
120        self.vars.insert(name.to_string(), ty);
121    }
122
123    fn define_fn(&mut self, name: &str, sig: FnSignature) {
124        self.functions.insert(name.to_string(), sig);
125    }
126}
127
128/// Known return types for builtin functions.
129fn builtin_return_type(name: &str) -> InferredType {
130    match name {
131        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133            Some(TypeExpr::Named("nil".into()))
134        }
135        "type_of"
136        | "to_string"
137        | "json_stringify"
138        | "read_file"
139        | "http_get"
140        | "http_post"
141        | "llm_call"
142        | "regex_replace"
143        | "path_join"
144        | "temp_dir"
145        | "date_format"
146        | "format"
147        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
148        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
149        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
150        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
151            Some(TypeExpr::Named("float".into()))
152        }
153        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
154            Some(TypeExpr::Named("bool".into()))
155        }
156        "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list" => {
157            Some(TypeExpr::Named("list".into()))
158        }
159        "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
160        | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
161            Some(TypeExpr::Named("dict".into()))
162        }
163        "metadata_set"
164        | "metadata_save"
165        | "metadata_refresh_hashes"
166        | "invalidate_facts"
167        | "log_json"
168        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
169        "env" | "regex_match" => Some(TypeExpr::Union(vec![
170            TypeExpr::Named("string".into()),
171            TypeExpr::Named("nil".into()),
172        ])),
173        "json_parse" | "json_extract" => None, // could be any type
174        _ => None,
175    }
176}
177
178/// Check if a name is a known builtin.
179fn is_builtin(name: &str) -> bool {
180    matches!(
181        name,
182        "log"
183            | "print"
184            | "println"
185            | "type_of"
186            | "to_string"
187            | "to_int"
188            | "to_float"
189            | "json_stringify"
190            | "json_parse"
191            | "env"
192            | "timestamp"
193            | "sleep"
194            | "read_file"
195            | "write_file"
196            | "exit"
197            | "regex_match"
198            | "regex_replace"
199            | "http_get"
200            | "http_post"
201            | "llm_call"
202            | "agent_loop"
203            | "await"
204            | "cancel"
205            | "file_exists"
206            | "delete_file"
207            | "list_dir"
208            | "mkdir"
209            | "path_join"
210            | "copy_file"
211            | "append_file"
212            | "temp_dir"
213            | "stat"
214            | "exec"
215            | "shell"
216            | "date_now"
217            | "date_format"
218            | "date_parse"
219            | "format"
220            | "json_validate"
221            | "json_extract"
222            | "trim"
223            | "lowercase"
224            | "uppercase"
225            | "split"
226            | "starts_with"
227            | "ends_with"
228            | "contains"
229            | "replace"
230            | "join"
231            | "len"
232            | "substring"
233            | "dirname"
234            | "basename"
235            | "extname"
236            | "sin"
237            | "cos"
238            | "tan"
239            | "asin"
240            | "acos"
241            | "atan"
242            | "atan2"
243            | "log2"
244            | "log10"
245            | "ln"
246            | "exp"
247            | "pi"
248            | "e"
249            | "sign"
250            | "is_nan"
251            | "is_infinite"
252            | "set"
253            | "set_add"
254            | "set_remove"
255            | "set_contains"
256            | "set_union"
257            | "set_intersect"
258            | "set_difference"
259            | "to_list"
260    )
261}
262
263/// The static type checker.
264pub struct TypeChecker {
265    diagnostics: Vec<TypeDiagnostic>,
266    scope: TypeScope,
267}
268
269impl TypeChecker {
270    pub fn new() -> Self {
271        Self {
272            diagnostics: Vec::new(),
273            scope: TypeScope::new(),
274        }
275    }
276
277    /// Check a program and return diagnostics.
278    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
279        // First pass: register type and enum declarations into root scope
280        Self::register_declarations_into(&mut self.scope, program);
281
282        // Also scan pipeline bodies for declarations
283        for snode in program {
284            if let Node::Pipeline { body, .. } = &snode.node {
285                Self::register_declarations_into(&mut self.scope, body);
286            }
287        }
288
289        // Check each top-level node
290        for snode in program {
291            match &snode.node {
292                Node::Pipeline { params, body, .. } => {
293                    let mut child = self.scope.child();
294                    for p in params {
295                        child.define_var(p, None);
296                    }
297                    self.check_block(body, &mut child);
298                }
299                Node::FnDecl {
300                    name,
301                    type_params,
302                    params,
303                    return_type,
304                    body,
305                    ..
306                } => {
307                    let required_params =
308                        params.iter().filter(|p| p.default_value.is_none()).count();
309                    let sig = FnSignature {
310                        params: params
311                            .iter()
312                            .map(|p| (p.name.clone(), p.type_expr.clone()))
313                            .collect(),
314                        return_type: return_type.clone(),
315                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
316                        required_params,
317                    };
318                    self.scope.define_fn(name, sig);
319                    self.check_fn_body(type_params, params, return_type, body);
320                }
321                _ => {
322                    let mut scope = self.scope.clone();
323                    self.check_node(snode, &mut scope);
324                    // Merge any new definitions back into the top-level scope
325                    for (name, ty) in scope.vars {
326                        self.scope.vars.entry(name).or_insert(ty);
327                    }
328                }
329            }
330        }
331
332        self.diagnostics
333    }
334
335    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
336    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
337        for snode in nodes {
338            match &snode.node {
339                Node::TypeDecl { name, type_expr } => {
340                    scope.type_aliases.insert(name.clone(), type_expr.clone());
341                }
342                Node::EnumDecl { name, variants } => {
343                    let variant_names: Vec<String> =
344                        variants.iter().map(|v| v.name.clone()).collect();
345                    scope.enums.insert(name.clone(), variant_names);
346                }
347                Node::InterfaceDecl { name, methods } => {
348                    scope.interfaces.insert(name.clone(), methods.clone());
349                }
350                Node::StructDecl { name, fields } => {
351                    let field_types: Vec<(String, InferredType)> = fields
352                        .iter()
353                        .map(|f| (f.name.clone(), f.type_expr.clone()))
354                        .collect();
355                    scope.structs.insert(name.clone(), field_types);
356                }
357                _ => {}
358            }
359        }
360    }
361
362    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
363        for stmt in stmts {
364            self.check_node(stmt, scope);
365        }
366    }
367
368    /// Define variables from a destructuring pattern in the given scope (as unknown type).
369    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
370        match pattern {
371            BindingPattern::Identifier(name) => {
372                scope.define_var(name, None);
373            }
374            BindingPattern::Dict(fields) => {
375                for field in fields {
376                    let name = field.alias.as_deref().unwrap_or(&field.key);
377                    scope.define_var(name, None);
378                }
379            }
380            BindingPattern::List(elements) => {
381                for elem in elements {
382                    scope.define_var(&elem.name, None);
383                }
384            }
385        }
386    }
387
388    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
389        let span = snode.span;
390        match &snode.node {
391            Node::LetBinding {
392                pattern,
393                type_ann,
394                value,
395            } => {
396                let inferred = self.infer_type(value, scope);
397                if let BindingPattern::Identifier(name) = pattern {
398                    if let Some(expected) = type_ann {
399                        if let Some(actual) = &inferred {
400                            if !self.types_compatible(expected, actual, scope) {
401                                let mut msg = format!(
402                                    "Type mismatch: '{}' declared as {}, but assigned {}",
403                                    name,
404                                    format_type(expected),
405                                    format_type(actual)
406                                );
407                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
408                                    msg.push_str(&format!(" ({})", detail));
409                                }
410                                self.error_at(msg, span);
411                            }
412                        }
413                    }
414                    let ty = type_ann.clone().or(inferred);
415                    scope.define_var(name, ty);
416                } else {
417                    Self::define_pattern_vars(pattern, scope);
418                }
419            }
420
421            Node::VarBinding {
422                pattern,
423                type_ann,
424                value,
425            } => {
426                let inferred = self.infer_type(value, scope);
427                if let BindingPattern::Identifier(name) = pattern {
428                    if let Some(expected) = type_ann {
429                        if let Some(actual) = &inferred {
430                            if !self.types_compatible(expected, actual, scope) {
431                                let mut msg = format!(
432                                    "Type mismatch: '{}' declared as {}, but assigned {}",
433                                    name,
434                                    format_type(expected),
435                                    format_type(actual)
436                                );
437                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
438                                    msg.push_str(&format!(" ({})", detail));
439                                }
440                                self.error_at(msg, span);
441                            }
442                        }
443                    }
444                    let ty = type_ann.clone().or(inferred);
445                    scope.define_var(name, ty);
446                } else {
447                    Self::define_pattern_vars(pattern, scope);
448                }
449            }
450
451            Node::FnDecl {
452                name,
453                type_params,
454                params,
455                return_type,
456                body,
457                ..
458            } => {
459                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
460                let sig = FnSignature {
461                    params: params
462                        .iter()
463                        .map(|p| (p.name.clone(), p.type_expr.clone()))
464                        .collect(),
465                    return_type: return_type.clone(),
466                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
467                    required_params,
468                };
469                scope.define_fn(name, sig.clone());
470                scope.define_var(name, None);
471                self.check_fn_body(type_params, params, return_type, body);
472            }
473
474            Node::FunctionCall { name, args } => {
475                self.check_call(name, args, scope, span);
476            }
477
478            Node::IfElse {
479                condition,
480                then_body,
481                else_body,
482            } => {
483                self.check_node(condition, scope);
484                let mut then_scope = scope.child();
485                // Narrow union types after nil checks: if x != nil, narrow x
486                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
487                    then_scope.define_var(&var_name, narrowed);
488                }
489                self.check_block(then_body, &mut then_scope);
490                if let Some(else_body) = else_body {
491                    let mut else_scope = scope.child();
492                    self.check_block(else_body, &mut else_scope);
493                }
494            }
495
496            Node::ForIn {
497                pattern,
498                iterable,
499                body,
500            } => {
501                self.check_node(iterable, scope);
502                let mut loop_scope = scope.child();
503                if let BindingPattern::Identifier(variable) = pattern {
504                    // Infer loop variable type from iterable
505                    let elem_type = match self.infer_type(iterable, scope) {
506                        Some(TypeExpr::List(inner)) => Some(*inner),
507                        Some(TypeExpr::Named(n)) if n == "string" => {
508                            Some(TypeExpr::Named("string".into()))
509                        }
510                        _ => None,
511                    };
512                    loop_scope.define_var(variable, elem_type);
513                } else {
514                    Self::define_pattern_vars(pattern, &mut loop_scope);
515                }
516                self.check_block(body, &mut loop_scope);
517            }
518
519            Node::WhileLoop { condition, body } => {
520                self.check_node(condition, scope);
521                let mut loop_scope = scope.child();
522                self.check_block(body, &mut loop_scope);
523            }
524
525            Node::TryCatch {
526                body,
527                error_var,
528                catch_body,
529                finally_body,
530                ..
531            } => {
532                let mut try_scope = scope.child();
533                self.check_block(body, &mut try_scope);
534                let mut catch_scope = scope.child();
535                if let Some(var) = error_var {
536                    catch_scope.define_var(var, None);
537                }
538                self.check_block(catch_body, &mut catch_scope);
539                if let Some(fb) = finally_body {
540                    let mut finally_scope = scope.child();
541                    self.check_block(fb, &mut finally_scope);
542                }
543            }
544
545            Node::ReturnStmt {
546                value: Some(val), ..
547            } => {
548                self.check_node(val, scope);
549            }
550
551            Node::Assignment {
552                target, value, op, ..
553            } => {
554                self.check_node(value, scope);
555                if let Node::Identifier(name) = &target.node {
556                    if let Some(Some(var_type)) = scope.get_var(name) {
557                        let value_type = self.infer_type(value, scope);
558                        let assigned = if let Some(op) = op {
559                            let var_inferred = scope.get_var(name).cloned().flatten();
560                            infer_binary_op_type(op, &var_inferred, &value_type)
561                        } else {
562                            value_type
563                        };
564                        if let Some(actual) = &assigned {
565                            if !self.types_compatible(var_type, actual, scope) {
566                                self.error_at(
567                                    format!(
568                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
569                                        format_type(actual),
570                                        name,
571                                        format_type(var_type)
572                                    ),
573                                    span,
574                                );
575                            }
576                        }
577                    }
578                }
579            }
580
581            Node::TypeDecl { name, type_expr } => {
582                scope.type_aliases.insert(name.clone(), type_expr.clone());
583            }
584
585            Node::EnumDecl { name, variants } => {
586                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
587                scope.enums.insert(name.clone(), variant_names);
588            }
589
590            Node::StructDecl { name, fields } => {
591                let field_types: Vec<(String, InferredType)> = fields
592                    .iter()
593                    .map(|f| (f.name.clone(), f.type_expr.clone()))
594                    .collect();
595                scope.structs.insert(name.clone(), field_types);
596            }
597
598            Node::InterfaceDecl { name, methods } => {
599                scope.interfaces.insert(name.clone(), methods.clone());
600            }
601
602            Node::ImplBlock { methods, .. } => {
603                for method_sn in methods {
604                    self.check_node(method_sn, scope);
605                }
606            }
607
608            Node::TryOperator { operand } => {
609                self.check_node(operand, scope);
610            }
611
612            Node::MatchExpr { value, arms } => {
613                self.check_node(value, scope);
614                let value_type = self.infer_type(value, scope);
615                for arm in arms {
616                    self.check_node(&arm.pattern, scope);
617                    // Check for incompatible literal pattern types
618                    if let Some(ref vt) = value_type {
619                        let value_type_name = format_type(vt);
620                        let mismatch = match &arm.pattern.node {
621                            Node::StringLiteral(_) => {
622                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
623                            }
624                            Node::IntLiteral(_) => {
625                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
626                                    && !self.types_compatible(
627                                        vt,
628                                        &TypeExpr::Named("float".into()),
629                                        scope,
630                                    )
631                            }
632                            Node::FloatLiteral(_) => {
633                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
634                                    && !self.types_compatible(
635                                        vt,
636                                        &TypeExpr::Named("int".into()),
637                                        scope,
638                                    )
639                            }
640                            Node::BoolLiteral(_) => {
641                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
642                            }
643                            _ => false,
644                        };
645                        if mismatch {
646                            let pattern_type = match &arm.pattern.node {
647                                Node::StringLiteral(_) => "string",
648                                Node::IntLiteral(_) => "int",
649                                Node::FloatLiteral(_) => "float",
650                                Node::BoolLiteral(_) => "bool",
651                                _ => unreachable!(),
652                            };
653                            self.warning_at(
654                                format!(
655                                    "Match pattern type mismatch: matching {} against {} literal",
656                                    value_type_name, pattern_type
657                                ),
658                                arm.pattern.span,
659                            );
660                        }
661                    }
662                    let mut arm_scope = scope.child();
663                    self.check_block(&arm.body, &mut arm_scope);
664                }
665                self.check_match_exhaustiveness(value, arms, scope, span);
666            }
667
668            // Recurse into nested expressions + validate binary op types
669            Node::BinaryOp { op, left, right } => {
670                self.check_node(left, scope);
671                self.check_node(right, scope);
672                // Validate operator/type compatibility
673                let lt = self.infer_type(left, scope);
674                let rt = self.infer_type(right, scope);
675                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
676                    match op.as_str() {
677                        "-" | "*" | "/" | "%" => {
678                            let numeric = ["int", "float"];
679                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
680                                self.warning_at(
681                                    format!(
682                                        "Operator '{op}' may not be valid for types {} and {}",
683                                        l, r
684                                    ),
685                                    span,
686                                );
687                            }
688                        }
689                        "+" => {
690                            // + is valid for int, float, string, list, dict
691                            let valid = ["int", "float", "string", "list", "dict"];
692                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
693                                self.warning_at(
694                                    format!(
695                                        "Operator '+' may not be valid for types {} and {}",
696                                        l, r
697                                    ),
698                                    span,
699                                );
700                            }
701                        }
702                        _ => {}
703                    }
704                }
705            }
706            Node::UnaryOp { operand, .. } => {
707                self.check_node(operand, scope);
708            }
709            Node::MethodCall { object, args, .. }
710            | Node::OptionalMethodCall { object, args, .. } => {
711                self.check_node(object, scope);
712                for arg in args {
713                    self.check_node(arg, scope);
714                }
715            }
716            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
717                self.check_node(object, scope);
718            }
719            Node::SubscriptAccess { object, index } => {
720                self.check_node(object, scope);
721                self.check_node(index, scope);
722            }
723            Node::SliceAccess { object, start, end } => {
724                self.check_node(object, scope);
725                if let Some(s) = start {
726                    self.check_node(s, scope);
727                }
728                if let Some(e) = end {
729                    self.check_node(e, scope);
730                }
731            }
732
733            // Terminals — nothing to check
734            _ => {}
735        }
736    }
737
738    fn check_fn_body(
739        &mut self,
740        type_params: &[TypeParam],
741        params: &[TypedParam],
742        return_type: &Option<TypeExpr>,
743        body: &[SNode],
744    ) {
745        let mut fn_scope = self.scope.child();
746        // Register generic type parameters so they are treated as compatible
747        // with any concrete type during type checking.
748        for tp in type_params {
749            fn_scope.generic_type_params.insert(tp.name.clone());
750        }
751        for param in params {
752            fn_scope.define_var(&param.name, param.type_expr.clone());
753            if let Some(default) = &param.default_value {
754                self.check_node(default, &mut fn_scope);
755            }
756        }
757        self.check_block(body, &mut fn_scope);
758
759        // Check return statements against declared return type
760        if let Some(ret_type) = return_type {
761            for stmt in body {
762                self.check_return_type(stmt, ret_type, &fn_scope);
763            }
764        }
765    }
766
767    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
768        let span = snode.span;
769        match &snode.node {
770            Node::ReturnStmt { value: Some(val) } => {
771                let inferred = self.infer_type(val, scope);
772                if let Some(actual) = &inferred {
773                    if !self.types_compatible(expected, actual, scope) {
774                        self.error_at(
775                            format!(
776                                "Return type mismatch: expected {}, got {}",
777                                format_type(expected),
778                                format_type(actual)
779                            ),
780                            span,
781                        );
782                    }
783                }
784            }
785            Node::IfElse {
786                then_body,
787                else_body,
788                ..
789            } => {
790                for stmt in then_body {
791                    self.check_return_type(stmt, expected, scope);
792                }
793                if let Some(else_body) = else_body {
794                    for stmt in else_body {
795                        self.check_return_type(stmt, expected, scope);
796                    }
797                }
798            }
799            _ => {}
800        }
801    }
802
803    /// Check if a match expression on an enum's `.variant` property covers all variants.
804    /// Extract narrowing info from nil-check conditions like `x != nil`.
805    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
806    fn extract_nil_narrowing(
807        condition: &SNode,
808        scope: &TypeScope,
809    ) -> Option<(String, InferredType)> {
810        if let Node::BinaryOp { op, left, right } = &condition.node {
811            if op == "!=" {
812                // Check for `x != nil` or `nil != x`
813                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
814                    (left, right)
815                } else if matches!(left.node, Node::NilLiteral) {
816                    (right, left)
817                } else {
818                    return None;
819                };
820                let _ = nil_node;
821                if let Node::Identifier(name) = &var_node.node {
822                    // Look up the variable's type and narrow it
823                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
824                        let narrowed: Vec<TypeExpr> = members
825                            .iter()
826                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
827                            .cloned()
828                            .collect();
829                        return if narrowed.len() == 1 {
830                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
831                        } else if narrowed.is_empty() {
832                            None
833                        } else {
834                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
835                        };
836                    }
837                }
838            }
839        }
840        None
841    }
842
843    fn check_match_exhaustiveness(
844        &mut self,
845        value: &SNode,
846        arms: &[MatchArm],
847        scope: &TypeScope,
848        span: Span,
849    ) {
850        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
851        let enum_name = match &value.node {
852            Node::PropertyAccess { object, property } if property == "variant" => {
853                // Infer the type of the object
854                match self.infer_type(object, scope) {
855                    Some(TypeExpr::Named(name)) => {
856                        if scope.get_enum(&name).is_some() {
857                            Some(name)
858                        } else {
859                            None
860                        }
861                    }
862                    _ => None,
863                }
864            }
865            _ => {
866                // Direct match on an enum value: match <expr> { ... }
867                match self.infer_type(value, scope) {
868                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
869                    _ => None,
870                }
871            }
872        };
873
874        let Some(enum_name) = enum_name else {
875            return;
876        };
877        let Some(variants) = scope.get_enum(&enum_name) else {
878            return;
879        };
880
881        // Collect variant names covered by match arms
882        let mut covered: Vec<String> = Vec::new();
883        let mut has_wildcard = false;
884
885        for arm in arms {
886            match &arm.pattern.node {
887                // String literal pattern (matching on .variant): "VariantA"
888                Node::StringLiteral(s) => covered.push(s.clone()),
889                // Identifier pattern acts as a wildcard/catch-all
890                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
891                    has_wildcard = true;
892                }
893                // Direct enum construct pattern: EnumName.Variant
894                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
895                // PropertyAccess pattern: EnumName.Variant (no args)
896                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
897                _ => {
898                    // Unknown pattern shape — conservatively treat as wildcard
899                    has_wildcard = true;
900                }
901            }
902        }
903
904        if has_wildcard {
905            return;
906        }
907
908        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
909        if !missing.is_empty() {
910            let missing_str = missing
911                .iter()
912                .map(|s| format!("\"{}\"", s))
913                .collect::<Vec<_>>()
914                .join(", ");
915            self.warning_at(
916                format!(
917                    "Non-exhaustive match on enum {}: missing variants {}",
918                    enum_name, missing_str
919                ),
920                span,
921            );
922        }
923    }
924
925    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
926        // Check against known function signatures
927        if let Some(sig) = scope.get_fn(name).cloned() {
928            if !is_builtin(name)
929                && (args.len() < sig.required_params || args.len() > sig.params.len())
930            {
931                let expected = if sig.required_params == sig.params.len() {
932                    format!("{}", sig.params.len())
933                } else {
934                    format!("{}-{}", sig.required_params, sig.params.len())
935                };
936                self.warning_at(
937                    format!(
938                        "Function '{}' expects {} arguments, got {}",
939                        name,
940                        expected,
941                        args.len()
942                    ),
943                    span,
944                );
945            }
946            // Build a scope that includes the function's generic type params
947            // so they are treated as compatible with any concrete type.
948            let call_scope = if sig.type_param_names.is_empty() {
949                scope.clone()
950            } else {
951                let mut s = scope.child();
952                for tp_name in &sig.type_param_names {
953                    s.generic_type_params.insert(tp_name.clone());
954                }
955                s
956            };
957            for (i, (arg, (param_name, param_type))) in
958                args.iter().zip(sig.params.iter()).enumerate()
959            {
960                if let Some(expected) = param_type {
961                    let actual = self.infer_type(arg, scope);
962                    if let Some(actual) = &actual {
963                        if !self.types_compatible(expected, actual, &call_scope) {
964                            self.error_at(
965                                format!(
966                                    "Argument {} ('{}'): expected {}, got {}",
967                                    i + 1,
968                                    param_name,
969                                    format_type(expected),
970                                    format_type(actual)
971                                ),
972                                arg.span,
973                            );
974                        }
975                    }
976                }
977            }
978        }
979        // Check args recursively
980        for arg in args {
981            self.check_node(arg, scope);
982        }
983    }
984
985    /// Infer the type of an expression.
986    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
987        match &snode.node {
988            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
989            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
990            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
991                Some(TypeExpr::Named("string".into()))
992            }
993            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
994            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
995            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
996            Node::DictLiteral(entries) => {
997                // Infer shape type when all keys are string literals
998                let mut fields = Vec::new();
999                let mut all_string_keys = true;
1000                for entry in entries {
1001                    if let Node::StringLiteral(key) = &entry.key.node {
1002                        let val_type = self
1003                            .infer_type(&entry.value, scope)
1004                            .unwrap_or(TypeExpr::Named("nil".into()));
1005                        fields.push(ShapeField {
1006                            name: key.clone(),
1007                            type_expr: val_type,
1008                            optional: false,
1009                        });
1010                    } else {
1011                        all_string_keys = false;
1012                        break;
1013                    }
1014                }
1015                if all_string_keys && !fields.is_empty() {
1016                    Some(TypeExpr::Shape(fields))
1017                } else {
1018                    Some(TypeExpr::Named("dict".into()))
1019                }
1020            }
1021            Node::Closure { params, body } => {
1022                // If all params are typed and we can infer a return type, produce FnType
1023                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1024                if all_typed && !params.is_empty() {
1025                    let param_types: Vec<TypeExpr> =
1026                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1027                    // Try to infer return type from last expression in body
1028                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1029                    if let Some(ret_type) = ret {
1030                        return Some(TypeExpr::FnType {
1031                            params: param_types,
1032                            return_type: Box::new(ret_type),
1033                        });
1034                    }
1035                }
1036                Some(TypeExpr::Named("closure".into()))
1037            }
1038
1039            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1040
1041            Node::FunctionCall { name, .. } => {
1042                // Check user-defined function return types
1043                if let Some(sig) = scope.get_fn(name) {
1044                    return sig.return_type.clone();
1045                }
1046                // Check builtin return types
1047                builtin_return_type(name)
1048            }
1049
1050            Node::BinaryOp { op, left, right } => {
1051                let lt = self.infer_type(left, scope);
1052                let rt = self.infer_type(right, scope);
1053                infer_binary_op_type(op, &lt, &rt)
1054            }
1055
1056            Node::UnaryOp { op, operand } => {
1057                let t = self.infer_type(operand, scope);
1058                match op.as_str() {
1059                    "!" => Some(TypeExpr::Named("bool".into())),
1060                    "-" => t, // negation preserves type
1061                    _ => None,
1062                }
1063            }
1064
1065            Node::Ternary {
1066                true_expr,
1067                false_expr,
1068                ..
1069            } => {
1070                let tt = self.infer_type(true_expr, scope);
1071                let ft = self.infer_type(false_expr, scope);
1072                match (&tt, &ft) {
1073                    (Some(a), Some(b)) if a == b => tt,
1074                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1075                    (Some(_), None) => tt,
1076                    (None, Some(_)) => ft,
1077                    (None, None) => None,
1078                }
1079            }
1080
1081            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1082
1083            Node::PropertyAccess { object, property } => {
1084                // EnumName.Variant → infer as the enum type
1085                if let Node::Identifier(name) = &object.node {
1086                    if scope.get_enum(name).is_some() {
1087                        return Some(TypeExpr::Named(name.clone()));
1088                    }
1089                }
1090                // .variant on an enum value → string
1091                if property == "variant" {
1092                    let obj_type = self.infer_type(object, scope);
1093                    if let Some(TypeExpr::Named(name)) = &obj_type {
1094                        if scope.get_enum(name).is_some() {
1095                            return Some(TypeExpr::Named("string".into()));
1096                        }
1097                    }
1098                }
1099                // Shape field access: obj.field → field type
1100                let obj_type = self.infer_type(object, scope);
1101                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1102                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1103                        return Some(field.type_expr.clone());
1104                    }
1105                }
1106                None
1107            }
1108
1109            Node::SubscriptAccess { object, index } => {
1110                let obj_type = self.infer_type(object, scope);
1111                match &obj_type {
1112                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1113                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1114                    Some(TypeExpr::Shape(fields)) => {
1115                        // If index is a string literal, look up the field type
1116                        if let Node::StringLiteral(key) = &index.node {
1117                            fields
1118                                .iter()
1119                                .find(|f| &f.name == key)
1120                                .map(|f| f.type_expr.clone())
1121                        } else {
1122                            None
1123                        }
1124                    }
1125                    Some(TypeExpr::Named(n)) if n == "list" => None,
1126                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1127                    Some(TypeExpr::Named(n)) if n == "string" => {
1128                        Some(TypeExpr::Named("string".into()))
1129                    }
1130                    _ => None,
1131                }
1132            }
1133            Node::SliceAccess { object, .. } => {
1134                // Slicing a list returns the same list type; slicing a string returns string
1135                let obj_type = self.infer_type(object, scope);
1136                match &obj_type {
1137                    Some(TypeExpr::List(_)) => obj_type,
1138                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1139                    Some(TypeExpr::Named(n)) if n == "string" => {
1140                        Some(TypeExpr::Named("string".into()))
1141                    }
1142                    _ => None,
1143                }
1144            }
1145            Node::MethodCall { object, method, .. }
1146            | Node::OptionalMethodCall { object, method, .. } => {
1147                let obj_type = self.infer_type(object, scope);
1148                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1149                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1150                match method.as_str() {
1151                    // Shared: bool-returning methods
1152                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1153                        Some(TypeExpr::Named("bool".into()))
1154                    }
1155                    // Shared: int-returning methods
1156                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1157                    // String methods
1158                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1159                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1160                        Some(TypeExpr::Named("string".into()))
1161                    }
1162                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1163                    // filter returns dict for dicts, list for lists
1164                    "filter" => {
1165                        if is_dict {
1166                            Some(TypeExpr::Named("dict".into()))
1167                        } else {
1168                            Some(TypeExpr::Named("list".into()))
1169                        }
1170                    }
1171                    // List methods
1172                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1173                    "reduce" | "find" | "first" | "last" => None,
1174                    // Dict methods
1175                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1176                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1177                    // Conversions
1178                    "to_string" => Some(TypeExpr::Named("string".into())),
1179                    "to_int" => Some(TypeExpr::Named("int".into())),
1180                    "to_float" => Some(TypeExpr::Named("float".into())),
1181                    _ => None,
1182                }
1183            }
1184
1185            // TryOperator on Result<T, E> produces T
1186            Node::TryOperator { operand } => {
1187                match self.infer_type(operand, scope) {
1188                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1189                    _ => None,
1190                }
1191            }
1192
1193            _ => None,
1194        }
1195    }
1196
1197    /// Check if two types are compatible (actual can be assigned to expected).
1198    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1199        // Generic type parameters match anything.
1200        if let TypeExpr::Named(name) = expected {
1201            if scope.is_generic_type_param(name) {
1202                return true;
1203            }
1204        }
1205        if let TypeExpr::Named(name) = actual {
1206            if scope.is_generic_type_param(name) {
1207                return true;
1208            }
1209        }
1210        let expected = self.resolve_alias(expected, scope);
1211        let actual = self.resolve_alias(actual, scope);
1212
1213        match (&expected, &actual) {
1214            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1215            (TypeExpr::Union(members), actual_type) => members
1216                .iter()
1217                .any(|m| self.types_compatible(m, actual_type, scope)),
1218            (expected_type, TypeExpr::Union(members)) => members
1219                .iter()
1220                .all(|m| self.types_compatible(expected_type, m, scope)),
1221            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1222            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1223            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1224                if expected_field.optional {
1225                    return true;
1226                }
1227                af.iter().any(|actual_field| {
1228                    actual_field.name == expected_field.name
1229                        && self.types_compatible(
1230                            &expected_field.type_expr,
1231                            &actual_field.type_expr,
1232                            scope,
1233                        )
1234                })
1235            }),
1236            // dict<K, V> expected, Shape actual → all field values must match V
1237            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1238                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1239                keys_ok
1240                    && af
1241                        .iter()
1242                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1243            }
1244            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1245            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1246            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1247                self.types_compatible(expected_inner, actual_inner, scope)
1248            }
1249            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1250            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1251            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1252                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1253            }
1254            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1255            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1256            // FnType compatibility: params match positionally and return types match
1257            (
1258                TypeExpr::FnType {
1259                    params: ep,
1260                    return_type: er,
1261                },
1262                TypeExpr::FnType {
1263                    params: ap,
1264                    return_type: ar,
1265                },
1266            ) => {
1267                ep.len() == ap.len()
1268                    && ep
1269                        .iter()
1270                        .zip(ap.iter())
1271                        .all(|(e, a)| self.types_compatible(e, a, scope))
1272                    && self.types_compatible(er, ar, scope)
1273            }
1274            // FnType is compatible with Named("closure") for backward compat
1275            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1276            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1277            _ => false,
1278        }
1279    }
1280
1281    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1282        if let TypeExpr::Named(name) = ty {
1283            if let Some(resolved) = scope.resolve_type(name) {
1284                return resolved.clone();
1285            }
1286        }
1287        ty.clone()
1288    }
1289
1290    fn error_at(&mut self, message: String, span: Span) {
1291        self.diagnostics.push(TypeDiagnostic {
1292            message,
1293            severity: DiagnosticSeverity::Error,
1294            span: Some(span),
1295        });
1296    }
1297
1298    fn warning_at(&mut self, message: String, span: Span) {
1299        self.diagnostics.push(TypeDiagnostic {
1300            message,
1301            severity: DiagnosticSeverity::Warning,
1302            span: Some(span),
1303        });
1304    }
1305}
1306
1307impl Default for TypeChecker {
1308    fn default() -> Self {
1309        Self::new()
1310    }
1311}
1312
1313/// Infer the result type of a binary operation.
1314fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1315    match op {
1316        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1317        "+" => match (left, right) {
1318            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1319                match (l.as_str(), r.as_str()) {
1320                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1321                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1322                    ("string", _) => Some(TypeExpr::Named("string".into())),
1323                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1324                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1325                    _ => Some(TypeExpr::Named("string".into())),
1326                }
1327            }
1328            _ => None,
1329        },
1330        "-" | "*" | "/" | "%" => match (left, right) {
1331            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1332                match (l.as_str(), r.as_str()) {
1333                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1334                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1335                    _ => None,
1336                }
1337            }
1338            _ => None,
1339        },
1340        "??" => match (left, right) {
1341            (Some(TypeExpr::Union(members)), _) => {
1342                let non_nil: Vec<_> = members
1343                    .iter()
1344                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1345                    .cloned()
1346                    .collect();
1347                if non_nil.len() == 1 {
1348                    Some(non_nil[0].clone())
1349                } else if non_nil.is_empty() {
1350                    right.clone()
1351                } else {
1352                    Some(TypeExpr::Union(non_nil))
1353                }
1354            }
1355            _ => right.clone(),
1356        },
1357        "|>" => None,
1358        _ => None,
1359    }
1360}
1361
1362/// Format a type expression for display in error messages.
1363/// Produce a detail string describing why a Shape type is incompatible with
1364/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1365/// has type int, expected string".  Returns `None` if both types are not shapes.
1366pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1367    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1368        let mut details = Vec::new();
1369        for field in ef {
1370            if field.optional {
1371                continue;
1372            }
1373            match af.iter().find(|f| f.name == field.name) {
1374                None => details.push(format!(
1375                    "missing field '{}' ({})",
1376                    field.name,
1377                    format_type(&field.type_expr)
1378                )),
1379                Some(actual_field) => {
1380                    let e_str = format_type(&field.type_expr);
1381                    let a_str = format_type(&actual_field.type_expr);
1382                    if e_str != a_str {
1383                        details.push(format!(
1384                            "field '{}' has type {}, expected {}",
1385                            field.name, a_str, e_str
1386                        ));
1387                    }
1388                }
1389            }
1390        }
1391        if details.is_empty() {
1392            None
1393        } else {
1394            Some(details.join("; "))
1395        }
1396    } else {
1397        None
1398    }
1399}
1400
1401pub fn format_type(ty: &TypeExpr) -> String {
1402    match ty {
1403        TypeExpr::Named(n) => n.clone(),
1404        TypeExpr::Union(types) => types
1405            .iter()
1406            .map(format_type)
1407            .collect::<Vec<_>>()
1408            .join(" | "),
1409        TypeExpr::Shape(fields) => {
1410            let inner: Vec<String> = fields
1411                .iter()
1412                .map(|f| {
1413                    let opt = if f.optional { "?" } else { "" };
1414                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1415                })
1416                .collect();
1417            format!("{{{}}}", inner.join(", "))
1418        }
1419        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1420        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1421        TypeExpr::FnType {
1422            params,
1423            return_type,
1424        } => {
1425            let params_str = params
1426                .iter()
1427                .map(format_type)
1428                .collect::<Vec<_>>()
1429                .join(", ");
1430            format!("fn({}) -> {}", params_str, format_type(return_type))
1431        }
1432    }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437    use super::*;
1438    use crate::Parser;
1439    use harn_lexer::Lexer;
1440
1441    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1442        let mut lexer = Lexer::new(source);
1443        let tokens = lexer.tokenize().unwrap();
1444        let mut parser = Parser::new(tokens);
1445        let program = parser.parse().unwrap();
1446        TypeChecker::new().check(&program)
1447    }
1448
1449    fn errors(source: &str) -> Vec<String> {
1450        check_source(source)
1451            .into_iter()
1452            .filter(|d| d.severity == DiagnosticSeverity::Error)
1453            .map(|d| d.message)
1454            .collect()
1455    }
1456
1457    #[test]
1458    fn test_no_errors_for_untyped_code() {
1459        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1460        assert!(errs.is_empty());
1461    }
1462
1463    #[test]
1464    fn test_correct_typed_let() {
1465        let errs = errors("pipeline t(task) { let x: int = 42 }");
1466        assert!(errs.is_empty());
1467    }
1468
1469    #[test]
1470    fn test_type_mismatch_let() {
1471        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1472        assert_eq!(errs.len(), 1);
1473        assert!(errs[0].contains("Type mismatch"));
1474        assert!(errs[0].contains("int"));
1475        assert!(errs[0].contains("string"));
1476    }
1477
1478    #[test]
1479    fn test_correct_typed_fn() {
1480        let errs = errors(
1481            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1482        );
1483        assert!(errs.is_empty());
1484    }
1485
1486    #[test]
1487    fn test_fn_arg_type_mismatch() {
1488        let errs = errors(
1489            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1490add("hello", 2) }"#,
1491        );
1492        assert_eq!(errs.len(), 1);
1493        assert!(errs[0].contains("Argument 1"));
1494        assert!(errs[0].contains("expected int"));
1495    }
1496
1497    #[test]
1498    fn test_return_type_mismatch() {
1499        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1500        assert_eq!(errs.len(), 1);
1501        assert!(errs[0].contains("Return type mismatch"));
1502    }
1503
1504    #[test]
1505    fn test_union_type_compatible() {
1506        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1507        assert!(errs.is_empty());
1508    }
1509
1510    #[test]
1511    fn test_union_type_mismatch() {
1512        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1513        assert_eq!(errs.len(), 1);
1514        assert!(errs[0].contains("Type mismatch"));
1515    }
1516
1517    #[test]
1518    fn test_type_inference_propagation() {
1519        let errs = errors(
1520            r#"pipeline t(task) {
1521  fn add(a: int, b: int) -> int { return a + b }
1522  let result: string = add(1, 2)
1523}"#,
1524        );
1525        assert_eq!(errs.len(), 1);
1526        assert!(errs[0].contains("Type mismatch"));
1527        assert!(errs[0].contains("string"));
1528        assert!(errs[0].contains("int"));
1529    }
1530
1531    #[test]
1532    fn test_builtin_return_type_inference() {
1533        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1534        assert_eq!(errs.len(), 1);
1535        assert!(errs[0].contains("string"));
1536        assert!(errs[0].contains("int"));
1537    }
1538
1539    #[test]
1540    fn test_binary_op_type_inference() {
1541        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1542        assert_eq!(errs.len(), 1);
1543    }
1544
1545    #[test]
1546    fn test_comparison_returns_bool() {
1547        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1548        assert!(errs.is_empty());
1549    }
1550
1551    #[test]
1552    fn test_int_float_promotion() {
1553        let errs = errors("pipeline t(task) { let x: float = 42 }");
1554        assert!(errs.is_empty());
1555    }
1556
1557    #[test]
1558    fn test_untyped_code_no_errors() {
1559        let errs = errors(
1560            r#"pipeline t(task) {
1561  fn process(data) {
1562    let result = data + " processed"
1563    return result
1564  }
1565  log(process("hello"))
1566}"#,
1567        );
1568        assert!(errs.is_empty());
1569    }
1570
1571    #[test]
1572    fn test_type_alias() {
1573        let errs = errors(
1574            r#"pipeline t(task) {
1575  type Name = string
1576  let x: Name = "hello"
1577}"#,
1578        );
1579        assert!(errs.is_empty());
1580    }
1581
1582    #[test]
1583    fn test_type_alias_mismatch() {
1584        let errs = errors(
1585            r#"pipeline t(task) {
1586  type Name = string
1587  let x: Name = 42
1588}"#,
1589        );
1590        assert_eq!(errs.len(), 1);
1591    }
1592
1593    #[test]
1594    fn test_assignment_type_check() {
1595        let errs = errors(
1596            r#"pipeline t(task) {
1597  var x: int = 0
1598  x = "hello"
1599}"#,
1600        );
1601        assert_eq!(errs.len(), 1);
1602        assert!(errs[0].contains("cannot assign string"));
1603    }
1604
1605    #[test]
1606    fn test_covariance_int_to_float_in_fn() {
1607        let errs = errors(
1608            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1609        );
1610        assert!(errs.is_empty());
1611    }
1612
1613    #[test]
1614    fn test_covariance_return_type() {
1615        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1616        assert!(errs.is_empty());
1617    }
1618
1619    #[test]
1620    fn test_no_contravariance_float_to_int() {
1621        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1622        assert_eq!(errs.len(), 1);
1623    }
1624
1625    // --- Exhaustiveness checking tests ---
1626
1627    fn warnings(source: &str) -> Vec<String> {
1628        check_source(source)
1629            .into_iter()
1630            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1631            .map(|d| d.message)
1632            .collect()
1633    }
1634
1635    #[test]
1636    fn test_exhaustive_match_no_warning() {
1637        let warns = warnings(
1638            r#"pipeline t(task) {
1639  enum Color { Red, Green, Blue }
1640  let c = Color.Red
1641  match c.variant {
1642    "Red" -> { log("r") }
1643    "Green" -> { log("g") }
1644    "Blue" -> { log("b") }
1645  }
1646}"#,
1647        );
1648        let exhaustive_warns: Vec<_> = warns
1649            .iter()
1650            .filter(|w| w.contains("Non-exhaustive"))
1651            .collect();
1652        assert!(exhaustive_warns.is_empty());
1653    }
1654
1655    #[test]
1656    fn test_non_exhaustive_match_warning() {
1657        let warns = warnings(
1658            r#"pipeline t(task) {
1659  enum Color { Red, Green, Blue }
1660  let c = Color.Red
1661  match c.variant {
1662    "Red" -> { log("r") }
1663    "Green" -> { log("g") }
1664  }
1665}"#,
1666        );
1667        let exhaustive_warns: Vec<_> = warns
1668            .iter()
1669            .filter(|w| w.contains("Non-exhaustive"))
1670            .collect();
1671        assert_eq!(exhaustive_warns.len(), 1);
1672        assert!(exhaustive_warns[0].contains("Blue"));
1673    }
1674
1675    #[test]
1676    fn test_non_exhaustive_multiple_missing() {
1677        let warns = warnings(
1678            r#"pipeline t(task) {
1679  enum Status { Active, Inactive, Pending }
1680  let s = Status.Active
1681  match s.variant {
1682    "Active" -> { log("a") }
1683  }
1684}"#,
1685        );
1686        let exhaustive_warns: Vec<_> = warns
1687            .iter()
1688            .filter(|w| w.contains("Non-exhaustive"))
1689            .collect();
1690        assert_eq!(exhaustive_warns.len(), 1);
1691        assert!(exhaustive_warns[0].contains("Inactive"));
1692        assert!(exhaustive_warns[0].contains("Pending"));
1693    }
1694
1695    #[test]
1696    fn test_enum_construct_type_inference() {
1697        let errs = errors(
1698            r#"pipeline t(task) {
1699  enum Color { Red, Green, Blue }
1700  let c: Color = Color.Red
1701}"#,
1702        );
1703        assert!(errs.is_empty());
1704    }
1705
1706    // --- Type narrowing tests ---
1707
1708    #[test]
1709    fn test_nil_coalescing_strips_nil() {
1710        // After ??, nil should be stripped from the type
1711        let errs = errors(
1712            r#"pipeline t(task) {
1713  let x: string | nil = nil
1714  let y: string = x ?? "default"
1715}"#,
1716        );
1717        assert!(errs.is_empty());
1718    }
1719
1720    #[test]
1721    fn test_shape_mismatch_detail_missing_field() {
1722        let errs = errors(
1723            r#"pipeline t(task) {
1724  let x: {name: string, age: int} = {name: "hello"}
1725}"#,
1726        );
1727        assert_eq!(errs.len(), 1);
1728        assert!(
1729            errs[0].contains("missing field 'age'"),
1730            "expected detail about missing field, got: {}",
1731            errs[0]
1732        );
1733    }
1734
1735    #[test]
1736    fn test_shape_mismatch_detail_wrong_type() {
1737        let errs = errors(
1738            r#"pipeline t(task) {
1739  let x: {name: string, age: int} = {name: 42, age: 10}
1740}"#,
1741        );
1742        assert_eq!(errs.len(), 1);
1743        assert!(
1744            errs[0].contains("field 'name' has type int, expected string"),
1745            "expected detail about wrong type, got: {}",
1746            errs[0]
1747        );
1748    }
1749
1750    // --- Match pattern type validation tests ---
1751
1752    #[test]
1753    fn test_match_pattern_string_against_int() {
1754        let warns = warnings(
1755            r#"pipeline t(task) {
1756  let x: int = 42
1757  match x {
1758    "hello" -> { log("bad") }
1759    42 -> { log("ok") }
1760  }
1761}"#,
1762        );
1763        let pattern_warns: Vec<_> = warns
1764            .iter()
1765            .filter(|w| w.contains("Match pattern type mismatch"))
1766            .collect();
1767        assert_eq!(pattern_warns.len(), 1);
1768        assert!(pattern_warns[0].contains("matching int against string literal"));
1769    }
1770
1771    #[test]
1772    fn test_match_pattern_int_against_string() {
1773        let warns = warnings(
1774            r#"pipeline t(task) {
1775  let x: string = "hello"
1776  match x {
1777    42 -> { log("bad") }
1778    "hello" -> { log("ok") }
1779  }
1780}"#,
1781        );
1782        let pattern_warns: Vec<_> = warns
1783            .iter()
1784            .filter(|w| w.contains("Match pattern type mismatch"))
1785            .collect();
1786        assert_eq!(pattern_warns.len(), 1);
1787        assert!(pattern_warns[0].contains("matching string against int literal"));
1788    }
1789
1790    #[test]
1791    fn test_match_pattern_bool_against_int() {
1792        let warns = warnings(
1793            r#"pipeline t(task) {
1794  let x: int = 42
1795  match x {
1796    true -> { log("bad") }
1797    42 -> { log("ok") }
1798  }
1799}"#,
1800        );
1801        let pattern_warns: Vec<_> = warns
1802            .iter()
1803            .filter(|w| w.contains("Match pattern type mismatch"))
1804            .collect();
1805        assert_eq!(pattern_warns.len(), 1);
1806        assert!(pattern_warns[0].contains("matching int against bool literal"));
1807    }
1808
1809    #[test]
1810    fn test_match_pattern_float_against_string() {
1811        let warns = warnings(
1812            r#"pipeline t(task) {
1813  let x: string = "hello"
1814  match x {
1815    3.14 -> { log("bad") }
1816    "hello" -> { log("ok") }
1817  }
1818}"#,
1819        );
1820        let pattern_warns: Vec<_> = warns
1821            .iter()
1822            .filter(|w| w.contains("Match pattern type mismatch"))
1823            .collect();
1824        assert_eq!(pattern_warns.len(), 1);
1825        assert!(pattern_warns[0].contains("matching string against float literal"));
1826    }
1827
1828    #[test]
1829    fn test_match_pattern_int_against_float_ok() {
1830        // int and float are compatible for match patterns
1831        let warns = warnings(
1832            r#"pipeline t(task) {
1833  let x: float = 3.14
1834  match x {
1835    42 -> { log("ok") }
1836    _ -> { log("default") }
1837  }
1838}"#,
1839        );
1840        let pattern_warns: Vec<_> = warns
1841            .iter()
1842            .filter(|w| w.contains("Match pattern type mismatch"))
1843            .collect();
1844        assert!(pattern_warns.is_empty());
1845    }
1846
1847    #[test]
1848    fn test_match_pattern_float_against_int_ok() {
1849        // float and int are compatible for match patterns
1850        let warns = warnings(
1851            r#"pipeline t(task) {
1852  let x: int = 42
1853  match x {
1854    3.14 -> { log("close") }
1855    _ -> { log("default") }
1856  }
1857}"#,
1858        );
1859        let pattern_warns: Vec<_> = warns
1860            .iter()
1861            .filter(|w| w.contains("Match pattern type mismatch"))
1862            .collect();
1863        assert!(pattern_warns.is_empty());
1864    }
1865
1866    #[test]
1867    fn test_match_pattern_correct_types_no_warning() {
1868        let warns = warnings(
1869            r#"pipeline t(task) {
1870  let x: int = 42
1871  match x {
1872    1 -> { log("one") }
1873    2 -> { log("two") }
1874    _ -> { log("other") }
1875  }
1876}"#,
1877        );
1878        let pattern_warns: Vec<_> = warns
1879            .iter()
1880            .filter(|w| w.contains("Match pattern type mismatch"))
1881            .collect();
1882        assert!(pattern_warns.is_empty());
1883    }
1884
1885    #[test]
1886    fn test_match_pattern_wildcard_no_warning() {
1887        let warns = warnings(
1888            r#"pipeline t(task) {
1889  let x: int = 42
1890  match x {
1891    _ -> { log("catch all") }
1892  }
1893}"#,
1894        );
1895        let pattern_warns: Vec<_> = warns
1896            .iter()
1897            .filter(|w| w.contains("Match pattern type mismatch"))
1898            .collect();
1899        assert!(pattern_warns.is_empty());
1900    }
1901
1902    #[test]
1903    fn test_match_pattern_untyped_no_warning() {
1904        // When value has no known type, no warning should be emitted
1905        let warns = warnings(
1906            r#"pipeline t(task) {
1907  let x = some_unknown_fn()
1908  match x {
1909    "hello" -> { log("string") }
1910    42 -> { log("int") }
1911  }
1912}"#,
1913        );
1914        let pattern_warns: Vec<_> = warns
1915            .iter()
1916            .filter(|w| w.contains("Match pattern type mismatch"))
1917            .collect();
1918        assert!(pattern_warns.is_empty());
1919    }
1920}