Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43    params: Vec<(String, InferredType)>,
44    return_type: InferredType,
45}
46
47impl TypeScope {
48    fn new() -> Self {
49        Self {
50            vars: BTreeMap::new(),
51            functions: BTreeMap::new(),
52            type_aliases: BTreeMap::new(),
53            enums: BTreeMap::new(),
54            interfaces: BTreeMap::new(),
55            structs: BTreeMap::new(),
56            parent: None,
57        }
58    }
59
60    fn child(&self) -> Self {
61        Self {
62            vars: BTreeMap::new(),
63            functions: BTreeMap::new(),
64            type_aliases: BTreeMap::new(),
65            enums: BTreeMap::new(),
66            interfaces: BTreeMap::new(),
67            structs: BTreeMap::new(),
68            parent: Some(Box::new(self.clone())),
69        }
70    }
71
72    fn get_var(&self, name: &str) -> Option<&InferredType> {
73        self.vars
74            .get(name)
75            .or_else(|| self.parent.as_ref()?.get_var(name))
76    }
77
78    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79        self.functions
80            .get(name)
81            .or_else(|| self.parent.as_ref()?.get_fn(name))
82    }
83
84    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85        self.type_aliases
86            .get(name)
87            .or_else(|| self.parent.as_ref()?.resolve_type(name))
88    }
89
90    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91        self.enums
92            .get(name)
93            .or_else(|| self.parent.as_ref()?.get_enum(name))
94    }
95
96    #[allow(dead_code)]
97    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98        self.interfaces
99            .get(name)
100            .or_else(|| self.parent.as_ref()?.get_interface(name))
101    }
102
103    fn define_var(&mut self, name: &str, ty: InferredType) {
104        self.vars.insert(name.to_string(), ty);
105    }
106
107    fn define_fn(&mut self, name: &str, sig: FnSignature) {
108        self.functions.insert(name.to_string(), sig);
109    }
110}
111
112/// Known return types for builtin functions.
113fn builtin_return_type(name: &str) -> InferredType {
114    match name {
115        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117            Some(TypeExpr::Named("nil".into()))
118        }
119        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122        "to_int" => Some(TypeExpr::Named("int".into())),
123        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125        "list_dir" => Some(TypeExpr::Named("list".into())),
126        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127        "env" | "regex_match" => Some(TypeExpr::Union(vec![
128            TypeExpr::Named("string".into()),
129            TypeExpr::Named("nil".into()),
130        ])),
131        "json_parse" | "json_extract" => None, // could be any type
132        _ => None,
133    }
134}
135
136/// Check if a name is a known builtin.
137fn is_builtin(name: &str) -> bool {
138    matches!(
139        name,
140        "log"
141            | "print"
142            | "println"
143            | "type_of"
144            | "to_string"
145            | "to_int"
146            | "to_float"
147            | "json_stringify"
148            | "json_parse"
149            | "env"
150            | "timestamp"
151            | "sleep"
152            | "read_file"
153            | "write_file"
154            | "exit"
155            | "regex_match"
156            | "regex_replace"
157            | "http_get"
158            | "http_post"
159            | "llm_call"
160            | "agent_loop"
161            | "await"
162            | "cancel"
163            | "file_exists"
164            | "delete_file"
165            | "list_dir"
166            | "mkdir"
167            | "path_join"
168            | "copy_file"
169            | "append_file"
170            | "temp_dir"
171            | "stat"
172            | "exec"
173            | "shell"
174            | "date_now"
175            | "date_format"
176            | "date_parse"
177            | "format"
178            | "json_validate"
179            | "json_extract"
180            | "trim"
181            | "lowercase"
182            | "uppercase"
183            | "split"
184            | "starts_with"
185            | "ends_with"
186            | "contains"
187            | "replace"
188            | "join"
189            | "len"
190            | "substring"
191            | "dirname"
192            | "basename"
193            | "extname"
194    )
195}
196
197/// The static type checker.
198pub struct TypeChecker {
199    diagnostics: Vec<TypeDiagnostic>,
200    scope: TypeScope,
201}
202
203impl TypeChecker {
204    pub fn new() -> Self {
205        Self {
206            diagnostics: Vec::new(),
207            scope: TypeScope::new(),
208        }
209    }
210
211    /// Check a program and return diagnostics.
212    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213        // First pass: register type and enum declarations into root scope
214        Self::register_declarations_into(&mut self.scope, program);
215
216        // Also scan pipeline bodies for declarations
217        for snode in program {
218            if let Node::Pipeline { body, .. } = &snode.node {
219                Self::register_declarations_into(&mut self.scope, body);
220            }
221        }
222
223        // Check each top-level node
224        for snode in program {
225            match &snode.node {
226                Node::Pipeline { params, body, .. } => {
227                    let mut child = self.scope.child();
228                    for p in params {
229                        child.define_var(p, None);
230                    }
231                    self.check_block(body, &mut child);
232                }
233                Node::FnDecl {
234                    name,
235                    params,
236                    return_type,
237                    body,
238                    ..
239                } => {
240                    let sig = FnSignature {
241                        params: params
242                            .iter()
243                            .map(|p| (p.name.clone(), p.type_expr.clone()))
244                            .collect(),
245                        return_type: return_type.clone(),
246                    };
247                    self.scope.define_fn(name, sig);
248                    self.check_fn_body(params, return_type, body);
249                }
250                _ => {
251                    let mut scope = self.scope.clone();
252                    self.check_node(snode, &mut scope);
253                    // Merge any new definitions back into the top-level scope
254                    for (name, ty) in scope.vars {
255                        self.scope.vars.entry(name).or_insert(ty);
256                    }
257                }
258            }
259        }
260
261        self.diagnostics
262    }
263
264    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
265    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266        for snode in nodes {
267            match &snode.node {
268                Node::TypeDecl { name, type_expr } => {
269                    scope.type_aliases.insert(name.clone(), type_expr.clone());
270                }
271                Node::EnumDecl { name, variants } => {
272                    let variant_names: Vec<String> =
273                        variants.iter().map(|v| v.name.clone()).collect();
274                    scope.enums.insert(name.clone(), variant_names);
275                }
276                Node::InterfaceDecl { name, methods } => {
277                    scope.interfaces.insert(name.clone(), methods.clone());
278                }
279                Node::StructDecl { name, fields } => {
280                    let field_types: Vec<(String, InferredType)> = fields
281                        .iter()
282                        .map(|f| (f.name.clone(), f.type_expr.clone()))
283                        .collect();
284                    scope.structs.insert(name.clone(), field_types);
285                }
286                _ => {}
287            }
288        }
289    }
290
291    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292        for stmt in stmts {
293            self.check_node(stmt, scope);
294        }
295    }
296
297    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
298        let span = snode.span;
299        match &snode.node {
300            Node::LetBinding {
301                name,
302                type_ann,
303                value,
304            } => {
305                let inferred = self.infer_type(value, scope);
306                if let Some(expected) = type_ann {
307                    if let Some(actual) = &inferred {
308                        if !self.types_compatible(expected, actual, scope) {
309                            self.error_at(
310                                format!(
311                                    "Type mismatch: '{}' declared as {}, but assigned {}",
312                                    name,
313                                    format_type(expected),
314                                    format_type(actual)
315                                ),
316                                span,
317                            );
318                        }
319                    }
320                }
321                let ty = type_ann.clone().or(inferred);
322                scope.define_var(name, ty);
323            }
324
325            Node::VarBinding {
326                name,
327                type_ann,
328                value,
329            } => {
330                let inferred = self.infer_type(value, scope);
331                if let Some(expected) = type_ann {
332                    if let Some(actual) = &inferred {
333                        if !self.types_compatible(expected, actual, scope) {
334                            self.error_at(
335                                format!(
336                                    "Type mismatch: '{}' declared as {}, but assigned {}",
337                                    name,
338                                    format_type(expected),
339                                    format_type(actual)
340                                ),
341                                span,
342                            );
343                        }
344                    }
345                }
346                let ty = type_ann.clone().or(inferred);
347                scope.define_var(name, ty);
348            }
349
350            Node::FnDecl {
351                name,
352                params,
353                return_type,
354                body,
355                ..
356            } => {
357                let sig = FnSignature {
358                    params: params
359                        .iter()
360                        .map(|p| (p.name.clone(), p.type_expr.clone()))
361                        .collect(),
362                    return_type: return_type.clone(),
363                };
364                scope.define_fn(name, sig.clone());
365                scope.define_var(name, None);
366                self.check_fn_body(params, return_type, body);
367            }
368
369            Node::FunctionCall { name, args } => {
370                self.check_call(name, args, scope, span);
371            }
372
373            Node::IfElse {
374                condition,
375                then_body,
376                else_body,
377            } => {
378                self.check_node(condition, scope);
379                let mut then_scope = scope.child();
380                self.check_block(then_body, &mut then_scope);
381                if let Some(else_body) = else_body {
382                    let mut else_scope = scope.child();
383                    self.check_block(else_body, &mut else_scope);
384                }
385            }
386
387            Node::ForIn {
388                variable,
389                iterable,
390                body,
391            } => {
392                self.check_node(iterable, scope);
393                let mut loop_scope = scope.child();
394                // Infer loop variable type from iterable
395                let elem_type = match self.infer_type(iterable, scope) {
396                    Some(TypeExpr::List(inner)) => Some(*inner),
397                    Some(TypeExpr::Named(n)) if n == "string" => {
398                        Some(TypeExpr::Named("string".into()))
399                    }
400                    _ => None,
401                };
402                loop_scope.define_var(variable, elem_type);
403                self.check_block(body, &mut loop_scope);
404            }
405
406            Node::WhileLoop { condition, body } => {
407                self.check_node(condition, scope);
408                let mut loop_scope = scope.child();
409                self.check_block(body, &mut loop_scope);
410            }
411
412            Node::TryCatch {
413                body,
414                error_var,
415                catch_body,
416                ..
417            } => {
418                let mut try_scope = scope.child();
419                self.check_block(body, &mut try_scope);
420                let mut catch_scope = scope.child();
421                if let Some(var) = error_var {
422                    catch_scope.define_var(var, None);
423                }
424                self.check_block(catch_body, &mut catch_scope);
425            }
426
427            Node::ReturnStmt {
428                value: Some(val), ..
429            } => {
430                self.check_node(val, scope);
431            }
432
433            Node::Assignment {
434                target, value, op, ..
435            } => {
436                self.check_node(value, scope);
437                if let Node::Identifier(name) = &target.node {
438                    if let Some(Some(var_type)) = scope.get_var(name) {
439                        let value_type = self.infer_type(value, scope);
440                        let assigned = if let Some(op) = op {
441                            let var_inferred = scope.get_var(name).cloned().flatten();
442                            infer_binary_op_type(op, &var_inferred, &value_type)
443                        } else {
444                            value_type
445                        };
446                        if let Some(actual) = &assigned {
447                            if !self.types_compatible(var_type, actual, scope) {
448                                self.error_at(
449                                    format!(
450                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
451                                        format_type(actual),
452                                        name,
453                                        format_type(var_type)
454                                    ),
455                                    span,
456                                );
457                            }
458                        }
459                    }
460                }
461            }
462
463            Node::TypeDecl { name, type_expr } => {
464                scope.type_aliases.insert(name.clone(), type_expr.clone());
465            }
466
467            Node::EnumDecl { name, variants } => {
468                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
469                scope.enums.insert(name.clone(), variant_names);
470            }
471
472            Node::StructDecl { name, fields } => {
473                let field_types: Vec<(String, InferredType)> = fields
474                    .iter()
475                    .map(|f| (f.name.clone(), f.type_expr.clone()))
476                    .collect();
477                scope.structs.insert(name.clone(), field_types);
478            }
479
480            Node::InterfaceDecl { name, methods } => {
481                scope.interfaces.insert(name.clone(), methods.clone());
482            }
483
484            Node::MatchExpr { value, arms } => {
485                self.check_node(value, scope);
486                for arm in arms {
487                    self.check_node(&arm.pattern, scope);
488                    let mut arm_scope = scope.child();
489                    self.check_block(&arm.body, &mut arm_scope);
490                }
491                self.check_match_exhaustiveness(value, arms, scope, span);
492            }
493
494            // Recurse into nested expressions + validate binary op types
495            Node::BinaryOp { op, left, right } => {
496                self.check_node(left, scope);
497                self.check_node(right, scope);
498                // Validate operator/type compatibility
499                let lt = self.infer_type(left, scope);
500                let rt = self.infer_type(right, scope);
501                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
502                    match op.as_str() {
503                        "-" | "*" | "/" | "%" => {
504                            let numeric = ["int", "float"];
505                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
506                                self.warning_at(
507                                    format!(
508                                        "Operator '{op}' may not be valid for types {} and {}",
509                                        l, r
510                                    ),
511                                    span,
512                                );
513                            }
514                        }
515                        "+" => {
516                            // + is valid for int, float, string, list, dict
517                            let valid = ["int", "float", "string", "list", "dict"];
518                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
519                                self.warning_at(
520                                    format!(
521                                        "Operator '+' may not be valid for types {} and {}",
522                                        l, r
523                                    ),
524                                    span,
525                                );
526                            }
527                        }
528                        _ => {}
529                    }
530                }
531            }
532            Node::UnaryOp { operand, .. } => {
533                self.check_node(operand, scope);
534            }
535            Node::MethodCall { object, args, .. }
536            | Node::OptionalMethodCall { object, args, .. } => {
537                self.check_node(object, scope);
538                for arg in args {
539                    self.check_node(arg, scope);
540                }
541            }
542            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
543                self.check_node(object, scope);
544            }
545            Node::SubscriptAccess { object, index } => {
546                self.check_node(object, scope);
547                self.check_node(index, scope);
548            }
549            Node::SliceAccess { object, start, end } => {
550                self.check_node(object, scope);
551                if let Some(s) = start {
552                    self.check_node(s, scope);
553                }
554                if let Some(e) = end {
555                    self.check_node(e, scope);
556                }
557            }
558
559            // Terminals — nothing to check
560            _ => {}
561        }
562    }
563
564    fn check_fn_body(
565        &mut self,
566        params: &[TypedParam],
567        return_type: &Option<TypeExpr>,
568        body: &[SNode],
569    ) {
570        let mut fn_scope = self.scope.child();
571        for param in params {
572            fn_scope.define_var(&param.name, param.type_expr.clone());
573        }
574        self.check_block(body, &mut fn_scope);
575
576        // Check return statements against declared return type
577        if let Some(ret_type) = return_type {
578            for stmt in body {
579                self.check_return_type(stmt, ret_type, &fn_scope);
580            }
581        }
582    }
583
584    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
585        let span = snode.span;
586        match &snode.node {
587            Node::ReturnStmt { value: Some(val) } => {
588                let inferred = self.infer_type(val, scope);
589                if let Some(actual) = &inferred {
590                    if !self.types_compatible(expected, actual, scope) {
591                        self.error_at(
592                            format!(
593                                "Return type mismatch: expected {}, got {}",
594                                format_type(expected),
595                                format_type(actual)
596                            ),
597                            span,
598                        );
599                    }
600                }
601            }
602            Node::IfElse {
603                then_body,
604                else_body,
605                ..
606            } => {
607                for stmt in then_body {
608                    self.check_return_type(stmt, expected, scope);
609                }
610                if let Some(else_body) = else_body {
611                    for stmt in else_body {
612                        self.check_return_type(stmt, expected, scope);
613                    }
614                }
615            }
616            _ => {}
617        }
618    }
619
620    /// Check if a match expression on an enum's `.variant` property covers all variants.
621    fn check_match_exhaustiveness(
622        &mut self,
623        value: &SNode,
624        arms: &[MatchArm],
625        scope: &TypeScope,
626        span: Span,
627    ) {
628        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
629        let enum_name = match &value.node {
630            Node::PropertyAccess { object, property } if property == "variant" => {
631                // Infer the type of the object
632                match self.infer_type(object, scope) {
633                    Some(TypeExpr::Named(name)) => {
634                        if scope.get_enum(&name).is_some() {
635                            Some(name)
636                        } else {
637                            None
638                        }
639                    }
640                    _ => None,
641                }
642            }
643            _ => {
644                // Direct match on an enum value: match <expr> { ... }
645                match self.infer_type(value, scope) {
646                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
647                    _ => None,
648                }
649            }
650        };
651
652        let Some(enum_name) = enum_name else {
653            return;
654        };
655        let Some(variants) = scope.get_enum(&enum_name) else {
656            return;
657        };
658
659        // Collect variant names covered by match arms
660        let mut covered: Vec<String> = Vec::new();
661        let mut has_wildcard = false;
662
663        for arm in arms {
664            match &arm.pattern.node {
665                // String literal pattern (matching on .variant): "VariantA"
666                Node::StringLiteral(s) => covered.push(s.clone()),
667                // Identifier pattern acts as a wildcard/catch-all
668                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
669                    has_wildcard = true;
670                }
671                // Direct enum construct pattern: EnumName.Variant
672                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
673                // PropertyAccess pattern: EnumName.Variant (no args)
674                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
675                _ => {
676                    // Unknown pattern shape — conservatively treat as wildcard
677                    has_wildcard = true;
678                }
679            }
680        }
681
682        if has_wildcard {
683            return;
684        }
685
686        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
687        if !missing.is_empty() {
688            let missing_str = missing
689                .iter()
690                .map(|s| format!("\"{}\"", s))
691                .collect::<Vec<_>>()
692                .join(", ");
693            self.warning_at(
694                format!(
695                    "Non-exhaustive match on enum {}: missing variants {}",
696                    enum_name, missing_str
697                ),
698                span,
699            );
700        }
701    }
702
703    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
704        // Check against known function signatures
705        if let Some(sig) = scope.get_fn(name).cloned() {
706            if args.len() != sig.params.len() && !is_builtin(name) {
707                self.warning_at(
708                    format!(
709                        "Function '{}' expects {} arguments, got {}",
710                        name,
711                        sig.params.len(),
712                        args.len()
713                    ),
714                    span,
715                );
716            }
717            for (i, (arg, (param_name, param_type))) in
718                args.iter().zip(sig.params.iter()).enumerate()
719            {
720                if let Some(expected) = param_type {
721                    let actual = self.infer_type(arg, scope);
722                    if let Some(actual) = &actual {
723                        if !self.types_compatible(expected, actual, scope) {
724                            self.error_at(
725                                format!(
726                                    "Argument {} ('{}'): expected {}, got {}",
727                                    i + 1,
728                                    param_name,
729                                    format_type(expected),
730                                    format_type(actual)
731                                ),
732                                arg.span,
733                            );
734                        }
735                    }
736                }
737            }
738        }
739        // Check args recursively
740        for arg in args {
741            self.check_node(arg, scope);
742        }
743    }
744
745    /// Infer the type of an expression.
746    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
747        match &snode.node {
748            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
749            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
750            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
751                Some(TypeExpr::Named("string".into()))
752            }
753            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
754            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
755            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
756            Node::DictLiteral(_) => Some(TypeExpr::Named("dict".into())),
757            Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
758
759            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
760
761            Node::FunctionCall { name, .. } => {
762                // Check user-defined function return types
763                if let Some(sig) = scope.get_fn(name) {
764                    return sig.return_type.clone();
765                }
766                // Check builtin return types
767                builtin_return_type(name)
768            }
769
770            Node::BinaryOp { op, left, right } => {
771                let lt = self.infer_type(left, scope);
772                let rt = self.infer_type(right, scope);
773                infer_binary_op_type(op, &lt, &rt)
774            }
775
776            Node::UnaryOp { op, operand } => {
777                let t = self.infer_type(operand, scope);
778                match op.as_str() {
779                    "!" => Some(TypeExpr::Named("bool".into())),
780                    "-" => t, // negation preserves type
781                    _ => None,
782                }
783            }
784
785            Node::Ternary {
786                true_expr,
787                false_expr,
788                ..
789            } => {
790                let tt = self.infer_type(true_expr, scope);
791                let ft = self.infer_type(false_expr, scope);
792                match (&tt, &ft) {
793                    (Some(a), Some(b)) if a == b => tt,
794                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
795                    (Some(_), None) => tt,
796                    (None, Some(_)) => ft,
797                    (None, None) => None,
798                }
799            }
800
801            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
802
803            Node::PropertyAccess { object, property } => {
804                // EnumName.Variant → infer as the enum type
805                if let Node::Identifier(name) = &object.node {
806                    if scope.get_enum(name).is_some() {
807                        return Some(TypeExpr::Named(name.clone()));
808                    }
809                }
810                // .variant on an enum value → string
811                if property == "variant" {
812                    let obj_type = self.infer_type(object, scope);
813                    if let Some(TypeExpr::Named(name)) = &obj_type {
814                        if scope.get_enum(name).is_some() {
815                            return Some(TypeExpr::Named("string".into()));
816                        }
817                    }
818                }
819                None
820            }
821
822            Node::SubscriptAccess { object, .. } => {
823                let obj_type = self.infer_type(object, scope);
824                match &obj_type {
825                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
826                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
827                    Some(TypeExpr::Named(n)) if n == "list" => None,
828                    Some(TypeExpr::Named(n)) if n == "dict" => None,
829                    Some(TypeExpr::Named(n)) if n == "string" => {
830                        Some(TypeExpr::Named("string".into()))
831                    }
832                    _ => None,
833                }
834            }
835            Node::SliceAccess { object, .. } => {
836                // Slicing a list returns the same list type; slicing a string returns string
837                let obj_type = self.infer_type(object, scope);
838                match &obj_type {
839                    Some(TypeExpr::List(_)) => obj_type,
840                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
841                    Some(TypeExpr::Named(n)) if n == "string" => {
842                        Some(TypeExpr::Named("string".into()))
843                    }
844                    _ => None,
845                }
846            }
847            Node::MethodCall { object, method, .. }
848            | Node::OptionalMethodCall { object, method, .. } => {
849                let obj_type = self.infer_type(object, scope);
850                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
851                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
852                match method.as_str() {
853                    // Shared: bool-returning methods
854                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
855                        Some(TypeExpr::Named("bool".into()))
856                    }
857                    // Shared: int-returning methods
858                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
859                    // String methods
860                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
861                    | "pad_left" | "pad_right" | "repeat" | "join" => {
862                        Some(TypeExpr::Named("string".into()))
863                    }
864                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
865                    // filter returns dict for dicts, list for lists
866                    "filter" => {
867                        if is_dict {
868                            Some(TypeExpr::Named("dict".into()))
869                        } else {
870                            Some(TypeExpr::Named("list".into()))
871                        }
872                    }
873                    // List methods
874                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
875                    "reduce" | "find" | "first" | "last" => None,
876                    // Dict methods
877                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
878                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
879                    // Conversions
880                    "to_string" => Some(TypeExpr::Named("string".into())),
881                    "to_int" => Some(TypeExpr::Named("int".into())),
882                    "to_float" => Some(TypeExpr::Named("float".into())),
883                    _ => None,
884                }
885            }
886
887            _ => None,
888        }
889    }
890
891    /// Check if two types are compatible (actual can be assigned to expected).
892    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
893        let expected = self.resolve_alias(expected, scope);
894        let actual = self.resolve_alias(actual, scope);
895
896        match (&expected, &actual) {
897            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
898            (TypeExpr::Union(members), actual_type) => members
899                .iter()
900                .any(|m| self.types_compatible(m, actual_type, scope)),
901            (expected_type, TypeExpr::Union(members)) => members
902                .iter()
903                .all(|m| self.types_compatible(expected_type, m, scope)),
904            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
905            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
906                if expected_field.optional {
907                    return true;
908                }
909                af.iter().any(|actual_field| {
910                    actual_field.name == expected_field.name
911                        && self.types_compatible(
912                            &expected_field.type_expr,
913                            &actual_field.type_expr,
914                            scope,
915                        )
916                })
917            }),
918            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
919                self.types_compatible(expected_inner, actual_inner, scope)
920            }
921            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
922            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
923            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
924                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
925            }
926            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
927            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
928            _ => false,
929        }
930    }
931
932    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
933        if let TypeExpr::Named(name) = ty {
934            if let Some(resolved) = scope.resolve_type(name) {
935                return resolved.clone();
936            }
937        }
938        ty.clone()
939    }
940
941    fn error_at(&mut self, message: String, span: Span) {
942        self.diagnostics.push(TypeDiagnostic {
943            message,
944            severity: DiagnosticSeverity::Error,
945            span: Some(span),
946        });
947    }
948
949    fn warning_at(&mut self, message: String, span: Span) {
950        self.diagnostics.push(TypeDiagnostic {
951            message,
952            severity: DiagnosticSeverity::Warning,
953            span: Some(span),
954        });
955    }
956}
957
958impl Default for TypeChecker {
959    fn default() -> Self {
960        Self::new()
961    }
962}
963
964/// Infer the result type of a binary operation.
965fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
966    match op {
967        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
968        "+" => match (left, right) {
969            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
970                match (l.as_str(), r.as_str()) {
971                    ("int", "int") => Some(TypeExpr::Named("int".into())),
972                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
973                    ("string", _) => Some(TypeExpr::Named("string".into())),
974                    ("list", "list") => Some(TypeExpr::Named("list".into())),
975                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
976                    _ => Some(TypeExpr::Named("string".into())),
977                }
978            }
979            _ => None,
980        },
981        "-" | "*" | "/" | "%" => match (left, right) {
982            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
983                match (l.as_str(), r.as_str()) {
984                    ("int", "int") => Some(TypeExpr::Named("int".into())),
985                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
986                    _ => None,
987                }
988            }
989            _ => None,
990        },
991        "??" => match (left, right) {
992            (Some(TypeExpr::Union(members)), _) => {
993                let non_nil: Vec<_> = members
994                    .iter()
995                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
996                    .cloned()
997                    .collect();
998                if non_nil.len() == 1 {
999                    Some(non_nil[0].clone())
1000                } else if non_nil.is_empty() {
1001                    right.clone()
1002                } else {
1003                    Some(TypeExpr::Union(non_nil))
1004                }
1005            }
1006            _ => right.clone(),
1007        },
1008        "|>" => None,
1009        _ => None,
1010    }
1011}
1012
1013/// Format a type expression for display in error messages.
1014pub fn format_type(ty: &TypeExpr) -> String {
1015    match ty {
1016        TypeExpr::Named(n) => n.clone(),
1017        TypeExpr::Union(types) => types
1018            .iter()
1019            .map(format_type)
1020            .collect::<Vec<_>>()
1021            .join(" | "),
1022        TypeExpr::Shape(fields) => {
1023            let inner: Vec<String> = fields
1024                .iter()
1025                .map(|f| {
1026                    let opt = if f.optional { "?" } else { "" };
1027                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1028                })
1029                .collect();
1030            format!("{{{}}}", inner.join(", "))
1031        }
1032        TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1033        TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040    use crate::Parser;
1041    use harn_lexer::Lexer;
1042
1043    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1044        let mut lexer = Lexer::new(source);
1045        let tokens = lexer.tokenize().unwrap();
1046        let mut parser = Parser::new(tokens);
1047        let program = parser.parse().unwrap();
1048        TypeChecker::new().check(&program)
1049    }
1050
1051    fn errors(source: &str) -> Vec<String> {
1052        check_source(source)
1053            .into_iter()
1054            .filter(|d| d.severity == DiagnosticSeverity::Error)
1055            .map(|d| d.message)
1056            .collect()
1057    }
1058
1059    #[test]
1060    fn test_no_errors_for_untyped_code() {
1061        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1062        assert!(errs.is_empty());
1063    }
1064
1065    #[test]
1066    fn test_correct_typed_let() {
1067        let errs = errors("pipeline t(task) { let x: int = 42 }");
1068        assert!(errs.is_empty());
1069    }
1070
1071    #[test]
1072    fn test_type_mismatch_let() {
1073        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1074        assert_eq!(errs.len(), 1);
1075        assert!(errs[0].contains("Type mismatch"));
1076        assert!(errs[0].contains("int"));
1077        assert!(errs[0].contains("string"));
1078    }
1079
1080    #[test]
1081    fn test_correct_typed_fn() {
1082        let errs = errors(
1083            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1084        );
1085        assert!(errs.is_empty());
1086    }
1087
1088    #[test]
1089    fn test_fn_arg_type_mismatch() {
1090        let errs = errors(
1091            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1092add("hello", 2) }"#,
1093        );
1094        assert_eq!(errs.len(), 1);
1095        assert!(errs[0].contains("Argument 1"));
1096        assert!(errs[0].contains("expected int"));
1097    }
1098
1099    #[test]
1100    fn test_return_type_mismatch() {
1101        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1102        assert_eq!(errs.len(), 1);
1103        assert!(errs[0].contains("Return type mismatch"));
1104    }
1105
1106    #[test]
1107    fn test_union_type_compatible() {
1108        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1109        assert!(errs.is_empty());
1110    }
1111
1112    #[test]
1113    fn test_union_type_mismatch() {
1114        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1115        assert_eq!(errs.len(), 1);
1116        assert!(errs[0].contains("Type mismatch"));
1117    }
1118
1119    #[test]
1120    fn test_type_inference_propagation() {
1121        let errs = errors(
1122            r#"pipeline t(task) {
1123  fn add(a: int, b: int) -> int { return a + b }
1124  let result: string = add(1, 2)
1125}"#,
1126        );
1127        assert_eq!(errs.len(), 1);
1128        assert!(errs[0].contains("Type mismatch"));
1129        assert!(errs[0].contains("string"));
1130        assert!(errs[0].contains("int"));
1131    }
1132
1133    #[test]
1134    fn test_builtin_return_type_inference() {
1135        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1136        assert_eq!(errs.len(), 1);
1137        assert!(errs[0].contains("string"));
1138        assert!(errs[0].contains("int"));
1139    }
1140
1141    #[test]
1142    fn test_binary_op_type_inference() {
1143        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1144        assert_eq!(errs.len(), 1);
1145    }
1146
1147    #[test]
1148    fn test_comparison_returns_bool() {
1149        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1150        assert!(errs.is_empty());
1151    }
1152
1153    #[test]
1154    fn test_int_float_promotion() {
1155        let errs = errors("pipeline t(task) { let x: float = 42 }");
1156        assert!(errs.is_empty());
1157    }
1158
1159    #[test]
1160    fn test_untyped_code_no_errors() {
1161        let errs = errors(
1162            r#"pipeline t(task) {
1163  fn process(data) {
1164    let result = data + " processed"
1165    return result
1166  }
1167  log(process("hello"))
1168}"#,
1169        );
1170        assert!(errs.is_empty());
1171    }
1172
1173    #[test]
1174    fn test_type_alias() {
1175        let errs = errors(
1176            r#"pipeline t(task) {
1177  type Name = string
1178  let x: Name = "hello"
1179}"#,
1180        );
1181        assert!(errs.is_empty());
1182    }
1183
1184    #[test]
1185    fn test_type_alias_mismatch() {
1186        let errs = errors(
1187            r#"pipeline t(task) {
1188  type Name = string
1189  let x: Name = 42
1190}"#,
1191        );
1192        assert_eq!(errs.len(), 1);
1193    }
1194
1195    #[test]
1196    fn test_assignment_type_check() {
1197        let errs = errors(
1198            r#"pipeline t(task) {
1199  var x: int = 0
1200  x = "hello"
1201}"#,
1202        );
1203        assert_eq!(errs.len(), 1);
1204        assert!(errs[0].contains("cannot assign string"));
1205    }
1206
1207    #[test]
1208    fn test_covariance_int_to_float_in_fn() {
1209        let errs = errors(
1210            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1211        );
1212        assert!(errs.is_empty());
1213    }
1214
1215    #[test]
1216    fn test_covariance_return_type() {
1217        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1218        assert!(errs.is_empty());
1219    }
1220
1221    #[test]
1222    fn test_no_contravariance_float_to_int() {
1223        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1224        assert_eq!(errs.len(), 1);
1225    }
1226
1227    // --- Exhaustiveness checking tests ---
1228
1229    fn warnings(source: &str) -> Vec<String> {
1230        check_source(source)
1231            .into_iter()
1232            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1233            .map(|d| d.message)
1234            .collect()
1235    }
1236
1237    #[test]
1238    fn test_exhaustive_match_no_warning() {
1239        let warns = warnings(
1240            r#"pipeline t(task) {
1241  enum Color { Red, Green, Blue }
1242  let c = Color.Red
1243  match c.variant {
1244    "Red" -> { log("r") }
1245    "Green" -> { log("g") }
1246    "Blue" -> { log("b") }
1247  }
1248}"#,
1249        );
1250        let exhaustive_warns: Vec<_> = warns
1251            .iter()
1252            .filter(|w| w.contains("Non-exhaustive"))
1253            .collect();
1254        assert!(exhaustive_warns.is_empty());
1255    }
1256
1257    #[test]
1258    fn test_non_exhaustive_match_warning() {
1259        let warns = warnings(
1260            r#"pipeline t(task) {
1261  enum Color { Red, Green, Blue }
1262  let c = Color.Red
1263  match c.variant {
1264    "Red" -> { log("r") }
1265    "Green" -> { log("g") }
1266  }
1267}"#,
1268        );
1269        let exhaustive_warns: Vec<_> = warns
1270            .iter()
1271            .filter(|w| w.contains("Non-exhaustive"))
1272            .collect();
1273        assert_eq!(exhaustive_warns.len(), 1);
1274        assert!(exhaustive_warns[0].contains("Blue"));
1275    }
1276
1277    #[test]
1278    fn test_non_exhaustive_multiple_missing() {
1279        let warns = warnings(
1280            r#"pipeline t(task) {
1281  enum Status { Active, Inactive, Pending }
1282  let s = Status.Active
1283  match s.variant {
1284    "Active" -> { log("a") }
1285  }
1286}"#,
1287        );
1288        let exhaustive_warns: Vec<_> = warns
1289            .iter()
1290            .filter(|w| w.contains("Non-exhaustive"))
1291            .collect();
1292        assert_eq!(exhaustive_warns.len(), 1);
1293        assert!(exhaustive_warns[0].contains("Inactive"));
1294        assert!(exhaustive_warns[0].contains("Pending"));
1295    }
1296
1297    #[test]
1298    fn test_enum_construct_type_inference() {
1299        let errs = errors(
1300            r#"pipeline t(task) {
1301  enum Color { Red, Green, Blue }
1302  let c: Color = Color.Red
1303}"#,
1304        );
1305        assert!(errs.is_empty());
1306    }
1307
1308    // --- Type narrowing tests ---
1309
1310    #[test]
1311    fn test_nil_coalescing_strips_nil() {
1312        // After ??, nil should be stripped from the type
1313        let errs = errors(
1314            r#"pipeline t(task) {
1315  let x: string | nil = nil
1316  let y: string = x ?? "default"
1317}"#,
1318        );
1319        assert!(errs.is_empty());
1320    }
1321}