Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43    params: Vec<(String, InferredType)>,
44    return_type: InferredType,
45}
46
47impl TypeScope {
48    fn new() -> Self {
49        Self {
50            vars: BTreeMap::new(),
51            functions: BTreeMap::new(),
52            type_aliases: BTreeMap::new(),
53            enums: BTreeMap::new(),
54            interfaces: BTreeMap::new(),
55            structs: BTreeMap::new(),
56            parent: None,
57        }
58    }
59
60    fn child(&self) -> Self {
61        Self {
62            vars: BTreeMap::new(),
63            functions: BTreeMap::new(),
64            type_aliases: BTreeMap::new(),
65            enums: BTreeMap::new(),
66            interfaces: BTreeMap::new(),
67            structs: BTreeMap::new(),
68            parent: Some(Box::new(self.clone())),
69        }
70    }
71
72    fn get_var(&self, name: &str) -> Option<&InferredType> {
73        self.vars
74            .get(name)
75            .or_else(|| self.parent.as_ref()?.get_var(name))
76    }
77
78    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79        self.functions
80            .get(name)
81            .or_else(|| self.parent.as_ref()?.get_fn(name))
82    }
83
84    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85        self.type_aliases
86            .get(name)
87            .or_else(|| self.parent.as_ref()?.resolve_type(name))
88    }
89
90    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91        self.enums
92            .get(name)
93            .or_else(|| self.parent.as_ref()?.get_enum(name))
94    }
95
96    #[allow(dead_code)]
97    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98        self.interfaces
99            .get(name)
100            .or_else(|| self.parent.as_ref()?.get_interface(name))
101    }
102
103    fn define_var(&mut self, name: &str, ty: InferredType) {
104        self.vars.insert(name.to_string(), ty);
105    }
106
107    fn define_fn(&mut self, name: &str, sig: FnSignature) {
108        self.functions.insert(name.to_string(), sig);
109    }
110}
111
112/// Known return types for builtin functions.
113fn builtin_return_type(name: &str) -> InferredType {
114    match name {
115        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117            Some(TypeExpr::Named("nil".into()))
118        }
119        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122        "to_int" => Some(TypeExpr::Named("int".into())),
123        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125        "list_dir" => Some(TypeExpr::Named("list".into())),
126        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127        "env" | "regex_match" => Some(TypeExpr::Union(vec![
128            TypeExpr::Named("string".into()),
129            TypeExpr::Named("nil".into()),
130        ])),
131        "json_parse" | "json_extract" => None, // could be any type
132        _ => None,
133    }
134}
135
136/// Check if a name is a known builtin.
137fn is_builtin(name: &str) -> bool {
138    matches!(
139        name,
140        "log"
141            | "print"
142            | "println"
143            | "type_of"
144            | "to_string"
145            | "to_int"
146            | "to_float"
147            | "json_stringify"
148            | "json_parse"
149            | "env"
150            | "timestamp"
151            | "sleep"
152            | "read_file"
153            | "write_file"
154            | "exit"
155            | "regex_match"
156            | "regex_replace"
157            | "http_get"
158            | "http_post"
159            | "llm_call"
160            | "agent_loop"
161            | "await"
162            | "cancel"
163            | "file_exists"
164            | "delete_file"
165            | "list_dir"
166            | "mkdir"
167            | "path_join"
168            | "copy_file"
169            | "append_file"
170            | "temp_dir"
171            | "stat"
172            | "exec"
173            | "shell"
174            | "date_now"
175            | "date_format"
176            | "date_parse"
177            | "format"
178            | "json_validate"
179            | "json_extract"
180            | "trim"
181            | "lowercase"
182            | "uppercase"
183            | "split"
184            | "starts_with"
185            | "ends_with"
186            | "contains"
187            | "replace"
188            | "join"
189            | "len"
190            | "substring"
191            | "dirname"
192            | "basename"
193            | "extname"
194    )
195}
196
197/// The static type checker.
198pub struct TypeChecker {
199    diagnostics: Vec<TypeDiagnostic>,
200    scope: TypeScope,
201}
202
203impl TypeChecker {
204    pub fn new() -> Self {
205        Self {
206            diagnostics: Vec::new(),
207            scope: TypeScope::new(),
208        }
209    }
210
211    /// Check a program and return diagnostics.
212    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213        // First pass: register type and enum declarations into root scope
214        Self::register_declarations_into(&mut self.scope, program);
215
216        // Also scan pipeline bodies for declarations
217        for snode in program {
218            if let Node::Pipeline { body, .. } = &snode.node {
219                Self::register_declarations_into(&mut self.scope, body);
220            }
221        }
222
223        // Check each top-level node
224        for snode in program {
225            match &snode.node {
226                Node::Pipeline { params, body, .. } => {
227                    let mut child = self.scope.child();
228                    for p in params {
229                        child.define_var(p, None);
230                    }
231                    self.check_block(body, &mut child);
232                }
233                Node::FnDecl {
234                    name,
235                    params,
236                    return_type,
237                    body,
238                    ..
239                } => {
240                    let sig = FnSignature {
241                        params: params
242                            .iter()
243                            .map(|p| (p.name.clone(), p.type_expr.clone()))
244                            .collect(),
245                        return_type: return_type.clone(),
246                    };
247                    self.scope.define_fn(name, sig);
248                    self.check_fn_body(params, return_type, body);
249                }
250                _ => {
251                    let mut scope = self.scope.clone();
252                    self.check_node(snode, &mut scope);
253                    // Merge any new definitions back into the top-level scope
254                    for (name, ty) in scope.vars {
255                        self.scope.vars.entry(name).or_insert(ty);
256                    }
257                }
258            }
259        }
260
261        self.diagnostics
262    }
263
264    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
265    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266        for snode in nodes {
267            match &snode.node {
268                Node::TypeDecl { name, type_expr } => {
269                    scope.type_aliases.insert(name.clone(), type_expr.clone());
270                }
271                Node::EnumDecl { name, variants } => {
272                    let variant_names: Vec<String> =
273                        variants.iter().map(|v| v.name.clone()).collect();
274                    scope.enums.insert(name.clone(), variant_names);
275                }
276                Node::InterfaceDecl { name, methods } => {
277                    scope.interfaces.insert(name.clone(), methods.clone());
278                }
279                Node::StructDecl { name, fields } => {
280                    let field_types: Vec<(String, InferredType)> = fields
281                        .iter()
282                        .map(|f| (f.name.clone(), f.type_expr.clone()))
283                        .collect();
284                    scope.structs.insert(name.clone(), field_types);
285                }
286                _ => {}
287            }
288        }
289    }
290
291    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292        for stmt in stmts {
293            self.check_node(stmt, scope);
294        }
295    }
296
297    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
298        let span = snode.span;
299        match &snode.node {
300            Node::LetBinding {
301                name,
302                type_ann,
303                value,
304            } => {
305                let inferred = self.infer_type(value, scope);
306                if let Some(expected) = type_ann {
307                    if let Some(actual) = &inferred {
308                        if !self.types_compatible(expected, actual, scope) {
309                            self.error_at(
310                                format!(
311                                    "Type mismatch: '{}' declared as {}, but assigned {}",
312                                    name,
313                                    format_type(expected),
314                                    format_type(actual)
315                                ),
316                                span,
317                            );
318                        }
319                    }
320                }
321                let ty = type_ann.clone().or(inferred);
322                scope.define_var(name, ty);
323            }
324
325            Node::VarBinding {
326                name,
327                type_ann,
328                value,
329            } => {
330                let inferred = self.infer_type(value, scope);
331                if let Some(expected) = type_ann {
332                    if let Some(actual) = &inferred {
333                        if !self.types_compatible(expected, actual, scope) {
334                            self.error_at(
335                                format!(
336                                    "Type mismatch: '{}' declared as {}, but assigned {}",
337                                    name,
338                                    format_type(expected),
339                                    format_type(actual)
340                                ),
341                                span,
342                            );
343                        }
344                    }
345                }
346                let ty = type_ann.clone().or(inferred);
347                scope.define_var(name, ty);
348            }
349
350            Node::FnDecl {
351                name,
352                params,
353                return_type,
354                body,
355                ..
356            } => {
357                let sig = FnSignature {
358                    params: params
359                        .iter()
360                        .map(|p| (p.name.clone(), p.type_expr.clone()))
361                        .collect(),
362                    return_type: return_type.clone(),
363                };
364                scope.define_fn(name, sig.clone());
365                scope.define_var(name, None);
366                self.check_fn_body(params, return_type, body);
367            }
368
369            Node::FunctionCall { name, args } => {
370                self.check_call(name, args, scope, span);
371            }
372
373            Node::IfElse {
374                condition,
375                then_body,
376                else_body,
377            } => {
378                self.check_node(condition, scope);
379                let mut then_scope = scope.child();
380                self.check_block(then_body, &mut then_scope);
381                if let Some(else_body) = else_body {
382                    let mut else_scope = scope.child();
383                    self.check_block(else_body, &mut else_scope);
384                }
385            }
386
387            Node::ForIn {
388                variable,
389                iterable,
390                body,
391            } => {
392                self.check_node(iterable, scope);
393                let mut loop_scope = scope.child();
394                // Infer loop variable type from iterable
395                let elem_type = match self.infer_type(iterable, scope) {
396                    Some(TypeExpr::List(inner)) => Some(*inner),
397                    Some(TypeExpr::Named(n)) if n == "string" => {
398                        Some(TypeExpr::Named("string".into()))
399                    }
400                    _ => None,
401                };
402                loop_scope.define_var(variable, elem_type);
403                self.check_block(body, &mut loop_scope);
404            }
405
406            Node::WhileLoop { condition, body } => {
407                self.check_node(condition, scope);
408                let mut loop_scope = scope.child();
409                self.check_block(body, &mut loop_scope);
410            }
411
412            Node::TryCatch {
413                body,
414                error_var,
415                catch_body,
416                ..
417            } => {
418                let mut try_scope = scope.child();
419                self.check_block(body, &mut try_scope);
420                let mut catch_scope = scope.child();
421                if let Some(var) = error_var {
422                    catch_scope.define_var(var, None);
423                }
424                self.check_block(catch_body, &mut catch_scope);
425            }
426
427            Node::ReturnStmt {
428                value: Some(val), ..
429            } => {
430                self.check_node(val, scope);
431            }
432
433            Node::Assignment {
434                target, value, op, ..
435            } => {
436                self.check_node(value, scope);
437                if let Node::Identifier(name) = &target.node {
438                    if let Some(Some(var_type)) = scope.get_var(name) {
439                        let value_type = self.infer_type(value, scope);
440                        let assigned = if let Some(op) = op {
441                            let var_inferred = scope.get_var(name).cloned().flatten();
442                            infer_binary_op_type(op, &var_inferred, &value_type)
443                        } else {
444                            value_type
445                        };
446                        if let Some(actual) = &assigned {
447                            if !self.types_compatible(var_type, actual, scope) {
448                                self.error_at(
449                                    format!(
450                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
451                                        format_type(actual),
452                                        name,
453                                        format_type(var_type)
454                                    ),
455                                    span,
456                                );
457                            }
458                        }
459                    }
460                }
461            }
462
463            Node::TypeDecl { name, type_expr } => {
464                scope.type_aliases.insert(name.clone(), type_expr.clone());
465            }
466
467            Node::EnumDecl { name, variants } => {
468                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
469                scope.enums.insert(name.clone(), variant_names);
470            }
471
472            Node::StructDecl { name, fields } => {
473                let field_types: Vec<(String, InferredType)> = fields
474                    .iter()
475                    .map(|f| (f.name.clone(), f.type_expr.clone()))
476                    .collect();
477                scope.structs.insert(name.clone(), field_types);
478            }
479
480            Node::InterfaceDecl { name, methods } => {
481                scope.interfaces.insert(name.clone(), methods.clone());
482            }
483
484            Node::MatchExpr { value, arms } => {
485                self.check_node(value, scope);
486                for arm in arms {
487                    self.check_node(&arm.pattern, scope);
488                    let mut arm_scope = scope.child();
489                    self.check_block(&arm.body, &mut arm_scope);
490                }
491                self.check_match_exhaustiveness(value, arms, scope, span);
492            }
493
494            // Recurse into nested expressions + validate binary op types
495            Node::BinaryOp { op, left, right } => {
496                self.check_node(left, scope);
497                self.check_node(right, scope);
498                // Validate operator/type compatibility
499                let lt = self.infer_type(left, scope);
500                let rt = self.infer_type(right, scope);
501                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
502                    match op.as_str() {
503                        "-" | "*" | "/" | "%" => {
504                            let numeric = ["int", "float"];
505                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
506                                self.warning_at(
507                                    format!(
508                                        "Operator '{op}' may not be valid for types {} and {}",
509                                        l, r
510                                    ),
511                                    span,
512                                );
513                            }
514                        }
515                        "+" => {
516                            // + is valid for int, float, string, list, dict
517                            let valid = ["int", "float", "string", "list", "dict"];
518                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
519                                self.warning_at(
520                                    format!(
521                                        "Operator '+' may not be valid for types {} and {}",
522                                        l, r
523                                    ),
524                                    span,
525                                );
526                            }
527                        }
528                        _ => {}
529                    }
530                }
531            }
532            Node::UnaryOp { operand, .. } => {
533                self.check_node(operand, scope);
534            }
535            Node::MethodCall { object, args, .. } => {
536                self.check_node(object, scope);
537                for arg in args {
538                    self.check_node(arg, scope);
539                }
540            }
541            Node::PropertyAccess { object, .. } => {
542                self.check_node(object, scope);
543            }
544            Node::SubscriptAccess { object, index } => {
545                self.check_node(object, scope);
546                self.check_node(index, scope);
547            }
548
549            // Terminals — nothing to check
550            _ => {}
551        }
552    }
553
554    fn check_fn_body(
555        &mut self,
556        params: &[TypedParam],
557        return_type: &Option<TypeExpr>,
558        body: &[SNode],
559    ) {
560        let mut fn_scope = self.scope.child();
561        for param in params {
562            fn_scope.define_var(&param.name, param.type_expr.clone());
563        }
564        self.check_block(body, &mut fn_scope);
565
566        // Check return statements against declared return type
567        if let Some(ret_type) = return_type {
568            for stmt in body {
569                self.check_return_type(stmt, ret_type, &fn_scope);
570            }
571        }
572    }
573
574    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
575        let span = snode.span;
576        match &snode.node {
577            Node::ReturnStmt { value: Some(val) } => {
578                let inferred = self.infer_type(val, scope);
579                if let Some(actual) = &inferred {
580                    if !self.types_compatible(expected, actual, scope) {
581                        self.error_at(
582                            format!(
583                                "Return type mismatch: expected {}, got {}",
584                                format_type(expected),
585                                format_type(actual)
586                            ),
587                            span,
588                        );
589                    }
590                }
591            }
592            Node::IfElse {
593                then_body,
594                else_body,
595                ..
596            } => {
597                for stmt in then_body {
598                    self.check_return_type(stmt, expected, scope);
599                }
600                if let Some(else_body) = else_body {
601                    for stmt in else_body {
602                        self.check_return_type(stmt, expected, scope);
603                    }
604                }
605            }
606            _ => {}
607        }
608    }
609
610    /// Check if a match expression on an enum's `.variant` property covers all variants.
611    fn check_match_exhaustiveness(
612        &mut self,
613        value: &SNode,
614        arms: &[MatchArm],
615        scope: &TypeScope,
616        span: Span,
617    ) {
618        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
619        let enum_name = match &value.node {
620            Node::PropertyAccess { object, property } if property == "variant" => {
621                // Infer the type of the object
622                match self.infer_type(object, scope) {
623                    Some(TypeExpr::Named(name)) => {
624                        if scope.get_enum(&name).is_some() {
625                            Some(name)
626                        } else {
627                            None
628                        }
629                    }
630                    _ => None,
631                }
632            }
633            _ => {
634                // Direct match on an enum value: match <expr> { ... }
635                match self.infer_type(value, scope) {
636                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
637                    _ => None,
638                }
639            }
640        };
641
642        let Some(enum_name) = enum_name else {
643            return;
644        };
645        let Some(variants) = scope.get_enum(&enum_name) else {
646            return;
647        };
648
649        // Collect variant names covered by match arms
650        let mut covered: Vec<String> = Vec::new();
651        let mut has_wildcard = false;
652
653        for arm in arms {
654            match &arm.pattern.node {
655                // String literal pattern (matching on .variant): "VariantA"
656                Node::StringLiteral(s) => covered.push(s.clone()),
657                // Identifier pattern acts as a wildcard/catch-all
658                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
659                    has_wildcard = true;
660                }
661                // Direct enum construct pattern: EnumName.Variant
662                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
663                // PropertyAccess pattern: EnumName.Variant (no args)
664                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
665                _ => {
666                    // Unknown pattern shape — conservatively treat as wildcard
667                    has_wildcard = true;
668                }
669            }
670        }
671
672        if has_wildcard {
673            return;
674        }
675
676        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
677        if !missing.is_empty() {
678            let missing_str = missing
679                .iter()
680                .map(|s| format!("\"{}\"", s))
681                .collect::<Vec<_>>()
682                .join(", ");
683            self.warning_at(
684                format!(
685                    "Non-exhaustive match on enum {}: missing variants {}",
686                    enum_name, missing_str
687                ),
688                span,
689            );
690        }
691    }
692
693    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
694        // Check against known function signatures
695        if let Some(sig) = scope.get_fn(name).cloned() {
696            if args.len() != sig.params.len() && !is_builtin(name) {
697                self.warning_at(
698                    format!(
699                        "Function '{}' expects {} arguments, got {}",
700                        name,
701                        sig.params.len(),
702                        args.len()
703                    ),
704                    span,
705                );
706            }
707            for (i, (arg, (param_name, param_type))) in
708                args.iter().zip(sig.params.iter()).enumerate()
709            {
710                if let Some(expected) = param_type {
711                    let actual = self.infer_type(arg, scope);
712                    if let Some(actual) = &actual {
713                        if !self.types_compatible(expected, actual, scope) {
714                            self.error_at(
715                                format!(
716                                    "Argument {} ('{}'): expected {}, got {}",
717                                    i + 1,
718                                    param_name,
719                                    format_type(expected),
720                                    format_type(actual)
721                                ),
722                                arg.span,
723                            );
724                        }
725                    }
726                }
727            }
728        }
729        // Check args recursively
730        for arg in args {
731            self.check_node(arg, scope);
732        }
733    }
734
735    /// Infer the type of an expression.
736    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
737        match &snode.node {
738            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
739            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
740            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
741                Some(TypeExpr::Named("string".into()))
742            }
743            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
744            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
745            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
746            Node::DictLiteral(_) => Some(TypeExpr::Named("dict".into())),
747            Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
748
749            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
750
751            Node::FunctionCall { name, .. } => {
752                // Check user-defined function return types
753                if let Some(sig) = scope.get_fn(name) {
754                    return sig.return_type.clone();
755                }
756                // Check builtin return types
757                builtin_return_type(name)
758            }
759
760            Node::BinaryOp { op, left, right } => {
761                let lt = self.infer_type(left, scope);
762                let rt = self.infer_type(right, scope);
763                infer_binary_op_type(op, &lt, &rt)
764            }
765
766            Node::UnaryOp { op, operand } => {
767                let t = self.infer_type(operand, scope);
768                match op.as_str() {
769                    "!" => Some(TypeExpr::Named("bool".into())),
770                    "-" => t, // negation preserves type
771                    _ => None,
772                }
773            }
774
775            Node::Ternary {
776                true_expr,
777                false_expr,
778                ..
779            } => {
780                let tt = self.infer_type(true_expr, scope);
781                let ft = self.infer_type(false_expr, scope);
782                match (&tt, &ft) {
783                    (Some(a), Some(b)) if a == b => tt,
784                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
785                    (Some(_), None) => tt,
786                    (None, Some(_)) => ft,
787                    (None, None) => None,
788                }
789            }
790
791            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
792
793            Node::PropertyAccess { object, property } => {
794                // EnumName.Variant → infer as the enum type
795                if let Node::Identifier(name) = &object.node {
796                    if scope.get_enum(name).is_some() {
797                        return Some(TypeExpr::Named(name.clone()));
798                    }
799                }
800                // .variant on an enum value → string
801                if property == "variant" {
802                    let obj_type = self.infer_type(object, scope);
803                    if let Some(TypeExpr::Named(name)) = &obj_type {
804                        if scope.get_enum(name).is_some() {
805                            return Some(TypeExpr::Named("string".into()));
806                        }
807                    }
808                }
809                None
810            }
811
812            Node::SubscriptAccess { object, .. } => {
813                let obj_type = self.infer_type(object, scope);
814                match &obj_type {
815                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
816                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
817                    Some(TypeExpr::Named(n)) if n == "list" => None,
818                    Some(TypeExpr::Named(n)) if n == "dict" => None,
819                    Some(TypeExpr::Named(n)) if n == "string" => {
820                        Some(TypeExpr::Named("string".into()))
821                    }
822                    _ => None,
823                }
824            }
825            Node::MethodCall { object, method, .. } => {
826                let obj_type = self.infer_type(object, scope);
827                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
828                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
829                match method.as_str() {
830                    // Shared: bool-returning methods
831                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
832                        Some(TypeExpr::Named("bool".into()))
833                    }
834                    // Shared: int-returning methods
835                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
836                    // String methods
837                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
838                    | "pad_left" | "pad_right" | "repeat" | "join" => {
839                        Some(TypeExpr::Named("string".into()))
840                    }
841                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
842                    // filter returns dict for dicts, list for lists
843                    "filter" => {
844                        if is_dict {
845                            Some(TypeExpr::Named("dict".into()))
846                        } else {
847                            Some(TypeExpr::Named("list".into()))
848                        }
849                    }
850                    // List methods
851                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
852                    "reduce" | "find" | "first" | "last" => None,
853                    // Dict methods
854                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
855                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
856                    // Conversions
857                    "to_string" => Some(TypeExpr::Named("string".into())),
858                    "to_int" => Some(TypeExpr::Named("int".into())),
859                    "to_float" => Some(TypeExpr::Named("float".into())),
860                    _ => None,
861                }
862            }
863
864            _ => None,
865        }
866    }
867
868    /// Check if two types are compatible (actual can be assigned to expected).
869    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
870        let expected = self.resolve_alias(expected, scope);
871        let actual = self.resolve_alias(actual, scope);
872
873        match (&expected, &actual) {
874            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
875            (TypeExpr::Union(members), actual_type) => members
876                .iter()
877                .any(|m| self.types_compatible(m, actual_type, scope)),
878            (expected_type, TypeExpr::Union(members)) => members
879                .iter()
880                .all(|m| self.types_compatible(expected_type, m, scope)),
881            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
882            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
883                if expected_field.optional {
884                    return true;
885                }
886                af.iter().any(|actual_field| {
887                    actual_field.name == expected_field.name
888                        && self.types_compatible(
889                            &expected_field.type_expr,
890                            &actual_field.type_expr,
891                            scope,
892                        )
893                })
894            }),
895            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
896                self.types_compatible(expected_inner, actual_inner, scope)
897            }
898            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
899            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
900            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
901                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
902            }
903            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
904            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
905            _ => false,
906        }
907    }
908
909    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
910        if let TypeExpr::Named(name) = ty {
911            if let Some(resolved) = scope.resolve_type(name) {
912                return resolved.clone();
913            }
914        }
915        ty.clone()
916    }
917
918    fn error_at(&mut self, message: String, span: Span) {
919        self.diagnostics.push(TypeDiagnostic {
920            message,
921            severity: DiagnosticSeverity::Error,
922            span: Some(span),
923        });
924    }
925
926    fn warning_at(&mut self, message: String, span: Span) {
927        self.diagnostics.push(TypeDiagnostic {
928            message,
929            severity: DiagnosticSeverity::Warning,
930            span: Some(span),
931        });
932    }
933}
934
935impl Default for TypeChecker {
936    fn default() -> Self {
937        Self::new()
938    }
939}
940
941/// Infer the result type of a binary operation.
942fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
943    match op {
944        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
945        "+" => match (left, right) {
946            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
947                match (l.as_str(), r.as_str()) {
948                    ("int", "int") => Some(TypeExpr::Named("int".into())),
949                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
950                    ("string", _) => Some(TypeExpr::Named("string".into())),
951                    ("list", "list") => Some(TypeExpr::Named("list".into())),
952                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
953                    _ => Some(TypeExpr::Named("string".into())),
954                }
955            }
956            _ => None,
957        },
958        "-" | "*" | "/" | "%" => match (left, right) {
959            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
960                match (l.as_str(), r.as_str()) {
961                    ("int", "int") => Some(TypeExpr::Named("int".into())),
962                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
963                    _ => None,
964                }
965            }
966            _ => None,
967        },
968        "??" => match (left, right) {
969            (Some(TypeExpr::Union(members)), _) => {
970                let non_nil: Vec<_> = members
971                    .iter()
972                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
973                    .cloned()
974                    .collect();
975                if non_nil.len() == 1 {
976                    Some(non_nil[0].clone())
977                } else if non_nil.is_empty() {
978                    right.clone()
979                } else {
980                    Some(TypeExpr::Union(non_nil))
981                }
982            }
983            _ => right.clone(),
984        },
985        "|>" => None,
986        _ => None,
987    }
988}
989
990/// Format a type expression for display in error messages.
991pub fn format_type(ty: &TypeExpr) -> String {
992    match ty {
993        TypeExpr::Named(n) => n.clone(),
994        TypeExpr::Union(types) => types
995            .iter()
996            .map(format_type)
997            .collect::<Vec<_>>()
998            .join(" | "),
999        TypeExpr::Shape(fields) => {
1000            let inner: Vec<String> = fields
1001                .iter()
1002                .map(|f| {
1003                    let opt = if f.optional { "?" } else { "" };
1004                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1005                })
1006                .collect();
1007            format!("{{{}}}", inner.join(", "))
1008        }
1009        TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1010        TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1011    }
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use super::*;
1017    use crate::Parser;
1018    use harn_lexer::Lexer;
1019
1020    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1021        let mut lexer = Lexer::new(source);
1022        let tokens = lexer.tokenize().unwrap();
1023        let mut parser = Parser::new(tokens);
1024        let program = parser.parse().unwrap();
1025        TypeChecker::new().check(&program)
1026    }
1027
1028    fn errors(source: &str) -> Vec<String> {
1029        check_source(source)
1030            .into_iter()
1031            .filter(|d| d.severity == DiagnosticSeverity::Error)
1032            .map(|d| d.message)
1033            .collect()
1034    }
1035
1036    #[test]
1037    fn test_no_errors_for_untyped_code() {
1038        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1039        assert!(errs.is_empty());
1040    }
1041
1042    #[test]
1043    fn test_correct_typed_let() {
1044        let errs = errors("pipeline t(task) { let x: int = 42 }");
1045        assert!(errs.is_empty());
1046    }
1047
1048    #[test]
1049    fn test_type_mismatch_let() {
1050        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1051        assert_eq!(errs.len(), 1);
1052        assert!(errs[0].contains("Type mismatch"));
1053        assert!(errs[0].contains("int"));
1054        assert!(errs[0].contains("string"));
1055    }
1056
1057    #[test]
1058    fn test_correct_typed_fn() {
1059        let errs = errors(
1060            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1061        );
1062        assert!(errs.is_empty());
1063    }
1064
1065    #[test]
1066    fn test_fn_arg_type_mismatch() {
1067        let errs = errors(
1068            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1069add("hello", 2) }"#,
1070        );
1071        assert_eq!(errs.len(), 1);
1072        assert!(errs[0].contains("Argument 1"));
1073        assert!(errs[0].contains("expected int"));
1074    }
1075
1076    #[test]
1077    fn test_return_type_mismatch() {
1078        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1079        assert_eq!(errs.len(), 1);
1080        assert!(errs[0].contains("Return type mismatch"));
1081    }
1082
1083    #[test]
1084    fn test_union_type_compatible() {
1085        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1086        assert!(errs.is_empty());
1087    }
1088
1089    #[test]
1090    fn test_union_type_mismatch() {
1091        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1092        assert_eq!(errs.len(), 1);
1093        assert!(errs[0].contains("Type mismatch"));
1094    }
1095
1096    #[test]
1097    fn test_type_inference_propagation() {
1098        let errs = errors(
1099            r#"pipeline t(task) {
1100  fn add(a: int, b: int) -> int { return a + b }
1101  let result: string = add(1, 2)
1102}"#,
1103        );
1104        assert_eq!(errs.len(), 1);
1105        assert!(errs[0].contains("Type mismatch"));
1106        assert!(errs[0].contains("string"));
1107        assert!(errs[0].contains("int"));
1108    }
1109
1110    #[test]
1111    fn test_builtin_return_type_inference() {
1112        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1113        assert_eq!(errs.len(), 1);
1114        assert!(errs[0].contains("string"));
1115        assert!(errs[0].contains("int"));
1116    }
1117
1118    #[test]
1119    fn test_binary_op_type_inference() {
1120        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1121        assert_eq!(errs.len(), 1);
1122    }
1123
1124    #[test]
1125    fn test_comparison_returns_bool() {
1126        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1127        assert!(errs.is_empty());
1128    }
1129
1130    #[test]
1131    fn test_int_float_promotion() {
1132        let errs = errors("pipeline t(task) { let x: float = 42 }");
1133        assert!(errs.is_empty());
1134    }
1135
1136    #[test]
1137    fn test_untyped_code_no_errors() {
1138        let errs = errors(
1139            r#"pipeline t(task) {
1140  fn process(data) {
1141    let result = data + " processed"
1142    return result
1143  }
1144  log(process("hello"))
1145}"#,
1146        );
1147        assert!(errs.is_empty());
1148    }
1149
1150    #[test]
1151    fn test_type_alias() {
1152        let errs = errors(
1153            r#"pipeline t(task) {
1154  type Name = string
1155  let x: Name = "hello"
1156}"#,
1157        );
1158        assert!(errs.is_empty());
1159    }
1160
1161    #[test]
1162    fn test_type_alias_mismatch() {
1163        let errs = errors(
1164            r#"pipeline t(task) {
1165  type Name = string
1166  let x: Name = 42
1167}"#,
1168        );
1169        assert_eq!(errs.len(), 1);
1170    }
1171
1172    #[test]
1173    fn test_assignment_type_check() {
1174        let errs = errors(
1175            r#"pipeline t(task) {
1176  var x: int = 0
1177  x = "hello"
1178}"#,
1179        );
1180        assert_eq!(errs.len(), 1);
1181        assert!(errs[0].contains("cannot assign string"));
1182    }
1183
1184    #[test]
1185    fn test_covariance_int_to_float_in_fn() {
1186        let errs = errors(
1187            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1188        );
1189        assert!(errs.is_empty());
1190    }
1191
1192    #[test]
1193    fn test_covariance_return_type() {
1194        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1195        assert!(errs.is_empty());
1196    }
1197
1198    #[test]
1199    fn test_no_contravariance_float_to_int() {
1200        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1201        assert_eq!(errs.len(), 1);
1202    }
1203
1204    // --- Exhaustiveness checking tests ---
1205
1206    fn warnings(source: &str) -> Vec<String> {
1207        check_source(source)
1208            .into_iter()
1209            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1210            .map(|d| d.message)
1211            .collect()
1212    }
1213
1214    #[test]
1215    fn test_exhaustive_match_no_warning() {
1216        let warns = warnings(
1217            r#"pipeline t(task) {
1218  enum Color { Red, Green, Blue }
1219  let c = Color.Red
1220  match c.variant {
1221    "Red" -> { log("r") }
1222    "Green" -> { log("g") }
1223    "Blue" -> { log("b") }
1224  }
1225}"#,
1226        );
1227        let exhaustive_warns: Vec<_> = warns
1228            .iter()
1229            .filter(|w| w.contains("Non-exhaustive"))
1230            .collect();
1231        assert!(exhaustive_warns.is_empty());
1232    }
1233
1234    #[test]
1235    fn test_non_exhaustive_match_warning() {
1236        let warns = warnings(
1237            r#"pipeline t(task) {
1238  enum Color { Red, Green, Blue }
1239  let c = Color.Red
1240  match c.variant {
1241    "Red" -> { log("r") }
1242    "Green" -> { log("g") }
1243  }
1244}"#,
1245        );
1246        let exhaustive_warns: Vec<_> = warns
1247            .iter()
1248            .filter(|w| w.contains("Non-exhaustive"))
1249            .collect();
1250        assert_eq!(exhaustive_warns.len(), 1);
1251        assert!(exhaustive_warns[0].contains("Blue"));
1252    }
1253
1254    #[test]
1255    fn test_non_exhaustive_multiple_missing() {
1256        let warns = warnings(
1257            r#"pipeline t(task) {
1258  enum Status { Active, Inactive, Pending }
1259  let s = Status.Active
1260  match s.variant {
1261    "Active" -> { log("a") }
1262  }
1263}"#,
1264        );
1265        let exhaustive_warns: Vec<_> = warns
1266            .iter()
1267            .filter(|w| w.contains("Non-exhaustive"))
1268            .collect();
1269        assert_eq!(exhaustive_warns.len(), 1);
1270        assert!(exhaustive_warns[0].contains("Inactive"));
1271        assert!(exhaustive_warns[0].contains("Pending"));
1272    }
1273
1274    #[test]
1275    fn test_enum_construct_type_inference() {
1276        let errs = errors(
1277            r#"pipeline t(task) {
1278  enum Color { Red, Green, Blue }
1279  let c: Color = Color.Red
1280}"#,
1281        );
1282        assert!(errs.is_empty());
1283    }
1284
1285    // --- Type narrowing tests ---
1286
1287    #[test]
1288    fn test_nil_coalescing_strips_nil() {
1289        // After ??, nil should be stripped from the type
1290        let errs = errors(
1291            r#"pipeline t(task) {
1292  let x: string | nil = nil
1293  let y: string = x ?? "default"
1294}"#,
1295        );
1296        assert!(errs.is_empty());
1297    }
1298}