Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43    params: Vec<(String, InferredType)>,
44    return_type: InferredType,
45}
46
47impl TypeScope {
48    fn new() -> Self {
49        Self {
50            vars: BTreeMap::new(),
51            functions: BTreeMap::new(),
52            type_aliases: BTreeMap::new(),
53            enums: BTreeMap::new(),
54            interfaces: BTreeMap::new(),
55            structs: BTreeMap::new(),
56            parent: None,
57        }
58    }
59
60    fn child(&self) -> Self {
61        Self {
62            vars: BTreeMap::new(),
63            functions: BTreeMap::new(),
64            type_aliases: BTreeMap::new(),
65            enums: BTreeMap::new(),
66            interfaces: BTreeMap::new(),
67            structs: BTreeMap::new(),
68            parent: Some(Box::new(self.clone())),
69        }
70    }
71
72    fn get_var(&self, name: &str) -> Option<&InferredType> {
73        self.vars
74            .get(name)
75            .or_else(|| self.parent.as_ref()?.get_var(name))
76    }
77
78    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79        self.functions
80            .get(name)
81            .or_else(|| self.parent.as_ref()?.get_fn(name))
82    }
83
84    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85        self.type_aliases
86            .get(name)
87            .or_else(|| self.parent.as_ref()?.resolve_type(name))
88    }
89
90    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91        self.enums
92            .get(name)
93            .or_else(|| self.parent.as_ref()?.get_enum(name))
94    }
95
96    #[allow(dead_code)]
97    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98        self.interfaces
99            .get(name)
100            .or_else(|| self.parent.as_ref()?.get_interface(name))
101    }
102
103    fn define_var(&mut self, name: &str, ty: InferredType) {
104        self.vars.insert(name.to_string(), ty);
105    }
106
107    fn define_fn(&mut self, name: &str, sig: FnSignature) {
108        self.functions.insert(name.to_string(), sig);
109    }
110}
111
112/// Known return types for builtin functions.
113fn builtin_return_type(name: &str) -> InferredType {
114    match name {
115        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117            Some(TypeExpr::Named("nil".into()))
118        }
119        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122        "to_int" => Some(TypeExpr::Named("int".into())),
123        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125        "list_dir" => Some(TypeExpr::Named("list".into())),
126        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127        "env" | "regex_match" => Some(TypeExpr::Union(vec![
128            TypeExpr::Named("string".into()),
129            TypeExpr::Named("nil".into()),
130        ])),
131        "json_parse" | "json_extract" => None, // could be any type
132        _ => None,
133    }
134}
135
136/// Check if a name is a known builtin.
137fn is_builtin(name: &str) -> bool {
138    matches!(
139        name,
140        "log"
141            | "print"
142            | "println"
143            | "type_of"
144            | "to_string"
145            | "to_int"
146            | "to_float"
147            | "json_stringify"
148            | "json_parse"
149            | "env"
150            | "timestamp"
151            | "sleep"
152            | "read_file"
153            | "write_file"
154            | "exit"
155            | "regex_match"
156            | "regex_replace"
157            | "http_get"
158            | "http_post"
159            | "llm_call"
160            | "agent_loop"
161            | "await"
162            | "cancel"
163            | "file_exists"
164            | "delete_file"
165            | "list_dir"
166            | "mkdir"
167            | "path_join"
168            | "copy_file"
169            | "append_file"
170            | "temp_dir"
171            | "stat"
172            | "exec"
173            | "shell"
174            | "date_now"
175            | "date_format"
176            | "date_parse"
177            | "format"
178            | "json_validate"
179            | "json_extract"
180            | "trim"
181            | "lowercase"
182            | "uppercase"
183            | "split"
184            | "starts_with"
185            | "ends_with"
186            | "contains"
187            | "replace"
188            | "join"
189            | "len"
190            | "substring"
191            | "dirname"
192            | "basename"
193            | "extname"
194    )
195}
196
197/// The static type checker.
198pub struct TypeChecker {
199    diagnostics: Vec<TypeDiagnostic>,
200    scope: TypeScope,
201}
202
203impl TypeChecker {
204    pub fn new() -> Self {
205        Self {
206            diagnostics: Vec::new(),
207            scope: TypeScope::new(),
208        }
209    }
210
211    /// Check a program and return diagnostics.
212    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213        // First pass: register type and enum declarations into root scope
214        Self::register_declarations_into(&mut self.scope, program);
215
216        // Also scan pipeline bodies for declarations
217        for snode in program {
218            if let Node::Pipeline { body, .. } = &snode.node {
219                Self::register_declarations_into(&mut self.scope, body);
220            }
221        }
222
223        // Check each top-level node
224        for snode in program {
225            match &snode.node {
226                Node::Pipeline { params, body, .. } => {
227                    let mut child = self.scope.child();
228                    for p in params {
229                        child.define_var(p, None);
230                    }
231                    self.check_block(body, &mut child);
232                }
233                Node::FnDecl {
234                    name,
235                    params,
236                    return_type,
237                    body,
238                    ..
239                } => {
240                    let sig = FnSignature {
241                        params: params
242                            .iter()
243                            .map(|p| (p.name.clone(), p.type_expr.clone()))
244                            .collect(),
245                        return_type: return_type.clone(),
246                    };
247                    self.scope.define_fn(name, sig);
248                    self.check_fn_body(params, return_type, body);
249                }
250                _ => {
251                    let mut scope = self.scope.clone();
252                    self.check_node(snode, &mut scope);
253                    // Merge any new definitions back into the top-level scope
254                    for (name, ty) in scope.vars {
255                        self.scope.vars.entry(name).or_insert(ty);
256                    }
257                }
258            }
259        }
260
261        self.diagnostics
262    }
263
264    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
265    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266        for snode in nodes {
267            match &snode.node {
268                Node::TypeDecl { name, type_expr } => {
269                    scope.type_aliases.insert(name.clone(), type_expr.clone());
270                }
271                Node::EnumDecl { name, variants } => {
272                    let variant_names: Vec<String> =
273                        variants.iter().map(|v| v.name.clone()).collect();
274                    scope.enums.insert(name.clone(), variant_names);
275                }
276                Node::InterfaceDecl { name, methods } => {
277                    scope.interfaces.insert(name.clone(), methods.clone());
278                }
279                Node::StructDecl { name, fields } => {
280                    let field_types: Vec<(String, InferredType)> = fields
281                        .iter()
282                        .map(|f| (f.name.clone(), f.type_expr.clone()))
283                        .collect();
284                    scope.structs.insert(name.clone(), field_types);
285                }
286                _ => {}
287            }
288        }
289    }
290
291    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292        for stmt in stmts {
293            self.check_node(stmt, scope);
294        }
295    }
296
297    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
298        let span = snode.span;
299        match &snode.node {
300            Node::LetBinding {
301                name,
302                type_ann,
303                value,
304            } => {
305                let inferred = self.infer_type(value, scope);
306                if let Some(expected) = type_ann {
307                    if let Some(actual) = &inferred {
308                        if !self.types_compatible(expected, actual, scope) {
309                            self.error_at(
310                                format!(
311                                    "Type mismatch: '{}' declared as {}, but assigned {}",
312                                    name,
313                                    format_type(expected),
314                                    format_type(actual)
315                                ),
316                                span,
317                            );
318                        }
319                    }
320                }
321                let ty = type_ann.clone().or(inferred);
322                scope.define_var(name, ty);
323            }
324
325            Node::VarBinding {
326                name,
327                type_ann,
328                value,
329            } => {
330                let inferred = self.infer_type(value, scope);
331                if let Some(expected) = type_ann {
332                    if let Some(actual) = &inferred {
333                        if !self.types_compatible(expected, actual, scope) {
334                            self.error_at(
335                                format!(
336                                    "Type mismatch: '{}' declared as {}, but assigned {}",
337                                    name,
338                                    format_type(expected),
339                                    format_type(actual)
340                                ),
341                                span,
342                            );
343                        }
344                    }
345                }
346                let ty = type_ann.clone().or(inferred);
347                scope.define_var(name, ty);
348            }
349
350            Node::FnDecl {
351                name,
352                params,
353                return_type,
354                body,
355                ..
356            } => {
357                let sig = FnSignature {
358                    params: params
359                        .iter()
360                        .map(|p| (p.name.clone(), p.type_expr.clone()))
361                        .collect(),
362                    return_type: return_type.clone(),
363                };
364                scope.define_fn(name, sig.clone());
365                scope.define_var(name, None);
366                self.check_fn_body(params, return_type, body);
367            }
368
369            Node::FunctionCall { name, args } => {
370                self.check_call(name, args, scope, span);
371            }
372
373            Node::IfElse {
374                condition,
375                then_body,
376                else_body,
377            } => {
378                self.check_node(condition, scope);
379                let mut then_scope = scope.child();
380                self.check_block(then_body, &mut then_scope);
381                if let Some(else_body) = else_body {
382                    let mut else_scope = scope.child();
383                    self.check_block(else_body, &mut else_scope);
384                }
385            }
386
387            Node::ForIn {
388                variable,
389                iterable,
390                body,
391            } => {
392                self.check_node(iterable, scope);
393                let mut loop_scope = scope.child();
394                // Infer loop variable type from iterable
395                let elem_type = match self.infer_type(iterable, scope) {
396                    Some(TypeExpr::List(inner)) => Some(*inner),
397                    Some(TypeExpr::Named(n)) if n == "string" => {
398                        Some(TypeExpr::Named("string".into()))
399                    }
400                    _ => None,
401                };
402                loop_scope.define_var(variable, elem_type);
403                self.check_block(body, &mut loop_scope);
404            }
405
406            Node::WhileLoop { condition, body } => {
407                self.check_node(condition, scope);
408                let mut loop_scope = scope.child();
409                self.check_block(body, &mut loop_scope);
410            }
411
412            Node::TryCatch {
413                body,
414                error_var,
415                catch_body,
416                ..
417            } => {
418                let mut try_scope = scope.child();
419                self.check_block(body, &mut try_scope);
420                let mut catch_scope = scope.child();
421                if let Some(var) = error_var {
422                    catch_scope.define_var(var, None);
423                }
424                self.check_block(catch_body, &mut catch_scope);
425            }
426
427            Node::ReturnStmt {
428                value: Some(val), ..
429            } => {
430                self.check_node(val, scope);
431            }
432
433            Node::Assignment {
434                target, value, op, ..
435            } => {
436                self.check_node(value, scope);
437                if let Node::Identifier(name) = &target.node {
438                    if let Some(Some(var_type)) = scope.get_var(name) {
439                        let value_type = self.infer_type(value, scope);
440                        let assigned = if let Some(op) = op {
441                            let var_inferred = scope.get_var(name).cloned().flatten();
442                            infer_binary_op_type(op, &var_inferred, &value_type)
443                        } else {
444                            value_type
445                        };
446                        if let Some(actual) = &assigned {
447                            if !self.types_compatible(var_type, actual, scope) {
448                                self.error_at(
449                                    format!(
450                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
451                                        format_type(actual),
452                                        name,
453                                        format_type(var_type)
454                                    ),
455                                    span,
456                                );
457                            }
458                        }
459                    }
460                }
461            }
462
463            Node::TypeDecl { name, type_expr } => {
464                scope.type_aliases.insert(name.clone(), type_expr.clone());
465            }
466
467            Node::EnumDecl { name, variants } => {
468                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
469                scope.enums.insert(name.clone(), variant_names);
470            }
471
472            Node::StructDecl { name, fields } => {
473                let field_types: Vec<(String, InferredType)> = fields
474                    .iter()
475                    .map(|f| (f.name.clone(), f.type_expr.clone()))
476                    .collect();
477                scope.structs.insert(name.clone(), field_types);
478            }
479
480            Node::InterfaceDecl { name, methods } => {
481                scope.interfaces.insert(name.clone(), methods.clone());
482            }
483
484            Node::MatchExpr { value, arms } => {
485                self.check_node(value, scope);
486                for arm in arms {
487                    self.check_node(&arm.pattern, scope);
488                    let mut arm_scope = scope.child();
489                    self.check_block(&arm.body, &mut arm_scope);
490                }
491                self.check_match_exhaustiveness(value, arms, scope, span);
492            }
493
494            // Recurse into nested expressions + validate binary op types
495            Node::BinaryOp { op, left, right } => {
496                self.check_node(left, scope);
497                self.check_node(right, scope);
498                // Validate operator/type compatibility
499                let lt = self.infer_type(left, scope);
500                let rt = self.infer_type(right, scope);
501                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
502                    match op.as_str() {
503                        "-" | "*" | "/" | "%" => {
504                            let numeric = ["int", "float"];
505                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
506                                self.warning_at(
507                                    format!(
508                                        "Operator '{op}' may not be valid for types {} and {}",
509                                        l, r
510                                    ),
511                                    span,
512                                );
513                            }
514                        }
515                        "+" => {
516                            // + is valid for int, float, string, list, dict
517                            let valid = ["int", "float", "string", "list", "dict"];
518                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
519                                self.warning_at(
520                                    format!(
521                                        "Operator '+' may not be valid for types {} and {}",
522                                        l, r
523                                    ),
524                                    span,
525                                );
526                            }
527                        }
528                        _ => {}
529                    }
530                }
531            }
532            Node::UnaryOp { operand, .. } => {
533                self.check_node(operand, scope);
534            }
535            Node::MethodCall { object, args, .. }
536            | Node::OptionalMethodCall { object, args, .. } => {
537                self.check_node(object, scope);
538                for arg in args {
539                    self.check_node(arg, scope);
540                }
541            }
542            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
543                self.check_node(object, scope);
544            }
545            Node::SubscriptAccess { object, index } => {
546                self.check_node(object, scope);
547                self.check_node(index, scope);
548            }
549            Node::SliceAccess { object, start, end } => {
550                self.check_node(object, scope);
551                if let Some(s) = start {
552                    self.check_node(s, scope);
553                }
554                if let Some(e) = end {
555                    self.check_node(e, scope);
556                }
557            }
558
559            // Terminals — nothing to check
560            _ => {}
561        }
562    }
563
564    fn check_fn_body(
565        &mut self,
566        params: &[TypedParam],
567        return_type: &Option<TypeExpr>,
568        body: &[SNode],
569    ) {
570        let mut fn_scope = self.scope.child();
571        for param in params {
572            fn_scope.define_var(&param.name, param.type_expr.clone());
573        }
574        self.check_block(body, &mut fn_scope);
575
576        // Check return statements against declared return type
577        if let Some(ret_type) = return_type {
578            for stmt in body {
579                self.check_return_type(stmt, ret_type, &fn_scope);
580            }
581        }
582    }
583
584    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
585        let span = snode.span;
586        match &snode.node {
587            Node::ReturnStmt { value: Some(val) } => {
588                let inferred = self.infer_type(val, scope);
589                if let Some(actual) = &inferred {
590                    if !self.types_compatible(expected, actual, scope) {
591                        self.error_at(
592                            format!(
593                                "Return type mismatch: expected {}, got {}",
594                                format_type(expected),
595                                format_type(actual)
596                            ),
597                            span,
598                        );
599                    }
600                }
601            }
602            Node::IfElse {
603                then_body,
604                else_body,
605                ..
606            } => {
607                for stmt in then_body {
608                    self.check_return_type(stmt, expected, scope);
609                }
610                if let Some(else_body) = else_body {
611                    for stmt in else_body {
612                        self.check_return_type(stmt, expected, scope);
613                    }
614                }
615            }
616            _ => {}
617        }
618    }
619
620    /// Check if a match expression on an enum's `.variant` property covers all variants.
621    fn check_match_exhaustiveness(
622        &mut self,
623        value: &SNode,
624        arms: &[MatchArm],
625        scope: &TypeScope,
626        span: Span,
627    ) {
628        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
629        let enum_name = match &value.node {
630            Node::PropertyAccess { object, property } if property == "variant" => {
631                // Infer the type of the object
632                match self.infer_type(object, scope) {
633                    Some(TypeExpr::Named(name)) => {
634                        if scope.get_enum(&name).is_some() {
635                            Some(name)
636                        } else {
637                            None
638                        }
639                    }
640                    _ => None,
641                }
642            }
643            _ => {
644                // Direct match on an enum value: match <expr> { ... }
645                match self.infer_type(value, scope) {
646                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
647                    _ => None,
648                }
649            }
650        };
651
652        let Some(enum_name) = enum_name else {
653            return;
654        };
655        let Some(variants) = scope.get_enum(&enum_name) else {
656            return;
657        };
658
659        // Collect variant names covered by match arms
660        let mut covered: Vec<String> = Vec::new();
661        let mut has_wildcard = false;
662
663        for arm in arms {
664            match &arm.pattern.node {
665                // String literal pattern (matching on .variant): "VariantA"
666                Node::StringLiteral(s) => covered.push(s.clone()),
667                // Identifier pattern acts as a wildcard/catch-all
668                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
669                    has_wildcard = true;
670                }
671                // Direct enum construct pattern: EnumName.Variant
672                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
673                // PropertyAccess pattern: EnumName.Variant (no args)
674                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
675                _ => {
676                    // Unknown pattern shape — conservatively treat as wildcard
677                    has_wildcard = true;
678                }
679            }
680        }
681
682        if has_wildcard {
683            return;
684        }
685
686        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
687        if !missing.is_empty() {
688            let missing_str = missing
689                .iter()
690                .map(|s| format!("\"{}\"", s))
691                .collect::<Vec<_>>()
692                .join(", ");
693            self.warning_at(
694                format!(
695                    "Non-exhaustive match on enum {}: missing variants {}",
696                    enum_name, missing_str
697                ),
698                span,
699            );
700        }
701    }
702
703    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
704        // Check against known function signatures
705        if let Some(sig) = scope.get_fn(name).cloned() {
706            if args.len() != sig.params.len() && !is_builtin(name) {
707                self.warning_at(
708                    format!(
709                        "Function '{}' expects {} arguments, got {}",
710                        name,
711                        sig.params.len(),
712                        args.len()
713                    ),
714                    span,
715                );
716            }
717            for (i, (arg, (param_name, param_type))) in
718                args.iter().zip(sig.params.iter()).enumerate()
719            {
720                if let Some(expected) = param_type {
721                    let actual = self.infer_type(arg, scope);
722                    if let Some(actual) = &actual {
723                        if !self.types_compatible(expected, actual, scope) {
724                            self.error_at(
725                                format!(
726                                    "Argument {} ('{}'): expected {}, got {}",
727                                    i + 1,
728                                    param_name,
729                                    format_type(expected),
730                                    format_type(actual)
731                                ),
732                                arg.span,
733                            );
734                        }
735                    }
736                }
737            }
738        }
739        // Check args recursively
740        for arg in args {
741            self.check_node(arg, scope);
742        }
743    }
744
745    /// Infer the type of an expression.
746    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
747        match &snode.node {
748            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
749            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
750            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
751                Some(TypeExpr::Named("string".into()))
752            }
753            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
754            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
755            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
756            Node::DictLiteral(entries) => {
757                // Infer shape type when all keys are string literals
758                let mut fields = Vec::new();
759                let mut all_string_keys = true;
760                for entry in entries {
761                    if let Node::StringLiteral(key) = &entry.key.node {
762                        let val_type = self
763                            .infer_type(&entry.value, scope)
764                            .unwrap_or(TypeExpr::Named("nil".into()));
765                        fields.push(ShapeField {
766                            name: key.clone(),
767                            type_expr: val_type,
768                            optional: false,
769                        });
770                    } else {
771                        all_string_keys = false;
772                        break;
773                    }
774                }
775                if all_string_keys && !fields.is_empty() {
776                    Some(TypeExpr::Shape(fields))
777                } else {
778                    Some(TypeExpr::Named("dict".into()))
779                }
780            }
781            Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
782
783            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
784
785            Node::FunctionCall { name, .. } => {
786                // Check user-defined function return types
787                if let Some(sig) = scope.get_fn(name) {
788                    return sig.return_type.clone();
789                }
790                // Check builtin return types
791                builtin_return_type(name)
792            }
793
794            Node::BinaryOp { op, left, right } => {
795                let lt = self.infer_type(left, scope);
796                let rt = self.infer_type(right, scope);
797                infer_binary_op_type(op, &lt, &rt)
798            }
799
800            Node::UnaryOp { op, operand } => {
801                let t = self.infer_type(operand, scope);
802                match op.as_str() {
803                    "!" => Some(TypeExpr::Named("bool".into())),
804                    "-" => t, // negation preserves type
805                    _ => None,
806                }
807            }
808
809            Node::Ternary {
810                true_expr,
811                false_expr,
812                ..
813            } => {
814                let tt = self.infer_type(true_expr, scope);
815                let ft = self.infer_type(false_expr, scope);
816                match (&tt, &ft) {
817                    (Some(a), Some(b)) if a == b => tt,
818                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
819                    (Some(_), None) => tt,
820                    (None, Some(_)) => ft,
821                    (None, None) => None,
822                }
823            }
824
825            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
826
827            Node::PropertyAccess { object, property } => {
828                // EnumName.Variant → infer as the enum type
829                if let Node::Identifier(name) = &object.node {
830                    if scope.get_enum(name).is_some() {
831                        return Some(TypeExpr::Named(name.clone()));
832                    }
833                }
834                // .variant on an enum value → string
835                if property == "variant" {
836                    let obj_type = self.infer_type(object, scope);
837                    if let Some(TypeExpr::Named(name)) = &obj_type {
838                        if scope.get_enum(name).is_some() {
839                            return Some(TypeExpr::Named("string".into()));
840                        }
841                    }
842                }
843                // Shape field access: obj.field → field type
844                let obj_type = self.infer_type(object, scope);
845                if let Some(TypeExpr::Shape(fields)) = &obj_type {
846                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
847                        return Some(field.type_expr.clone());
848                    }
849                }
850                None
851            }
852
853            Node::SubscriptAccess { object, index } => {
854                let obj_type = self.infer_type(object, scope);
855                match &obj_type {
856                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
857                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
858                    Some(TypeExpr::Shape(fields)) => {
859                        // If index is a string literal, look up the field type
860                        if let Node::StringLiteral(key) = &index.node {
861                            fields
862                                .iter()
863                                .find(|f| &f.name == key)
864                                .map(|f| f.type_expr.clone())
865                        } else {
866                            None
867                        }
868                    }
869                    Some(TypeExpr::Named(n)) if n == "list" => None,
870                    Some(TypeExpr::Named(n)) if n == "dict" => None,
871                    Some(TypeExpr::Named(n)) if n == "string" => {
872                        Some(TypeExpr::Named("string".into()))
873                    }
874                    _ => None,
875                }
876            }
877            Node::SliceAccess { object, .. } => {
878                // Slicing a list returns the same list type; slicing a string returns string
879                let obj_type = self.infer_type(object, scope);
880                match &obj_type {
881                    Some(TypeExpr::List(_)) => obj_type,
882                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
883                    Some(TypeExpr::Named(n)) if n == "string" => {
884                        Some(TypeExpr::Named("string".into()))
885                    }
886                    _ => None,
887                }
888            }
889            Node::MethodCall { object, method, .. }
890            | Node::OptionalMethodCall { object, method, .. } => {
891                let obj_type = self.infer_type(object, scope);
892                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
893                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
894                match method.as_str() {
895                    // Shared: bool-returning methods
896                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
897                        Some(TypeExpr::Named("bool".into()))
898                    }
899                    // Shared: int-returning methods
900                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
901                    // String methods
902                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
903                    | "pad_left" | "pad_right" | "repeat" | "join" => {
904                        Some(TypeExpr::Named("string".into()))
905                    }
906                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
907                    // filter returns dict for dicts, list for lists
908                    "filter" => {
909                        if is_dict {
910                            Some(TypeExpr::Named("dict".into()))
911                        } else {
912                            Some(TypeExpr::Named("list".into()))
913                        }
914                    }
915                    // List methods
916                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
917                    "reduce" | "find" | "first" | "last" => None,
918                    // Dict methods
919                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
920                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
921                    // Conversions
922                    "to_string" => Some(TypeExpr::Named("string".into())),
923                    "to_int" => Some(TypeExpr::Named("int".into())),
924                    "to_float" => Some(TypeExpr::Named("float".into())),
925                    _ => None,
926                }
927            }
928
929            _ => None,
930        }
931    }
932
933    /// Check if two types are compatible (actual can be assigned to expected).
934    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
935        let expected = self.resolve_alias(expected, scope);
936        let actual = self.resolve_alias(actual, scope);
937
938        match (&expected, &actual) {
939            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
940            (TypeExpr::Union(members), actual_type) => members
941                .iter()
942                .any(|m| self.types_compatible(m, actual_type, scope)),
943            (expected_type, TypeExpr::Union(members)) => members
944                .iter()
945                .all(|m| self.types_compatible(expected_type, m, scope)),
946            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
947            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
948            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
949                if expected_field.optional {
950                    return true;
951                }
952                af.iter().any(|actual_field| {
953                    actual_field.name == expected_field.name
954                        && self.types_compatible(
955                            &expected_field.type_expr,
956                            &actual_field.type_expr,
957                            scope,
958                        )
959                })
960            }),
961            // dict[K, V] expected, Shape actual → all field values must match V
962            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
963                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
964                keys_ok
965                    && af
966                        .iter()
967                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
968            }
969            // Shape expected, dict[K, V] actual → gradual: allow since dict may have the fields
970            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
971            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
972                self.types_compatible(expected_inner, actual_inner, scope)
973            }
974            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
975            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
976            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
977                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
978            }
979            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
980            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
981            _ => false,
982        }
983    }
984
985    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
986        if let TypeExpr::Named(name) = ty {
987            if let Some(resolved) = scope.resolve_type(name) {
988                return resolved.clone();
989            }
990        }
991        ty.clone()
992    }
993
994    fn error_at(&mut self, message: String, span: Span) {
995        self.diagnostics.push(TypeDiagnostic {
996            message,
997            severity: DiagnosticSeverity::Error,
998            span: Some(span),
999        });
1000    }
1001
1002    fn warning_at(&mut self, message: String, span: Span) {
1003        self.diagnostics.push(TypeDiagnostic {
1004            message,
1005            severity: DiagnosticSeverity::Warning,
1006            span: Some(span),
1007        });
1008    }
1009}
1010
1011impl Default for TypeChecker {
1012    fn default() -> Self {
1013        Self::new()
1014    }
1015}
1016
1017/// Infer the result type of a binary operation.
1018fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1019    match op {
1020        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1021        "+" => match (left, right) {
1022            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1023                match (l.as_str(), r.as_str()) {
1024                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1025                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1026                    ("string", _) => Some(TypeExpr::Named("string".into())),
1027                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1028                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1029                    _ => Some(TypeExpr::Named("string".into())),
1030                }
1031            }
1032            _ => None,
1033        },
1034        "-" | "*" | "/" | "%" => match (left, right) {
1035            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1036                match (l.as_str(), r.as_str()) {
1037                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1038                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1039                    _ => None,
1040                }
1041            }
1042            _ => None,
1043        },
1044        "??" => match (left, right) {
1045            (Some(TypeExpr::Union(members)), _) => {
1046                let non_nil: Vec<_> = members
1047                    .iter()
1048                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1049                    .cloned()
1050                    .collect();
1051                if non_nil.len() == 1 {
1052                    Some(non_nil[0].clone())
1053                } else if non_nil.is_empty() {
1054                    right.clone()
1055                } else {
1056                    Some(TypeExpr::Union(non_nil))
1057                }
1058            }
1059            _ => right.clone(),
1060        },
1061        "|>" => None,
1062        _ => None,
1063    }
1064}
1065
1066/// Format a type expression for display in error messages.
1067pub fn format_type(ty: &TypeExpr) -> String {
1068    match ty {
1069        TypeExpr::Named(n) => n.clone(),
1070        TypeExpr::Union(types) => types
1071            .iter()
1072            .map(format_type)
1073            .collect::<Vec<_>>()
1074            .join(" | "),
1075        TypeExpr::Shape(fields) => {
1076            let inner: Vec<String> = fields
1077                .iter()
1078                .map(|f| {
1079                    let opt = if f.optional { "?" } else { "" };
1080                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1081                })
1082                .collect();
1083            format!("{{{}}}", inner.join(", "))
1084        }
1085        TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1086        TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1087    }
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092    use super::*;
1093    use crate::Parser;
1094    use harn_lexer::Lexer;
1095
1096    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1097        let mut lexer = Lexer::new(source);
1098        let tokens = lexer.tokenize().unwrap();
1099        let mut parser = Parser::new(tokens);
1100        let program = parser.parse().unwrap();
1101        TypeChecker::new().check(&program)
1102    }
1103
1104    fn errors(source: &str) -> Vec<String> {
1105        check_source(source)
1106            .into_iter()
1107            .filter(|d| d.severity == DiagnosticSeverity::Error)
1108            .map(|d| d.message)
1109            .collect()
1110    }
1111
1112    #[test]
1113    fn test_no_errors_for_untyped_code() {
1114        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1115        assert!(errs.is_empty());
1116    }
1117
1118    #[test]
1119    fn test_correct_typed_let() {
1120        let errs = errors("pipeline t(task) { let x: int = 42 }");
1121        assert!(errs.is_empty());
1122    }
1123
1124    #[test]
1125    fn test_type_mismatch_let() {
1126        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1127        assert_eq!(errs.len(), 1);
1128        assert!(errs[0].contains("Type mismatch"));
1129        assert!(errs[0].contains("int"));
1130        assert!(errs[0].contains("string"));
1131    }
1132
1133    #[test]
1134    fn test_correct_typed_fn() {
1135        let errs = errors(
1136            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1137        );
1138        assert!(errs.is_empty());
1139    }
1140
1141    #[test]
1142    fn test_fn_arg_type_mismatch() {
1143        let errs = errors(
1144            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1145add("hello", 2) }"#,
1146        );
1147        assert_eq!(errs.len(), 1);
1148        assert!(errs[0].contains("Argument 1"));
1149        assert!(errs[0].contains("expected int"));
1150    }
1151
1152    #[test]
1153    fn test_return_type_mismatch() {
1154        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1155        assert_eq!(errs.len(), 1);
1156        assert!(errs[0].contains("Return type mismatch"));
1157    }
1158
1159    #[test]
1160    fn test_union_type_compatible() {
1161        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1162        assert!(errs.is_empty());
1163    }
1164
1165    #[test]
1166    fn test_union_type_mismatch() {
1167        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1168        assert_eq!(errs.len(), 1);
1169        assert!(errs[0].contains("Type mismatch"));
1170    }
1171
1172    #[test]
1173    fn test_type_inference_propagation() {
1174        let errs = errors(
1175            r#"pipeline t(task) {
1176  fn add(a: int, b: int) -> int { return a + b }
1177  let result: string = add(1, 2)
1178}"#,
1179        );
1180        assert_eq!(errs.len(), 1);
1181        assert!(errs[0].contains("Type mismatch"));
1182        assert!(errs[0].contains("string"));
1183        assert!(errs[0].contains("int"));
1184    }
1185
1186    #[test]
1187    fn test_builtin_return_type_inference() {
1188        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1189        assert_eq!(errs.len(), 1);
1190        assert!(errs[0].contains("string"));
1191        assert!(errs[0].contains("int"));
1192    }
1193
1194    #[test]
1195    fn test_binary_op_type_inference() {
1196        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1197        assert_eq!(errs.len(), 1);
1198    }
1199
1200    #[test]
1201    fn test_comparison_returns_bool() {
1202        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1203        assert!(errs.is_empty());
1204    }
1205
1206    #[test]
1207    fn test_int_float_promotion() {
1208        let errs = errors("pipeline t(task) { let x: float = 42 }");
1209        assert!(errs.is_empty());
1210    }
1211
1212    #[test]
1213    fn test_untyped_code_no_errors() {
1214        let errs = errors(
1215            r#"pipeline t(task) {
1216  fn process(data) {
1217    let result = data + " processed"
1218    return result
1219  }
1220  log(process("hello"))
1221}"#,
1222        );
1223        assert!(errs.is_empty());
1224    }
1225
1226    #[test]
1227    fn test_type_alias() {
1228        let errs = errors(
1229            r#"pipeline t(task) {
1230  type Name = string
1231  let x: Name = "hello"
1232}"#,
1233        );
1234        assert!(errs.is_empty());
1235    }
1236
1237    #[test]
1238    fn test_type_alias_mismatch() {
1239        let errs = errors(
1240            r#"pipeline t(task) {
1241  type Name = string
1242  let x: Name = 42
1243}"#,
1244        );
1245        assert_eq!(errs.len(), 1);
1246    }
1247
1248    #[test]
1249    fn test_assignment_type_check() {
1250        let errs = errors(
1251            r#"pipeline t(task) {
1252  var x: int = 0
1253  x = "hello"
1254}"#,
1255        );
1256        assert_eq!(errs.len(), 1);
1257        assert!(errs[0].contains("cannot assign string"));
1258    }
1259
1260    #[test]
1261    fn test_covariance_int_to_float_in_fn() {
1262        let errs = errors(
1263            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1264        );
1265        assert!(errs.is_empty());
1266    }
1267
1268    #[test]
1269    fn test_covariance_return_type() {
1270        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1271        assert!(errs.is_empty());
1272    }
1273
1274    #[test]
1275    fn test_no_contravariance_float_to_int() {
1276        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1277        assert_eq!(errs.len(), 1);
1278    }
1279
1280    // --- Exhaustiveness checking tests ---
1281
1282    fn warnings(source: &str) -> Vec<String> {
1283        check_source(source)
1284            .into_iter()
1285            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1286            .map(|d| d.message)
1287            .collect()
1288    }
1289
1290    #[test]
1291    fn test_exhaustive_match_no_warning() {
1292        let warns = warnings(
1293            r#"pipeline t(task) {
1294  enum Color { Red, Green, Blue }
1295  let c = Color.Red
1296  match c.variant {
1297    "Red" -> { log("r") }
1298    "Green" -> { log("g") }
1299    "Blue" -> { log("b") }
1300  }
1301}"#,
1302        );
1303        let exhaustive_warns: Vec<_> = warns
1304            .iter()
1305            .filter(|w| w.contains("Non-exhaustive"))
1306            .collect();
1307        assert!(exhaustive_warns.is_empty());
1308    }
1309
1310    #[test]
1311    fn test_non_exhaustive_match_warning() {
1312        let warns = warnings(
1313            r#"pipeline t(task) {
1314  enum Color { Red, Green, Blue }
1315  let c = Color.Red
1316  match c.variant {
1317    "Red" -> { log("r") }
1318    "Green" -> { log("g") }
1319  }
1320}"#,
1321        );
1322        let exhaustive_warns: Vec<_> = warns
1323            .iter()
1324            .filter(|w| w.contains("Non-exhaustive"))
1325            .collect();
1326        assert_eq!(exhaustive_warns.len(), 1);
1327        assert!(exhaustive_warns[0].contains("Blue"));
1328    }
1329
1330    #[test]
1331    fn test_non_exhaustive_multiple_missing() {
1332        let warns = warnings(
1333            r#"pipeline t(task) {
1334  enum Status { Active, Inactive, Pending }
1335  let s = Status.Active
1336  match s.variant {
1337    "Active" -> { log("a") }
1338  }
1339}"#,
1340        );
1341        let exhaustive_warns: Vec<_> = warns
1342            .iter()
1343            .filter(|w| w.contains("Non-exhaustive"))
1344            .collect();
1345        assert_eq!(exhaustive_warns.len(), 1);
1346        assert!(exhaustive_warns[0].contains("Inactive"));
1347        assert!(exhaustive_warns[0].contains("Pending"));
1348    }
1349
1350    #[test]
1351    fn test_enum_construct_type_inference() {
1352        let errs = errors(
1353            r#"pipeline t(task) {
1354  enum Color { Red, Green, Blue }
1355  let c: Color = Color.Red
1356}"#,
1357        );
1358        assert!(errs.is_empty());
1359    }
1360
1361    // --- Type narrowing tests ---
1362
1363    #[test]
1364    fn test_nil_coalescing_strips_nil() {
1365        // After ??, nil should be stripped from the type
1366        let errs = errors(
1367            r#"pipeline t(task) {
1368  let x: string | nil = nil
1369  let y: string = x ?? "default"
1370}"#,
1371        );
1372        assert!(errs.is_empty());
1373    }
1374}