Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    /// Generic type parameter names in scope (treated as compatible with any type).
39    generic_type_params: std::collections::BTreeSet<String>,
40    parent: Option<Box<TypeScope>>,
41}
42
43#[derive(Debug, Clone)]
44struct FnSignature {
45    params: Vec<(String, InferredType)>,
46    return_type: InferredType,
47    /// Generic type parameter names declared on the function.
48    type_param_names: Vec<String>,
49    /// Number of required parameters (those without defaults).
50    required_params: usize,
51}
52
53impl TypeScope {
54    fn new() -> Self {
55        Self {
56            vars: BTreeMap::new(),
57            functions: BTreeMap::new(),
58            type_aliases: BTreeMap::new(),
59            enums: BTreeMap::new(),
60            interfaces: BTreeMap::new(),
61            structs: BTreeMap::new(),
62            generic_type_params: std::collections::BTreeSet::new(),
63            parent: None,
64        }
65    }
66
67    fn child(&self) -> Self {
68        Self {
69            vars: BTreeMap::new(),
70            functions: BTreeMap::new(),
71            type_aliases: BTreeMap::new(),
72            enums: BTreeMap::new(),
73            interfaces: BTreeMap::new(),
74            structs: BTreeMap::new(),
75            generic_type_params: std::collections::BTreeSet::new(),
76            parent: Some(Box::new(self.clone())),
77        }
78    }
79
80    fn get_var(&self, name: &str) -> Option<&InferredType> {
81        self.vars
82            .get(name)
83            .or_else(|| self.parent.as_ref()?.get_var(name))
84    }
85
86    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
87        self.functions
88            .get(name)
89            .or_else(|| self.parent.as_ref()?.get_fn(name))
90    }
91
92    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
93        self.type_aliases
94            .get(name)
95            .or_else(|| self.parent.as_ref()?.resolve_type(name))
96    }
97
98    fn is_generic_type_param(&self, name: &str) -> bool {
99        self.generic_type_params.contains(name)
100            || self
101                .parent
102                .as_ref()
103                .is_some_and(|p| p.is_generic_type_param(name))
104    }
105
106    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
107        self.enums
108            .get(name)
109            .or_else(|| self.parent.as_ref()?.get_enum(name))
110    }
111
112    #[allow(dead_code)]
113    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
114        self.interfaces
115            .get(name)
116            .or_else(|| self.parent.as_ref()?.get_interface(name))
117    }
118
119    fn define_var(&mut self, name: &str, ty: InferredType) {
120        self.vars.insert(name.to_string(), ty);
121    }
122
123    fn define_fn(&mut self, name: &str, sig: FnSignature) {
124        self.functions.insert(name.to_string(), sig);
125    }
126}
127
128/// Known return types for builtin functions.
129fn builtin_return_type(name: &str) -> InferredType {
130    match name {
131        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
132        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
133            Some(TypeExpr::Named("nil".into()))
134        }
135        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
136        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
137        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
138        "to_int" => Some(TypeExpr::Named("int".into())),
139        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
140        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
141        "list_dir" => Some(TypeExpr::Named("list".into())),
142        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
143        "env" | "regex_match" => Some(TypeExpr::Union(vec![
144            TypeExpr::Named("string".into()),
145            TypeExpr::Named("nil".into()),
146        ])),
147        "json_parse" | "json_extract" => None, // could be any type
148        _ => None,
149    }
150}
151
152/// Check if a name is a known builtin.
153fn is_builtin(name: &str) -> bool {
154    matches!(
155        name,
156        "log"
157            | "print"
158            | "println"
159            | "type_of"
160            | "to_string"
161            | "to_int"
162            | "to_float"
163            | "json_stringify"
164            | "json_parse"
165            | "env"
166            | "timestamp"
167            | "sleep"
168            | "read_file"
169            | "write_file"
170            | "exit"
171            | "regex_match"
172            | "regex_replace"
173            | "http_get"
174            | "http_post"
175            | "llm_call"
176            | "agent_loop"
177            | "await"
178            | "cancel"
179            | "file_exists"
180            | "delete_file"
181            | "list_dir"
182            | "mkdir"
183            | "path_join"
184            | "copy_file"
185            | "append_file"
186            | "temp_dir"
187            | "stat"
188            | "exec"
189            | "shell"
190            | "date_now"
191            | "date_format"
192            | "date_parse"
193            | "format"
194            | "json_validate"
195            | "json_extract"
196            | "trim"
197            | "lowercase"
198            | "uppercase"
199            | "split"
200            | "starts_with"
201            | "ends_with"
202            | "contains"
203            | "replace"
204            | "join"
205            | "len"
206            | "substring"
207            | "dirname"
208            | "basename"
209            | "extname"
210    )
211}
212
213/// The static type checker.
214pub struct TypeChecker {
215    diagnostics: Vec<TypeDiagnostic>,
216    scope: TypeScope,
217}
218
219impl TypeChecker {
220    pub fn new() -> Self {
221        Self {
222            diagnostics: Vec::new(),
223            scope: TypeScope::new(),
224        }
225    }
226
227    /// Check a program and return diagnostics.
228    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
229        // First pass: register type and enum declarations into root scope
230        Self::register_declarations_into(&mut self.scope, program);
231
232        // Also scan pipeline bodies for declarations
233        for snode in program {
234            if let Node::Pipeline { body, .. } = &snode.node {
235                Self::register_declarations_into(&mut self.scope, body);
236            }
237        }
238
239        // Check each top-level node
240        for snode in program {
241            match &snode.node {
242                Node::Pipeline { params, body, .. } => {
243                    let mut child = self.scope.child();
244                    for p in params {
245                        child.define_var(p, None);
246                    }
247                    self.check_block(body, &mut child);
248                }
249                Node::FnDecl {
250                    name,
251                    type_params,
252                    params,
253                    return_type,
254                    body,
255                    ..
256                } => {
257                    let required_params = params
258                        .iter()
259                        .filter(|p| p.default_value.is_none())
260                        .count();
261                    let sig = FnSignature {
262                        params: params
263                            .iter()
264                            .map(|p| (p.name.clone(), p.type_expr.clone()))
265                            .collect(),
266                        return_type: return_type.clone(),
267                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
268                        required_params,
269                    };
270                    self.scope.define_fn(name, sig);
271                    self.check_fn_body(type_params, params, return_type, body);
272                }
273                _ => {
274                    let mut scope = self.scope.clone();
275                    self.check_node(snode, &mut scope);
276                    // Merge any new definitions back into the top-level scope
277                    for (name, ty) in scope.vars {
278                        self.scope.vars.entry(name).or_insert(ty);
279                    }
280                }
281            }
282        }
283
284        self.diagnostics
285    }
286
287    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
288    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
289        for snode in nodes {
290            match &snode.node {
291                Node::TypeDecl { name, type_expr } => {
292                    scope.type_aliases.insert(name.clone(), type_expr.clone());
293                }
294                Node::EnumDecl { name, variants } => {
295                    let variant_names: Vec<String> =
296                        variants.iter().map(|v| v.name.clone()).collect();
297                    scope.enums.insert(name.clone(), variant_names);
298                }
299                Node::InterfaceDecl { name, methods } => {
300                    scope.interfaces.insert(name.clone(), methods.clone());
301                }
302                Node::StructDecl { name, fields } => {
303                    let field_types: Vec<(String, InferredType)> = fields
304                        .iter()
305                        .map(|f| (f.name.clone(), f.type_expr.clone()))
306                        .collect();
307                    scope.structs.insert(name.clone(), field_types);
308                }
309                _ => {}
310            }
311        }
312    }
313
314    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
315        for stmt in stmts {
316            self.check_node(stmt, scope);
317        }
318    }
319
320    /// Define variables from a destructuring pattern in the given scope (as unknown type).
321    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
322        match pattern {
323            BindingPattern::Identifier(name) => {
324                scope.define_var(name, None);
325            }
326            BindingPattern::Dict(fields) => {
327                for field in fields {
328                    let name = field.alias.as_deref().unwrap_or(&field.key);
329                    scope.define_var(name, None);
330                }
331            }
332            BindingPattern::List(elements) => {
333                for elem in elements {
334                    scope.define_var(&elem.name, None);
335                }
336            }
337        }
338    }
339
340    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
341        let span = snode.span;
342        match &snode.node {
343            Node::LetBinding {
344                pattern,
345                type_ann,
346                value,
347            } => {
348                let inferred = self.infer_type(value, scope);
349                if let BindingPattern::Identifier(name) = pattern {
350                    if let Some(expected) = type_ann {
351                        if let Some(actual) = &inferred {
352                            if !self.types_compatible(expected, actual, scope) {
353                                let mut msg = format!(
354                                    "Type mismatch: '{}' declared as {}, but assigned {}",
355                                    name,
356                                    format_type(expected),
357                                    format_type(actual)
358                                );
359                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
360                                    msg.push_str(&format!(" ({})", detail));
361                                }
362                                self.error_at(msg, span);
363                            }
364                        }
365                    }
366                    let ty = type_ann.clone().or(inferred);
367                    scope.define_var(name, ty);
368                } else {
369                    Self::define_pattern_vars(pattern, scope);
370                }
371            }
372
373            Node::VarBinding {
374                pattern,
375                type_ann,
376                value,
377            } => {
378                let inferred = self.infer_type(value, scope);
379                if let BindingPattern::Identifier(name) = pattern {
380                    if let Some(expected) = type_ann {
381                        if let Some(actual) = &inferred {
382                            if !self.types_compatible(expected, actual, scope) {
383                                let mut msg = format!(
384                                    "Type mismatch: '{}' declared as {}, but assigned {}",
385                                    name,
386                                    format_type(expected),
387                                    format_type(actual)
388                                );
389                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
390                                    msg.push_str(&format!(" ({})", detail));
391                                }
392                                self.error_at(msg, span);
393                            }
394                        }
395                    }
396                    let ty = type_ann.clone().or(inferred);
397                    scope.define_var(name, ty);
398                } else {
399                    Self::define_pattern_vars(pattern, scope);
400                }
401            }
402
403            Node::FnDecl {
404                name,
405                type_params,
406                params,
407                return_type,
408                body,
409                ..
410            } => {
411                let required_params = params
412                    .iter()
413                    .filter(|p| p.default_value.is_none())
414                    .count();
415                let sig = FnSignature {
416                    params: params
417                        .iter()
418                        .map(|p| (p.name.clone(), p.type_expr.clone()))
419                        .collect(),
420                    return_type: return_type.clone(),
421                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
422                    required_params,
423                };
424                scope.define_fn(name, sig.clone());
425                scope.define_var(name, None);
426                self.check_fn_body(type_params, params, return_type, body);
427            }
428
429            Node::FunctionCall { name, args } => {
430                self.check_call(name, args, scope, span);
431            }
432
433            Node::IfElse {
434                condition,
435                then_body,
436                else_body,
437            } => {
438                self.check_node(condition, scope);
439                let mut then_scope = scope.child();
440                self.check_block(then_body, &mut then_scope);
441                if let Some(else_body) = else_body {
442                    let mut else_scope = scope.child();
443                    self.check_block(else_body, &mut else_scope);
444                }
445            }
446
447            Node::ForIn {
448                pattern,
449                iterable,
450                body,
451            } => {
452                self.check_node(iterable, scope);
453                let mut loop_scope = scope.child();
454                if let BindingPattern::Identifier(variable) = pattern {
455                    // Infer loop variable type from iterable
456                    let elem_type = match self.infer_type(iterable, scope) {
457                        Some(TypeExpr::List(inner)) => Some(*inner),
458                        Some(TypeExpr::Named(n)) if n == "string" => {
459                            Some(TypeExpr::Named("string".into()))
460                        }
461                        _ => None,
462                    };
463                    loop_scope.define_var(variable, elem_type);
464                } else {
465                    Self::define_pattern_vars(pattern, &mut loop_scope);
466                }
467                self.check_block(body, &mut loop_scope);
468            }
469
470            Node::WhileLoop { condition, body } => {
471                self.check_node(condition, scope);
472                let mut loop_scope = scope.child();
473                self.check_block(body, &mut loop_scope);
474            }
475
476            Node::TryCatch {
477                body,
478                error_var,
479                catch_body,
480                finally_body,
481                ..
482            } => {
483                let mut try_scope = scope.child();
484                self.check_block(body, &mut try_scope);
485                let mut catch_scope = scope.child();
486                if let Some(var) = error_var {
487                    catch_scope.define_var(var, None);
488                }
489                self.check_block(catch_body, &mut catch_scope);
490                if let Some(fb) = finally_body {
491                    let mut finally_scope = scope.child();
492                    self.check_block(fb, &mut finally_scope);
493                }
494            }
495
496            Node::ReturnStmt {
497                value: Some(val), ..
498            } => {
499                self.check_node(val, scope);
500            }
501
502            Node::Assignment {
503                target, value, op, ..
504            } => {
505                self.check_node(value, scope);
506                if let Node::Identifier(name) = &target.node {
507                    if let Some(Some(var_type)) = scope.get_var(name) {
508                        let value_type = self.infer_type(value, scope);
509                        let assigned = if let Some(op) = op {
510                            let var_inferred = scope.get_var(name).cloned().flatten();
511                            infer_binary_op_type(op, &var_inferred, &value_type)
512                        } else {
513                            value_type
514                        };
515                        if let Some(actual) = &assigned {
516                            if !self.types_compatible(var_type, actual, scope) {
517                                self.error_at(
518                                    format!(
519                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
520                                        format_type(actual),
521                                        name,
522                                        format_type(var_type)
523                                    ),
524                                    span,
525                                );
526                            }
527                        }
528                    }
529                }
530            }
531
532            Node::TypeDecl { name, type_expr } => {
533                scope.type_aliases.insert(name.clone(), type_expr.clone());
534            }
535
536            Node::EnumDecl { name, variants } => {
537                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
538                scope.enums.insert(name.clone(), variant_names);
539            }
540
541            Node::StructDecl { name, fields } => {
542                let field_types: Vec<(String, InferredType)> = fields
543                    .iter()
544                    .map(|f| (f.name.clone(), f.type_expr.clone()))
545                    .collect();
546                scope.structs.insert(name.clone(), field_types);
547            }
548
549            Node::InterfaceDecl { name, methods } => {
550                scope.interfaces.insert(name.clone(), methods.clone());
551            }
552
553            Node::MatchExpr { value, arms } => {
554                self.check_node(value, scope);
555                let value_type = self.infer_type(value, scope);
556                for arm in arms {
557                    self.check_node(&arm.pattern, scope);
558                    // Check for incompatible literal pattern types
559                    if let Some(ref vt) = value_type {
560                        let value_type_name = format_type(vt);
561                        let mismatch = match &arm.pattern.node {
562                            Node::StringLiteral(_) => {
563                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
564                            }
565                            Node::IntLiteral(_) => {
566                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
567                                    && !self.types_compatible(
568                                        vt,
569                                        &TypeExpr::Named("float".into()),
570                                        scope,
571                                    )
572                            }
573                            Node::FloatLiteral(_) => {
574                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
575                                    && !self.types_compatible(
576                                        vt,
577                                        &TypeExpr::Named("int".into()),
578                                        scope,
579                                    )
580                            }
581                            Node::BoolLiteral(_) => {
582                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
583                            }
584                            _ => false,
585                        };
586                        if mismatch {
587                            let pattern_type = match &arm.pattern.node {
588                                Node::StringLiteral(_) => "string",
589                                Node::IntLiteral(_) => "int",
590                                Node::FloatLiteral(_) => "float",
591                                Node::BoolLiteral(_) => "bool",
592                                _ => unreachable!(),
593                            };
594                            self.warning_at(
595                                format!(
596                                    "Match pattern type mismatch: matching {} against {} literal",
597                                    value_type_name, pattern_type
598                                ),
599                                arm.pattern.span,
600                            );
601                        }
602                    }
603                    let mut arm_scope = scope.child();
604                    self.check_block(&arm.body, &mut arm_scope);
605                }
606                self.check_match_exhaustiveness(value, arms, scope, span);
607            }
608
609            // Recurse into nested expressions + validate binary op types
610            Node::BinaryOp { op, left, right } => {
611                self.check_node(left, scope);
612                self.check_node(right, scope);
613                // Validate operator/type compatibility
614                let lt = self.infer_type(left, scope);
615                let rt = self.infer_type(right, scope);
616                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
617                    match op.as_str() {
618                        "-" | "*" | "/" | "%" => {
619                            let numeric = ["int", "float"];
620                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
621                                self.warning_at(
622                                    format!(
623                                        "Operator '{op}' may not be valid for types {} and {}",
624                                        l, r
625                                    ),
626                                    span,
627                                );
628                            }
629                        }
630                        "+" => {
631                            // + is valid for int, float, string, list, dict
632                            let valid = ["int", "float", "string", "list", "dict"];
633                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
634                                self.warning_at(
635                                    format!(
636                                        "Operator '+' may not be valid for types {} and {}",
637                                        l, r
638                                    ),
639                                    span,
640                                );
641                            }
642                        }
643                        _ => {}
644                    }
645                }
646            }
647            Node::UnaryOp { operand, .. } => {
648                self.check_node(operand, scope);
649            }
650            Node::MethodCall { object, args, .. }
651            | Node::OptionalMethodCall { object, args, .. } => {
652                self.check_node(object, scope);
653                for arg in args {
654                    self.check_node(arg, scope);
655                }
656            }
657            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
658                self.check_node(object, scope);
659            }
660            Node::SubscriptAccess { object, index } => {
661                self.check_node(object, scope);
662                self.check_node(index, scope);
663            }
664            Node::SliceAccess { object, start, end } => {
665                self.check_node(object, scope);
666                if let Some(s) = start {
667                    self.check_node(s, scope);
668                }
669                if let Some(e) = end {
670                    self.check_node(e, scope);
671                }
672            }
673
674            // Terminals — nothing to check
675            _ => {}
676        }
677    }
678
679    fn check_fn_body(
680        &mut self,
681        type_params: &[TypeParam],
682        params: &[TypedParam],
683        return_type: &Option<TypeExpr>,
684        body: &[SNode],
685    ) {
686        let mut fn_scope = self.scope.child();
687        // Register generic type parameters so they are treated as compatible
688        // with any concrete type during type checking.
689        for tp in type_params {
690            fn_scope.generic_type_params.insert(tp.name.clone());
691        }
692        for param in params {
693            fn_scope.define_var(&param.name, param.type_expr.clone());
694            if let Some(default) = &param.default_value {
695                self.check_node(default, &mut fn_scope);
696            }
697        }
698        self.check_block(body, &mut fn_scope);
699
700        // Check return statements against declared return type
701        if let Some(ret_type) = return_type {
702            for stmt in body {
703                self.check_return_type(stmt, ret_type, &fn_scope);
704            }
705        }
706    }
707
708    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
709        let span = snode.span;
710        match &snode.node {
711            Node::ReturnStmt { value: Some(val) } => {
712                let inferred = self.infer_type(val, scope);
713                if let Some(actual) = &inferred {
714                    if !self.types_compatible(expected, actual, scope) {
715                        self.error_at(
716                            format!(
717                                "Return type mismatch: expected {}, got {}",
718                                format_type(expected),
719                                format_type(actual)
720                            ),
721                            span,
722                        );
723                    }
724                }
725            }
726            Node::IfElse {
727                then_body,
728                else_body,
729                ..
730            } => {
731                for stmt in then_body {
732                    self.check_return_type(stmt, expected, scope);
733                }
734                if let Some(else_body) = else_body {
735                    for stmt in else_body {
736                        self.check_return_type(stmt, expected, scope);
737                    }
738                }
739            }
740            _ => {}
741        }
742    }
743
744    /// Check if a match expression on an enum's `.variant` property covers all variants.
745    fn check_match_exhaustiveness(
746        &mut self,
747        value: &SNode,
748        arms: &[MatchArm],
749        scope: &TypeScope,
750        span: Span,
751    ) {
752        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
753        let enum_name = match &value.node {
754            Node::PropertyAccess { object, property } if property == "variant" => {
755                // Infer the type of the object
756                match self.infer_type(object, scope) {
757                    Some(TypeExpr::Named(name)) => {
758                        if scope.get_enum(&name).is_some() {
759                            Some(name)
760                        } else {
761                            None
762                        }
763                    }
764                    _ => None,
765                }
766            }
767            _ => {
768                // Direct match on an enum value: match <expr> { ... }
769                match self.infer_type(value, scope) {
770                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
771                    _ => None,
772                }
773            }
774        };
775
776        let Some(enum_name) = enum_name else {
777            return;
778        };
779        let Some(variants) = scope.get_enum(&enum_name) else {
780            return;
781        };
782
783        // Collect variant names covered by match arms
784        let mut covered: Vec<String> = Vec::new();
785        let mut has_wildcard = false;
786
787        for arm in arms {
788            match &arm.pattern.node {
789                // String literal pattern (matching on .variant): "VariantA"
790                Node::StringLiteral(s) => covered.push(s.clone()),
791                // Identifier pattern acts as a wildcard/catch-all
792                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
793                    has_wildcard = true;
794                }
795                // Direct enum construct pattern: EnumName.Variant
796                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
797                // PropertyAccess pattern: EnumName.Variant (no args)
798                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
799                _ => {
800                    // Unknown pattern shape — conservatively treat as wildcard
801                    has_wildcard = true;
802                }
803            }
804        }
805
806        if has_wildcard {
807            return;
808        }
809
810        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
811        if !missing.is_empty() {
812            let missing_str = missing
813                .iter()
814                .map(|s| format!("\"{}\"", s))
815                .collect::<Vec<_>>()
816                .join(", ");
817            self.warning_at(
818                format!(
819                    "Non-exhaustive match on enum {}: missing variants {}",
820                    enum_name, missing_str
821                ),
822                span,
823            );
824        }
825    }
826
827    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
828        // Check against known function signatures
829        if let Some(sig) = scope.get_fn(name).cloned() {
830            if !is_builtin(name)
831                && (args.len() < sig.required_params || args.len() > sig.params.len())
832            {
833                let expected = if sig.required_params == sig.params.len() {
834                    format!("{}", sig.params.len())
835                } else {
836                    format!("{}-{}", sig.required_params, sig.params.len())
837                };
838                self.warning_at(
839                    format!(
840                        "Function '{}' expects {} arguments, got {}",
841                        name,
842                        expected,
843                        args.len()
844                    ),
845                    span,
846                );
847            }
848            // Build a scope that includes the function's generic type params
849            // so they are treated as compatible with any concrete type.
850            let call_scope = if sig.type_param_names.is_empty() {
851                scope.clone()
852            } else {
853                let mut s = scope.child();
854                for tp_name in &sig.type_param_names {
855                    s.generic_type_params.insert(tp_name.clone());
856                }
857                s
858            };
859            for (i, (arg, (param_name, param_type))) in
860                args.iter().zip(sig.params.iter()).enumerate()
861            {
862                if let Some(expected) = param_type {
863                    let actual = self.infer_type(arg, scope);
864                    if let Some(actual) = &actual {
865                        if !self.types_compatible(expected, actual, &call_scope) {
866                            self.error_at(
867                                format!(
868                                    "Argument {} ('{}'): expected {}, got {}",
869                                    i + 1,
870                                    param_name,
871                                    format_type(expected),
872                                    format_type(actual)
873                                ),
874                                arg.span,
875                            );
876                        }
877                    }
878                }
879            }
880        }
881        // Check args recursively
882        for arg in args {
883            self.check_node(arg, scope);
884        }
885    }
886
887    /// Infer the type of an expression.
888    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
889        match &snode.node {
890            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
891            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
892            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
893                Some(TypeExpr::Named("string".into()))
894            }
895            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
896            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
897            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
898            Node::DictLiteral(entries) => {
899                // Infer shape type when all keys are string literals
900                let mut fields = Vec::new();
901                let mut all_string_keys = true;
902                for entry in entries {
903                    if let Node::StringLiteral(key) = &entry.key.node {
904                        let val_type = self
905                            .infer_type(&entry.value, scope)
906                            .unwrap_or(TypeExpr::Named("nil".into()));
907                        fields.push(ShapeField {
908                            name: key.clone(),
909                            type_expr: val_type,
910                            optional: false,
911                        });
912                    } else {
913                        all_string_keys = false;
914                        break;
915                    }
916                }
917                if all_string_keys && !fields.is_empty() {
918                    Some(TypeExpr::Shape(fields))
919                } else {
920                    Some(TypeExpr::Named("dict".into()))
921                }
922            }
923            Node::Closure { params, body } => {
924                // If all params are typed and we can infer a return type, produce FnType
925                let all_typed = params.iter().all(|p| p.type_expr.is_some());
926                if all_typed && !params.is_empty() {
927                    let param_types: Vec<TypeExpr> =
928                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
929                    // Try to infer return type from last expression in body
930                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
931                    if let Some(ret_type) = ret {
932                        return Some(TypeExpr::FnType {
933                            params: param_types,
934                            return_type: Box::new(ret_type),
935                        });
936                    }
937                }
938                Some(TypeExpr::Named("closure".into()))
939            }
940
941            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
942
943            Node::FunctionCall { name, .. } => {
944                // Check user-defined function return types
945                if let Some(sig) = scope.get_fn(name) {
946                    return sig.return_type.clone();
947                }
948                // Check builtin return types
949                builtin_return_type(name)
950            }
951
952            Node::BinaryOp { op, left, right } => {
953                let lt = self.infer_type(left, scope);
954                let rt = self.infer_type(right, scope);
955                infer_binary_op_type(op, &lt, &rt)
956            }
957
958            Node::UnaryOp { op, operand } => {
959                let t = self.infer_type(operand, scope);
960                match op.as_str() {
961                    "!" => Some(TypeExpr::Named("bool".into())),
962                    "-" => t, // negation preserves type
963                    _ => None,
964                }
965            }
966
967            Node::Ternary {
968                true_expr,
969                false_expr,
970                ..
971            } => {
972                let tt = self.infer_type(true_expr, scope);
973                let ft = self.infer_type(false_expr, scope);
974                match (&tt, &ft) {
975                    (Some(a), Some(b)) if a == b => tt,
976                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
977                    (Some(_), None) => tt,
978                    (None, Some(_)) => ft,
979                    (None, None) => None,
980                }
981            }
982
983            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
984
985            Node::PropertyAccess { object, property } => {
986                // EnumName.Variant → infer as the enum type
987                if let Node::Identifier(name) = &object.node {
988                    if scope.get_enum(name).is_some() {
989                        return Some(TypeExpr::Named(name.clone()));
990                    }
991                }
992                // .variant on an enum value → string
993                if property == "variant" {
994                    let obj_type = self.infer_type(object, scope);
995                    if let Some(TypeExpr::Named(name)) = &obj_type {
996                        if scope.get_enum(name).is_some() {
997                            return Some(TypeExpr::Named("string".into()));
998                        }
999                    }
1000                }
1001                // Shape field access: obj.field → field type
1002                let obj_type = self.infer_type(object, scope);
1003                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1004                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1005                        return Some(field.type_expr.clone());
1006                    }
1007                }
1008                None
1009            }
1010
1011            Node::SubscriptAccess { object, index } => {
1012                let obj_type = self.infer_type(object, scope);
1013                match &obj_type {
1014                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1015                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1016                    Some(TypeExpr::Shape(fields)) => {
1017                        // If index is a string literal, look up the field type
1018                        if let Node::StringLiteral(key) = &index.node {
1019                            fields
1020                                .iter()
1021                                .find(|f| &f.name == key)
1022                                .map(|f| f.type_expr.clone())
1023                        } else {
1024                            None
1025                        }
1026                    }
1027                    Some(TypeExpr::Named(n)) if n == "list" => None,
1028                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1029                    Some(TypeExpr::Named(n)) if n == "string" => {
1030                        Some(TypeExpr::Named("string".into()))
1031                    }
1032                    _ => None,
1033                }
1034            }
1035            Node::SliceAccess { object, .. } => {
1036                // Slicing a list returns the same list type; slicing a string returns string
1037                let obj_type = self.infer_type(object, scope);
1038                match &obj_type {
1039                    Some(TypeExpr::List(_)) => obj_type,
1040                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1041                    Some(TypeExpr::Named(n)) if n == "string" => {
1042                        Some(TypeExpr::Named("string".into()))
1043                    }
1044                    _ => None,
1045                }
1046            }
1047            Node::MethodCall { object, method, .. }
1048            | Node::OptionalMethodCall { object, method, .. } => {
1049                let obj_type = self.infer_type(object, scope);
1050                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1051                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1052                match method.as_str() {
1053                    // Shared: bool-returning methods
1054                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1055                        Some(TypeExpr::Named("bool".into()))
1056                    }
1057                    // Shared: int-returning methods
1058                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1059                    // String methods
1060                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1061                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1062                        Some(TypeExpr::Named("string".into()))
1063                    }
1064                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1065                    // filter returns dict for dicts, list for lists
1066                    "filter" => {
1067                        if is_dict {
1068                            Some(TypeExpr::Named("dict".into()))
1069                        } else {
1070                            Some(TypeExpr::Named("list".into()))
1071                        }
1072                    }
1073                    // List methods
1074                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1075                    "reduce" | "find" | "first" | "last" => None,
1076                    // Dict methods
1077                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1078                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1079                    // Conversions
1080                    "to_string" => Some(TypeExpr::Named("string".into())),
1081                    "to_int" => Some(TypeExpr::Named("int".into())),
1082                    "to_float" => Some(TypeExpr::Named("float".into())),
1083                    _ => None,
1084                }
1085            }
1086
1087            _ => None,
1088        }
1089    }
1090
1091    /// Check if two types are compatible (actual can be assigned to expected).
1092    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1093        // Generic type parameters match anything.
1094        if let TypeExpr::Named(name) = expected {
1095            if scope.is_generic_type_param(name) {
1096                return true;
1097            }
1098        }
1099        if let TypeExpr::Named(name) = actual {
1100            if scope.is_generic_type_param(name) {
1101                return true;
1102            }
1103        }
1104        let expected = self.resolve_alias(expected, scope);
1105        let actual = self.resolve_alias(actual, scope);
1106
1107        match (&expected, &actual) {
1108            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1109            (TypeExpr::Union(members), actual_type) => members
1110                .iter()
1111                .any(|m| self.types_compatible(m, actual_type, scope)),
1112            (expected_type, TypeExpr::Union(members)) => members
1113                .iter()
1114                .all(|m| self.types_compatible(expected_type, m, scope)),
1115            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1116            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1117            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1118                if expected_field.optional {
1119                    return true;
1120                }
1121                af.iter().any(|actual_field| {
1122                    actual_field.name == expected_field.name
1123                        && self.types_compatible(
1124                            &expected_field.type_expr,
1125                            &actual_field.type_expr,
1126                            scope,
1127                        )
1128                })
1129            }),
1130            // dict<K, V> expected, Shape actual → all field values must match V
1131            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1132                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1133                keys_ok
1134                    && af
1135                        .iter()
1136                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1137            }
1138            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1139            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1140            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1141                self.types_compatible(expected_inner, actual_inner, scope)
1142            }
1143            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1144            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1145            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1146                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1147            }
1148            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1149            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1150            // FnType compatibility: params match positionally and return types match
1151            (
1152                TypeExpr::FnType {
1153                    params: ep,
1154                    return_type: er,
1155                },
1156                TypeExpr::FnType {
1157                    params: ap,
1158                    return_type: ar,
1159                },
1160            ) => {
1161                ep.len() == ap.len()
1162                    && ep
1163                        .iter()
1164                        .zip(ap.iter())
1165                        .all(|(e, a)| self.types_compatible(e, a, scope))
1166                    && self.types_compatible(er, ar, scope)
1167            }
1168            // FnType is compatible with Named("closure") for backward compat
1169            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1170            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1171            _ => false,
1172        }
1173    }
1174
1175    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1176        if let TypeExpr::Named(name) = ty {
1177            if let Some(resolved) = scope.resolve_type(name) {
1178                return resolved.clone();
1179            }
1180        }
1181        ty.clone()
1182    }
1183
1184    fn error_at(&mut self, message: String, span: Span) {
1185        self.diagnostics.push(TypeDiagnostic {
1186            message,
1187            severity: DiagnosticSeverity::Error,
1188            span: Some(span),
1189        });
1190    }
1191
1192    fn warning_at(&mut self, message: String, span: Span) {
1193        self.diagnostics.push(TypeDiagnostic {
1194            message,
1195            severity: DiagnosticSeverity::Warning,
1196            span: Some(span),
1197        });
1198    }
1199}
1200
1201impl Default for TypeChecker {
1202    fn default() -> Self {
1203        Self::new()
1204    }
1205}
1206
1207/// Infer the result type of a binary operation.
1208fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1209    match op {
1210        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1211        "+" => match (left, right) {
1212            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1213                match (l.as_str(), r.as_str()) {
1214                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1215                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1216                    ("string", _) => Some(TypeExpr::Named("string".into())),
1217                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1218                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1219                    _ => Some(TypeExpr::Named("string".into())),
1220                }
1221            }
1222            _ => None,
1223        },
1224        "-" | "*" | "/" | "%" => match (left, right) {
1225            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1226                match (l.as_str(), r.as_str()) {
1227                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1228                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1229                    _ => None,
1230                }
1231            }
1232            _ => None,
1233        },
1234        "??" => match (left, right) {
1235            (Some(TypeExpr::Union(members)), _) => {
1236                let non_nil: Vec<_> = members
1237                    .iter()
1238                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1239                    .cloned()
1240                    .collect();
1241                if non_nil.len() == 1 {
1242                    Some(non_nil[0].clone())
1243                } else if non_nil.is_empty() {
1244                    right.clone()
1245                } else {
1246                    Some(TypeExpr::Union(non_nil))
1247                }
1248            }
1249            _ => right.clone(),
1250        },
1251        "|>" => None,
1252        _ => None,
1253    }
1254}
1255
1256/// Format a type expression for display in error messages.
1257/// Produce a detail string describing why a Shape type is incompatible with
1258/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1259/// has type int, expected string".  Returns `None` if both types are not shapes.
1260pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1261    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1262        let mut details = Vec::new();
1263        for field in ef {
1264            if field.optional {
1265                continue;
1266            }
1267            match af.iter().find(|f| f.name == field.name) {
1268                None => details.push(format!(
1269                    "missing field '{}' ({})",
1270                    field.name,
1271                    format_type(&field.type_expr)
1272                )),
1273                Some(actual_field) => {
1274                    let e_str = format_type(&field.type_expr);
1275                    let a_str = format_type(&actual_field.type_expr);
1276                    if e_str != a_str {
1277                        details.push(format!(
1278                            "field '{}' has type {}, expected {}",
1279                            field.name, a_str, e_str
1280                        ));
1281                    }
1282                }
1283            }
1284        }
1285        if details.is_empty() {
1286            None
1287        } else {
1288            Some(details.join("; "))
1289        }
1290    } else {
1291        None
1292    }
1293}
1294
1295pub fn format_type(ty: &TypeExpr) -> String {
1296    match ty {
1297        TypeExpr::Named(n) => n.clone(),
1298        TypeExpr::Union(types) => types
1299            .iter()
1300            .map(format_type)
1301            .collect::<Vec<_>>()
1302            .join(" | "),
1303        TypeExpr::Shape(fields) => {
1304            let inner: Vec<String> = fields
1305                .iter()
1306                .map(|f| {
1307                    let opt = if f.optional { "?" } else { "" };
1308                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1309                })
1310                .collect();
1311            format!("{{{}}}", inner.join(", "))
1312        }
1313        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1314        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1315        TypeExpr::FnType {
1316            params,
1317            return_type,
1318        } => {
1319            let params_str = params
1320                .iter()
1321                .map(format_type)
1322                .collect::<Vec<_>>()
1323                .join(", ");
1324            format!("fn({}) -> {}", params_str, format_type(return_type))
1325        }
1326    }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331    use super::*;
1332    use crate::Parser;
1333    use harn_lexer::Lexer;
1334
1335    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1336        let mut lexer = Lexer::new(source);
1337        let tokens = lexer.tokenize().unwrap();
1338        let mut parser = Parser::new(tokens);
1339        let program = parser.parse().unwrap();
1340        TypeChecker::new().check(&program)
1341    }
1342
1343    fn errors(source: &str) -> Vec<String> {
1344        check_source(source)
1345            .into_iter()
1346            .filter(|d| d.severity == DiagnosticSeverity::Error)
1347            .map(|d| d.message)
1348            .collect()
1349    }
1350
1351    #[test]
1352    fn test_no_errors_for_untyped_code() {
1353        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1354        assert!(errs.is_empty());
1355    }
1356
1357    #[test]
1358    fn test_correct_typed_let() {
1359        let errs = errors("pipeline t(task) { let x: int = 42 }");
1360        assert!(errs.is_empty());
1361    }
1362
1363    #[test]
1364    fn test_type_mismatch_let() {
1365        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1366        assert_eq!(errs.len(), 1);
1367        assert!(errs[0].contains("Type mismatch"));
1368        assert!(errs[0].contains("int"));
1369        assert!(errs[0].contains("string"));
1370    }
1371
1372    #[test]
1373    fn test_correct_typed_fn() {
1374        let errs = errors(
1375            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1376        );
1377        assert!(errs.is_empty());
1378    }
1379
1380    #[test]
1381    fn test_fn_arg_type_mismatch() {
1382        let errs = errors(
1383            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1384add("hello", 2) }"#,
1385        );
1386        assert_eq!(errs.len(), 1);
1387        assert!(errs[0].contains("Argument 1"));
1388        assert!(errs[0].contains("expected int"));
1389    }
1390
1391    #[test]
1392    fn test_return_type_mismatch() {
1393        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1394        assert_eq!(errs.len(), 1);
1395        assert!(errs[0].contains("Return type mismatch"));
1396    }
1397
1398    #[test]
1399    fn test_union_type_compatible() {
1400        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1401        assert!(errs.is_empty());
1402    }
1403
1404    #[test]
1405    fn test_union_type_mismatch() {
1406        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1407        assert_eq!(errs.len(), 1);
1408        assert!(errs[0].contains("Type mismatch"));
1409    }
1410
1411    #[test]
1412    fn test_type_inference_propagation() {
1413        let errs = errors(
1414            r#"pipeline t(task) {
1415  fn add(a: int, b: int) -> int { return a + b }
1416  let result: string = add(1, 2)
1417}"#,
1418        );
1419        assert_eq!(errs.len(), 1);
1420        assert!(errs[0].contains("Type mismatch"));
1421        assert!(errs[0].contains("string"));
1422        assert!(errs[0].contains("int"));
1423    }
1424
1425    #[test]
1426    fn test_builtin_return_type_inference() {
1427        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1428        assert_eq!(errs.len(), 1);
1429        assert!(errs[0].contains("string"));
1430        assert!(errs[0].contains("int"));
1431    }
1432
1433    #[test]
1434    fn test_binary_op_type_inference() {
1435        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1436        assert_eq!(errs.len(), 1);
1437    }
1438
1439    #[test]
1440    fn test_comparison_returns_bool() {
1441        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1442        assert!(errs.is_empty());
1443    }
1444
1445    #[test]
1446    fn test_int_float_promotion() {
1447        let errs = errors("pipeline t(task) { let x: float = 42 }");
1448        assert!(errs.is_empty());
1449    }
1450
1451    #[test]
1452    fn test_untyped_code_no_errors() {
1453        let errs = errors(
1454            r#"pipeline t(task) {
1455  fn process(data) {
1456    let result = data + " processed"
1457    return result
1458  }
1459  log(process("hello"))
1460}"#,
1461        );
1462        assert!(errs.is_empty());
1463    }
1464
1465    #[test]
1466    fn test_type_alias() {
1467        let errs = errors(
1468            r#"pipeline t(task) {
1469  type Name = string
1470  let x: Name = "hello"
1471}"#,
1472        );
1473        assert!(errs.is_empty());
1474    }
1475
1476    #[test]
1477    fn test_type_alias_mismatch() {
1478        let errs = errors(
1479            r#"pipeline t(task) {
1480  type Name = string
1481  let x: Name = 42
1482}"#,
1483        );
1484        assert_eq!(errs.len(), 1);
1485    }
1486
1487    #[test]
1488    fn test_assignment_type_check() {
1489        let errs = errors(
1490            r#"pipeline t(task) {
1491  var x: int = 0
1492  x = "hello"
1493}"#,
1494        );
1495        assert_eq!(errs.len(), 1);
1496        assert!(errs[0].contains("cannot assign string"));
1497    }
1498
1499    #[test]
1500    fn test_covariance_int_to_float_in_fn() {
1501        let errs = errors(
1502            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1503        );
1504        assert!(errs.is_empty());
1505    }
1506
1507    #[test]
1508    fn test_covariance_return_type() {
1509        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1510        assert!(errs.is_empty());
1511    }
1512
1513    #[test]
1514    fn test_no_contravariance_float_to_int() {
1515        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1516        assert_eq!(errs.len(), 1);
1517    }
1518
1519    // --- Exhaustiveness checking tests ---
1520
1521    fn warnings(source: &str) -> Vec<String> {
1522        check_source(source)
1523            .into_iter()
1524            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1525            .map(|d| d.message)
1526            .collect()
1527    }
1528
1529    #[test]
1530    fn test_exhaustive_match_no_warning() {
1531        let warns = warnings(
1532            r#"pipeline t(task) {
1533  enum Color { Red, Green, Blue }
1534  let c = Color.Red
1535  match c.variant {
1536    "Red" -> { log("r") }
1537    "Green" -> { log("g") }
1538    "Blue" -> { log("b") }
1539  }
1540}"#,
1541        );
1542        let exhaustive_warns: Vec<_> = warns
1543            .iter()
1544            .filter(|w| w.contains("Non-exhaustive"))
1545            .collect();
1546        assert!(exhaustive_warns.is_empty());
1547    }
1548
1549    #[test]
1550    fn test_non_exhaustive_match_warning() {
1551        let warns = warnings(
1552            r#"pipeline t(task) {
1553  enum Color { Red, Green, Blue }
1554  let c = Color.Red
1555  match c.variant {
1556    "Red" -> { log("r") }
1557    "Green" -> { log("g") }
1558  }
1559}"#,
1560        );
1561        let exhaustive_warns: Vec<_> = warns
1562            .iter()
1563            .filter(|w| w.contains("Non-exhaustive"))
1564            .collect();
1565        assert_eq!(exhaustive_warns.len(), 1);
1566        assert!(exhaustive_warns[0].contains("Blue"));
1567    }
1568
1569    #[test]
1570    fn test_non_exhaustive_multiple_missing() {
1571        let warns = warnings(
1572            r#"pipeline t(task) {
1573  enum Status { Active, Inactive, Pending }
1574  let s = Status.Active
1575  match s.variant {
1576    "Active" -> { log("a") }
1577  }
1578}"#,
1579        );
1580        let exhaustive_warns: Vec<_> = warns
1581            .iter()
1582            .filter(|w| w.contains("Non-exhaustive"))
1583            .collect();
1584        assert_eq!(exhaustive_warns.len(), 1);
1585        assert!(exhaustive_warns[0].contains("Inactive"));
1586        assert!(exhaustive_warns[0].contains("Pending"));
1587    }
1588
1589    #[test]
1590    fn test_enum_construct_type_inference() {
1591        let errs = errors(
1592            r#"pipeline t(task) {
1593  enum Color { Red, Green, Blue }
1594  let c: Color = Color.Red
1595}"#,
1596        );
1597        assert!(errs.is_empty());
1598    }
1599
1600    // --- Type narrowing tests ---
1601
1602    #[test]
1603    fn test_nil_coalescing_strips_nil() {
1604        // After ??, nil should be stripped from the type
1605        let errs = errors(
1606            r#"pipeline t(task) {
1607  let x: string | nil = nil
1608  let y: string = x ?? "default"
1609}"#,
1610        );
1611        assert!(errs.is_empty());
1612    }
1613
1614    #[test]
1615    fn test_shape_mismatch_detail_missing_field() {
1616        let errs = errors(
1617            r#"pipeline t(task) {
1618  let x: {name: string, age: int} = {name: "hello"}
1619}"#,
1620        );
1621        assert_eq!(errs.len(), 1);
1622        assert!(
1623            errs[0].contains("missing field 'age'"),
1624            "expected detail about missing field, got: {}",
1625            errs[0]
1626        );
1627    }
1628
1629    #[test]
1630    fn test_shape_mismatch_detail_wrong_type() {
1631        let errs = errors(
1632            r#"pipeline t(task) {
1633  let x: {name: string, age: int} = {name: 42, age: 10}
1634}"#,
1635        );
1636        assert_eq!(errs.len(), 1);
1637        assert!(
1638            errs[0].contains("field 'name' has type int, expected string"),
1639            "expected detail about wrong type, got: {}",
1640            errs[0]
1641        );
1642    }
1643
1644    // --- Match pattern type validation tests ---
1645
1646    #[test]
1647    fn test_match_pattern_string_against_int() {
1648        let warns = warnings(
1649            r#"pipeline t(task) {
1650  let x: int = 42
1651  match x {
1652    "hello" -> { log("bad") }
1653    42 -> { log("ok") }
1654  }
1655}"#,
1656        );
1657        let pattern_warns: Vec<_> = warns
1658            .iter()
1659            .filter(|w| w.contains("Match pattern type mismatch"))
1660            .collect();
1661        assert_eq!(pattern_warns.len(), 1);
1662        assert!(pattern_warns[0].contains("matching int against string literal"));
1663    }
1664
1665    #[test]
1666    fn test_match_pattern_int_against_string() {
1667        let warns = warnings(
1668            r#"pipeline t(task) {
1669  let x: string = "hello"
1670  match x {
1671    42 -> { log("bad") }
1672    "hello" -> { log("ok") }
1673  }
1674}"#,
1675        );
1676        let pattern_warns: Vec<_> = warns
1677            .iter()
1678            .filter(|w| w.contains("Match pattern type mismatch"))
1679            .collect();
1680        assert_eq!(pattern_warns.len(), 1);
1681        assert!(pattern_warns[0].contains("matching string against int literal"));
1682    }
1683
1684    #[test]
1685    fn test_match_pattern_bool_against_int() {
1686        let warns = warnings(
1687            r#"pipeline t(task) {
1688  let x: int = 42
1689  match x {
1690    true -> { log("bad") }
1691    42 -> { log("ok") }
1692  }
1693}"#,
1694        );
1695        let pattern_warns: Vec<_> = warns
1696            .iter()
1697            .filter(|w| w.contains("Match pattern type mismatch"))
1698            .collect();
1699        assert_eq!(pattern_warns.len(), 1);
1700        assert!(pattern_warns[0].contains("matching int against bool literal"));
1701    }
1702
1703    #[test]
1704    fn test_match_pattern_float_against_string() {
1705        let warns = warnings(
1706            r#"pipeline t(task) {
1707  let x: string = "hello"
1708  match x {
1709    3.14 -> { log("bad") }
1710    "hello" -> { log("ok") }
1711  }
1712}"#,
1713        );
1714        let pattern_warns: Vec<_> = warns
1715            .iter()
1716            .filter(|w| w.contains("Match pattern type mismatch"))
1717            .collect();
1718        assert_eq!(pattern_warns.len(), 1);
1719        assert!(pattern_warns[0].contains("matching string against float literal"));
1720    }
1721
1722    #[test]
1723    fn test_match_pattern_int_against_float_ok() {
1724        // int and float are compatible for match patterns
1725        let warns = warnings(
1726            r#"pipeline t(task) {
1727  let x: float = 3.14
1728  match x {
1729    42 -> { log("ok") }
1730    _ -> { log("default") }
1731  }
1732}"#,
1733        );
1734        let pattern_warns: Vec<_> = warns
1735            .iter()
1736            .filter(|w| w.contains("Match pattern type mismatch"))
1737            .collect();
1738        assert!(pattern_warns.is_empty());
1739    }
1740
1741    #[test]
1742    fn test_match_pattern_float_against_int_ok() {
1743        // float and int are compatible for match patterns
1744        let warns = warnings(
1745            r#"pipeline t(task) {
1746  let x: int = 42
1747  match x {
1748    3.14 -> { log("close") }
1749    _ -> { log("default") }
1750  }
1751}"#,
1752        );
1753        let pattern_warns: Vec<_> = warns
1754            .iter()
1755            .filter(|w| w.contains("Match pattern type mismatch"))
1756            .collect();
1757        assert!(pattern_warns.is_empty());
1758    }
1759
1760    #[test]
1761    fn test_match_pattern_correct_types_no_warning() {
1762        let warns = warnings(
1763            r#"pipeline t(task) {
1764  let x: int = 42
1765  match x {
1766    1 -> { log("one") }
1767    2 -> { log("two") }
1768    _ -> { log("other") }
1769  }
1770}"#,
1771        );
1772        let pattern_warns: Vec<_> = warns
1773            .iter()
1774            .filter(|w| w.contains("Match pattern type mismatch"))
1775            .collect();
1776        assert!(pattern_warns.is_empty());
1777    }
1778
1779    #[test]
1780    fn test_match_pattern_wildcard_no_warning() {
1781        let warns = warnings(
1782            r#"pipeline t(task) {
1783  let x: int = 42
1784  match x {
1785    _ -> { log("catch all") }
1786  }
1787}"#,
1788        );
1789        let pattern_warns: Vec<_> = warns
1790            .iter()
1791            .filter(|w| w.contains("Match pattern type mismatch"))
1792            .collect();
1793        assert!(pattern_warns.is_empty());
1794    }
1795
1796    #[test]
1797    fn test_match_pattern_untyped_no_warning() {
1798        // When value has no known type, no warning should be emitted
1799        let warns = warnings(
1800            r#"pipeline t(task) {
1801  let x = some_unknown_fn()
1802  match x {
1803    "hello" -> { log("string") }
1804    42 -> { log("int") }
1805  }
1806}"#,
1807        );
1808        let pattern_warns: Vec<_> = warns
1809            .iter()
1810            .filter(|w| w.contains("Match pattern type mismatch"))
1811            .collect();
1812        assert!(pattern_warns.is_empty());
1813    }
1814}