Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum DiagnosticSeverity {
16    Error,
17    Warning,
18}
19
20/// Inferred type of an expression. None means unknown/untyped (gradual typing).
21type InferredType = Option<TypeExpr>;
22
23/// Scope for tracking variable types.
24#[derive(Debug, Clone)]
25struct TypeScope {
26    /// Variable name → inferred type.
27    vars: BTreeMap<String, InferredType>,
28    /// Function name → (param types, return type).
29    functions: BTreeMap<String, FnSignature>,
30    /// Named type aliases.
31    type_aliases: BTreeMap<String, TypeExpr>,
32    /// Enum declarations: name → variant names.
33    enums: BTreeMap<String, Vec<String>>,
34    /// Interface declarations: name → method signatures.
35    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
36    /// Struct declarations: name → field types.
37    structs: BTreeMap<String, Vec<(String, InferredType)>>,
38    parent: Option<Box<TypeScope>>,
39}
40
41#[derive(Debug, Clone)]
42struct FnSignature {
43    params: Vec<(String, InferredType)>,
44    return_type: InferredType,
45}
46
47impl TypeScope {
48    fn new() -> Self {
49        Self {
50            vars: BTreeMap::new(),
51            functions: BTreeMap::new(),
52            type_aliases: BTreeMap::new(),
53            enums: BTreeMap::new(),
54            interfaces: BTreeMap::new(),
55            structs: BTreeMap::new(),
56            parent: None,
57        }
58    }
59
60    fn child(&self) -> Self {
61        Self {
62            vars: BTreeMap::new(),
63            functions: BTreeMap::new(),
64            type_aliases: BTreeMap::new(),
65            enums: BTreeMap::new(),
66            interfaces: BTreeMap::new(),
67            structs: BTreeMap::new(),
68            parent: Some(Box::new(self.clone())),
69        }
70    }
71
72    fn get_var(&self, name: &str) -> Option<&InferredType> {
73        self.vars
74            .get(name)
75            .or_else(|| self.parent.as_ref()?.get_var(name))
76    }
77
78    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
79        self.functions
80            .get(name)
81            .or_else(|| self.parent.as_ref()?.get_fn(name))
82    }
83
84    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
85        self.type_aliases
86            .get(name)
87            .or_else(|| self.parent.as_ref()?.resolve_type(name))
88    }
89
90    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
91        self.enums
92            .get(name)
93            .or_else(|| self.parent.as_ref()?.get_enum(name))
94    }
95
96    #[allow(dead_code)]
97    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
98        self.interfaces
99            .get(name)
100            .or_else(|| self.parent.as_ref()?.get_interface(name))
101    }
102
103    fn define_var(&mut self, name: &str, ty: InferredType) {
104        self.vars.insert(name.to_string(), ty);
105    }
106
107    fn define_fn(&mut self, name: &str, sig: FnSignature) {
108        self.functions.insert(name.to_string(), sig);
109    }
110}
111
112/// Known return types for builtin functions.
113fn builtin_return_type(name: &str) -> InferredType {
114    match name {
115        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
116        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
117            Some(TypeExpr::Named("nil".into()))
118        }
119        "type_of" | "to_string" | "json_stringify" | "read_file" | "http_get" | "http_post"
120        | "llm_call" | "agent_loop" | "regex_replace" | "path_join" | "temp_dir"
121        | "date_format" | "format" => Some(TypeExpr::Named("string".into())),
122        "to_int" => Some(TypeExpr::Named("int".into())),
123        "to_float" | "timestamp" | "date_parse" => Some(TypeExpr::Named("float".into())),
124        "file_exists" => Some(TypeExpr::Named("bool".into())),
125        "list_dir" => Some(TypeExpr::Named("list".into())),
126        "stat" | "exec" | "shell" | "date_now" => Some(TypeExpr::Named("dict".into())),
127        "env" | "regex_match" => Some(TypeExpr::Union(vec![
128            TypeExpr::Named("string".into()),
129            TypeExpr::Named("nil".into()),
130        ])),
131        "json_parse" => None, // could be any type
132        _ => None,
133    }
134}
135
136/// Check if a name is a known builtin.
137fn is_builtin(name: &str) -> bool {
138    matches!(
139        name,
140        "log"
141            | "print"
142            | "println"
143            | "type_of"
144            | "to_string"
145            | "to_int"
146            | "to_float"
147            | "json_stringify"
148            | "json_parse"
149            | "env"
150            | "timestamp"
151            | "sleep"
152            | "read_file"
153            | "write_file"
154            | "exit"
155            | "regex_match"
156            | "regex_replace"
157            | "http_get"
158            | "http_post"
159            | "llm_call"
160            | "agent_loop"
161            | "await"
162            | "cancel"
163            | "file_exists"
164            | "delete_file"
165            | "list_dir"
166            | "mkdir"
167            | "path_join"
168            | "copy_file"
169            | "append_file"
170            | "temp_dir"
171            | "stat"
172            | "exec"
173            | "shell"
174            | "date_now"
175            | "date_format"
176            | "date_parse"
177            | "format"
178    )
179}
180
181/// The static type checker.
182pub struct TypeChecker {
183    diagnostics: Vec<TypeDiagnostic>,
184    scope: TypeScope,
185}
186
187impl TypeChecker {
188    pub fn new() -> Self {
189        Self {
190            diagnostics: Vec::new(),
191            scope: TypeScope::new(),
192        }
193    }
194
195    /// Check a program and return diagnostics.
196    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
197        // First pass: register type and enum declarations into root scope
198        Self::register_declarations_into(&mut self.scope, program);
199
200        // Also scan pipeline bodies for declarations
201        for snode in program {
202            if let Node::Pipeline { body, .. } = &snode.node {
203                Self::register_declarations_into(&mut self.scope, body);
204            }
205        }
206
207        // Check each top-level node
208        for snode in program {
209            match &snode.node {
210                Node::Pipeline { params, body, .. } => {
211                    let mut child = self.scope.child();
212                    for p in params {
213                        child.define_var(p, None);
214                    }
215                    self.check_block(body, &mut child);
216                }
217                Node::FnDecl {
218                    name,
219                    params,
220                    return_type,
221                    body,
222                    ..
223                } => {
224                    let sig = FnSignature {
225                        params: params
226                            .iter()
227                            .map(|p| (p.name.clone(), p.type_expr.clone()))
228                            .collect(),
229                        return_type: return_type.clone(),
230                    };
231                    self.scope.define_fn(name, sig);
232                    self.check_fn_body(params, return_type, body);
233                }
234                _ => {
235                    let mut scope = self.scope.clone();
236                    self.check_node(snode, &mut scope);
237                    // Merge any new definitions back into the top-level scope
238                    for (name, ty) in scope.vars {
239                        self.scope.vars.entry(name).or_insert(ty);
240                    }
241                }
242            }
243        }
244
245        self.diagnostics
246    }
247
248    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
249    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
250        for snode in nodes {
251            match &snode.node {
252                Node::TypeDecl { name, type_expr } => {
253                    scope.type_aliases.insert(name.clone(), type_expr.clone());
254                }
255                Node::EnumDecl { name, variants } => {
256                    let variant_names: Vec<String> =
257                        variants.iter().map(|v| v.name.clone()).collect();
258                    scope.enums.insert(name.clone(), variant_names);
259                }
260                Node::InterfaceDecl { name, methods } => {
261                    scope.interfaces.insert(name.clone(), methods.clone());
262                }
263                Node::StructDecl { name, fields } => {
264                    let field_types: Vec<(String, InferredType)> = fields
265                        .iter()
266                        .map(|f| (f.name.clone(), f.type_expr.clone()))
267                        .collect();
268                    scope.structs.insert(name.clone(), field_types);
269                }
270                _ => {}
271            }
272        }
273    }
274
275    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
276        for stmt in stmts {
277            self.check_node(stmt, scope);
278        }
279    }
280
281    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
282        let span = snode.span;
283        match &snode.node {
284            Node::LetBinding {
285                name,
286                type_ann,
287                value,
288            } => {
289                let inferred = self.infer_type(value, scope);
290                if let Some(expected) = type_ann {
291                    if let Some(actual) = &inferred {
292                        if !self.types_compatible(expected, actual, scope) {
293                            self.error_at(
294                                format!(
295                                    "Type mismatch: '{}' declared as {}, but assigned {}",
296                                    name,
297                                    format_type(expected),
298                                    format_type(actual)
299                                ),
300                                span,
301                            );
302                        }
303                    }
304                }
305                let ty = type_ann.clone().or(inferred);
306                scope.define_var(name, ty);
307            }
308
309            Node::VarBinding {
310                name,
311                type_ann,
312                value,
313            } => {
314                let inferred = self.infer_type(value, scope);
315                if let Some(expected) = type_ann {
316                    if let Some(actual) = &inferred {
317                        if !self.types_compatible(expected, actual, scope) {
318                            self.error_at(
319                                format!(
320                                    "Type mismatch: '{}' declared as {}, but assigned {}",
321                                    name,
322                                    format_type(expected),
323                                    format_type(actual)
324                                ),
325                                span,
326                            );
327                        }
328                    }
329                }
330                let ty = type_ann.clone().or(inferred);
331                scope.define_var(name, ty);
332            }
333
334            Node::FnDecl {
335                name,
336                params,
337                return_type,
338                body,
339                ..
340            } => {
341                let sig = FnSignature {
342                    params: params
343                        .iter()
344                        .map(|p| (p.name.clone(), p.type_expr.clone()))
345                        .collect(),
346                    return_type: return_type.clone(),
347                };
348                scope.define_fn(name, sig.clone());
349                scope.define_var(name, None);
350                self.check_fn_body(params, return_type, body);
351            }
352
353            Node::FunctionCall { name, args } => {
354                self.check_call(name, args, scope, span);
355            }
356
357            Node::IfElse {
358                condition,
359                then_body,
360                else_body,
361            } => {
362                self.check_node(condition, scope);
363                let mut then_scope = scope.child();
364                self.check_block(then_body, &mut then_scope);
365                if let Some(else_body) = else_body {
366                    let mut else_scope = scope.child();
367                    self.check_block(else_body, &mut else_scope);
368                }
369            }
370
371            Node::ForIn {
372                variable,
373                iterable,
374                body,
375            } => {
376                self.check_node(iterable, scope);
377                let mut loop_scope = scope.child();
378                // Infer loop variable type from iterable
379                let elem_type = match self.infer_type(iterable, scope) {
380                    Some(TypeExpr::List(inner)) => Some(*inner),
381                    Some(TypeExpr::Named(n)) if n == "string" => {
382                        Some(TypeExpr::Named("string".into()))
383                    }
384                    _ => None,
385                };
386                loop_scope.define_var(variable, elem_type);
387                self.check_block(body, &mut loop_scope);
388            }
389
390            Node::WhileLoop { condition, body } => {
391                self.check_node(condition, scope);
392                let mut loop_scope = scope.child();
393                self.check_block(body, &mut loop_scope);
394            }
395
396            Node::TryCatch {
397                body,
398                error_var,
399                catch_body,
400                ..
401            } => {
402                let mut try_scope = scope.child();
403                self.check_block(body, &mut try_scope);
404                let mut catch_scope = scope.child();
405                if let Some(var) = error_var {
406                    catch_scope.define_var(var, None);
407                }
408                self.check_block(catch_body, &mut catch_scope);
409            }
410
411            Node::ReturnStmt {
412                value: Some(val), ..
413            } => {
414                self.check_node(val, scope);
415            }
416
417            Node::Assignment {
418                target, value, op, ..
419            } => {
420                self.check_node(value, scope);
421                if let Node::Identifier(name) = &target.node {
422                    if let Some(Some(var_type)) = scope.get_var(name) {
423                        let value_type = self.infer_type(value, scope);
424                        let assigned = if let Some(op) = op {
425                            let var_inferred = scope.get_var(name).cloned().flatten();
426                            infer_binary_op_type(op, &var_inferred, &value_type)
427                        } else {
428                            value_type
429                        };
430                        if let Some(actual) = &assigned {
431                            if !self.types_compatible(var_type, actual, scope) {
432                                self.error_at(
433                                    format!(
434                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
435                                        format_type(actual),
436                                        name,
437                                        format_type(var_type)
438                                    ),
439                                    span,
440                                );
441                            }
442                        }
443                    }
444                }
445            }
446
447            Node::TypeDecl { name, type_expr } => {
448                scope.type_aliases.insert(name.clone(), type_expr.clone());
449            }
450
451            Node::EnumDecl { name, variants } => {
452                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
453                scope.enums.insert(name.clone(), variant_names);
454            }
455
456            Node::StructDecl { name, fields } => {
457                let field_types: Vec<(String, InferredType)> = fields
458                    .iter()
459                    .map(|f| (f.name.clone(), f.type_expr.clone()))
460                    .collect();
461                scope.structs.insert(name.clone(), field_types);
462            }
463
464            Node::InterfaceDecl { name, methods } => {
465                scope.interfaces.insert(name.clone(), methods.clone());
466            }
467
468            Node::MatchExpr { value, arms } => {
469                self.check_node(value, scope);
470                for arm in arms {
471                    self.check_node(&arm.pattern, scope);
472                    let mut arm_scope = scope.child();
473                    self.check_block(&arm.body, &mut arm_scope);
474                }
475                self.check_match_exhaustiveness(value, arms, scope, span);
476            }
477
478            // Recurse into nested expressions + validate binary op types
479            Node::BinaryOp { op, left, right } => {
480                self.check_node(left, scope);
481                self.check_node(right, scope);
482                // Validate operator/type compatibility
483                let lt = self.infer_type(left, scope);
484                let rt = self.infer_type(right, scope);
485                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
486                    match op.as_str() {
487                        "-" | "*" | "/" | "%" => {
488                            let numeric = ["int", "float"];
489                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
490                                self.warning_at(
491                                    format!(
492                                        "Operator '{op}' may not be valid for types {} and {}",
493                                        l, r
494                                    ),
495                                    span,
496                                );
497                            }
498                        }
499                        "+" => {
500                            // + is valid for int, float, string, list, dict
501                            let valid = ["int", "float", "string", "list", "dict"];
502                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
503                                self.warning_at(
504                                    format!(
505                                        "Operator '+' may not be valid for types {} and {}",
506                                        l, r
507                                    ),
508                                    span,
509                                );
510                            }
511                        }
512                        _ => {}
513                    }
514                }
515            }
516            Node::UnaryOp { operand, .. } => {
517                self.check_node(operand, scope);
518            }
519            Node::MethodCall { object, args, .. } => {
520                self.check_node(object, scope);
521                for arg in args {
522                    self.check_node(arg, scope);
523                }
524            }
525            Node::PropertyAccess { object, .. } => {
526                self.check_node(object, scope);
527            }
528            Node::SubscriptAccess { object, index } => {
529                self.check_node(object, scope);
530                self.check_node(index, scope);
531            }
532
533            // Terminals — nothing to check
534            _ => {}
535        }
536    }
537
538    fn check_fn_body(
539        &mut self,
540        params: &[TypedParam],
541        return_type: &Option<TypeExpr>,
542        body: &[SNode],
543    ) {
544        let mut fn_scope = self.scope.child();
545        for param in params {
546            fn_scope.define_var(&param.name, param.type_expr.clone());
547        }
548        self.check_block(body, &mut fn_scope);
549
550        // Check return statements against declared return type
551        if let Some(ret_type) = return_type {
552            for stmt in body {
553                self.check_return_type(stmt, ret_type, &fn_scope);
554            }
555        }
556    }
557
558    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
559        let span = snode.span;
560        match &snode.node {
561            Node::ReturnStmt { value: Some(val) } => {
562                let inferred = self.infer_type(val, scope);
563                if let Some(actual) = &inferred {
564                    if !self.types_compatible(expected, actual, scope) {
565                        self.error_at(
566                            format!(
567                                "Return type mismatch: expected {}, got {}",
568                                format_type(expected),
569                                format_type(actual)
570                            ),
571                            span,
572                        );
573                    }
574                }
575            }
576            Node::IfElse {
577                then_body,
578                else_body,
579                ..
580            } => {
581                for stmt in then_body {
582                    self.check_return_type(stmt, expected, scope);
583                }
584                if let Some(else_body) = else_body {
585                    for stmt in else_body {
586                        self.check_return_type(stmt, expected, scope);
587                    }
588                }
589            }
590            _ => {}
591        }
592    }
593
594    /// Check if a match expression on an enum's `.variant` property covers all variants.
595    fn check_match_exhaustiveness(
596        &mut self,
597        value: &SNode,
598        arms: &[MatchArm],
599        scope: &TypeScope,
600        span: Span,
601    ) {
602        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
603        let enum_name = match &value.node {
604            Node::PropertyAccess { object, property } if property == "variant" => {
605                // Infer the type of the object
606                match self.infer_type(object, scope) {
607                    Some(TypeExpr::Named(name)) => {
608                        if scope.get_enum(&name).is_some() {
609                            Some(name)
610                        } else {
611                            None
612                        }
613                    }
614                    _ => None,
615                }
616            }
617            _ => {
618                // Direct match on an enum value: match <expr> { ... }
619                match self.infer_type(value, scope) {
620                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
621                    _ => None,
622                }
623            }
624        };
625
626        let Some(enum_name) = enum_name else {
627            return;
628        };
629        let Some(variants) = scope.get_enum(&enum_name) else {
630            return;
631        };
632
633        // Collect variant names covered by match arms
634        let mut covered: Vec<String> = Vec::new();
635        let mut has_wildcard = false;
636
637        for arm in arms {
638            match &arm.pattern.node {
639                // String literal pattern (matching on .variant): "VariantA"
640                Node::StringLiteral(s) => covered.push(s.clone()),
641                // Identifier pattern acts as a wildcard/catch-all
642                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
643                    has_wildcard = true;
644                }
645                // Direct enum construct pattern: EnumName.Variant
646                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
647                // PropertyAccess pattern: EnumName.Variant (no args)
648                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
649                _ => {
650                    // Unknown pattern shape — conservatively treat as wildcard
651                    has_wildcard = true;
652                }
653            }
654        }
655
656        if has_wildcard {
657            return;
658        }
659
660        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
661        if !missing.is_empty() {
662            let missing_str = missing
663                .iter()
664                .map(|s| format!("\"{}\"", s))
665                .collect::<Vec<_>>()
666                .join(", ");
667            self.warning_at(
668                format!(
669                    "Non-exhaustive match on enum {}: missing variants {}",
670                    enum_name, missing_str
671                ),
672                span,
673            );
674        }
675    }
676
677    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
678        // Check against known function signatures
679        if let Some(sig) = scope.get_fn(name).cloned() {
680            if args.len() != sig.params.len() && !is_builtin(name) {
681                self.warning_at(
682                    format!(
683                        "Function '{}' expects {} arguments, got {}",
684                        name,
685                        sig.params.len(),
686                        args.len()
687                    ),
688                    span,
689                );
690            }
691            for (i, (arg, (param_name, param_type))) in
692                args.iter().zip(sig.params.iter()).enumerate()
693            {
694                if let Some(expected) = param_type {
695                    let actual = self.infer_type(arg, scope);
696                    if let Some(actual) = &actual {
697                        if !self.types_compatible(expected, actual, scope) {
698                            self.error_at(
699                                format!(
700                                    "Argument {} ('{}'): expected {}, got {}",
701                                    i + 1,
702                                    param_name,
703                                    format_type(expected),
704                                    format_type(actual)
705                                ),
706                                arg.span,
707                            );
708                        }
709                    }
710                }
711            }
712        }
713        // Check args recursively
714        for arg in args {
715            self.check_node(arg, scope);
716        }
717    }
718
719    /// Infer the type of an expression.
720    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
721        match &snode.node {
722            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
723            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
724            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
725                Some(TypeExpr::Named("string".into()))
726            }
727            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
728            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
729            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
730            Node::DictLiteral(_) => Some(TypeExpr::Named("dict".into())),
731            Node::Closure { .. } => Some(TypeExpr::Named("closure".into())),
732
733            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
734
735            Node::FunctionCall { name, .. } => {
736                // Check user-defined function return types
737                if let Some(sig) = scope.get_fn(name) {
738                    return sig.return_type.clone();
739                }
740                // Check builtin return types
741                builtin_return_type(name)
742            }
743
744            Node::BinaryOp { op, left, right } => {
745                let lt = self.infer_type(left, scope);
746                let rt = self.infer_type(right, scope);
747                infer_binary_op_type(op, &lt, &rt)
748            }
749
750            Node::UnaryOp { op, operand } => {
751                let t = self.infer_type(operand, scope);
752                match op.as_str() {
753                    "!" => Some(TypeExpr::Named("bool".into())),
754                    "-" => t, // negation preserves type
755                    _ => None,
756                }
757            }
758
759            Node::Ternary {
760                true_expr,
761                false_expr,
762                ..
763            } => {
764                let tt = self.infer_type(true_expr, scope);
765                let ft = self.infer_type(false_expr, scope);
766                match (&tt, &ft) {
767                    (Some(a), Some(b)) if a == b => tt,
768                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
769                    (Some(_), None) => tt,
770                    (None, Some(_)) => ft,
771                    (None, None) => None,
772                }
773            }
774
775            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
776
777            Node::PropertyAccess { object, property } => {
778                // EnumName.Variant → infer as the enum type
779                if let Node::Identifier(name) = &object.node {
780                    if scope.get_enum(name).is_some() {
781                        return Some(TypeExpr::Named(name.clone()));
782                    }
783                }
784                // .variant on an enum value → string
785                if property == "variant" {
786                    let obj_type = self.infer_type(object, scope);
787                    if let Some(TypeExpr::Named(name)) = &obj_type {
788                        if scope.get_enum(name).is_some() {
789                            return Some(TypeExpr::Named("string".into()));
790                        }
791                    }
792                }
793                None
794            }
795
796            Node::SubscriptAccess { object, .. } => {
797                let obj_type = self.infer_type(object, scope);
798                match &obj_type {
799                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
800                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
801                    Some(TypeExpr::Named(n)) if n == "list" => None,
802                    Some(TypeExpr::Named(n)) if n == "dict" => None,
803                    Some(TypeExpr::Named(n)) if n == "string" => {
804                        Some(TypeExpr::Named("string".into()))
805                    }
806                    _ => None,
807                }
808            }
809            Node::MethodCall { object, method, .. } => {
810                let obj_type = self.infer_type(object, scope);
811                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
812                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
813                match method.as_str() {
814                    // Shared: bool-returning methods
815                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
816                        Some(TypeExpr::Named("bool".into()))
817                    }
818                    // Shared: int-returning methods
819                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
820                    // String methods
821                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
822                    | "pad_left" | "pad_right" | "repeat" | "join" => {
823                        Some(TypeExpr::Named("string".into()))
824                    }
825                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
826                    // filter returns dict for dicts, list for lists
827                    "filter" => {
828                        if is_dict {
829                            Some(TypeExpr::Named("dict".into()))
830                        } else {
831                            Some(TypeExpr::Named("list".into()))
832                        }
833                    }
834                    // List methods
835                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
836                    "reduce" | "find" | "first" | "last" => None,
837                    // Dict methods
838                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
839                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
840                    // Conversions
841                    "to_string" => Some(TypeExpr::Named("string".into())),
842                    "to_int" => Some(TypeExpr::Named("int".into())),
843                    "to_float" => Some(TypeExpr::Named("float".into())),
844                    _ => None,
845                }
846            }
847
848            _ => None,
849        }
850    }
851
852    /// Check if two types are compatible (actual can be assigned to expected).
853    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
854        let expected = self.resolve_alias(expected, scope);
855        let actual = self.resolve_alias(actual, scope);
856
857        match (&expected, &actual) {
858            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
859            (TypeExpr::Union(members), actual_type) => members
860                .iter()
861                .any(|m| self.types_compatible(m, actual_type, scope)),
862            (expected_type, TypeExpr::Union(members)) => members
863                .iter()
864                .all(|m| self.types_compatible(expected_type, m, scope)),
865            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
866            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
867                if expected_field.optional {
868                    return true;
869                }
870                af.iter().any(|actual_field| {
871                    actual_field.name == expected_field.name
872                        && self.types_compatible(
873                            &expected_field.type_expr,
874                            &actual_field.type_expr,
875                            scope,
876                        )
877                })
878            }),
879            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
880                self.types_compatible(expected_inner, actual_inner, scope)
881            }
882            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
883            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
884            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
885                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
886            }
887            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
888            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
889            _ => false,
890        }
891    }
892
893    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
894        if let TypeExpr::Named(name) = ty {
895            if let Some(resolved) = scope.resolve_type(name) {
896                return resolved.clone();
897            }
898        }
899        ty.clone()
900    }
901
902    fn error_at(&mut self, message: String, span: Span) {
903        self.diagnostics.push(TypeDiagnostic {
904            message,
905            severity: DiagnosticSeverity::Error,
906            span: Some(span),
907        });
908    }
909
910    fn warning_at(&mut self, message: String, span: Span) {
911        self.diagnostics.push(TypeDiagnostic {
912            message,
913            severity: DiagnosticSeverity::Warning,
914            span: Some(span),
915        });
916    }
917}
918
919impl Default for TypeChecker {
920    fn default() -> Self {
921        Self::new()
922    }
923}
924
925/// Infer the result type of a binary operation.
926fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
927    match op {
928        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" => Some(TypeExpr::Named("bool".into())),
929        "+" => match (left, right) {
930            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
931                match (l.as_str(), r.as_str()) {
932                    ("int", "int") => Some(TypeExpr::Named("int".into())),
933                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
934                    ("string", _) => Some(TypeExpr::Named("string".into())),
935                    ("list", "list") => Some(TypeExpr::Named("list".into())),
936                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
937                    _ => Some(TypeExpr::Named("string".into())),
938                }
939            }
940            _ => None,
941        },
942        "-" | "*" | "/" | "%" => match (left, right) {
943            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
944                match (l.as_str(), r.as_str()) {
945                    ("int", "int") => Some(TypeExpr::Named("int".into())),
946                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
947                    _ => None,
948                }
949            }
950            _ => None,
951        },
952        "??" => match (left, right) {
953            (Some(TypeExpr::Union(members)), _) => {
954                let non_nil: Vec<_> = members
955                    .iter()
956                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
957                    .cloned()
958                    .collect();
959                if non_nil.len() == 1 {
960                    Some(non_nil[0].clone())
961                } else if non_nil.is_empty() {
962                    right.clone()
963                } else {
964                    Some(TypeExpr::Union(non_nil))
965                }
966            }
967            _ => right.clone(),
968        },
969        "|>" => None,
970        _ => None,
971    }
972}
973
974/// Format a type expression for display in error messages.
975pub fn format_type(ty: &TypeExpr) -> String {
976    match ty {
977        TypeExpr::Named(n) => n.clone(),
978        TypeExpr::Union(types) => types
979            .iter()
980            .map(format_type)
981            .collect::<Vec<_>>()
982            .join(" | "),
983        TypeExpr::Shape(fields) => {
984            let inner: Vec<String> = fields
985                .iter()
986                .map(|f| {
987                    let opt = if f.optional { "?" } else { "" };
988                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
989                })
990                .collect();
991            format!("{{{}}}", inner.join(", "))
992        }
993        TypeExpr::List(inner) => format!("list[{}]", format_type(inner)),
994        TypeExpr::DictType(k, v) => format!("dict[{}, {}]", format_type(k), format_type(v)),
995    }
996}
997
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001    use crate::Parser;
1002    use harn_lexer::Lexer;
1003
1004    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1005        let mut lexer = Lexer::new(source);
1006        let tokens = lexer.tokenize().unwrap();
1007        let mut parser = Parser::new(tokens);
1008        let program = parser.parse().unwrap();
1009        TypeChecker::new().check(&program)
1010    }
1011
1012    fn errors(source: &str) -> Vec<String> {
1013        check_source(source)
1014            .into_iter()
1015            .filter(|d| d.severity == DiagnosticSeverity::Error)
1016            .map(|d| d.message)
1017            .collect()
1018    }
1019
1020    #[test]
1021    fn test_no_errors_for_untyped_code() {
1022        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1023        assert!(errs.is_empty());
1024    }
1025
1026    #[test]
1027    fn test_correct_typed_let() {
1028        let errs = errors("pipeline t(task) { let x: int = 42 }");
1029        assert!(errs.is_empty());
1030    }
1031
1032    #[test]
1033    fn test_type_mismatch_let() {
1034        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1035        assert_eq!(errs.len(), 1);
1036        assert!(errs[0].contains("Type mismatch"));
1037        assert!(errs[0].contains("int"));
1038        assert!(errs[0].contains("string"));
1039    }
1040
1041    #[test]
1042    fn test_correct_typed_fn() {
1043        let errs = errors(
1044            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1045        );
1046        assert!(errs.is_empty());
1047    }
1048
1049    #[test]
1050    fn test_fn_arg_type_mismatch() {
1051        let errs = errors(
1052            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1053add("hello", 2) }"#,
1054        );
1055        assert_eq!(errs.len(), 1);
1056        assert!(errs[0].contains("Argument 1"));
1057        assert!(errs[0].contains("expected int"));
1058    }
1059
1060    #[test]
1061    fn test_return_type_mismatch() {
1062        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1063        assert_eq!(errs.len(), 1);
1064        assert!(errs[0].contains("Return type mismatch"));
1065    }
1066
1067    #[test]
1068    fn test_union_type_compatible() {
1069        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1070        assert!(errs.is_empty());
1071    }
1072
1073    #[test]
1074    fn test_union_type_mismatch() {
1075        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1076        assert_eq!(errs.len(), 1);
1077        assert!(errs[0].contains("Type mismatch"));
1078    }
1079
1080    #[test]
1081    fn test_type_inference_propagation() {
1082        let errs = errors(
1083            r#"pipeline t(task) {
1084  fn add(a: int, b: int) -> int { return a + b }
1085  let result: string = add(1, 2)
1086}"#,
1087        );
1088        assert_eq!(errs.len(), 1);
1089        assert!(errs[0].contains("Type mismatch"));
1090        assert!(errs[0].contains("string"));
1091        assert!(errs[0].contains("int"));
1092    }
1093
1094    #[test]
1095    fn test_builtin_return_type_inference() {
1096        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
1097        assert_eq!(errs.len(), 1);
1098        assert!(errs[0].contains("string"));
1099        assert!(errs[0].contains("int"));
1100    }
1101
1102    #[test]
1103    fn test_binary_op_type_inference() {
1104        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
1105        assert_eq!(errs.len(), 1);
1106    }
1107
1108    #[test]
1109    fn test_comparison_returns_bool() {
1110        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
1111        assert!(errs.is_empty());
1112    }
1113
1114    #[test]
1115    fn test_int_float_promotion() {
1116        let errs = errors("pipeline t(task) { let x: float = 42 }");
1117        assert!(errs.is_empty());
1118    }
1119
1120    #[test]
1121    fn test_untyped_code_no_errors() {
1122        let errs = errors(
1123            r#"pipeline t(task) {
1124  fn process(data) {
1125    let result = data + " processed"
1126    return result
1127  }
1128  log(process("hello"))
1129}"#,
1130        );
1131        assert!(errs.is_empty());
1132    }
1133
1134    #[test]
1135    fn test_type_alias() {
1136        let errs = errors(
1137            r#"pipeline t(task) {
1138  type Name = string
1139  let x: Name = "hello"
1140}"#,
1141        );
1142        assert!(errs.is_empty());
1143    }
1144
1145    #[test]
1146    fn test_type_alias_mismatch() {
1147        let errs = errors(
1148            r#"pipeline t(task) {
1149  type Name = string
1150  let x: Name = 42
1151}"#,
1152        );
1153        assert_eq!(errs.len(), 1);
1154    }
1155
1156    #[test]
1157    fn test_assignment_type_check() {
1158        let errs = errors(
1159            r#"pipeline t(task) {
1160  var x: int = 0
1161  x = "hello"
1162}"#,
1163        );
1164        assert_eq!(errs.len(), 1);
1165        assert!(errs[0].contains("cannot assign string"));
1166    }
1167
1168    #[test]
1169    fn test_covariance_int_to_float_in_fn() {
1170        let errs = errors(
1171            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
1172        );
1173        assert!(errs.is_empty());
1174    }
1175
1176    #[test]
1177    fn test_covariance_return_type() {
1178        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
1179        assert!(errs.is_empty());
1180    }
1181
1182    #[test]
1183    fn test_no_contravariance_float_to_int() {
1184        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
1185        assert_eq!(errs.len(), 1);
1186    }
1187
1188    // --- Exhaustiveness checking tests ---
1189
1190    fn warnings(source: &str) -> Vec<String> {
1191        check_source(source)
1192            .into_iter()
1193            .filter(|d| d.severity == DiagnosticSeverity::Warning)
1194            .map(|d| d.message)
1195            .collect()
1196    }
1197
1198    #[test]
1199    fn test_exhaustive_match_no_warning() {
1200        let warns = warnings(
1201            r#"pipeline t(task) {
1202  enum Color { Red, Green, Blue }
1203  let c = Color.Red
1204  match c.variant {
1205    "Red" -> { log("r") }
1206    "Green" -> { log("g") }
1207    "Blue" -> { log("b") }
1208  }
1209}"#,
1210        );
1211        let exhaustive_warns: Vec<_> = warns
1212            .iter()
1213            .filter(|w| w.contains("Non-exhaustive"))
1214            .collect();
1215        assert!(exhaustive_warns.is_empty());
1216    }
1217
1218    #[test]
1219    fn test_non_exhaustive_match_warning() {
1220        let warns = warnings(
1221            r#"pipeline t(task) {
1222  enum Color { Red, Green, Blue }
1223  let c = Color.Red
1224  match c.variant {
1225    "Red" -> { log("r") }
1226    "Green" -> { log("g") }
1227  }
1228}"#,
1229        );
1230        let exhaustive_warns: Vec<_> = warns
1231            .iter()
1232            .filter(|w| w.contains("Non-exhaustive"))
1233            .collect();
1234        assert_eq!(exhaustive_warns.len(), 1);
1235        assert!(exhaustive_warns[0].contains("Blue"));
1236    }
1237
1238    #[test]
1239    fn test_non_exhaustive_multiple_missing() {
1240        let warns = warnings(
1241            r#"pipeline t(task) {
1242  enum Status { Active, Inactive, Pending }
1243  let s = Status.Active
1244  match s.variant {
1245    "Active" -> { log("a") }
1246  }
1247}"#,
1248        );
1249        let exhaustive_warns: Vec<_> = warns
1250            .iter()
1251            .filter(|w| w.contains("Non-exhaustive"))
1252            .collect();
1253        assert_eq!(exhaustive_warns.len(), 1);
1254        assert!(exhaustive_warns[0].contains("Inactive"));
1255        assert!(exhaustive_warns[0].contains("Pending"));
1256    }
1257
1258    #[test]
1259    fn test_enum_construct_type_inference() {
1260        let errs = errors(
1261            r#"pipeline t(task) {
1262  enum Color { Red, Green, Blue }
1263  let c: Color = Color.Red
1264}"#,
1265        );
1266        assert!(errs.is_empty());
1267    }
1268
1269    // --- Type narrowing tests ---
1270
1271    #[test]
1272    fn test_nil_coalescing_strips_nil() {
1273        // After ??, nil should be stripped from the type
1274        let errs = errors(
1275            r#"pipeline t(task) {
1276  let x: string | nil = nil
1277  let y: string = x ?? "default"
1278}"#,
1279        );
1280        assert!(errs.is_empty());
1281    }
1282}