Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Generic type parameter names in scope (treated as compatible with any type).
39    generic_type_params: std::collections::BTreeSet<String>,
40    parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45    params: Vec<(String, InferredType)>,
46    return_type: InferredType,
47    /// Generic type parameter names declared on the function.
48    type_param_names: Vec<String>,
49}
50
51impl TypeScope {
52    fn new() -> Self {
53        Self {
54            vars: BTreeMap::new(),
55            functions: BTreeMap::new(),
56            type_aliases: BTreeMap::new(),
57            enums: BTreeMap::new(),
58            interfaces: BTreeMap::new(),
59            structs: BTreeMap::new(),
60            generic_type_params: std::collections::BTreeSet::new(),
61            parent: None,
62        }
63    }
64
65    fn child(&self) -> Self {
66        Self {
67            vars: BTreeMap::new(),
68            functions: BTreeMap::new(),
69            type_aliases: BTreeMap::new(),
70            enums: BTreeMap::new(),
71            interfaces: BTreeMap::new(),
72            structs: BTreeMap::new(),
73            generic_type_params: std::collections::BTreeSet::new(),
74            parent: Some(Box::new(self.clone())),
75        }
76    }
77
78    fn get_var(&self, name: &str) -> Option<&InferredType> {
79        self.vars
80            .get(name)
81            .or_else(|| self.parent.as_ref()?.get_var(name))
82    }
83
84    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
85        self.functions
86            .get(name)
87            .or_else(|| self.parent.as_ref()?.get_fn(name))
88    }
89
90    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
91        self.type_aliases
92            .get(name)
93            .or_else(|| self.parent.as_ref()?.resolve_type(name))
94    }
95
96    fn is_generic_type_param(&self, name: &str) -> bool {
97        self.generic_type_params.contains(name)
98            || self
99                .parent
100                .as_ref()
101                .is_some_and(|p| p.is_generic_type_param(name))
102    }
103
104    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
105        self.enums
106            .get(name)
107            .or_else(|| self.parent.as_ref()?.get_enum(name))
108    }
109
110    #[allow(dead_code)]
111    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
112        self.interfaces
113            .get(name)
114            .or_else(|| self.parent.as_ref()?.get_interface(name))
115    }
116
117    fn define_var(&mut self, name: &str, ty: InferredType) {
118        self.vars.insert(name.to_string(), ty);
119    }
120
121    fn define_fn(&mut self, name: &str, sig: FnSignature) {
122        self.functions.insert(name.to_string(), sig);
123    }
124}
125
126/// Known return types for builtin functions.
127fn builtin_return_type(name: &str) -> InferredType {
128    match name {
129        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
130        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
131            Some(TypeExpr::Named("nil".into()))
132        }
133        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
134        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
135        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
136        "to_int" => Some(TypeExpr::Named("int".into())),
137        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
138        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
139        "list_dir" => Some(TypeExpr::Named("list".into())),
140        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
141        "env" | "regex_match" => Some(TypeExpr::Union(vec![
142            TypeExpr::Named("string".into()),
143            TypeExpr::Named("nil".into()),
144        ])),
145        "json_parse" | "json_extract" => None, // could be any type
146        _ => None,
147    }
148}
149
150/// Check if a name is a known builtin.
151fn is_builtin(name: &str) -> bool {
152    matches!(
153        name,
154        "log"
155            | "print"
156            | "println"
157            | "type_of"
158            | "to_string"
159            | "to_int"
160            | "to_float"
161            | "json_stringify"
162            | "json_parse"
163            | "env"
164            | "timestamp"
165            | "sleep"
166            | "read_file"
167            | "write_file"
168            | "exit"
169            | "regex_match"
170            | "regex_replace"
171            | "http_get"
172            | "http_post"
173            | "llm_call"
174            | "agent_loop"
175            | "await"
176            | "cancel"
177            | "file_exists"
178            | "delete_file"
179            | "list_dir"
180            | "mkdir"
181            | "path_join"
182            | "copy_file"
183            | "append_file"
184            | "temp_dir"
185            | "stat"
186            | "exec"
187            | "shell"
188            | "date_now"
189            | "date_format"
190            | "date_parse"
191            | "format"
192            | "json_validate"
193            | "json_extract"
194            | "trim"
195            | "lowercase"
196            | "uppercase"
197            | "split"
198            | "starts_with"
199            | "ends_with"
200            | "contains"
201            | "replace"
202            | "join"
203            | "len"
204            | "substring"
205            | "dirname"
206            | "basename"
207            | "extname"
208    )
209}
210
211/// The static type checker.
212pub struct TypeChecker {
213    diagnostics: Vec<TypeDiagnostic>,
214    scope: TypeScope,
215}
216
217impl TypeChecker {
218    pub fn new() -> Self {
219        Self {
220            diagnostics: Vec::new(),
221            scope: TypeScope::new(),
222        }
223    }
224
225    /// Check a program and return diagnostics.
226    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
227        // First pass: register type and enum declarations into root scope
228        Self::register_declarations_into(&mut self.scope, program);
229
230        // Also scan pipeline bodies for declarations
231        for snode in program {
232            if let Node::Pipeline { body, .. } = &snode.node {
233                Self::register_declarations_into(&mut self.scope, body);
234            }
235        }
236
237        // Check each top-level node
238        for snode in program {
239            match &snode.node {
240                Node::Pipeline { params, body, .. } => {
241                    let mut child = self.scope.child();
242                    for p in params {
243                        child.define_var(p, None);
244                    }
245                    self.check_block(body, &mut child);
246                }
247                Node::FnDecl {
248                    name,
249                    type_params,
250                    params,
251                    return_type,
252                    body,
253                    ..
254                } => {
255                    let sig = FnSignature {
256                        params: params
257                            .iter()
258                            .map(|p| (p.name.clone(), p.type_expr.clone()))
259                            .collect(),
260                        return_type: return_type.clone(),
261                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
262                    };
263                    self.scope.define_fn(name, sig);
264                    self.check_fn_body(type_params, params, return_type, body);
265                }
266                _ => {
267                    let mut scope = self.scope.clone();
268                    self.check_node(snode, &mut scope);
269                    // Merge any new definitions back into the top-level scope
270                    for (name, ty) in scope.vars {
271                        self.scope.vars.entry(name).or_insert(ty);
272                    }
273                }
274            }
275        }
276
277        self.diagnostics
278    }
279
280    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
281    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
282        for snode in nodes {
283            match &snode.node {
284                Node::TypeDecl { name, type_expr } => {
285                    scope.type_aliases.insert(name.clone(), type_expr.clone());
286                }
287                Node::EnumDecl { name, variants } => {
288                    let variant_names: Vec<String> =
289                        variants.iter().map(|v| v.name.clone()).collect();
290                    scope.enums.insert(name.clone(), variant_names);
291                }
292                Node::InterfaceDecl { name, methods } => {
293                    scope.interfaces.insert(name.clone(), methods.clone());
294                }
295                Node::StructDecl { name, fields } => {
296                    let field_types: Vec<(String, InferredType)> = fields
297                        .iter()
298                        .map(|f| (f.name.clone(), f.type_expr.clone()))
299                        .collect();
300                    scope.structs.insert(name.clone(), field_types);
301                }
302                _ => {}
303            }
304        }
305    }
306
307    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
308        for stmt in stmts {
309            self.check_node(stmt, scope);
310        }
311    }
312
313    /// Define variables from a destructuring pattern in the given scope (as unknown type).
314    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
315        match pattern {
316            BindingPattern::Identifier(name) => {
317                scope.define_var(name, None);
318            }
319            BindingPattern::Dict(fields) => {
320                for field in fields {
321                    let name = field.alias.as_deref().unwrap_or(&field.key);
322                    scope.define_var(name, None);
323                }
324            }
325            BindingPattern::List(elements) => {
326                for elem in elements {
327                    scope.define_var(&elem.name, None);
328                }
329            }
330        }
331    }
332
333    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
334        let span = snode.span;
335        match &snode.node {
336            Node::LetBinding {
337                pattern,
338                type_ann,
339                value,
340            } => {
341                let inferred = self.infer_type(value, scope);
342                if let BindingPattern::Identifier(name) = pattern {
343                    if let Some(expected) = type_ann {
344                        if let Some(actual) = &inferred {
345                            if !self.types_compatible(expected, actual, scope) {
346                                let mut msg = format!(
347                                    "Type mismatch: '{}' declared as {}, but assigned {}",
348                                    name,
349                                    format_type(expected),
350                                    format_type(actual)
351                                );
352                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
353                                    msg.push_str(&format!(" ({})", detail));
354                                }
355                                self.error_at(msg, span);
356                            }
357                        }
358                    }
359                    let ty = type_ann.clone().or(inferred);
360                    scope.define_var(name, ty);
361                } else {
362                    Self::define_pattern_vars(pattern, scope);
363                }
364            }
365
366            Node::VarBinding {
367                pattern,
368                type_ann,
369                value,
370            } => {
371                let inferred = self.infer_type(value, scope);
372                if let BindingPattern::Identifier(name) = pattern {
373                    if let Some(expected) = type_ann {
374                        if let Some(actual) = &inferred {
375                            if !self.types_compatible(expected, actual, scope) {
376                                let mut msg = format!(
377                                    "Type mismatch: '{}' declared as {}, but assigned {}",
378                                    name,
379                                    format_type(expected),
380                                    format_type(actual)
381                                );
382                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
383                                    msg.push_str(&format!(" ({})", detail));
384                                }
385                                self.error_at(msg, span);
386                            }
387                        }
388                    }
389                    let ty = type_ann.clone().or(inferred);
390                    scope.define_var(name, ty);
391                } else {
392                    Self::define_pattern_vars(pattern, scope);
393                }
394            }
395
396            Node::FnDecl {
397                name,
398                type_params,
399                params,
400                return_type,
401                body,
402                ..
403            } => {
404                let sig = FnSignature {
405                    params: params
406                        .iter()
407                        .map(|p| (p.name.clone(), p.type_expr.clone()))
408                        .collect(),
409                    return_type: return_type.clone(),
410                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
411                };
412                scope.define_fn(name, sig.clone());
413                scope.define_var(name, None);
414                self.check_fn_body(type_params, params, return_type, body);
415            }
416
417            Node::FunctionCall { name, args } => {
418                self.check_call(name, args, scope, span);
419            }
420
421            Node::IfElse {
422                condition,
423                then_body,
424                else_body,
425            } => {
426                self.check_node(condition, scope);
427                let mut then_scope = scope.child();
428                self.check_block(then_body, &mut then_scope);
429                if let Some(else_body) = else_body {
430                    let mut else_scope = scope.child();
431                    self.check_block(else_body, &mut else_scope);
432                }
433            }
434
435            Node::ForIn {
436                pattern,
437                iterable,
438                body,
439            } => {
440                self.check_node(iterable, scope);
441                let mut loop_scope = scope.child();
442                if let BindingPattern::Identifier(variable) = pattern {
443                    // Infer loop variable type from iterable
444                    let elem_type = match self.infer_type(iterable, scope) {
445                        Some(TypeExpr::List(inner)) => Some(*inner),
446                        Some(TypeExpr::Named(n)) if n == "string" => {
447                            Some(TypeExpr::Named("string".into()))
448                        }
449                        _ => None,
450                    };
451                    loop_scope.define_var(variable, elem_type);
452                } else {
453                    Self::define_pattern_vars(pattern, &mut loop_scope);
454                }
455                self.check_block(body, &mut loop_scope);
456            }
457
458            Node::WhileLoop { condition, body } => {
459                self.check_node(condition, scope);
460                let mut loop_scope = scope.child();
461                self.check_block(body, &mut loop_scope);
462            }
463
464            Node::TryCatch {
465                body,
466                error_var,
467                catch_body,
468                ..
469            } => {
470                let mut try_scope = scope.child();
471                self.check_block(body, &mut try_scope);
472                let mut catch_scope = scope.child();
473                if let Some(var) = error_var {
474                    catch_scope.define_var(var, None);
475                }
476                self.check_block(catch_body, &mut catch_scope);
477            }
478
479            Node::ReturnStmt {
480                value: Some(val), ..
481            } => {
482                self.check_node(val, scope);
483            }
484
485            Node::Assignment {
486                target, value, op, ..
487            } => {
488                self.check_node(value, scope);
489                if let Node::Identifier(name) = &target.node {
490                    if let Some(Some(var_type)) = scope.get_var(name) {
491                        let value_type = self.infer_type(value, scope);
492                        let assigned = if let Some(op) = op {
493                            let var_inferred = scope.get_var(name).cloned().flatten();
494                            infer_binary_op_type(op, &var_inferred, &value_type)
495                        } else {
496                            value_type
497                        };
498                        if let Some(actual) = &assigned {
499                            if !self.types_compatible(var_type, actual, scope) {
500                                self.error_at(
501                                    format!(
502                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
503                                        format_type(actual),
504                                        name,
505                                        format_type(var_type)
506                                    ),
507                                    span,
508                                );
509                            }
510                        }
511                    }
512                }
513            }
514
515            Node::TypeDecl { name, type_expr } => {
516                scope.type_aliases.insert(name.clone(), type_expr.clone());
517            }
518
519            Node::EnumDecl { name, variants } => {
520                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
521                scope.enums.insert(name.clone(), variant_names);
522            }
523
524            Node::StructDecl { name, fields } => {
525                let field_types: Vec<(String, InferredType)> = fields
526                    .iter()
527                    .map(|f| (f.name.clone(), f.type_expr.clone()))
528                    .collect();
529                scope.structs.insert(name.clone(), field_types);
530            }
531
532            Node::InterfaceDecl { name, methods } => {
533                scope.interfaces.insert(name.clone(), methods.clone());
534            }
535
536            Node::MatchExpr { value, arms } => {
537                self.check_node(value, scope);
538                let value_type = self.infer_type(value, scope);
539                for arm in arms {
540                    self.check_node(&arm.pattern, scope);
541                    // Check for incompatible literal pattern types
542                    if let Some(ref vt) = value_type {
543                        let value_type_name = format_type(vt);
544                        let mismatch = match &arm.pattern.node {
545                            Node::StringLiteral(_) => {
546                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
547                            }
548                            Node::IntLiteral(_) => {
549                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
550                                    && !self.types_compatible(
551                                        vt,
552                                        &TypeExpr::Named("float".into()),
553                                        scope,
554                                    )
555                            }
556                            Node::FloatLiteral(_) => {
557                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
558                                    && !self.types_compatible(
559                                        vt,
560                                        &TypeExpr::Named("int".into()),
561                                        scope,
562                                    )
563                            }
564                            Node::BoolLiteral(_) => {
565                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
566                            }
567                            _ => false,
568                        };
569                        if mismatch {
570                            let pattern_type = match &arm.pattern.node {
571                                Node::StringLiteral(_) => "string",
572                                Node::IntLiteral(_) => "int",
573                                Node::FloatLiteral(_) => "float",
574                                Node::BoolLiteral(_) => "bool",
575                                _ => unreachable!(),
576                            };
577                            self.warning_at(
578                                format!(
579                                    "Match pattern type mismatch: matching {} against {} literal",
580                                    value_type_name, pattern_type
581                                ),
582                                arm.pattern.span,
583                            );
584                        }
585                    }
586                    let mut arm_scope = scope.child();
587                    self.check_block(&arm.body, &mut arm_scope);
588                }
589                self.check_match_exhaustiveness(value, arms, scope, span);
590            }
591
592            // Recurse into nested expressions + validate binary op types
593            Node::BinaryOp { op, left, right } => {
594                self.check_node(left, scope);
595                self.check_node(right, scope);
596                // Validate operator/type compatibility
597                let lt = self.infer_type(left, scope);
598                let rt = self.infer_type(right, scope);
599                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
600                    match op.as_str() {
601                        "-" | "*" | "/" | "%" => {
602                            let numeric = ["int", "float"];
603                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
604                                self.warning_at(
605                                    format!(
606                                        "Operator '{op}' may not be valid for types {} and {}",
607                                        l, r
608                                    ),
609                                    span,
610                                );
611                            }
612                        }
613                        "+" => {
614                            // + is valid for int, float, string, list, dict
615                            let valid = ["int", "float", "string", "list", "dict"];
616                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
617                                self.warning_at(
618                                    format!(
619                                        "Operator '+' may not be valid for types {} and {}",
620                                        l, r
621                                    ),
622                                    span,
623                                );
624                            }
625                        }
626                        _ => {}
627                    }
628                }
629            }
630            Node::UnaryOp { operand, .. } => {
631                self.check_node(operand, scope);
632            }
633            Node::MethodCall { object, args, .. }
634            | Node::OptionalMethodCall { object, args, .. } => {
635                self.check_node(object, scope);
636                for arg in args {
637                    self.check_node(arg, scope);
638                }
639            }
640            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
641                self.check_node(object, scope);
642            }
643            Node::SubscriptAccess { object, index } => {
644                self.check_node(object, scope);
645                self.check_node(index, scope);
646            }
647            Node::SliceAccess { object, start, end } => {
648                self.check_node(object, scope);
649                if let Some(s) = start {
650                    self.check_node(s, scope);
651                }
652                if let Some(e) = end {
653                    self.check_node(e, scope);
654                }
655            }
656
657            // Terminals — nothing to check
658            _ => {}
659        }
660    }
661
662    fn check_fn_body(
663        &mut self,
664        type_params: &[TypeParam],
665        params: &[TypedParam],
666        return_type: &Option<TypeExpr>,
667        body: &[SNode],
668    ) {
669        let mut fn_scope = self.scope.child();
670        // Register generic type parameters so they are treated as compatible
671        // with any concrete type during type checking.
672        for tp in type_params {
673            fn_scope.generic_type_params.insert(tp.name.clone());
674        }
675        for param in params {
676            fn_scope.define_var(&param.name, param.type_expr.clone());
677        }
678        self.check_block(body, &mut fn_scope);
679
680        // Check return statements against declared return type
681        if let Some(ret_type) = return_type {
682            for stmt in body {
683                self.check_return_type(stmt, ret_type, &fn_scope);
684            }
685        }
686    }
687
688    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
689        let span = snode.span;
690        match &snode.node {
691            Node::ReturnStmt { value: Some(val) } => {
692                let inferred = self.infer_type(val, scope);
693                if let Some(actual) = &inferred {
694                    if !self.types_compatible(expected, actual, scope) {
695                        self.error_at(
696                            format!(
697                                "Return type mismatch: expected {}, got {}",
698                                format_type(expected),
699                                format_type(actual)
700                            ),
701                            span,
702                        );
703                    }
704                }
705            }
706            Node::IfElse {
707                then_body,
708                else_body,
709                ..
710            } => {
711                for stmt in then_body {
712                    self.check_return_type(stmt, expected, scope);
713                }
714                if let Some(else_body) = else_body {
715                    for stmt in else_body {
716                        self.check_return_type(stmt, expected, scope);
717                    }
718                }
719            }
720            _ => {}
721        }
722    }
723
724    /// Check if a match expression on an enum's `.variant` property covers all variants.
725    fn check_match_exhaustiveness(
726        &mut self,
727        value: &SNode,
728        arms: &[MatchArm],
729        scope: &TypeScope,
730        span: Span,
731    ) {
732        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
733        let enum_name = match &value.node {
734            Node::PropertyAccess { object, property } if property == "variant" => {
735                // Infer the type of the object
736                match self.infer_type(object, scope) {
737                    Some(TypeExpr::Named(name)) => {
738                        if scope.get_enum(&name).is_some() {
739                            Some(name)
740                        } else {
741                            None
742                        }
743                    }
744                    _ => None,
745                }
746            }
747            _ => {
748                // Direct match on an enum value: match <expr> { ... }
749                match self.infer_type(value, scope) {
750                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
751                    _ => None,
752                }
753            }
754        };
755
756        let Some(enum_name) = enum_name else {
757            return;
758        };
759        let Some(variants) = scope.get_enum(&enum_name) else {
760            return;
761        };
762
763        // Collect variant names covered by match arms
764        let mut covered: Vec<String> = Vec::new();
765        let mut has_wildcard = false;
766
767        for arm in arms {
768            match &arm.pattern.node {
769                // String literal pattern (matching on .variant): "VariantA"
770                Node::StringLiteral(s) => covered.push(s.clone()),
771                // Identifier pattern acts as a wildcard/catch-all
772                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
773                    has_wildcard = true;
774                }
775                // Direct enum construct pattern: EnumName.Variant
776                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
777                // PropertyAccess pattern: EnumName.Variant (no args)
778                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
779                _ => {
780                    // Unknown pattern shape — conservatively treat as wildcard
781                    has_wildcard = true;
782                }
783            }
784        }
785
786        if has_wildcard {
787            return;
788        }
789
790        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
791        if !missing.is_empty() {
792            let missing_str = missing
793                .iter()
794                .map(|s| format!("\"{}\"", s))
795                .collect::<Vec<_>>()
796                .join(", ");
797            self.warning_at(
798                format!(
799                    "Non-exhaustive match on enum {}: missing variants {}",
800                    enum_name, missing_str
801                ),
802                span,
803            );
804        }
805    }
806
807    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
808        // Check against known function signatures
809        if let Some(sig) = scope.get_fn(name).cloned() {
810            if args.len() != sig.params.len() && !is_builtin(name) {
811                self.warning_at(
812                    format!(
813                        "Function '{}' expects {} arguments, got {}",
814                        name,
815                        sig.params.len(),
816                        args.len()
817                    ),
818                    span,
819                );
820            }
821            // Build a scope that includes the function's generic type params
822            // so they are treated as compatible with any concrete type.
823            let call_scope = if sig.type_param_names.is_empty() {
824                scope.clone()
825            } else {
826                let mut s = scope.child();
827                for tp_name in &sig.type_param_names {
828                    s.generic_type_params.insert(tp_name.clone());
829                }
830                s
831            };
832            for (i, (arg, (param_name, param_type))) in
833                args.iter().zip(sig.params.iter()).enumerate()
834            {
835                if let Some(expected) = param_type {
836                    let actual = self.infer_type(arg, scope);
837                    if let Some(actual) = &actual {
838                        if !self.types_compatible(expected, actual, &call_scope) {
839                            self.error_at(
840                                format!(
841                                    "Argument {} ('{}'): expected {}, got {}",
842                                    i + 1,
843                                    param_name,
844                                    format_type(expected),
845                                    format_type(actual)
846                                ),
847                                arg.span,
848                            );
849                        }
850                    }
851                }
852            }
853        }
854        // Check args recursively
855        for arg in args {
856            self.check_node(arg, scope);
857        }
858    }
859
860    /// Infer the type of an expression.
861    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
862        match &snode.node {
863            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
864            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
865            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
866                Some(TypeExpr::Named("string".into()))
867            }
868            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
869            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
870            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
871            Node::DictLiteral(entries) => {
872                // Infer shape type when all keys are string literals
873                let mut fields = Vec::new();
874                let mut all_string_keys = true;
875                for entry in entries {
876                    if let Node::StringLiteral(key) = &entry.key.node {
877                        let val_type = self
878                            .infer_type(&entry.value, scope)
879                            .unwrap_or(TypeExpr::Named("nil".into()));
880                        fields.push(ShapeField {
881                            name: key.clone(),
882                            type_expr: val_type,
883                            optional: false,
884                        });
885                    } else {
886                        all_string_keys = false;
887                        break;
888                    }
889                }
890                if all_string_keys && !fields.is_empty() {
891                    Some(TypeExpr::Shape(fields))
892                } else {
893                    Some(TypeExpr::Named("dict".into()))
894                }
895            }
896            Node::Closure { params, body } => {
897                // If all params are typed and we can infer a return type, produce FnType
898                let all_typed = params.iter().all(|p| p.type_expr.is_some());
899                if all_typed && !params.is_empty() {
900                    let param_types: Vec<TypeExpr> =
901                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
902                    // Try to infer return type from last expression in body
903                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
904                    if let Some(ret_type) = ret {
905                        return Some(TypeExpr::FnType {
906                            params: param_types,
907                            return_type: Box::new(ret_type),
908                        });
909                    }
910                }
911                Some(TypeExpr::Named("closure".into()))
912            }
913
914            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
915
916            Node::FunctionCall { name, .. } => {
917                // Check user-defined function return types
918                if let Some(sig) = scope.get_fn(name) {
919                    return sig.return_type.clone();
920                }
921                // Check builtin return types
922                builtin_return_type(name)
923            }
924
925            Node::BinaryOp { op, left, right } => {
926                let lt = self.infer_type(left, scope);
927                let rt = self.infer_type(right, scope);
928                infer_binary_op_type(op, &lt, &rt)
929            }
930
931            Node::UnaryOp { op, operand } => {
932                let t = self.infer_type(operand, scope);
933                match op.as_str() {
934                    "!" => Some(TypeExpr::Named("bool".into())),
935                    "-" => t, // negation preserves type
936                    _ => None,
937                }
938            }
939
940            Node::Ternary {
941                true_expr,
942                false_expr,
943                ..
944            } => {
945                let tt = self.infer_type(true_expr, scope);
946                let ft = self.infer_type(false_expr, scope);
947                match (&tt, &ft) {
948                    (Some(a), Some(b)) if a == b => tt,
949                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
950                    (Some(_), None) => tt,
951                    (None, Some(_)) => ft,
952                    (None, None) => None,
953                }
954            }
955
956            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
957
958            Node::PropertyAccess { object, property } => {
959                // EnumName.Variant → infer as the enum type
960                if let Node::Identifier(name) = &object.node {
961                    if scope.get_enum(name).is_some() {
962                        return Some(TypeExpr::Named(name.clone()));
963                    }
964                }
965                // .variant on an enum value → string
966                if property == "variant" {
967                    let obj_type = self.infer_type(object, scope);
968                    if let Some(TypeExpr::Named(name)) = &obj_type {
969                        if scope.get_enum(name).is_some() {
970                            return Some(TypeExpr::Named("string".into()));
971                        }
972                    }
973                }
974                // Shape field access: obj.field → field type
975                let obj_type = self.infer_type(object, scope);
976                if let Some(TypeExpr::Shape(fields)) = &obj_type {
977                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
978                        return Some(field.type_expr.clone());
979                    }
980                }
981                None
982            }
983
984            Node::SubscriptAccess { object, index } => {
985                let obj_type = self.infer_type(object, scope);
986                match &obj_type {
987                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
988                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
989                    Some(TypeExpr::Shape(fields)) => {
990                        // If index is a string literal, look up the field type
991                        if let Node::StringLiteral(key) = &index.node {
992                            fields
993                                .iter()
994                                .find(|f| &f.name == key)
995                                .map(|f| f.type_expr.clone())
996                        } else {
997                            None
998                        }
999                    }
1000                    Some(TypeExpr::Named(n)) if n == "list" => None,
1001                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1002                    Some(TypeExpr::Named(n)) if n == "string" => {
1003                        Some(TypeExpr::Named("string".into()))
1004                    }
1005                    _ => None,
1006                }
1007            }
1008            Node::SliceAccess { object, .. } => {
1009                // Slicing a list returns the same list type; slicing a string returns string
1010                let obj_type = self.infer_type(object, scope);
1011                match &obj_type {
1012                    Some(TypeExpr::List(_)) => obj_type,
1013                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1014                    Some(TypeExpr::Named(n)) if n == "string" => {
1015                        Some(TypeExpr::Named("string".into()))
1016                    }
1017                    _ => None,
1018                }
1019            }
1020            Node::MethodCall { object, method, .. }
1021            | Node::OptionalMethodCall { object, method, .. } => {
1022                let obj_type = self.infer_type(object, scope);
1023                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1024                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1025                match method.as_str() {
1026                    // Shared: bool-returning methods
1027                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1028                        Some(TypeExpr::Named("bool".into()))
1029                    }
1030                    // Shared: int-returning methods
1031                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1032                    // String methods
1033                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1034                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1035                        Some(TypeExpr::Named("string".into()))
1036                    }
1037                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1038                    // filter returns dict for dicts, list for lists
1039                    "filter" => {
1040                        if is_dict {
1041                            Some(TypeExpr::Named("dict".into()))
1042                        } else {
1043                            Some(TypeExpr::Named("list".into()))
1044                        }
1045                    }
1046                    // List methods
1047                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1048                    "reduce" | "find" | "first" | "last" => None,
1049                    // Dict methods
1050                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1051                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1052                    // Conversions
1053                    "to_string" => Some(TypeExpr::Named("string".into())),
1054                    "to_int" => Some(TypeExpr::Named("int".into())),
1055                    "to_float" => Some(TypeExpr::Named("float".into())),
1056                    _ => None,
1057                }
1058            }
1059
1060            _ => None,
1061        }
1062    }
1063
1064    /// Check if two types are compatible (actual can be assigned to expected).
1065    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1066        // Generic type parameters match anything.
1067        if let TypeExpr::Named(name) = expected {
1068            if scope.is_generic_type_param(name) {
1069                return true;
1070            }
1071        }
1072        if let TypeExpr::Named(name) = actual {
1073            if scope.is_generic_type_param(name) {
1074                return true;
1075            }
1076        }
1077        let expected = self.resolve_alias(expected, scope);
1078        let actual = self.resolve_alias(actual, scope);
1079
1080        match (&expected, &actual) {
1081            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1082            (TypeExpr::Union(members), actual_type) => members
1083                .iter()
1084                .any(|m| self.types_compatible(m, actual_type, scope)),
1085            (expected_type, TypeExpr::Union(members)) => members
1086                .iter()
1087                .all(|m| self.types_compatible(expected_type, m, scope)),
1088            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1089            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1090            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1091                if expected_field.optional {
1092                    return true;
1093                }
1094                af.iter().any(|actual_field| {
1095                    actual_field.name == expected_field.name
1096                        && self.types_compatible(
1097                            &expected_field.type_expr,
1098                            &actual_field.type_expr,
1099                            scope,
1100                        )
1101                })
1102            }),
1103            // dict<K, V> expected, Shape actual → all field values must match V
1104            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1105                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1106                keys_ok
1107                    && af
1108                        .iter()
1109                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1110            }
1111            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1112            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1113            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1114                self.types_compatible(expected_inner, actual_inner, scope)
1115            }
1116            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1117            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1118            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1119                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1120            }
1121            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1122            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1123            // FnType compatibility: params match positionally and return types match
1124            (
1125                TypeExpr::FnType {
1126                    params: ep,
1127                    return_type: er,
1128                },
1129                TypeExpr::FnType {
1130                    params: ap,
1131                    return_type: ar,
1132                },
1133            ) => {
1134                ep.len() == ap.len()
1135                    && ep
1136                        .iter()
1137                        .zip(ap.iter())
1138                        .all(|(e, a)| self.types_compatible(e, a, scope))
1139                    && self.types_compatible(er, ar, scope)
1140            }
1141            // FnType is compatible with Named("closure") for backward compat
1142            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1143            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1144            _ => false,
1145        }
1146    }
1147
1148    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1149        if let TypeExpr::Named(name) = ty {
1150            if let Some(resolved) = scope.resolve_type(name) {
1151                return resolved.clone();
1152            }
1153        }
1154        ty.clone()
1155    }
1156
1157    fn error_at(&mut self, message: String, span: Span) {
1158        self.diagnostics.push(TypeDiagnostic {
1159            message,
1160            severity: DiagnosticSeverity::Error,
1161            span: Some(span),
1162        });
1163    }
1164
1165    fn warning_at(&mut self, message: String, span: Span) {
1166        self.diagnostics.push(TypeDiagnostic {
1167            message,
1168            severity: DiagnosticSeverity::Warning,
1169            span: Some(span),
1170        });
1171    }
1172}
1173
1174impl Default for TypeChecker {
1175    fn default() -> Self {
1176        Self::new()
1177    }
1178}
1179
1180/// Infer the result type of a binary operation.
1181fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1182    match op {
1183        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1184        "+" => match (left, right) {
1185            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1186                match (l.as_str(), r.as_str()) {
1187                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1188                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1189                    ("string", _) => Some(TypeExpr::Named("string".into())),
1190                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1191                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1192                    _ => Some(TypeExpr::Named("string".into())),
1193                }
1194            }
1195            _ => None,
1196        },
1197        "-" | "*" | "/" | "%" => match (left, right) {
1198            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1199                match (l.as_str(), r.as_str()) {
1200                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1201                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1202                    _ => None,
1203                }
1204            }
1205            _ => None,
1206        },
1207        "??" => match (left, right) {
1208            (Some(TypeExpr::Union(members)), _) => {
1209                let non_nil: Vec<_> = members
1210                    .iter()
1211                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1212                    .cloned()
1213                    .collect();
1214                if non_nil.len() == 1 {
1215                    Some(non_nil[0].clone())
1216                } else if non_nil.is_empty() {
1217                    right.clone()
1218                } else {
1219                    Some(TypeExpr::Union(non_nil))
1220                }
1221            }
1222            _ => right.clone(),
1223        },
1224        "|>" => None,
1225        _ => None,
1226    }
1227}
1228
1229/// Format a type expression for display in error messages.
1230/// Produce a detail string describing why a Shape type is incompatible with
1231/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1232/// has type int, expected string".  Returns `None` if both types are not shapes.
1233pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1234    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1235        let mut details = Vec::new();
1236        for field in ef {
1237            if field.optional {
1238                continue;
1239            }
1240            match af.iter().find(|f| f.name == field.name) {
1241                None => details.push(format!(
1242                    "missing field '{}' ({})",
1243                    field.name,
1244                    format_type(&field.type_expr)
1245                )),
1246                Some(actual_field) => {
1247                    let e_str = format_type(&field.type_expr);
1248                    let a_str = format_type(&actual_field.type_expr);
1249                    if e_str != a_str {
1250                        details.push(format!(
1251                            "field '{}' has type {}, expected {}",
1252                            field.name, a_str, e_str
1253                        ));
1254                    }
1255                }
1256            }
1257        }
1258        if details.is_empty() {
1259            None
1260        } else {
1261            Some(details.join("; "))
1262        }
1263    } else {
1264        None
1265    }
1266}
1267
1268pub fn format_type(ty: &TypeExpr) -> String {
1269    match ty {
1270        TypeExpr::Named(n) => n.clone(),
1271        TypeExpr::Union(types) => types
1272            .iter()
1273            .map(format_type)
1274            .collect::<Vec<_>>()
1275            .join(" | "),
1276        TypeExpr::Shape(fields) => {
1277            let inner: Vec<String> = fields
1278                .iter()
1279                .map(|f| {
1280                    let opt = if f.optional { "?" } else { "" };
1281                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1282                })
1283                .collect();
1284            format!("{{{}}}", inner.join(", "))
1285        }
1286        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1287        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1288        TypeExpr::FnType {
1289            params,
1290            return_type,
1291        } => {
1292            let params_str = params
1293                .iter()
1294                .map(format_type)
1295                .collect::<Vec<_>>()
1296                .join(", ");
1297            format!("fn({}) -> {}", params_str, format_type(return_type))
1298        }
1299    }
1300}
1301
1302#[cfg(test)]
1303mod tests {
1304    use super::*;
1305    use crate::Parser;
1306    use harn_lexer::Lexer;
1307
1308    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1309        let mut lexer = Lexer::new(source);
1310        let tokens = lexer.tokenize().unwrap();
1311        let mut parser = Parser::new(tokens);
1312        let program = parser.parse().unwrap();
1313        TypeChecker::new().check(&program)
1314    }
1315
1316    fn errors(source: &str) -> Vec<String> {
1317        check_source(source)
1318            .into_iter()
1319            .filter(|d| d.severity == DiagnosticSeverity::Error)
1320            .map(|d| d.message)
1321            .collect()
1322    }
1323
1324    #[test]
1325    fn test_no_errors_for_untyped_code() {
1326        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1327        assert!(errs.is_empty());
1328    }
1329
1330    #[test]
1331    fn test_correct_typed_let() {
1332        let errs = errors("pipeline t(task) { let x: int = 42 }");
1333        assert!(errs.is_empty());
1334    }
1335
1336    #[test]
1337    fn test_type_mismatch_let() {
1338        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1339        assert_eq!(errs.len(), 1);
1340        assert!(errs[0].contains("Type mismatch"));
1341        assert!(errs[0].contains("int"));
1342        assert!(errs[0].contains("string"));
1343    }
1344
1345    #[test]
1346    fn test_correct_typed_fn() {
1347        let errs = errors(
1348            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1349        );
1350        assert!(errs.is_empty());
1351    }
1352
1353    #[test]
1354    fn test_fn_arg_type_mismatch() {
1355        let errs = errors(
1356            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1357add("hello", 2) }"#,
1358        );
1359        assert_eq!(errs.len(), 1);
1360        assert!(errs[0].contains("Argument 1"));
1361        assert!(errs[0].contains("expected int"));
1362    }
1363
1364    #[test]
1365    fn test_return_type_mismatch() {
1366        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1367        assert_eq!(errs.len(), 1);
1368        assert!(errs[0].contains("Return type mismatch"));
1369    }
1370
1371    #[test]
1372    fn test_union_type_compatible() {
1373        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1374        assert!(errs.is_empty());
1375    }
1376
1377    #[test]
1378    fn test_union_type_mismatch() {
1379        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1380        assert_eq!(errs.len(), 1);
1381        assert!(errs[0].contains("Type mismatch"));
1382    }
1383
1384    #[test]
1385    fn test_type_inference_propagation() {
1386        let errs = errors(
1387            r#"pipeline t(task) {
1388  fn add(a: int, b: int) -> int { return a + b }
1389  let result: string = add(1, 2)
1390}"#,
1391        );
1392        assert_eq!(errs.len(), 1);
1393        assert!(errs[0].contains("Type mismatch"));
1394        assert!(errs[0].contains("string"));
1395        assert!(errs[0].contains("int"));
1396    }
1397
1398    #[test]
1399    fn test_builtin_return_type_inference() {
1400        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1401        assert_eq!(errs.len(), 1);
1402        assert!(errs[0].contains("string"));
1403        assert!(errs[0].contains("int"));
1404    }
1405
1406    #[test]
1407    fn test_binary_op_type_inference() {
1408        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1409        assert_eq!(errs.len(), 1);
1410    }
1411
1412    #[test]
1413    fn test_comparison_returns_bool() {
1414        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1415        assert!(errs.is_empty());
1416    }
1417
1418    #[test]
1419    fn test_int_float_promotion() {
1420        let errs = errors("pipeline t(task) { let x: float = 42 }");
1421        assert!(errs.is_empty());
1422    }
1423
1424    #[test]
1425    fn test_untyped_code_no_errors() {
1426        let errs = errors(
1427            r#"pipeline t(task) {
1428  fn process(data) {
1429    let result = data + " processed"
1430    return result
1431  }
1432  log(process("hello"))
1433}"#,
1434        );
1435        assert!(errs.is_empty());
1436    }
1437
1438    #[test]
1439    fn test_type_alias() {
1440        let errs = errors(
1441            r#"pipeline t(task) {
1442  type Name = string
1443  let x: Name = "hello"
1444}"#,
1445        );
1446        assert!(errs.is_empty());
1447    }
1448
1449    #[test]
1450    fn test_type_alias_mismatch() {
1451        let errs = errors(
1452            r#"pipeline t(task) {
1453  type Name = string
1454  let x: Name = 42
1455}"#,
1456        );
1457        assert_eq!(errs.len(), 1);
1458    }
1459
1460    #[test]
1461    fn test_assignment_type_check() {
1462        let errs = errors(
1463            r#"pipeline t(task) {
1464  var x: int = 0
1465  x = "hello"
1466}"#,
1467        );
1468        assert_eq!(errs.len(), 1);
1469        assert!(errs[0].contains("cannot assign string"));
1470    }
1471
1472    #[test]
1473    fn test_covariance_int_to_float_in_fn() {
1474        let errs = errors(
1475            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1476        );
1477        assert!(errs.is_empty());
1478    }
1479
1480    #[test]
1481    fn test_covariance_return_type() {
1482        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1483        assert!(errs.is_empty());
1484    }
1485
1486    #[test]
1487    fn test_no_contravariance_float_to_int() {
1488        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1489        assert_eq!(errs.len(), 1);
1490    }
1491
1492    // --- Exhaustiveness checking tests ---
1493
1494    fn warnings(source: &str) -> Vec<String> {
1495        check_source(source)
1496            .into_iter()
1497            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1498            .map(|d| d.message)
1499            .collect()
1500    }
1501
1502    #[test]
1503    fn test_exhaustive_match_no_warning() {
1504        let warns = warnings(
1505            r#"pipeline t(task) {
1506  enum Color { Red, Green, Blue }
1507  let c = Color.Red
1508  match c.variant {
1509    "Red" -> { log("r") }
1510    "Green" -> { log("g") }
1511    "Blue" -> { log("b") }
1512  }
1513}"#,
1514        );
1515        let exhaustive_warns: Vec<_> = warns
1516            .iter()
1517            .filter(|w| w.contains("Non-exhaustive"))
1518            .collect();
1519        assert!(exhaustive_warns.is_empty());
1520    }
1521
1522    #[test]
1523    fn test_non_exhaustive_match_warning() {
1524        let warns = warnings(
1525            r#"pipeline t(task) {
1526  enum Color { Red, Green, Blue }
1527  let c = Color.Red
1528  match c.variant {
1529    "Red" -> { log("r") }
1530    "Green" -> { log("g") }
1531  }
1532}"#,
1533        );
1534        let exhaustive_warns: Vec<_> = warns
1535            .iter()
1536            .filter(|w| w.contains("Non-exhaustive"))
1537            .collect();
1538        assert_eq!(exhaustive_warns.len(), 1);
1539        assert!(exhaustive_warns[0].contains("Blue"));
1540    }
1541
1542    #[test]
1543    fn test_non_exhaustive_multiple_missing() {
1544        let warns = warnings(
1545            r#"pipeline t(task) {
1546  enum Status { Active, Inactive, Pending }
1547  let s = Status.Active
1548  match s.variant {
1549    "Active" -> { log("a") }
1550  }
1551}"#,
1552        );
1553        let exhaustive_warns: Vec<_> = warns
1554            .iter()
1555            .filter(|w| w.contains("Non-exhaustive"))
1556            .collect();
1557        assert_eq!(exhaustive_warns.len(), 1);
1558        assert!(exhaustive_warns[0].contains("Inactive"));
1559        assert!(exhaustive_warns[0].contains("Pending"));
1560    }
1561
1562    #[test]
1563    fn test_enum_construct_type_inference() {
1564        let errs = errors(
1565            r#"pipeline t(task) {
1566  enum Color { Red, Green, Blue }
1567  let c: Color = Color.Red
1568}"#,
1569        );
1570        assert!(errs.is_empty());
1571    }
1572
1573    // --- Type narrowing tests ---
1574
1575    #[test]
1576    fn test_nil_coalescing_strips_nil() {
1577        // After ??, nil should be stripped from the type
1578        let errs = errors(
1579            r#"pipeline t(task) {
1580  let x: string | nil = nil
1581  let y: string = x ?? "default"
1582}"#,
1583        );
1584        assert!(errs.is_empty());
1585    }
1586
1587    #[test]
1588    fn test_shape_mismatch_detail_missing_field() {
1589        let errs = errors(
1590            r#"pipeline t(task) {
1591  let x: {name: string, age: int} = {name: "hello"}
1592}"#,
1593        );
1594        assert_eq!(errs.len(), 1);
1595        assert!(
1596            errs[0].contains("missing field 'age'"),
1597            "expected detail about missing field, got: {}",
1598            errs[0]
1599        );
1600    }
1601
1602    #[test]
1603    fn test_shape_mismatch_detail_wrong_type() {
1604        let errs = errors(
1605            r#"pipeline t(task) {
1606  let x: {name: string, age: int} = {name: 42, age: 10}
1607}"#,
1608        );
1609        assert_eq!(errs.len(), 1);
1610        assert!(
1611            errs[0].contains("field 'name' has type int, expected string"),
1612            "expected detail about wrong type, got: {}",
1613            errs[0]
1614        );
1615    }
1616
1617    // --- Match pattern type validation tests ---
1618
1619    #[test]
1620    fn test_match_pattern_string_against_int() {
1621        let warns = warnings(
1622            r#"pipeline t(task) {
1623  let x: int = 42
1624  match x {
1625    "hello" -> { log("bad") }
1626    42 -> { log("ok") }
1627  }
1628}"#,
1629        );
1630        let pattern_warns: Vec<_> = warns
1631            .iter()
1632            .filter(|w| w.contains("Match pattern type mismatch"))
1633            .collect();
1634        assert_eq!(pattern_warns.len(), 1);
1635        assert!(pattern_warns[0].contains("matching int against string literal"));
1636    }
1637
1638    #[test]
1639    fn test_match_pattern_int_against_string() {
1640        let warns = warnings(
1641            r#"pipeline t(task) {
1642  let x: string = "hello"
1643  match x {
1644    42 -> { log("bad") }
1645    "hello" -> { log("ok") }
1646  }
1647}"#,
1648        );
1649        let pattern_warns: Vec<_> = warns
1650            .iter()
1651            .filter(|w| w.contains("Match pattern type mismatch"))
1652            .collect();
1653        assert_eq!(pattern_warns.len(), 1);
1654        assert!(pattern_warns[0].contains("matching string against int literal"));
1655    }
1656
1657    #[test]
1658    fn test_match_pattern_bool_against_int() {
1659        let warns = warnings(
1660            r#"pipeline t(task) {
1661  let x: int = 42
1662  match x {
1663    true -> { log("bad") }
1664    42 -> { log("ok") }
1665  }
1666}"#,
1667        );
1668        let pattern_warns: Vec<_> = warns
1669            .iter()
1670            .filter(|w| w.contains("Match pattern type mismatch"))
1671            .collect();
1672        assert_eq!(pattern_warns.len(), 1);
1673        assert!(pattern_warns[0].contains("matching int against bool literal"));
1674    }
1675
1676    #[test]
1677    fn test_match_pattern_float_against_string() {
1678        let warns = warnings(
1679            r#"pipeline t(task) {
1680  let x: string = "hello"
1681  match x {
1682    3.14 -> { log("bad") }
1683    "hello" -> { log("ok") }
1684  }
1685}"#,
1686        );
1687        let pattern_warns: Vec<_> = warns
1688            .iter()
1689            .filter(|w| w.contains("Match pattern type mismatch"))
1690            .collect();
1691        assert_eq!(pattern_warns.len(), 1);
1692        assert!(pattern_warns[0].contains("matching string against float literal"));
1693    }
1694
1695    #[test]
1696    fn test_match_pattern_int_against_float_ok() {
1697        // int and float are compatible for match patterns
1698        let warns = warnings(
1699            r#"pipeline t(task) {
1700  let x: float = 3.14
1701  match x {
1702    42 -> { log("ok") }
1703    _ -> { log("default") }
1704  }
1705}"#,
1706        );
1707        let pattern_warns: Vec<_> = warns
1708            .iter()
1709            .filter(|w| w.contains("Match pattern type mismatch"))
1710            .collect();
1711        assert!(pattern_warns.is_empty());
1712    }
1713
1714    #[test]
1715    fn test_match_pattern_float_against_int_ok() {
1716        // float and int are compatible for match patterns
1717        let warns = warnings(
1718            r#"pipeline t(task) {
1719  let x: int = 42
1720  match x {
1721    3.14 -> { log("close") }
1722    _ -> { log("default") }
1723  }
1724}"#,
1725        );
1726        let pattern_warns: Vec<_> = warns
1727            .iter()
1728            .filter(|w| w.contains("Match pattern type mismatch"))
1729            .collect();
1730        assert!(pattern_warns.is_empty());
1731    }
1732
1733    #[test]
1734    fn test_match_pattern_correct_types_no_warning() {
1735        let warns = warnings(
1736            r#"pipeline t(task) {
1737  let x: int = 42
1738  match x {
1739    1 -> { log("one") }
1740    2 -> { log("two") }
1741    _ -> { log("other") }
1742  }
1743}"#,
1744        );
1745        let pattern_warns: Vec<_> = warns
1746            .iter()
1747            .filter(|w| w.contains("Match pattern type mismatch"))
1748            .collect();
1749        assert!(pattern_warns.is_empty());
1750    }
1751
1752    #[test]
1753    fn test_match_pattern_wildcard_no_warning() {
1754        let warns = warnings(
1755            r#"pipeline t(task) {
1756  let x: int = 42
1757  match x {
1758    _ -> { log("catch all") }
1759  }
1760}"#,
1761        );
1762        let pattern_warns: Vec<_> = warns
1763            .iter()
1764            .filter(|w| w.contains("Match pattern type mismatch"))
1765            .collect();
1766        assert!(pattern_warns.is_empty());
1767    }
1768
1769    #[test]
1770    fn test_match_pattern_untyped_no_warning() {
1771        // When value has no known type, no warning should be emitted
1772        let warns = warnings(
1773            r#"pipeline t(task) {
1774  let x = some_unknown_fn()
1775  match x {
1776    "hello" -> { log("string") }
1777    42 -> { log("int") }
1778  }
1779}"#,
1780        );
1781        let pattern_warns: Vec<_> = warns
1782            .iter()
1783            .filter(|w| w.contains("Match pattern type mismatch"))
1784            .collect();
1785        assert!(pattern_warns.is_empty());
1786    }
1787}