Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Impl block methods: type_name → method signatures.
39    impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
40    /// Generic type parameter names in scope (treated as compatible with any type).
41    generic_type_params: std::collections::BTreeSet<String>,
42    parent: Option<Box<TypeScope>>,
43}
44
45/// Method signature extracted from an impl block (for interface checking).
46#[derive(Debug, Clone)]
47struct ImplMethodSig {
48    name: String,
49    /// Number of parameters excluding `self`.
50    param_count: usize,
51    /// Parameter types (excluding `self`), None means untyped.
52    param_types: Vec<Option<TypeExpr>>,
53    /// Return type, None means untyped.
54    return_type: Option<TypeExpr>,
55}
56
57#[derive(Debug, Clone)]
58struct FnSignature {
59    params: Vec<(String, InferredType)>,
60    return_type: InferredType,
61    /// Generic type parameter names declared on the function.
62    type_param_names: Vec<String>,
63    /// Number of required parameters (those without defaults).
64    required_params: usize,
65    /// Where-clause constraints: (type_param_name, interface_bound).
66    where_clauses: Vec<(String, String)>,
67}
68
69impl TypeScope {
70    fn new() -> Self {
71        Self {
72            vars: BTreeMap::new(),
73            functions: BTreeMap::new(),
74            type_aliases: BTreeMap::new(),
75            enums: BTreeMap::new(),
76            interfaces: BTreeMap::new(),
77            structs: BTreeMap::new(),
78            impl_methods: BTreeMap::new(),
79            generic_type_params: std::collections::BTreeSet::new(),
80            parent: None,
81        }
82    }
83
84    fn child(&self) -> Self {
85        Self {
86            vars: BTreeMap::new(),
87            functions: BTreeMap::new(),
88            type_aliases: BTreeMap::new(),
89            enums: BTreeMap::new(),
90            interfaces: BTreeMap::new(),
91            structs: BTreeMap::new(),
92            impl_methods: BTreeMap::new(),
93            generic_type_params: std::collections::BTreeSet::new(),
94            parent: Some(Box::new(self.clone())),
95        }
96    }
97
98    fn get_var(&self, name: &str) -> Option<&InferredType> {
99        self.vars
100            .get(name)
101            .or_else(|| self.parent.as_ref()?.get_var(name))
102    }
103
104    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
105        self.functions
106            .get(name)
107            .or_else(|| self.parent.as_ref()?.get_fn(name))
108    }
109
110    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
111        self.type_aliases
112            .get(name)
113            .or_else(|| self.parent.as_ref()?.resolve_type(name))
114    }
115
116    fn is_generic_type_param(&self, name: &str) -> bool {
117        self.generic_type_params.contains(name)
118            || self
119                .parent
120                .as_ref()
121                .is_some_and(|p| p.is_generic_type_param(name))
122    }
123
124    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
125        self.enums
126            .get(name)
127            .or_else(|| self.parent.as_ref()?.get_enum(name))
128    }
129
130    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
131        self.interfaces
132            .get(name)
133            .or_else(|| self.parent.as_ref()?.get_interface(name))
134    }
135
136    fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
137        self.structs
138            .get(name)
139            .or_else(|| self.parent.as_ref()?.get_struct(name))
140    }
141
142    fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
143        self.impl_methods
144            .get(name)
145            .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
146    }
147
148    fn define_var(&mut self, name: &str, ty: InferredType) {
149        self.vars.insert(name.to_string(), ty);
150    }
151
152    fn define_fn(&mut self, name: &str, sig: FnSignature) {
153        self.functions.insert(name.to_string(), sig);
154    }
155}
156
157/// Known return types for builtin functions.
158fn builtin_return_type(name: &str) -> InferredType {
159    match name {
160        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
161        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
162            Some(TypeExpr::Named("nil".into()))
163        }
164        "type_of"
165        | "to_string"
166        | "json_stringify"
167        | "read_file"
168        | "http_get"
169        | "http_post"
170        | "llm_call"
171        | "regex_replace"
172        | "path_join"
173        | "temp_dir"
174        | "date_format"
175        | "format"
176        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
177        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
178        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
179        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
180            Some(TypeExpr::Named("float".into()))
181        }
182        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
183            Some(TypeExpr::Named("bool".into()))
184        }
185        "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list"
186        | "regex_captures" => Some(TypeExpr::Named("list".into())),
187        "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
188        | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
189            Some(TypeExpr::Named("dict".into()))
190        }
191        "metadata_set"
192        | "metadata_save"
193        | "metadata_refresh_hashes"
194        | "invalidate_facts"
195        | "log_json"
196        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
197        "env" | "regex_match" => Some(TypeExpr::Union(vec![
198            TypeExpr::Named("string".into()),
199            TypeExpr::Named("nil".into()),
200        ])),
201        "json_parse" | "json_extract" => None, // could be any type
202        _ => None,
203    }
204}
205
206/// Check if a name is a known builtin.
207fn is_builtin(name: &str) -> bool {
208    matches!(
209        name,
210        "log"
211            | "print"
212            | "println"
213            | "type_of"
214            | "to_string"
215            | "to_int"
216            | "to_float"
217            | "json_stringify"
218            | "json_parse"
219            | "env"
220            | "timestamp"
221            | "sleep"
222            | "read_file"
223            | "write_file"
224            | "exit"
225            | "regex_match"
226            | "regex_replace"
227            | "regex_captures"
228            | "http_get"
229            | "http_post"
230            | "llm_call"
231            | "agent_loop"
232            | "await"
233            | "cancel"
234            | "file_exists"
235            | "delete_file"
236            | "list_dir"
237            | "mkdir"
238            | "path_join"
239            | "copy_file"
240            | "append_file"
241            | "temp_dir"
242            | "stat"
243            | "exec"
244            | "shell"
245            | "date_now"
246            | "date_format"
247            | "date_parse"
248            | "format"
249            | "json_validate"
250            | "json_extract"
251            | "trim"
252            | "lowercase"
253            | "uppercase"
254            | "split"
255            | "starts_with"
256            | "ends_with"
257            | "contains"
258            | "replace"
259            | "join"
260            | "len"
261            | "substring"
262            | "dirname"
263            | "basename"
264            | "extname"
265            | "sin"
266            | "cos"
267            | "tan"
268            | "asin"
269            | "acos"
270            | "atan"
271            | "atan2"
272            | "log2"
273            | "log10"
274            | "ln"
275            | "exp"
276            | "pi"
277            | "e"
278            | "sign"
279            | "is_nan"
280            | "is_infinite"
281            | "set"
282            | "set_add"
283            | "set_remove"
284            | "set_contains"
285            | "set_union"
286            | "set_intersect"
287            | "set_difference"
288            | "to_list"
289    )
290}
291
292/// The static type checker.
293pub struct TypeChecker {
294    diagnostics: Vec<TypeDiagnostic>,
295    scope: TypeScope,
296}
297
298impl TypeChecker {
299    pub fn new() -> Self {
300        Self {
301            diagnostics: Vec::new(),
302            scope: TypeScope::new(),
303        }
304    }
305
306    /// Check a program and return diagnostics.
307    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
308        // First pass: register type and enum declarations into root scope
309        Self::register_declarations_into(&mut self.scope, program);
310
311        // Also scan pipeline bodies for declarations
312        for snode in program {
313            if let Node::Pipeline { body, .. } = &snode.node {
314                Self::register_declarations_into(&mut self.scope, body);
315            }
316        }
317
318        // Check each top-level node
319        for snode in program {
320            match &snode.node {
321                Node::Pipeline { params, body, .. } => {
322                    let mut child = self.scope.child();
323                    for p in params {
324                        child.define_var(p, None);
325                    }
326                    self.check_block(body, &mut child);
327                }
328                Node::FnDecl {
329                    name,
330                    type_params,
331                    params,
332                    return_type,
333                    where_clauses,
334                    body,
335                    ..
336                } => {
337                    let required_params =
338                        params.iter().filter(|p| p.default_value.is_none()).count();
339                    let sig = FnSignature {
340                        params: params
341                            .iter()
342                            .map(|p| (p.name.clone(), p.type_expr.clone()))
343                            .collect(),
344                        return_type: return_type.clone(),
345                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
346                        required_params,
347                        where_clauses: where_clauses
348                            .iter()
349                            .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
350                            .collect(),
351                    };
352                    self.scope.define_fn(name, sig);
353                    self.check_fn_body(type_params, params, return_type, body);
354                }
355                _ => {
356                    let mut scope = self.scope.clone();
357                    self.check_node(snode, &mut scope);
358                    // Merge any new definitions back into the top-level scope
359                    for (name, ty) in scope.vars {
360                        self.scope.vars.entry(name).or_insert(ty);
361                    }
362                }
363            }
364        }
365
366        self.diagnostics
367    }
368
369    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
370    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
371        for snode in nodes {
372            match &snode.node {
373                Node::TypeDecl { name, type_expr } => {
374                    scope.type_aliases.insert(name.clone(), type_expr.clone());
375                }
376                Node::EnumDecl { name, variants } => {
377                    let variant_names: Vec<String> =
378                        variants.iter().map(|v| v.name.clone()).collect();
379                    scope.enums.insert(name.clone(), variant_names);
380                }
381                Node::InterfaceDecl { name, methods } => {
382                    scope.interfaces.insert(name.clone(), methods.clone());
383                }
384                Node::StructDecl { name, fields } => {
385                    let field_types: Vec<(String, InferredType)> = fields
386                        .iter()
387                        .map(|f| (f.name.clone(), f.type_expr.clone()))
388                        .collect();
389                    scope.structs.insert(name.clone(), field_types);
390                }
391                Node::ImplBlock {
392                    type_name, methods, ..
393                } => {
394                    let sigs: Vec<ImplMethodSig> = methods
395                        .iter()
396                        .filter_map(|m| {
397                            if let Node::FnDecl {
398                                name,
399                                params,
400                                return_type,
401                                ..
402                            } = &m.node
403                            {
404                                let non_self: Vec<_> =
405                                    params.iter().filter(|p| p.name != "self").collect();
406                                let param_count = non_self.len();
407                                let param_types: Vec<Option<TypeExpr>> =
408                                    non_self.iter().map(|p| p.type_expr.clone()).collect();
409                                Some(ImplMethodSig {
410                                    name: name.clone(),
411                                    param_count,
412                                    param_types,
413                                    return_type: return_type.clone(),
414                                })
415                            } else {
416                                None
417                            }
418                        })
419                        .collect();
420                    scope.impl_methods.insert(type_name.clone(), sigs);
421                }
422                _ => {}
423            }
424        }
425    }
426
427    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
428        for stmt in stmts {
429            self.check_node(stmt, scope);
430        }
431    }
432
433    /// Define variables from a destructuring pattern in the given scope (as unknown type).
434    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
435        match pattern {
436            BindingPattern::Identifier(name) => {
437                scope.define_var(name, None);
438            }
439            BindingPattern::Dict(fields) => {
440                for field in fields {
441                    let name = field.alias.as_deref().unwrap_or(&field.key);
442                    scope.define_var(name, None);
443                }
444            }
445            BindingPattern::List(elements) => {
446                for elem in elements {
447                    scope.define_var(&elem.name, None);
448                }
449            }
450        }
451    }
452
453    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
454        let span = snode.span;
455        match &snode.node {
456            Node::LetBinding {
457                pattern,
458                type_ann,
459                value,
460            } => {
461                let inferred = self.infer_type(value, scope);
462                if let BindingPattern::Identifier(name) = pattern {
463                    if let Some(expected) = type_ann {
464                        if let Some(actual) = &inferred {
465                            if !self.types_compatible(expected, actual, scope) {
466                                let mut msg = format!(
467                                    "Type mismatch: '{}' declared as {}, but assigned {}",
468                                    name,
469                                    format_type(expected),
470                                    format_type(actual)
471                                );
472                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
473                                    msg.push_str(&format!(" ({})", detail));
474                                }
475                                self.error_at(msg, span);
476                            }
477                        }
478                    }
479                    let ty = type_ann.clone().or(inferred);
480                    scope.define_var(name, ty);
481                } else {
482                    Self::define_pattern_vars(pattern, scope);
483                }
484            }
485
486            Node::VarBinding {
487                pattern,
488                type_ann,
489                value,
490            } => {
491                let inferred = self.infer_type(value, scope);
492                if let BindingPattern::Identifier(name) = pattern {
493                    if let Some(expected) = type_ann {
494                        if let Some(actual) = &inferred {
495                            if !self.types_compatible(expected, actual, scope) {
496                                let mut msg = format!(
497                                    "Type mismatch: '{}' declared as {}, but assigned {}",
498                                    name,
499                                    format_type(expected),
500                                    format_type(actual)
501                                );
502                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
503                                    msg.push_str(&format!(" ({})", detail));
504                                }
505                                self.error_at(msg, span);
506                            }
507                        }
508                    }
509                    let ty = type_ann.clone().or(inferred);
510                    scope.define_var(name, ty);
511                } else {
512                    Self::define_pattern_vars(pattern, scope);
513                }
514            }
515
516            Node::FnDecl {
517                name,
518                type_params,
519                params,
520                return_type,
521                where_clauses,
522                body,
523                ..
524            } => {
525                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
526                let sig = FnSignature {
527                    params: params
528                        .iter()
529                        .map(|p| (p.name.clone(), p.type_expr.clone()))
530                        .collect(),
531                    return_type: return_type.clone(),
532                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
533                    required_params,
534                    where_clauses: where_clauses
535                        .iter()
536                        .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
537                        .collect(),
538                };
539                scope.define_fn(name, sig.clone());
540                scope.define_var(name, None);
541                self.check_fn_body(type_params, params, return_type, body);
542            }
543
544            Node::FunctionCall { name, args } => {
545                self.check_call(name, args, scope, span);
546            }
547
548            Node::IfElse {
549                condition,
550                then_body,
551                else_body,
552            } => {
553                self.check_node(condition, scope);
554                let mut then_scope = scope.child();
555                // Narrow union types after nil checks: if x != nil, narrow x
556                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
557                    then_scope.define_var(&var_name, narrowed);
558                }
559                self.check_block(then_body, &mut then_scope);
560                if let Some(else_body) = else_body {
561                    let mut else_scope = scope.child();
562                    self.check_block(else_body, &mut else_scope);
563                }
564            }
565
566            Node::ForIn {
567                pattern,
568                iterable,
569                body,
570            } => {
571                self.check_node(iterable, scope);
572                let mut loop_scope = scope.child();
573                if let BindingPattern::Identifier(variable) = pattern {
574                    // Infer loop variable type from iterable
575                    let elem_type = match self.infer_type(iterable, scope) {
576                        Some(TypeExpr::List(inner)) => Some(*inner),
577                        Some(TypeExpr::Named(n)) if n == "string" => {
578                            Some(TypeExpr::Named("string".into()))
579                        }
580                        _ => None,
581                    };
582                    loop_scope.define_var(variable, elem_type);
583                } else {
584                    Self::define_pattern_vars(pattern, &mut loop_scope);
585                }
586                self.check_block(body, &mut loop_scope);
587            }
588
589            Node::WhileLoop { condition, body } => {
590                self.check_node(condition, scope);
591                let mut loop_scope = scope.child();
592                self.check_block(body, &mut loop_scope);
593            }
594
595            Node::TryCatch {
596                body,
597                error_var,
598                catch_body,
599                finally_body,
600                ..
601            } => {
602                let mut try_scope = scope.child();
603                self.check_block(body, &mut try_scope);
604                let mut catch_scope = scope.child();
605                if let Some(var) = error_var {
606                    catch_scope.define_var(var, None);
607                }
608                self.check_block(catch_body, &mut catch_scope);
609                if let Some(fb) = finally_body {
610                    let mut finally_scope = scope.child();
611                    self.check_block(fb, &mut finally_scope);
612                }
613            }
614
615            Node::TryExpr { body } => {
616                let mut try_scope = scope.child();
617                self.check_block(body, &mut try_scope);
618            }
619
620            Node::ReturnStmt {
621                value: Some(val), ..
622            } => {
623                self.check_node(val, scope);
624            }
625
626            Node::Assignment {
627                target, value, op, ..
628            } => {
629                self.check_node(value, scope);
630                if let Node::Identifier(name) = &target.node {
631                    if let Some(Some(var_type)) = scope.get_var(name) {
632                        let value_type = self.infer_type(value, scope);
633                        let assigned = if let Some(op) = op {
634                            let var_inferred = scope.get_var(name).cloned().flatten();
635                            infer_binary_op_type(op, &var_inferred, &value_type)
636                        } else {
637                            value_type
638                        };
639                        if let Some(actual) = &assigned {
640                            if !self.types_compatible(var_type, actual, scope) {
641                                self.error_at(
642                                    format!(
643                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
644                                        format_type(actual),
645                                        name,
646                                        format_type(var_type)
647                                    ),
648                                    span,
649                                );
650                            }
651                        }
652                    }
653                }
654            }
655
656            Node::TypeDecl { name, type_expr } => {
657                scope.type_aliases.insert(name.clone(), type_expr.clone());
658            }
659
660            Node::EnumDecl { name, variants } => {
661                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
662                scope.enums.insert(name.clone(), variant_names);
663            }
664
665            Node::StructDecl { name, fields } => {
666                let field_types: Vec<(String, InferredType)> = fields
667                    .iter()
668                    .map(|f| (f.name.clone(), f.type_expr.clone()))
669                    .collect();
670                scope.structs.insert(name.clone(), field_types);
671            }
672
673            Node::InterfaceDecl { name, methods } => {
674                scope.interfaces.insert(name.clone(), methods.clone());
675            }
676
677            Node::ImplBlock {
678                type_name, methods, ..
679            } => {
680                // Register impl methods for interface satisfaction checking
681                let sigs: Vec<ImplMethodSig> = methods
682                    .iter()
683                    .filter_map(|m| {
684                        if let Node::FnDecl {
685                            name,
686                            params,
687                            return_type,
688                            ..
689                        } = &m.node
690                        {
691                            let non_self: Vec<_> =
692                                params.iter().filter(|p| p.name != "self").collect();
693                            let param_count = non_self.len();
694                            let param_types: Vec<Option<TypeExpr>> =
695                                non_self.iter().map(|p| p.type_expr.clone()).collect();
696                            Some(ImplMethodSig {
697                                name: name.clone(),
698                                param_count,
699                                param_types,
700                                return_type: return_type.clone(),
701                            })
702                        } else {
703                            None
704                        }
705                    })
706                    .collect();
707                scope.impl_methods.insert(type_name.clone(), sigs);
708                for method_sn in methods {
709                    self.check_node(method_sn, scope);
710                }
711            }
712
713            Node::TryOperator { operand } => {
714                self.check_node(operand, scope);
715            }
716
717            Node::MatchExpr { value, arms } => {
718                self.check_node(value, scope);
719                let value_type = self.infer_type(value, scope);
720                for arm in arms {
721                    self.check_node(&arm.pattern, scope);
722                    // Check for incompatible literal pattern types
723                    if let Some(ref vt) = value_type {
724                        let value_type_name = format_type(vt);
725                        let mismatch = match &arm.pattern.node {
726                            Node::StringLiteral(_) => {
727                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
728                            }
729                            Node::IntLiteral(_) => {
730                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
731                                    && !self.types_compatible(
732                                        vt,
733                                        &TypeExpr::Named("float".into()),
734                                        scope,
735                                    )
736                            }
737                            Node::FloatLiteral(_) => {
738                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
739                                    && !self.types_compatible(
740                                        vt,
741                                        &TypeExpr::Named("int".into()),
742                                        scope,
743                                    )
744                            }
745                            Node::BoolLiteral(_) => {
746                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
747                            }
748                            _ => false,
749                        };
750                        if mismatch {
751                            let pattern_type = match &arm.pattern.node {
752                                Node::StringLiteral(_) => "string",
753                                Node::IntLiteral(_) => "int",
754                                Node::FloatLiteral(_) => "float",
755                                Node::BoolLiteral(_) => "bool",
756                                _ => unreachable!(),
757                            };
758                            self.warning_at(
759                                format!(
760                                    "Match pattern type mismatch: matching {} against {} literal",
761                                    value_type_name, pattern_type
762                                ),
763                                arm.pattern.span,
764                            );
765                        }
766                    }
767                    let mut arm_scope = scope.child();
768                    self.check_block(&arm.body, &mut arm_scope);
769                }
770                self.check_match_exhaustiveness(value, arms, scope, span);
771            }
772
773            // Recurse into nested expressions + validate binary op types
774            Node::BinaryOp { op, left, right } => {
775                self.check_node(left, scope);
776                self.check_node(right, scope);
777                // Validate operator/type compatibility
778                let lt = self.infer_type(left, scope);
779                let rt = self.infer_type(right, scope);
780                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
781                    match op.as_str() {
782                        "-" | "*" | "/" | "%" => {
783                            let numeric = ["int", "float"];
784                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
785                                self.warning_at(
786                                    format!(
787                                        "Operator '{op}' may not be valid for types {} and {}",
788                                        l, r
789                                    ),
790                                    span,
791                                );
792                            }
793                        }
794                        "+" => {
795                            // + is valid for int, float, string, list, dict
796                            let valid = ["int", "float", "string", "list", "dict"];
797                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
798                                self.warning_at(
799                                    format!(
800                                        "Operator '+' may not be valid for types {} and {}",
801                                        l, r
802                                    ),
803                                    span,
804                                );
805                            }
806                        }
807                        _ => {}
808                    }
809                }
810            }
811            Node::UnaryOp { operand, .. } => {
812                self.check_node(operand, scope);
813            }
814            Node::MethodCall { object, args, .. }
815            | Node::OptionalMethodCall { object, args, .. } => {
816                self.check_node(object, scope);
817                for arg in args {
818                    self.check_node(arg, scope);
819                }
820            }
821            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
822                self.check_node(object, scope);
823            }
824            Node::SubscriptAccess { object, index } => {
825                self.check_node(object, scope);
826                self.check_node(index, scope);
827            }
828            Node::SliceAccess { object, start, end } => {
829                self.check_node(object, scope);
830                if let Some(s) = start {
831                    self.check_node(s, scope);
832                }
833                if let Some(e) = end {
834                    self.check_node(e, scope);
835                }
836            }
837
838            // Terminals — nothing to check
839            _ => {}
840        }
841    }
842
843    fn check_fn_body(
844        &mut self,
845        type_params: &[TypeParam],
846        params: &[TypedParam],
847        return_type: &Option<TypeExpr>,
848        body: &[SNode],
849    ) {
850        let mut fn_scope = self.scope.child();
851        // Register generic type parameters so they are treated as compatible
852        // with any concrete type during type checking.
853        for tp in type_params {
854            fn_scope.generic_type_params.insert(tp.name.clone());
855        }
856        for param in params {
857            fn_scope.define_var(&param.name, param.type_expr.clone());
858            if let Some(default) = &param.default_value {
859                self.check_node(default, &mut fn_scope);
860            }
861        }
862        self.check_block(body, &mut fn_scope);
863
864        // Check return statements against declared return type
865        if let Some(ret_type) = return_type {
866            for stmt in body {
867                self.check_return_type(stmt, ret_type, &fn_scope);
868            }
869        }
870    }
871
872    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
873        let span = snode.span;
874        match &snode.node {
875            Node::ReturnStmt { value: Some(val) } => {
876                let inferred = self.infer_type(val, scope);
877                if let Some(actual) = &inferred {
878                    if !self.types_compatible(expected, actual, scope) {
879                        self.error_at(
880                            format!(
881                                "Return type mismatch: expected {}, got {}",
882                                format_type(expected),
883                                format_type(actual)
884                            ),
885                            span,
886                        );
887                    }
888                }
889            }
890            Node::IfElse {
891                then_body,
892                else_body,
893                ..
894            } => {
895                for stmt in then_body {
896                    self.check_return_type(stmt, expected, scope);
897                }
898                if let Some(else_body) = else_body {
899                    for stmt in else_body {
900                        self.check_return_type(stmt, expected, scope);
901                    }
902                }
903            }
904            _ => {}
905        }
906    }
907
908    /// Check if a match expression on an enum's `.variant` property covers all variants.
909    /// Extract narrowing info from nil-check conditions like `x != nil`.
910    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
911    /// Check if a type satisfies an interface (Go-style implicit satisfaction).
912    /// A type satisfies an interface if its impl block has all the required methods.
913    fn satisfies_interface(
914        &self,
915        type_name: &str,
916        interface_name: &str,
917        scope: &TypeScope,
918    ) -> bool {
919        let interface_methods = match scope.get_interface(interface_name) {
920            Some(methods) => methods,
921            None => return false,
922        };
923        let impl_methods = match scope.get_impl_methods(type_name) {
924            Some(methods) => methods,
925            None => return interface_methods.is_empty(),
926        };
927        interface_methods.iter().all(|iface_method| {
928            let iface_params: Vec<_> = iface_method
929                .params
930                .iter()
931                .filter(|p| p.name != "self")
932                .collect();
933            let iface_param_count = iface_params.len();
934            impl_methods.iter().any(|impl_method| {
935                if impl_method.name != iface_method.name
936                    || impl_method.param_count != iface_param_count
937                {
938                    return false;
939                }
940                // Check parameter types where both sides specify them
941                for (i, iface_param) in iface_params.iter().enumerate() {
942                    if let (Some(expected), Some(actual)) = (
943                        &iface_param.type_expr,
944                        impl_method.param_types.get(i).and_then(|t| t.as_ref()),
945                    ) {
946                        if !self.types_compatible(expected, actual, scope) {
947                            return false;
948                        }
949                    }
950                }
951                // Check return type where both sides specify it
952                if let (Some(expected_ret), Some(actual_ret)) =
953                    (&iface_method.return_type, &impl_method.return_type)
954                {
955                    if !self.types_compatible(expected_ret, actual_ret, scope) {
956                        return false;
957                    }
958                }
959                true
960            })
961        })
962    }
963
964    /// Recursively extract type parameter bindings from matching param/arg types.
965    /// E.g., param_type=list<T> + arg_type=list<Dog> → binds T=Dog.
966    fn extract_type_bindings(
967        param_type: &TypeExpr,
968        arg_type: &TypeExpr,
969        type_params: &std::collections::BTreeSet<String>,
970        bindings: &mut BTreeMap<String, String>,
971    ) {
972        match (param_type, arg_type) {
973            // Direct type param match: T → concrete
974            (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
975                if type_params.contains(param_name) =>
976            {
977                bindings
978                    .entry(param_name.clone())
979                    .or_insert(concrete.clone());
980            }
981            // list<T> + list<Dog>
982            (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
983                Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
984            }
985            // dict<K, V> + dict<string, int>
986            (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
987                Self::extract_type_bindings(pk, ak, type_params, bindings);
988                Self::extract_type_bindings(pv, av, type_params, bindings);
989            }
990            _ => {}
991        }
992    }
993
994    fn extract_nil_narrowing(
995        condition: &SNode,
996        scope: &TypeScope,
997    ) -> Option<(String, InferredType)> {
998        if let Node::BinaryOp { op, left, right } = &condition.node {
999            if op == "!=" {
1000                // Check for `x != nil` or `nil != x`
1001                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1002                    (left, right)
1003                } else if matches!(left.node, Node::NilLiteral) {
1004                    (right, left)
1005                } else {
1006                    return None;
1007                };
1008                let _ = nil_node;
1009                if let Node::Identifier(name) = &var_node.node {
1010                    // Look up the variable's type and narrow it
1011                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1012                        let narrowed: Vec<TypeExpr> = members
1013                            .iter()
1014                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1015                            .cloned()
1016                            .collect();
1017                        return if narrowed.len() == 1 {
1018                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1019                        } else if narrowed.is_empty() {
1020                            None
1021                        } else {
1022                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1023                        };
1024                    }
1025                }
1026            }
1027        }
1028        None
1029    }
1030
1031    fn check_match_exhaustiveness(
1032        &mut self,
1033        value: &SNode,
1034        arms: &[MatchArm],
1035        scope: &TypeScope,
1036        span: Span,
1037    ) {
1038        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
1039        let enum_name = match &value.node {
1040            Node::PropertyAccess { object, property } if property == "variant" => {
1041                // Infer the type of the object
1042                match self.infer_type(object, scope) {
1043                    Some(TypeExpr::Named(name)) => {
1044                        if scope.get_enum(&name).is_some() {
1045                            Some(name)
1046                        } else {
1047                            None
1048                        }
1049                    }
1050                    _ => None,
1051                }
1052            }
1053            _ => {
1054                // Direct match on an enum value: match <expr> { ... }
1055                match self.infer_type(value, scope) {
1056                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1057                    _ => None,
1058                }
1059            }
1060        };
1061
1062        let Some(enum_name) = enum_name else {
1063            return;
1064        };
1065        let Some(variants) = scope.get_enum(&enum_name) else {
1066            return;
1067        };
1068
1069        // Collect variant names covered by match arms
1070        let mut covered: Vec<String> = Vec::new();
1071        let mut has_wildcard = false;
1072
1073        for arm in arms {
1074            match &arm.pattern.node {
1075                // String literal pattern (matching on .variant): "VariantA"
1076                Node::StringLiteral(s) => covered.push(s.clone()),
1077                // Identifier pattern acts as a wildcard/catch-all
1078                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1079                    has_wildcard = true;
1080                }
1081                // Direct enum construct pattern: EnumName.Variant
1082                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1083                // PropertyAccess pattern: EnumName.Variant (no args)
1084                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1085                _ => {
1086                    // Unknown pattern shape — conservatively treat as wildcard
1087                    has_wildcard = true;
1088                }
1089            }
1090        }
1091
1092        if has_wildcard {
1093            return;
1094        }
1095
1096        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1097        if !missing.is_empty() {
1098            let missing_str = missing
1099                .iter()
1100                .map(|s| format!("\"{}\"", s))
1101                .collect::<Vec<_>>()
1102                .join(", ");
1103            self.warning_at(
1104                format!(
1105                    "Non-exhaustive match on enum {}: missing variants {}",
1106                    enum_name, missing_str
1107                ),
1108                span,
1109            );
1110        }
1111    }
1112
1113    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1114        // Check against known function signatures
1115        let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1116        if let Some(sig) = scope.get_fn(name).cloned() {
1117            if !has_spread
1118                && !is_builtin(name)
1119                && (args.len() < sig.required_params || args.len() > sig.params.len())
1120            {
1121                let expected = if sig.required_params == sig.params.len() {
1122                    format!("{}", sig.params.len())
1123                } else {
1124                    format!("{}-{}", sig.required_params, sig.params.len())
1125                };
1126                self.warning_at(
1127                    format!(
1128                        "Function '{}' expects {} arguments, got {}",
1129                        name,
1130                        expected,
1131                        args.len()
1132                    ),
1133                    span,
1134                );
1135            }
1136            // Build a scope that includes the function's generic type params
1137            // so they are treated as compatible with any concrete type.
1138            let call_scope = if sig.type_param_names.is_empty() {
1139                scope.clone()
1140            } else {
1141                let mut s = scope.child();
1142                for tp_name in &sig.type_param_names {
1143                    s.generic_type_params.insert(tp_name.clone());
1144                }
1145                s
1146            };
1147            for (i, (arg, (param_name, param_type))) in
1148                args.iter().zip(sig.params.iter()).enumerate()
1149            {
1150                if let Some(expected) = param_type {
1151                    let actual = self.infer_type(arg, scope);
1152                    if let Some(actual) = &actual {
1153                        if !self.types_compatible(expected, actual, &call_scope) {
1154                            self.error_at(
1155                                format!(
1156                                    "Argument {} ('{}'): expected {}, got {}",
1157                                    i + 1,
1158                                    param_name,
1159                                    format_type(expected),
1160                                    format_type(actual)
1161                                ),
1162                                arg.span,
1163                            );
1164                        }
1165                    }
1166                }
1167            }
1168            // Enforce where-clause constraints at call site
1169            if !sig.where_clauses.is_empty() {
1170                // Build mapping: type_param → concrete type from inferred args.
1171                // Recursively walks Generic types so list<T> + list<Dog> binds T=Dog.
1172                let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1173                let type_param_set: std::collections::BTreeSet<String> =
1174                    sig.type_param_names.iter().cloned().collect();
1175                for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1176                    if let Some(param_ty) = param_type {
1177                        if let Some(arg_ty) = self.infer_type(arg, scope) {
1178                            Self::extract_type_bindings(
1179                                param_ty,
1180                                &arg_ty,
1181                                &type_param_set,
1182                                &mut type_bindings,
1183                            );
1184                        }
1185                    }
1186                }
1187                for (type_param, bound) in &sig.where_clauses {
1188                    if let Some(concrete_type) = type_bindings.get(type_param) {
1189                        if !self.satisfies_interface(concrete_type, bound, scope) {
1190                            self.warning_at(
1191                                format!(
1192                                    "Type '{}' does not satisfy interface '{}': \
1193                                     required by constraint `where {}: {}`",
1194                                    concrete_type, bound, type_param, bound
1195                                ),
1196                                span,
1197                            );
1198                        }
1199                    }
1200                }
1201            }
1202        }
1203        // Check args recursively
1204        for arg in args {
1205            self.check_node(arg, scope);
1206        }
1207    }
1208
1209    /// Infer the type of an expression.
1210    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1211        match &snode.node {
1212            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1213            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1214            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1215                Some(TypeExpr::Named("string".into()))
1216            }
1217            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1218            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1219            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1220            Node::DictLiteral(entries) => {
1221                // Infer shape type when all keys are string literals
1222                let mut fields = Vec::new();
1223                let mut all_string_keys = true;
1224                for entry in entries {
1225                    if let Node::StringLiteral(key) = &entry.key.node {
1226                        let val_type = self
1227                            .infer_type(&entry.value, scope)
1228                            .unwrap_or(TypeExpr::Named("nil".into()));
1229                        fields.push(ShapeField {
1230                            name: key.clone(),
1231                            type_expr: val_type,
1232                            optional: false,
1233                        });
1234                    } else {
1235                        all_string_keys = false;
1236                        break;
1237                    }
1238                }
1239                if all_string_keys && !fields.is_empty() {
1240                    Some(TypeExpr::Shape(fields))
1241                } else {
1242                    Some(TypeExpr::Named("dict".into()))
1243                }
1244            }
1245            Node::Closure { params, body } => {
1246                // If all params are typed and we can infer a return type, produce FnType
1247                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1248                if all_typed && !params.is_empty() {
1249                    let param_types: Vec<TypeExpr> =
1250                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1251                    // Try to infer return type from last expression in body
1252                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1253                    if let Some(ret_type) = ret {
1254                        return Some(TypeExpr::FnType {
1255                            params: param_types,
1256                            return_type: Box::new(ret_type),
1257                        });
1258                    }
1259                }
1260                Some(TypeExpr::Named("closure".into()))
1261            }
1262
1263            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1264
1265            Node::FunctionCall { name, .. } => {
1266                // Struct constructor calls return the struct type
1267                if scope.get_struct(name).is_some() {
1268                    return Some(TypeExpr::Named(name.clone()));
1269                }
1270                // Check user-defined function return types
1271                if let Some(sig) = scope.get_fn(name) {
1272                    return sig.return_type.clone();
1273                }
1274                // Check builtin return types
1275                builtin_return_type(name)
1276            }
1277
1278            Node::BinaryOp { op, left, right } => {
1279                let lt = self.infer_type(left, scope);
1280                let rt = self.infer_type(right, scope);
1281                infer_binary_op_type(op, &lt, &rt)
1282            }
1283
1284            Node::UnaryOp { op, operand } => {
1285                let t = self.infer_type(operand, scope);
1286                match op.as_str() {
1287                    "!" => Some(TypeExpr::Named("bool".into())),
1288                    "-" => t, // negation preserves type
1289                    _ => None,
1290                }
1291            }
1292
1293            Node::Ternary {
1294                true_expr,
1295                false_expr,
1296                ..
1297            } => {
1298                let tt = self.infer_type(true_expr, scope);
1299                let ft = self.infer_type(false_expr, scope);
1300                match (&tt, &ft) {
1301                    (Some(a), Some(b)) if a == b => tt,
1302                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1303                    (Some(_), None) => tt,
1304                    (None, Some(_)) => ft,
1305                    (None, None) => None,
1306                }
1307            }
1308
1309            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1310
1311            Node::PropertyAccess { object, property } => {
1312                // EnumName.Variant → infer as the enum type
1313                if let Node::Identifier(name) = &object.node {
1314                    if scope.get_enum(name).is_some() {
1315                        return Some(TypeExpr::Named(name.clone()));
1316                    }
1317                }
1318                // .variant on an enum value → string
1319                if property == "variant" {
1320                    let obj_type = self.infer_type(object, scope);
1321                    if let Some(TypeExpr::Named(name)) = &obj_type {
1322                        if scope.get_enum(name).is_some() {
1323                            return Some(TypeExpr::Named("string".into()));
1324                        }
1325                    }
1326                }
1327                // Shape field access: obj.field → field type
1328                let obj_type = self.infer_type(object, scope);
1329                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1330                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1331                        return Some(field.type_expr.clone());
1332                    }
1333                }
1334                None
1335            }
1336
1337            Node::SubscriptAccess { object, index } => {
1338                let obj_type = self.infer_type(object, scope);
1339                match &obj_type {
1340                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1341                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1342                    Some(TypeExpr::Shape(fields)) => {
1343                        // If index is a string literal, look up the field type
1344                        if let Node::StringLiteral(key) = &index.node {
1345                            fields
1346                                .iter()
1347                                .find(|f| &f.name == key)
1348                                .map(|f| f.type_expr.clone())
1349                        } else {
1350                            None
1351                        }
1352                    }
1353                    Some(TypeExpr::Named(n)) if n == "list" => None,
1354                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1355                    Some(TypeExpr::Named(n)) if n == "string" => {
1356                        Some(TypeExpr::Named("string".into()))
1357                    }
1358                    _ => None,
1359                }
1360            }
1361            Node::SliceAccess { object, .. } => {
1362                // Slicing a list returns the same list type; slicing a string returns string
1363                let obj_type = self.infer_type(object, scope);
1364                match &obj_type {
1365                    Some(TypeExpr::List(_)) => obj_type,
1366                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1367                    Some(TypeExpr::Named(n)) if n == "string" => {
1368                        Some(TypeExpr::Named("string".into()))
1369                    }
1370                    _ => None,
1371                }
1372            }
1373            Node::MethodCall { object, method, .. }
1374            | Node::OptionalMethodCall { object, method, .. } => {
1375                let obj_type = self.infer_type(object, scope);
1376                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1377                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1378                match method.as_str() {
1379                    // Shared: bool-returning methods
1380                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1381                        Some(TypeExpr::Named("bool".into()))
1382                    }
1383                    // Shared: int-returning methods
1384                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1385                    // String methods
1386                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1387                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1388                        Some(TypeExpr::Named("string".into()))
1389                    }
1390                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1391                    // filter returns dict for dicts, list for lists
1392                    "filter" => {
1393                        if is_dict {
1394                            Some(TypeExpr::Named("dict".into()))
1395                        } else {
1396                            Some(TypeExpr::Named("list".into()))
1397                        }
1398                    }
1399                    // List methods
1400                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1401                    "reduce" | "find" | "first" | "last" => None,
1402                    // Dict methods
1403                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1404                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1405                    // Conversions
1406                    "to_string" => Some(TypeExpr::Named("string".into())),
1407                    "to_int" => Some(TypeExpr::Named("int".into())),
1408                    "to_float" => Some(TypeExpr::Named("float".into())),
1409                    _ => None,
1410                }
1411            }
1412
1413            // TryOperator on Result<T, E> produces T
1414            Node::TryOperator { operand } => {
1415                match self.infer_type(operand, scope) {
1416                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1417                    _ => None,
1418                }
1419            }
1420
1421            _ => None,
1422        }
1423    }
1424
1425    /// Check if two types are compatible (actual can be assigned to expected).
1426    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1427        // Generic type parameters match anything.
1428        if let TypeExpr::Named(name) = expected {
1429            if scope.is_generic_type_param(name) {
1430                return true;
1431            }
1432        }
1433        if let TypeExpr::Named(name) = actual {
1434            if scope.is_generic_type_param(name) {
1435                return true;
1436            }
1437        }
1438        let expected = self.resolve_alias(expected, scope);
1439        let actual = self.resolve_alias(actual, scope);
1440
1441        // Interface satisfaction: if expected is an interface name, check if actual type
1442        // has all required methods (Go-style implicit satisfaction).
1443        if let TypeExpr::Named(iface_name) = &expected {
1444            if scope.get_interface(iface_name).is_some() {
1445                if let TypeExpr::Named(type_name) = &actual {
1446                    return self.satisfies_interface(type_name, iface_name, scope);
1447                }
1448                return false;
1449            }
1450        }
1451
1452        match (&expected, &actual) {
1453            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1454            (TypeExpr::Union(members), actual_type) => members
1455                .iter()
1456                .any(|m| self.types_compatible(m, actual_type, scope)),
1457            (expected_type, TypeExpr::Union(members)) => members
1458                .iter()
1459                .all(|m| self.types_compatible(expected_type, m, scope)),
1460            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1461            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1462            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1463                if expected_field.optional {
1464                    return true;
1465                }
1466                af.iter().any(|actual_field| {
1467                    actual_field.name == expected_field.name
1468                        && self.types_compatible(
1469                            &expected_field.type_expr,
1470                            &actual_field.type_expr,
1471                            scope,
1472                        )
1473                })
1474            }),
1475            // dict<K, V> expected, Shape actual → all field values must match V
1476            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1477                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1478                keys_ok
1479                    && af
1480                        .iter()
1481                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1482            }
1483            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1484            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1485            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1486                self.types_compatible(expected_inner, actual_inner, scope)
1487            }
1488            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1489            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1490            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1491                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1492            }
1493            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1494            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1495            // FnType compatibility: params match positionally and return types match
1496            (
1497                TypeExpr::FnType {
1498                    params: ep,
1499                    return_type: er,
1500                },
1501                TypeExpr::FnType {
1502                    params: ap,
1503                    return_type: ar,
1504                },
1505            ) => {
1506                ep.len() == ap.len()
1507                    && ep
1508                        .iter()
1509                        .zip(ap.iter())
1510                        .all(|(e, a)| self.types_compatible(e, a, scope))
1511                    && self.types_compatible(er, ar, scope)
1512            }
1513            // FnType is compatible with Named("closure") for backward compat
1514            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1515            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1516            _ => false,
1517        }
1518    }
1519
1520    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1521        if let TypeExpr::Named(name) = ty {
1522            if let Some(resolved) = scope.resolve_type(name) {
1523                return resolved.clone();
1524            }
1525        }
1526        ty.clone()
1527    }
1528
1529    fn error_at(&mut self, message: String, span: Span) {
1530        self.diagnostics.push(TypeDiagnostic {
1531            message,
1532            severity: DiagnosticSeverity::Error,
1533            span: Some(span),
1534        });
1535    }
1536
1537    fn warning_at(&mut self, message: String, span: Span) {
1538        self.diagnostics.push(TypeDiagnostic {
1539            message,
1540            severity: DiagnosticSeverity::Warning,
1541            span: Some(span),
1542        });
1543    }
1544}
1545
1546impl Default for TypeChecker {
1547    fn default() -> Self {
1548        Self::new()
1549    }
1550}
1551
1552/// Infer the result type of a binary operation.
1553fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1554    match op {
1555        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1556        "+" => match (left, right) {
1557            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1558                match (l.as_str(), r.as_str()) {
1559                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1560                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1561                    ("string", _) => Some(TypeExpr::Named("string".into())),
1562                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1563                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1564                    _ => Some(TypeExpr::Named("string".into())),
1565                }
1566            }
1567            _ => None,
1568        },
1569        "-" | "*" | "/" | "%" => match (left, right) {
1570            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1571                match (l.as_str(), r.as_str()) {
1572                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1573                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1574                    _ => None,
1575                }
1576            }
1577            _ => None,
1578        },
1579        "??" => match (left, right) {
1580            (Some(TypeExpr::Union(members)), _) => {
1581                let non_nil: Vec<_> = members
1582                    .iter()
1583                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1584                    .cloned()
1585                    .collect();
1586                if non_nil.len() == 1 {
1587                    Some(non_nil[0].clone())
1588                } else if non_nil.is_empty() {
1589                    right.clone()
1590                } else {
1591                    Some(TypeExpr::Union(non_nil))
1592                }
1593            }
1594            _ => right.clone(),
1595        },
1596        "|>" => None,
1597        _ => None,
1598    }
1599}
1600
1601/// Format a type expression for display in error messages.
1602/// Produce a detail string describing why a Shape type is incompatible with
1603/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1604/// has type int, expected string".  Returns `None` if both types are not shapes.
1605pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1606    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1607        let mut details = Vec::new();
1608        for field in ef {
1609            if field.optional {
1610                continue;
1611            }
1612            match af.iter().find(|f| f.name == field.name) {
1613                None => details.push(format!(
1614                    "missing field '{}' ({})",
1615                    field.name,
1616                    format_type(&field.type_expr)
1617                )),
1618                Some(actual_field) => {
1619                    let e_str = format_type(&field.type_expr);
1620                    let a_str = format_type(&actual_field.type_expr);
1621                    if e_str != a_str {
1622                        details.push(format!(
1623                            "field '{}' has type {}, expected {}",
1624                            field.name, a_str, e_str
1625                        ));
1626                    }
1627                }
1628            }
1629        }
1630        if details.is_empty() {
1631            None
1632        } else {
1633            Some(details.join("; "))
1634        }
1635    } else {
1636        None
1637    }
1638}
1639
1640pub fn format_type(ty: &TypeExpr) -> String {
1641    match ty {
1642        TypeExpr::Named(n) => n.clone(),
1643        TypeExpr::Union(types) => types
1644            .iter()
1645            .map(format_type)
1646            .collect::<Vec<_>>()
1647            .join(" | "),
1648        TypeExpr::Shape(fields) => {
1649            let inner: Vec<String> = fields
1650                .iter()
1651                .map(|f| {
1652                    let opt = if f.optional { "?" } else { "" };
1653                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1654                })
1655                .collect();
1656            format!("{{{}}}", inner.join(", "))
1657        }
1658        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1659        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1660        TypeExpr::FnType {
1661            params,
1662            return_type,
1663        } => {
1664            let params_str = params
1665                .iter()
1666                .map(format_type)
1667                .collect::<Vec<_>>()
1668                .join(", ");
1669            format!("fn({}) -> {}", params_str, format_type(return_type))
1670        }
1671    }
1672}
1673
1674#[cfg(test)]
1675mod tests {
1676    use super::*;
1677    use crate::Parser;
1678    use harn_lexer::Lexer;
1679
1680    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1681        let mut lexer = Lexer::new(source);
1682        let tokens = lexer.tokenize().unwrap();
1683        let mut parser = Parser::new(tokens);
1684        let program = parser.parse().unwrap();
1685        TypeChecker::new().check(&program)
1686    }
1687
1688    fn errors(source: &str) -> Vec<String> {
1689        check_source(source)
1690            .into_iter()
1691            .filter(|d| d.severity == DiagnosticSeverity::Error)
1692            .map(|d| d.message)
1693            .collect()
1694    }
1695
1696    #[test]
1697    fn test_no_errors_for_untyped_code() {
1698        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1699        assert!(errs.is_empty());
1700    }
1701
1702    #[test]
1703    fn test_correct_typed_let() {
1704        let errs = errors("pipeline t(task) { let x: int = 42 }");
1705        assert!(errs.is_empty());
1706    }
1707
1708    #[test]
1709    fn test_type_mismatch_let() {
1710        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1711        assert_eq!(errs.len(), 1);
1712        assert!(errs[0].contains("Type mismatch"));
1713        assert!(errs[0].contains("int"));
1714        assert!(errs[0].contains("string"));
1715    }
1716
1717    #[test]
1718    fn test_correct_typed_fn() {
1719        let errs = errors(
1720            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1721        );
1722        assert!(errs.is_empty());
1723    }
1724
1725    #[test]
1726    fn test_fn_arg_type_mismatch() {
1727        let errs = errors(
1728            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1729add("hello", 2) }"#,
1730        );
1731        assert_eq!(errs.len(), 1);
1732        assert!(errs[0].contains("Argument 1"));
1733        assert!(errs[0].contains("expected int"));
1734    }
1735
1736    #[test]
1737    fn test_return_type_mismatch() {
1738        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1739        assert_eq!(errs.len(), 1);
1740        assert!(errs[0].contains("Return type mismatch"));
1741    }
1742
1743    #[test]
1744    fn test_union_type_compatible() {
1745        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1746        assert!(errs.is_empty());
1747    }
1748
1749    #[test]
1750    fn test_union_type_mismatch() {
1751        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1752        assert_eq!(errs.len(), 1);
1753        assert!(errs[0].contains("Type mismatch"));
1754    }
1755
1756    #[test]
1757    fn test_type_inference_propagation() {
1758        let errs = errors(
1759            r#"pipeline t(task) {
1760  fn add(a: int, b: int) -> int { return a + b }
1761  let result: string = add(1, 2)
1762}"#,
1763        );
1764        assert_eq!(errs.len(), 1);
1765        assert!(errs[0].contains("Type mismatch"));
1766        assert!(errs[0].contains("string"));
1767        assert!(errs[0].contains("int"));
1768    }
1769
1770    #[test]
1771    fn test_builtin_return_type_inference() {
1772        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1773        assert_eq!(errs.len(), 1);
1774        assert!(errs[0].contains("string"));
1775        assert!(errs[0].contains("int"));
1776    }
1777
1778    #[test]
1779    fn test_binary_op_type_inference() {
1780        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1781        assert_eq!(errs.len(), 1);
1782    }
1783
1784    #[test]
1785    fn test_comparison_returns_bool() {
1786        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1787        assert!(errs.is_empty());
1788    }
1789
1790    #[test]
1791    fn test_int_float_promotion() {
1792        let errs = errors("pipeline t(task) { let x: float = 42 }");
1793        assert!(errs.is_empty());
1794    }
1795
1796    #[test]
1797    fn test_untyped_code_no_errors() {
1798        let errs = errors(
1799            r#"pipeline t(task) {
1800  fn process(data) {
1801    let result = data + " processed"
1802    return result
1803  }
1804  log(process("hello"))
1805}"#,
1806        );
1807        assert!(errs.is_empty());
1808    }
1809
1810    #[test]
1811    fn test_type_alias() {
1812        let errs = errors(
1813            r#"pipeline t(task) {
1814  type Name = string
1815  let x: Name = "hello"
1816}"#,
1817        );
1818        assert!(errs.is_empty());
1819    }
1820
1821    #[test]
1822    fn test_type_alias_mismatch() {
1823        let errs = errors(
1824            r#"pipeline t(task) {
1825  type Name = string
1826  let x: Name = 42
1827}"#,
1828        );
1829        assert_eq!(errs.len(), 1);
1830    }
1831
1832    #[test]
1833    fn test_assignment_type_check() {
1834        let errs = errors(
1835            r#"pipeline t(task) {
1836  var x: int = 0
1837  x = "hello"
1838}"#,
1839        );
1840        assert_eq!(errs.len(), 1);
1841        assert!(errs[0].contains("cannot assign string"));
1842    }
1843
1844    #[test]
1845    fn test_covariance_int_to_float_in_fn() {
1846        let errs = errors(
1847            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1848        );
1849        assert!(errs.is_empty());
1850    }
1851
1852    #[test]
1853    fn test_covariance_return_type() {
1854        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1855        assert!(errs.is_empty());
1856    }
1857
1858    #[test]
1859    fn test_no_contravariance_float_to_int() {
1860        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1861        assert_eq!(errs.len(), 1);
1862    }
1863
1864    // --- Exhaustiveness checking tests ---
1865
1866    fn warnings(source: &str) -> Vec<String> {
1867        check_source(source)
1868            .into_iter()
1869            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1870            .map(|d| d.message)
1871            .collect()
1872    }
1873
1874    #[test]
1875    fn test_exhaustive_match_no_warning() {
1876        let warns = warnings(
1877            r#"pipeline t(task) {
1878  enum Color { Red, Green, Blue }
1879  let c = Color.Red
1880  match c.variant {
1881    "Red" -> { log("r") }
1882    "Green" -> { log("g") }
1883    "Blue" -> { log("b") }
1884  }
1885}"#,
1886        );
1887        let exhaustive_warns: Vec<_> = warns
1888            .iter()
1889            .filter(|w| w.contains("Non-exhaustive"))
1890            .collect();
1891        assert!(exhaustive_warns.is_empty());
1892    }
1893
1894    #[test]
1895    fn test_non_exhaustive_match_warning() {
1896        let warns = warnings(
1897            r#"pipeline t(task) {
1898  enum Color { Red, Green, Blue }
1899  let c = Color.Red
1900  match c.variant {
1901    "Red" -> { log("r") }
1902    "Green" -> { log("g") }
1903  }
1904}"#,
1905        );
1906        let exhaustive_warns: Vec<_> = warns
1907            .iter()
1908            .filter(|w| w.contains("Non-exhaustive"))
1909            .collect();
1910        assert_eq!(exhaustive_warns.len(), 1);
1911        assert!(exhaustive_warns[0].contains("Blue"));
1912    }
1913
1914    #[test]
1915    fn test_non_exhaustive_multiple_missing() {
1916        let warns = warnings(
1917            r#"pipeline t(task) {
1918  enum Status { Active, Inactive, Pending }
1919  let s = Status.Active
1920  match s.variant {
1921    "Active" -> { log("a") }
1922  }
1923}"#,
1924        );
1925        let exhaustive_warns: Vec<_> = warns
1926            .iter()
1927            .filter(|w| w.contains("Non-exhaustive"))
1928            .collect();
1929        assert_eq!(exhaustive_warns.len(), 1);
1930        assert!(exhaustive_warns[0].contains("Inactive"));
1931        assert!(exhaustive_warns[0].contains("Pending"));
1932    }
1933
1934    #[test]
1935    fn test_enum_construct_type_inference() {
1936        let errs = errors(
1937            r#"pipeline t(task) {
1938  enum Color { Red, Green, Blue }
1939  let c: Color = Color.Red
1940}"#,
1941        );
1942        assert!(errs.is_empty());
1943    }
1944
1945    // --- Type narrowing tests ---
1946
1947    #[test]
1948    fn test_nil_coalescing_strips_nil() {
1949        // After ??, nil should be stripped from the type
1950        let errs = errors(
1951            r#"pipeline t(task) {
1952  let x: string | nil = nil
1953  let y: string = x ?? "default"
1954}"#,
1955        );
1956        assert!(errs.is_empty());
1957    }
1958
1959    #[test]
1960    fn test_shape_mismatch_detail_missing_field() {
1961        let errs = errors(
1962            r#"pipeline t(task) {
1963  let x: {name: string, age: int} = {name: "hello"}
1964}"#,
1965        );
1966        assert_eq!(errs.len(), 1);
1967        assert!(
1968            errs[0].contains("missing field 'age'"),
1969            "expected detail about missing field, got: {}",
1970            errs[0]
1971        );
1972    }
1973
1974    #[test]
1975    fn test_shape_mismatch_detail_wrong_type() {
1976        let errs = errors(
1977            r#"pipeline t(task) {
1978  let x: {name: string, age: int} = {name: 42, age: 10}
1979}"#,
1980        );
1981        assert_eq!(errs.len(), 1);
1982        assert!(
1983            errs[0].contains("field 'name' has type int, expected string"),
1984            "expected detail about wrong type, got: {}",
1985            errs[0]
1986        );
1987    }
1988
1989    // --- Match pattern type validation tests ---
1990
1991    #[test]
1992    fn test_match_pattern_string_against_int() {
1993        let warns = warnings(
1994            r#"pipeline t(task) {
1995  let x: int = 42
1996  match x {
1997    "hello" -> { log("bad") }
1998    42 -> { log("ok") }
1999  }
2000}"#,
2001        );
2002        let pattern_warns: Vec<_> = warns
2003            .iter()
2004            .filter(|w| w.contains("Match pattern type mismatch"))
2005            .collect();
2006        assert_eq!(pattern_warns.len(), 1);
2007        assert!(pattern_warns[0].contains("matching int against string literal"));
2008    }
2009
2010    #[test]
2011    fn test_match_pattern_int_against_string() {
2012        let warns = warnings(
2013            r#"pipeline t(task) {
2014  let x: string = "hello"
2015  match x {
2016    42 -> { log("bad") }
2017    "hello" -> { log("ok") }
2018  }
2019}"#,
2020        );
2021        let pattern_warns: Vec<_> = warns
2022            .iter()
2023            .filter(|w| w.contains("Match pattern type mismatch"))
2024            .collect();
2025        assert_eq!(pattern_warns.len(), 1);
2026        assert!(pattern_warns[0].contains("matching string against int literal"));
2027    }
2028
2029    #[test]
2030    fn test_match_pattern_bool_against_int() {
2031        let warns = warnings(
2032            r#"pipeline t(task) {
2033  let x: int = 42
2034  match x {
2035    true -> { log("bad") }
2036    42 -> { log("ok") }
2037  }
2038}"#,
2039        );
2040        let pattern_warns: Vec<_> = warns
2041            .iter()
2042            .filter(|w| w.contains("Match pattern type mismatch"))
2043            .collect();
2044        assert_eq!(pattern_warns.len(), 1);
2045        assert!(pattern_warns[0].contains("matching int against bool literal"));
2046    }
2047
2048    #[test]
2049    fn test_match_pattern_float_against_string() {
2050        let warns = warnings(
2051            r#"pipeline t(task) {
2052  let x: string = "hello"
2053  match x {
2054    3.14 -> { log("bad") }
2055    "hello" -> { log("ok") }
2056  }
2057}"#,
2058        );
2059        let pattern_warns: Vec<_> = warns
2060            .iter()
2061            .filter(|w| w.contains("Match pattern type mismatch"))
2062            .collect();
2063        assert_eq!(pattern_warns.len(), 1);
2064        assert!(pattern_warns[0].contains("matching string against float literal"));
2065    }
2066
2067    #[test]
2068    fn test_match_pattern_int_against_float_ok() {
2069        // int and float are compatible for match patterns
2070        let warns = warnings(
2071            r#"pipeline t(task) {
2072  let x: float = 3.14
2073  match x {
2074    42 -> { log("ok") }
2075    _ -> { log("default") }
2076  }
2077}"#,
2078        );
2079        let pattern_warns: Vec<_> = warns
2080            .iter()
2081            .filter(|w| w.contains("Match pattern type mismatch"))
2082            .collect();
2083        assert!(pattern_warns.is_empty());
2084    }
2085
2086    #[test]
2087    fn test_match_pattern_float_against_int_ok() {
2088        // float and int are compatible for match patterns
2089        let warns = warnings(
2090            r#"pipeline t(task) {
2091  let x: int = 42
2092  match x {
2093    3.14 -> { log("close") }
2094    _ -> { log("default") }
2095  }
2096}"#,
2097        );
2098        let pattern_warns: Vec<_> = warns
2099            .iter()
2100            .filter(|w| w.contains("Match pattern type mismatch"))
2101            .collect();
2102        assert!(pattern_warns.is_empty());
2103    }
2104
2105    #[test]
2106    fn test_match_pattern_correct_types_no_warning() {
2107        let warns = warnings(
2108            r#"pipeline t(task) {
2109  let x: int = 42
2110  match x {
2111    1 -> { log("one") }
2112    2 -> { log("two") }
2113    _ -> { log("other") }
2114  }
2115}"#,
2116        );
2117        let pattern_warns: Vec<_> = warns
2118            .iter()
2119            .filter(|w| w.contains("Match pattern type mismatch"))
2120            .collect();
2121        assert!(pattern_warns.is_empty());
2122    }
2123
2124    #[test]
2125    fn test_match_pattern_wildcard_no_warning() {
2126        let warns = warnings(
2127            r#"pipeline t(task) {
2128  let x: int = 42
2129  match x {
2130    _ -> { log("catch all") }
2131  }
2132}"#,
2133        );
2134        let pattern_warns: Vec<_> = warns
2135            .iter()
2136            .filter(|w| w.contains("Match pattern type mismatch"))
2137            .collect();
2138        assert!(pattern_warns.is_empty());
2139    }
2140
2141    #[test]
2142    fn test_match_pattern_untyped_no_warning() {
2143        // When value has no known type, no warning should be emitted
2144        let warns = warnings(
2145            r#"pipeline t(task) {
2146  let x = some_unknown_fn()
2147  match x {
2148    "hello" -> { log("string") }
2149    42 -> { log("int") }
2150  }
2151}"#,
2152        );
2153        let pattern_warns: Vec<_> = warns
2154            .iter()
2155            .filter(|w| w.contains("Match pattern type mismatch"))
2156            .collect();
2157        assert!(pattern_warns.is_empty());
2158    }
2159}