Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Generic type parameter names in scope (treated as compatible with any type).
39    generic_type_params: std::collections::BTreeSet<String>,
40    parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45    params: Vec<(String, InferredType)>,
46    return_type: InferredType,
47    /// Generic type parameter names declared on the function.
48    type_param_names: Vec<String>,
49    /// Number of required parameters (those without defaults).
50    required_params: usize,
51}
52
53impl TypeScope {
54    fn new() -> Self {
55        Self {
56            vars: BTreeMap::new(),
57            functions: BTreeMap::new(),
58            type_aliases: BTreeMap::new(),
59            enums: BTreeMap::new(),
60            interfaces: BTreeMap::new(),
61            structs: BTreeMap::new(),
62            generic_type_params: std::collections::BTreeSet::new(),
63            parent: None,
64        }
65    }
66
67    fn child(&self) -> Self {
68        Self {
69            vars: BTreeMap::new(),
70            functions: BTreeMap::new(),
71            type_aliases: BTreeMap::new(),
72            enums: BTreeMap::new(),
73            interfaces: BTreeMap::new(),
74            structs: BTreeMap::new(),
75            generic_type_params: std::collections::BTreeSet::new(),
76            parent: Some(Box::new(self.clone())),
77        }
78    }
79
80    fn get_var(&self, name: &str) -> Option<&InferredType> {
81        self.vars
82            .get(name)
83            .or_else(|| self.parent.as_ref()?.get_var(name))
84    }
85
86    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87        self.functions
88            .get(name)
89            .or_else(|| self.parent.as_ref()?.get_fn(name))
90    }
91
92    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93        self.type_aliases
94            .get(name)
95            .or_else(|| self.parent.as_ref()?.resolve_type(name))
96    }
97
98    fn is_generic_type_param(&self, name: &str) -> bool {
99        self.generic_type_params.contains(name)
100            || self
101                .parent
102                .as_ref()
103                .is_some_and(|p| p.is_generic_type_param(name))
104    }
105
106    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107        self.enums
108            .get(name)
109            .or_else(|| self.parent.as_ref()?.get_enum(name))
110    }
111
112    #[allow(dead_code)]
113    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114        self.interfaces
115            .get(name)
116            .or_else(|| self.parent.as_ref()?.get_interface(name))
117    }
118
119    fn define_var(&mut self, name: &str, ty: InferredType) {
120        self.vars.insert(name.to_string(), ty);
121    }
122
123    fn define_fn(&mut self, name: &str, sig: FnSignature) {
124        self.functions.insert(name.to_string(), sig);
125    }
126}
127
128/// Known return types for builtin functions.
129fn builtin_return_type(name: &str) -> InferredType {
130    match name {
131        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133            Some(TypeExpr::Named("nil".into()))
134        }
135        "type_of"
136        | "to_string"
137        | "json_stringify"
138        | "read_file"
139        | "http_get"
140        | "http_post"
141        | "llm_call"
142        | "regex_replace"
143        | "path_join"
144        | "temp_dir"
145        | "date_format"
146        | "format"
147        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
148        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
149        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
150        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
151            Some(TypeExpr::Named("float".into()))
152        }
153        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
154            Some(TypeExpr::Named("bool".into()))
155        }
156        "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" | "to_list" => {
157            Some(TypeExpr::Named("list".into()))
158        }
159        "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
160        | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
161            Some(TypeExpr::Named("dict".into()))
162        }
163        "metadata_set"
164        | "metadata_save"
165        | "metadata_refresh_hashes"
166        | "invalidate_facts"
167        | "log_json"
168        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
169        "env" | "regex_match" => Some(TypeExpr::Union(vec![
170            TypeExpr::Named("string".into()),
171            TypeExpr::Named("nil".into()),
172        ])),
173        "json_parse" | "json_extract" => None, // could be any type
174        _ => None,
175    }
176}
177
178/// Check if a name is a known builtin.
179fn is_builtin(name: &str) -> bool {
180    matches!(
181        name,
182        "log"
183            | "print"
184            | "println"
185            | "type_of"
186            | "to_string"
187            | "to_int"
188            | "to_float"
189            | "json_stringify"
190            | "json_parse"
191            | "env"
192            | "timestamp"
193            | "sleep"
194            | "read_file"
195            | "write_file"
196            | "exit"
197            | "regex_match"
198            | "regex_replace"
199            | "http_get"
200            | "http_post"
201            | "llm_call"
202            | "agent_loop"
203            | "await"
204            | "cancel"
205            | "file_exists"
206            | "delete_file"
207            | "list_dir"
208            | "mkdir"
209            | "path_join"
210            | "copy_file"
211            | "append_file"
212            | "temp_dir"
213            | "stat"
214            | "exec"
215            | "shell"
216            | "date_now"
217            | "date_format"
218            | "date_parse"
219            | "format"
220            | "json_validate"
221            | "json_extract"
222            | "trim"
223            | "lowercase"
224            | "uppercase"
225            | "split"
226            | "starts_with"
227            | "ends_with"
228            | "contains"
229            | "replace"
230            | "join"
231            | "len"
232            | "substring"
233            | "dirname"
234            | "basename"
235            | "extname"
236            | "sin"
237            | "cos"
238            | "tan"
239            | "asin"
240            | "acos"
241            | "atan"
242            | "atan2"
243            | "log2"
244            | "log10"
245            | "ln"
246            | "exp"
247            | "pi"
248            | "e"
249            | "sign"
250            | "is_nan"
251            | "is_infinite"
252            | "set"
253            | "set_add"
254            | "set_remove"
255            | "set_contains"
256            | "set_union"
257            | "set_intersect"
258            | "set_difference"
259            | "to_list"
260    )
261}
262
263/// The static type checker.
264pub struct TypeChecker {
265    diagnostics: Vec<TypeDiagnostic>,
266    scope: TypeScope,
267}
268
269impl TypeChecker {
270    pub fn new() -> Self {
271        Self {
272            diagnostics: Vec::new(),
273            scope: TypeScope::new(),
274        }
275    }
276
277    /// Check a program and return diagnostics.
278    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
279        // First pass: register type and enum declarations into root scope
280        Self::register_declarations_into(&mut self.scope, program);
281
282        // Also scan pipeline bodies for declarations
283        for snode in program {
284            if let Node::Pipeline { body, .. } = &snode.node {
285                Self::register_declarations_into(&mut self.scope, body);
286            }
287        }
288
289        // Check each top-level node
290        for snode in program {
291            match &snode.node {
292                Node::Pipeline { params, body, .. } => {
293                    let mut child = self.scope.child();
294                    for p in params {
295                        child.define_var(p, None);
296                    }
297                    self.check_block(body, &mut child);
298                }
299                Node::FnDecl {
300                    name,
301                    type_params,
302                    params,
303                    return_type,
304                    body,
305                    ..
306                } => {
307                    let required_params =
308                        params.iter().filter(|p| p.default_value.is_none()).count();
309                    let sig = FnSignature {
310                        params: params
311                            .iter()
312                            .map(|p| (p.name.clone(), p.type_expr.clone()))
313                            .collect(),
314                        return_type: return_type.clone(),
315                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
316                        required_params,
317                    };
318                    self.scope.define_fn(name, sig);
319                    self.check_fn_body(type_params, params, return_type, body);
320                }
321                _ => {
322                    let mut scope = self.scope.clone();
323                    self.check_node(snode, &mut scope);
324                    // Merge any new definitions back into the top-level scope
325                    for (name, ty) in scope.vars {
326                        self.scope.vars.entry(name).or_insert(ty);
327                    }
328                }
329            }
330        }
331
332        self.diagnostics
333    }
334
335    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
336    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
337        for snode in nodes {
338            match &snode.node {
339                Node::TypeDecl { name, type_expr } => {
340                    scope.type_aliases.insert(name.clone(), type_expr.clone());
341                }
342                Node::EnumDecl { name, variants } => {
343                    let variant_names: Vec<String> =
344                        variants.iter().map(|v| v.name.clone()).collect();
345                    scope.enums.insert(name.clone(), variant_names);
346                }
347                Node::InterfaceDecl { name, methods } => {
348                    scope.interfaces.insert(name.clone(), methods.clone());
349                }
350                Node::StructDecl { name, fields } => {
351                    let field_types: Vec<(String, InferredType)> = fields
352                        .iter()
353                        .map(|f| (f.name.clone(), f.type_expr.clone()))
354                        .collect();
355                    scope.structs.insert(name.clone(), field_types);
356                }
357                _ => {}
358            }
359        }
360    }
361
362    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
363        for stmt in stmts {
364            self.check_node(stmt, scope);
365        }
366    }
367
368    /// Define variables from a destructuring pattern in the given scope (as unknown type).
369    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
370        match pattern {
371            BindingPattern::Identifier(name) => {
372                scope.define_var(name, None);
373            }
374            BindingPattern::Dict(fields) => {
375                for field in fields {
376                    let name = field.alias.as_deref().unwrap_or(&field.key);
377                    scope.define_var(name, None);
378                }
379            }
380            BindingPattern::List(elements) => {
381                for elem in elements {
382                    scope.define_var(&elem.name, None);
383                }
384            }
385        }
386    }
387
388    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
389        let span = snode.span;
390        match &snode.node {
391            Node::LetBinding {
392                pattern,
393                type_ann,
394                value,
395            } => {
396                let inferred = self.infer_type(value, scope);
397                if let BindingPattern::Identifier(name) = pattern {
398                    if let Some(expected) = type_ann {
399                        if let Some(actual) = &inferred {
400                            if !self.types_compatible(expected, actual, scope) {
401                                let mut msg = format!(
402                                    "Type mismatch: '{}' declared as {}, but assigned {}",
403                                    name,
404                                    format_type(expected),
405                                    format_type(actual)
406                                );
407                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
408                                    msg.push_str(&format!(" ({})", detail));
409                                }
410                                self.error_at(msg, span);
411                            }
412                        }
413                    }
414                    let ty = type_ann.clone().or(inferred);
415                    scope.define_var(name, ty);
416                } else {
417                    Self::define_pattern_vars(pattern, scope);
418                }
419            }
420
421            Node::VarBinding {
422                pattern,
423                type_ann,
424                value,
425            } => {
426                let inferred = self.infer_type(value, scope);
427                if let BindingPattern::Identifier(name) = pattern {
428                    if let Some(expected) = type_ann {
429                        if let Some(actual) = &inferred {
430                            if !self.types_compatible(expected, actual, scope) {
431                                let mut msg = format!(
432                                    "Type mismatch: '{}' declared as {}, but assigned {}",
433                                    name,
434                                    format_type(expected),
435                                    format_type(actual)
436                                );
437                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
438                                    msg.push_str(&format!(" ({})", detail));
439                                }
440                                self.error_at(msg, span);
441                            }
442                        }
443                    }
444                    let ty = type_ann.clone().or(inferred);
445                    scope.define_var(name, ty);
446                } else {
447                    Self::define_pattern_vars(pattern, scope);
448                }
449            }
450
451            Node::FnDecl {
452                name,
453                type_params,
454                params,
455                return_type,
456                body,
457                ..
458            } => {
459                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
460                let sig = FnSignature {
461                    params: params
462                        .iter()
463                        .map(|p| (p.name.clone(), p.type_expr.clone()))
464                        .collect(),
465                    return_type: return_type.clone(),
466                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
467                    required_params,
468                };
469                scope.define_fn(name, sig.clone());
470                scope.define_var(name, None);
471                self.check_fn_body(type_params, params, return_type, body);
472            }
473
474            Node::FunctionCall { name, args } => {
475                self.check_call(name, args, scope, span);
476            }
477
478            Node::IfElse {
479                condition,
480                then_body,
481                else_body,
482            } => {
483                self.check_node(condition, scope);
484                let mut then_scope = scope.child();
485                self.check_block(then_body, &mut then_scope);
486                if let Some(else_body) = else_body {
487                    let mut else_scope = scope.child();
488                    self.check_block(else_body, &mut else_scope);
489                }
490            }
491
492            Node::ForIn {
493                pattern,
494                iterable,
495                body,
496            } => {
497                self.check_node(iterable, scope);
498                let mut loop_scope = scope.child();
499                if let BindingPattern::Identifier(variable) = pattern {
500                    // Infer loop variable type from iterable
501                    let elem_type = match self.infer_type(iterable, scope) {
502                        Some(TypeExpr::List(inner)) => Some(*inner),
503                        Some(TypeExpr::Named(n)) if n == "string" => {
504                            Some(TypeExpr::Named("string".into()))
505                        }
506                        _ => None,
507                    };
508                    loop_scope.define_var(variable, elem_type);
509                } else {
510                    Self::define_pattern_vars(pattern, &mut loop_scope);
511                }
512                self.check_block(body, &mut loop_scope);
513            }
514
515            Node::WhileLoop { condition, body } => {
516                self.check_node(condition, scope);
517                let mut loop_scope = scope.child();
518                self.check_block(body, &mut loop_scope);
519            }
520
521            Node::TryCatch {
522                body,
523                error_var,
524                catch_body,
525                finally_body,
526                ..
527            } => {
528                let mut try_scope = scope.child();
529                self.check_block(body, &mut try_scope);
530                let mut catch_scope = scope.child();
531                if let Some(var) = error_var {
532                    catch_scope.define_var(var, None);
533                }
534                self.check_block(catch_body, &mut catch_scope);
535                if let Some(fb) = finally_body {
536                    let mut finally_scope = scope.child();
537                    self.check_block(fb, &mut finally_scope);
538                }
539            }
540
541            Node::ReturnStmt {
542                value: Some(val), ..
543            } => {
544                self.check_node(val, scope);
545            }
546
547            Node::Assignment {
548                target, value, op, ..
549            } => {
550                self.check_node(value, scope);
551                if let Node::Identifier(name) = &target.node {
552                    if let Some(Some(var_type)) = scope.get_var(name) {
553                        let value_type = self.infer_type(value, scope);
554                        let assigned = if let Some(op) = op {
555                            let var_inferred = scope.get_var(name).cloned().flatten();
556                            infer_binary_op_type(op, &var_inferred, &value_type)
557                        } else {
558                            value_type
559                        };
560                        if let Some(actual) = &assigned {
561                            if !self.types_compatible(var_type, actual, scope) {
562                                self.error_at(
563                                    format!(
564                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
565                                        format_type(actual),
566                                        name,
567                                        format_type(var_type)
568                                    ),
569                                    span,
570                                );
571                            }
572                        }
573                    }
574                }
575            }
576
577            Node::TypeDecl { name, type_expr } => {
578                scope.type_aliases.insert(name.clone(), type_expr.clone());
579            }
580
581            Node::EnumDecl { name, variants } => {
582                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
583                scope.enums.insert(name.clone(), variant_names);
584            }
585
586            Node::StructDecl { name, fields } => {
587                let field_types: Vec<(String, InferredType)> = fields
588                    .iter()
589                    .map(|f| (f.name.clone(), f.type_expr.clone()))
590                    .collect();
591                scope.structs.insert(name.clone(), field_types);
592            }
593
594            Node::InterfaceDecl { name, methods } => {
595                scope.interfaces.insert(name.clone(), methods.clone());
596            }
597
598            Node::MatchExpr { value, arms } => {
599                self.check_node(value, scope);
600                let value_type = self.infer_type(value, scope);
601                for arm in arms {
602                    self.check_node(&arm.pattern, scope);
603                    // Check for incompatible literal pattern types
604                    if let Some(ref vt) = value_type {
605                        let value_type_name = format_type(vt);
606                        let mismatch = match &arm.pattern.node {
607                            Node::StringLiteral(_) => {
608                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
609                            }
610                            Node::IntLiteral(_) => {
611                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
612                                    && !self.types_compatible(
613                                        vt,
614                                        &TypeExpr::Named("float".into()),
615                                        scope,
616                                    )
617                            }
618                            Node::FloatLiteral(_) => {
619                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
620                                    && !self.types_compatible(
621                                        vt,
622                                        &TypeExpr::Named("int".into()),
623                                        scope,
624                                    )
625                            }
626                            Node::BoolLiteral(_) => {
627                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
628                            }
629                            _ => false,
630                        };
631                        if mismatch {
632                            let pattern_type = match &arm.pattern.node {
633                                Node::StringLiteral(_) => "string",
634                                Node::IntLiteral(_) => "int",
635                                Node::FloatLiteral(_) => "float",
636                                Node::BoolLiteral(_) => "bool",
637                                _ => unreachable!(),
638                            };
639                            self.warning_at(
640                                format!(
641                                    "Match pattern type mismatch: matching {} against {} literal",
642                                    value_type_name, pattern_type
643                                ),
644                                arm.pattern.span,
645                            );
646                        }
647                    }
648                    let mut arm_scope = scope.child();
649                    self.check_block(&arm.body, &mut arm_scope);
650                }
651                self.check_match_exhaustiveness(value, arms, scope, span);
652            }
653
654            // Recurse into nested expressions + validate binary op types
655            Node::BinaryOp { op, left, right } => {
656                self.check_node(left, scope);
657                self.check_node(right, scope);
658                // Validate operator/type compatibility
659                let lt = self.infer_type(left, scope);
660                let rt = self.infer_type(right, scope);
661                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
662                    match op.as_str() {
663                        "-" | "*" | "/" | "%" => {
664                            let numeric = ["int", "float"];
665                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
666                                self.warning_at(
667                                    format!(
668                                        "Operator '{op}' may not be valid for types {} and {}",
669                                        l, r
670                                    ),
671                                    span,
672                                );
673                            }
674                        }
675                        "+" => {
676                            // + is valid for int, float, string, list, dict
677                            let valid = ["int", "float", "string", "list", "dict"];
678                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
679                                self.warning_at(
680                                    format!(
681                                        "Operator '+' may not be valid for types {} and {}",
682                                        l, r
683                                    ),
684                                    span,
685                                );
686                            }
687                        }
688                        _ => {}
689                    }
690                }
691            }
692            Node::UnaryOp { operand, .. } => {
693                self.check_node(operand, scope);
694            }
695            Node::MethodCall { object, args, .. }
696            | Node::OptionalMethodCall { object, args, .. } => {
697                self.check_node(object, scope);
698                for arg in args {
699                    self.check_node(arg, scope);
700                }
701            }
702            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
703                self.check_node(object, scope);
704            }
705            Node::SubscriptAccess { object, index } => {
706                self.check_node(object, scope);
707                self.check_node(index, scope);
708            }
709            Node::SliceAccess { object, start, end } => {
710                self.check_node(object, scope);
711                if let Some(s) = start {
712                    self.check_node(s, scope);
713                }
714                if let Some(e) = end {
715                    self.check_node(e, scope);
716                }
717            }
718
719            // Terminals — nothing to check
720            _ => {}
721        }
722    }
723
724    fn check_fn_body(
725        &mut self,
726        type_params: &[TypeParam],
727        params: &[TypedParam],
728        return_type: &Option<TypeExpr>,
729        body: &[SNode],
730    ) {
731        let mut fn_scope = self.scope.child();
732        // Register generic type parameters so they are treated as compatible
733        // with any concrete type during type checking.
734        for tp in type_params {
735            fn_scope.generic_type_params.insert(tp.name.clone());
736        }
737        for param in params {
738            fn_scope.define_var(&param.name, param.type_expr.clone());
739            if let Some(default) = &param.default_value {
740                self.check_node(default, &mut fn_scope);
741            }
742        }
743        self.check_block(body, &mut fn_scope);
744
745        // Check return statements against declared return type
746        if let Some(ret_type) = return_type {
747            for stmt in body {
748                self.check_return_type(stmt, ret_type, &fn_scope);
749            }
750        }
751    }
752
753    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
754        let span = snode.span;
755        match &snode.node {
756            Node::ReturnStmt { value: Some(val) } => {
757                let inferred = self.infer_type(val, scope);
758                if let Some(actual) = &inferred {
759                    if !self.types_compatible(expected, actual, scope) {
760                        self.error_at(
761                            format!(
762                                "Return type mismatch: expected {}, got {}",
763                                format_type(expected),
764                                format_type(actual)
765                            ),
766                            span,
767                        );
768                    }
769                }
770            }
771            Node::IfElse {
772                then_body,
773                else_body,
774                ..
775            } => {
776                for stmt in then_body {
777                    self.check_return_type(stmt, expected, scope);
778                }
779                if let Some(else_body) = else_body {
780                    for stmt in else_body {
781                        self.check_return_type(stmt, expected, scope);
782                    }
783                }
784            }
785            _ => {}
786        }
787    }
788
789    /// Check if a match expression on an enum's `.variant` property covers all variants.
790    fn check_match_exhaustiveness(
791        &mut self,
792        value: &SNode,
793        arms: &[MatchArm],
794        scope: &TypeScope,
795        span: Span,
796    ) {
797        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
798        let enum_name = match &value.node {
799            Node::PropertyAccess { object, property } if property == "variant" => {
800                // Infer the type of the object
801                match self.infer_type(object, scope) {
802                    Some(TypeExpr::Named(name)) => {
803                        if scope.get_enum(&name).is_some() {
804                            Some(name)
805                        } else {
806                            None
807                        }
808                    }
809                    _ => None,
810                }
811            }
812            _ => {
813                // Direct match on an enum value: match <expr> { ... }
814                match self.infer_type(value, scope) {
815                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
816                    _ => None,
817                }
818            }
819        };
820
821        let Some(enum_name) = enum_name else {
822            return;
823        };
824        let Some(variants) = scope.get_enum(&enum_name) else {
825            return;
826        };
827
828        // Collect variant names covered by match arms
829        let mut covered: Vec<String> = Vec::new();
830        let mut has_wildcard = false;
831
832        for arm in arms {
833            match &arm.pattern.node {
834                // String literal pattern (matching on .variant): "VariantA"
835                Node::StringLiteral(s) => covered.push(s.clone()),
836                // Identifier pattern acts as a wildcard/catch-all
837                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
838                    has_wildcard = true;
839                }
840                // Direct enum construct pattern: EnumName.Variant
841                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
842                // PropertyAccess pattern: EnumName.Variant (no args)
843                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
844                _ => {
845                    // Unknown pattern shape — conservatively treat as wildcard
846                    has_wildcard = true;
847                }
848            }
849        }
850
851        if has_wildcard {
852            return;
853        }
854
855        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
856        if !missing.is_empty() {
857            let missing_str = missing
858                .iter()
859                .map(|s| format!("\"{}\"", s))
860                .collect::<Vec<_>>()
861                .join(", ");
862            self.warning_at(
863                format!(
864                    "Non-exhaustive match on enum {}: missing variants {}",
865                    enum_name, missing_str
866                ),
867                span,
868            );
869        }
870    }
871
872    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
873        // Check against known function signatures
874        if let Some(sig) = scope.get_fn(name).cloned() {
875            if !is_builtin(name)
876                && (args.len() < sig.required_params || args.len() > sig.params.len())
877            {
878                let expected = if sig.required_params == sig.params.len() {
879                    format!("{}", sig.params.len())
880                } else {
881                    format!("{}-{}", sig.required_params, sig.params.len())
882                };
883                self.warning_at(
884                    format!(
885                        "Function '{}' expects {} arguments, got {}",
886                        name,
887                        expected,
888                        args.len()
889                    ),
890                    span,
891                );
892            }
893            // Build a scope that includes the function's generic type params
894            // so they are treated as compatible with any concrete type.
895            let call_scope = if sig.type_param_names.is_empty() {
896                scope.clone()
897            } else {
898                let mut s = scope.child();
899                for tp_name in &sig.type_param_names {
900                    s.generic_type_params.insert(tp_name.clone());
901                }
902                s
903            };
904            for (i, (arg, (param_name, param_type))) in
905                args.iter().zip(sig.params.iter()).enumerate()
906            {
907                if let Some(expected) = param_type {
908                    let actual = self.infer_type(arg, scope);
909                    if let Some(actual) = &actual {
910                        if !self.types_compatible(expected, actual, &call_scope) {
911                            self.error_at(
912                                format!(
913                                    "Argument {} ('{}'): expected {}, got {}",
914                                    i + 1,
915                                    param_name,
916                                    format_type(expected),
917                                    format_type(actual)
918                                ),
919                                arg.span,
920                            );
921                        }
922                    }
923                }
924            }
925        }
926        // Check args recursively
927        for arg in args {
928            self.check_node(arg, scope);
929        }
930    }
931
932    /// Infer the type of an expression.
933    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
934        match &snode.node {
935            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
936            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
937            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
938                Some(TypeExpr::Named("string".into()))
939            }
940            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
941            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
942            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
943            Node::DictLiteral(entries) => {
944                // Infer shape type when all keys are string literals
945                let mut fields = Vec::new();
946                let mut all_string_keys = true;
947                for entry in entries {
948                    if let Node::StringLiteral(key) = &entry.key.node {
949                        let val_type = self
950                            .infer_type(&entry.value, scope)
951                            .unwrap_or(TypeExpr::Named("nil".into()));
952                        fields.push(ShapeField {
953                            name: key.clone(),
954                            type_expr: val_type,
955                            optional: false,
956                        });
957                    } else {
958                        all_string_keys = false;
959                        break;
960                    }
961                }
962                if all_string_keys && !fields.is_empty() {
963                    Some(TypeExpr::Shape(fields))
964                } else {
965                    Some(TypeExpr::Named("dict".into()))
966                }
967            }
968            Node::Closure { params, body } => {
969                // If all params are typed and we can infer a return type, produce FnType
970                let all_typed = params.iter().all(|p| p.type_expr.is_some());
971                if all_typed && !params.is_empty() {
972                    let param_types: Vec<TypeExpr> =
973                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
974                    // Try to infer return type from last expression in body
975                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
976                    if let Some(ret_type) = ret {
977                        return Some(TypeExpr::FnType {
978                            params: param_types,
979                            return_type: Box::new(ret_type),
980                        });
981                    }
982                }
983                Some(TypeExpr::Named("closure".into()))
984            }
985
986            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
987
988            Node::FunctionCall { name, .. } => {
989                // Check user-defined function return types
990                if let Some(sig) = scope.get_fn(name) {
991                    return sig.return_type.clone();
992                }
993                // Check builtin return types
994                builtin_return_type(name)
995            }
996
997            Node::BinaryOp { op, left, right } => {
998                let lt = self.infer_type(left, scope);
999                let rt = self.infer_type(right, scope);
1000                infer_binary_op_type(op, &lt, &rt)
1001            }
1002
1003            Node::UnaryOp { op, operand } => {
1004                let t = self.infer_type(operand, scope);
1005                match op.as_str() {
1006                    "!" => Some(TypeExpr::Named("bool".into())),
1007                    "-" => t, // negation preserves type
1008                    _ => None,
1009                }
1010            }
1011
1012            Node::Ternary {
1013                true_expr,
1014                false_expr,
1015                ..
1016            } => {
1017                let tt = self.infer_type(true_expr, scope);
1018                let ft = self.infer_type(false_expr, scope);
1019                match (&tt, &ft) {
1020                    (Some(a), Some(b)) if a == b => tt,
1021                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1022                    (Some(_), None) => tt,
1023                    (None, Some(_)) => ft,
1024                    (None, None) => None,
1025                }
1026            }
1027
1028            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1029
1030            Node::PropertyAccess { object, property } => {
1031                // EnumName.Variant → infer as the enum type
1032                if let Node::Identifier(name) = &object.node {
1033                    if scope.get_enum(name).is_some() {
1034                        return Some(TypeExpr::Named(name.clone()));
1035                    }
1036                }
1037                // .variant on an enum value → string
1038                if property == "variant" {
1039                    let obj_type = self.infer_type(object, scope);
1040                    if let Some(TypeExpr::Named(name)) = &obj_type {
1041                        if scope.get_enum(name).is_some() {
1042                            return Some(TypeExpr::Named("string".into()));
1043                        }
1044                    }
1045                }
1046                // Shape field access: obj.field → field type
1047                let obj_type = self.infer_type(object, scope);
1048                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1049                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1050                        return Some(field.type_expr.clone());
1051                    }
1052                }
1053                None
1054            }
1055
1056            Node::SubscriptAccess { object, index } => {
1057                let obj_type = self.infer_type(object, scope);
1058                match &obj_type {
1059                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1060                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1061                    Some(TypeExpr::Shape(fields)) => {
1062                        // If index is a string literal, look up the field type
1063                        if let Node::StringLiteral(key) = &index.node {
1064                            fields
1065                                .iter()
1066                                .find(|f| &f.name == key)
1067                                .map(|f| f.type_expr.clone())
1068                        } else {
1069                            None
1070                        }
1071                    }
1072                    Some(TypeExpr::Named(n)) if n == "list" => None,
1073                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1074                    Some(TypeExpr::Named(n)) if n == "string" => {
1075                        Some(TypeExpr::Named("string".into()))
1076                    }
1077                    _ => None,
1078                }
1079            }
1080            Node::SliceAccess { object, .. } => {
1081                // Slicing a list returns the same list type; slicing a string returns string
1082                let obj_type = self.infer_type(object, scope);
1083                match &obj_type {
1084                    Some(TypeExpr::List(_)) => obj_type,
1085                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1086                    Some(TypeExpr::Named(n)) if n == "string" => {
1087                        Some(TypeExpr::Named("string".into()))
1088                    }
1089                    _ => None,
1090                }
1091            }
1092            Node::MethodCall { object, method, .. }
1093            | Node::OptionalMethodCall { object, method, .. } => {
1094                let obj_type = self.infer_type(object, scope);
1095                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1096                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1097                match method.as_str() {
1098                    // Shared: bool-returning methods
1099                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1100                        Some(TypeExpr::Named("bool".into()))
1101                    }
1102                    // Shared: int-returning methods
1103                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1104                    // String methods
1105                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1106                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1107                        Some(TypeExpr::Named("string".into()))
1108                    }
1109                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1110                    // filter returns dict for dicts, list for lists
1111                    "filter" => {
1112                        if is_dict {
1113                            Some(TypeExpr::Named("dict".into()))
1114                        } else {
1115                            Some(TypeExpr::Named("list".into()))
1116                        }
1117                    }
1118                    // List methods
1119                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1120                    "reduce" | "find" | "first" | "last" => None,
1121                    // Dict methods
1122                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1123                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1124                    // Conversions
1125                    "to_string" => Some(TypeExpr::Named("string".into())),
1126                    "to_int" => Some(TypeExpr::Named("int".into())),
1127                    "to_float" => Some(TypeExpr::Named("float".into())),
1128                    _ => None,
1129                }
1130            }
1131
1132            _ => None,
1133        }
1134    }
1135
1136    /// Check if two types are compatible (actual can be assigned to expected).
1137    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1138        // Generic type parameters match anything.
1139        if let TypeExpr::Named(name) = expected {
1140            if scope.is_generic_type_param(name) {
1141                return true;
1142            }
1143        }
1144        if let TypeExpr::Named(name) = actual {
1145            if scope.is_generic_type_param(name) {
1146                return true;
1147            }
1148        }
1149        let expected = self.resolve_alias(expected, scope);
1150        let actual = self.resolve_alias(actual, scope);
1151
1152        match (&expected, &actual) {
1153            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1154            (TypeExpr::Union(members), actual_type) => members
1155                .iter()
1156                .any(|m| self.types_compatible(m, actual_type, scope)),
1157            (expected_type, TypeExpr::Union(members)) => members
1158                .iter()
1159                .all(|m| self.types_compatible(expected_type, m, scope)),
1160            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1161            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1162            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1163                if expected_field.optional {
1164                    return true;
1165                }
1166                af.iter().any(|actual_field| {
1167                    actual_field.name == expected_field.name
1168                        && self.types_compatible(
1169                            &expected_field.type_expr,
1170                            &actual_field.type_expr,
1171                            scope,
1172                        )
1173                })
1174            }),
1175            // dict<K, V> expected, Shape actual → all field values must match V
1176            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1177                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1178                keys_ok
1179                    && af
1180                        .iter()
1181                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1182            }
1183            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1184            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1185            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1186                self.types_compatible(expected_inner, actual_inner, scope)
1187            }
1188            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1189            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1190            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1191                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1192            }
1193            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1194            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1195            // FnType compatibility: params match positionally and return types match
1196            (
1197                TypeExpr::FnType {
1198                    params: ep,
1199                    return_type: er,
1200                },
1201                TypeExpr::FnType {
1202                    params: ap,
1203                    return_type: ar,
1204                },
1205            ) => {
1206                ep.len() == ap.len()
1207                    && ep
1208                        .iter()
1209                        .zip(ap.iter())
1210                        .all(|(e, a)| self.types_compatible(e, a, scope))
1211                    && self.types_compatible(er, ar, scope)
1212            }
1213            // FnType is compatible with Named("closure") for backward compat
1214            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1215            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1216            _ => false,
1217        }
1218    }
1219
1220    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1221        if let TypeExpr::Named(name) = ty {
1222            if let Some(resolved) = scope.resolve_type(name) {
1223                return resolved.clone();
1224            }
1225        }
1226        ty.clone()
1227    }
1228
1229    fn error_at(&mut self, message: String, span: Span) {
1230        self.diagnostics.push(TypeDiagnostic {
1231            message,
1232            severity: DiagnosticSeverity::Error,
1233            span: Some(span),
1234        });
1235    }
1236
1237    fn warning_at(&mut self, message: String, span: Span) {
1238        self.diagnostics.push(TypeDiagnostic {
1239            message,
1240            severity: DiagnosticSeverity::Warning,
1241            span: Some(span),
1242        });
1243    }
1244}
1245
1246impl Default for TypeChecker {
1247    fn default() -> Self {
1248        Self::new()
1249    }
1250}
1251
1252/// Infer the result type of a binary operation.
1253fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1254    match op {
1255        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1256        "+" => match (left, right) {
1257            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1258                match (l.as_str(), r.as_str()) {
1259                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1260                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1261                    ("string", _) => Some(TypeExpr::Named("string".into())),
1262                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1263                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1264                    _ => Some(TypeExpr::Named("string".into())),
1265                }
1266            }
1267            _ => None,
1268        },
1269        "-" | "*" | "/" | "%" => match (left, right) {
1270            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1271                match (l.as_str(), r.as_str()) {
1272                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1273                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1274                    _ => None,
1275                }
1276            }
1277            _ => None,
1278        },
1279        "??" => match (left, right) {
1280            (Some(TypeExpr::Union(members)), _) => {
1281                let non_nil: Vec<_> = members
1282                    .iter()
1283                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1284                    .cloned()
1285                    .collect();
1286                if non_nil.len() == 1 {
1287                    Some(non_nil[0].clone())
1288                } else if non_nil.is_empty() {
1289                    right.clone()
1290                } else {
1291                    Some(TypeExpr::Union(non_nil))
1292                }
1293            }
1294            _ => right.clone(),
1295        },
1296        "|>" => None,
1297        _ => None,
1298    }
1299}
1300
1301/// Format a type expression for display in error messages.
1302/// Produce a detail string describing why a Shape type is incompatible with
1303/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1304/// has type int, expected string".  Returns `None` if both types are not shapes.
1305pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1306    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1307        let mut details = Vec::new();
1308        for field in ef {
1309            if field.optional {
1310                continue;
1311            }
1312            match af.iter().find(|f| f.name == field.name) {
1313                None => details.push(format!(
1314                    "missing field '{}' ({})",
1315                    field.name,
1316                    format_type(&field.type_expr)
1317                )),
1318                Some(actual_field) => {
1319                    let e_str = format_type(&field.type_expr);
1320                    let a_str = format_type(&actual_field.type_expr);
1321                    if e_str != a_str {
1322                        details.push(format!(
1323                            "field '{}' has type {}, expected {}",
1324                            field.name, a_str, e_str
1325                        ));
1326                    }
1327                }
1328            }
1329        }
1330        if details.is_empty() {
1331            None
1332        } else {
1333            Some(details.join("; "))
1334        }
1335    } else {
1336        None
1337    }
1338}
1339
1340pub fn format_type(ty: &TypeExpr) -> String {
1341    match ty {
1342        TypeExpr::Named(n) => n.clone(),
1343        TypeExpr::Union(types) => types
1344            .iter()
1345            .map(format_type)
1346            .collect::<Vec<_>>()
1347            .join(" | "),
1348        TypeExpr::Shape(fields) => {
1349            let inner: Vec<String> = fields
1350                .iter()
1351                .map(|f| {
1352                    let opt = if f.optional { "?" } else { "" };
1353                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1354                })
1355                .collect();
1356            format!("{{{}}}", inner.join(", "))
1357        }
1358        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1359        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1360        TypeExpr::FnType {
1361            params,
1362            return_type,
1363        } => {
1364            let params_str = params
1365                .iter()
1366                .map(format_type)
1367                .collect::<Vec<_>>()
1368                .join(", ");
1369            format!("fn({}) -> {}", params_str, format_type(return_type))
1370        }
1371    }
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376    use super::*;
1377    use crate::Parser;
1378    use harn_lexer::Lexer;
1379
1380    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1381        let mut lexer = Lexer::new(source);
1382        let tokens = lexer.tokenize().unwrap();
1383        let mut parser = Parser::new(tokens);
1384        let program = parser.parse().unwrap();
1385        TypeChecker::new().check(&program)
1386    }
1387
1388    fn errors(source: &str) -> Vec<String> {
1389        check_source(source)
1390            .into_iter()
1391            .filter(|d| d.severity == DiagnosticSeverity::Error)
1392            .map(|d| d.message)
1393            .collect()
1394    }
1395
1396    #[test]
1397    fn test_no_errors_for_untyped_code() {
1398        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1399        assert!(errs.is_empty());
1400    }
1401
1402    #[test]
1403    fn test_correct_typed_let() {
1404        let errs = errors("pipeline t(task) { let x: int = 42 }");
1405        assert!(errs.is_empty());
1406    }
1407
1408    #[test]
1409    fn test_type_mismatch_let() {
1410        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1411        assert_eq!(errs.len(), 1);
1412        assert!(errs[0].contains("Type mismatch"));
1413        assert!(errs[0].contains("int"));
1414        assert!(errs[0].contains("string"));
1415    }
1416
1417    #[test]
1418    fn test_correct_typed_fn() {
1419        let errs = errors(
1420            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1421        );
1422        assert!(errs.is_empty());
1423    }
1424
1425    #[test]
1426    fn test_fn_arg_type_mismatch() {
1427        let errs = errors(
1428            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1429add("hello", 2) }"#,
1430        );
1431        assert_eq!(errs.len(), 1);
1432        assert!(errs[0].contains("Argument 1"));
1433        assert!(errs[0].contains("expected int"));
1434    }
1435
1436    #[test]
1437    fn test_return_type_mismatch() {
1438        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1439        assert_eq!(errs.len(), 1);
1440        assert!(errs[0].contains("Return type mismatch"));
1441    }
1442
1443    #[test]
1444    fn test_union_type_compatible() {
1445        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1446        assert!(errs.is_empty());
1447    }
1448
1449    #[test]
1450    fn test_union_type_mismatch() {
1451        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1452        assert_eq!(errs.len(), 1);
1453        assert!(errs[0].contains("Type mismatch"));
1454    }
1455
1456    #[test]
1457    fn test_type_inference_propagation() {
1458        let errs = errors(
1459            r#"pipeline t(task) {
1460  fn add(a: int, b: int) -> int { return a + b }
1461  let result: string = add(1, 2)
1462}"#,
1463        );
1464        assert_eq!(errs.len(), 1);
1465        assert!(errs[0].contains("Type mismatch"));
1466        assert!(errs[0].contains("string"));
1467        assert!(errs[0].contains("int"));
1468    }
1469
1470    #[test]
1471    fn test_builtin_return_type_inference() {
1472        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1473        assert_eq!(errs.len(), 1);
1474        assert!(errs[0].contains("string"));
1475        assert!(errs[0].contains("int"));
1476    }
1477
1478    #[test]
1479    fn test_binary_op_type_inference() {
1480        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1481        assert_eq!(errs.len(), 1);
1482    }
1483
1484    #[test]
1485    fn test_comparison_returns_bool() {
1486        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1487        assert!(errs.is_empty());
1488    }
1489
1490    #[test]
1491    fn test_int_float_promotion() {
1492        let errs = errors("pipeline t(task) { let x: float = 42 }");
1493        assert!(errs.is_empty());
1494    }
1495
1496    #[test]
1497    fn test_untyped_code_no_errors() {
1498        let errs = errors(
1499            r#"pipeline t(task) {
1500  fn process(data) {
1501    let result = data + " processed"
1502    return result
1503  }
1504  log(process("hello"))
1505}"#,
1506        );
1507        assert!(errs.is_empty());
1508    }
1509
1510    #[test]
1511    fn test_type_alias() {
1512        let errs = errors(
1513            r#"pipeline t(task) {
1514  type Name = string
1515  let x: Name = "hello"
1516}"#,
1517        );
1518        assert!(errs.is_empty());
1519    }
1520
1521    #[test]
1522    fn test_type_alias_mismatch() {
1523        let errs = errors(
1524            r#"pipeline t(task) {
1525  type Name = string
1526  let x: Name = 42
1527}"#,
1528        );
1529        assert_eq!(errs.len(), 1);
1530    }
1531
1532    #[test]
1533    fn test_assignment_type_check() {
1534        let errs = errors(
1535            r#"pipeline t(task) {
1536  var x: int = 0
1537  x = "hello"
1538}"#,
1539        );
1540        assert_eq!(errs.len(), 1);
1541        assert!(errs[0].contains("cannot assign string"));
1542    }
1543
1544    #[test]
1545    fn test_covariance_int_to_float_in_fn() {
1546        let errs = errors(
1547            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1548        );
1549        assert!(errs.is_empty());
1550    }
1551
1552    #[test]
1553    fn test_covariance_return_type() {
1554        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1555        assert!(errs.is_empty());
1556    }
1557
1558    #[test]
1559    fn test_no_contravariance_float_to_int() {
1560        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1561        assert_eq!(errs.len(), 1);
1562    }
1563
1564    // --- Exhaustiveness checking tests ---
1565
1566    fn warnings(source: &str) -> Vec<String> {
1567        check_source(source)
1568            .into_iter()
1569            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1570            .map(|d| d.message)
1571            .collect()
1572    }
1573
1574    #[test]
1575    fn test_exhaustive_match_no_warning() {
1576        let warns = warnings(
1577            r#"pipeline t(task) {
1578  enum Color { Red, Green, Blue }
1579  let c = Color.Red
1580  match c.variant {
1581    "Red" -> { log("r") }
1582    "Green" -> { log("g") }
1583    "Blue" -> { log("b") }
1584  }
1585}"#,
1586        );
1587        let exhaustive_warns: Vec<_> = warns
1588            .iter()
1589            .filter(|w| w.contains("Non-exhaustive"))
1590            .collect();
1591        assert!(exhaustive_warns.is_empty());
1592    }
1593
1594    #[test]
1595    fn test_non_exhaustive_match_warning() {
1596        let warns = warnings(
1597            r#"pipeline t(task) {
1598  enum Color { Red, Green, Blue }
1599  let c = Color.Red
1600  match c.variant {
1601    "Red" -> { log("r") }
1602    "Green" -> { log("g") }
1603  }
1604}"#,
1605        );
1606        let exhaustive_warns: Vec<_> = warns
1607            .iter()
1608            .filter(|w| w.contains("Non-exhaustive"))
1609            .collect();
1610        assert_eq!(exhaustive_warns.len(), 1);
1611        assert!(exhaustive_warns[0].contains("Blue"));
1612    }
1613
1614    #[test]
1615    fn test_non_exhaustive_multiple_missing() {
1616        let warns = warnings(
1617            r#"pipeline t(task) {
1618  enum Status { Active, Inactive, Pending }
1619  let s = Status.Active
1620  match s.variant {
1621    "Active" -> { log("a") }
1622  }
1623}"#,
1624        );
1625        let exhaustive_warns: Vec<_> = warns
1626            .iter()
1627            .filter(|w| w.contains("Non-exhaustive"))
1628            .collect();
1629        assert_eq!(exhaustive_warns.len(), 1);
1630        assert!(exhaustive_warns[0].contains("Inactive"));
1631        assert!(exhaustive_warns[0].contains("Pending"));
1632    }
1633
1634    #[test]
1635    fn test_enum_construct_type_inference() {
1636        let errs = errors(
1637            r#"pipeline t(task) {
1638  enum Color { Red, Green, Blue }
1639  let c: Color = Color.Red
1640}"#,
1641        );
1642        assert!(errs.is_empty());
1643    }
1644
1645    // --- Type narrowing tests ---
1646
1647    #[test]
1648    fn test_nil_coalescing_strips_nil() {
1649        // After ??, nil should be stripped from the type
1650        let errs = errors(
1651            r#"pipeline t(task) {
1652  let x: string | nil = nil
1653  let y: string = x ?? "default"
1654}"#,
1655        );
1656        assert!(errs.is_empty());
1657    }
1658
1659    #[test]
1660    fn test_shape_mismatch_detail_missing_field() {
1661        let errs = errors(
1662            r#"pipeline t(task) {
1663  let x: {name: string, age: int} = {name: "hello"}
1664}"#,
1665        );
1666        assert_eq!(errs.len(), 1);
1667        assert!(
1668            errs[0].contains("missing field 'age'"),
1669            "expected detail about missing field, got: {}",
1670            errs[0]
1671        );
1672    }
1673
1674    #[test]
1675    fn test_shape_mismatch_detail_wrong_type() {
1676        let errs = errors(
1677            r#"pipeline t(task) {
1678  let x: {name: string, age: int} = {name: 42, age: 10}
1679}"#,
1680        );
1681        assert_eq!(errs.len(), 1);
1682        assert!(
1683            errs[0].contains("field 'name' has type int, expected string"),
1684            "expected detail about wrong type, got: {}",
1685            errs[0]
1686        );
1687    }
1688
1689    // --- Match pattern type validation tests ---
1690
1691    #[test]
1692    fn test_match_pattern_string_against_int() {
1693        let warns = warnings(
1694            r#"pipeline t(task) {
1695  let x: int = 42
1696  match x {
1697    "hello" -> { log("bad") }
1698    42 -> { log("ok") }
1699  }
1700}"#,
1701        );
1702        let pattern_warns: Vec<_> = warns
1703            .iter()
1704            .filter(|w| w.contains("Match pattern type mismatch"))
1705            .collect();
1706        assert_eq!(pattern_warns.len(), 1);
1707        assert!(pattern_warns[0].contains("matching int against string literal"));
1708    }
1709
1710    #[test]
1711    fn test_match_pattern_int_against_string() {
1712        let warns = warnings(
1713            r#"pipeline t(task) {
1714  let x: string = "hello"
1715  match x {
1716    42 -> { log("bad") }
1717    "hello" -> { log("ok") }
1718  }
1719}"#,
1720        );
1721        let pattern_warns: Vec<_> = warns
1722            .iter()
1723            .filter(|w| w.contains("Match pattern type mismatch"))
1724            .collect();
1725        assert_eq!(pattern_warns.len(), 1);
1726        assert!(pattern_warns[0].contains("matching string against int literal"));
1727    }
1728
1729    #[test]
1730    fn test_match_pattern_bool_against_int() {
1731        let warns = warnings(
1732            r#"pipeline t(task) {
1733  let x: int = 42
1734  match x {
1735    true -> { log("bad") }
1736    42 -> { log("ok") }
1737  }
1738}"#,
1739        );
1740        let pattern_warns: Vec<_> = warns
1741            .iter()
1742            .filter(|w| w.contains("Match pattern type mismatch"))
1743            .collect();
1744        assert_eq!(pattern_warns.len(), 1);
1745        assert!(pattern_warns[0].contains("matching int against bool literal"));
1746    }
1747
1748    #[test]
1749    fn test_match_pattern_float_against_string() {
1750        let warns = warnings(
1751            r#"pipeline t(task) {
1752  let x: string = "hello"
1753  match x {
1754    3.14 -> { log("bad") }
1755    "hello" -> { log("ok") }
1756  }
1757}"#,
1758        );
1759        let pattern_warns: Vec<_> = warns
1760            .iter()
1761            .filter(|w| w.contains("Match pattern type mismatch"))
1762            .collect();
1763        assert_eq!(pattern_warns.len(), 1);
1764        assert!(pattern_warns[0].contains("matching string against float literal"));
1765    }
1766
1767    #[test]
1768    fn test_match_pattern_int_against_float_ok() {
1769        // int and float are compatible for match patterns
1770        let warns = warnings(
1771            r#"pipeline t(task) {
1772  let x: float = 3.14
1773  match x {
1774    42 -> { log("ok") }
1775    _ -> { log("default") }
1776  }
1777}"#,
1778        );
1779        let pattern_warns: Vec<_> = warns
1780            .iter()
1781            .filter(|w| w.contains("Match pattern type mismatch"))
1782            .collect();
1783        assert!(pattern_warns.is_empty());
1784    }
1785
1786    #[test]
1787    fn test_match_pattern_float_against_int_ok() {
1788        // float and int are compatible for match patterns
1789        let warns = warnings(
1790            r#"pipeline t(task) {
1791  let x: int = 42
1792  match x {
1793    3.14 -> { log("close") }
1794    _ -> { log("default") }
1795  }
1796}"#,
1797        );
1798        let pattern_warns: Vec<_> = warns
1799            .iter()
1800            .filter(|w| w.contains("Match pattern type mismatch"))
1801            .collect();
1802        assert!(pattern_warns.is_empty());
1803    }
1804
1805    #[test]
1806    fn test_match_pattern_correct_types_no_warning() {
1807        let warns = warnings(
1808            r#"pipeline t(task) {
1809  let x: int = 42
1810  match x {
1811    1 -> { log("one") }
1812    2 -> { log("two") }
1813    _ -> { log("other") }
1814  }
1815}"#,
1816        );
1817        let pattern_warns: Vec<_> = warns
1818            .iter()
1819            .filter(|w| w.contains("Match pattern type mismatch"))
1820            .collect();
1821        assert!(pattern_warns.is_empty());
1822    }
1823
1824    #[test]
1825    fn test_match_pattern_wildcard_no_warning() {
1826        let warns = warnings(
1827            r#"pipeline t(task) {
1828  let x: int = 42
1829  match x {
1830    _ -> { log("catch all") }
1831  }
1832}"#,
1833        );
1834        let pattern_warns: Vec<_> = warns
1835            .iter()
1836            .filter(|w| w.contains("Match pattern type mismatch"))
1837            .collect();
1838        assert!(pattern_warns.is_empty());
1839    }
1840
1841    #[test]
1842    fn test_match_pattern_untyped_no_warning() {
1843        // When value has no known type, no warning should be emitted
1844        let warns = warnings(
1845            r#"pipeline t(task) {
1846  let x = some_unknown_fn()
1847  match x {
1848    "hello" -> { log("string") }
1849    42 -> { log("int") }
1850  }
1851}"#,
1852        );
1853        let pattern_warns: Vec<_> = warns
1854            .iter()
1855            .filter(|w| w.contains("Match pattern type mismatch"))
1856            .collect();
1857        assert!(pattern_warns.is_empty());
1858    }
1859}