Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Generic type parameter names in scope (treated as compatible with any type).
39    generic_type_params: std::collections::BTreeSet<String>,
40    parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45    params: Vec<(String, InferredType)>,
46    return_type: InferredType,
47    /// Generic type parameter names declared on the function.
48    type_param_names: Vec<String>,
49    /// Number of required parameters (those without defaults).
50    required_params: usize,
51}
52
53impl TypeScope {
54    fn new() -> Self {
55        Self {
56            vars: BTreeMap::new(),
57            functions: BTreeMap::new(),
58            type_aliases: BTreeMap::new(),
59            enums: BTreeMap::new(),
60            interfaces: BTreeMap::new(),
61            structs: BTreeMap::new(),
62            generic_type_params: std::collections::BTreeSet::new(),
63            parent: None,
64        }
65    }
66
67    fn child(&self) -> Self {
68        Self {
69            vars: BTreeMap::new(),
70            functions: BTreeMap::new(),
71            type_aliases: BTreeMap::new(),
72            enums: BTreeMap::new(),
73            interfaces: BTreeMap::new(),
74            structs: BTreeMap::new(),
75            generic_type_params: std::collections::BTreeSet::new(),
76            parent: Some(Box::new(self.clone())),
77        }
78    }
79
80    fn get_var(&self, name: &str) -> Option<&InferredType> {
81        self.vars
82            .get(name)
83            .or_else(|| self.parent.as_ref()?.get_var(name))
84    }
85
86    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87        self.functions
88            .get(name)
89            .or_else(|| self.parent.as_ref()?.get_fn(name))
90    }
91
92    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93        self.type_aliases
94            .get(name)
95            .or_else(|| self.parent.as_ref()?.resolve_type(name))
96    }
97
98    fn is_generic_type_param(&self, name: &str) -> bool {
99        self.generic_type_params.contains(name)
100            || self
101                .parent
102                .as_ref()
103                .is_some_and(|p| p.is_generic_type_param(name))
104    }
105
106    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107        self.enums
108            .get(name)
109            .or_else(|| self.parent.as_ref()?.get_enum(name))
110    }
111
112    #[allow(dead_code)]
113    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114        self.interfaces
115            .get(name)
116            .or_else(|| self.parent.as_ref()?.get_interface(name))
117    }
118
119    fn define_var(&mut self, name: &str, ty: InferredType) {
120        self.vars.insert(name.to_string(), ty);
121    }
122
123    fn define_fn(&mut self, name: &str, sig: FnSignature) {
124        self.functions.insert(name.to_string(), sig);
125    }
126}
127
128/// Known return types for builtin functions.
129fn builtin_return_type(name: &str) -> InferredType {
130    match name {
131        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133            Some(TypeExpr::Named("nil".into()))
134        }
135        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
136        | "llm_call" | "regex_replace" | "path_join" | "temp_dir"
137        | "date_format" | "format" | "compute_content_hash" => {
138            Some(TypeExpr::Named("string".into()))
139        }
140        "to_int" | "timer_end" | "elapsed" => Some(TypeExpr::Named("int".into())),
141        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
142        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
143        "list_dir" => Some(TypeExpr::Named("list".into())),
144        "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
145        | "timer_start" | "metadata_get" => Some(TypeExpr::Named("dict".into())),
146        "metadata_set" | "metadata_save" | "metadata_refresh_hashes"
147        | "invalidate_facts" | "log_json" => Some(TypeExpr::Named("nil".into())),
148        "env" | "regex_match" => Some(TypeExpr::Union(vec![
149            TypeExpr::Named("string".into()),
150            TypeExpr::Named("nil".into()),
151        ])),
152        "json_parse" | "json_extract" => None, // could be any type
153        _ => None,
154    }
155}
156
157/// Check if a name is a known builtin.
158fn is_builtin(name: &str) -> bool {
159    matches!(
160        name,
161        "log"
162            | "print"
163            | "println"
164            | "type_of"
165            | "to_string"
166            | "to_int"
167            | "to_float"
168            | "json_stringify"
169            | "json_parse"
170            | "env"
171            | "timestamp"
172            | "sleep"
173            | "read_file"
174            | "write_file"
175            | "exit"
176            | "regex_match"
177            | "regex_replace"
178            | "http_get"
179            | "http_post"
180            | "llm_call"
181            | "agent_loop"
182            | "await"
183            | "cancel"
184            | "file_exists"
185            | "delete_file"
186            | "list_dir"
187            | "mkdir"
188            | "path_join"
189            | "copy_file"
190            | "append_file"
191            | "temp_dir"
192            | "stat"
193            | "exec"
194            | "shell"
195            | "date_now"
196            | "date_format"
197            | "date_parse"
198            | "format"
199            | "json_validate"
200            | "json_extract"
201            | "trim"
202            | "lowercase"
203            | "uppercase"
204            | "split"
205            | "starts_with"
206            | "ends_with"
207            | "contains"
208            | "replace"
209            | "join"
210            | "len"
211            | "substring"
212            | "dirname"
213            | "basename"
214            | "extname"
215    )
216}
217
218/// The static type checker.
219pub struct TypeChecker {
220    diagnostics: Vec<TypeDiagnostic>,
221    scope: TypeScope,
222}
223
224impl TypeChecker {
225    pub fn new() -> Self {
226        Self {
227            diagnostics: Vec::new(),
228            scope: TypeScope::new(),
229        }
230    }
231
232    /// Check a program and return diagnostics.
233    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
234        // First pass: register type and enum declarations into root scope
235        Self::register_declarations_into(&mut self.scope, program);
236
237        // Also scan pipeline bodies for declarations
238        for snode in program {
239            if let Node::Pipeline { body, .. } = &snode.node {
240                Self::register_declarations_into(&mut self.scope, body);
241            }
242        }
243
244        // Check each top-level node
245        for snode in program {
246            match &snode.node {
247                Node::Pipeline { params, body, .. } => {
248                    let mut child = self.scope.child();
249                    for p in params {
250                        child.define_var(p, None);
251                    }
252                    self.check_block(body, &mut child);
253                }
254                Node::FnDecl {
255                    name,
256                    type_params,
257                    params,
258                    return_type,
259                    body,
260                    ..
261                } => {
262                    let required_params = params
263                        .iter()
264                        .filter(|p| p.default_value.is_none())
265                        .count();
266                    let sig = FnSignature {
267                        params: params
268                            .iter()
269                            .map(|p| (p.name.clone(), p.type_expr.clone()))
270                            .collect(),
271                        return_type: return_type.clone(),
272                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
273                        required_params,
274                    };
275                    self.scope.define_fn(name, sig);
276                    self.check_fn_body(type_params, params, return_type, body);
277                }
278                _ => {
279                    let mut scope = self.scope.clone();
280                    self.check_node(snode, &mut scope);
281                    // Merge any new definitions back into the top-level scope
282                    for (name, ty) in scope.vars {
283                        self.scope.vars.entry(name).or_insert(ty);
284                    }
285                }
286            }
287        }
288
289        self.diagnostics
290    }
291
292    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
293    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
294        for snode in nodes {
295            match &snode.node {
296                Node::TypeDecl { name, type_expr } => {
297                    scope.type_aliases.insert(name.clone(), type_expr.clone());
298                }
299                Node::EnumDecl { name, variants } => {
300                    let variant_names: Vec<String> =
301                        variants.iter().map(|v| v.name.clone()).collect();
302                    scope.enums.insert(name.clone(), variant_names);
303                }
304                Node::InterfaceDecl { name, methods } => {
305                    scope.interfaces.insert(name.clone(), methods.clone());
306                }
307                Node::StructDecl { name, fields } => {
308                    let field_types: Vec<(String, InferredType)> = fields
309                        .iter()
310                        .map(|f| (f.name.clone(), f.type_expr.clone()))
311                        .collect();
312                    scope.structs.insert(name.clone(), field_types);
313                }
314                _ => {}
315            }
316        }
317    }
318
319    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
320        for stmt in stmts {
321            self.check_node(stmt, scope);
322        }
323    }
324
325    /// Define variables from a destructuring pattern in the given scope (as unknown type).
326    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
327        match pattern {
328            BindingPattern::Identifier(name) => {
329                scope.define_var(name, None);
330            }
331            BindingPattern::Dict(fields) => {
332                for field in fields {
333                    let name = field.alias.as_deref().unwrap_or(&field.key);
334                    scope.define_var(name, None);
335                }
336            }
337            BindingPattern::List(elements) => {
338                for elem in elements {
339                    scope.define_var(&elem.name, None);
340                }
341            }
342        }
343    }
344
345    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
346        let span = snode.span;
347        match &snode.node {
348            Node::LetBinding {
349                pattern,
350                type_ann,
351                value,
352            } => {
353                let inferred = self.infer_type(value, scope);
354                if let BindingPattern::Identifier(name) = pattern {
355                    if let Some(expected) = type_ann {
356                        if let Some(actual) = &inferred {
357                            if !self.types_compatible(expected, actual, scope) {
358                                let mut msg = format!(
359                                    "Type mismatch: '{}' declared as {}, but assigned {}",
360                                    name,
361                                    format_type(expected),
362                                    format_type(actual)
363                                );
364                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
365                                    msg.push_str(&format!(" ({})", detail));
366                                }
367                                self.error_at(msg, span);
368                            }
369                        }
370                    }
371                    let ty = type_ann.clone().or(inferred);
372                    scope.define_var(name, ty);
373                } else {
374                    Self::define_pattern_vars(pattern, scope);
375                }
376            }
377
378            Node::VarBinding {
379                pattern,
380                type_ann,
381                value,
382            } => {
383                let inferred = self.infer_type(value, scope);
384                if let BindingPattern::Identifier(name) = pattern {
385                    if let Some(expected) = type_ann {
386                        if let Some(actual) = &inferred {
387                            if !self.types_compatible(expected, actual, scope) {
388                                let mut msg = format!(
389                                    "Type mismatch: '{}' declared as {}, but assigned {}",
390                                    name,
391                                    format_type(expected),
392                                    format_type(actual)
393                                );
394                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
395                                    msg.push_str(&format!(" ({})", detail));
396                                }
397                                self.error_at(msg, span);
398                            }
399                        }
400                    }
401                    let ty = type_ann.clone().or(inferred);
402                    scope.define_var(name, ty);
403                } else {
404                    Self::define_pattern_vars(pattern, scope);
405                }
406            }
407
408            Node::FnDecl {
409                name,
410                type_params,
411                params,
412                return_type,
413                body,
414                ..
415            } => {
416                let required_params = params
417                    .iter()
418                    .filter(|p| p.default_value.is_none())
419                    .count();
420                let sig = FnSignature {
421                    params: params
422                        .iter()
423                        .map(|p| (p.name.clone(), p.type_expr.clone()))
424                        .collect(),
425                    return_type: return_type.clone(),
426                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
427                    required_params,
428                };
429                scope.define_fn(name, sig.clone());
430                scope.define_var(name, None);
431                self.check_fn_body(type_params, params, return_type, body);
432            }
433
434            Node::FunctionCall { name, args } => {
435                self.check_call(name, args, scope, span);
436            }
437
438            Node::IfElse {
439                condition,
440                then_body,
441                else_body,
442            } => {
443                self.check_node(condition, scope);
444                let mut then_scope = scope.child();
445                self.check_block(then_body, &mut then_scope);
446                if let Some(else_body) = else_body {
447                    let mut else_scope = scope.child();
448                    self.check_block(else_body, &mut else_scope);
449                }
450            }
451
452            Node::ForIn {
453                pattern,
454                iterable,
455                body,
456            } => {
457                self.check_node(iterable, scope);
458                let mut loop_scope = scope.child();
459                if let BindingPattern::Identifier(variable) = pattern {
460                    // Infer loop variable type from iterable
461                    let elem_type = match self.infer_type(iterable, scope) {
462                        Some(TypeExpr::List(inner)) => Some(*inner),
463                        Some(TypeExpr::Named(n)) if n == "string" => {
464                            Some(TypeExpr::Named("string".into()))
465                        }
466                        _ => None,
467                    };
468                    loop_scope.define_var(variable, elem_type);
469                } else {
470                    Self::define_pattern_vars(pattern, &mut loop_scope);
471                }
472                self.check_block(body, &mut loop_scope);
473            }
474
475            Node::WhileLoop { condition, body } => {
476                self.check_node(condition, scope);
477                let mut loop_scope = scope.child();
478                self.check_block(body, &mut loop_scope);
479            }
480
481            Node::TryCatch {
482                body,
483                error_var,
484                catch_body,
485                finally_body,
486                ..
487            } => {
488                let mut try_scope = scope.child();
489                self.check_block(body, &mut try_scope);
490                let mut catch_scope = scope.child();
491                if let Some(var) = error_var {
492                    catch_scope.define_var(var, None);
493                }
494                self.check_block(catch_body, &mut catch_scope);
495                if let Some(fb) = finally_body {
496                    let mut finally_scope = scope.child();
497                    self.check_block(fb, &mut finally_scope);
498                }
499            }
500
501            Node::ReturnStmt {
502                value: Some(val), ..
503            } => {
504                self.check_node(val, scope);
505            }
506
507            Node::Assignment {
508                target, value, op, ..
509            } => {
510                self.check_node(value, scope);
511                if let Node::Identifier(name) = &target.node {
512                    if let Some(Some(var_type)) = scope.get_var(name) {
513                        let value_type = self.infer_type(value, scope);
514                        let assigned = if let Some(op) = op {
515                            let var_inferred = scope.get_var(name).cloned().flatten();
516                            infer_binary_op_type(op, &var_inferred, &value_type)
517                        } else {
518                            value_type
519                        };
520                        if let Some(actual) = &assigned {
521                            if !self.types_compatible(var_type, actual, scope) {
522                                self.error_at(
523                                    format!(
524                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
525                                        format_type(actual),
526                                        name,
527                                        format_type(var_type)
528                                    ),
529                                    span,
530                                );
531                            }
532                        }
533                    }
534                }
535            }
536
537            Node::TypeDecl { name, type_expr } => {
538                scope.type_aliases.insert(name.clone(), type_expr.clone());
539            }
540
541            Node::EnumDecl { name, variants } => {
542                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
543                scope.enums.insert(name.clone(), variant_names);
544            }
545
546            Node::StructDecl { name, fields } => {
547                let field_types: Vec<(String, InferredType)> = fields
548                    .iter()
549                    .map(|f| (f.name.clone(), f.type_expr.clone()))
550                    .collect();
551                scope.structs.insert(name.clone(), field_types);
552            }
553
554            Node::InterfaceDecl { name, methods } => {
555                scope.interfaces.insert(name.clone(), methods.clone());
556            }
557
558            Node::MatchExpr { value, arms } => {
559                self.check_node(value, scope);
560                let value_type = self.infer_type(value, scope);
561                for arm in arms {
562                    self.check_node(&arm.pattern, scope);
563                    // Check for incompatible literal pattern types
564                    if let Some(ref vt) = value_type {
565                        let value_type_name = format_type(vt);
566                        let mismatch = match &arm.pattern.node {
567                            Node::StringLiteral(_) => {
568                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
569                            }
570                            Node::IntLiteral(_) => {
571                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
572                                    && !self.types_compatible(
573                                        vt,
574                                        &TypeExpr::Named("float".into()),
575                                        scope,
576                                    )
577                            }
578                            Node::FloatLiteral(_) => {
579                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
580                                    && !self.types_compatible(
581                                        vt,
582                                        &TypeExpr::Named("int".into()),
583                                        scope,
584                                    )
585                            }
586                            Node::BoolLiteral(_) => {
587                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
588                            }
589                            _ => false,
590                        };
591                        if mismatch {
592                            let pattern_type = match &arm.pattern.node {
593                                Node::StringLiteral(_) => "string",
594                                Node::IntLiteral(_) => "int",
595                                Node::FloatLiteral(_) => "float",
596                                Node::BoolLiteral(_) => "bool",
597                                _ => unreachable!(),
598                            };
599                            self.warning_at(
600                                format!(
601                                    "Match pattern type mismatch: matching {} against {} literal",
602                                    value_type_name, pattern_type
603                                ),
604                                arm.pattern.span,
605                            );
606                        }
607                    }
608                    let mut arm_scope = scope.child();
609                    self.check_block(&arm.body, &mut arm_scope);
610                }
611                self.check_match_exhaustiveness(value, arms, scope, span);
612            }
613
614            // Recurse into nested expressions + validate binary op types
615            Node::BinaryOp { op, left, right } => {
616                self.check_node(left, scope);
617                self.check_node(right, scope);
618                // Validate operator/type compatibility
619                let lt = self.infer_type(left, scope);
620                let rt = self.infer_type(right, scope);
621                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
622                    match op.as_str() {
623                        "-" | "*" | "/" | "%" => {
624                            let numeric = ["int", "float"];
625                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
626                                self.warning_at(
627                                    format!(
628                                        "Operator '{op}' may not be valid for types {} and {}",
629                                        l, r
630                                    ),
631                                    span,
632                                );
633                            }
634                        }
635                        "+" => {
636                            // + is valid for int, float, string, list, dict
637                            let valid = ["int", "float", "string", "list", "dict"];
638                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
639                                self.warning_at(
640                                    format!(
641                                        "Operator '+' may not be valid for types {} and {}",
642                                        l, r
643                                    ),
644                                    span,
645                                );
646                            }
647                        }
648                        _ => {}
649                    }
650                }
651            }
652            Node::UnaryOp { operand, .. } => {
653                self.check_node(operand, scope);
654            }
655            Node::MethodCall { object, args, .. }
656            | Node::OptionalMethodCall { object, args, .. } => {
657                self.check_node(object, scope);
658                for arg in args {
659                    self.check_node(arg, scope);
660                }
661            }
662            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
663                self.check_node(object, scope);
664            }
665            Node::SubscriptAccess { object, index } => {
666                self.check_node(object, scope);
667                self.check_node(index, scope);
668            }
669            Node::SliceAccess { object, start, end } => {
670                self.check_node(object, scope);
671                if let Some(s) = start {
672                    self.check_node(s, scope);
673                }
674                if let Some(e) = end {
675                    self.check_node(e, scope);
676                }
677            }
678
679            // Terminals — nothing to check
680            _ => {}
681        }
682    }
683
684    fn check_fn_body(
685        &mut self,
686        type_params: &[TypeParam],
687        params: &[TypedParam],
688        return_type: &Option<TypeExpr>,
689        body: &[SNode],
690    ) {
691        let mut fn_scope = self.scope.child();
692        // Register generic type parameters so they are treated as compatible
693        // with any concrete type during type checking.
694        for tp in type_params {
695            fn_scope.generic_type_params.insert(tp.name.clone());
696        }
697        for param in params {
698            fn_scope.define_var(&param.name, param.type_expr.clone());
699            if let Some(default) = &param.default_value {
700                self.check_node(default, &mut fn_scope);
701            }
702        }
703        self.check_block(body, &mut fn_scope);
704
705        // Check return statements against declared return type
706        if let Some(ret_type) = return_type {
707            for stmt in body {
708                self.check_return_type(stmt, ret_type, &fn_scope);
709            }
710        }
711    }
712
713    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
714        let span = snode.span;
715        match &snode.node {
716            Node::ReturnStmt { value: Some(val) } => {
717                let inferred = self.infer_type(val, scope);
718                if let Some(actual) = &inferred {
719                    if !self.types_compatible(expected, actual, scope) {
720                        self.error_at(
721                            format!(
722                                "Return type mismatch: expected {}, got {}",
723                                format_type(expected),
724                                format_type(actual)
725                            ),
726                            span,
727                        );
728                    }
729                }
730            }
731            Node::IfElse {
732                then_body,
733                else_body,
734                ..
735            } => {
736                for stmt in then_body {
737                    self.check_return_type(stmt, expected, scope);
738                }
739                if let Some(else_body) = else_body {
740                    for stmt in else_body {
741                        self.check_return_type(stmt, expected, scope);
742                    }
743                }
744            }
745            _ => {}
746        }
747    }
748
749    /// Check if a match expression on an enum's `.variant` property covers all variants.
750    fn check_match_exhaustiveness(
751        &mut self,
752        value: &SNode,
753        arms: &[MatchArm],
754        scope: &TypeScope,
755        span: Span,
756    ) {
757        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
758        let enum_name = match &value.node {
759            Node::PropertyAccess { object, property } if property == "variant" => {
760                // Infer the type of the object
761                match self.infer_type(object, scope) {
762                    Some(TypeExpr::Named(name)) => {
763                        if scope.get_enum(&name).is_some() {
764                            Some(name)
765                        } else {
766                            None
767                        }
768                    }
769                    _ => None,
770                }
771            }
772            _ => {
773                // Direct match on an enum value: match <expr> { ... }
774                match self.infer_type(value, scope) {
775                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
776                    _ => None,
777                }
778            }
779        };
780
781        let Some(enum_name) = enum_name else {
782            return;
783        };
784        let Some(variants) = scope.get_enum(&enum_name) else {
785            return;
786        };
787
788        // Collect variant names covered by match arms
789        let mut covered: Vec<String> = Vec::new();
790        let mut has_wildcard = false;
791
792        for arm in arms {
793            match &arm.pattern.node {
794                // String literal pattern (matching on .variant): "VariantA"
795                Node::StringLiteral(s) => covered.push(s.clone()),
796                // Identifier pattern acts as a wildcard/catch-all
797                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
798                    has_wildcard = true;
799                }
800                // Direct enum construct pattern: EnumName.Variant
801                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
802                // PropertyAccess pattern: EnumName.Variant (no args)
803                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
804                _ => {
805                    // Unknown pattern shape — conservatively treat as wildcard
806                    has_wildcard = true;
807                }
808            }
809        }
810
811        if has_wildcard {
812            return;
813        }
814
815        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
816        if !missing.is_empty() {
817            let missing_str = missing
818                .iter()
819                .map(|s| format!("\"{}\"", s))
820                .collect::<Vec<_>>()
821                .join(", ");
822            self.warning_at(
823                format!(
824                    "Non-exhaustive match on enum {}: missing variants {}",
825                    enum_name, missing_str
826                ),
827                span,
828            );
829        }
830    }
831
832    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
833        // Check against known function signatures
834        if let Some(sig) = scope.get_fn(name).cloned() {
835            if !is_builtin(name)
836                && (args.len() < sig.required_params || args.len() > sig.params.len())
837            {
838                let expected = if sig.required_params == sig.params.len() {
839                    format!("{}", sig.params.len())
840                } else {
841                    format!("{}-{}", sig.required_params, sig.params.len())
842                };
843                self.warning_at(
844                    format!(
845                        "Function '{}' expects {} arguments, got {}",
846                        name,
847                        expected,
848                        args.len()
849                    ),
850                    span,
851                );
852            }
853            // Build a scope that includes the function's generic type params
854            // so they are treated as compatible with any concrete type.
855            let call_scope = if sig.type_param_names.is_empty() {
856                scope.clone()
857            } else {
858                let mut s = scope.child();
859                for tp_name in &sig.type_param_names {
860                    s.generic_type_params.insert(tp_name.clone());
861                }
862                s
863            };
864            for (i, (arg, (param_name, param_type))) in
865                args.iter().zip(sig.params.iter()).enumerate()
866            {
867                if let Some(expected) = param_type {
868                    let actual = self.infer_type(arg, scope);
869                    if let Some(actual) = &actual {
870                        if !self.types_compatible(expected, actual, &call_scope) {
871                            self.error_at(
872                                format!(
873                                    "Argument {} ('{}'): expected {}, got {}",
874                                    i + 1,
875                                    param_name,
876                                    format_type(expected),
877                                    format_type(actual)
878                                ),
879                                arg.span,
880                            );
881                        }
882                    }
883                }
884            }
885        }
886        // Check args recursively
887        for arg in args {
888            self.check_node(arg, scope);
889        }
890    }
891
892    /// Infer the type of an expression.
893    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
894        match &snode.node {
895            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
896            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
897            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
898                Some(TypeExpr::Named("string".into()))
899            }
900            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
901            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
902            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
903            Node::DictLiteral(entries) => {
904                // Infer shape type when all keys are string literals
905                let mut fields = Vec::new();
906                let mut all_string_keys = true;
907                for entry in entries {
908                    if let Node::StringLiteral(key) = &entry.key.node {
909                        let val_type = self
910                            .infer_type(&entry.value, scope)
911                            .unwrap_or(TypeExpr::Named("nil".into()));
912                        fields.push(ShapeField {
913                            name: key.clone(),
914                            type_expr: val_type,
915                            optional: false,
916                        });
917                    } else {
918                        all_string_keys = false;
919                        break;
920                    }
921                }
922                if all_string_keys && !fields.is_empty() {
923                    Some(TypeExpr::Shape(fields))
924                } else {
925                    Some(TypeExpr::Named("dict".into()))
926                }
927            }
928            Node::Closure { params, body } => {
929                // If all params are typed and we can infer a return type, produce FnType
930                let all_typed = params.iter().all(|p| p.type_expr.is_some());
931                if all_typed && !params.is_empty() {
932                    let param_types: Vec<TypeExpr> =
933                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
934                    // Try to infer return type from last expression in body
935                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
936                    if let Some(ret_type) = ret {
937                        return Some(TypeExpr::FnType {
938                            params: param_types,
939                            return_type: Box::new(ret_type),
940                        });
941                    }
942                }
943                Some(TypeExpr::Named("closure".into()))
944            }
945
946            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
947
948            Node::FunctionCall { name, .. } => {
949                // Check user-defined function return types
950                if let Some(sig) = scope.get_fn(name) {
951                    return sig.return_type.clone();
952                }
953                // Check builtin return types
954                builtin_return_type(name)
955            }
956
957            Node::BinaryOp { op, left, right } => {
958                let lt = self.infer_type(left, scope);
959                let rt = self.infer_type(right, scope);
960                infer_binary_op_type(op, &lt, &rt)
961            }
962
963            Node::UnaryOp { op, operand } => {
964                let t = self.infer_type(operand, scope);
965                match op.as_str() {
966                    "!" => Some(TypeExpr::Named("bool".into())),
967                    "-" => t, // negation preserves type
968                    _ => None,
969                }
970            }
971
972            Node::Ternary {
973                true_expr,
974                false_expr,
975                ..
976            } => {
977                let tt = self.infer_type(true_expr, scope);
978                let ft = self.infer_type(false_expr, scope);
979                match (&tt, &ft) {
980                    (Some(a), Some(b)) if a == b => tt,
981                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
982                    (Some(_), None) => tt,
983                    (None, Some(_)) => ft,
984                    (None, None) => None,
985                }
986            }
987
988            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
989
990            Node::PropertyAccess { object, property } => {
991                // EnumName.Variant → infer as the enum type
992                if let Node::Identifier(name) = &object.node {
993                    if scope.get_enum(name).is_some() {
994                        return Some(TypeExpr::Named(name.clone()));
995                    }
996                }
997                // .variant on an enum value → string
998                if property == "variant" {
999                    let obj_type = self.infer_type(object, scope);
1000                    if let Some(TypeExpr::Named(name)) = &obj_type {
1001                        if scope.get_enum(name).is_some() {
1002                            return Some(TypeExpr::Named("string".into()));
1003                        }
1004                    }
1005                }
1006                // Shape field access: obj.field → field type
1007                let obj_type = self.infer_type(object, scope);
1008                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1009                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1010                        return Some(field.type_expr.clone());
1011                    }
1012                }
1013                None
1014            }
1015
1016            Node::SubscriptAccess { object, index } => {
1017                let obj_type = self.infer_type(object, scope);
1018                match &obj_type {
1019                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1020                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1021                    Some(TypeExpr::Shape(fields)) => {
1022                        // If index is a string literal, look up the field type
1023                        if let Node::StringLiteral(key) = &index.node {
1024                            fields
1025                                .iter()
1026                                .find(|f| &f.name == key)
1027                                .map(|f| f.type_expr.clone())
1028                        } else {
1029                            None
1030                        }
1031                    }
1032                    Some(TypeExpr::Named(n)) if n == "list" => None,
1033                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1034                    Some(TypeExpr::Named(n)) if n == "string" => {
1035                        Some(TypeExpr::Named("string".into()))
1036                    }
1037                    _ => None,
1038                }
1039            }
1040            Node::SliceAccess { object, .. } => {
1041                // Slicing a list returns the same list type; slicing a string returns string
1042                let obj_type = self.infer_type(object, scope);
1043                match &obj_type {
1044                    Some(TypeExpr::List(_)) => obj_type,
1045                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1046                    Some(TypeExpr::Named(n)) if n == "string" => {
1047                        Some(TypeExpr::Named("string".into()))
1048                    }
1049                    _ => None,
1050                }
1051            }
1052            Node::MethodCall { object, method, .. }
1053            | Node::OptionalMethodCall { object, method, .. } => {
1054                let obj_type = self.infer_type(object, scope);
1055                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1056                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1057                match method.as_str() {
1058                    // Shared: bool-returning methods
1059                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1060                        Some(TypeExpr::Named("bool".into()))
1061                    }
1062                    // Shared: int-returning methods
1063                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1064                    // String methods
1065                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1066                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1067                        Some(TypeExpr::Named("string".into()))
1068                    }
1069                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1070                    // filter returns dict for dicts, list for lists
1071                    "filter" => {
1072                        if is_dict {
1073                            Some(TypeExpr::Named("dict".into()))
1074                        } else {
1075                            Some(TypeExpr::Named("list".into()))
1076                        }
1077                    }
1078                    // List methods
1079                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1080                    "reduce" | "find" | "first" | "last" => None,
1081                    // Dict methods
1082                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1083                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1084                    // Conversions
1085                    "to_string" => Some(TypeExpr::Named("string".into())),
1086                    "to_int" => Some(TypeExpr::Named("int".into())),
1087                    "to_float" => Some(TypeExpr::Named("float".into())),
1088                    _ => None,
1089                }
1090            }
1091
1092            _ => None,
1093        }
1094    }
1095
1096    /// Check if two types are compatible (actual can be assigned to expected).
1097    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1098        // Generic type parameters match anything.
1099        if let TypeExpr::Named(name) = expected {
1100            if scope.is_generic_type_param(name) {
1101                return true;
1102            }
1103        }
1104        if let TypeExpr::Named(name) = actual {
1105            if scope.is_generic_type_param(name) {
1106                return true;
1107            }
1108        }
1109        let expected = self.resolve_alias(expected, scope);
1110        let actual = self.resolve_alias(actual, scope);
1111
1112        match (&expected, &actual) {
1113            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1114            (TypeExpr::Union(members), actual_type) => members
1115                .iter()
1116                .any(|m| self.types_compatible(m, actual_type, scope)),
1117            (expected_type, TypeExpr::Union(members)) => members
1118                .iter()
1119                .all(|m| self.types_compatible(expected_type, m, scope)),
1120            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1121            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1122            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1123                if expected_field.optional {
1124                    return true;
1125                }
1126                af.iter().any(|actual_field| {
1127                    actual_field.name == expected_field.name
1128                        && self.types_compatible(
1129                            &expected_field.type_expr,
1130                            &actual_field.type_expr,
1131                            scope,
1132                        )
1133                })
1134            }),
1135            // dict<K, V> expected, Shape actual → all field values must match V
1136            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1137                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1138                keys_ok
1139                    && af
1140                        .iter()
1141                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1142            }
1143            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1144            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1145            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1146                self.types_compatible(expected_inner, actual_inner, scope)
1147            }
1148            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1149            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1150            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1151                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1152            }
1153            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1154            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1155            // FnType compatibility: params match positionally and return types match
1156            (
1157                TypeExpr::FnType {
1158                    params: ep,
1159                    return_type: er,
1160                },
1161                TypeExpr::FnType {
1162                    params: ap,
1163                    return_type: ar,
1164                },
1165            ) => {
1166                ep.len() == ap.len()
1167                    && ep
1168                        .iter()
1169                        .zip(ap.iter())
1170                        .all(|(e, a)| self.types_compatible(e, a, scope))
1171                    && self.types_compatible(er, ar, scope)
1172            }
1173            // FnType is compatible with Named("closure") for backward compat
1174            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1175            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1176            _ => false,
1177        }
1178    }
1179
1180    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1181        if let TypeExpr::Named(name) = ty {
1182            if let Some(resolved) = scope.resolve_type(name) {
1183                return resolved.clone();
1184            }
1185        }
1186        ty.clone()
1187    }
1188
1189    fn error_at(&mut self, message: String, span: Span) {
1190        self.diagnostics.push(TypeDiagnostic {
1191            message,
1192            severity: DiagnosticSeverity::Error,
1193            span: Some(span),
1194        });
1195    }
1196
1197    fn warning_at(&mut self, message: String, span: Span) {
1198        self.diagnostics.push(TypeDiagnostic {
1199            message,
1200            severity: DiagnosticSeverity::Warning,
1201            span: Some(span),
1202        });
1203    }
1204}
1205
1206impl Default for TypeChecker {
1207    fn default() -> Self {
1208        Self::new()
1209    }
1210}
1211
1212/// Infer the result type of a binary operation.
1213fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1214    match op {
1215        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1216        "+" => match (left, right) {
1217            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1218                match (l.as_str(), r.as_str()) {
1219                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1220                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1221                    ("string", _) => Some(TypeExpr::Named("string".into())),
1222                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1223                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1224                    _ => Some(TypeExpr::Named("string".into())),
1225                }
1226            }
1227            _ => None,
1228        },
1229        "-" | "*" | "/" | "%" => match (left, right) {
1230            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1231                match (l.as_str(), r.as_str()) {
1232                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1233                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1234                    _ => None,
1235                }
1236            }
1237            _ => None,
1238        },
1239        "??" => match (left, right) {
1240            (Some(TypeExpr::Union(members)), _) => {
1241                let non_nil: Vec<_> = members
1242                    .iter()
1243                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1244                    .cloned()
1245                    .collect();
1246                if non_nil.len() == 1 {
1247                    Some(non_nil[0].clone())
1248                } else if non_nil.is_empty() {
1249                    right.clone()
1250                } else {
1251                    Some(TypeExpr::Union(non_nil))
1252                }
1253            }
1254            _ => right.clone(),
1255        },
1256        "|>" => None,
1257        _ => None,
1258    }
1259}
1260
1261/// Format a type expression for display in error messages.
1262/// Produce a detail string describing why a Shape type is incompatible with
1263/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1264/// has type int, expected string".  Returns `None` if both types are not shapes.
1265pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1266    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1267        let mut details = Vec::new();
1268        for field in ef {
1269            if field.optional {
1270                continue;
1271            }
1272            match af.iter().find(|f| f.name == field.name) {
1273                None => details.push(format!(
1274                    "missing field '{}' ({})",
1275                    field.name,
1276                    format_type(&field.type_expr)
1277                )),
1278                Some(actual_field) => {
1279                    let e_str = format_type(&field.type_expr);
1280                    let a_str = format_type(&actual_field.type_expr);
1281                    if e_str != a_str {
1282                        details.push(format!(
1283                            "field '{}' has type {}, expected {}",
1284                            field.name, a_str, e_str
1285                        ));
1286                    }
1287                }
1288            }
1289        }
1290        if details.is_empty() {
1291            None
1292        } else {
1293            Some(details.join("; "))
1294        }
1295    } else {
1296        None
1297    }
1298}
1299
1300pub fn format_type(ty: &TypeExpr) -> String {
1301    match ty {
1302        TypeExpr::Named(n) => n.clone(),
1303        TypeExpr::Union(types) => types
1304            .iter()
1305            .map(format_type)
1306            .collect::<Vec<_>>()
1307            .join(" | "),
1308        TypeExpr::Shape(fields) => {
1309            let inner: Vec<String> = fields
1310                .iter()
1311                .map(|f| {
1312                    let opt = if f.optional { "?" } else { "" };
1313                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1314                })
1315                .collect();
1316            format!("{{{}}}", inner.join(", "))
1317        }
1318        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1319        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1320        TypeExpr::FnType {
1321            params,
1322            return_type,
1323        } => {
1324            let params_str = params
1325                .iter()
1326                .map(format_type)
1327                .collect::<Vec<_>>()
1328                .join(", ");
1329            format!("fn({}) -> {}", params_str, format_type(return_type))
1330        }
1331    }
1332}
1333
1334#[cfg(test)]
1335mod tests {
1336    use super::*;
1337    use crate::Parser;
1338    use harn_lexer::Lexer;
1339
1340    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1341        let mut lexer = Lexer::new(source);
1342        let tokens = lexer.tokenize().unwrap();
1343        let mut parser = Parser::new(tokens);
1344        let program = parser.parse().unwrap();
1345        TypeChecker::new().check(&program)
1346    }
1347
1348    fn errors(source: &str) -> Vec<String> {
1349        check_source(source)
1350            .into_iter()
1351            .filter(|d| d.severity == DiagnosticSeverity::Error)
1352            .map(|d| d.message)
1353            .collect()
1354    }
1355
1356    #[test]
1357    fn test_no_errors_for_untyped_code() {
1358        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1359        assert!(errs.is_empty());
1360    }
1361
1362    #[test]
1363    fn test_correct_typed_let() {
1364        let errs = errors("pipeline t(task) { let x: int = 42 }");
1365        assert!(errs.is_empty());
1366    }
1367
1368    #[test]
1369    fn test_type_mismatch_let() {
1370        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1371        assert_eq!(errs.len(), 1);
1372        assert!(errs[0].contains("Type mismatch"));
1373        assert!(errs[0].contains("int"));
1374        assert!(errs[0].contains("string"));
1375    }
1376
1377    #[test]
1378    fn test_correct_typed_fn() {
1379        let errs = errors(
1380            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1381        );
1382        assert!(errs.is_empty());
1383    }
1384
1385    #[test]
1386    fn test_fn_arg_type_mismatch() {
1387        let errs = errors(
1388            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1389add("hello", 2) }"#,
1390        );
1391        assert_eq!(errs.len(), 1);
1392        assert!(errs[0].contains("Argument 1"));
1393        assert!(errs[0].contains("expected int"));
1394    }
1395
1396    #[test]
1397    fn test_return_type_mismatch() {
1398        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1399        assert_eq!(errs.len(), 1);
1400        assert!(errs[0].contains("Return type mismatch"));
1401    }
1402
1403    #[test]
1404    fn test_union_type_compatible() {
1405        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1406        assert!(errs.is_empty());
1407    }
1408
1409    #[test]
1410    fn test_union_type_mismatch() {
1411        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1412        assert_eq!(errs.len(), 1);
1413        assert!(errs[0].contains("Type mismatch"));
1414    }
1415
1416    #[test]
1417    fn test_type_inference_propagation() {
1418        let errs = errors(
1419            r#"pipeline t(task) {
1420  fn add(a: int, b: int) -> int { return a + b }
1421  let result: string = add(1, 2)
1422}"#,
1423        );
1424        assert_eq!(errs.len(), 1);
1425        assert!(errs[0].contains("Type mismatch"));
1426        assert!(errs[0].contains("string"));
1427        assert!(errs[0].contains("int"));
1428    }
1429
1430    #[test]
1431    fn test_builtin_return_type_inference() {
1432        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1433        assert_eq!(errs.len(), 1);
1434        assert!(errs[0].contains("string"));
1435        assert!(errs[0].contains("int"));
1436    }
1437
1438    #[test]
1439    fn test_binary_op_type_inference() {
1440        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1441        assert_eq!(errs.len(), 1);
1442    }
1443
1444    #[test]
1445    fn test_comparison_returns_bool() {
1446        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1447        assert!(errs.is_empty());
1448    }
1449
1450    #[test]
1451    fn test_int_float_promotion() {
1452        let errs = errors("pipeline t(task) { let x: float = 42 }");
1453        assert!(errs.is_empty());
1454    }
1455
1456    #[test]
1457    fn test_untyped_code_no_errors() {
1458        let errs = errors(
1459            r#"pipeline t(task) {
1460  fn process(data) {
1461    let result = data + " processed"
1462    return result
1463  }
1464  log(process("hello"))
1465}"#,
1466        );
1467        assert!(errs.is_empty());
1468    }
1469
1470    #[test]
1471    fn test_type_alias() {
1472        let errs = errors(
1473            r#"pipeline t(task) {
1474  type Name = string
1475  let x: Name = "hello"
1476}"#,
1477        );
1478        assert!(errs.is_empty());
1479    }
1480
1481    #[test]
1482    fn test_type_alias_mismatch() {
1483        let errs = errors(
1484            r#"pipeline t(task) {
1485  type Name = string
1486  let x: Name = 42
1487}"#,
1488        );
1489        assert_eq!(errs.len(), 1);
1490    }
1491
1492    #[test]
1493    fn test_assignment_type_check() {
1494        let errs = errors(
1495            r#"pipeline t(task) {
1496  var x: int = 0
1497  x = "hello"
1498}"#,
1499        );
1500        assert_eq!(errs.len(), 1);
1501        assert!(errs[0].contains("cannot assign string"));
1502    }
1503
1504    #[test]
1505    fn test_covariance_int_to_float_in_fn() {
1506        let errs = errors(
1507            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1508        );
1509        assert!(errs.is_empty());
1510    }
1511
1512    #[test]
1513    fn test_covariance_return_type() {
1514        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1515        assert!(errs.is_empty());
1516    }
1517
1518    #[test]
1519    fn test_no_contravariance_float_to_int() {
1520        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1521        assert_eq!(errs.len(), 1);
1522    }
1523
1524    // --- Exhaustiveness checking tests ---
1525
1526    fn warnings(source: &str) -> Vec<String> {
1527        check_source(source)
1528            .into_iter()
1529            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1530            .map(|d| d.message)
1531            .collect()
1532    }
1533
1534    #[test]
1535    fn test_exhaustive_match_no_warning() {
1536        let warns = warnings(
1537            r#"pipeline t(task) {
1538  enum Color { Red, Green, Blue }
1539  let c = Color.Red
1540  match c.variant {
1541    "Red" -> { log("r") }
1542    "Green" -> { log("g") }
1543    "Blue" -> { log("b") }
1544  }
1545}"#,
1546        );
1547        let exhaustive_warns: Vec<_> = warns
1548            .iter()
1549            .filter(|w| w.contains("Non-exhaustive"))
1550            .collect();
1551        assert!(exhaustive_warns.is_empty());
1552    }
1553
1554    #[test]
1555    fn test_non_exhaustive_match_warning() {
1556        let warns = warnings(
1557            r#"pipeline t(task) {
1558  enum Color { Red, Green, Blue }
1559  let c = Color.Red
1560  match c.variant {
1561    "Red" -> { log("r") }
1562    "Green" -> { log("g") }
1563  }
1564}"#,
1565        );
1566        let exhaustive_warns: Vec<_> = warns
1567            .iter()
1568            .filter(|w| w.contains("Non-exhaustive"))
1569            .collect();
1570        assert_eq!(exhaustive_warns.len(), 1);
1571        assert!(exhaustive_warns[0].contains("Blue"));
1572    }
1573
1574    #[test]
1575    fn test_non_exhaustive_multiple_missing() {
1576        let warns = warnings(
1577            r#"pipeline t(task) {
1578  enum Status { Active, Inactive, Pending }
1579  let s = Status.Active
1580  match s.variant {
1581    "Active" -> { log("a") }
1582  }
1583}"#,
1584        );
1585        let exhaustive_warns: Vec<_> = warns
1586            .iter()
1587            .filter(|w| w.contains("Non-exhaustive"))
1588            .collect();
1589        assert_eq!(exhaustive_warns.len(), 1);
1590        assert!(exhaustive_warns[0].contains("Inactive"));
1591        assert!(exhaustive_warns[0].contains("Pending"));
1592    }
1593
1594    #[test]
1595    fn test_enum_construct_type_inference() {
1596        let errs = errors(
1597            r#"pipeline t(task) {
1598  enum Color { Red, Green, Blue }
1599  let c: Color = Color.Red
1600}"#,
1601        );
1602        assert!(errs.is_empty());
1603    }
1604
1605    // --- Type narrowing tests ---
1606
1607    #[test]
1608    fn test_nil_coalescing_strips_nil() {
1609        // After ??, nil should be stripped from the type
1610        let errs = errors(
1611            r#"pipeline t(task) {
1612  let x: string | nil = nil
1613  let y: string = x ?? "default"
1614}"#,
1615        );
1616        assert!(errs.is_empty());
1617    }
1618
1619    #[test]
1620    fn test_shape_mismatch_detail_missing_field() {
1621        let errs = errors(
1622            r#"pipeline t(task) {
1623  let x: {name: string, age: int} = {name: "hello"}
1624}"#,
1625        );
1626        assert_eq!(errs.len(), 1);
1627        assert!(
1628            errs[0].contains("missing field 'age'"),
1629            "expected detail about missing field, got: {}",
1630            errs[0]
1631        );
1632    }
1633
1634    #[test]
1635    fn test_shape_mismatch_detail_wrong_type() {
1636        let errs = errors(
1637            r#"pipeline t(task) {
1638  let x: {name: string, age: int} = {name: 42, age: 10}
1639}"#,
1640        );
1641        assert_eq!(errs.len(), 1);
1642        assert!(
1643            errs[0].contains("field 'name' has type int, expected string"),
1644            "expected detail about wrong type, got: {}",
1645            errs[0]
1646        );
1647    }
1648
1649    // --- Match pattern type validation tests ---
1650
1651    #[test]
1652    fn test_match_pattern_string_against_int() {
1653        let warns = warnings(
1654            r#"pipeline t(task) {
1655  let x: int = 42
1656  match x {
1657    "hello" -> { log("bad") }
1658    42 -> { log("ok") }
1659  }
1660}"#,
1661        );
1662        let pattern_warns: Vec<_> = warns
1663            .iter()
1664            .filter(|w| w.contains("Match pattern type mismatch"))
1665            .collect();
1666        assert_eq!(pattern_warns.len(), 1);
1667        assert!(pattern_warns[0].contains("matching int against string literal"));
1668    }
1669
1670    #[test]
1671    fn test_match_pattern_int_against_string() {
1672        let warns = warnings(
1673            r#"pipeline t(task) {
1674  let x: string = "hello"
1675  match x {
1676    42 -> { log("bad") }
1677    "hello" -> { log("ok") }
1678  }
1679}"#,
1680        );
1681        let pattern_warns: Vec<_> = warns
1682            .iter()
1683            .filter(|w| w.contains("Match pattern type mismatch"))
1684            .collect();
1685        assert_eq!(pattern_warns.len(), 1);
1686        assert!(pattern_warns[0].contains("matching string against int literal"));
1687    }
1688
1689    #[test]
1690    fn test_match_pattern_bool_against_int() {
1691        let warns = warnings(
1692            r#"pipeline t(task) {
1693  let x: int = 42
1694  match x {
1695    true -> { log("bad") }
1696    42 -> { log("ok") }
1697  }
1698}"#,
1699        );
1700        let pattern_warns: Vec<_> = warns
1701            .iter()
1702            .filter(|w| w.contains("Match pattern type mismatch"))
1703            .collect();
1704        assert_eq!(pattern_warns.len(), 1);
1705        assert!(pattern_warns[0].contains("matching int against bool literal"));
1706    }
1707
1708    #[test]
1709    fn test_match_pattern_float_against_string() {
1710        let warns = warnings(
1711            r#"pipeline t(task) {
1712  let x: string = "hello"
1713  match x {
1714    3.14 -> { log("bad") }
1715    "hello" -> { log("ok") }
1716  }
1717}"#,
1718        );
1719        let pattern_warns: Vec<_> = warns
1720            .iter()
1721            .filter(|w| w.contains("Match pattern type mismatch"))
1722            .collect();
1723        assert_eq!(pattern_warns.len(), 1);
1724        assert!(pattern_warns[0].contains("matching string against float literal"));
1725    }
1726
1727    #[test]
1728    fn test_match_pattern_int_against_float_ok() {
1729        // int and float are compatible for match patterns
1730        let warns = warnings(
1731            r#"pipeline t(task) {
1732  let x: float = 3.14
1733  match x {
1734    42 -> { log("ok") }
1735    _ -> { log("default") }
1736  }
1737}"#,
1738        );
1739        let pattern_warns: Vec<_> = warns
1740            .iter()
1741            .filter(|w| w.contains("Match pattern type mismatch"))
1742            .collect();
1743        assert!(pattern_warns.is_empty());
1744    }
1745
1746    #[test]
1747    fn test_match_pattern_float_against_int_ok() {
1748        // float and int are compatible for match patterns
1749        let warns = warnings(
1750            r#"pipeline t(task) {
1751  let x: int = 42
1752  match x {
1753    3.14 -> { log("close") }
1754    _ -> { log("default") }
1755  }
1756}"#,
1757        );
1758        let pattern_warns: Vec<_> = warns
1759            .iter()
1760            .filter(|w| w.contains("Match pattern type mismatch"))
1761            .collect();
1762        assert!(pattern_warns.is_empty());
1763    }
1764
1765    #[test]
1766    fn test_match_pattern_correct_types_no_warning() {
1767        let warns = warnings(
1768            r#"pipeline t(task) {
1769  let x: int = 42
1770  match x {
1771    1 -> { log("one") }
1772    2 -> { log("two") }
1773    _ -> { log("other") }
1774  }
1775}"#,
1776        );
1777        let pattern_warns: Vec<_> = warns
1778            .iter()
1779            .filter(|w| w.contains("Match pattern type mismatch"))
1780            .collect();
1781        assert!(pattern_warns.is_empty());
1782    }
1783
1784    #[test]
1785    fn test_match_pattern_wildcard_no_warning() {
1786        let warns = warnings(
1787            r#"pipeline t(task) {
1788  let x: int = 42
1789  match x {
1790    _ -> { log("catch all") }
1791  }
1792}"#,
1793        );
1794        let pattern_warns: Vec<_> = warns
1795            .iter()
1796            .filter(|w| w.contains("Match pattern type mismatch"))
1797            .collect();
1798        assert!(pattern_warns.is_empty());
1799    }
1800
1801    #[test]
1802    fn test_match_pattern_untyped_no_warning() {
1803        // When value has no known type, no warning should be emitted
1804        let warns = warnings(
1805            r#"pipeline t(task) {
1806  let x = some_unknown_fn()
1807  match x {
1808    "hello" -> { log("string") }
1809    42 -> { log("int") }
1810  }
1811}"#,
1812        );
1813        let pattern_warns: Vec<_> = warns
1814            .iter()
1815            .filter(|w| w.contains("Match pattern type mismatch"))
1816            .collect();
1817        assert!(pattern_warns.is_empty());
1818    }
1819}