Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43    params: Vec<(String, InferredType)>,
44    return_type: InferredType,
45}
46
47impl TypeScope {
48    fn new() -> Self {
49        Self {
50            vars: BTreeMap::new(),
51            functions: BTreeMap::new(),
52            type_aliases: BTreeMap::new(),
53            enums: BTreeMap::new(),
54            interfaces: BTreeMap::new(),
55            structs: BTreeMap::new(),
56            parent: None,
57        }
58    }
59
60    fn child(&self) -> Self {
61        Self {
62            vars: BTreeMap::new(),
63            functions: BTreeMap::new(),
64            type_aliases: BTreeMap::new(),
65            enums: BTreeMap::new(),
66            interfaces: BTreeMap::new(),
67            structs: BTreeMap::new(),
68            parent: Some(Box::new(self.clone())),
69        }
70    }
71
72    fn get_var(&self, name: &str) -> Option<&InferredType> {
73        self.vars
74            .get(name)
75            .or_else(|| self.parent.as_ref()?.get_var(name))
76    }
77
78    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79        self.functions
80            .get(name)
81            .or_else(|| self.parent.as_ref()?.get_fn(name))
82    }
83
84    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85        self.type_aliases
86            .get(name)
87            .or_else(|| self.parent.as_ref()?.resolve_type(name))
88    }
89
90    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91        self.enums
92            .get(name)
93            .or_else(|| self.parent.as_ref()?.get_enum(name))
94    }
95
96    #[allow(dead_code)]
97    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98        self.interfaces
99            .get(name)
100            .or_else(|| self.parent.as_ref()?.get_interface(name))
101    }
102
103    fn define_var(&mut self, name: &str, ty: InferredType) {
104        self.vars.insert(name.to_string(), ty);
105    }
106
107    fn define_fn(&mut self, name: &str, sig: FnSignature) {
108        self.functions.insert(name.to_string(), sig);
109    }
110}
111
112/// Known return types for builtin functions.
113fn builtin_return_type(name: &str) -> InferredType {
114    match name {
115        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117            Some(TypeExpr::Named("nil".into()))
118        }
119        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122        "to_int" => Some(TypeExpr::Named("int".into())),
123        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124        "file_exists" | "json_validate" => Some(TypeExpr::Named("bool".into())),
125        "list_dir" => Some(TypeExpr::Named("list".into())),
126        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127        "env" | "regex_match" => Some(TypeExpr::Union(vec![
128            TypeExpr::Named("string".into()),
129            TypeExpr::Named("nil".into()),
130        ])),
131        "json_parse" | "json_extract" => None, // could be any type
132        _ => None,
133    }
134}
135
136/// Check if a name is a known builtin.
137fn is_builtin(name: &str) -> bool {
138    matches!(
139        name,
140        "log"
141            | "print"
142            | "println"
143            | "type_of"
144            | "to_string"
145            | "to_int"
146            | "to_float"
147            | "json_stringify"
148            | "json_parse"
149            | "env"
150            | "timestamp"
151            | "sleep"
152            | "read_file"
153            | "write_file"
154            | "exit"
155            | "regex_match"
156            | "regex_replace"
157            | "http_get"
158            | "http_post"
159            | "llm_call"
160            | "agent_loop"
161            | "await"
162            | "cancel"
163            | "file_exists"
164            | "delete_file"
165            | "list_dir"
166            | "mkdir"
167            | "path_join"
168            | "copy_file"
169            | "append_file"
170            | "temp_dir"
171            | "stat"
172            | "exec"
173            | "shell"
174            | "date_now"
175            | "date_format"
176            | "date_parse"
177            | "format"
178            | "json_validate"
179            | "json_extract"
180            | "trim"
181            | "lowercase"
182            | "uppercase"
183            | "split"
184            | "starts_with"
185            | "ends_with"
186            | "contains"
187            | "replace"
188            | "join"
189            | "len"
190            | "substring"
191            | "dirname"
192            | "basename"
193            | "extname"
194    )
195}
196
197/// The static type checker.
198pub struct TypeChecker {
199    diagnostics: Vec<TypeDiagnostic>,
200    scope: TypeScope,
201}
202
203impl TypeChecker {
204    pub fn new() -> Self {
205        Self {
206            diagnostics: Vec::new(),
207            scope: TypeScope::new(),
208        }
209    }
210
211    /// Check a program and return diagnostics.
212    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
213        // First pass: register type and enum declarations into root scope
214        Self::register_declarations_into(&mut self.scope, program);
215
216        // Also scan pipeline bodies for declarations
217        for snode in program {
218            if let Node::Pipeline { body, .. } = &snode.node {
219                Self::register_declarations_into(&mut self.scope, body);
220            }
221        }
222
223        // Check each top-level node
224        for snode in program {
225            match &snode.node {
226                Node::Pipeline { params, body, .. } => {
227                    let mut child = self.scope.child();
228                    for p in params {
229                        child.define_var(p, None);
230                    }
231                    self.check_block(body, &mut child);
232                }
233                Node::FnDecl {
234                    name,
235                    params,
236                    return_type,
237                    body,
238                    ..
239                } => {
240                    let sig = FnSignature {
241                        params: params
242                            .iter()
243                            .map(|p| (p.name.clone(), p.type_expr.clone()))
244                            .collect(),
245                        return_type: return_type.clone(),
246                    };
247                    self.scope.define_fn(name, sig);
248                    self.check_fn_body(params, return_type, body);
249                }
250                _ => {
251                    let mut scope = self.scope.clone();
252                    self.check_node(snode, &mut scope);
253                    // Merge any new definitions back into the top-level scope
254                    for (name, ty) in scope.vars {
255                        self.scope.vars.entry(name).or_insert(ty);
256                    }
257                }
258            }
259        }
260
261        self.diagnostics
262    }
263
264    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
265    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266        for snode in nodes {
267            match &snode.node {
268                Node::TypeDecl { name, type_expr } => {
269                    scope.type_aliases.insert(name.clone(), type_expr.clone());
270                }
271                Node::EnumDecl { name, variants } => {
272                    let variant_names: Vec<String> =
273                        variants.iter().map(|v| v.name.clone()).collect();
274                    scope.enums.insert(name.clone(), variant_names);
275                }
276                Node::InterfaceDecl { name, methods } => {
277                    scope.interfaces.insert(name.clone(), methods.clone());
278                }
279                Node::StructDecl { name, fields } => {
280                    let field_types: Vec<(String, InferredType)> = fields
281                        .iter()
282                        .map(|f| (f.name.clone(), f.type_expr.clone()))
283                        .collect();
284                    scope.structs.insert(name.clone(), field_types);
285                }
286                _ => {}
287            }
288        }
289    }
290
291    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
292        for stmt in stmts {
293            self.check_node(stmt, scope);
294        }
295    }
296
297    /// Define variables from a destructuring pattern in the given scope (as unknown type).
298    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
299        match pattern {
300            BindingPattern::Identifier(name) => {
301                scope.define_var(name, None);
302            }
303            BindingPattern::Dict(fields) => {
304                for field in fields {
305                    let name = field.alias.as_deref().unwrap_or(&field.key);
306                    scope.define_var(name, None);
307                }
308            }
309            BindingPattern::List(elements) => {
310                for elem in elements {
311                    scope.define_var(&elem.name, None);
312                }
313            }
314        }
315    }
316
317    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
318        let span = snode.span;
319        match &snode.node {
320            Node::LetBinding {
321                pattern,
322                type_ann,
323                value,
324            } => {
325                let inferred = self.infer_type(value, scope);
326                if let BindingPattern::Identifier(name) = pattern {
327                    if let Some(expected) = type_ann {
328                        if let Some(actual) = &inferred {
329                            if !self.types_compatible(expected, actual, scope) {
330                                self.error_at(
331                                    format!(
332                                        "Type mismatch: '{}' declared as {}, but assigned {}",
333                                        name,
334                                        format_type(expected),
335                                        format_type(actual)
336                                    ),
337                                    span,
338                                );
339                            }
340                        }
341                    }
342                    let ty = type_ann.clone().or(inferred);
343                    scope.define_var(name, ty);
344                } else {
345                    Self::define_pattern_vars(pattern, scope);
346                }
347            }
348
349            Node::VarBinding {
350                pattern,
351                type_ann,
352                value,
353            } => {
354                let inferred = self.infer_type(value, scope);
355                if let BindingPattern::Identifier(name) = pattern {
356                    if let Some(expected) = type_ann {
357                        if let Some(actual) = &inferred {
358                            if !self.types_compatible(expected, actual, scope) {
359                                self.error_at(
360                                    format!(
361                                        "Type mismatch: '{}' declared as {}, but assigned {}",
362                                        name,
363                                        format_type(expected),
364                                        format_type(actual)
365                                    ),
366                                    span,
367                                );
368                            }
369                        }
370                    }
371                    let ty = type_ann.clone().or(inferred);
372                    scope.define_var(name, ty);
373                } else {
374                    Self::define_pattern_vars(pattern, scope);
375                }
376            }
377
378            Node::FnDecl {
379                name,
380                params,
381                return_type,
382                body,
383                ..
384            } => {
385                let sig = FnSignature {
386                    params: params
387                        .iter()
388                        .map(|p| (p.name.clone(), p.type_expr.clone()))
389                        .collect(),
390                    return_type: return_type.clone(),
391                };
392                scope.define_fn(name, sig.clone());
393                scope.define_var(name, None);
394                self.check_fn_body(params, return_type, body);
395            }
396
397            Node::FunctionCall { name, args } => {
398                self.check_call(name, args, scope, span);
399            }
400
401            Node::IfElse {
402                condition,
403                then_body,
404                else_body,
405            } => {
406                self.check_node(condition, scope);
407                let mut then_scope = scope.child();
408                self.check_block(then_body, &mut then_scope);
409                if let Some(else_body) = else_body {
410                    let mut else_scope = scope.child();
411                    self.check_block(else_body, &mut else_scope);
412                }
413            }
414
415            Node::ForIn {
416                pattern,
417                iterable,
418                body,
419            } => {
420                self.check_node(iterable, scope);
421                let mut loop_scope = scope.child();
422                if let BindingPattern::Identifier(variable) = pattern {
423                    // Infer loop variable type from iterable
424                    let elem_type = match self.infer_type(iterable, scope) {
425                        Some(TypeExpr::List(inner)) => Some(*inner),
426                        Some(TypeExpr::Named(n)) if n == "string" => {
427                            Some(TypeExpr::Named("string".into()))
428                        }
429                        _ => None,
430                    };
431                    loop_scope.define_var(variable, elem_type);
432                } else {
433                    Self::define_pattern_vars(pattern, &mut loop_scope);
434                }
435                self.check_block(body, &mut loop_scope);
436            }
437
438            Node::WhileLoop { condition, body } => {
439                self.check_node(condition, scope);
440                let mut loop_scope = scope.child();
441                self.check_block(body, &mut loop_scope);
442            }
443
444            Node::TryCatch {
445                body,
446                error_var,
447                catch_body,
448                ..
449            } => {
450                let mut try_scope = scope.child();
451                self.check_block(body, &mut try_scope);
452                let mut catch_scope = scope.child();
453                if let Some(var) = error_var {
454                    catch_scope.define_var(var, None);
455                }
456                self.check_block(catch_body, &mut catch_scope);
457            }
458
459            Node::ReturnStmt {
460                value: Some(val), ..
461            } => {
462                self.check_node(val, scope);
463            }
464
465            Node::Assignment {
466                target, value, op, ..
467            } => {
468                self.check_node(value, scope);
469                if let Node::Identifier(name) = &target.node {
470                    if let Some(Some(var_type)) = scope.get_var(name) {
471                        let value_type = self.infer_type(value, scope);
472                        let assigned = if let Some(op) = op {
473                            let var_inferred = scope.get_var(name).cloned().flatten();
474                            infer_binary_op_type(op, &var_inferred, &value_type)
475                        } else {
476                            value_type
477                        };
478                        if let Some(actual) = &assigned {
479                            if !self.types_compatible(var_type, actual, scope) {
480                                self.error_at(
481                                    format!(
482                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
483                                        format_type(actual),
484                                        name,
485                                        format_type(var_type)
486                                    ),
487                                    span,
488                                );
489                            }
490                        }
491                    }
492                }
493            }
494
495            Node::TypeDecl { name, type_expr } => {
496                scope.type_aliases.insert(name.clone(), type_expr.clone());
497            }
498
499            Node::EnumDecl { name, variants } => {
500                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
501                scope.enums.insert(name.clone(), variant_names);
502            }
503
504            Node::StructDecl { name, fields } => {
505                let field_types: Vec<(String, InferredType)> = fields
506                    .iter()
507                    .map(|f| (f.name.clone(), f.type_expr.clone()))
508                    .collect();
509                scope.structs.insert(name.clone(), field_types);
510            }
511
512            Node::InterfaceDecl { name, methods } => {
513                scope.interfaces.insert(name.clone(), methods.clone());
514            }
515
516            Node::MatchExpr { value, arms } => {
517                self.check_node(value, scope);
518                for arm in arms {
519                    self.check_node(&arm.pattern, scope);
520                    let mut arm_scope = scope.child();
521                    self.check_block(&arm.body, &mut arm_scope);
522                }
523                self.check_match_exhaustiveness(value, arms, scope, span);
524            }
525
526            // Recurse into nested expressions + validate binary op types
527            Node::BinaryOp { op, left, right } => {
528                self.check_node(left, scope);
529                self.check_node(right, scope);
530                // Validate operator/type compatibility
531                let lt = self.infer_type(left, scope);
532                let rt = self.infer_type(right, scope);
533                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
534                    match op.as_str() {
535                        "-" | "*" | "/" | "%" => {
536                            let numeric = ["int", "float"];
537                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
538                                self.warning_at(
539                                    format!(
540                                        "Operator '{op}' may not be valid for types {} and {}",
541                                        l, r
542                                    ),
543                                    span,
544                                );
545                            }
546                        }
547                        "+" => {
548                            // + is valid for int, float, string, list, dict
549                            let valid = ["int", "float", "string", "list", "dict"];
550                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
551                                self.warning_at(
552                                    format!(
553                                        "Operator '+' may not be valid for types {} and {}",
554                                        l, r
555                                    ),
556                                    span,
557                                );
558                            }
559                        }
560                        _ => {}
561                    }
562                }
563            }
564            Node::UnaryOp { operand, .. } => {
565                self.check_node(operand, scope);
566            }
567            Node::MethodCall { object, args, .. }
568            | Node::OptionalMethodCall { object, args, .. } => {
569                self.check_node(object, scope);
570                for arg in args {
571                    self.check_node(arg, scope);
572                }
573            }
574            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
575                self.check_node(object, scope);
576            }
577            Node::SubscriptAccess { object, index } => {
578                self.check_node(object, scope);
579                self.check_node(index, scope);
580            }
581            Node::SliceAccess { object, start, end } => {
582                self.check_node(object, scope);
583                if let Some(s) = start {
584                    self.check_node(s, scope);
585                }
586                if let Some(e) = end {
587                    self.check_node(e, scope);
588                }
589            }
590
591            // Terminals — nothing to check
592            _ => {}
593        }
594    }
595
596    fn check_fn_body(
597        &mut self,
598        params: &[TypedParam],
599        return_type: &Option<TypeExpr>,
600        body: &[SNode],
601    ) {
602        let mut fn_scope = self.scope.child();
603        for param in params {
604            fn_scope.define_var(&param.name, param.type_expr.clone());
605        }
606        self.check_block(body, &mut fn_scope);
607
608        // Check return statements against declared return type
609        if let Some(ret_type) = return_type {
610            for stmt in body {
611                self.check_return_type(stmt, ret_type, &fn_scope);
612            }
613        }
614    }
615
616    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
617        let span = snode.span;
618        match &snode.node {
619            Node::ReturnStmt { value: Some(val) } => {
620                let inferred = self.infer_type(val, scope);
621                if let Some(actual) = &inferred {
622                    if !self.types_compatible(expected, actual, scope) {
623                        self.error_at(
624                            format!(
625                                "Return type mismatch: expected {}, got {}",
626                                format_type(expected),
627                                format_type(actual)
628                            ),
629                            span,
630                        );
631                    }
632                }
633            }
634            Node::IfElse {
635                then_body,
636                else_body,
637                ..
638            } => {
639                for stmt in then_body {
640                    self.check_return_type(stmt, expected, scope);
641                }
642                if let Some(else_body) = else_body {
643                    for stmt in else_body {
644                        self.check_return_type(stmt, expected, scope);
645                    }
646                }
647            }
648            _ => {}
649        }
650    }
651
652    /// Check if a match expression on an enum's `.variant` property covers all variants.
653    fn check_match_exhaustiveness(
654        &mut self,
655        value: &SNode,
656        arms: &[MatchArm],
657        scope: &TypeScope,
658        span: Span,
659    ) {
660        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
661        let enum_name = match &value.node {
662            Node::PropertyAccess { object, property } if property == "variant" => {
663                // Infer the type of the object
664                match self.infer_type(object, scope) {
665                    Some(TypeExpr::Named(name)) => {
666                        if scope.get_enum(&name).is_some() {
667                            Some(name)
668                        } else {
669                            None
670                        }
671                    }
672                    _ => None,
673                }
674            }
675            _ => {
676                // Direct match on an enum value: match <expr> { ... }
677                match self.infer_type(value, scope) {
678                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
679                    _ => None,
680                }
681            }
682        };
683
684        let Some(enum_name) = enum_name else {
685            return;
686        };
687        let Some(variants) = scope.get_enum(&enum_name) else {
688            return;
689        };
690
691        // Collect variant names covered by match arms
692        let mut covered: Vec<String> = Vec::new();
693        let mut has_wildcard = false;
694
695        for arm in arms {
696            match &arm.pattern.node {
697                // String literal pattern (matching on .variant): "VariantA"
698                Node::StringLiteral(s) => covered.push(s.clone()),
699                // Identifier pattern acts as a wildcard/catch-all
700                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
701                    has_wildcard = true;
702                }
703                // Direct enum construct pattern: EnumName.Variant
704                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
705                // PropertyAccess pattern: EnumName.Variant (no args)
706                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
707                _ => {
708                    // Unknown pattern shape — conservatively treat as wildcard
709                    has_wildcard = true;
710                }
711            }
712        }
713
714        if has_wildcard {
715            return;
716        }
717
718        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
719        if !missing.is_empty() {
720            let missing_str = missing
721                .iter()
722                .map(|s| format!("\"{}\"", s))
723                .collect::<Vec<_>>()
724                .join(", ");
725            self.warning_at(
726                format!(
727                    "Non-exhaustive match on enum {}: missing variants {}",
728                    enum_name, missing_str
729                ),
730                span,
731            );
732        }
733    }
734
735    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
736        // Check against known function signatures
737        if let Some(sig) = scope.get_fn(name).cloned() {
738            if args.len() != sig.params.len() && !is_builtin(name) {
739                self.warning_at(
740                    format!(
741                        "Function '{}' expects {} arguments, got {}",
742                        name,
743                        sig.params.len(),
744                        args.len()
745                    ),
746                    span,
747                );
748            }
749            for (i, (arg, (param_name, param_type))) in
750                args.iter().zip(sig.params.iter()).enumerate()
751            {
752                if let Some(expected) = param_type {
753                    let actual = self.infer_type(arg, scope);
754                    if let Some(actual) = &actual {
755                        if !self.types_compatible(expected, actual, scope) {
756                            self.error_at(
757                                format!(
758                                    "Argument {} ('{}'): expected {}, got {}",
759                                    i + 1,
760                                    param_name,
761                                    format_type(expected),
762                                    format_type(actual)
763                                ),
764                                arg.span,
765                            );
766                        }
767                    }
768                }
769            }
770        }
771        // Check args recursively
772        for arg in args {
773            self.check_node(arg, scope);
774        }
775    }
776
777    /// Infer the type of an expression.
778    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
779        match &snode.node {
780            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
781            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
782            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
783                Some(TypeExpr::Named("string".into()))
784            }
785            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
786            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
787            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
788            Node::DictLiteral(entries) => {
789                // Infer shape type when all keys are string literals
790                let mut fields = Vec::new();
791                let mut all_string_keys = true;
792                for entry in entries {
793                    if let Node::StringLiteral(key) = &entry.key.node {
794                        let val_type = self
795                            .infer_type(&entry.value, scope)
796                            .unwrap_or(TypeExpr::Named("nil".into()));
797                        fields.push(ShapeField {
798                            name: key.clone(),
799                            type_expr: val_type,
800                            optional: false,
801                        });
802                    } else {
803                        all_string_keys = false;
804                        break;
805                    }
806                }
807                if all_string_keys && !fields.is_empty() {
808                    Some(TypeExpr::Shape(fields))
809                } else {
810                    Some(TypeExpr::Named("dict".into()))
811                }
812            }
813            Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
814
815            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
816
817            Node::FunctionCall { name, .. } => {
818                // Check user-defined function return types
819                if let Some(sig) = scope.get_fn(name) {
820                    return sig.return_type.clone();
821                }
822                // Check builtin return types
823                builtin_return_type(name)
824            }
825
826            Node::BinaryOp { op, left, right } => {
827                let lt = self.infer_type(left, scope);
828                let rt = self.infer_type(right, scope);
829                infer_binary_op_type(op, &lt, &rt)
830            }
831
832            Node::UnaryOp { op, operand } => {
833                let t = self.infer_type(operand, scope);
834                match op.as_str() {
835                    "!" => Some(TypeExpr::Named("bool".into())),
836                    "-" => t, // negation preserves type
837                    _ => None,
838                }
839            }
840
841            Node::Ternary {
842                true_expr,
843                false_expr,
844                ..
845            } => {
846                let tt = self.infer_type(true_expr, scope);
847                let ft = self.infer_type(false_expr, scope);
848                match (&tt, &ft) {
849                    (Some(a), Some(b)) if a == b => tt,
850                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
851                    (Some(_), None) => tt,
852                    (None, Some(_)) => ft,
853                    (None, None) => None,
854                }
855            }
856
857            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
858
859            Node::PropertyAccess { object, property } => {
860                // EnumName.Variant → infer as the enum type
861                if let Node::Identifier(name) = &object.node {
862                    if scope.get_enum(name).is_some() {
863                        return Some(TypeExpr::Named(name.clone()));
864                    }
865                }
866                // .variant on an enum value → string
867                if property == "variant" {
868                    let obj_type = self.infer_type(object, scope);
869                    if let Some(TypeExpr::Named(name)) = &obj_type {
870                        if scope.get_enum(name).is_some() {
871                            return Some(TypeExpr::Named("string".into()));
872                        }
873                    }
874                }
875                // Shape field access: obj.field → field type
876                let obj_type = self.infer_type(object, scope);
877                if let Some(TypeExpr::Shape(fields)) = &obj_type {
878                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
879                        return Some(field.type_expr.clone());
880                    }
881                }
882                None
883            }
884
885            Node::SubscriptAccess { object, index } => {
886                let obj_type = self.infer_type(object, scope);
887                match &obj_type {
888                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
889                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
890                    Some(TypeExpr::Shape(fields)) => {
891                        // If index is a string literal, look up the field type
892                        if let Node::StringLiteral(key) = &index.node {
893                            fields
894                                .iter()
895                                .find(|f| &f.name == key)
896                                .map(|f| f.type_expr.clone())
897                        } else {
898                            None
899                        }
900                    }
901                    Some(TypeExpr::Named(n)) if n == "list" => None,
902                    Some(TypeExpr::Named(n)) if n == "dict" => None,
903                    Some(TypeExpr::Named(n)) if n == "string" => {
904                        Some(TypeExpr::Named("string".into()))
905                    }
906                    _ => None,
907                }
908            }
909            Node::SliceAccess { object, .. } => {
910                // Slicing a list returns the same list type; slicing a string returns string
911                let obj_type = self.infer_type(object, scope);
912                match &obj_type {
913                    Some(TypeExpr::List(_)) => obj_type,
914                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
915                    Some(TypeExpr::Named(n)) if n == "string" => {
916                        Some(TypeExpr::Named("string".into()))
917                    }
918                    _ => None,
919                }
920            }
921            Node::MethodCall { object, method, .. }
922            | Node::OptionalMethodCall { object, method, .. } => {
923                let obj_type = self.infer_type(object, scope);
924                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
925                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
926                match method.as_str() {
927                    // Shared: bool-returning methods
928                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
929                        Some(TypeExpr::Named("bool".into()))
930                    }
931                    // Shared: int-returning methods
932                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
933                    // String methods
934                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
935                    | "pad_left" | "pad_right" | "repeat" | "join" => {
936                        Some(TypeExpr::Named("string".into()))
937                    }
938                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
939                    // filter returns dict for dicts, list for lists
940                    "filter" => {
941                        if is_dict {
942                            Some(TypeExpr::Named("dict".into()))
943                        } else {
944                            Some(TypeExpr::Named("list".into()))
945                        }
946                    }
947                    // List methods
948                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
949                    "reduce" | "find" | "first" | "last" => None,
950                    // Dict methods
951                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
952                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
953                    // Conversions
954                    "to_string" => Some(TypeExpr::Named("string".into())),
955                    "to_int" => Some(TypeExpr::Named("int".into())),
956                    "to_float" => Some(TypeExpr::Named("float".into())),
957                    _ => None,
958                }
959            }
960
961            _ => None,
962        }
963    }
964
965    /// Check if two types are compatible (actual can be assigned to expected).
966    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
967        let expected = self.resolve_alias(expected, scope);
968        let actual = self.resolve_alias(actual, scope);
969
970        match (&expected, &actual) {
971            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
972            (TypeExpr::Union(members), actual_type) => members
973                .iter()
974                .any(|m| self.types_compatible(m, actual_type, scope)),
975            (expected_type, TypeExpr::Union(members)) => members
976                .iter()
977                .all(|m| self.types_compatible(expected_type, m, scope)),
978            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
979            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
980            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
981                if expected_field.optional {
982                    return true;
983                }
984                af.iter().any(|actual_field| {
985                    actual_field.name == expected_field.name
986                        && self.types_compatible(
987                            &expected_field.type_expr,
988                            &actual_field.type_expr,
989                            scope,
990                        )
991                })
992            }),
993            // dict[K, V] expected, Shape actual → all field values must match V
994            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
995                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
996                keys_ok
997                    && af
998                        .iter()
999                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1000            }
1001            // Shape expected, dict[K, V] actual → gradual: allow since dict may have the fields
1002            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1003            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1004                self.types_compatible(expected_inner, actual_inner, scope)
1005            }
1006            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1007            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1008            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1009                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1010            }
1011            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1012            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1013            _ => false,
1014        }
1015    }
1016
1017    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1018        if let TypeExpr::Named(name) = ty {
1019            if let Some(resolved) = scope.resolve_type(name) {
1020                return resolved.clone();
1021            }
1022        }
1023        ty.clone()
1024    }
1025
1026    fn error_at(&mut self, message: String, span: Span) {
1027        self.diagnostics.push(TypeDiagnostic {
1028            message,
1029            severity: DiagnosticSeverity::Error,
1030            span: Some(span),
1031        });
1032    }
1033
1034    fn warning_at(&mut self, message: String, span: Span) {
1035        self.diagnostics.push(TypeDiagnostic {
1036            message,
1037            severity: DiagnosticSeverity::Warning,
1038            span: Some(span),
1039        });
1040    }
1041}
1042
1043impl Default for TypeChecker {
1044    fn default() -> Self {
1045        Self::new()
1046    }
1047}
1048
1049/// Infer the result type of a binary operation.
1050fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1051    match op {
1052        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
1053        "+" => match (left, right) {
1054            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1055                match (l.as_str(), r.as_str()) {
1056                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1057                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1058                    ("string", _) => Some(TypeExpr::Named("string".into())),
1059                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1060                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1061                    _ => Some(TypeExpr::Named("string".into())),
1062                }
1063            }
1064            _ => None,
1065        },
1066        "-" | "*" | "/" | "%" => match (left, right) {
1067            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1068                match (l.as_str(), r.as_str()) {
1069                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1070                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1071                    _ => None,
1072                }
1073            }
1074            _ => None,
1075        },
1076        "??" => match (left, right) {
1077            (Some(TypeExpr::Union(members)), _) => {
1078                let non_nil: Vec<_> = members
1079                    .iter()
1080                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1081                    .cloned()
1082                    .collect();
1083                if non_nil.len() == 1 {
1084                    Some(non_nil[0].clone())
1085                } else if non_nil.is_empty() {
1086                    right.clone()
1087                } else {
1088                    Some(TypeExpr::Union(non_nil))
1089                }
1090            }
1091            _ => right.clone(),
1092        },
1093        "|>" => None,
1094        _ => None,
1095    }
1096}
1097
1098/// Format a type expression for display in error messages.
1099pub fn format_type(ty: &TypeExpr) -> String {
1100    match ty {
1101        TypeExpr::Named(n) => n.clone(),
1102        TypeExpr::Union(types) => types
1103            .iter()
1104            .map(format_type)
1105            .collect::<Vec<_>>()
1106            .join(" | "),
1107        TypeExpr::Shape(fields) => {
1108            let inner: Vec<String> = fields
1109                .iter()
1110                .map(|f| {
1111                    let opt = if f.optional { "?" } else { "" };
1112                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1113                })
1114                .collect();
1115            format!("{{{}}}", inner.join(", "))
1116        }
1117        TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
1118        TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
1119    }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124    use super::*;
1125    use crate::Parser;
1126    use harn_lexer::Lexer;
1127
1128    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1129        let mut lexer = Lexer::new(source);
1130        let tokens = lexer.tokenize().unwrap();
1131        let mut parser = Parser::new(tokens);
1132        let program = parser.parse().unwrap();
1133        TypeChecker::new().check(&program)
1134    }
1135
1136    fn errors(source: &str) -> Vec<String> {
1137        check_source(source)
1138            .into_iter()
1139            .filter(|d| d.severity == DiagnosticSeverity::Error)
1140            .map(|d| d.message)
1141            .collect()
1142    }
1143
1144    #[test]
1145    fn test_no_errors_for_untyped_code() {
1146        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1147        assert!(errs.is_empty());
1148    }
1149
1150    #[test]
1151    fn test_correct_typed_let() {
1152        let errs = errors("pipeline t(task) { let x: int = 42 }");
1153        assert!(errs.is_empty());
1154    }
1155
1156    #[test]
1157    fn test_type_mismatch_let() {
1158        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1159        assert_eq!(errs.len(), 1);
1160        assert!(errs[0].contains("Type mismatch"));
1161        assert!(errs[0].contains("int"));
1162        assert!(errs[0].contains("string"));
1163    }
1164
1165    #[test]
1166    fn test_correct_typed_fn() {
1167        let errs = errors(
1168            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1169        );
1170        assert!(errs.is_empty());
1171    }
1172
1173    #[test]
1174    fn test_fn_arg_type_mismatch() {
1175        let errs = errors(
1176            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1177add("hello", 2) }"#,
1178        );
1179        assert_eq!(errs.len(), 1);
1180        assert!(errs[0].contains("Argument 1"));
1181        assert!(errs[0].contains("expected int"));
1182    }
1183
1184    #[test]
1185    fn test_return_type_mismatch() {
1186        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1187        assert_eq!(errs.len(), 1);
1188        assert!(errs[0].contains("Return type mismatch"));
1189    }
1190
1191    #[test]
1192    fn test_union_type_compatible() {
1193        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1194        assert!(errs.is_empty());
1195    }
1196
1197    #[test]
1198    fn test_union_type_mismatch() {
1199        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1200        assert_eq!(errs.len(), 1);
1201        assert!(errs[0].contains("Type mismatch"));
1202    }
1203
1204    #[test]
1205    fn test_type_inference_propagation() {
1206        let errs = errors(
1207            r#"pipeline t(task) {
1208  fn add(a: int, b: int) -> int { return a + b }
1209  let result: string = add(1, 2)
1210}"#,
1211        );
1212        assert_eq!(errs.len(), 1);
1213        assert!(errs[0].contains("Type mismatch"));
1214        assert!(errs[0].contains("string"));
1215        assert!(errs[0].contains("int"));
1216    }
1217
1218    #[test]
1219    fn test_builtin_return_type_inference() {
1220        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1221        assert_eq!(errs.len(), 1);
1222        assert!(errs[0].contains("string"));
1223        assert!(errs[0].contains("int"));
1224    }
1225
1226    #[test]
1227    fn test_binary_op_type_inference() {
1228        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1229        assert_eq!(errs.len(), 1);
1230    }
1231
1232    #[test]
1233    fn test_comparison_returns_bool() {
1234        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1235        assert!(errs.is_empty());
1236    }
1237
1238    #[test]
1239    fn test_int_float_promotion() {
1240        let errs = errors("pipeline t(task) { let x: float = 42 }");
1241        assert!(errs.is_empty());
1242    }
1243
1244    #[test]
1245    fn test_untyped_code_no_errors() {
1246        let errs = errors(
1247            r#"pipeline t(task) {
1248  fn process(data) {
1249    let result = data + " processed"
1250    return result
1251  }
1252  log(process("hello"))
1253}"#,
1254        );
1255        assert!(errs.is_empty());
1256    }
1257
1258    #[test]
1259    fn test_type_alias() {
1260        let errs = errors(
1261            r#"pipeline t(task) {
1262  type Name = string
1263  let x: Name = "hello"
1264}"#,
1265        );
1266        assert!(errs.is_empty());
1267    }
1268
1269    #[test]
1270    fn test_type_alias_mismatch() {
1271        let errs = errors(
1272            r#"pipeline t(task) {
1273  type Name = string
1274  let x: Name = 42
1275}"#,
1276        );
1277        assert_eq!(errs.len(), 1);
1278    }
1279
1280    #[test]
1281    fn test_assignment_type_check() {
1282        let errs = errors(
1283            r#"pipeline t(task) {
1284  var x: int = 0
1285  x = "hello"
1286}"#,
1287        );
1288        assert_eq!(errs.len(), 1);
1289        assert!(errs[0].contains("cannot assign string"));
1290    }
1291
1292    #[test]
1293    fn test_covariance_int_to_float_in_fn() {
1294        let errs = errors(
1295            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1296        );
1297        assert!(errs.is_empty());
1298    }
1299
1300    #[test]
1301    fn test_covariance_return_type() {
1302        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1303        assert!(errs.is_empty());
1304    }
1305
1306    #[test]
1307    fn test_no_contravariance_float_to_int() {
1308        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1309        assert_eq!(errs.len(), 1);
1310    }
1311
1312    // --- Exhaustiveness checking tests ---
1313
1314    fn warnings(source: &str) -> Vec<String> {
1315        check_source(source)
1316            .into_iter()
1317            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1318            .map(|d| d.message)
1319            .collect()
1320    }
1321
1322    #[test]
1323    fn test_exhaustive_match_no_warning() {
1324        let warns = warnings(
1325            r#"pipeline t(task) {
1326  enum Color { Red, Green, Blue }
1327  let c = Color.Red
1328  match c.variant {
1329    "Red" -> { log("r") }
1330    "Green" -> { log("g") }
1331    "Blue" -> { log("b") }
1332  }
1333}"#,
1334        );
1335        let exhaustive_warns: Vec<_> = warns
1336            .iter()
1337            .filter(|w| w.contains("Non-exhaustive"))
1338            .collect();
1339        assert!(exhaustive_warns.is_empty());
1340    }
1341
1342    #[test]
1343    fn test_non_exhaustive_match_warning() {
1344        let warns = warnings(
1345            r#"pipeline t(task) {
1346  enum Color { Red, Green, Blue }
1347  let c = Color.Red
1348  match c.variant {
1349    "Red" -> { log("r") }
1350    "Green" -> { log("g") }
1351  }
1352}"#,
1353        );
1354        let exhaustive_warns: Vec<_> = warns
1355            .iter()
1356            .filter(|w| w.contains("Non-exhaustive"))
1357            .collect();
1358        assert_eq!(exhaustive_warns.len(), 1);
1359        assert!(exhaustive_warns[0].contains("Blue"));
1360    }
1361
1362    #[test]
1363    fn test_non_exhaustive_multiple_missing() {
1364        let warns = warnings(
1365            r#"pipeline t(task) {
1366  enum Status { Active, Inactive, Pending }
1367  let s = Status.Active
1368  match s.variant {
1369    "Active" -> { log("a") }
1370  }
1371}"#,
1372        );
1373        let exhaustive_warns: Vec<_> = warns
1374            .iter()
1375            .filter(|w| w.contains("Non-exhaustive"))
1376            .collect();
1377        assert_eq!(exhaustive_warns.len(), 1);
1378        assert!(exhaustive_warns[0].contains("Inactive"));
1379        assert!(exhaustive_warns[0].contains("Pending"));
1380    }
1381
1382    #[test]
1383    fn test_enum_construct_type_inference() {
1384        let errs = errors(
1385            r#"pipeline t(task) {
1386  enum Color { Red, Green, Blue }
1387  let c: Color = Color.Red
1388}"#,
1389        );
1390        assert!(errs.is_empty());
1391    }
1392
1393    // --- Type narrowing tests ---
1394
1395    #[test]
1396    fn test_nil_coalescing_strips_nil() {
1397        // After ??, nil should be stripped from the type
1398        let errs = errors(
1399            r#"pipeline t(task) {
1400  let x: string | nil = nil
1401  let y: string = x ?? "default"
1402}"#,
1403        );
1404        assert!(errs.is_empty());
1405    }
1406}