Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Generic type parameter names in scope (treated as compatible with any type).
39    generic_type_params: std::collections::BTreeSet<String>,
40    parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45    params: Vec<(String, InferredType)>,
46    return_type: InferredType,
47    /// Generic type parameter names declared on the function.
48    type_param_names: Vec<String>,
49    /// Number of required parameters (those without defaults).
50    required_params: usize,
51}
52
53impl TypeScope {
54    fn new() -> Self {
55        Self {
56            vars: BTreeMap::new(),
57            functions: BTreeMap::new(),
58            type_aliases: BTreeMap::new(),
59            enums: BTreeMap::new(),
60            interfaces: BTreeMap::new(),
61            structs: BTreeMap::new(),
62            generic_type_params: std::collections::BTreeSet::new(),
63            parent: None,
64        }
65    }
66
67    fn child(&self) -> Self {
68        Self {
69            vars: BTreeMap::new(),
70            functions: BTreeMap::new(),
71            type_aliases: BTreeMap::new(),
72            enums: BTreeMap::new(),
73            interfaces: BTreeMap::new(),
74            structs: BTreeMap::new(),
75            generic_type_params: std::collections::BTreeSet::new(),
76            parent: Some(Box::new(self.clone())),
77        }
78    }
79
80    fn get_var(&self, name: &str) -> Option<&InferredType> {
81        self.vars
82            .get(name)
83            .or_else(|| self.parent.as_ref()?.get_var(name))
84    }
85
86    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87        self.functions
88            .get(name)
89            .or_else(|| self.parent.as_ref()?.get_fn(name))
90    }
91
92    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93        self.type_aliases
94            .get(name)
95            .or_else(|| self.parent.as_ref()?.resolve_type(name))
96    }
97
98    fn is_generic_type_param(&self, name: &str) -> bool {
99        self.generic_type_params.contains(name)
100            || self
101                .parent
102                .as_ref()
103                .is_some_and(|p| p.is_generic_type_param(name))
104    }
105
106    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107        self.enums
108            .get(name)
109            .or_else(|| self.parent.as_ref()?.get_enum(name))
110    }
111
112    #[allow(dead_code)]
113    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114        self.interfaces
115            .get(name)
116            .or_else(|| self.parent.as_ref()?.get_interface(name))
117    }
118
119    fn define_var(&mut self, name: &str, ty: InferredType) {
120        self.vars.insert(name.to_string(), ty);
121    }
122
123    fn define_fn(&mut self, name: &str, sig: FnSignature) {
124        self.functions.insert(name.to_string(), sig);
125    }
126}
127
128/// Known return types for builtin functions.
129fn builtin_return_type(name: &str) -> InferredType {
130    match name {
131        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133            Some(TypeExpr::Named("nil".into()))
134        }
135        "type_of"
136        | "to_string"
137        | "json_stringify"
138        | "read_file"
139        | "http_get"
140        | "http_post"
141        | "llm_call"
142        | "regex_replace"
143        | "path_join"
144        | "temp_dir"
145        | "date_format"
146        | "format"
147        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
148        "to_int" | "timer_end" | "elapsed" => Some(TypeExpr::Named("int".into())),
149        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
150        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
151        "list_dir" | "mcp_list_tools" | "mcp_list_resources" | "mcp_list_prompts" => {
152            Some(TypeExpr::Named("list".into()))
153        }
154        "stat" | "exec" | "shell" | "date_now" | "agent_loop" | "llm_info" | "llm_usage"
155        | "timer_start" | "metadata_get" | "mcp_server_info" | "mcp_get_prompt" => {
156            Some(TypeExpr::Named("dict".into()))
157        }
158        "metadata_set"
159        | "metadata_save"
160        | "metadata_refresh_hashes"
161        | "invalidate_facts"
162        | "log_json"
163        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
164        "env" | "regex_match" => Some(TypeExpr::Union(vec![
165            TypeExpr::Named("string".into()),
166            TypeExpr::Named("nil".into()),
167        ])),
168        "json_parse" | "json_extract" => None, // could be any type
169        _ => None,
170    }
171}
172
173/// Check if a name is a known builtin.
174fn is_builtin(name: &str) -> bool {
175    matches!(
176        name,
177        "log"
178            | "print"
179            | "println"
180            | "type_of"
181            | "to_string"
182            | "to_int"
183            | "to_float"
184            | "json_stringify"
185            | "json_parse"
186            | "env"
187            | "timestamp"
188            | "sleep"
189            | "read_file"
190            | "write_file"
191            | "exit"
192            | "regex_match"
193            | "regex_replace"
194            | "http_get"
195            | "http_post"
196            | "llm_call"
197            | "agent_loop"
198            | "await"
199            | "cancel"
200            | "file_exists"
201            | "delete_file"
202            | "list_dir"
203            | "mkdir"
204            | "path_join"
205            | "copy_file"
206            | "append_file"
207            | "temp_dir"
208            | "stat"
209            | "exec"
210            | "shell"
211            | "date_now"
212            | "date_format"
213            | "date_parse"
214            | "format"
215            | "json_validate"
216            | "json_extract"
217            | "trim"
218            | "lowercase"
219            | "uppercase"
220            | "split"
221            | "starts_with"
222            | "ends_with"
223            | "contains"
224            | "replace"
225            | "join"
226            | "len"
227            | "substring"
228            | "dirname"
229            | "basename"
230            | "extname"
231    )
232}
233
234/// The static type checker.
235pub struct TypeChecker {
236    diagnostics: Vec<TypeDiagnostic>,
237    scope: TypeScope,
238}
239
240impl TypeChecker {
241    pub fn new() -> Self {
242        Self {
243            diagnostics: Vec::new(),
244            scope: TypeScope::new(),
245        }
246    }
247
248    /// Check a program and return diagnostics.
249    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
250        // First pass: register type and enum declarations into root scope
251        Self::register_declarations_into(&mut self.scope, program);
252
253        // Also scan pipeline bodies for declarations
254        for snode in program {
255            if let Node::Pipeline { body, .. } = &snode.node {
256                Self::register_declarations_into(&mut self.scope, body);
257            }
258        }
259
260        // Check each top-level node
261        for snode in program {
262            match &snode.node {
263                Node::Pipeline { params, body, .. } => {
264                    let mut child = self.scope.child();
265                    for p in params {
266                        child.define_var(p, None);
267                    }
268                    self.check_block(body, &mut child);
269                }
270                Node::FnDecl {
271                    name,
272                    type_params,
273                    params,
274                    return_type,
275                    body,
276                    ..
277                } => {
278                    let required_params =
279                        params.iter().filter(|p| p.default_value.is_none()).count();
280                    let sig = FnSignature {
281                        params: params
282                            .iter()
283                            .map(|p| (p.name.clone(), p.type_expr.clone()))
284                            .collect(),
285                        return_type: return_type.clone(),
286                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
287                        required_params,
288                    };
289                    self.scope.define_fn(name, sig);
290                    self.check_fn_body(type_params, params, return_type, body);
291                }
292                _ => {
293                    let mut scope = self.scope.clone();
294                    self.check_node(snode, &mut scope);
295                    // Merge any new definitions back into the top-level scope
296                    for (name, ty) in scope.vars {
297                        self.scope.vars.entry(name).or_insert(ty);
298                    }
299                }
300            }
301        }
302
303        self.diagnostics
304    }
305
306    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
307    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
308        for snode in nodes {
309            match &snode.node {
310                Node::TypeDecl { name, type_expr } => {
311                    scope.type_aliases.insert(name.clone(), type_expr.clone());
312                }
313                Node::EnumDecl { name, variants } => {
314                    let variant_names: Vec<String> =
315                        variants.iter().map(|v| v.name.clone()).collect();
316                    scope.enums.insert(name.clone(), variant_names);
317                }
318                Node::InterfaceDecl { name, methods } => {
319                    scope.interfaces.insert(name.clone(), methods.clone());
320                }
321                Node::StructDecl { name, fields } => {
322                    let field_types: Vec<(String, InferredType)> = fields
323                        .iter()
324                        .map(|f| (f.name.clone(), f.type_expr.clone()))
325                        .collect();
326                    scope.structs.insert(name.clone(), field_types);
327                }
328                _ => {}
329            }
330        }
331    }
332
333    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
334        for stmt in stmts {
335            self.check_node(stmt, scope);
336        }
337    }
338
339    /// Define variables from a destructuring pattern in the given scope (as unknown type).
340    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
341        match pattern {
342            BindingPattern::Identifier(name) => {
343                scope.define_var(name, None);
344            }
345            BindingPattern::Dict(fields) => {
346                for field in fields {
347                    let name = field.alias.as_deref().unwrap_or(&field.key);
348                    scope.define_var(name, None);
349                }
350            }
351            BindingPattern::List(elements) => {
352                for elem in elements {
353                    scope.define_var(&elem.name, None);
354                }
355            }
356        }
357    }
358
359    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
360        let span = snode.span;
361        match &snode.node {
362            Node::LetBinding {
363                pattern,
364                type_ann,
365                value,
366            } => {
367                let inferred = self.infer_type(value, scope);
368                if let BindingPattern::Identifier(name) = pattern {
369                    if let Some(expected) = type_ann {
370                        if let Some(actual) = &inferred {
371                            if !self.types_compatible(expected, actual, scope) {
372                                let mut msg = format!(
373                                    "Type mismatch: '{}' declared as {}, but assigned {}",
374                                    name,
375                                    format_type(expected),
376                                    format_type(actual)
377                                );
378                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
379                                    msg.push_str(&format!(" ({})", detail));
380                                }
381                                self.error_at(msg, span);
382                            }
383                        }
384                    }
385                    let ty = type_ann.clone().or(inferred);
386                    scope.define_var(name, ty);
387                } else {
388                    Self::define_pattern_vars(pattern, scope);
389                }
390            }
391
392            Node::VarBinding {
393                pattern,
394                type_ann,
395                value,
396            } => {
397                let inferred = self.infer_type(value, scope);
398                if let BindingPattern::Identifier(name) = pattern {
399                    if let Some(expected) = type_ann {
400                        if let Some(actual) = &inferred {
401                            if !self.types_compatible(expected, actual, scope) {
402                                let mut msg = format!(
403                                    "Type mismatch: '{}' declared as {}, but assigned {}",
404                                    name,
405                                    format_type(expected),
406                                    format_type(actual)
407                                );
408                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
409                                    msg.push_str(&format!(" ({})", detail));
410                                }
411                                self.error_at(msg, span);
412                            }
413                        }
414                    }
415                    let ty = type_ann.clone().or(inferred);
416                    scope.define_var(name, ty);
417                } else {
418                    Self::define_pattern_vars(pattern, scope);
419                }
420            }
421
422            Node::FnDecl {
423                name,
424                type_params,
425                params,
426                return_type,
427                body,
428                ..
429            } => {
430                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
431                let sig = FnSignature {
432                    params: params
433                        .iter()
434                        .map(|p| (p.name.clone(), p.type_expr.clone()))
435                        .collect(),
436                    return_type: return_type.clone(),
437                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
438                    required_params,
439                };
440                scope.define_fn(name, sig.clone());
441                scope.define_var(name, None);
442                self.check_fn_body(type_params, params, return_type, body);
443            }
444
445            Node::FunctionCall { name, args } => {
446                self.check_call(name, args, scope, span);
447            }
448
449            Node::IfElse {
450                condition,
451                then_body,
452                else_body,
453            } => {
454                self.check_node(condition, scope);
455                let mut then_scope = scope.child();
456                self.check_block(then_body, &mut then_scope);
457                if let Some(else_body) = else_body {
458                    let mut else_scope = scope.child();
459                    self.check_block(else_body, &mut else_scope);
460                }
461            }
462
463            Node::ForIn {
464                pattern,
465                iterable,
466                body,
467            } => {
468                self.check_node(iterable, scope);
469                let mut loop_scope = scope.child();
470                if let BindingPattern::Identifier(variable) = pattern {
471                    // Infer loop variable type from iterable
472                    let elem_type = match self.infer_type(iterable, scope) {
473                        Some(TypeExpr::List(inner)) => Some(*inner),
474                        Some(TypeExpr::Named(n)) if n == "string" => {
475                            Some(TypeExpr::Named("string".into()))
476                        }
477                        _ => None,
478                    };
479                    loop_scope.define_var(variable, elem_type);
480                } else {
481                    Self::define_pattern_vars(pattern, &mut loop_scope);
482                }
483                self.check_block(body, &mut loop_scope);
484            }
485
486            Node::WhileLoop { condition, body } => {
487                self.check_node(condition, scope);
488                let mut loop_scope = scope.child();
489                self.check_block(body, &mut loop_scope);
490            }
491
492            Node::TryCatch {
493                body,
494                error_var,
495                catch_body,
496                finally_body,
497                ..
498            } => {
499                let mut try_scope = scope.child();
500                self.check_block(body, &mut try_scope);
501                let mut catch_scope = scope.child();
502                if let Some(var) = error_var {
503                    catch_scope.define_var(var, None);
504                }
505                self.check_block(catch_body, &mut catch_scope);
506                if let Some(fb) = finally_body {
507                    let mut finally_scope = scope.child();
508                    self.check_block(fb, &mut finally_scope);
509                }
510            }
511
512            Node::ReturnStmt {
513                value: Some(val), ..
514            } => {
515                self.check_node(val, scope);
516            }
517
518            Node::Assignment {
519                target, value, op, ..
520            } => {
521                self.check_node(value, scope);
522                if let Node::Identifier(name) = &target.node {
523                    if let Some(Some(var_type)) = scope.get_var(name) {
524                        let value_type = self.infer_type(value, scope);
525                        let assigned = if let Some(op) = op {
526                            let var_inferred = scope.get_var(name).cloned().flatten();
527                            infer_binary_op_type(op, &var_inferred, &value_type)
528                        } else {
529                            value_type
530                        };
531                        if let Some(actual) = &assigned {
532                            if !self.types_compatible(var_type, actual, scope) {
533                                self.error_at(
534                                    format!(
535                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
536                                        format_type(actual),
537                                        name,
538                                        format_type(var_type)
539                                    ),
540                                    span,
541                                );
542                            }
543                        }
544                    }
545                }
546            }
547
548            Node::TypeDecl { name, type_expr } => {
549                scope.type_aliases.insert(name.clone(), type_expr.clone());
550            }
551
552            Node::EnumDecl { name, variants } => {
553                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
554                scope.enums.insert(name.clone(), variant_names);
555            }
556
557            Node::StructDecl { name, fields } => {
558                let field_types: Vec<(String, InferredType)> = fields
559                    .iter()
560                    .map(|f| (f.name.clone(), f.type_expr.clone()))
561                    .collect();
562                scope.structs.insert(name.clone(), field_types);
563            }
564
565            Node::InterfaceDecl { name, methods } => {
566                scope.interfaces.insert(name.clone(), methods.clone());
567            }
568
569            Node::MatchExpr { value, arms } => {
570                self.check_node(value, scope);
571                let value_type = self.infer_type(value, scope);
572                for arm in arms {
573                    self.check_node(&arm.pattern, scope);
574                    // Check for incompatible literal pattern types
575                    if let Some(ref vt) = value_type {
576                        let value_type_name = format_type(vt);
577                        let mismatch = match &arm.pattern.node {
578                            Node::StringLiteral(_) => {
579                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
580                            }
581                            Node::IntLiteral(_) => {
582                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
583                                    && !self.types_compatible(
584                                        vt,
585                                        &TypeExpr::Named("float".into()),
586                                        scope,
587                                    )
588                            }
589                            Node::FloatLiteral(_) => {
590                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
591                                    && !self.types_compatible(
592                                        vt,
593                                        &TypeExpr::Named("int".into()),
594                                        scope,
595                                    )
596                            }
597                            Node::BoolLiteral(_) => {
598                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
599                            }
600                            _ => false,
601                        };
602                        if mismatch {
603                            let pattern_type = match &arm.pattern.node {
604                                Node::StringLiteral(_) => "string",
605                                Node::IntLiteral(_) => "int",
606                                Node::FloatLiteral(_) => "float",
607                                Node::BoolLiteral(_) => "bool",
608                                _ => unreachable!(),
609                            };
610                            self.warning_at(
611                                format!(
612                                    "Match pattern type mismatch: matching {} against {} literal",
613                                    value_type_name, pattern_type
614                                ),
615                                arm.pattern.span,
616                            );
617                        }
618                    }
619                    let mut arm_scope = scope.child();
620                    self.check_block(&arm.body, &mut arm_scope);
621                }
622                self.check_match_exhaustiveness(value, arms, scope, span);
623            }
624
625            // Recurse into nested expressions + validate binary op types
626            Node::BinaryOp { op, left, right } => {
627                self.check_node(left, scope);
628                self.check_node(right, scope);
629                // Validate operator/type compatibility
630                let lt = self.infer_type(left, scope);
631                let rt = self.infer_type(right, scope);
632                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
633                    match op.as_str() {
634                        "-" | "*" | "/" | "%" => {
635                            let numeric = ["int", "float"];
636                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
637                                self.warning_at(
638                                    format!(
639                                        "Operator '{op}' may not be valid for types {} and {}",
640                                        l, r
641                                    ),
642                                    span,
643                                );
644                            }
645                        }
646                        "+" => {
647                            // + is valid for int, float, string, list, dict
648                            let valid = ["int", "float", "string", "list", "dict"];
649                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
650                                self.warning_at(
651                                    format!(
652                                        "Operator '+' may not be valid for types {} and {}",
653                                        l, r
654                                    ),
655                                    span,
656                                );
657                            }
658                        }
659                        _ => {}
660                    }
661                }
662            }
663            Node::UnaryOp { operand, .. } => {
664                self.check_node(operand, scope);
665            }
666            Node::MethodCall { object, args, .. }
667            | Node::OptionalMethodCall { object, args, .. } => {
668                self.check_node(object, scope);
669                for arg in args {
670                    self.check_node(arg, scope);
671                }
672            }
673            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
674                self.check_node(object, scope);
675            }
676            Node::SubscriptAccess { object, index } => {
677                self.check_node(object, scope);
678                self.check_node(index, scope);
679            }
680            Node::SliceAccess { object, start, end } => {
681                self.check_node(object, scope);
682                if let Some(s) = start {
683                    self.check_node(s, scope);
684                }
685                if let Some(e) = end {
686                    self.check_node(e, scope);
687                }
688            }
689
690            // Terminals — nothing to check
691            _ => {}
692        }
693    }
694
695    fn check_fn_body(
696        &mut self,
697        type_params: &[TypeParam],
698        params: &[TypedParam],
699        return_type: &Option<TypeExpr>,
700        body: &[SNode],
701    ) {
702        let mut fn_scope = self.scope.child();
703        // Register generic type parameters so they are treated as compatible
704        // with any concrete type during type checking.
705        for tp in type_params {
706            fn_scope.generic_type_params.insert(tp.name.clone());
707        }
708        for param in params {
709            fn_scope.define_var(&param.name, param.type_expr.clone());
710            if let Some(default) = &param.default_value {
711                self.check_node(default, &mut fn_scope);
712            }
713        }
714        self.check_block(body, &mut fn_scope);
715
716        // Check return statements against declared return type
717        if let Some(ret_type) = return_type {
718            for stmt in body {
719                self.check_return_type(stmt, ret_type, &fn_scope);
720            }
721        }
722    }
723
724    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
725        let span = snode.span;
726        match &snode.node {
727            Node::ReturnStmt { value: Some(val) } => {
728                let inferred = self.infer_type(val, scope);
729                if let Some(actual) = &inferred {
730                    if !self.types_compatible(expected, actual, scope) {
731                        self.error_at(
732                            format!(
733                                "Return type mismatch: expected {}, got {}",
734                                format_type(expected),
735                                format_type(actual)
736                            ),
737                            span,
738                        );
739                    }
740                }
741            }
742            Node::IfElse {
743                then_body,
744                else_body,
745                ..
746            } => {
747                for stmt in then_body {
748                    self.check_return_type(stmt, expected, scope);
749                }
750                if let Some(else_body) = else_body {
751                    for stmt in else_body {
752                        self.check_return_type(stmt, expected, scope);
753                    }
754                }
755            }
756            _ => {}
757        }
758    }
759
760    /// Check if a match expression on an enum's `.variant` property covers all variants.
761    fn check_match_exhaustiveness(
762        &mut self,
763        value: &SNode,
764        arms: &[MatchArm],
765        scope: &TypeScope,
766        span: Span,
767    ) {
768        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
769        let enum_name = match &value.node {
770            Node::PropertyAccess { object, property } if property == "variant" => {
771                // Infer the type of the object
772                match self.infer_type(object, scope) {
773                    Some(TypeExpr::Named(name)) => {
774                        if scope.get_enum(&name).is_some() {
775                            Some(name)
776                        } else {
777                            None
778                        }
779                    }
780                    _ => None,
781                }
782            }
783            _ => {
784                // Direct match on an enum value: match <expr> { ... }
785                match self.infer_type(value, scope) {
786                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
787                    _ => None,
788                }
789            }
790        };
791
792        let Some(enum_name) = enum_name else {
793            return;
794        };
795        let Some(variants) = scope.get_enum(&enum_name) else {
796            return;
797        };
798
799        // Collect variant names covered by match arms
800        let mut covered: Vec<String> = Vec::new();
801        let mut has_wildcard = false;
802
803        for arm in arms {
804            match &arm.pattern.node {
805                // String literal pattern (matching on .variant): "VariantA"
806                Node::StringLiteral(s) => covered.push(s.clone()),
807                // Identifier pattern acts as a wildcard/catch-all
808                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
809                    has_wildcard = true;
810                }
811                // Direct enum construct pattern: EnumName.Variant
812                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
813                // PropertyAccess pattern: EnumName.Variant (no args)
814                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
815                _ => {
816                    // Unknown pattern shape — conservatively treat as wildcard
817                    has_wildcard = true;
818                }
819            }
820        }
821
822        if has_wildcard {
823            return;
824        }
825
826        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
827        if !missing.is_empty() {
828            let missing_str = missing
829                .iter()
830                .map(|s| format!("\"{}\"", s))
831                .collect::<Vec<_>>()
832                .join(", ");
833            self.warning_at(
834                format!(
835                    "Non-exhaustive match on enum {}: missing variants {}",
836                    enum_name, missing_str
837                ),
838                span,
839            );
840        }
841    }
842
843    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
844        // Check against known function signatures
845        if let Some(sig) = scope.get_fn(name).cloned() {
846            if !is_builtin(name)
847                && (args.len() < sig.required_params || args.len() > sig.params.len())
848            {
849                let expected = if sig.required_params == sig.params.len() {
850                    format!("{}", sig.params.len())
851                } else {
852                    format!("{}-{}", sig.required_params, sig.params.len())
853                };
854                self.warning_at(
855                    format!(
856                        "Function '{}' expects {} arguments, got {}",
857                        name,
858                        expected,
859                        args.len()
860                    ),
861                    span,
862                );
863            }
864            // Build a scope that includes the function's generic type params
865            // so they are treated as compatible with any concrete type.
866            let call_scope = if sig.type_param_names.is_empty() {
867                scope.clone()
868            } else {
869                let mut s = scope.child();
870                for tp_name in &sig.type_param_names {
871                    s.generic_type_params.insert(tp_name.clone());
872                }
873                s
874            };
875            for (i, (arg, (param_name, param_type))) in
876                args.iter().zip(sig.params.iter()).enumerate()
877            {
878                if let Some(expected) = param_type {
879                    let actual = self.infer_type(arg, scope);
880                    if let Some(actual) = &actual {
881                        if !self.types_compatible(expected, actual, &call_scope) {
882                            self.error_at(
883                                format!(
884                                    "Argument {} ('{}'): expected {}, got {}",
885                                    i + 1,
886                                    param_name,
887                                    format_type(expected),
888                                    format_type(actual)
889                                ),
890                                arg.span,
891                            );
892                        }
893                    }
894                }
895            }
896        }
897        // Check args recursively
898        for arg in args {
899            self.check_node(arg, scope);
900        }
901    }
902
903    /// Infer the type of an expression.
904    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
905        match &snode.node {
906            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
907            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
908            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
909                Some(TypeExpr::Named("string".into()))
910            }
911            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
912            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
913            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
914            Node::DictLiteral(entries) => {
915                // Infer shape type when all keys are string literals
916                let mut fields = Vec::new();
917                let mut all_string_keys = true;
918                for entry in entries {
919                    if let Node::StringLiteral(key) = &entry.key.node {
920                        let val_type = self
921                            .infer_type(&entry.value, scope)
922                            .unwrap_or(TypeExpr::Named("nil".into()));
923                        fields.push(ShapeField {
924                            name: key.clone(),
925                            type_expr: val_type,
926                            optional: false,
927                        });
928                    } else {
929                        all_string_keys = false;
930                        break;
931                    }
932                }
933                if all_string_keys && !fields.is_empty() {
934                    Some(TypeExpr::Shape(fields))
935                } else {
936                    Some(TypeExpr::Named("dict".into()))
937                }
938            }
939            Node::Closure { params, body } => {
940                // If all params are typed and we can infer a return type, produce FnType
941                let all_typed = params.iter().all(|p| p.type_expr.is_some());
942                if all_typed && !params.is_empty() {
943                    let param_types: Vec<TypeExpr> =
944                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
945                    // Try to infer return type from last expression in body
946                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
947                    if let Some(ret_type) = ret {
948                        return Some(TypeExpr::FnType {
949                            params: param_types,
950                            return_type: Box::new(ret_type),
951                        });
952                    }
953                }
954                Some(TypeExpr::Named("closure".into()))
955            }
956
957            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
958
959            Node::FunctionCall { name, .. } => {
960                // Check user-defined function return types
961                if let Some(sig) = scope.get_fn(name) {
962                    return sig.return_type.clone();
963                }
964                // Check builtin return types
965                builtin_return_type(name)
966            }
967
968            Node::BinaryOp { op, left, right } => {
969                let lt = self.infer_type(left, scope);
970                let rt = self.infer_type(right, scope);
971                infer_binary_op_type(op, &lt, &rt)
972            }
973
974            Node::UnaryOp { op, operand } => {
975                let t = self.infer_type(operand, scope);
976                match op.as_str() {
977                    "!" => Some(TypeExpr::Named("bool".into())),
978                    "-" => t, // negation preserves type
979                    _ => None,
980                }
981            }
982
983            Node::Ternary {
984                true_expr,
985                false_expr,
986                ..
987            } => {
988                let tt = self.infer_type(true_expr, scope);
989                let ft = self.infer_type(false_expr, scope);
990                match (&tt, &ft) {
991                    (Some(a), Some(b)) if a == b => tt,
992                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
993                    (Some(_), None) => tt,
994                    (None, Some(_)) => ft,
995                    (None, None) => None,
996                }
997            }
998
999            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1000
1001            Node::PropertyAccess { object, property } => {
1002                // EnumName.Variant → infer as the enum type
1003                if let Node::Identifier(name) = &object.node {
1004                    if scope.get_enum(name).is_some() {
1005                        return Some(TypeExpr::Named(name.clone()));
1006                    }
1007                }
1008                // .variant on an enum value → string
1009                if property == "variant" {
1010                    let obj_type = self.infer_type(object, scope);
1011                    if let Some(TypeExpr::Named(name)) = &obj_type {
1012                        if scope.get_enum(name).is_some() {
1013                            return Some(TypeExpr::Named("string".into()));
1014                        }
1015                    }
1016                }
1017                // Shape field access: obj.field → field type
1018                let obj_type = self.infer_type(object, scope);
1019                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1020                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1021                        return Some(field.type_expr.clone());
1022                    }
1023                }
1024                None
1025            }
1026
1027            Node::SubscriptAccess { object, index } => {
1028                let obj_type = self.infer_type(object, scope);
1029                match &obj_type {
1030                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1031                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1032                    Some(TypeExpr::Shape(fields)) => {
1033                        // If index is a string literal, look up the field type
1034                        if let Node::StringLiteral(key) = &index.node {
1035                            fields
1036                                .iter()
1037                                .find(|f| &f.name == key)
1038                                .map(|f| f.type_expr.clone())
1039                        } else {
1040                            None
1041                        }
1042                    }
1043                    Some(TypeExpr::Named(n)) if n == "list" => None,
1044                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1045                    Some(TypeExpr::Named(n)) if n == "string" => {
1046                        Some(TypeExpr::Named("string".into()))
1047                    }
1048                    _ => None,
1049                }
1050            }
1051            Node::SliceAccess { object, .. } => {
1052                // Slicing a list returns the same list type; slicing a string returns string
1053                let obj_type = self.infer_type(object, scope);
1054                match &obj_type {
1055                    Some(TypeExpr::List(_)) => obj_type,
1056                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1057                    Some(TypeExpr::Named(n)) if n == "string" => {
1058                        Some(TypeExpr::Named("string".into()))
1059                    }
1060                    _ => None,
1061                }
1062            }
1063            Node::MethodCall { object, method, .. }
1064            | Node::OptionalMethodCall { object, method, .. } => {
1065                let obj_type = self.infer_type(object, scope);
1066                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1067                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1068                match method.as_str() {
1069                    // Shared: bool-returning methods
1070                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1071                        Some(TypeExpr::Named("bool".into()))
1072                    }
1073                    // Shared: int-returning methods
1074                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1075                    // String methods
1076                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1077                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1078                        Some(TypeExpr::Named("string".into()))
1079                    }
1080                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1081                    // filter returns dict for dicts, list for lists
1082                    "filter" => {
1083                        if is_dict {
1084                            Some(TypeExpr::Named("dict".into()))
1085                        } else {
1086                            Some(TypeExpr::Named("list".into()))
1087                        }
1088                    }
1089                    // List methods
1090                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1091                    "reduce" | "find" | "first" | "last" => None,
1092                    // Dict methods
1093                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1094                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1095                    // Conversions
1096                    "to_string" => Some(TypeExpr::Named("string".into())),
1097                    "to_int" => Some(TypeExpr::Named("int".into())),
1098                    "to_float" => Some(TypeExpr::Named("float".into())),
1099                    _ => None,
1100                }
1101            }
1102
1103            _ => None,
1104        }
1105    }
1106
1107    /// Check if two types are compatible (actual can be assigned to expected).
1108    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1109        // Generic type parameters match anything.
1110        if let TypeExpr::Named(name) = expected {
1111            if scope.is_generic_type_param(name) {
1112                return true;
1113            }
1114        }
1115        if let TypeExpr::Named(name) = actual {
1116            if scope.is_generic_type_param(name) {
1117                return true;
1118            }
1119        }
1120        let expected = self.resolve_alias(expected, scope);
1121        let actual = self.resolve_alias(actual, scope);
1122
1123        match (&expected, &actual) {
1124            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1125            (TypeExpr::Union(members), actual_type) => members
1126                .iter()
1127                .any(|m| self.types_compatible(m, actual_type, scope)),
1128            (expected_type, TypeExpr::Union(members)) => members
1129                .iter()
1130                .all(|m| self.types_compatible(expected_type, m, scope)),
1131            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1132            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1133            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1134                if expected_field.optional {
1135                    return true;
1136                }
1137                af.iter().any(|actual_field| {
1138                    actual_field.name == expected_field.name
1139                        && self.types_compatible(
1140                            &expected_field.type_expr,
1141                            &actual_field.type_expr,
1142                            scope,
1143                        )
1144                })
1145            }),
1146            // dict<K, V> expected, Shape actual → all field values must match V
1147            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1148                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1149                keys_ok
1150                    && af
1151                        .iter()
1152                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1153            }
1154            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1155            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1156            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1157                self.types_compatible(expected_inner, actual_inner, scope)
1158            }
1159            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1160            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1161            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1162                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1163            }
1164            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1165            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1166            // FnType compatibility: params match positionally and return types match
1167            (
1168                TypeExpr::FnType {
1169                    params: ep,
1170                    return_type: er,
1171                },
1172                TypeExpr::FnType {
1173                    params: ap,
1174                    return_type: ar,
1175                },
1176            ) => {
1177                ep.len() == ap.len()
1178                    && ep
1179                        .iter()
1180                        .zip(ap.iter())
1181                        .all(|(e, a)| self.types_compatible(e, a, scope))
1182                    && self.types_compatible(er, ar, scope)
1183            }
1184            // FnType is compatible with Named("closure") for backward compat
1185            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1186            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1187            _ => false,
1188        }
1189    }
1190
1191    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1192        if let TypeExpr::Named(name) = ty {
1193            if let Some(resolved) = scope.resolve_type(name) {
1194                return resolved.clone();
1195            }
1196        }
1197        ty.clone()
1198    }
1199
1200    fn error_at(&mut self, message: String, span: Span) {
1201        self.diagnostics.push(TypeDiagnostic {
1202            message,
1203            severity: DiagnosticSeverity::Error,
1204            span: Some(span),
1205        });
1206    }
1207
1208    fn warning_at(&mut self, message: String, span: Span) {
1209        self.diagnostics.push(TypeDiagnostic {
1210            message,
1211            severity: DiagnosticSeverity::Warning,
1212            span: Some(span),
1213        });
1214    }
1215}
1216
1217impl Default for TypeChecker {
1218    fn default() -> Self {
1219        Self::new()
1220    }
1221}
1222
1223/// Infer the result type of a binary operation.
1224fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1225    match op {
1226        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1227        "+" => match (left, right) {
1228            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1229                match (l.as_str(), r.as_str()) {
1230                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1231                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1232                    ("string", _) => Some(TypeExpr::Named("string".into())),
1233                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1234                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1235                    _ => Some(TypeExpr::Named("string".into())),
1236                }
1237            }
1238            _ => None,
1239        },
1240        "-" | "*" | "/" | "%" => match (left, right) {
1241            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1242                match (l.as_str(), r.as_str()) {
1243                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1244                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1245                    _ => None,
1246                }
1247            }
1248            _ => None,
1249        },
1250        "??" => match (left, right) {
1251            (Some(TypeExpr::Union(members)), _) => {
1252                let non_nil: Vec<_> = members
1253                    .iter()
1254                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1255                    .cloned()
1256                    .collect();
1257                if non_nil.len() == 1 {
1258                    Some(non_nil[0].clone())
1259                } else if non_nil.is_empty() {
1260                    right.clone()
1261                } else {
1262                    Some(TypeExpr::Union(non_nil))
1263                }
1264            }
1265            _ => right.clone(),
1266        },
1267        "|>" => None,
1268        _ => None,
1269    }
1270}
1271
1272/// Format a type expression for display in error messages.
1273/// Produce a detail string describing why a Shape type is incompatible with
1274/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1275/// has type int, expected string".  Returns `None` if both types are not shapes.
1276pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1277    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1278        let mut details = Vec::new();
1279        for field in ef {
1280            if field.optional {
1281                continue;
1282            }
1283            match af.iter().find(|f| f.name == field.name) {
1284                None => details.push(format!(
1285                    "missing field '{}' ({})",
1286                    field.name,
1287                    format_type(&field.type_expr)
1288                )),
1289                Some(actual_field) => {
1290                    let e_str = format_type(&field.type_expr);
1291                    let a_str = format_type(&actual_field.type_expr);
1292                    if e_str != a_str {
1293                        details.push(format!(
1294                            "field '{}' has type {}, expected {}",
1295                            field.name, a_str, e_str
1296                        ));
1297                    }
1298                }
1299            }
1300        }
1301        if details.is_empty() {
1302            None
1303        } else {
1304            Some(details.join("; "))
1305        }
1306    } else {
1307        None
1308    }
1309}
1310
1311pub fn format_type(ty: &TypeExpr) -> String {
1312    match ty {
1313        TypeExpr::Named(n) => n.clone(),
1314        TypeExpr::Union(types) => types
1315            .iter()
1316            .map(format_type)
1317            .collect::<Vec<_>>()
1318            .join(" | "),
1319        TypeExpr::Shape(fields) => {
1320            let inner: Vec<String> = fields
1321                .iter()
1322                .map(|f| {
1323                    let opt = if f.optional { "?" } else { "" };
1324                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1325                })
1326                .collect();
1327            format!("{{{}}}", inner.join(", "))
1328        }
1329        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1330        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1331        TypeExpr::FnType {
1332            params,
1333            return_type,
1334        } => {
1335            let params_str = params
1336                .iter()
1337                .map(format_type)
1338                .collect::<Vec<_>>()
1339                .join(", ");
1340            format!("fn({}) -> {}", params_str, format_type(return_type))
1341        }
1342    }
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347    use super::*;
1348    use crate::Parser;
1349    use harn_lexer::Lexer;
1350
1351    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1352        let mut lexer = Lexer::new(source);
1353        let tokens = lexer.tokenize().unwrap();
1354        let mut parser = Parser::new(tokens);
1355        let program = parser.parse().unwrap();
1356        TypeChecker::new().check(&program)
1357    }
1358
1359    fn errors(source: &str) -> Vec<String> {
1360        check_source(source)
1361            .into_iter()
1362            .filter(|d| d.severity == DiagnosticSeverity::Error)
1363            .map(|d| d.message)
1364            .collect()
1365    }
1366
1367    #[test]
1368    fn test_no_errors_for_untyped_code() {
1369        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1370        assert!(errs.is_empty());
1371    }
1372
1373    #[test]
1374    fn test_correct_typed_let() {
1375        let errs = errors("pipeline t(task) { let x: int = 42 }");
1376        assert!(errs.is_empty());
1377    }
1378
1379    #[test]
1380    fn test_type_mismatch_let() {
1381        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1382        assert_eq!(errs.len(), 1);
1383        assert!(errs[0].contains("Type mismatch"));
1384        assert!(errs[0].contains("int"));
1385        assert!(errs[0].contains("string"));
1386    }
1387
1388    #[test]
1389    fn test_correct_typed_fn() {
1390        let errs = errors(
1391            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1392        );
1393        assert!(errs.is_empty());
1394    }
1395
1396    #[test]
1397    fn test_fn_arg_type_mismatch() {
1398        let errs = errors(
1399            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1400add("hello", 2) }"#,
1401        );
1402        assert_eq!(errs.len(), 1);
1403        assert!(errs[0].contains("Argument 1"));
1404        assert!(errs[0].contains("expected int"));
1405    }
1406
1407    #[test]
1408    fn test_return_type_mismatch() {
1409        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1410        assert_eq!(errs.len(), 1);
1411        assert!(errs[0].contains("Return type mismatch"));
1412    }
1413
1414    #[test]
1415    fn test_union_type_compatible() {
1416        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1417        assert!(errs.is_empty());
1418    }
1419
1420    #[test]
1421    fn test_union_type_mismatch() {
1422        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1423        assert_eq!(errs.len(), 1);
1424        assert!(errs[0].contains("Type mismatch"));
1425    }
1426
1427    #[test]
1428    fn test_type_inference_propagation() {
1429        let errs = errors(
1430            r#"pipeline t(task) {
1431  fn add(a: int, b: int) -> int { return a + b }
1432  let result: string = add(1, 2)
1433}"#,
1434        );
1435        assert_eq!(errs.len(), 1);
1436        assert!(errs[0].contains("Type mismatch"));
1437        assert!(errs[0].contains("string"));
1438        assert!(errs[0].contains("int"));
1439    }
1440
1441    #[test]
1442    fn test_builtin_return_type_inference() {
1443        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1444        assert_eq!(errs.len(), 1);
1445        assert!(errs[0].contains("string"));
1446        assert!(errs[0].contains("int"));
1447    }
1448
1449    #[test]
1450    fn test_binary_op_type_inference() {
1451        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1452        assert_eq!(errs.len(), 1);
1453    }
1454
1455    #[test]
1456    fn test_comparison_returns_bool() {
1457        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1458        assert!(errs.is_empty());
1459    }
1460
1461    #[test]
1462    fn test_int_float_promotion() {
1463        let errs = errors("pipeline t(task) { let x: float = 42 }");
1464        assert!(errs.is_empty());
1465    }
1466
1467    #[test]
1468    fn test_untyped_code_no_errors() {
1469        let errs = errors(
1470            r#"pipeline t(task) {
1471  fn process(data) {
1472    let result = data + " processed"
1473    return result
1474  }
1475  log(process("hello"))
1476}"#,
1477        );
1478        assert!(errs.is_empty());
1479    }
1480
1481    #[test]
1482    fn test_type_alias() {
1483        let errs = errors(
1484            r#"pipeline t(task) {
1485  type Name = string
1486  let x: Name = "hello"
1487}"#,
1488        );
1489        assert!(errs.is_empty());
1490    }
1491
1492    #[test]
1493    fn test_type_alias_mismatch() {
1494        let errs = errors(
1495            r#"pipeline t(task) {
1496  type Name = string
1497  let x: Name = 42
1498}"#,
1499        );
1500        assert_eq!(errs.len(), 1);
1501    }
1502
1503    #[test]
1504    fn test_assignment_type_check() {
1505        let errs = errors(
1506            r#"pipeline t(task) {
1507  var x: int = 0
1508  x = "hello"
1509}"#,
1510        );
1511        assert_eq!(errs.len(), 1);
1512        assert!(errs[0].contains("cannot assign string"));
1513    }
1514
1515    #[test]
1516    fn test_covariance_int_to_float_in_fn() {
1517        let errs = errors(
1518            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1519        );
1520        assert!(errs.is_empty());
1521    }
1522
1523    #[test]
1524    fn test_covariance_return_type() {
1525        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1526        assert!(errs.is_empty());
1527    }
1528
1529    #[test]
1530    fn test_no_contravariance_float_to_int() {
1531        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1532        assert_eq!(errs.len(), 1);
1533    }
1534
1535    // --- Exhaustiveness checking tests ---
1536
1537    fn warnings(source: &str) -> Vec<String> {
1538        check_source(source)
1539            .into_iter()
1540            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1541            .map(|d| d.message)
1542            .collect()
1543    }
1544
1545    #[test]
1546    fn test_exhaustive_match_no_warning() {
1547        let warns = warnings(
1548            r#"pipeline t(task) {
1549  enum Color { Red, Green, Blue }
1550  let c = Color.Red
1551  match c.variant {
1552    "Red" -> { log("r") }
1553    "Green" -> { log("g") }
1554    "Blue" -> { log("b") }
1555  }
1556}"#,
1557        );
1558        let exhaustive_warns: Vec<_> = warns
1559            .iter()
1560            .filter(|w| w.contains("Non-exhaustive"))
1561            .collect();
1562        assert!(exhaustive_warns.is_empty());
1563    }
1564
1565    #[test]
1566    fn test_non_exhaustive_match_warning() {
1567        let warns = warnings(
1568            r#"pipeline t(task) {
1569  enum Color { Red, Green, Blue }
1570  let c = Color.Red
1571  match c.variant {
1572    "Red" -> { log("r") }
1573    "Green" -> { log("g") }
1574  }
1575}"#,
1576        );
1577        let exhaustive_warns: Vec<_> = warns
1578            .iter()
1579            .filter(|w| w.contains("Non-exhaustive"))
1580            .collect();
1581        assert_eq!(exhaustive_warns.len(), 1);
1582        assert!(exhaustive_warns[0].contains("Blue"));
1583    }
1584
1585    #[test]
1586    fn test_non_exhaustive_multiple_missing() {
1587        let warns = warnings(
1588            r#"pipeline t(task) {
1589  enum Status { Active, Inactive, Pending }
1590  let s = Status.Active
1591  match s.variant {
1592    "Active" -> { log("a") }
1593  }
1594}"#,
1595        );
1596        let exhaustive_warns: Vec<_> = warns
1597            .iter()
1598            .filter(|w| w.contains("Non-exhaustive"))
1599            .collect();
1600        assert_eq!(exhaustive_warns.len(), 1);
1601        assert!(exhaustive_warns[0].contains("Inactive"));
1602        assert!(exhaustive_warns[0].contains("Pending"));
1603    }
1604
1605    #[test]
1606    fn test_enum_construct_type_inference() {
1607        let errs = errors(
1608            r#"pipeline t(task) {
1609  enum Color { Red, Green, Blue }
1610  let c: Color = Color.Red
1611}"#,
1612        );
1613        assert!(errs.is_empty());
1614    }
1615
1616    // --- Type narrowing tests ---
1617
1618    #[test]
1619    fn test_nil_coalescing_strips_nil() {
1620        // After ??, nil should be stripped from the type
1621        let errs = errors(
1622            r#"pipeline t(task) {
1623  let x: string | nil = nil
1624  let y: string = x ?? "default"
1625}"#,
1626        );
1627        assert!(errs.is_empty());
1628    }
1629
1630    #[test]
1631    fn test_shape_mismatch_detail_missing_field() {
1632        let errs = errors(
1633            r#"pipeline t(task) {
1634  let x: {name: string, age: int} = {name: "hello"}
1635}"#,
1636        );
1637        assert_eq!(errs.len(), 1);
1638        assert!(
1639            errs[0].contains("missing field 'age'"),
1640            "expected detail about missing field, got: {}",
1641            errs[0]
1642        );
1643    }
1644
1645    #[test]
1646    fn test_shape_mismatch_detail_wrong_type() {
1647        let errs = errors(
1648            r#"pipeline t(task) {
1649  let x: {name: string, age: int} = {name: 42, age: 10}
1650}"#,
1651        );
1652        assert_eq!(errs.len(), 1);
1653        assert!(
1654            errs[0].contains("field 'name' has type int, expected string"),
1655            "expected detail about wrong type, got: {}",
1656            errs[0]
1657        );
1658    }
1659
1660    // --- Match pattern type validation tests ---
1661
1662    #[test]
1663    fn test_match_pattern_string_against_int() {
1664        let warns = warnings(
1665            r#"pipeline t(task) {
1666  let x: int = 42
1667  match x {
1668    "hello" -> { log("bad") }
1669    42 -> { log("ok") }
1670  }
1671}"#,
1672        );
1673        let pattern_warns: Vec<_> = warns
1674            .iter()
1675            .filter(|w| w.contains("Match pattern type mismatch"))
1676            .collect();
1677        assert_eq!(pattern_warns.len(), 1);
1678        assert!(pattern_warns[0].contains("matching int against string literal"));
1679    }
1680
1681    #[test]
1682    fn test_match_pattern_int_against_string() {
1683        let warns = warnings(
1684            r#"pipeline t(task) {
1685  let x: string = "hello"
1686  match x {
1687    42 -> { log("bad") }
1688    "hello" -> { log("ok") }
1689  }
1690}"#,
1691        );
1692        let pattern_warns: Vec<_> = warns
1693            .iter()
1694            .filter(|w| w.contains("Match pattern type mismatch"))
1695            .collect();
1696        assert_eq!(pattern_warns.len(), 1);
1697        assert!(pattern_warns[0].contains("matching string against int literal"));
1698    }
1699
1700    #[test]
1701    fn test_match_pattern_bool_against_int() {
1702        let warns = warnings(
1703            r#"pipeline t(task) {
1704  let x: int = 42
1705  match x {
1706    true -> { log("bad") }
1707    42 -> { log("ok") }
1708  }
1709}"#,
1710        );
1711        let pattern_warns: Vec<_> = warns
1712            .iter()
1713            .filter(|w| w.contains("Match pattern type mismatch"))
1714            .collect();
1715        assert_eq!(pattern_warns.len(), 1);
1716        assert!(pattern_warns[0].contains("matching int against bool literal"));
1717    }
1718
1719    #[test]
1720    fn test_match_pattern_float_against_string() {
1721        let warns = warnings(
1722            r#"pipeline t(task) {
1723  let x: string = "hello"
1724  match x {
1725    3.14 -> { log("bad") }
1726    "hello" -> { log("ok") }
1727  }
1728}"#,
1729        );
1730        let pattern_warns: Vec<_> = warns
1731            .iter()
1732            .filter(|w| w.contains("Match pattern type mismatch"))
1733            .collect();
1734        assert_eq!(pattern_warns.len(), 1);
1735        assert!(pattern_warns[0].contains("matching string against float literal"));
1736    }
1737
1738    #[test]
1739    fn test_match_pattern_int_against_float_ok() {
1740        // int and float are compatible for match patterns
1741        let warns = warnings(
1742            r#"pipeline t(task) {
1743  let x: float = 3.14
1744  match x {
1745    42 -> { log("ok") }
1746    _ -> { log("default") }
1747  }
1748}"#,
1749        );
1750        let pattern_warns: Vec<_> = warns
1751            .iter()
1752            .filter(|w| w.contains("Match pattern type mismatch"))
1753            .collect();
1754        assert!(pattern_warns.is_empty());
1755    }
1756
1757    #[test]
1758    fn test_match_pattern_float_against_int_ok() {
1759        // float and int are compatible for match patterns
1760        let warns = warnings(
1761            r#"pipeline t(task) {
1762  let x: int = 42
1763  match x {
1764    3.14 -> { log("close") }
1765    _ -> { log("default") }
1766  }
1767}"#,
1768        );
1769        let pattern_warns: Vec<_> = warns
1770            .iter()
1771            .filter(|w| w.contains("Match pattern type mismatch"))
1772            .collect();
1773        assert!(pattern_warns.is_empty());
1774    }
1775
1776    #[test]
1777    fn test_match_pattern_correct_types_no_warning() {
1778        let warns = warnings(
1779            r#"pipeline t(task) {
1780  let x: int = 42
1781  match x {
1782    1 -> { log("one") }
1783    2 -> { log("two") }
1784    _ -> { log("other") }
1785  }
1786}"#,
1787        );
1788        let pattern_warns: Vec<_> = warns
1789            .iter()
1790            .filter(|w| w.contains("Match pattern type mismatch"))
1791            .collect();
1792        assert!(pattern_warns.is_empty());
1793    }
1794
1795    #[test]
1796    fn test_match_pattern_wildcard_no_warning() {
1797        let warns = warnings(
1798            r#"pipeline t(task) {
1799  let x: int = 42
1800  match x {
1801    _ -> { log("catch all") }
1802  }
1803}"#,
1804        );
1805        let pattern_warns: Vec<_> = warns
1806            .iter()
1807            .filter(|w| w.contains("Match pattern type mismatch"))
1808            .collect();
1809        assert!(pattern_warns.is_empty());
1810    }
1811
1812    #[test]
1813    fn test_match_pattern_untyped_no_warning() {
1814        // When value has no known type, no warning should be emitted
1815        let warns = warnings(
1816            r#"pipeline t(task) {
1817  let x = some_unknown_fn()
1818  match x {
1819    "hello" -> { log("string") }
1820    42 -> { log("int") }
1821  }
1822}"#,
1823        );
1824        let pattern_warns: Vec<_> = warns
1825            .iter()
1826            .filter(|w| w.contains("Match pattern type mismatch"))
1827            .collect();
1828        assert!(pattern_warns.is_empty());
1829    }
1830}