Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use crate::builtin_signatures;
5use harn_lexer::Span;
6
7/// A diagnostic produced by the type checker.
8#[derive(Debug, Clone)]
9pub struct TypeDiagnostic {
10    pub message: String,
11    pub severity: DiagnosticSeverity,
12    pub span: Option<Span>,
13    pub help: Option<String>,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum DiagnosticSeverity {
18    Error,
19    Warning,
20}
21
22/// Inferred type of an expression. None means unknown/untyped (gradual typing).
23type InferredType = Option<TypeExpr>;
24
25/// Scope for tracking variable types.
26#[derive(Debug, Clone)]
27struct TypeScope {
28    /// Variable name → inferred type.
29    vars: BTreeMap<String, InferredType>,
30    /// Function name → (param types, return type).
31    functions: BTreeMap<String, FnSignature>,
32    /// Named type aliases.
33    type_aliases: BTreeMap<String, TypeExpr>,
34    /// Enum declarations: name → variant names.
35    enums: BTreeMap<String, Vec<String>>,
36    /// Interface declarations: name → method signatures.
37    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
38    /// Struct declarations: name → field types.
39    structs: BTreeMap<String, Vec<(String, InferredType)>>,
40    /// Impl block methods: type_name → method signatures.
41    impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
42    /// Generic type parameter names in scope (treated as compatible with any type).
43    generic_type_params: std::collections::BTreeSet<String>,
44    /// Where-clause constraints: type_param → interface_bound.
45    /// Used for definition-site checking of generic function bodies.
46    where_constraints: BTreeMap<String, String>,
47    parent: Option<Box<TypeScope>>,
48}
49
50/// Method signature extracted from an impl block (for interface checking).
51#[derive(Debug, Clone)]
52struct ImplMethodSig {
53    name: String,
54    /// Number of parameters excluding `self`.
55    param_count: usize,
56    /// Parameter types (excluding `self`), None means untyped.
57    param_types: Vec<Option<TypeExpr>>,
58    /// Return type, None means untyped.
59    return_type: Option<TypeExpr>,
60}
61
62#[derive(Debug, Clone)]
63struct FnSignature {
64    params: Vec<(String, InferredType)>,
65    return_type: InferredType,
66    /// Generic type parameter names declared on the function.
67    type_param_names: Vec<String>,
68    /// Number of required parameters (those without defaults).
69    required_params: usize,
70    /// Where-clause constraints: (type_param_name, interface_bound).
71    where_clauses: Vec<(String, String)>,
72}
73
74impl TypeScope {
75    fn new() -> Self {
76        Self {
77            vars: BTreeMap::new(),
78            functions: BTreeMap::new(),
79            type_aliases: BTreeMap::new(),
80            enums: BTreeMap::new(),
81            interfaces: BTreeMap::new(),
82            structs: BTreeMap::new(),
83            impl_methods: BTreeMap::new(),
84            generic_type_params: std::collections::BTreeSet::new(),
85            where_constraints: BTreeMap::new(),
86            parent: None,
87        }
88    }
89
90    fn child(&self) -> Self {
91        Self {
92            vars: BTreeMap::new(),
93            functions: BTreeMap::new(),
94            type_aliases: BTreeMap::new(),
95            enums: BTreeMap::new(),
96            interfaces: BTreeMap::new(),
97            structs: BTreeMap::new(),
98            impl_methods: BTreeMap::new(),
99            generic_type_params: std::collections::BTreeSet::new(),
100            where_constraints: BTreeMap::new(),
101            parent: Some(Box::new(self.clone())),
102        }
103    }
104
105    fn get_var(&self, name: &str) -> Option<&InferredType> {
106        self.vars
107            .get(name)
108            .or_else(|| self.parent.as_ref()?.get_var(name))
109    }
110
111    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
112        self.functions
113            .get(name)
114            .or_else(|| self.parent.as_ref()?.get_fn(name))
115    }
116
117    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
118        self.type_aliases
119            .get(name)
120            .or_else(|| self.parent.as_ref()?.resolve_type(name))
121    }
122
123    fn is_generic_type_param(&self, name: &str) -> bool {
124        self.generic_type_params.contains(name)
125            || self
126                .parent
127                .as_ref()
128                .is_some_and(|p| p.is_generic_type_param(name))
129    }
130
131    fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
132        self.where_constraints
133            .get(type_param)
134            .map(|s| s.as_str())
135            .or_else(|| {
136                self.parent
137                    .as_ref()
138                    .and_then(|p| p.get_where_constraint(type_param))
139            })
140    }
141
142    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
143        self.enums
144            .get(name)
145            .or_else(|| self.parent.as_ref()?.get_enum(name))
146    }
147
148    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
149        self.interfaces
150            .get(name)
151            .or_else(|| self.parent.as_ref()?.get_interface(name))
152    }
153
154    fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
155        self.structs
156            .get(name)
157            .or_else(|| self.parent.as_ref()?.get_struct(name))
158    }
159
160    fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
161        self.impl_methods
162            .get(name)
163            .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
164    }
165
166    fn define_var(&mut self, name: &str, ty: InferredType) {
167        self.vars.insert(name.to_string(), ty);
168    }
169
170    fn define_fn(&mut self, name: &str, sig: FnSignature) {
171        self.functions.insert(name.to_string(), sig);
172    }
173}
174
175/// Known return types for builtin functions. Delegates to the shared
176/// [`builtin_signatures`] registry — see that module for the full table.
177fn builtin_return_type(name: &str) -> InferredType {
178    builtin_signatures::builtin_return_type(name)
179}
180
181/// Check if a name is a known builtin. Delegates to the shared
182/// [`builtin_signatures`] registry.
183fn is_builtin(name: &str) -> bool {
184    builtin_signatures::is_builtin(name)
185}
186
187/// The static type checker.
188pub struct TypeChecker {
189    diagnostics: Vec<TypeDiagnostic>,
190    scope: TypeScope,
191}
192
193impl TypeChecker {
194    pub fn new() -> Self {
195        Self {
196            diagnostics: Vec::new(),
197            scope: TypeScope::new(),
198        }
199    }
200
201    /// Check a program and return diagnostics.
202    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
203        // First pass: register type and enum declarations into root scope
204        Self::register_declarations_into(&mut self.scope, program);
205
206        // Also scan pipeline bodies for declarations
207        for snode in program {
208            if let Node::Pipeline { body, .. } = &snode.node {
209                Self::register_declarations_into(&mut self.scope, body);
210            }
211        }
212
213        // Check each top-level node
214        for snode in program {
215            match &snode.node {
216                Node::Pipeline { params, body, .. } => {
217                    let mut child = self.scope.child();
218                    for p in params {
219                        child.define_var(p, None);
220                    }
221                    self.check_block(body, &mut child);
222                }
223                Node::FnDecl {
224                    name,
225                    type_params,
226                    params,
227                    return_type,
228                    where_clauses,
229                    body,
230                    ..
231                } => {
232                    let required_params =
233                        params.iter().filter(|p| p.default_value.is_none()).count();
234                    let sig = FnSignature {
235                        params: params
236                            .iter()
237                            .map(|p| (p.name.clone(), p.type_expr.clone()))
238                            .collect(),
239                        return_type: return_type.clone(),
240                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
241                        required_params,
242                        where_clauses: where_clauses
243                            .iter()
244                            .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
245                            .collect(),
246                    };
247                    self.scope.define_fn(name, sig);
248                    self.check_fn_body(type_params, params, return_type, body, where_clauses);
249                }
250                _ => {
251                    let mut scope = self.scope.clone();
252                    self.check_node(snode, &mut scope);
253                    // Merge any new definitions back into the top-level scope
254                    for (name, ty) in scope.vars {
255                        self.scope.vars.entry(name).or_insert(ty);
256                    }
257                }
258            }
259        }
260
261        self.diagnostics
262    }
263
264    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
265    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
266        for snode in nodes {
267            match &snode.node {
268                Node::TypeDecl { name, type_expr } => {
269                    scope.type_aliases.insert(name.clone(), type_expr.clone());
270                }
271                Node::EnumDecl { name, variants, .. } => {
272                    let variant_names: Vec<String> =
273                        variants.iter().map(|v| v.name.clone()).collect();
274                    scope.enums.insert(name.clone(), variant_names);
275                }
276                Node::InterfaceDecl { name, methods, .. } => {
277                    scope.interfaces.insert(name.clone(), methods.clone());
278                }
279                Node::StructDecl { name, fields, .. } => {
280                    let field_types: Vec<(String, InferredType)> = fields
281                        .iter()
282                        .map(|f| (f.name.clone(), f.type_expr.clone()))
283                        .collect();
284                    scope.structs.insert(name.clone(), field_types);
285                }
286                Node::ImplBlock {
287                    type_name, methods, ..
288                } => {
289                    let sigs: Vec<ImplMethodSig> = methods
290                        .iter()
291                        .filter_map(|m| {
292                            if let Node::FnDecl {
293                                name,
294                                params,
295                                return_type,
296                                ..
297                            } = &m.node
298                            {
299                                let non_self: Vec<_> =
300                                    params.iter().filter(|p| p.name != "self").collect();
301                                let param_count = non_self.len();
302                                let param_types: Vec<Option<TypeExpr>> =
303                                    non_self.iter().map(|p| p.type_expr.clone()).collect();
304                                Some(ImplMethodSig {
305                                    name: name.clone(),
306                                    param_count,
307                                    param_types,
308                                    return_type: return_type.clone(),
309                                })
310                            } else {
311                                None
312                            }
313                        })
314                        .collect();
315                    scope.impl_methods.insert(type_name.clone(), sigs);
316                }
317                _ => {}
318            }
319        }
320    }
321
322    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
323        for stmt in stmts {
324            self.check_node(stmt, scope);
325        }
326    }
327
328    /// Define variables from a destructuring pattern in the given scope (as unknown type).
329    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
330        match pattern {
331            BindingPattern::Identifier(name) => {
332                scope.define_var(name, None);
333            }
334            BindingPattern::Dict(fields) => {
335                for field in fields {
336                    let name = field.alias.as_deref().unwrap_or(&field.key);
337                    scope.define_var(name, None);
338                }
339            }
340            BindingPattern::List(elements) => {
341                for elem in elements {
342                    scope.define_var(&elem.name, None);
343                }
344            }
345        }
346    }
347
348    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
349        let span = snode.span;
350        match &snode.node {
351            Node::LetBinding {
352                pattern,
353                type_ann,
354                value,
355            } => {
356                let inferred = self.infer_type(value, scope);
357                if let BindingPattern::Identifier(name) = pattern {
358                    if let Some(expected) = type_ann {
359                        if let Some(actual) = &inferred {
360                            if !self.types_compatible(expected, actual, scope) {
361                                let mut msg = format!(
362                                    "Type mismatch: '{}' declared as {}, but assigned {}",
363                                    name,
364                                    format_type(expected),
365                                    format_type(actual)
366                                );
367                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
368                                    msg.push_str(&format!(" ({})", detail));
369                                }
370                                self.error_at(msg, span);
371                            }
372                        }
373                    }
374                    let ty = type_ann.clone().or(inferred);
375                    scope.define_var(name, ty);
376                } else {
377                    Self::define_pattern_vars(pattern, scope);
378                }
379            }
380
381            Node::VarBinding {
382                pattern,
383                type_ann,
384                value,
385            } => {
386                let inferred = self.infer_type(value, scope);
387                if let BindingPattern::Identifier(name) = pattern {
388                    if let Some(expected) = type_ann {
389                        if let Some(actual) = &inferred {
390                            if !self.types_compatible(expected, actual, scope) {
391                                let mut msg = format!(
392                                    "Type mismatch: '{}' declared as {}, but assigned {}",
393                                    name,
394                                    format_type(expected),
395                                    format_type(actual)
396                                );
397                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
398                                    msg.push_str(&format!(" ({})", detail));
399                                }
400                                self.error_at(msg, span);
401                            }
402                        }
403                    }
404                    let ty = type_ann.clone().or(inferred);
405                    scope.define_var(name, ty);
406                } else {
407                    Self::define_pattern_vars(pattern, scope);
408                }
409            }
410
411            Node::FnDecl {
412                name,
413                type_params,
414                params,
415                return_type,
416                where_clauses,
417                body,
418                ..
419            } => {
420                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
421                let sig = FnSignature {
422                    params: params
423                        .iter()
424                        .map(|p| (p.name.clone(), p.type_expr.clone()))
425                        .collect(),
426                    return_type: return_type.clone(),
427                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
428                    required_params,
429                    where_clauses: where_clauses
430                        .iter()
431                        .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
432                        .collect(),
433                };
434                scope.define_fn(name, sig.clone());
435                scope.define_var(name, None);
436                self.check_fn_body(type_params, params, return_type, body, where_clauses);
437            }
438
439            Node::FunctionCall { name, args } => {
440                self.check_call(name, args, scope, span);
441            }
442
443            Node::IfElse {
444                condition,
445                then_body,
446                else_body,
447            } => {
448                self.check_node(condition, scope);
449                let mut then_scope = scope.child();
450                // Narrow union types after nil checks: if x != nil, narrow x
451                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
452                    then_scope.define_var(&var_name, narrowed);
453                }
454                self.check_block(then_body, &mut then_scope);
455                if let Some(else_body) = else_body {
456                    let mut else_scope = scope.child();
457                    self.check_block(else_body, &mut else_scope);
458                }
459            }
460
461            Node::ForIn {
462                pattern,
463                iterable,
464                body,
465            } => {
466                self.check_node(iterable, scope);
467                let mut loop_scope = scope.child();
468                if let BindingPattern::Identifier(variable) = pattern {
469                    // Infer loop variable type from iterable
470                    let elem_type = match self.infer_type(iterable, scope) {
471                        Some(TypeExpr::List(inner)) => Some(*inner),
472                        Some(TypeExpr::Named(n)) if n == "string" => {
473                            Some(TypeExpr::Named("string".into()))
474                        }
475                        _ => None,
476                    };
477                    loop_scope.define_var(variable, elem_type);
478                } else {
479                    Self::define_pattern_vars(pattern, &mut loop_scope);
480                }
481                self.check_block(body, &mut loop_scope);
482            }
483
484            Node::WhileLoop { condition, body } => {
485                self.check_node(condition, scope);
486                let mut loop_scope = scope.child();
487                self.check_block(body, &mut loop_scope);
488            }
489
490            Node::RequireStmt { condition, message } => {
491                self.check_node(condition, scope);
492                if let Some(message) = message {
493                    self.check_node(message, scope);
494                }
495            }
496
497            Node::TryCatch {
498                body,
499                error_var,
500                catch_body,
501                finally_body,
502                ..
503            } => {
504                let mut try_scope = scope.child();
505                self.check_block(body, &mut try_scope);
506                let mut catch_scope = scope.child();
507                if let Some(var) = error_var {
508                    catch_scope.define_var(var, None);
509                }
510                self.check_block(catch_body, &mut catch_scope);
511                if let Some(fb) = finally_body {
512                    let mut finally_scope = scope.child();
513                    self.check_block(fb, &mut finally_scope);
514                }
515            }
516
517            Node::TryExpr { body } => {
518                let mut try_scope = scope.child();
519                self.check_block(body, &mut try_scope);
520            }
521
522            Node::ReturnStmt {
523                value: Some(val), ..
524            } => {
525                self.check_node(val, scope);
526            }
527
528            Node::Assignment {
529                target, value, op, ..
530            } => {
531                self.check_node(value, scope);
532                if let Node::Identifier(name) = &target.node {
533                    if let Some(Some(var_type)) = scope.get_var(name) {
534                        let value_type = self.infer_type(value, scope);
535                        let assigned = if let Some(op) = op {
536                            let var_inferred = scope.get_var(name).cloned().flatten();
537                            infer_binary_op_type(op, &var_inferred, &value_type)
538                        } else {
539                            value_type
540                        };
541                        if let Some(actual) = &assigned {
542                            if !self.types_compatible(var_type, actual, scope) {
543                                self.error_at(
544                                    format!(
545                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
546                                        format_type(actual),
547                                        name,
548                                        format_type(var_type)
549                                    ),
550                                    span,
551                                );
552                            }
553                        }
554                    }
555                }
556            }
557
558            Node::TypeDecl { name, type_expr } => {
559                scope.type_aliases.insert(name.clone(), type_expr.clone());
560            }
561
562            Node::EnumDecl { name, variants, .. } => {
563                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
564                scope.enums.insert(name.clone(), variant_names);
565            }
566
567            Node::StructDecl { name, fields, .. } => {
568                let field_types: Vec<(String, InferredType)> = fields
569                    .iter()
570                    .map(|f| (f.name.clone(), f.type_expr.clone()))
571                    .collect();
572                scope.structs.insert(name.clone(), field_types);
573            }
574
575            Node::InterfaceDecl { name, methods, .. } => {
576                scope.interfaces.insert(name.clone(), methods.clone());
577            }
578
579            Node::ImplBlock {
580                type_name, methods, ..
581            } => {
582                // Register impl methods for interface satisfaction checking
583                let sigs: Vec<ImplMethodSig> = methods
584                    .iter()
585                    .filter_map(|m| {
586                        if let Node::FnDecl {
587                            name,
588                            params,
589                            return_type,
590                            ..
591                        } = &m.node
592                        {
593                            let non_self: Vec<_> =
594                                params.iter().filter(|p| p.name != "self").collect();
595                            let param_count = non_self.len();
596                            let param_types: Vec<Option<TypeExpr>> =
597                                non_self.iter().map(|p| p.type_expr.clone()).collect();
598                            Some(ImplMethodSig {
599                                name: name.clone(),
600                                param_count,
601                                param_types,
602                                return_type: return_type.clone(),
603                            })
604                        } else {
605                            None
606                        }
607                    })
608                    .collect();
609                scope.impl_methods.insert(type_name.clone(), sigs);
610                for method_sn in methods {
611                    self.check_node(method_sn, scope);
612                }
613            }
614
615            Node::TryOperator { operand } => {
616                self.check_node(operand, scope);
617            }
618
619            Node::MatchExpr { value, arms } => {
620                self.check_node(value, scope);
621                let value_type = self.infer_type(value, scope);
622                for arm in arms {
623                    self.check_node(&arm.pattern, scope);
624                    // Check for incompatible literal pattern types
625                    if let Some(ref vt) = value_type {
626                        let value_type_name = format_type(vt);
627                        let mismatch = match &arm.pattern.node {
628                            Node::StringLiteral(_) => {
629                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
630                            }
631                            Node::IntLiteral(_) => {
632                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
633                                    && !self.types_compatible(
634                                        vt,
635                                        &TypeExpr::Named("float".into()),
636                                        scope,
637                                    )
638                            }
639                            Node::FloatLiteral(_) => {
640                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
641                                    && !self.types_compatible(
642                                        vt,
643                                        &TypeExpr::Named("int".into()),
644                                        scope,
645                                    )
646                            }
647                            Node::BoolLiteral(_) => {
648                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
649                            }
650                            _ => false,
651                        };
652                        if mismatch {
653                            let pattern_type = match &arm.pattern.node {
654                                Node::StringLiteral(_) => "string",
655                                Node::IntLiteral(_) => "int",
656                                Node::FloatLiteral(_) => "float",
657                                Node::BoolLiteral(_) => "bool",
658                                _ => unreachable!(),
659                            };
660                            self.warning_at(
661                                format!(
662                                    "Match pattern type mismatch: matching {} against {} literal",
663                                    value_type_name, pattern_type
664                                ),
665                                arm.pattern.span,
666                            );
667                        }
668                    }
669                    let mut arm_scope = scope.child();
670                    self.check_block(&arm.body, &mut arm_scope);
671                }
672                self.check_match_exhaustiveness(value, arms, scope, span);
673            }
674
675            // Recurse into nested expressions + validate binary op types
676            Node::BinaryOp { op, left, right } => {
677                self.check_node(left, scope);
678                self.check_node(right, scope);
679                // Validate operator/type compatibility
680                let lt = self.infer_type(left, scope);
681                let rt = self.infer_type(right, scope);
682                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
683                    match op.as_str() {
684                        "-" | "*" | "/" | "%" => {
685                            let numeric = ["int", "float"];
686                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
687                                self.warning_at(
688                                    format!(
689                                        "Operator '{op}' may not be valid for types {} and {}",
690                                        l, r
691                                    ),
692                                    span,
693                                );
694                            }
695                        }
696                        "+" => {
697                            // + is valid for int, float, string, list, dict
698                            let valid = ["int", "float", "string", "list", "dict"];
699                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
700                                self.warning_at(
701                                    format!(
702                                        "Operator '+' may not be valid for types {} and {}",
703                                        l, r
704                                    ),
705                                    span,
706                                );
707                            }
708                        }
709                        _ => {}
710                    }
711                }
712            }
713            Node::UnaryOp { operand, .. } => {
714                self.check_node(operand, scope);
715            }
716            Node::MethodCall {
717                object,
718                method,
719                args,
720                ..
721            }
722            | Node::OptionalMethodCall {
723                object,
724                method,
725                args,
726                ..
727            } => {
728                self.check_node(object, scope);
729                for arg in args {
730                    self.check_node(arg, scope);
731                }
732                // Definition-site generic checking: if the object's type is a
733                // constrained generic param (where T: Interface), verify the
734                // method exists in the bound interface.
735                if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
736                    if scope.is_generic_type_param(&type_name) {
737                        if let Some(iface_name) = scope.get_where_constraint(&type_name) {
738                            if let Some(iface_methods) = scope.get_interface(iface_name) {
739                                let has_method = iface_methods.iter().any(|m| m.name == *method);
740                                if !has_method {
741                                    self.warning_at(
742                                        format!(
743                                            "Method '{}' not found in interface '{}' (constraint on '{}')",
744                                            method, iface_name, type_name
745                                        ),
746                                        span,
747                                    );
748                                }
749                            }
750                        }
751                    }
752                }
753            }
754            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
755                self.check_node(object, scope);
756            }
757            Node::SubscriptAccess { object, index } => {
758                self.check_node(object, scope);
759                self.check_node(index, scope);
760            }
761            Node::SliceAccess { object, start, end } => {
762                self.check_node(object, scope);
763                if let Some(s) = start {
764                    self.check_node(s, scope);
765                }
766                if let Some(e) = end {
767                    self.check_node(e, scope);
768                }
769            }
770
771            // --- Compound nodes: recurse into children ---
772            Node::Ternary {
773                condition,
774                true_expr,
775                false_expr,
776            } => {
777                self.check_node(condition, scope);
778                self.check_node(true_expr, scope);
779                self.check_node(false_expr, scope);
780            }
781
782            Node::ThrowStmt { value } => {
783                self.check_node(value, scope);
784            }
785
786            Node::GuardStmt {
787                condition,
788                else_body,
789            } => {
790                self.check_node(condition, scope);
791                let mut else_scope = scope.child();
792                self.check_block(else_body, &mut else_scope);
793            }
794
795            Node::SpawnExpr { body } => {
796                let mut spawn_scope = scope.child();
797                self.check_block(body, &mut spawn_scope);
798            }
799
800            Node::Parallel {
801                count,
802                variable,
803                body,
804            } => {
805                self.check_node(count, scope);
806                let mut par_scope = scope.child();
807                if let Some(var) = variable {
808                    par_scope.define_var(var, Some(TypeExpr::Named("int".into())));
809                }
810                self.check_block(body, &mut par_scope);
811            }
812
813            Node::ParallelMap {
814                list,
815                variable,
816                body,
817            }
818            | Node::ParallelSettle {
819                list,
820                variable,
821                body,
822            } => {
823                self.check_node(list, scope);
824                let mut par_scope = scope.child();
825                let elem_type = match self.infer_type(list, scope) {
826                    Some(TypeExpr::List(inner)) => Some(*inner),
827                    _ => None,
828                };
829                par_scope.define_var(variable, elem_type);
830                self.check_block(body, &mut par_scope);
831            }
832
833            Node::SelectExpr {
834                cases,
835                timeout,
836                default_body,
837            } => {
838                for case in cases {
839                    self.check_node(&case.channel, scope);
840                    let mut case_scope = scope.child();
841                    case_scope.define_var(&case.variable, None);
842                    self.check_block(&case.body, &mut case_scope);
843                }
844                if let Some((dur, body)) = timeout {
845                    self.check_node(dur, scope);
846                    let mut timeout_scope = scope.child();
847                    self.check_block(body, &mut timeout_scope);
848                }
849                if let Some(body) = default_body {
850                    let mut default_scope = scope.child();
851                    self.check_block(body, &mut default_scope);
852                }
853            }
854
855            Node::DeadlineBlock { duration, body } => {
856                self.check_node(duration, scope);
857                let mut block_scope = scope.child();
858                self.check_block(body, &mut block_scope);
859            }
860
861            Node::MutexBlock { body } => {
862                let mut block_scope = scope.child();
863                self.check_block(body, &mut block_scope);
864            }
865
866            Node::Retry { count, body } => {
867                self.check_node(count, scope);
868                let mut retry_scope = scope.child();
869                self.check_block(body, &mut retry_scope);
870            }
871
872            Node::Closure { params, body, .. } => {
873                let mut closure_scope = scope.child();
874                for p in params {
875                    closure_scope.define_var(&p.name, p.type_expr.clone());
876                }
877                self.check_block(body, &mut closure_scope);
878            }
879
880            Node::ListLiteral(elements) => {
881                for elem in elements {
882                    self.check_node(elem, scope);
883                }
884            }
885
886            Node::DictLiteral(entries) | Node::AskExpr { fields: entries } => {
887                for entry in entries {
888                    self.check_node(&entry.key, scope);
889                    self.check_node(&entry.value, scope);
890                }
891            }
892
893            Node::RangeExpr { start, end, .. } => {
894                self.check_node(start, scope);
895                self.check_node(end, scope);
896            }
897
898            Node::Spread(inner) => {
899                self.check_node(inner, scope);
900            }
901
902            Node::Block(stmts) => {
903                let mut block_scope = scope.child();
904                self.check_block(stmts, &mut block_scope);
905            }
906
907            Node::YieldExpr { value } => {
908                if let Some(v) = value {
909                    self.check_node(v, scope);
910                }
911            }
912
913            // --- Struct construction: validate fields against declaration ---
914            Node::StructConstruct {
915                struct_name,
916                fields,
917            } => {
918                for entry in fields {
919                    self.check_node(&entry.key, scope);
920                    self.check_node(&entry.value, scope);
921                }
922                if let Some(declared_fields) = scope.get_struct(struct_name).cloned() {
923                    // Warn on unknown fields
924                    for entry in fields {
925                        if let Node::StringLiteral(key) | Node::Identifier(key) = &entry.key.node {
926                            if !declared_fields.iter().any(|(name, _)| name == key) {
927                                self.warning_at(
928                                    format!("Unknown field '{}' in struct '{}'", key, struct_name),
929                                    entry.key.span,
930                                );
931                            }
932                        }
933                    }
934                    // Warn on missing required fields
935                    let provided: Vec<String> = fields
936                        .iter()
937                        .filter_map(|e| match &e.key.node {
938                            Node::StringLiteral(k) | Node::Identifier(k) => Some(k.clone()),
939                            _ => None,
940                        })
941                        .collect();
942                    for (name, _) in &declared_fields {
943                        if !provided.contains(name) {
944                            self.warning_at(
945                                format!(
946                                    "Missing field '{}' in struct '{}' construction",
947                                    name, struct_name
948                                ),
949                                span,
950                            );
951                        }
952                    }
953                }
954            }
955
956            // --- Enum construction: validate variant exists ---
957            Node::EnumConstruct {
958                enum_name,
959                variant,
960                args,
961            } => {
962                for arg in args {
963                    self.check_node(arg, scope);
964                }
965                if let Some(variants) = scope.get_enum(enum_name) {
966                    if !variants.contains(variant) {
967                        self.warning_at(
968                            format!("Unknown variant '{}' in enum '{}'", variant, enum_name),
969                            span,
970                        );
971                    }
972                }
973            }
974
975            // --- InterpolatedString: segments are lexer-level, no SNode children ---
976            Node::InterpolatedString(_) => {}
977
978            // --- Terminals: no children to check ---
979            Node::StringLiteral(_)
980            | Node::IntLiteral(_)
981            | Node::FloatLiteral(_)
982            | Node::BoolLiteral(_)
983            | Node::NilLiteral
984            | Node::Identifier(_)
985            | Node::DurationLiteral(_)
986            | Node::BreakStmt
987            | Node::ContinueStmt
988            | Node::ReturnStmt { value: None }
989            | Node::ImportDecl { .. }
990            | Node::SelectiveImport { .. } => {}
991
992            // Declarations already handled above; catch remaining variants
993            // that have no meaningful type-check behavior.
994            Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
995                let mut decl_scope = scope.child();
996                self.check_block(body, &mut decl_scope);
997            }
998        }
999    }
1000
1001    fn check_fn_body(
1002        &mut self,
1003        type_params: &[TypeParam],
1004        params: &[TypedParam],
1005        return_type: &Option<TypeExpr>,
1006        body: &[SNode],
1007        where_clauses: &[WhereClause],
1008    ) {
1009        let mut fn_scope = self.scope.child();
1010        // Register generic type parameters so they are treated as compatible
1011        // with any concrete type during type checking.
1012        for tp in type_params {
1013            fn_scope.generic_type_params.insert(tp.name.clone());
1014        }
1015        // Store where-clause constraints for definition-site checking
1016        for wc in where_clauses {
1017            fn_scope
1018                .where_constraints
1019                .insert(wc.type_name.clone(), wc.bound.clone());
1020        }
1021        for param in params {
1022            fn_scope.define_var(&param.name, param.type_expr.clone());
1023            if let Some(default) = &param.default_value {
1024                self.check_node(default, &mut fn_scope);
1025            }
1026        }
1027        self.check_block(body, &mut fn_scope);
1028
1029        // Check return statements against declared return type
1030        if let Some(ret_type) = return_type {
1031            for stmt in body {
1032                self.check_return_type(stmt, ret_type, &fn_scope);
1033            }
1034        }
1035    }
1036
1037    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
1038        let span = snode.span;
1039        match &snode.node {
1040            Node::ReturnStmt { value: Some(val) } => {
1041                let inferred = self.infer_type(val, scope);
1042                if let Some(actual) = &inferred {
1043                    if !self.types_compatible(expected, actual, scope) {
1044                        self.error_at(
1045                            format!(
1046                                "Return type mismatch: expected {}, got {}",
1047                                format_type(expected),
1048                                format_type(actual)
1049                            ),
1050                            span,
1051                        );
1052                    }
1053                }
1054            }
1055            Node::IfElse {
1056                then_body,
1057                else_body,
1058                ..
1059            } => {
1060                for stmt in then_body {
1061                    self.check_return_type(stmt, expected, scope);
1062                }
1063                if let Some(else_body) = else_body {
1064                    for stmt in else_body {
1065                        self.check_return_type(stmt, expected, scope);
1066                    }
1067                }
1068            }
1069            _ => {}
1070        }
1071    }
1072
1073    /// Check if a match expression on an enum's `.variant` property covers all variants.
1074    /// Extract narrowing info from nil-check conditions like `x != nil`.
1075    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
1076    /// Check if a type satisfies an interface (Go-style implicit satisfaction).
1077    /// A type satisfies an interface if its impl block has all the required methods.
1078    fn satisfies_interface(
1079        &self,
1080        type_name: &str,
1081        interface_name: &str,
1082        scope: &TypeScope,
1083    ) -> bool {
1084        self.interface_mismatch_reason(type_name, interface_name, scope)
1085            .is_none()
1086    }
1087
1088    /// Return a detailed reason why a type does not satisfy an interface, or None
1089    /// if it does satisfy it.  Used for producing actionable warning messages.
1090    fn interface_mismatch_reason(
1091        &self,
1092        type_name: &str,
1093        interface_name: &str,
1094        scope: &TypeScope,
1095    ) -> Option<String> {
1096        let interface_methods = match scope.get_interface(interface_name) {
1097            Some(methods) => methods,
1098            None => return Some(format!("interface '{}' not found", interface_name)),
1099        };
1100        let impl_methods = match scope.get_impl_methods(type_name) {
1101            Some(methods) => methods,
1102            None => {
1103                if interface_methods.is_empty() {
1104                    return None;
1105                }
1106                let names: Vec<_> = interface_methods.iter().map(|m| m.name.as_str()).collect();
1107                return Some(format!("missing method(s): {}", names.join(", ")));
1108            }
1109        };
1110        for iface_method in interface_methods {
1111            let iface_params: Vec<_> = iface_method
1112                .params
1113                .iter()
1114                .filter(|p| p.name != "self")
1115                .collect();
1116            let iface_param_count = iface_params.len();
1117            let matching_impl = impl_methods.iter().find(|im| im.name == iface_method.name);
1118            let impl_method = match matching_impl {
1119                Some(m) => m,
1120                None => {
1121                    return Some(format!("missing method '{}'", iface_method.name));
1122                }
1123            };
1124            if impl_method.param_count != iface_param_count {
1125                return Some(format!(
1126                    "method '{}' has {} parameter(s), expected {}",
1127                    iface_method.name, impl_method.param_count, iface_param_count
1128                ));
1129            }
1130            // Check parameter types where both sides specify them
1131            for (i, iface_param) in iface_params.iter().enumerate() {
1132                if let (Some(expected), Some(actual)) = (
1133                    &iface_param.type_expr,
1134                    impl_method.param_types.get(i).and_then(|t| t.as_ref()),
1135                ) {
1136                    if !self.types_compatible(expected, actual, scope) {
1137                        return Some(format!(
1138                            "method '{}' parameter {} has type '{}', expected '{}'",
1139                            iface_method.name,
1140                            i + 1,
1141                            format_type(actual),
1142                            format_type(expected),
1143                        ));
1144                    }
1145                }
1146            }
1147            // Check return type where both sides specify it
1148            if let (Some(expected_ret), Some(actual_ret)) =
1149                (&iface_method.return_type, &impl_method.return_type)
1150            {
1151                if !self.types_compatible(expected_ret, actual_ret, scope) {
1152                    return Some(format!(
1153                        "method '{}' returns '{}', expected '{}'",
1154                        iface_method.name,
1155                        format_type(actual_ret),
1156                        format_type(expected_ret),
1157                    ));
1158                }
1159            }
1160        }
1161        None
1162    }
1163
1164    /// Recursively extract type parameter bindings from matching param/arg types.
1165    /// E.g., param_type=list<T> + arg_type=list<Dog> → binds T=Dog.
1166    fn extract_type_bindings(
1167        param_type: &TypeExpr,
1168        arg_type: &TypeExpr,
1169        type_params: &std::collections::BTreeSet<String>,
1170        bindings: &mut BTreeMap<String, String>,
1171    ) {
1172        match (param_type, arg_type) {
1173            // Direct type param match: T → concrete
1174            (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
1175                if type_params.contains(param_name) =>
1176            {
1177                bindings
1178                    .entry(param_name.clone())
1179                    .or_insert(concrete.clone());
1180            }
1181            // list<T> + list<Dog>
1182            (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
1183                Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
1184            }
1185            // dict<K, V> + dict<string, int>
1186            (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
1187                Self::extract_type_bindings(pk, ak, type_params, bindings);
1188                Self::extract_type_bindings(pv, av, type_params, bindings);
1189            }
1190            _ => {}
1191        }
1192    }
1193
1194    fn extract_nil_narrowing(
1195        condition: &SNode,
1196        scope: &TypeScope,
1197    ) -> Option<(String, InferredType)> {
1198        if let Node::BinaryOp { op, left, right } = &condition.node {
1199            if op == "!=" {
1200                // Check for `x != nil` or `nil != x`
1201                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1202                    (left, right)
1203                } else if matches!(left.node, Node::NilLiteral) {
1204                    (right, left)
1205                } else {
1206                    return None;
1207                };
1208                let _ = nil_node;
1209                if let Node::Identifier(name) = &var_node.node {
1210                    // Look up the variable's type and narrow it
1211                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1212                        let narrowed: Vec<TypeExpr> = members
1213                            .iter()
1214                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1215                            .cloned()
1216                            .collect();
1217                        return if narrowed.len() == 1 {
1218                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1219                        } else if narrowed.is_empty() {
1220                            None
1221                        } else {
1222                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1223                        };
1224                    }
1225                }
1226            }
1227        }
1228        None
1229    }
1230
1231    fn check_match_exhaustiveness(
1232        &mut self,
1233        value: &SNode,
1234        arms: &[MatchArm],
1235        scope: &TypeScope,
1236        span: Span,
1237    ) {
1238        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
1239        let enum_name = match &value.node {
1240            Node::PropertyAccess { object, property } if property == "variant" => {
1241                // Infer the type of the object
1242                match self.infer_type(object, scope) {
1243                    Some(TypeExpr::Named(name)) => {
1244                        if scope.get_enum(&name).is_some() {
1245                            Some(name)
1246                        } else {
1247                            None
1248                        }
1249                    }
1250                    _ => None,
1251                }
1252            }
1253            _ => {
1254                // Direct match on an enum value: match <expr> { ... }
1255                match self.infer_type(value, scope) {
1256                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1257                    _ => None,
1258                }
1259            }
1260        };
1261
1262        let Some(enum_name) = enum_name else {
1263            return;
1264        };
1265        let Some(variants) = scope.get_enum(&enum_name) else {
1266            return;
1267        };
1268
1269        // Collect variant names covered by match arms
1270        let mut covered: Vec<String> = Vec::new();
1271        let mut has_wildcard = false;
1272
1273        for arm in arms {
1274            match &arm.pattern.node {
1275                // String literal pattern (matching on .variant): "VariantA"
1276                Node::StringLiteral(s) => covered.push(s.clone()),
1277                // Identifier pattern acts as a wildcard/catch-all
1278                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1279                    has_wildcard = true;
1280                }
1281                // Direct enum construct pattern: EnumName.Variant
1282                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1283                // PropertyAccess pattern: EnumName.Variant (no args)
1284                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1285                _ => {
1286                    // Unknown pattern shape — conservatively treat as wildcard
1287                    has_wildcard = true;
1288                }
1289            }
1290        }
1291
1292        if has_wildcard {
1293            return;
1294        }
1295
1296        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1297        if !missing.is_empty() {
1298            let missing_str = missing
1299                .iter()
1300                .map(|s| format!("\"{}\"", s))
1301                .collect::<Vec<_>>()
1302                .join(", ");
1303            self.warning_at(
1304                format!(
1305                    "Non-exhaustive match on enum {}: missing variants {}",
1306                    enum_name, missing_str
1307                ),
1308                span,
1309            );
1310        }
1311    }
1312
1313    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1314        // Check against known function signatures
1315        let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1316        if let Some(sig) = scope.get_fn(name).cloned() {
1317            if !has_spread
1318                && !is_builtin(name)
1319                && (args.len() < sig.required_params || args.len() > sig.params.len())
1320            {
1321                let expected = if sig.required_params == sig.params.len() {
1322                    format!("{}", sig.params.len())
1323                } else {
1324                    format!("{}-{}", sig.required_params, sig.params.len())
1325                };
1326                self.warning_at(
1327                    format!(
1328                        "Function '{}' expects {} arguments, got {}",
1329                        name,
1330                        expected,
1331                        args.len()
1332                    ),
1333                    span,
1334                );
1335            }
1336            // Build a scope that includes the function's generic type params
1337            // so they are treated as compatible with any concrete type.
1338            let call_scope = if sig.type_param_names.is_empty() {
1339                scope.clone()
1340            } else {
1341                let mut s = scope.child();
1342                for tp_name in &sig.type_param_names {
1343                    s.generic_type_params.insert(tp_name.clone());
1344                }
1345                s
1346            };
1347            for (i, (arg, (param_name, param_type))) in
1348                args.iter().zip(sig.params.iter()).enumerate()
1349            {
1350                if let Some(expected) = param_type {
1351                    let actual = self.infer_type(arg, scope);
1352                    if let Some(actual) = &actual {
1353                        if !self.types_compatible(expected, actual, &call_scope) {
1354                            self.error_at(
1355                                format!(
1356                                    "Argument {} ('{}'): expected {}, got {}",
1357                                    i + 1,
1358                                    param_name,
1359                                    format_type(expected),
1360                                    format_type(actual)
1361                                ),
1362                                arg.span,
1363                            );
1364                        }
1365                    }
1366                }
1367            }
1368            // Enforce where-clause constraints at call site
1369            if !sig.where_clauses.is_empty() {
1370                // Build mapping: type_param → concrete type from inferred args.
1371                // Recursively walks Generic types so list<T> + list<Dog> binds T=Dog.
1372                let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1373                let type_param_set: std::collections::BTreeSet<String> =
1374                    sig.type_param_names.iter().cloned().collect();
1375                for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1376                    if let Some(param_ty) = param_type {
1377                        if let Some(arg_ty) = self.infer_type(arg, scope) {
1378                            Self::extract_type_bindings(
1379                                param_ty,
1380                                &arg_ty,
1381                                &type_param_set,
1382                                &mut type_bindings,
1383                            );
1384                        }
1385                    }
1386                }
1387                for (type_param, bound) in &sig.where_clauses {
1388                    if let Some(concrete_type) = type_bindings.get(type_param) {
1389                        if let Some(reason) =
1390                            self.interface_mismatch_reason(concrete_type, bound, scope)
1391                        {
1392                            self.warning_at(
1393                                format!(
1394                                    "Type '{}' does not satisfy interface '{}': {} \
1395                                     (required by constraint `where {}: {}`)",
1396                                    concrete_type, bound, reason, type_param, bound
1397                                ),
1398                                span,
1399                            );
1400                        }
1401                    }
1402                }
1403            }
1404        }
1405        // Check args recursively
1406        for arg in args {
1407            self.check_node(arg, scope);
1408        }
1409    }
1410
1411    /// Infer the type of an expression.
1412    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1413        match &snode.node {
1414            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1415            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1416            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1417                Some(TypeExpr::Named("string".into()))
1418            }
1419            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1420            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1421            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1422            Node::DictLiteral(entries) => {
1423                // Infer shape type when all keys are string literals
1424                let mut fields = Vec::new();
1425                let mut all_string_keys = true;
1426                for entry in entries {
1427                    if let Node::StringLiteral(key) = &entry.key.node {
1428                        let val_type = self
1429                            .infer_type(&entry.value, scope)
1430                            .unwrap_or(TypeExpr::Named("nil".into()));
1431                        fields.push(ShapeField {
1432                            name: key.clone(),
1433                            type_expr: val_type,
1434                            optional: false,
1435                        });
1436                    } else {
1437                        all_string_keys = false;
1438                        break;
1439                    }
1440                }
1441                if all_string_keys && !fields.is_empty() {
1442                    Some(TypeExpr::Shape(fields))
1443                } else {
1444                    Some(TypeExpr::Named("dict".into()))
1445                }
1446            }
1447            Node::Closure { params, body, .. } => {
1448                // If all params are typed and we can infer a return type, produce FnType
1449                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1450                if all_typed && !params.is_empty() {
1451                    let param_types: Vec<TypeExpr> =
1452                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1453                    // Try to infer return type from last expression in body
1454                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1455                    if let Some(ret_type) = ret {
1456                        return Some(TypeExpr::FnType {
1457                            params: param_types,
1458                            return_type: Box::new(ret_type),
1459                        });
1460                    }
1461                }
1462                Some(TypeExpr::Named("closure".into()))
1463            }
1464
1465            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1466
1467            Node::FunctionCall { name, .. } => {
1468                // Struct constructor calls return the struct type
1469                if scope.get_struct(name).is_some() {
1470                    return Some(TypeExpr::Named(name.clone()));
1471                }
1472                // Check user-defined function return types
1473                if let Some(sig) = scope.get_fn(name) {
1474                    return sig.return_type.clone();
1475                }
1476                // Check builtin return types
1477                builtin_return_type(name)
1478            }
1479
1480            Node::BinaryOp { op, left, right } => {
1481                let lt = self.infer_type(left, scope);
1482                let rt = self.infer_type(right, scope);
1483                infer_binary_op_type(op, &lt, &rt)
1484            }
1485
1486            Node::UnaryOp { op, operand } => {
1487                let t = self.infer_type(operand, scope);
1488                match op.as_str() {
1489                    "!" => Some(TypeExpr::Named("bool".into())),
1490                    "-" => t, // negation preserves type
1491                    _ => None,
1492                }
1493            }
1494
1495            Node::Ternary {
1496                true_expr,
1497                false_expr,
1498                ..
1499            } => {
1500                let tt = self.infer_type(true_expr, scope);
1501                let ft = self.infer_type(false_expr, scope);
1502                match (&tt, &ft) {
1503                    (Some(a), Some(b)) if a == b => tt,
1504                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1505                    (Some(_), None) => tt,
1506                    (None, Some(_)) => ft,
1507                    (None, None) => None,
1508                }
1509            }
1510
1511            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1512
1513            Node::PropertyAccess { object, property } => {
1514                // EnumName.Variant → infer as the enum type
1515                if let Node::Identifier(name) = &object.node {
1516                    if scope.get_enum(name).is_some() {
1517                        return Some(TypeExpr::Named(name.clone()));
1518                    }
1519                }
1520                // .variant on an enum value → string
1521                if property == "variant" {
1522                    let obj_type = self.infer_type(object, scope);
1523                    if let Some(TypeExpr::Named(name)) = &obj_type {
1524                        if scope.get_enum(name).is_some() {
1525                            return Some(TypeExpr::Named("string".into()));
1526                        }
1527                    }
1528                }
1529                // Shape field access: obj.field → field type
1530                let obj_type = self.infer_type(object, scope);
1531                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1532                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1533                        return Some(field.type_expr.clone());
1534                    }
1535                }
1536                None
1537            }
1538
1539            Node::SubscriptAccess { object, index } => {
1540                let obj_type = self.infer_type(object, scope);
1541                match &obj_type {
1542                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1543                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1544                    Some(TypeExpr::Shape(fields)) => {
1545                        // If index is a string literal, look up the field type
1546                        if let Node::StringLiteral(key) = &index.node {
1547                            fields
1548                                .iter()
1549                                .find(|f| &f.name == key)
1550                                .map(|f| f.type_expr.clone())
1551                        } else {
1552                            None
1553                        }
1554                    }
1555                    Some(TypeExpr::Named(n)) if n == "list" => None,
1556                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1557                    Some(TypeExpr::Named(n)) if n == "string" => {
1558                        Some(TypeExpr::Named("string".into()))
1559                    }
1560                    _ => None,
1561                }
1562            }
1563            Node::SliceAccess { object, .. } => {
1564                // Slicing a list returns the same list type; slicing a string returns string
1565                let obj_type = self.infer_type(object, scope);
1566                match &obj_type {
1567                    Some(TypeExpr::List(_)) => obj_type,
1568                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1569                    Some(TypeExpr::Named(n)) if n == "string" => {
1570                        Some(TypeExpr::Named("string".into()))
1571                    }
1572                    _ => None,
1573                }
1574            }
1575            Node::MethodCall { object, method, .. }
1576            | Node::OptionalMethodCall { object, method, .. } => {
1577                let obj_type = self.infer_type(object, scope);
1578                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1579                    || matches!(&obj_type, Some(TypeExpr::DictType(..)))
1580                    || matches!(&obj_type, Some(TypeExpr::Shape(_)));
1581                match method.as_str() {
1582                    // Shared: bool-returning methods
1583                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1584                        Some(TypeExpr::Named("bool".into()))
1585                    }
1586                    // Shared: int-returning methods
1587                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1588                    // String methods
1589                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1590                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1591                        Some(TypeExpr::Named("string".into()))
1592                    }
1593                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1594                    // filter returns dict for dicts, list for lists
1595                    "filter" => {
1596                        if is_dict {
1597                            Some(TypeExpr::Named("dict".into()))
1598                        } else {
1599                            Some(TypeExpr::Named("list".into()))
1600                        }
1601                    }
1602                    // List methods
1603                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1604                    "reduce" | "find" | "first" | "last" => None,
1605                    // Dict methods
1606                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1607                    "merge" | "map_values" | "rekey" | "map_keys" => {
1608                        // Rekey/map_keys transform keys; resulting dict still keys-by-string.
1609                        // Preserve the value-type parameter when known so downstream code can
1610                        // still rely on dict<string, V> typing after a key-rename.
1611                        if let Some(TypeExpr::DictType(_, v)) = &obj_type {
1612                            Some(TypeExpr::DictType(
1613                                Box::new(TypeExpr::Named("string".into())),
1614                                v.clone(),
1615                            ))
1616                        } else {
1617                            Some(TypeExpr::Named("dict".into()))
1618                        }
1619                    }
1620                    // Conversions
1621                    "to_string" => Some(TypeExpr::Named("string".into())),
1622                    "to_int" => Some(TypeExpr::Named("int".into())),
1623                    "to_float" => Some(TypeExpr::Named("float".into())),
1624                    _ => None,
1625                }
1626            }
1627
1628            // TryOperator on Result<T, E> produces T
1629            Node::TryOperator { operand } => {
1630                match self.infer_type(operand, scope) {
1631                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1632                    _ => None,
1633                }
1634            }
1635
1636            _ => None,
1637        }
1638    }
1639
1640    /// Check if two types are compatible (actual can be assigned to expected).
1641    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1642        // Generic type parameters match anything.
1643        if let TypeExpr::Named(name) = expected {
1644            if scope.is_generic_type_param(name) {
1645                return true;
1646            }
1647        }
1648        if let TypeExpr::Named(name) = actual {
1649            if scope.is_generic_type_param(name) {
1650                return true;
1651            }
1652        }
1653        let expected = self.resolve_alias(expected, scope);
1654        let actual = self.resolve_alias(actual, scope);
1655
1656        // Interface satisfaction: if expected is an interface name, check if actual type
1657        // has all required methods (Go-style implicit satisfaction).
1658        if let TypeExpr::Named(iface_name) = &expected {
1659            if scope.get_interface(iface_name).is_some() {
1660                if let TypeExpr::Named(type_name) = &actual {
1661                    return self.satisfies_interface(type_name, iface_name, scope);
1662                }
1663                return false;
1664            }
1665        }
1666
1667        match (&expected, &actual) {
1668            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1669            (TypeExpr::Union(members), actual_type) => members
1670                .iter()
1671                .any(|m| self.types_compatible(m, actual_type, scope)),
1672            (expected_type, TypeExpr::Union(members)) => members
1673                .iter()
1674                .all(|m| self.types_compatible(expected_type, m, scope)),
1675            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1676            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1677            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1678                if expected_field.optional {
1679                    return true;
1680                }
1681                af.iter().any(|actual_field| {
1682                    actual_field.name == expected_field.name
1683                        && self.types_compatible(
1684                            &expected_field.type_expr,
1685                            &actual_field.type_expr,
1686                            scope,
1687                        )
1688                })
1689            }),
1690            // dict<K, V> expected, Shape actual → all field values must match V
1691            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1692                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1693                keys_ok
1694                    && af
1695                        .iter()
1696                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1697            }
1698            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1699            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1700            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1701                self.types_compatible(expected_inner, actual_inner, scope)
1702            }
1703            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1704            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1705            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1706                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1707            }
1708            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1709            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1710            // FnType compatibility: params match positionally and return types match
1711            (
1712                TypeExpr::FnType {
1713                    params: ep,
1714                    return_type: er,
1715                },
1716                TypeExpr::FnType {
1717                    params: ap,
1718                    return_type: ar,
1719                },
1720            ) => {
1721                ep.len() == ap.len()
1722                    && ep
1723                        .iter()
1724                        .zip(ap.iter())
1725                        .all(|(e, a)| self.types_compatible(e, a, scope))
1726                    && self.types_compatible(er, ar, scope)
1727            }
1728            // FnType is compatible with Named("closure") for backward compat
1729            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1730            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1731            _ => false,
1732        }
1733    }
1734
1735    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1736        if let TypeExpr::Named(name) = ty {
1737            if let Some(resolved) = scope.resolve_type(name) {
1738                return resolved.clone();
1739            }
1740        }
1741        ty.clone()
1742    }
1743
1744    fn error_at(&mut self, message: String, span: Span) {
1745        self.diagnostics.push(TypeDiagnostic {
1746            message,
1747            severity: DiagnosticSeverity::Error,
1748            span: Some(span),
1749            help: None,
1750        });
1751    }
1752
1753    #[allow(dead_code)]
1754    fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
1755        self.diagnostics.push(TypeDiagnostic {
1756            message,
1757            severity: DiagnosticSeverity::Error,
1758            span: Some(span),
1759            help: Some(help),
1760        });
1761    }
1762
1763    fn warning_at(&mut self, message: String, span: Span) {
1764        self.diagnostics.push(TypeDiagnostic {
1765            message,
1766            severity: DiagnosticSeverity::Warning,
1767            span: Some(span),
1768            help: None,
1769        });
1770    }
1771
1772    #[allow(dead_code)]
1773    fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
1774        self.diagnostics.push(TypeDiagnostic {
1775            message,
1776            severity: DiagnosticSeverity::Warning,
1777            span: Some(span),
1778            help: Some(help),
1779        });
1780    }
1781}
1782
1783impl Default for TypeChecker {
1784    fn default() -> Self {
1785        Self::new()
1786    }
1787}
1788
1789/// Infer the result type of a binary operation.
1790fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1791    match op {
1792        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
1793            Some(TypeExpr::Named("bool".into()))
1794        }
1795        "+" => match (left, right) {
1796            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1797                match (l.as_str(), r.as_str()) {
1798                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1799                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1800                    ("string", _) => Some(TypeExpr::Named("string".into())),
1801                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1802                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1803                    _ => Some(TypeExpr::Named("string".into())),
1804                }
1805            }
1806            _ => None,
1807        },
1808        "-" | "*" | "/" | "%" => match (left, right) {
1809            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1810                match (l.as_str(), r.as_str()) {
1811                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1812                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1813                    _ => None,
1814                }
1815            }
1816            _ => None,
1817        },
1818        "??" => match (left, right) {
1819            (Some(TypeExpr::Union(members)), _) => {
1820                let non_nil: Vec<_> = members
1821                    .iter()
1822                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1823                    .cloned()
1824                    .collect();
1825                if non_nil.len() == 1 {
1826                    Some(non_nil[0].clone())
1827                } else if non_nil.is_empty() {
1828                    right.clone()
1829                } else {
1830                    Some(TypeExpr::Union(non_nil))
1831                }
1832            }
1833            _ => right.clone(),
1834        },
1835        "|>" => None,
1836        _ => None,
1837    }
1838}
1839
1840/// Format a type expression for display in error messages.
1841/// Produce a detail string describing why a Shape type is incompatible with
1842/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1843/// has type int, expected string".  Returns `None` if both types are not shapes.
1844pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1845    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1846        let mut details = Vec::new();
1847        for field in ef {
1848            if field.optional {
1849                continue;
1850            }
1851            match af.iter().find(|f| f.name == field.name) {
1852                None => details.push(format!(
1853                    "missing field '{}' ({})",
1854                    field.name,
1855                    format_type(&field.type_expr)
1856                )),
1857                Some(actual_field) => {
1858                    let e_str = format_type(&field.type_expr);
1859                    let a_str = format_type(&actual_field.type_expr);
1860                    if e_str != a_str {
1861                        details.push(format!(
1862                            "field '{}' has type {}, expected {}",
1863                            field.name, a_str, e_str
1864                        ));
1865                    }
1866                }
1867            }
1868        }
1869        if details.is_empty() {
1870            None
1871        } else {
1872            Some(details.join("; "))
1873        }
1874    } else {
1875        None
1876    }
1877}
1878
1879pub fn format_type(ty: &TypeExpr) -> String {
1880    match ty {
1881        TypeExpr::Named(n) => n.clone(),
1882        TypeExpr::Union(types) => types
1883            .iter()
1884            .map(format_type)
1885            .collect::<Vec<_>>()
1886            .join(" | "),
1887        TypeExpr::Shape(fields) => {
1888            let inner: Vec<String> = fields
1889                .iter()
1890                .map(|f| {
1891                    let opt = if f.optional { "?" } else { "" };
1892                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1893                })
1894                .collect();
1895            format!("{{{}}}", inner.join(", "))
1896        }
1897        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1898        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1899        TypeExpr::FnType {
1900            params,
1901            return_type,
1902        } => {
1903            let params_str = params
1904                .iter()
1905                .map(format_type)
1906                .collect::<Vec<_>>()
1907                .join(", ");
1908            format!("fn({}) -> {}", params_str, format_type(return_type))
1909        }
1910    }
1911}
1912
1913#[cfg(test)]
1914mod tests {
1915    use super::*;
1916    use crate::Parser;
1917    use harn_lexer::Lexer;
1918
1919    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1920        let mut lexer = Lexer::new(source);
1921        let tokens = lexer.tokenize().unwrap();
1922        let mut parser = Parser::new(tokens);
1923        let program = parser.parse().unwrap();
1924        TypeChecker::new().check(&program)
1925    }
1926
1927    fn errors(source: &str) -> Vec<String> {
1928        check_source(source)
1929            .into_iter()
1930            .filter(|d| d.severity == DiagnosticSeverity::Error)
1931            .map(|d| d.message)
1932            .collect()
1933    }
1934
1935    #[test]
1936    fn test_no_errors_for_untyped_code() {
1937        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1938        assert!(errs.is_empty());
1939    }
1940
1941    #[test]
1942    fn test_correct_typed_let() {
1943        let errs = errors("pipeline t(task) { let x: int = 42 }");
1944        assert!(errs.is_empty());
1945    }
1946
1947    #[test]
1948    fn test_type_mismatch_let() {
1949        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1950        assert_eq!(errs.len(), 1);
1951        assert!(errs[0].contains("Type mismatch"));
1952        assert!(errs[0].contains("int"));
1953        assert!(errs[0].contains("string"));
1954    }
1955
1956    #[test]
1957    fn test_correct_typed_fn() {
1958        let errs = errors(
1959            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1960        );
1961        assert!(errs.is_empty());
1962    }
1963
1964    #[test]
1965    fn test_fn_arg_type_mismatch() {
1966        let errs = errors(
1967            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1968add("hello", 2) }"#,
1969        );
1970        assert_eq!(errs.len(), 1);
1971        assert!(errs[0].contains("Argument 1"));
1972        assert!(errs[0].contains("expected int"));
1973    }
1974
1975    #[test]
1976    fn test_return_type_mismatch() {
1977        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1978        assert_eq!(errs.len(), 1);
1979        assert!(errs[0].contains("Return type mismatch"));
1980    }
1981
1982    #[test]
1983    fn test_union_type_compatible() {
1984        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1985        assert!(errs.is_empty());
1986    }
1987
1988    #[test]
1989    fn test_union_type_mismatch() {
1990        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1991        assert_eq!(errs.len(), 1);
1992        assert!(errs[0].contains("Type mismatch"));
1993    }
1994
1995    #[test]
1996    fn test_type_inference_propagation() {
1997        let errs = errors(
1998            r#"pipeline t(task) {
1999  fn add(a: int, b: int) -> int { return a + b }
2000  let result: string = add(1, 2)
2001}"#,
2002        );
2003        assert_eq!(errs.len(), 1);
2004        assert!(errs[0].contains("Type mismatch"));
2005        assert!(errs[0].contains("string"));
2006        assert!(errs[0].contains("int"));
2007    }
2008
2009    #[test]
2010    fn test_builtin_return_type_inference() {
2011        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
2012        assert_eq!(errs.len(), 1);
2013        assert!(errs[0].contains("string"));
2014        assert!(errs[0].contains("int"));
2015    }
2016
2017    #[test]
2018    fn test_workflow_and_transcript_builtins_are_known() {
2019        let errs = errors(
2020            r#"pipeline t(task) {
2021  let flow = workflow_graph({name: "demo", entry: "act", nodes: {act: {kind: "stage"}}})
2022  let report: dict = workflow_policy_report(flow, {tools: tool_registry(), capabilities: {workspace: ["read_text"]}})
2023  let run: dict = workflow_execute("task", flow, [], {})
2024  let tree: dict = load_run_tree("run.json")
2025  let fixture: dict = run_record_fixture(run?.run)
2026  let suite: dict = run_record_eval_suite([{run: run?.run, fixture: fixture}])
2027  let diff: dict = run_record_diff(run?.run, run?.run)
2028  let manifest: dict = eval_suite_manifest({cases: [{run_path: "run.json"}]})
2029  let suite_report: dict = eval_suite_run(manifest)
2030  let wf: dict = artifact_workspace_file("src/main.rs", "fn main() {}", {source: "host"})
2031  let snap: dict = artifact_workspace_snapshot(["src/main.rs"], "snapshot")
2032  let selection: dict = artifact_editor_selection("src/main.rs", "main")
2033  let verify: dict = artifact_verification_result("verify", "ok")
2034  let test_result: dict = artifact_test_result("tests", "pass")
2035  let cmd: dict = artifact_command_result("cargo test", {status: 0})
2036  let patch: dict = artifact_diff("src/main.rs", "old", "new")
2037  let git: dict = artifact_git_diff("diff --git a b")
2038  let review: dict = artifact_diff_review(patch, "review me")
2039  let decision: dict = artifact_review_decision(review, "accepted")
2040  let proposal: dict = artifact_patch_proposal(review, "*** Begin Patch")
2041  let bundle: dict = artifact_verification_bundle("checks", [{name: "fmt", ok: true}])
2042  let apply: dict = artifact_apply_intent(review, "apply")
2043  let transcript = transcript_reset({metadata: {source: "test"}})
2044  let visible: string = transcript_render_visible(transcript_archive(transcript))
2045  let events: list = transcript_events(transcript)
2046  let context: string = artifact_context([], {max_artifacts: 1})
2047  println(report)
2048  println(run)
2049  println(tree)
2050  println(fixture)
2051  println(suite)
2052  println(diff)
2053  println(manifest)
2054  println(suite_report)
2055  println(wf)
2056  println(snap)
2057  println(selection)
2058  println(verify)
2059  println(test_result)
2060  println(cmd)
2061  println(patch)
2062  println(git)
2063  println(review)
2064  println(decision)
2065  println(proposal)
2066  println(bundle)
2067  println(apply)
2068  println(visible)
2069  println(events)
2070  println(context)
2071}"#,
2072        );
2073        assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
2074    }
2075
2076    #[test]
2077    fn test_binary_op_type_inference() {
2078        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
2079        assert_eq!(errs.len(), 1);
2080    }
2081
2082    #[test]
2083    fn test_comparison_returns_bool() {
2084        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
2085        assert!(errs.is_empty());
2086    }
2087
2088    #[test]
2089    fn test_int_float_promotion() {
2090        let errs = errors("pipeline t(task) { let x: float = 42 }");
2091        assert!(errs.is_empty());
2092    }
2093
2094    #[test]
2095    fn test_untyped_code_no_errors() {
2096        let errs = errors(
2097            r#"pipeline t(task) {
2098  fn process(data) {
2099    let result = data + " processed"
2100    return result
2101  }
2102  log(process("hello"))
2103}"#,
2104        );
2105        assert!(errs.is_empty());
2106    }
2107
2108    #[test]
2109    fn test_type_alias() {
2110        let errs = errors(
2111            r#"pipeline t(task) {
2112  type Name = string
2113  let x: Name = "hello"
2114}"#,
2115        );
2116        assert!(errs.is_empty());
2117    }
2118
2119    #[test]
2120    fn test_type_alias_mismatch() {
2121        let errs = errors(
2122            r#"pipeline t(task) {
2123  type Name = string
2124  let x: Name = 42
2125}"#,
2126        );
2127        assert_eq!(errs.len(), 1);
2128    }
2129
2130    #[test]
2131    fn test_assignment_type_check() {
2132        let errs = errors(
2133            r#"pipeline t(task) {
2134  var x: int = 0
2135  x = "hello"
2136}"#,
2137        );
2138        assert_eq!(errs.len(), 1);
2139        assert!(errs[0].contains("cannot assign string"));
2140    }
2141
2142    #[test]
2143    fn test_covariance_int_to_float_in_fn() {
2144        let errs = errors(
2145            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
2146        );
2147        assert!(errs.is_empty());
2148    }
2149
2150    #[test]
2151    fn test_covariance_return_type() {
2152        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
2153        assert!(errs.is_empty());
2154    }
2155
2156    #[test]
2157    fn test_no_contravariance_float_to_int() {
2158        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
2159        assert_eq!(errs.len(), 1);
2160    }
2161
2162    // --- Exhaustiveness checking tests ---
2163
2164    fn warnings(source: &str) -> Vec<String> {
2165        check_source(source)
2166            .into_iter()
2167            .filter(|d| d.severity == DiagnosticSeverity::Warning)
2168            .map(|d| d.message)
2169            .collect()
2170    }
2171
2172    #[test]
2173    fn test_exhaustive_match_no_warning() {
2174        let warns = warnings(
2175            r#"pipeline t(task) {
2176  enum Color { Red, Green, Blue }
2177  let c = Color.Red
2178  match c.variant {
2179    "Red" -> { log("r") }
2180    "Green" -> { log("g") }
2181    "Blue" -> { log("b") }
2182  }
2183}"#,
2184        );
2185        let exhaustive_warns: Vec<_> = warns
2186            .iter()
2187            .filter(|w| w.contains("Non-exhaustive"))
2188            .collect();
2189        assert!(exhaustive_warns.is_empty());
2190    }
2191
2192    #[test]
2193    fn test_non_exhaustive_match_warning() {
2194        let warns = warnings(
2195            r#"pipeline t(task) {
2196  enum Color { Red, Green, Blue }
2197  let c = Color.Red
2198  match c.variant {
2199    "Red" -> { log("r") }
2200    "Green" -> { log("g") }
2201  }
2202}"#,
2203        );
2204        let exhaustive_warns: Vec<_> = warns
2205            .iter()
2206            .filter(|w| w.contains("Non-exhaustive"))
2207            .collect();
2208        assert_eq!(exhaustive_warns.len(), 1);
2209        assert!(exhaustive_warns[0].contains("Blue"));
2210    }
2211
2212    #[test]
2213    fn test_non_exhaustive_multiple_missing() {
2214        let warns = warnings(
2215            r#"pipeline t(task) {
2216  enum Status { Active, Inactive, Pending }
2217  let s = Status.Active
2218  match s.variant {
2219    "Active" -> { log("a") }
2220  }
2221}"#,
2222        );
2223        let exhaustive_warns: Vec<_> = warns
2224            .iter()
2225            .filter(|w| w.contains("Non-exhaustive"))
2226            .collect();
2227        assert_eq!(exhaustive_warns.len(), 1);
2228        assert!(exhaustive_warns[0].contains("Inactive"));
2229        assert!(exhaustive_warns[0].contains("Pending"));
2230    }
2231
2232    #[test]
2233    fn test_enum_construct_type_inference() {
2234        let errs = errors(
2235            r#"pipeline t(task) {
2236  enum Color { Red, Green, Blue }
2237  let c: Color = Color.Red
2238}"#,
2239        );
2240        assert!(errs.is_empty());
2241    }
2242
2243    // --- Type narrowing tests ---
2244
2245    #[test]
2246    fn test_nil_coalescing_strips_nil() {
2247        // After ??, nil should be stripped from the type
2248        let errs = errors(
2249            r#"pipeline t(task) {
2250  let x: string | nil = nil
2251  let y: string = x ?? "default"
2252}"#,
2253        );
2254        assert!(errs.is_empty());
2255    }
2256
2257    #[test]
2258    fn test_shape_mismatch_detail_missing_field() {
2259        let errs = errors(
2260            r#"pipeline t(task) {
2261  let x: {name: string, age: int} = {name: "hello"}
2262}"#,
2263        );
2264        assert_eq!(errs.len(), 1);
2265        assert!(
2266            errs[0].contains("missing field 'age'"),
2267            "expected detail about missing field, got: {}",
2268            errs[0]
2269        );
2270    }
2271
2272    #[test]
2273    fn test_shape_mismatch_detail_wrong_type() {
2274        let errs = errors(
2275            r#"pipeline t(task) {
2276  let x: {name: string, age: int} = {name: 42, age: 10}
2277}"#,
2278        );
2279        assert_eq!(errs.len(), 1);
2280        assert!(
2281            errs[0].contains("field 'name' has type int, expected string"),
2282            "expected detail about wrong type, got: {}",
2283            errs[0]
2284        );
2285    }
2286
2287    // --- Match pattern type validation tests ---
2288
2289    #[test]
2290    fn test_match_pattern_string_against_int() {
2291        let warns = warnings(
2292            r#"pipeline t(task) {
2293  let x: int = 42
2294  match x {
2295    "hello" -> { log("bad") }
2296    42 -> { log("ok") }
2297  }
2298}"#,
2299        );
2300        let pattern_warns: Vec<_> = warns
2301            .iter()
2302            .filter(|w| w.contains("Match pattern type mismatch"))
2303            .collect();
2304        assert_eq!(pattern_warns.len(), 1);
2305        assert!(pattern_warns[0].contains("matching int against string literal"));
2306    }
2307
2308    #[test]
2309    fn test_match_pattern_int_against_string() {
2310        let warns = warnings(
2311            r#"pipeline t(task) {
2312  let x: string = "hello"
2313  match x {
2314    42 -> { log("bad") }
2315    "hello" -> { log("ok") }
2316  }
2317}"#,
2318        );
2319        let pattern_warns: Vec<_> = warns
2320            .iter()
2321            .filter(|w| w.contains("Match pattern type mismatch"))
2322            .collect();
2323        assert_eq!(pattern_warns.len(), 1);
2324        assert!(pattern_warns[0].contains("matching string against int literal"));
2325    }
2326
2327    #[test]
2328    fn test_match_pattern_bool_against_int() {
2329        let warns = warnings(
2330            r#"pipeline t(task) {
2331  let x: int = 42
2332  match x {
2333    true -> { log("bad") }
2334    42 -> { log("ok") }
2335  }
2336}"#,
2337        );
2338        let pattern_warns: Vec<_> = warns
2339            .iter()
2340            .filter(|w| w.contains("Match pattern type mismatch"))
2341            .collect();
2342        assert_eq!(pattern_warns.len(), 1);
2343        assert!(pattern_warns[0].contains("matching int against bool literal"));
2344    }
2345
2346    #[test]
2347    fn test_match_pattern_float_against_string() {
2348        let warns = warnings(
2349            r#"pipeline t(task) {
2350  let x: string = "hello"
2351  match x {
2352    3.14 -> { log("bad") }
2353    "hello" -> { log("ok") }
2354  }
2355}"#,
2356        );
2357        let pattern_warns: Vec<_> = warns
2358            .iter()
2359            .filter(|w| w.contains("Match pattern type mismatch"))
2360            .collect();
2361        assert_eq!(pattern_warns.len(), 1);
2362        assert!(pattern_warns[0].contains("matching string against float literal"));
2363    }
2364
2365    #[test]
2366    fn test_match_pattern_int_against_float_ok() {
2367        // int and float are compatible for match patterns
2368        let warns = warnings(
2369            r#"pipeline t(task) {
2370  let x: float = 3.14
2371  match x {
2372    42 -> { log("ok") }
2373    _ -> { log("default") }
2374  }
2375}"#,
2376        );
2377        let pattern_warns: Vec<_> = warns
2378            .iter()
2379            .filter(|w| w.contains("Match pattern type mismatch"))
2380            .collect();
2381        assert!(pattern_warns.is_empty());
2382    }
2383
2384    #[test]
2385    fn test_match_pattern_float_against_int_ok() {
2386        // float and int are compatible for match patterns
2387        let warns = warnings(
2388            r#"pipeline t(task) {
2389  let x: int = 42
2390  match x {
2391    3.14 -> { log("close") }
2392    _ -> { log("default") }
2393  }
2394}"#,
2395        );
2396        let pattern_warns: Vec<_> = warns
2397            .iter()
2398            .filter(|w| w.contains("Match pattern type mismatch"))
2399            .collect();
2400        assert!(pattern_warns.is_empty());
2401    }
2402
2403    #[test]
2404    fn test_match_pattern_correct_types_no_warning() {
2405        let warns = warnings(
2406            r#"pipeline t(task) {
2407  let x: int = 42
2408  match x {
2409    1 -> { log("one") }
2410    2 -> { log("two") }
2411    _ -> { log("other") }
2412  }
2413}"#,
2414        );
2415        let pattern_warns: Vec<_> = warns
2416            .iter()
2417            .filter(|w| w.contains("Match pattern type mismatch"))
2418            .collect();
2419        assert!(pattern_warns.is_empty());
2420    }
2421
2422    #[test]
2423    fn test_match_pattern_wildcard_no_warning() {
2424        let warns = warnings(
2425            r#"pipeline t(task) {
2426  let x: int = 42
2427  match x {
2428    _ -> { log("catch all") }
2429  }
2430}"#,
2431        );
2432        let pattern_warns: Vec<_> = warns
2433            .iter()
2434            .filter(|w| w.contains("Match pattern type mismatch"))
2435            .collect();
2436        assert!(pattern_warns.is_empty());
2437    }
2438
2439    #[test]
2440    fn test_match_pattern_untyped_no_warning() {
2441        // When value has no known type, no warning should be emitted
2442        let warns = warnings(
2443            r#"pipeline t(task) {
2444  let x = some_unknown_fn()
2445  match x {
2446    "hello" -> { log("string") }
2447    42 -> { log("int") }
2448  }
2449}"#,
2450        );
2451        let pattern_warns: Vec<_> = warns
2452            .iter()
2453            .filter(|w| w.contains("Match pattern type mismatch"))
2454            .collect();
2455        assert!(pattern_warns.is_empty());
2456    }
2457
2458    // --- Interface constraint type checking tests ---
2459
2460    fn iface_warns(source: &str) -> Vec<String> {
2461        warnings(source)
2462            .into_iter()
2463            .filter(|w| w.contains("does not satisfy interface"))
2464            .collect()
2465    }
2466
2467    #[test]
2468    fn test_interface_constraint_return_type_mismatch() {
2469        let warns = iface_warns(
2470            r#"pipeline t(task) {
2471  interface Sizable {
2472    fn size(self) -> int
2473  }
2474  struct Box { width: int }
2475  impl Box {
2476    fn size(self) -> string { return "nope" }
2477  }
2478  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2479  measure(Box({width: 3}))
2480}"#,
2481        );
2482        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2483        assert!(
2484            warns[0].contains("method 'size' returns 'string', expected 'int'"),
2485            "unexpected message: {}",
2486            warns[0]
2487        );
2488    }
2489
2490    #[test]
2491    fn test_interface_constraint_param_type_mismatch() {
2492        let warns = iface_warns(
2493            r#"pipeline t(task) {
2494  interface Processor {
2495    fn process(self, x: int) -> string
2496  }
2497  struct MyProc { name: string }
2498  impl MyProc {
2499    fn process(self, x: string) -> string { return x }
2500  }
2501  fn run_proc<T>(p: T) where T: Processor { log(p.process(42)) }
2502  run_proc(MyProc({name: "a"}))
2503}"#,
2504        );
2505        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2506        assert!(
2507            warns[0].contains("method 'process' parameter 1 has type 'string', expected 'int'"),
2508            "unexpected message: {}",
2509            warns[0]
2510        );
2511    }
2512
2513    #[test]
2514    fn test_interface_constraint_missing_method() {
2515        let warns = iface_warns(
2516            r#"pipeline t(task) {
2517  interface Sizable {
2518    fn size(self) -> int
2519  }
2520  struct Box { width: int }
2521  impl Box {
2522    fn area(self) -> int { return self.width }
2523  }
2524  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2525  measure(Box({width: 3}))
2526}"#,
2527        );
2528        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2529        assert!(
2530            warns[0].contains("missing method 'size'"),
2531            "unexpected message: {}",
2532            warns[0]
2533        );
2534    }
2535
2536    #[test]
2537    fn test_interface_constraint_param_count_mismatch() {
2538        let warns = iface_warns(
2539            r#"pipeline t(task) {
2540  interface Doubler {
2541    fn double(self, x: int) -> int
2542  }
2543  struct Bad { v: int }
2544  impl Bad {
2545    fn double(self) -> int { return self.v * 2 }
2546  }
2547  fn run_double<T>(d: T) where T: Doubler { log(d.double(3)) }
2548  run_double(Bad({v: 5}))
2549}"#,
2550        );
2551        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2552        assert!(
2553            warns[0].contains("method 'double' has 0 parameter(s), expected 1"),
2554            "unexpected message: {}",
2555            warns[0]
2556        );
2557    }
2558
2559    #[test]
2560    fn test_interface_constraint_satisfied() {
2561        let warns = iface_warns(
2562            r#"pipeline t(task) {
2563  interface Sizable {
2564    fn size(self) -> int
2565  }
2566  struct Box { width: int, height: int }
2567  impl Box {
2568    fn size(self) -> int { return self.width * self.height }
2569  }
2570  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2571  measure(Box({width: 3, height: 4}))
2572}"#,
2573        );
2574        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2575    }
2576
2577    #[test]
2578    fn test_interface_constraint_untyped_impl_compatible() {
2579        // Gradual typing: untyped impl return should not trigger warning
2580        let warns = iface_warns(
2581            r#"pipeline t(task) {
2582  interface Sizable {
2583    fn size(self) -> int
2584  }
2585  struct Box { width: int }
2586  impl Box {
2587    fn size(self) { return self.width }
2588  }
2589  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2590  measure(Box({width: 3}))
2591}"#,
2592        );
2593        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2594    }
2595
2596    #[test]
2597    fn test_interface_constraint_int_float_covariance() {
2598        // int is compatible with float (covariance)
2599        let warns = iface_warns(
2600            r#"pipeline t(task) {
2601  interface Measurable {
2602    fn value(self) -> float
2603  }
2604  struct Gauge { v: int }
2605  impl Gauge {
2606    fn value(self) -> int { return self.v }
2607  }
2608  fn read_val<T>(g: T) where T: Measurable { log(g.value()) }
2609  read_val(Gauge({v: 42}))
2610}"#,
2611        );
2612        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2613    }
2614}