Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Impl block methods: type_name → method signatures.
39    impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
40    /// Generic type parameter names in scope (treated as compatible with any type).
41    generic_type_params: std::collections::BTreeSet<String>,
42    parent: Option<Box<TypeScope>>,
43}
44
45/// Method signature extracted from an impl block (for interface checking).
46#[derive(Debug, Clone)]
47struct ImplMethodSig {
48    name: String,
49    /// Number of parameters excluding `self`.
50    param_count: usize,
51}
52
53#[derive(Debug, Clone)]
54struct FnSignature {
55    params: Vec<(String, InferredType)>,
56    return_type: InferredType,
57    /// Generic type parameter names declared on the function.
58    type_param_names: Vec<String>,
59    /// Number of required parameters (those without defaults).
60    required_params: usize,
61    /// Where-clause constraints: (type_param_name, interface_bound).
62    where_clauses: Vec<(String, String)>,
63}
64
65impl TypeScope {
66    fn new() -> Self {
67        Self {
68            vars: BTreeMap::new(),
69            functions: BTreeMap::new(),
70            type_aliases: BTreeMap::new(),
71            enums: BTreeMap::new(),
72            interfaces: BTreeMap::new(),
73            structs: BTreeMap::new(),
74            impl_methods: BTreeMap::new(),
75            generic_type_params: std::collections::BTreeSet::new(),
76            parent: None,
77        }
78    }
79
80    fn child(&self) -> Self {
81        Self {
82            vars: BTreeMap::new(),
83            functions: BTreeMap::new(),
84            type_aliases: BTreeMap::new(),
85            enums: BTreeMap::new(),
86            interfaces: BTreeMap::new(),
87            structs: BTreeMap::new(),
88            impl_methods: BTreeMap::new(),
89            generic_type_params: std::collections::BTreeSet::new(),
90            parent: Some(Box::new(self.clone())),
91        }
92    }
93
94    fn get_var(&self, name: &str) -> Option<&InferredType> {
95        self.vars
96            .get(name)
97            .or_else(|| self.parent.as_ref()?.get_var(name))
98    }
99
100    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
101        self.functions
102            .get(name)
103            .or_else(|| self.parent.as_ref()?.get_fn(name))
104    }
105
106    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
107        self.type_aliases
108            .get(name)
109            .or_else(|| self.parent.as_ref()?.resolve_type(name))
110    }
111
112    fn is_generic_type_param(&self, name: &str) -> bool {
113        self.generic_type_params.contains(name)
114            || self
115                .parent
116                .as_ref()
117                .is_some_and(|p| p.is_generic_type_param(name))
118    }
119
120    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
121        self.enums
122            .get(name)
123            .or_else(|| self.parent.as_ref()?.get_enum(name))
124    }
125
126    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
127        self.interfaces
128            .get(name)
129            .or_else(|| self.parent.as_ref()?.get_interface(name))
130    }
131
132    fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
133        self.structs
134            .get(name)
135            .or_else(|| self.parent.as_ref()?.get_struct(name))
136    }
137
138    fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
139        self.impl_methods
140            .get(name)
141            .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
142    }
143
144    fn define_var(&mut self, name: &str, ty: InferredType) {
145        self.vars.insert(name.to_string(), ty);
146    }
147
148    fn define_fn(&mut self, name: &str, sig: FnSignature) {
149        self.functions.insert(name.to_string(), sig);
150    }
151}
152
153/// Known return types for builtin functions.
154fn builtin_return_type(name: &str) -> InferredType {
155    match name {
156        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
157        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
158            Some(TypeExpr::Named("nil".into()))
159        }
160        "type_of"
161        | "to_string"
162        | "json_stringify"
163        | "read_file"
164        | "http_get"
165        | "http_post"
166        | "llm_call"
167        | "regex_replace"
168        | "path_join"
169        | "temp_dir"
170        | "date_format"
171        | "format"
172        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
173        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
174        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
175        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
176            Some(TypeExpr::Named("float".into()))
177        }
178        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
179            Some(TypeExpr::Named("bool".into()))
180        }
181        "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list"
182        | "regex_captures" => Some(TypeExpr::Named("list".into())),
183        "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
184        | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
185            Some(TypeExpr::Named("dict".into()))
186        }
187        "metadata_set"
188        | "metadata_save"
189        | "metadata_refresh_hashes"
190        | "invalidate_facts"
191        | "log_json"
192        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
193        "env" | "regex_match" => Some(TypeExpr::Union(vec![
194            TypeExpr::Named("string".into()),
195            TypeExpr::Named("nil".into()),
196        ])),
197        "json_parse" | "json_extract" => None, // could be any type
198        _ => None,
199    }
200}
201
202/// Check if a name is a known builtin.
203fn is_builtin(name: &str) -> bool {
204    matches!(
205        name,
206        "log"
207            | "print"
208            | "println"
209            | "type_of"
210            | "to_string"
211            | "to_int"
212            | "to_float"
213            | "json_stringify"
214            | "json_parse"
215            | "env"
216            | "timestamp"
217            | "sleep"
218            | "read_file"
219            | "write_file"
220            | "exit"
221            | "regex_match"
222            | "regex_replace"
223            | "regex_captures"
224            | "http_get"
225            | "http_post"
226            | "llm_call"
227            | "agent_loop"
228            | "await"
229            | "cancel"
230            | "file_exists"
231            | "delete_file"
232            | "list_dir"
233            | "mkdir"
234            | "path_join"
235            | "copy_file"
236            | "append_file"
237            | "temp_dir"
238            | "stat"
239            | "exec"
240            | "shell"
241            | "date_now"
242            | "date_format"
243            | "date_parse"
244            | "format"
245            | "json_validate"
246            | "json_extract"
247            | "trim"
248            | "lowercase"
249            | "uppercase"
250            | "split"
251            | "starts_with"
252            | "ends_with"
253            | "contains"
254            | "replace"
255            | "join"
256            | "len"
257            | "substring"
258            | "dirname"
259            | "basename"
260            | "extname"
261            | "sin"
262            | "cos"
263            | "tan"
264            | "asin"
265            | "acos"
266            | "atan"
267            | "atan2"
268            | "log2"
269            | "log10"
270            | "ln"
271            | "exp"
272            | "pi"
273            | "e"
274            | "sign"
275            | "is_nan"
276            | "is_infinite"
277            | "set"
278            | "set_add"
279            | "set_remove"
280            | "set_contains"
281            | "set_union"
282            | "set_intersect"
283            | "set_difference"
284            | "to_list"
285    )
286}
287
288/// The static type checker.
289pub struct TypeChecker {
290    diagnostics: Vec<TypeDiagnostic>,
291    scope: TypeScope,
292}
293
294impl TypeChecker {
295    pub fn new() -> Self {
296        Self {
297            diagnostics: Vec::new(),
298            scope: TypeScope::new(),
299        }
300    }
301
302    /// Check a program and return diagnostics.
303    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
304        // First pass: register type and enum declarations into root scope
305        Self::register_declarations_into(&mut self.scope, program);
306
307        // Also scan pipeline bodies for declarations
308        for snode in program {
309            if let Node::Pipeline { body, .. } = &snode.node {
310                Self::register_declarations_into(&mut self.scope, body);
311            }
312        }
313
314        // Check each top-level node
315        for snode in program {
316            match &snode.node {
317                Node::Pipeline { params, body, .. } => {
318                    let mut child = self.scope.child();
319                    for p in params {
320                        child.define_var(p, None);
321                    }
322                    self.check_block(body, &mut child);
323                }
324                Node::FnDecl {
325                    name,
326                    type_params,
327                    params,
328                    return_type,
329                    where_clauses,
330                    body,
331                    ..
332                } => {
333                    let required_params =
334                        params.iter().filter(|p| p.default_value.is_none()).count();
335                    let sig = FnSignature {
336                        params: params
337                            .iter()
338                            .map(|p| (p.name.clone(), p.type_expr.clone()))
339                            .collect(),
340                        return_type: return_type.clone(),
341                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
342                        required_params,
343                        where_clauses: where_clauses
344                            .iter()
345                            .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
346                            .collect(),
347                    };
348                    self.scope.define_fn(name, sig);
349                    self.check_fn_body(type_params, params, return_type, body);
350                }
351                _ => {
352                    let mut scope = self.scope.clone();
353                    self.check_node(snode, &mut scope);
354                    // Merge any new definitions back into the top-level scope
355                    for (name, ty) in scope.vars {
356                        self.scope.vars.entry(name).or_insert(ty);
357                    }
358                }
359            }
360        }
361
362        self.diagnostics
363    }
364
365    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
366    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
367        for snode in nodes {
368            match &snode.node {
369                Node::TypeDecl { name, type_expr } => {
370                    scope.type_aliases.insert(name.clone(), type_expr.clone());
371                }
372                Node::EnumDecl { name, variants } => {
373                    let variant_names: Vec<String> =
374                        variants.iter().map(|v| v.name.clone()).collect();
375                    scope.enums.insert(name.clone(), variant_names);
376                }
377                Node::InterfaceDecl { name, methods } => {
378                    scope.interfaces.insert(name.clone(), methods.clone());
379                }
380                Node::StructDecl { name, fields } => {
381                    let field_types: Vec<(String, InferredType)> = fields
382                        .iter()
383                        .map(|f| (f.name.clone(), f.type_expr.clone()))
384                        .collect();
385                    scope.structs.insert(name.clone(), field_types);
386                }
387                Node::ImplBlock {
388                    type_name, methods, ..
389                } => {
390                    let sigs: Vec<ImplMethodSig> = methods
391                        .iter()
392                        .filter_map(|m| {
393                            if let Node::FnDecl { name, params, .. } = &m.node {
394                                // Exclude `self` from param count
395                                let param_count =
396                                    params.iter().filter(|p| p.name != "self").count();
397                                Some(ImplMethodSig {
398                                    name: name.clone(),
399                                    param_count,
400                                })
401                            } else {
402                                None
403                            }
404                        })
405                        .collect();
406                    scope.impl_methods.insert(type_name.clone(), sigs);
407                }
408                _ => {}
409            }
410        }
411    }
412
413    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
414        for stmt in stmts {
415            self.check_node(stmt, scope);
416        }
417    }
418
419    /// Define variables from a destructuring pattern in the given scope (as unknown type).
420    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
421        match pattern {
422            BindingPattern::Identifier(name) => {
423                scope.define_var(name, None);
424            }
425            BindingPattern::Dict(fields) => {
426                for field in fields {
427                    let name = field.alias.as_deref().unwrap_or(&field.key);
428                    scope.define_var(name, None);
429                }
430            }
431            BindingPattern::List(elements) => {
432                for elem in elements {
433                    scope.define_var(&elem.name, None);
434                }
435            }
436        }
437    }
438
439    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
440        let span = snode.span;
441        match &snode.node {
442            Node::LetBinding {
443                pattern,
444                type_ann,
445                value,
446            } => {
447                let inferred = self.infer_type(value, scope);
448                if let BindingPattern::Identifier(name) = pattern {
449                    if let Some(expected) = type_ann {
450                        if let Some(actual) = &inferred {
451                            if !self.types_compatible(expected, actual, scope) {
452                                let mut msg = format!(
453                                    "Type mismatch: '{}' declared as {}, but assigned {}",
454                                    name,
455                                    format_type(expected),
456                                    format_type(actual)
457                                );
458                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
459                                    msg.push_str(&format!(" ({})", detail));
460                                }
461                                self.error_at(msg, span);
462                            }
463                        }
464                    }
465                    let ty = type_ann.clone().or(inferred);
466                    scope.define_var(name, ty);
467                } else {
468                    Self::define_pattern_vars(pattern, scope);
469                }
470            }
471
472            Node::VarBinding {
473                pattern,
474                type_ann,
475                value,
476            } => {
477                let inferred = self.infer_type(value, scope);
478                if let BindingPattern::Identifier(name) = pattern {
479                    if let Some(expected) = type_ann {
480                        if let Some(actual) = &inferred {
481                            if !self.types_compatible(expected, actual, scope) {
482                                let mut msg = format!(
483                                    "Type mismatch: '{}' declared as {}, but assigned {}",
484                                    name,
485                                    format_type(expected),
486                                    format_type(actual)
487                                );
488                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
489                                    msg.push_str(&format!(" ({})", detail));
490                                }
491                                self.error_at(msg, span);
492                            }
493                        }
494                    }
495                    let ty = type_ann.clone().or(inferred);
496                    scope.define_var(name, ty);
497                } else {
498                    Self::define_pattern_vars(pattern, scope);
499                }
500            }
501
502            Node::FnDecl {
503                name,
504                type_params,
505                params,
506                return_type,
507                where_clauses,
508                body,
509                ..
510            } => {
511                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
512                let sig = FnSignature {
513                    params: params
514                        .iter()
515                        .map(|p| (p.name.clone(), p.type_expr.clone()))
516                        .collect(),
517                    return_type: return_type.clone(),
518                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
519                    required_params,
520                    where_clauses: where_clauses
521                        .iter()
522                        .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
523                        .collect(),
524                };
525                scope.define_fn(name, sig.clone());
526                scope.define_var(name, None);
527                self.check_fn_body(type_params, params, return_type, body);
528            }
529
530            Node::FunctionCall { name, args } => {
531                self.check_call(name, args, scope, span);
532            }
533
534            Node::IfElse {
535                condition,
536                then_body,
537                else_body,
538            } => {
539                self.check_node(condition, scope);
540                let mut then_scope = scope.child();
541                // Narrow union types after nil checks: if x != nil, narrow x
542                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
543                    then_scope.define_var(&var_name, narrowed);
544                }
545                self.check_block(then_body, &mut then_scope);
546                if let Some(else_body) = else_body {
547                    let mut else_scope = scope.child();
548                    self.check_block(else_body, &mut else_scope);
549                }
550            }
551
552            Node::ForIn {
553                pattern,
554                iterable,
555                body,
556            } => {
557                self.check_node(iterable, scope);
558                let mut loop_scope = scope.child();
559                if let BindingPattern::Identifier(variable) = pattern {
560                    // Infer loop variable type from iterable
561                    let elem_type = match self.infer_type(iterable, scope) {
562                        Some(TypeExpr::List(inner)) => Some(*inner),
563                        Some(TypeExpr::Named(n)) if n == "string" => {
564                            Some(TypeExpr::Named("string".into()))
565                        }
566                        _ => None,
567                    };
568                    loop_scope.define_var(variable, elem_type);
569                } else {
570                    Self::define_pattern_vars(pattern, &mut loop_scope);
571                }
572                self.check_block(body, &mut loop_scope);
573            }
574
575            Node::WhileLoop { condition, body } => {
576                self.check_node(condition, scope);
577                let mut loop_scope = scope.child();
578                self.check_block(body, &mut loop_scope);
579            }
580
581            Node::TryCatch {
582                body,
583                error_var,
584                catch_body,
585                finally_body,
586                ..
587            } => {
588                let mut try_scope = scope.child();
589                self.check_block(body, &mut try_scope);
590                let mut catch_scope = scope.child();
591                if let Some(var) = error_var {
592                    catch_scope.define_var(var, None);
593                }
594                self.check_block(catch_body, &mut catch_scope);
595                if let Some(fb) = finally_body {
596                    let mut finally_scope = scope.child();
597                    self.check_block(fb, &mut finally_scope);
598                }
599            }
600
601            Node::TryExpr { body } => {
602                let mut try_scope = scope.child();
603                self.check_block(body, &mut try_scope);
604            }
605
606            Node::ReturnStmt {
607                value: Some(val), ..
608            } => {
609                self.check_node(val, scope);
610            }
611
612            Node::Assignment {
613                target, value, op, ..
614            } => {
615                self.check_node(value, scope);
616                if let Node::Identifier(name) = &target.node {
617                    if let Some(Some(var_type)) = scope.get_var(name) {
618                        let value_type = self.infer_type(value, scope);
619                        let assigned = if let Some(op) = op {
620                            let var_inferred = scope.get_var(name).cloned().flatten();
621                            infer_binary_op_type(op, &var_inferred, &value_type)
622                        } else {
623                            value_type
624                        };
625                        if let Some(actual) = &assigned {
626                            if !self.types_compatible(var_type, actual, scope) {
627                                self.error_at(
628                                    format!(
629                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
630                                        format_type(actual),
631                                        name,
632                                        format_type(var_type)
633                                    ),
634                                    span,
635                                );
636                            }
637                        }
638                    }
639                }
640            }
641
642            Node::TypeDecl { name, type_expr } => {
643                scope.type_aliases.insert(name.clone(), type_expr.clone());
644            }
645
646            Node::EnumDecl { name, variants } => {
647                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
648                scope.enums.insert(name.clone(), variant_names);
649            }
650
651            Node::StructDecl { name, fields } => {
652                let field_types: Vec<(String, InferredType)> = fields
653                    .iter()
654                    .map(|f| (f.name.clone(), f.type_expr.clone()))
655                    .collect();
656                scope.structs.insert(name.clone(), field_types);
657            }
658
659            Node::InterfaceDecl { name, methods } => {
660                scope.interfaces.insert(name.clone(), methods.clone());
661            }
662
663            Node::ImplBlock {
664                type_name, methods, ..
665            } => {
666                // Register impl methods for interface satisfaction checking
667                let sigs: Vec<ImplMethodSig> = methods
668                    .iter()
669                    .filter_map(|m| {
670                        if let Node::FnDecl { name, params, .. } = &m.node {
671                            let param_count = params.iter().filter(|p| p.name != "self").count();
672                            Some(ImplMethodSig {
673                                name: name.clone(),
674                                param_count,
675                            })
676                        } else {
677                            None
678                        }
679                    })
680                    .collect();
681                scope.impl_methods.insert(type_name.clone(), sigs);
682                for method_sn in methods {
683                    self.check_node(method_sn, scope);
684                }
685            }
686
687            Node::TryOperator { operand } => {
688                self.check_node(operand, scope);
689            }
690
691            Node::MatchExpr { value, arms } => {
692                self.check_node(value, scope);
693                let value_type = self.infer_type(value, scope);
694                for arm in arms {
695                    self.check_node(&arm.pattern, scope);
696                    // Check for incompatible literal pattern types
697                    if let Some(ref vt) = value_type {
698                        let value_type_name = format_type(vt);
699                        let mismatch = match &arm.pattern.node {
700                            Node::StringLiteral(_) => {
701                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
702                            }
703                            Node::IntLiteral(_) => {
704                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
705                                    && !self.types_compatible(
706                                        vt,
707                                        &TypeExpr::Named("float".into()),
708                                        scope,
709                                    )
710                            }
711                            Node::FloatLiteral(_) => {
712                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
713                                    && !self.types_compatible(
714                                        vt,
715                                        &TypeExpr::Named("int".into()),
716                                        scope,
717                                    )
718                            }
719                            Node::BoolLiteral(_) => {
720                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
721                            }
722                            _ => false,
723                        };
724                        if mismatch {
725                            let pattern_type = match &arm.pattern.node {
726                                Node::StringLiteral(_) => "string",
727                                Node::IntLiteral(_) => "int",
728                                Node::FloatLiteral(_) => "float",
729                                Node::BoolLiteral(_) => "bool",
730                                _ => unreachable!(),
731                            };
732                            self.warning_at(
733                                format!(
734                                    "Match pattern type mismatch: matching {} against {} literal",
735                                    value_type_name, pattern_type
736                                ),
737                                arm.pattern.span,
738                            );
739                        }
740                    }
741                    let mut arm_scope = scope.child();
742                    self.check_block(&arm.body, &mut arm_scope);
743                }
744                self.check_match_exhaustiveness(value, arms, scope, span);
745            }
746
747            // Recurse into nested expressions + validate binary op types
748            Node::BinaryOp { op, left, right } => {
749                self.check_node(left, scope);
750                self.check_node(right, scope);
751                // Validate operator/type compatibility
752                let lt = self.infer_type(left, scope);
753                let rt = self.infer_type(right, scope);
754                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
755                    match op.as_str() {
756                        "-" | "*" | "/" | "%" => {
757                            let numeric = ["int", "float"];
758                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
759                                self.warning_at(
760                                    format!(
761                                        "Operator '{op}' may not be valid for types {} and {}",
762                                        l, r
763                                    ),
764                                    span,
765                                );
766                            }
767                        }
768                        "+" => {
769                            // + is valid for int, float, string, list, dict
770                            let valid = ["int", "float", "string", "list", "dict"];
771                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
772                                self.warning_at(
773                                    format!(
774                                        "Operator '+' may not be valid for types {} and {}",
775                                        l, r
776                                    ),
777                                    span,
778                                );
779                            }
780                        }
781                        _ => {}
782                    }
783                }
784            }
785            Node::UnaryOp { operand, .. } => {
786                self.check_node(operand, scope);
787            }
788            Node::MethodCall { object, args, .. }
789            | Node::OptionalMethodCall { object, args, .. } => {
790                self.check_node(object, scope);
791                for arg in args {
792                    self.check_node(arg, scope);
793                }
794            }
795            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
796                self.check_node(object, scope);
797            }
798            Node::SubscriptAccess { object, index } => {
799                self.check_node(object, scope);
800                self.check_node(index, scope);
801            }
802            Node::SliceAccess { object, start, end } => {
803                self.check_node(object, scope);
804                if let Some(s) = start {
805                    self.check_node(s, scope);
806                }
807                if let Some(e) = end {
808                    self.check_node(e, scope);
809                }
810            }
811
812            // Terminals — nothing to check
813            _ => {}
814        }
815    }
816
817    fn check_fn_body(
818        &mut self,
819        type_params: &[TypeParam],
820        params: &[TypedParam],
821        return_type: &Option<TypeExpr>,
822        body: &[SNode],
823    ) {
824        let mut fn_scope = self.scope.child();
825        // Register generic type parameters so they are treated as compatible
826        // with any concrete type during type checking.
827        for tp in type_params {
828            fn_scope.generic_type_params.insert(tp.name.clone());
829        }
830        for param in params {
831            fn_scope.define_var(&param.name, param.type_expr.clone());
832            if let Some(default) = &param.default_value {
833                self.check_node(default, &mut fn_scope);
834            }
835        }
836        self.check_block(body, &mut fn_scope);
837
838        // Check return statements against declared return type
839        if let Some(ret_type) = return_type {
840            for stmt in body {
841                self.check_return_type(stmt, ret_type, &fn_scope);
842            }
843        }
844    }
845
846    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
847        let span = snode.span;
848        match &snode.node {
849            Node::ReturnStmt { value: Some(val) } => {
850                let inferred = self.infer_type(val, scope);
851                if let Some(actual) = &inferred {
852                    if !self.types_compatible(expected, actual, scope) {
853                        self.error_at(
854                            format!(
855                                "Return type mismatch: expected {}, got {}",
856                                format_type(expected),
857                                format_type(actual)
858                            ),
859                            span,
860                        );
861                    }
862                }
863            }
864            Node::IfElse {
865                then_body,
866                else_body,
867                ..
868            } => {
869                for stmt in then_body {
870                    self.check_return_type(stmt, expected, scope);
871                }
872                if let Some(else_body) = else_body {
873                    for stmt in else_body {
874                        self.check_return_type(stmt, expected, scope);
875                    }
876                }
877            }
878            _ => {}
879        }
880    }
881
882    /// Check if a match expression on an enum's `.variant` property covers all variants.
883    /// Extract narrowing info from nil-check conditions like `x != nil`.
884    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
885    /// Check if a type satisfies an interface (Go-style implicit satisfaction).
886    /// A type satisfies an interface if its impl block has all the required methods.
887    fn satisfies_interface(
888        &self,
889        type_name: &str,
890        interface_name: &str,
891        scope: &TypeScope,
892    ) -> bool {
893        let interface_methods = match scope.get_interface(interface_name) {
894            Some(methods) => methods,
895            None => return false,
896        };
897        let impl_methods = match scope.get_impl_methods(type_name) {
898            Some(methods) => methods,
899            None => return interface_methods.is_empty(),
900        };
901        interface_methods.iter().all(|iface_method| {
902            let iface_param_count = iface_method
903                .params
904                .iter()
905                .filter(|p| p.name != "self")
906                .count();
907            impl_methods.iter().any(|impl_method| {
908                impl_method.name == iface_method.name
909                    && impl_method.param_count == iface_param_count
910            })
911        })
912    }
913
914    fn extract_nil_narrowing(
915        condition: &SNode,
916        scope: &TypeScope,
917    ) -> Option<(String, InferredType)> {
918        if let Node::BinaryOp { op, left, right } = &condition.node {
919            if op == "!=" {
920                // Check for `x != nil` or `nil != x`
921                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
922                    (left, right)
923                } else if matches!(left.node, Node::NilLiteral) {
924                    (right, left)
925                } else {
926                    return None;
927                };
928                let _ = nil_node;
929                if let Node::Identifier(name) = &var_node.node {
930                    // Look up the variable's type and narrow it
931                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
932                        let narrowed: Vec<TypeExpr> = members
933                            .iter()
934                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
935                            .cloned()
936                            .collect();
937                        return if narrowed.len() == 1 {
938                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
939                        } else if narrowed.is_empty() {
940                            None
941                        } else {
942                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
943                        };
944                    }
945                }
946            }
947        }
948        None
949    }
950
951    fn check_match_exhaustiveness(
952        &mut self,
953        value: &SNode,
954        arms: &[MatchArm],
955        scope: &TypeScope,
956        span: Span,
957    ) {
958        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
959        let enum_name = match &value.node {
960            Node::PropertyAccess { object, property } if property == "variant" => {
961                // Infer the type of the object
962                match self.infer_type(object, scope) {
963                    Some(TypeExpr::Named(name)) => {
964                        if scope.get_enum(&name).is_some() {
965                            Some(name)
966                        } else {
967                            None
968                        }
969                    }
970                    _ => None,
971                }
972            }
973            _ => {
974                // Direct match on an enum value: match <expr> { ... }
975                match self.infer_type(value, scope) {
976                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
977                    _ => None,
978                }
979            }
980        };
981
982        let Some(enum_name) = enum_name else {
983            return;
984        };
985        let Some(variants) = scope.get_enum(&enum_name) else {
986            return;
987        };
988
989        // Collect variant names covered by match arms
990        let mut covered: Vec<String> = Vec::new();
991        let mut has_wildcard = false;
992
993        for arm in arms {
994            match &arm.pattern.node {
995                // String literal pattern (matching on .variant): "VariantA"
996                Node::StringLiteral(s) => covered.push(s.clone()),
997                // Identifier pattern acts as a wildcard/catch-all
998                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
999                    has_wildcard = true;
1000                }
1001                // Direct enum construct pattern: EnumName.Variant
1002                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1003                // PropertyAccess pattern: EnumName.Variant (no args)
1004                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1005                _ => {
1006                    // Unknown pattern shape — conservatively treat as wildcard
1007                    has_wildcard = true;
1008                }
1009            }
1010        }
1011
1012        if has_wildcard {
1013            return;
1014        }
1015
1016        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1017        if !missing.is_empty() {
1018            let missing_str = missing
1019                .iter()
1020                .map(|s| format!("\"{}\"", s))
1021                .collect::<Vec<_>>()
1022                .join(", ");
1023            self.warning_at(
1024                format!(
1025                    "Non-exhaustive match on enum {}: missing variants {}",
1026                    enum_name, missing_str
1027                ),
1028                span,
1029            );
1030        }
1031    }
1032
1033    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1034        // Check against known function signatures
1035        let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1036        if let Some(sig) = scope.get_fn(name).cloned() {
1037            if !has_spread
1038                && !is_builtin(name)
1039                && (args.len() < sig.required_params || args.len() > sig.params.len())
1040            {
1041                let expected = if sig.required_params == sig.params.len() {
1042                    format!("{}", sig.params.len())
1043                } else {
1044                    format!("{}-{}", sig.required_params, sig.params.len())
1045                };
1046                self.warning_at(
1047                    format!(
1048                        "Function '{}' expects {} arguments, got {}",
1049                        name,
1050                        expected,
1051                        args.len()
1052                    ),
1053                    span,
1054                );
1055            }
1056            // Build a scope that includes the function's generic type params
1057            // so they are treated as compatible with any concrete type.
1058            let call_scope = if sig.type_param_names.is_empty() {
1059                scope.clone()
1060            } else {
1061                let mut s = scope.child();
1062                for tp_name in &sig.type_param_names {
1063                    s.generic_type_params.insert(tp_name.clone());
1064                }
1065                s
1066            };
1067            for (i, (arg, (param_name, param_type))) in
1068                args.iter().zip(sig.params.iter()).enumerate()
1069            {
1070                if let Some(expected) = param_type {
1071                    let actual = self.infer_type(arg, scope);
1072                    if let Some(actual) = &actual {
1073                        if !self.types_compatible(expected, actual, &call_scope) {
1074                            self.error_at(
1075                                format!(
1076                                    "Argument {} ('{}'): expected {}, got {}",
1077                                    i + 1,
1078                                    param_name,
1079                                    format_type(expected),
1080                                    format_type(actual)
1081                                ),
1082                                arg.span,
1083                            );
1084                        }
1085                    }
1086                }
1087            }
1088            // Enforce where-clause constraints at call site
1089            if !sig.where_clauses.is_empty() {
1090                // Build mapping: type_param → concrete type from inferred args
1091                let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1092                for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1093                    if let Some(TypeExpr::Named(param_ty)) = param_type {
1094                        if sig.type_param_names.contains(param_ty) {
1095                            if let Some(TypeExpr::Named(concrete)) = self.infer_type(arg, scope) {
1096                                type_bindings.entry(param_ty.clone()).or_insert(concrete);
1097                            }
1098                        }
1099                    }
1100                }
1101                for (type_param, bound) in &sig.where_clauses {
1102                    if let Some(concrete_type) = type_bindings.get(type_param) {
1103                        if !self.satisfies_interface(concrete_type, bound, scope) {
1104                            self.warning_at(
1105                                format!(
1106                                    "Type '{}' does not satisfy interface '{}': \
1107                                     required by constraint `where {}: {}`",
1108                                    concrete_type, bound, type_param, bound
1109                                ),
1110                                span,
1111                            );
1112                        }
1113                    }
1114                }
1115            }
1116        }
1117        // Check args recursively
1118        for arg in args {
1119            self.check_node(arg, scope);
1120        }
1121    }
1122
1123    /// Infer the type of an expression.
1124    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1125        match &snode.node {
1126            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1127            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1128            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1129                Some(TypeExpr::Named("string".into()))
1130            }
1131            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1132            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1133            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1134            Node::DictLiteral(entries) => {
1135                // Infer shape type when all keys are string literals
1136                let mut fields = Vec::new();
1137                let mut all_string_keys = true;
1138                for entry in entries {
1139                    if let Node::StringLiteral(key) = &entry.key.node {
1140                        let val_type = self
1141                            .infer_type(&entry.value, scope)
1142                            .unwrap_or(TypeExpr::Named("nil".into()));
1143                        fields.push(ShapeField {
1144                            name: key.clone(),
1145                            type_expr: val_type,
1146                            optional: false,
1147                        });
1148                    } else {
1149                        all_string_keys = false;
1150                        break;
1151                    }
1152                }
1153                if all_string_keys && !fields.is_empty() {
1154                    Some(TypeExpr::Shape(fields))
1155                } else {
1156                    Some(TypeExpr::Named("dict".into()))
1157                }
1158            }
1159            Node::Closure { params, body } => {
1160                // If all params are typed and we can infer a return type, produce FnType
1161                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1162                if all_typed && !params.is_empty() {
1163                    let param_types: Vec<TypeExpr> =
1164                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1165                    // Try to infer return type from last expression in body
1166                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1167                    if let Some(ret_type) = ret {
1168                        return Some(TypeExpr::FnType {
1169                            params: param_types,
1170                            return_type: Box::new(ret_type),
1171                        });
1172                    }
1173                }
1174                Some(TypeExpr::Named("closure".into()))
1175            }
1176
1177            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1178
1179            Node::FunctionCall { name, .. } => {
1180                // Struct constructor calls return the struct type
1181                if scope.get_struct(name).is_some() {
1182                    return Some(TypeExpr::Named(name.clone()));
1183                }
1184                // Check user-defined function return types
1185                if let Some(sig) = scope.get_fn(name) {
1186                    return sig.return_type.clone();
1187                }
1188                // Check builtin return types
1189                builtin_return_type(name)
1190            }
1191
1192            Node::BinaryOp { op, left, right } => {
1193                let lt = self.infer_type(left, scope);
1194                let rt = self.infer_type(right, scope);
1195                infer_binary_op_type(op, &lt, &rt)
1196            }
1197
1198            Node::UnaryOp { op, operand } => {
1199                let t = self.infer_type(operand, scope);
1200                match op.as_str() {
1201                    "!" => Some(TypeExpr::Named("bool".into())),
1202                    "-" => t, // negation preserves type
1203                    _ => None,
1204                }
1205            }
1206
1207            Node::Ternary {
1208                true_expr,
1209                false_expr,
1210                ..
1211            } => {
1212                let tt = self.infer_type(true_expr, scope);
1213                let ft = self.infer_type(false_expr, scope);
1214                match (&tt, &ft) {
1215                    (Some(a), Some(b)) if a == b => tt,
1216                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1217                    (Some(_), None) => tt,
1218                    (None, Some(_)) => ft,
1219                    (None, None) => None,
1220                }
1221            }
1222
1223            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1224
1225            Node::PropertyAccess { object, property } => {
1226                // EnumName.Variant → infer as the enum type
1227                if let Node::Identifier(name) = &object.node {
1228                    if scope.get_enum(name).is_some() {
1229                        return Some(TypeExpr::Named(name.clone()));
1230                    }
1231                }
1232                // .variant on an enum value → string
1233                if property == "variant" {
1234                    let obj_type = self.infer_type(object, scope);
1235                    if let Some(TypeExpr::Named(name)) = &obj_type {
1236                        if scope.get_enum(name).is_some() {
1237                            return Some(TypeExpr::Named("string".into()));
1238                        }
1239                    }
1240                }
1241                // Shape field access: obj.field → field type
1242                let obj_type = self.infer_type(object, scope);
1243                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1244                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1245                        return Some(field.type_expr.clone());
1246                    }
1247                }
1248                None
1249            }
1250
1251            Node::SubscriptAccess { object, index } => {
1252                let obj_type = self.infer_type(object, scope);
1253                match &obj_type {
1254                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1255                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1256                    Some(TypeExpr::Shape(fields)) => {
1257                        // If index is a string literal, look up the field type
1258                        if let Node::StringLiteral(key) = &index.node {
1259                            fields
1260                                .iter()
1261                                .find(|f| &f.name == key)
1262                                .map(|f| f.type_expr.clone())
1263                        } else {
1264                            None
1265                        }
1266                    }
1267                    Some(TypeExpr::Named(n)) if n == "list" => None,
1268                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1269                    Some(TypeExpr::Named(n)) if n == "string" => {
1270                        Some(TypeExpr::Named("string".into()))
1271                    }
1272                    _ => None,
1273                }
1274            }
1275            Node::SliceAccess { object, .. } => {
1276                // Slicing a list returns the same list type; slicing a string returns string
1277                let obj_type = self.infer_type(object, scope);
1278                match &obj_type {
1279                    Some(TypeExpr::List(_)) => obj_type,
1280                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1281                    Some(TypeExpr::Named(n)) if n == "string" => {
1282                        Some(TypeExpr::Named("string".into()))
1283                    }
1284                    _ => None,
1285                }
1286            }
1287            Node::MethodCall { object, method, .. }
1288            | Node::OptionalMethodCall { object, method, .. } => {
1289                let obj_type = self.infer_type(object, scope);
1290                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1291                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1292                match method.as_str() {
1293                    // Shared: bool-returning methods
1294                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1295                        Some(TypeExpr::Named("bool".into()))
1296                    }
1297                    // Shared: int-returning methods
1298                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1299                    // String methods
1300                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1301                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1302                        Some(TypeExpr::Named("string".into()))
1303                    }
1304                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1305                    // filter returns dict for dicts, list for lists
1306                    "filter" => {
1307                        if is_dict {
1308                            Some(TypeExpr::Named("dict".into()))
1309                        } else {
1310                            Some(TypeExpr::Named("list".into()))
1311                        }
1312                    }
1313                    // List methods
1314                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1315                    "reduce" | "find" | "first" | "last" => None,
1316                    // Dict methods
1317                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1318                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1319                    // Conversions
1320                    "to_string" => Some(TypeExpr::Named("string".into())),
1321                    "to_int" => Some(TypeExpr::Named("int".into())),
1322                    "to_float" => Some(TypeExpr::Named("float".into())),
1323                    _ => None,
1324                }
1325            }
1326
1327            // TryOperator on Result<T, E> produces T
1328            Node::TryOperator { operand } => {
1329                match self.infer_type(operand, scope) {
1330                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1331                    _ => None,
1332                }
1333            }
1334
1335            _ => None,
1336        }
1337    }
1338
1339    /// Check if two types are compatible (actual can be assigned to expected).
1340    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1341        // Generic type parameters match anything.
1342        if let TypeExpr::Named(name) = expected {
1343            if scope.is_generic_type_param(name) {
1344                return true;
1345            }
1346        }
1347        if let TypeExpr::Named(name) = actual {
1348            if scope.is_generic_type_param(name) {
1349                return true;
1350            }
1351        }
1352        let expected = self.resolve_alias(expected, scope);
1353        let actual = self.resolve_alias(actual, scope);
1354
1355        // Interface satisfaction: if expected is an interface name, check if actual type
1356        // has all required methods (Go-style implicit satisfaction).
1357        if let TypeExpr::Named(iface_name) = &expected {
1358            if scope.get_interface(iface_name).is_some() {
1359                if let TypeExpr::Named(type_name) = &actual {
1360                    return self.satisfies_interface(type_name, iface_name, scope);
1361                }
1362                return false;
1363            }
1364        }
1365
1366        match (&expected, &actual) {
1367            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1368            (TypeExpr::Union(members), actual_type) => members
1369                .iter()
1370                .any(|m| self.types_compatible(m, actual_type, scope)),
1371            (expected_type, TypeExpr::Union(members)) => members
1372                .iter()
1373                .all(|m| self.types_compatible(expected_type, m, scope)),
1374            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1375            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1376            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1377                if expected_field.optional {
1378                    return true;
1379                }
1380                af.iter().any(|actual_field| {
1381                    actual_field.name == expected_field.name
1382                        && self.types_compatible(
1383                            &expected_field.type_expr,
1384                            &actual_field.type_expr,
1385                            scope,
1386                        )
1387                })
1388            }),
1389            // dict<K, V> expected, Shape actual → all field values must match V
1390            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1391                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1392                keys_ok
1393                    && af
1394                        .iter()
1395                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1396            }
1397            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1398            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1399            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1400                self.types_compatible(expected_inner, actual_inner, scope)
1401            }
1402            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1403            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1404            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1405                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1406            }
1407            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1408            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1409            // FnType compatibility: params match positionally and return types match
1410            (
1411                TypeExpr::FnType {
1412                    params: ep,
1413                    return_type: er,
1414                },
1415                TypeExpr::FnType {
1416                    params: ap,
1417                    return_type: ar,
1418                },
1419            ) => {
1420                ep.len() == ap.len()
1421                    && ep
1422                        .iter()
1423                        .zip(ap.iter())
1424                        .all(|(e, a)| self.types_compatible(e, a, scope))
1425                    && self.types_compatible(er, ar, scope)
1426            }
1427            // FnType is compatible with Named("closure") for backward compat
1428            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1429            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1430            _ => false,
1431        }
1432    }
1433
1434    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1435        if let TypeExpr::Named(name) = ty {
1436            if let Some(resolved) = scope.resolve_type(name) {
1437                return resolved.clone();
1438            }
1439        }
1440        ty.clone()
1441    }
1442
1443    fn error_at(&mut self, message: String, span: Span) {
1444        self.diagnostics.push(TypeDiagnostic {
1445            message,
1446            severity: DiagnosticSeverity::Error,
1447            span: Some(span),
1448        });
1449    }
1450
1451    fn warning_at(&mut self, message: String, span: Span) {
1452        self.diagnostics.push(TypeDiagnostic {
1453            message,
1454            severity: DiagnosticSeverity::Warning,
1455            span: Some(span),
1456        });
1457    }
1458}
1459
1460impl Default for TypeChecker {
1461    fn default() -> Self {
1462        Self::new()
1463    }
1464}
1465
1466/// Infer the result type of a binary operation.
1467fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1468    match op {
1469        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1470        "+" => match (left, right) {
1471            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1472                match (l.as_str(), r.as_str()) {
1473                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1474                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1475                    ("string", _) => Some(TypeExpr::Named("string".into())),
1476                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1477                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1478                    _ => Some(TypeExpr::Named("string".into())),
1479                }
1480            }
1481            _ => None,
1482        },
1483        "-" | "*" | "/" | "%" => match (left, right) {
1484            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1485                match (l.as_str(), r.as_str()) {
1486                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1487                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1488                    _ => None,
1489                }
1490            }
1491            _ => None,
1492        },
1493        "??" => match (left, right) {
1494            (Some(TypeExpr::Union(members)), _) => {
1495                let non_nil: Vec<_> = members
1496                    .iter()
1497                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1498                    .cloned()
1499                    .collect();
1500                if non_nil.len() == 1 {
1501                    Some(non_nil[0].clone())
1502                } else if non_nil.is_empty() {
1503                    right.clone()
1504                } else {
1505                    Some(TypeExpr::Union(non_nil))
1506                }
1507            }
1508            _ => right.clone(),
1509        },
1510        "|>" => None,
1511        _ => None,
1512    }
1513}
1514
1515/// Format a type expression for display in error messages.
1516/// Produce a detail string describing why a Shape type is incompatible with
1517/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1518/// has type int, expected string".  Returns `None` if both types are not shapes.
1519pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1520    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1521        let mut details = Vec::new();
1522        for field in ef {
1523            if field.optional {
1524                continue;
1525            }
1526            match af.iter().find(|f| f.name == field.name) {
1527                None => details.push(format!(
1528                    "missing field '{}' ({})",
1529                    field.name,
1530                    format_type(&field.type_expr)
1531                )),
1532                Some(actual_field) => {
1533                    let e_str = format_type(&field.type_expr);
1534                    let a_str = format_type(&actual_field.type_expr);
1535                    if e_str != a_str {
1536                        details.push(format!(
1537                            "field '{}' has type {}, expected {}",
1538                            field.name, a_str, e_str
1539                        ));
1540                    }
1541                }
1542            }
1543        }
1544        if details.is_empty() {
1545            None
1546        } else {
1547            Some(details.join("; "))
1548        }
1549    } else {
1550        None
1551    }
1552}
1553
1554pub fn format_type(ty: &TypeExpr) -> String {
1555    match ty {
1556        TypeExpr::Named(n) => n.clone(),
1557        TypeExpr::Union(types) => types
1558            .iter()
1559            .map(format_type)
1560            .collect::<Vec<_>>()
1561            .join(" | "),
1562        TypeExpr::Shape(fields) => {
1563            let inner: Vec<String> = fields
1564                .iter()
1565                .map(|f| {
1566                    let opt = if f.optional { "?" } else { "" };
1567                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1568                })
1569                .collect();
1570            format!("{{{}}}", inner.join(", "))
1571        }
1572        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1573        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1574        TypeExpr::FnType {
1575            params,
1576            return_type,
1577        } => {
1578            let params_str = params
1579                .iter()
1580                .map(format_type)
1581                .collect::<Vec<_>>()
1582                .join(", ");
1583            format!("fn({}) -> {}", params_str, format_type(return_type))
1584        }
1585    }
1586}
1587
1588#[cfg(test)]
1589mod tests {
1590    use super::*;
1591    use crate::Parser;
1592    use harn_lexer::Lexer;
1593
1594    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1595        let mut lexer = Lexer::new(source);
1596        let tokens = lexer.tokenize().unwrap();
1597        let mut parser = Parser::new(tokens);
1598        let program = parser.parse().unwrap();
1599        TypeChecker::new().check(&program)
1600    }
1601
1602    fn errors(source: &str) -> Vec<String> {
1603        check_source(source)
1604            .into_iter()
1605            .filter(|d| d.severity == DiagnosticSeverity::Error)
1606            .map(|d| d.message)
1607            .collect()
1608    }
1609
1610    #[test]
1611    fn test_no_errors_for_untyped_code() {
1612        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1613        assert!(errs.is_empty());
1614    }
1615
1616    #[test]
1617    fn test_correct_typed_let() {
1618        let errs = errors("pipeline t(task) { let x: int = 42 }");
1619        assert!(errs.is_empty());
1620    }
1621
1622    #[test]
1623    fn test_type_mismatch_let() {
1624        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1625        assert_eq!(errs.len(), 1);
1626        assert!(errs[0].contains("Type mismatch"));
1627        assert!(errs[0].contains("int"));
1628        assert!(errs[0].contains("string"));
1629    }
1630
1631    #[test]
1632    fn test_correct_typed_fn() {
1633        let errs = errors(
1634            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1635        );
1636        assert!(errs.is_empty());
1637    }
1638
1639    #[test]
1640    fn test_fn_arg_type_mismatch() {
1641        let errs = errors(
1642            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1643add("hello", 2) }"#,
1644        );
1645        assert_eq!(errs.len(), 1);
1646        assert!(errs[0].contains("Argument 1"));
1647        assert!(errs[0].contains("expected int"));
1648    }
1649
1650    #[test]
1651    fn test_return_type_mismatch() {
1652        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1653        assert_eq!(errs.len(), 1);
1654        assert!(errs[0].contains("Return type mismatch"));
1655    }
1656
1657    #[test]
1658    fn test_union_type_compatible() {
1659        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1660        assert!(errs.is_empty());
1661    }
1662
1663    #[test]
1664    fn test_union_type_mismatch() {
1665        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1666        assert_eq!(errs.len(), 1);
1667        assert!(errs[0].contains("Type mismatch"));
1668    }
1669
1670    #[test]
1671    fn test_type_inference_propagation() {
1672        let errs = errors(
1673            r#"pipeline t(task) {
1674  fn add(a: int, b: int) -> int { return a + b }
1675  let result: string = add(1, 2)
1676}"#,
1677        );
1678        assert_eq!(errs.len(), 1);
1679        assert!(errs[0].contains("Type mismatch"));
1680        assert!(errs[0].contains("string"));
1681        assert!(errs[0].contains("int"));
1682    }
1683
1684    #[test]
1685    fn test_builtin_return_type_inference() {
1686        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1687        assert_eq!(errs.len(), 1);
1688        assert!(errs[0].contains("string"));
1689        assert!(errs[0].contains("int"));
1690    }
1691
1692    #[test]
1693    fn test_binary_op_type_inference() {
1694        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1695        assert_eq!(errs.len(), 1);
1696    }
1697
1698    #[test]
1699    fn test_comparison_returns_bool() {
1700        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1701        assert!(errs.is_empty());
1702    }
1703
1704    #[test]
1705    fn test_int_float_promotion() {
1706        let errs = errors("pipeline t(task) { let x: float = 42 }");
1707        assert!(errs.is_empty());
1708    }
1709
1710    #[test]
1711    fn test_untyped_code_no_errors() {
1712        let errs = errors(
1713            r#"pipeline t(task) {
1714  fn process(data) {
1715    let result = data + " processed"
1716    return result
1717  }
1718  log(process("hello"))
1719}"#,
1720        );
1721        assert!(errs.is_empty());
1722    }
1723
1724    #[test]
1725    fn test_type_alias() {
1726        let errs = errors(
1727            r#"pipeline t(task) {
1728  type Name = string
1729  let x: Name = "hello"
1730}"#,
1731        );
1732        assert!(errs.is_empty());
1733    }
1734
1735    #[test]
1736    fn test_type_alias_mismatch() {
1737        let errs = errors(
1738            r#"pipeline t(task) {
1739  type Name = string
1740  let x: Name = 42
1741}"#,
1742        );
1743        assert_eq!(errs.len(), 1);
1744    }
1745
1746    #[test]
1747    fn test_assignment_type_check() {
1748        let errs = errors(
1749            r#"pipeline t(task) {
1750  var x: int = 0
1751  x = "hello"
1752}"#,
1753        );
1754        assert_eq!(errs.len(), 1);
1755        assert!(errs[0].contains("cannot assign string"));
1756    }
1757
1758    #[test]
1759    fn test_covariance_int_to_float_in_fn() {
1760        let errs = errors(
1761            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1762        );
1763        assert!(errs.is_empty());
1764    }
1765
1766    #[test]
1767    fn test_covariance_return_type() {
1768        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1769        assert!(errs.is_empty());
1770    }
1771
1772    #[test]
1773    fn test_no_contravariance_float_to_int() {
1774        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1775        assert_eq!(errs.len(), 1);
1776    }
1777
1778    // --- Exhaustiveness checking tests ---
1779
1780    fn warnings(source: &str) -> Vec<String> {
1781        check_source(source)
1782            .into_iter()
1783            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1784            .map(|d| d.message)
1785            .collect()
1786    }
1787
1788    #[test]
1789    fn test_exhaustive_match_no_warning() {
1790        let warns = warnings(
1791            r#"pipeline t(task) {
1792  enum Color { Red, Green, Blue }
1793  let c = Color.Red
1794  match c.variant {
1795    "Red" -> { log("r") }
1796    "Green" -> { log("g") }
1797    "Blue" -> { log("b") }
1798  }
1799}"#,
1800        );
1801        let exhaustive_warns: Vec<_> = warns
1802            .iter()
1803            .filter(|w| w.contains("Non-exhaustive"))
1804            .collect();
1805        assert!(exhaustive_warns.is_empty());
1806    }
1807
1808    #[test]
1809    fn test_non_exhaustive_match_warning() {
1810        let warns = warnings(
1811            r#"pipeline t(task) {
1812  enum Color { Red, Green, Blue }
1813  let c = Color.Red
1814  match c.variant {
1815    "Red" -> { log("r") }
1816    "Green" -> { log("g") }
1817  }
1818}"#,
1819        );
1820        let exhaustive_warns: Vec<_> = warns
1821            .iter()
1822            .filter(|w| w.contains("Non-exhaustive"))
1823            .collect();
1824        assert_eq!(exhaustive_warns.len(), 1);
1825        assert!(exhaustive_warns[0].contains("Blue"));
1826    }
1827
1828    #[test]
1829    fn test_non_exhaustive_multiple_missing() {
1830        let warns = warnings(
1831            r#"pipeline t(task) {
1832  enum Status { Active, Inactive, Pending }
1833  let s = Status.Active
1834  match s.variant {
1835    "Active" -> { log("a") }
1836  }
1837}"#,
1838        );
1839        let exhaustive_warns: Vec<_> = warns
1840            .iter()
1841            .filter(|w| w.contains("Non-exhaustive"))
1842            .collect();
1843        assert_eq!(exhaustive_warns.len(), 1);
1844        assert!(exhaustive_warns[0].contains("Inactive"));
1845        assert!(exhaustive_warns[0].contains("Pending"));
1846    }
1847
1848    #[test]
1849    fn test_enum_construct_type_inference() {
1850        let errs = errors(
1851            r#"pipeline t(task) {
1852  enum Color { Red, Green, Blue }
1853  let c: Color = Color.Red
1854}"#,
1855        );
1856        assert!(errs.is_empty());
1857    }
1858
1859    // --- Type narrowing tests ---
1860
1861    #[test]
1862    fn test_nil_coalescing_strips_nil() {
1863        // After ??, nil should be stripped from the type
1864        let errs = errors(
1865            r#"pipeline t(task) {
1866  let x: string | nil = nil
1867  let y: string = x ?? "default"
1868}"#,
1869        );
1870        assert!(errs.is_empty());
1871    }
1872
1873    #[test]
1874    fn test_shape_mismatch_detail_missing_field() {
1875        let errs = errors(
1876            r#"pipeline t(task) {
1877  let x: {name: string, age: int} = {name: "hello"}
1878}"#,
1879        );
1880        assert_eq!(errs.len(), 1);
1881        assert!(
1882            errs[0].contains("missing field 'age'"),
1883            "expected detail about missing field, got: {}",
1884            errs[0]
1885        );
1886    }
1887
1888    #[test]
1889    fn test_shape_mismatch_detail_wrong_type() {
1890        let errs = errors(
1891            r#"pipeline t(task) {
1892  let x: {name: string, age: int} = {name: 42, age: 10}
1893}"#,
1894        );
1895        assert_eq!(errs.len(), 1);
1896        assert!(
1897            errs[0].contains("field 'name' has type int, expected string"),
1898            "expected detail about wrong type, got: {}",
1899            errs[0]
1900        );
1901    }
1902
1903    // --- Match pattern type validation tests ---
1904
1905    #[test]
1906    fn test_match_pattern_string_against_int() {
1907        let warns = warnings(
1908            r#"pipeline t(task) {
1909  let x: int = 42
1910  match x {
1911    "hello" -> { log("bad") }
1912    42 -> { log("ok") }
1913  }
1914}"#,
1915        );
1916        let pattern_warns: Vec<_> = warns
1917            .iter()
1918            .filter(|w| w.contains("Match pattern type mismatch"))
1919            .collect();
1920        assert_eq!(pattern_warns.len(), 1);
1921        assert!(pattern_warns[0].contains("matching int against string literal"));
1922    }
1923
1924    #[test]
1925    fn test_match_pattern_int_against_string() {
1926        let warns = warnings(
1927            r#"pipeline t(task) {
1928  let x: string = "hello"
1929  match x {
1930    42 -> { log("bad") }
1931    "hello" -> { log("ok") }
1932  }
1933}"#,
1934        );
1935        let pattern_warns: Vec<_> = warns
1936            .iter()
1937            .filter(|w| w.contains("Match pattern type mismatch"))
1938            .collect();
1939        assert_eq!(pattern_warns.len(), 1);
1940        assert!(pattern_warns[0].contains("matching string against int literal"));
1941    }
1942
1943    #[test]
1944    fn test_match_pattern_bool_against_int() {
1945        let warns = warnings(
1946            r#"pipeline t(task) {
1947  let x: int = 42
1948  match x {
1949    true -> { log("bad") }
1950    42 -> { log("ok") }
1951  }
1952}"#,
1953        );
1954        let pattern_warns: Vec<_> = warns
1955            .iter()
1956            .filter(|w| w.contains("Match pattern type mismatch"))
1957            .collect();
1958        assert_eq!(pattern_warns.len(), 1);
1959        assert!(pattern_warns[0].contains("matching int against bool literal"));
1960    }
1961
1962    #[test]
1963    fn test_match_pattern_float_against_string() {
1964        let warns = warnings(
1965            r#"pipeline t(task) {
1966  let x: string = "hello"
1967  match x {
1968    3.14 -> { log("bad") }
1969    "hello" -> { log("ok") }
1970  }
1971}"#,
1972        );
1973        let pattern_warns: Vec<_> = warns
1974            .iter()
1975            .filter(|w| w.contains("Match pattern type mismatch"))
1976            .collect();
1977        assert_eq!(pattern_warns.len(), 1);
1978        assert!(pattern_warns[0].contains("matching string against float literal"));
1979    }
1980
1981    #[test]
1982    fn test_match_pattern_int_against_float_ok() {
1983        // int and float are compatible for match patterns
1984        let warns = warnings(
1985            r#"pipeline t(task) {
1986  let x: float = 3.14
1987  match x {
1988    42 -> { log("ok") }
1989    _ -> { log("default") }
1990  }
1991}"#,
1992        );
1993        let pattern_warns: Vec<_> = warns
1994            .iter()
1995            .filter(|w| w.contains("Match pattern type mismatch"))
1996            .collect();
1997        assert!(pattern_warns.is_empty());
1998    }
1999
2000    #[test]
2001    fn test_match_pattern_float_against_int_ok() {
2002        // float and int are compatible for match patterns
2003        let warns = warnings(
2004            r#"pipeline t(task) {
2005  let x: int = 42
2006  match x {
2007    3.14 -> { log("close") }
2008    _ -> { log("default") }
2009  }
2010}"#,
2011        );
2012        let pattern_warns: Vec<_> = warns
2013            .iter()
2014            .filter(|w| w.contains("Match pattern type mismatch"))
2015            .collect();
2016        assert!(pattern_warns.is_empty());
2017    }
2018
2019    #[test]
2020    fn test_match_pattern_correct_types_no_warning() {
2021        let warns = warnings(
2022            r#"pipeline t(task) {
2023  let x: int = 42
2024  match x {
2025    1 -> { log("one") }
2026    2 -> { log("two") }
2027    _ -> { log("other") }
2028  }
2029}"#,
2030        );
2031        let pattern_warns: Vec<_> = warns
2032            .iter()
2033            .filter(|w| w.contains("Match pattern type mismatch"))
2034            .collect();
2035        assert!(pattern_warns.is_empty());
2036    }
2037
2038    #[test]
2039    fn test_match_pattern_wildcard_no_warning() {
2040        let warns = warnings(
2041            r#"pipeline t(task) {
2042  let x: int = 42
2043  match x {
2044    _ -> { log("catch all") }
2045  }
2046}"#,
2047        );
2048        let pattern_warns: Vec<_> = warns
2049            .iter()
2050            .filter(|w| w.contains("Match pattern type mismatch"))
2051            .collect();
2052        assert!(pattern_warns.is_empty());
2053    }
2054
2055    #[test]
2056    fn test_match_pattern_untyped_no_warning() {
2057        // When value has no known type, no warning should be emitted
2058        let warns = warnings(
2059            r#"pipeline t(task) {
2060  let x = some_unknown_fn()
2061  match x {
2062    "hello" -> { log("string") }
2063    42 -> { log("int") }
2064  }
2065}"#,
2066        );
2067        let pattern_warns: Vec<_> = warns
2068            .iter()
2069            .filter(|w| w.contains("Match pattern type mismatch"))
2070            .collect();
2071        assert!(pattern_warns.is_empty());
2072    }
2073}