Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12    pub help: Option<String>,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum DiagnosticSeverity {
17    Error,
18    Warning,
19}
20
21/// Inferred type of an expression. None means unknown/untyped (gradual typing).
22type InferredType = Option<TypeExpr>;
23
24/// Scope for tracking variable types.
25#[derive(Debug, Clone)]
26struct TypeScope {
27    /// Variable name → inferred type.
28    vars: BTreeMap<String, InferredType>,
29    /// Function name → (param types, return type).
30    functions: BTreeMap<String, FnSignature>,
31    /// Named type aliases.
32    type_aliases: BTreeMap<String, TypeExpr>,
33    /// Enum declarations: name → variant names.
34    enums: BTreeMap<String, Vec<String>>,
35    /// Interface declarations: name → method signatures.
36    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
37    /// Struct declarations: name → field types.
38    structs: BTreeMap<String, Vec<(String, InferredType)>>,
39    /// Impl block methods: type_name → method signatures.
40    impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
41    /// Generic type parameter names in scope (treated as compatible with any type).
42    generic_type_params: std::collections::BTreeSet<String>,
43    /// Where-clause constraints: type_param → interface_bound.
44    /// Used for definition-site checking of generic function bodies.
45    where_constraints: BTreeMap<String, String>,
46    parent: Option<Box<TypeScope>>,
47}
48
49/// Method signature extracted from an impl block (for interface checking).
50#[derive(Debug, Clone)]
51struct ImplMethodSig {
52    name: String,
53    /// Number of parameters excluding `self`.
54    param_count: usize,
55    /// Parameter types (excluding `self`), None means untyped.
56    param_types: Vec<Option<TypeExpr>>,
57    /// Return type, None means untyped.
58    return_type: Option<TypeExpr>,
59}
60
61#[derive(Debug, Clone)]
62struct FnSignature {
63    params: Vec<(String, InferredType)>,
64    return_type: InferredType,
65    /// Generic type parameter names declared on the function.
66    type_param_names: Vec<String>,
67    /// Number of required parameters (those without defaults).
68    required_params: usize,
69    /// Where-clause constraints: (type_param_name, interface_bound).
70    where_clauses: Vec<(String, String)>,
71}
72
73impl TypeScope {
74    fn new() -> Self {
75        Self {
76            vars: BTreeMap::new(),
77            functions: BTreeMap::new(),
78            type_aliases: BTreeMap::new(),
79            enums: BTreeMap::new(),
80            interfaces: BTreeMap::new(),
81            structs: BTreeMap::new(),
82            impl_methods: BTreeMap::new(),
83            generic_type_params: std::collections::BTreeSet::new(),
84            where_constraints: BTreeMap::new(),
85            parent: None,
86        }
87    }
88
89    fn child(&self) -> Self {
90        Self {
91            vars: BTreeMap::new(),
92            functions: BTreeMap::new(),
93            type_aliases: BTreeMap::new(),
94            enums: BTreeMap::new(),
95            interfaces: BTreeMap::new(),
96            structs: BTreeMap::new(),
97            impl_methods: BTreeMap::new(),
98            generic_type_params: std::collections::BTreeSet::new(),
99            where_constraints: BTreeMap::new(),
100            parent: Some(Box::new(self.clone())),
101        }
102    }
103
104    fn get_var(&self, name: &str) -> Option<&InferredType> {
105        self.vars
106            .get(name)
107            .or_else(|| self.parent.as_ref()?.get_var(name))
108    }
109
110    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
111        self.functions
112            .get(name)
113            .or_else(|| self.parent.as_ref()?.get_fn(name))
114    }
115
116    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
117        self.type_aliases
118            .get(name)
119            .or_else(|| self.parent.as_ref()?.resolve_type(name))
120    }
121
122    fn is_generic_type_param(&self, name: &str) -> bool {
123        self.generic_type_params.contains(name)
124            || self
125                .parent
126                .as_ref()
127                .is_some_and(|p| p.is_generic_type_param(name))
128    }
129
130    fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
131        self.where_constraints
132            .get(type_param)
133            .map(|s| s.as_str())
134            .or_else(|| {
135                self.parent
136                    .as_ref()
137                    .and_then(|p| p.get_where_constraint(type_param))
138            })
139    }
140
141    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
142        self.enums
143            .get(name)
144            .or_else(|| self.parent.as_ref()?.get_enum(name))
145    }
146
147    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
148        self.interfaces
149            .get(name)
150            .or_else(|| self.parent.as_ref()?.get_interface(name))
151    }
152
153    fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
154        self.structs
155            .get(name)
156            .or_else(|| self.parent.as_ref()?.get_struct(name))
157    }
158
159    fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
160        self.impl_methods
161            .get(name)
162            .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
163    }
164
165    fn define_var(&mut self, name: &str, ty: InferredType) {
166        self.vars.insert(name.to_string(), ty);
167    }
168
169    fn define_fn(&mut self, name: &str, sig: FnSignature) {
170        self.functions.insert(name.to_string(), sig);
171    }
172}
173
174/// Known return types for builtin functions.
175fn builtin_return_type(name: &str) -> InferredType {
176    match name {
177        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
178        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
179            Some(TypeExpr::Named("nil".into()))
180        }
181        "type_of"
182        | "to_string"
183        | "json_stringify"
184        | "read_file"
185        | "http_get"
186        | "http_post"
187        | "regex_replace"
188        | "path_join"
189        | "temp_dir"
190        | "date_format"
191        | "format"
192        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
193        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
194        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
195        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
196            Some(TypeExpr::Named("float".into()))
197        }
198        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
199            Some(TypeExpr::Named("bool".into()))
200        }
201        "list_dir"
202        | "mcp_list_tools"
203        | "mcp_list_resources"
204        | "mcp_list_prompts"
205        | "to_list"
206        | "regex_captures"
207        | "artifact_select"
208        | "transcript_messages"
209        | "transcript_events" => Some(TypeExpr::Named("list".into())),
210        "stat"
211        | "exec"
212        | "exec_at"
213        | "shell"
214        | "shell_at"
215        | "date_now"
216        | "llm_call"
217        | "llm_completion"
218        | "agent_loop"
219        | "llm_info"
220        | "llm_usage"
221        | "timer_start"
222        | "metadata_get"
223        | "mcp_server_info"
224        | "mcp_get_prompt"
225        | "llm_pick_model"
226        | "transcript"
227        | "transcript_from_messages"
228        | "transcript_reset"
229        | "transcript_archive"
230        | "transcript_abandon"
231        | "transcript_resume"
232        | "workflow_graph"
233        | "workflow_validate"
234        | "workflow_inspect"
235        | "workflow_policy_report"
236        | "workflow_clone"
237        | "workflow_insert_node"
238        | "workflow_replace_node"
239        | "workflow_rewire"
240        | "workflow_set_model_policy"
241        | "workflow_set_context_policy"
242        | "workflow_set_transcript_policy"
243        | "workflow_diff"
244        | "workflow_commit"
245        | "artifact"
246        | "artifact_derive"
247        | "artifact_workspace_file"
248        | "artifact_workspace_snapshot"
249        | "artifact_editor_selection"
250        | "artifact_verification_result"
251        | "artifact_test_result"
252        | "artifact_command_result"
253        | "artifact_diff"
254        | "artifact_git_diff"
255        | "artifact_diff_review"
256        | "artifact_review_decision"
257        | "artifact_patch_proposal"
258        | "artifact_verification_bundle"
259        | "artifact_apply_intent"
260        | "run_record"
261        | "load_run_tree"
262        | "run_record_save"
263        | "run_record_load"
264        | "run_record_fixture"
265        | "run_record_eval"
266        | "run_record_eval_suite"
267        | "run_record_diff"
268        | "eval_suite_manifest"
269        | "eval_suite_run"
270        | "workflow_execute"
271        | "resume_agent"
272        | "transcript_compact"
273        | "transcript_summarize"
274        | "host_capabilities" => Some(TypeExpr::Named("dict".into())),
275        "transcript_render_visible"
276        | "transcript_render_full"
277        | "artifact_context"
278        | "transcript_export"
279        | "transcript_id" => Some(TypeExpr::Named("string".into())),
280        "transcript_summary" => Some(TypeExpr::Union(vec![
281            TypeExpr::Named("string".into()),
282            TypeExpr::Named("nil".into()),
283        ])),
284        "host_has" => Some(TypeExpr::Named("bool".into())),
285        "metadata_set"
286        | "metadata_save"
287        | "metadata_refresh_hashes"
288        | "invalidate_facts"
289        | "log_json"
290        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
291        "env" | "regex_match" => Some(TypeExpr::Union(vec![
292            TypeExpr::Named("string".into()),
293            TypeExpr::Named("nil".into()),
294        ])),
295        "json_parse" | "json_extract" => None, // could be any type
296        _ => None,
297    }
298}
299
300/// Check if a name is a known builtin.
301fn is_builtin(name: &str) -> bool {
302    matches!(
303        name,
304        "log"
305            | "print"
306            | "println"
307            | "type_of"
308            | "to_string"
309            | "to_int"
310            | "to_float"
311            | "json_stringify"
312            | "json_parse"
313            | "env"
314            | "timestamp"
315            | "sleep"
316            | "read_file"
317            | "write_file"
318            | "exit"
319            | "regex_match"
320            | "regex_replace"
321            | "regex_captures"
322            | "http_get"
323            | "http_post"
324            | "llm_call"
325            | "llm_completion"
326            | "agent_loop"
327            | "llm_pick_model"
328            | "await"
329            | "cancel"
330            | "file_exists"
331            | "delete_file"
332            | "list_dir"
333            | "mkdir"
334            | "path_join"
335            | "copy_file"
336            | "append_file"
337            | "temp_dir"
338            | "transcript"
339            | "transcript_from_messages"
340            | "transcript_messages"
341            | "transcript_events"
342            | "transcript_summary"
343            | "transcript_id"
344            | "transcript_export"
345            | "transcript_import"
346            | "transcript_fork"
347            | "transcript_reset"
348            | "transcript_archive"
349            | "transcript_abandon"
350            | "transcript_resume"
351            | "transcript_render_visible"
352            | "transcript_render_full"
353            | "transcript_compact"
354            | "transcript_summarize"
355            | "host_capabilities"
356            | "host_has"
357            | "host_invoke"
358            | "workflow_graph"
359            | "workflow_validate"
360            | "workflow_inspect"
361            | "workflow_policy_report"
362            | "workflow_clone"
363            | "workflow_insert_node"
364            | "workflow_replace_node"
365            | "workflow_rewire"
366            | "workflow_set_model_policy"
367            | "workflow_set_context_policy"
368            | "workflow_set_transcript_policy"
369            | "workflow_diff"
370            | "workflow_commit"
371            | "workflow_execute"
372            | "resume_agent"
373            | "artifact"
374            | "artifact_derive"
375            | "artifact_workspace_file"
376            | "artifact_workspace_snapshot"
377            | "artifact_editor_selection"
378            | "artifact_verification_result"
379            | "artifact_test_result"
380            | "artifact_command_result"
381            | "artifact_diff"
382            | "artifact_git_diff"
383            | "artifact_diff_review"
384            | "artifact_review_decision"
385            | "artifact_patch_proposal"
386            | "artifact_verification_bundle"
387            | "artifact_apply_intent"
388            | "artifact_select"
389            | "artifact_context"
390            | "run_record"
391            | "load_run_tree"
392            | "run_record_save"
393            | "run_record_load"
394            | "run_record_fixture"
395            | "run_record_eval"
396            | "run_record_eval_suite"
397            | "run_record_diff"
398            | "eval_suite_manifest"
399            | "eval_suite_run"
400            | "stat"
401            | "exec"
402            | "exec_at"
403            | "shell"
404            | "shell_at"
405            | "date_now"
406            | "date_format"
407            | "date_parse"
408            | "format"
409            | "json_validate"
410            | "json_extract"
411            | "trim"
412            | "lowercase"
413            | "uppercase"
414            | "split"
415            | "starts_with"
416            | "ends_with"
417            | "contains"
418            | "replace"
419            | "join"
420            | "len"
421            | "substring"
422            | "dirname"
423            | "basename"
424            | "extname"
425            | "sin"
426            | "cos"
427            | "tan"
428            | "asin"
429            | "acos"
430            | "atan"
431            | "atan2"
432            | "log2"
433            | "log10"
434            | "ln"
435            | "exp"
436            | "pi"
437            | "e"
438            | "sign"
439            | "is_nan"
440            | "is_infinite"
441            | "set"
442            | "set_add"
443            | "set_remove"
444            | "set_contains"
445            | "set_union"
446            | "set_intersect"
447            | "set_difference"
448            | "to_list"
449    )
450}
451
452/// The static type checker.
453pub struct TypeChecker {
454    diagnostics: Vec<TypeDiagnostic>,
455    scope: TypeScope,
456}
457
458impl TypeChecker {
459    pub fn new() -> Self {
460        Self {
461            diagnostics: Vec::new(),
462            scope: TypeScope::new(),
463        }
464    }
465
466    /// Check a program and return diagnostics.
467    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
468        // First pass: register type and enum declarations into root scope
469        Self::register_declarations_into(&mut self.scope, program);
470
471        // Also scan pipeline bodies for declarations
472        for snode in program {
473            if let Node::Pipeline { body, .. } = &snode.node {
474                Self::register_declarations_into(&mut self.scope, body);
475            }
476        }
477
478        // Check each top-level node
479        for snode in program {
480            match &snode.node {
481                Node::Pipeline { params, body, .. } => {
482                    let mut child = self.scope.child();
483                    for p in params {
484                        child.define_var(p, None);
485                    }
486                    self.check_block(body, &mut child);
487                }
488                Node::FnDecl {
489                    name,
490                    type_params,
491                    params,
492                    return_type,
493                    where_clauses,
494                    body,
495                    ..
496                } => {
497                    let required_params =
498                        params.iter().filter(|p| p.default_value.is_none()).count();
499                    let sig = FnSignature {
500                        params: params
501                            .iter()
502                            .map(|p| (p.name.clone(), p.type_expr.clone()))
503                            .collect(),
504                        return_type: return_type.clone(),
505                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
506                        required_params,
507                        where_clauses: where_clauses
508                            .iter()
509                            .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
510                            .collect(),
511                    };
512                    self.scope.define_fn(name, sig);
513                    self.check_fn_body(type_params, params, return_type, body, where_clauses);
514                }
515                _ => {
516                    let mut scope = self.scope.clone();
517                    self.check_node(snode, &mut scope);
518                    // Merge any new definitions back into the top-level scope
519                    for (name, ty) in scope.vars {
520                        self.scope.vars.entry(name).or_insert(ty);
521                    }
522                }
523            }
524        }
525
526        self.diagnostics
527    }
528
529    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
530    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
531        for snode in nodes {
532            match &snode.node {
533                Node::TypeDecl { name, type_expr } => {
534                    scope.type_aliases.insert(name.clone(), type_expr.clone());
535                }
536                Node::EnumDecl { name, variants } => {
537                    let variant_names: Vec<String> =
538                        variants.iter().map(|v| v.name.clone()).collect();
539                    scope.enums.insert(name.clone(), variant_names);
540                }
541                Node::InterfaceDecl { name, methods } => {
542                    scope.interfaces.insert(name.clone(), methods.clone());
543                }
544                Node::StructDecl { name, fields } => {
545                    let field_types: Vec<(String, InferredType)> = fields
546                        .iter()
547                        .map(|f| (f.name.clone(), f.type_expr.clone()))
548                        .collect();
549                    scope.structs.insert(name.clone(), field_types);
550                }
551                Node::ImplBlock {
552                    type_name, methods, ..
553                } => {
554                    let sigs: Vec<ImplMethodSig> = methods
555                        .iter()
556                        .filter_map(|m| {
557                            if let Node::FnDecl {
558                                name,
559                                params,
560                                return_type,
561                                ..
562                            } = &m.node
563                            {
564                                let non_self: Vec<_> =
565                                    params.iter().filter(|p| p.name != "self").collect();
566                                let param_count = non_self.len();
567                                let param_types: Vec<Option<TypeExpr>> =
568                                    non_self.iter().map(|p| p.type_expr.clone()).collect();
569                                Some(ImplMethodSig {
570                                    name: name.clone(),
571                                    param_count,
572                                    param_types,
573                                    return_type: return_type.clone(),
574                                })
575                            } else {
576                                None
577                            }
578                        })
579                        .collect();
580                    scope.impl_methods.insert(type_name.clone(), sigs);
581                }
582                _ => {}
583            }
584        }
585    }
586
587    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
588        for stmt in stmts {
589            self.check_node(stmt, scope);
590        }
591    }
592
593    /// Define variables from a destructuring pattern in the given scope (as unknown type).
594    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
595        match pattern {
596            BindingPattern::Identifier(name) => {
597                scope.define_var(name, None);
598            }
599            BindingPattern::Dict(fields) => {
600                for field in fields {
601                    let name = field.alias.as_deref().unwrap_or(&field.key);
602                    scope.define_var(name, None);
603                }
604            }
605            BindingPattern::List(elements) => {
606                for elem in elements {
607                    scope.define_var(&elem.name, None);
608                }
609            }
610        }
611    }
612
613    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
614        let span = snode.span;
615        match &snode.node {
616            Node::LetBinding {
617                pattern,
618                type_ann,
619                value,
620            } => {
621                let inferred = self.infer_type(value, scope);
622                if let BindingPattern::Identifier(name) = pattern {
623                    if let Some(expected) = type_ann {
624                        if let Some(actual) = &inferred {
625                            if !self.types_compatible(expected, actual, scope) {
626                                let mut msg = format!(
627                                    "Type mismatch: '{}' declared as {}, but assigned {}",
628                                    name,
629                                    format_type(expected),
630                                    format_type(actual)
631                                );
632                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
633                                    msg.push_str(&format!(" ({})", detail));
634                                }
635                                self.error_at(msg, span);
636                            }
637                        }
638                    }
639                    let ty = type_ann.clone().or(inferred);
640                    scope.define_var(name, ty);
641                } else {
642                    Self::define_pattern_vars(pattern, scope);
643                }
644            }
645
646            Node::VarBinding {
647                pattern,
648                type_ann,
649                value,
650            } => {
651                let inferred = self.infer_type(value, scope);
652                if let BindingPattern::Identifier(name) = pattern {
653                    if let Some(expected) = type_ann {
654                        if let Some(actual) = &inferred {
655                            if !self.types_compatible(expected, actual, scope) {
656                                let mut msg = format!(
657                                    "Type mismatch: '{}' declared as {}, but assigned {}",
658                                    name,
659                                    format_type(expected),
660                                    format_type(actual)
661                                );
662                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
663                                    msg.push_str(&format!(" ({})", detail));
664                                }
665                                self.error_at(msg, span);
666                            }
667                        }
668                    }
669                    let ty = type_ann.clone().or(inferred);
670                    scope.define_var(name, ty);
671                } else {
672                    Self::define_pattern_vars(pattern, scope);
673                }
674            }
675
676            Node::FnDecl {
677                name,
678                type_params,
679                params,
680                return_type,
681                where_clauses,
682                body,
683                ..
684            } => {
685                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
686                let sig = FnSignature {
687                    params: params
688                        .iter()
689                        .map(|p| (p.name.clone(), p.type_expr.clone()))
690                        .collect(),
691                    return_type: return_type.clone(),
692                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
693                    required_params,
694                    where_clauses: where_clauses
695                        .iter()
696                        .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
697                        .collect(),
698                };
699                scope.define_fn(name, sig.clone());
700                scope.define_var(name, None);
701                self.check_fn_body(type_params, params, return_type, body, where_clauses);
702            }
703
704            Node::FunctionCall { name, args } => {
705                self.check_call(name, args, scope, span);
706            }
707
708            Node::IfElse {
709                condition,
710                then_body,
711                else_body,
712            } => {
713                self.check_node(condition, scope);
714                let mut then_scope = scope.child();
715                // Narrow union types after nil checks: if x != nil, narrow x
716                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
717                    then_scope.define_var(&var_name, narrowed);
718                }
719                self.check_block(then_body, &mut then_scope);
720                if let Some(else_body) = else_body {
721                    let mut else_scope = scope.child();
722                    self.check_block(else_body, &mut else_scope);
723                }
724            }
725
726            Node::ForIn {
727                pattern,
728                iterable,
729                body,
730            } => {
731                self.check_node(iterable, scope);
732                let mut loop_scope = scope.child();
733                if let BindingPattern::Identifier(variable) = pattern {
734                    // Infer loop variable type from iterable
735                    let elem_type = match self.infer_type(iterable, scope) {
736                        Some(TypeExpr::List(inner)) => Some(*inner),
737                        Some(TypeExpr::Named(n)) if n == "string" => {
738                            Some(TypeExpr::Named("string".into()))
739                        }
740                        _ => None,
741                    };
742                    loop_scope.define_var(variable, elem_type);
743                } else {
744                    Self::define_pattern_vars(pattern, &mut loop_scope);
745                }
746                self.check_block(body, &mut loop_scope);
747            }
748
749            Node::WhileLoop { condition, body } => {
750                self.check_node(condition, scope);
751                let mut loop_scope = scope.child();
752                self.check_block(body, &mut loop_scope);
753            }
754
755            Node::TryCatch {
756                body,
757                error_var,
758                catch_body,
759                finally_body,
760                ..
761            } => {
762                let mut try_scope = scope.child();
763                self.check_block(body, &mut try_scope);
764                let mut catch_scope = scope.child();
765                if let Some(var) = error_var {
766                    catch_scope.define_var(var, None);
767                }
768                self.check_block(catch_body, &mut catch_scope);
769                if let Some(fb) = finally_body {
770                    let mut finally_scope = scope.child();
771                    self.check_block(fb, &mut finally_scope);
772                }
773            }
774
775            Node::TryExpr { body } => {
776                let mut try_scope = scope.child();
777                self.check_block(body, &mut try_scope);
778            }
779
780            Node::ReturnStmt {
781                value: Some(val), ..
782            } => {
783                self.check_node(val, scope);
784            }
785
786            Node::Assignment {
787                target, value, op, ..
788            } => {
789                self.check_node(value, scope);
790                if let Node::Identifier(name) = &target.node {
791                    if let Some(Some(var_type)) = scope.get_var(name) {
792                        let value_type = self.infer_type(value, scope);
793                        let assigned = if let Some(op) = op {
794                            let var_inferred = scope.get_var(name).cloned().flatten();
795                            infer_binary_op_type(op, &var_inferred, &value_type)
796                        } else {
797                            value_type
798                        };
799                        if let Some(actual) = &assigned {
800                            if !self.types_compatible(var_type, actual, scope) {
801                                self.error_at(
802                                    format!(
803                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
804                                        format_type(actual),
805                                        name,
806                                        format_type(var_type)
807                                    ),
808                                    span,
809                                );
810                            }
811                        }
812                    }
813                }
814            }
815
816            Node::TypeDecl { name, type_expr } => {
817                scope.type_aliases.insert(name.clone(), type_expr.clone());
818            }
819
820            Node::EnumDecl { name, variants } => {
821                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
822                scope.enums.insert(name.clone(), variant_names);
823            }
824
825            Node::StructDecl { name, fields } => {
826                let field_types: Vec<(String, InferredType)> = fields
827                    .iter()
828                    .map(|f| (f.name.clone(), f.type_expr.clone()))
829                    .collect();
830                scope.structs.insert(name.clone(), field_types);
831            }
832
833            Node::InterfaceDecl { name, methods } => {
834                scope.interfaces.insert(name.clone(), methods.clone());
835            }
836
837            Node::ImplBlock {
838                type_name, methods, ..
839            } => {
840                // Register impl methods for interface satisfaction checking
841                let sigs: Vec<ImplMethodSig> = methods
842                    .iter()
843                    .filter_map(|m| {
844                        if let Node::FnDecl {
845                            name,
846                            params,
847                            return_type,
848                            ..
849                        } = &m.node
850                        {
851                            let non_self: Vec<_> =
852                                params.iter().filter(|p| p.name != "self").collect();
853                            let param_count = non_self.len();
854                            let param_types: Vec<Option<TypeExpr>> =
855                                non_self.iter().map(|p| p.type_expr.clone()).collect();
856                            Some(ImplMethodSig {
857                                name: name.clone(),
858                                param_count,
859                                param_types,
860                                return_type: return_type.clone(),
861                            })
862                        } else {
863                            None
864                        }
865                    })
866                    .collect();
867                scope.impl_methods.insert(type_name.clone(), sigs);
868                for method_sn in methods {
869                    self.check_node(method_sn, scope);
870                }
871            }
872
873            Node::TryOperator { operand } => {
874                self.check_node(operand, scope);
875            }
876
877            Node::MatchExpr { value, arms } => {
878                self.check_node(value, scope);
879                let value_type = self.infer_type(value, scope);
880                for arm in arms {
881                    self.check_node(&arm.pattern, scope);
882                    // Check for incompatible literal pattern types
883                    if let Some(ref vt) = value_type {
884                        let value_type_name = format_type(vt);
885                        let mismatch = match &arm.pattern.node {
886                            Node::StringLiteral(_) => {
887                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
888                            }
889                            Node::IntLiteral(_) => {
890                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
891                                    && !self.types_compatible(
892                                        vt,
893                                        &TypeExpr::Named("float".into()),
894                                        scope,
895                                    )
896                            }
897                            Node::FloatLiteral(_) => {
898                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
899                                    && !self.types_compatible(
900                                        vt,
901                                        &TypeExpr::Named("int".into()),
902                                        scope,
903                                    )
904                            }
905                            Node::BoolLiteral(_) => {
906                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
907                            }
908                            _ => false,
909                        };
910                        if mismatch {
911                            let pattern_type = match &arm.pattern.node {
912                                Node::StringLiteral(_) => "string",
913                                Node::IntLiteral(_) => "int",
914                                Node::FloatLiteral(_) => "float",
915                                Node::BoolLiteral(_) => "bool",
916                                _ => unreachable!(),
917                            };
918                            self.warning_at(
919                                format!(
920                                    "Match pattern type mismatch: matching {} against {} literal",
921                                    value_type_name, pattern_type
922                                ),
923                                arm.pattern.span,
924                            );
925                        }
926                    }
927                    let mut arm_scope = scope.child();
928                    self.check_block(&arm.body, &mut arm_scope);
929                }
930                self.check_match_exhaustiveness(value, arms, scope, span);
931            }
932
933            // Recurse into nested expressions + validate binary op types
934            Node::BinaryOp { op, left, right } => {
935                self.check_node(left, scope);
936                self.check_node(right, scope);
937                // Validate operator/type compatibility
938                let lt = self.infer_type(left, scope);
939                let rt = self.infer_type(right, scope);
940                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
941                    match op.as_str() {
942                        "-" | "*" | "/" | "%" => {
943                            let numeric = ["int", "float"];
944                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
945                                self.warning_at(
946                                    format!(
947                                        "Operator '{op}' may not be valid for types {} and {}",
948                                        l, r
949                                    ),
950                                    span,
951                                );
952                            }
953                        }
954                        "+" => {
955                            // + is valid for int, float, string, list, dict
956                            let valid = ["int", "float", "string", "list", "dict"];
957                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
958                                self.warning_at(
959                                    format!(
960                                        "Operator '+' may not be valid for types {} and {}",
961                                        l, r
962                                    ),
963                                    span,
964                                );
965                            }
966                        }
967                        _ => {}
968                    }
969                }
970            }
971            Node::UnaryOp { operand, .. } => {
972                self.check_node(operand, scope);
973            }
974            Node::MethodCall {
975                object,
976                method,
977                args,
978                ..
979            }
980            | Node::OptionalMethodCall {
981                object,
982                method,
983                args,
984                ..
985            } => {
986                self.check_node(object, scope);
987                for arg in args {
988                    self.check_node(arg, scope);
989                }
990                // Definition-site generic checking: if the object's type is a
991                // constrained generic param (where T: Interface), verify the
992                // method exists in the bound interface.
993                if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
994                    if scope.is_generic_type_param(&type_name) {
995                        if let Some(iface_name) = scope.get_where_constraint(&type_name) {
996                            if let Some(iface_methods) = scope.get_interface(iface_name) {
997                                let has_method = iface_methods.iter().any(|m| m.name == *method);
998                                if !has_method {
999                                    self.warning_at(
1000                                        format!(
1001                                            "Method '{}' not found in interface '{}' (constraint on '{}')",
1002                                            method, iface_name, type_name
1003                                        ),
1004                                        span,
1005                                    );
1006                                }
1007                            }
1008                        }
1009                    }
1010                }
1011            }
1012            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
1013                self.check_node(object, scope);
1014            }
1015            Node::SubscriptAccess { object, index } => {
1016                self.check_node(object, scope);
1017                self.check_node(index, scope);
1018            }
1019            Node::SliceAccess { object, start, end } => {
1020                self.check_node(object, scope);
1021                if let Some(s) = start {
1022                    self.check_node(s, scope);
1023                }
1024                if let Some(e) = end {
1025                    self.check_node(e, scope);
1026                }
1027            }
1028
1029            // Terminals — nothing to check
1030            _ => {}
1031        }
1032    }
1033
1034    fn check_fn_body(
1035        &mut self,
1036        type_params: &[TypeParam],
1037        params: &[TypedParam],
1038        return_type: &Option<TypeExpr>,
1039        body: &[SNode],
1040        where_clauses: &[WhereClause],
1041    ) {
1042        let mut fn_scope = self.scope.child();
1043        // Register generic type parameters so they are treated as compatible
1044        // with any concrete type during type checking.
1045        for tp in type_params {
1046            fn_scope.generic_type_params.insert(tp.name.clone());
1047        }
1048        // Store where-clause constraints for definition-site checking
1049        for wc in where_clauses {
1050            fn_scope
1051                .where_constraints
1052                .insert(wc.type_name.clone(), wc.bound.clone());
1053        }
1054        for param in params {
1055            fn_scope.define_var(&param.name, param.type_expr.clone());
1056            if let Some(default) = &param.default_value {
1057                self.check_node(default, &mut fn_scope);
1058            }
1059        }
1060        self.check_block(body, &mut fn_scope);
1061
1062        // Check return statements against declared return type
1063        if let Some(ret_type) = return_type {
1064            for stmt in body {
1065                self.check_return_type(stmt, ret_type, &fn_scope);
1066            }
1067        }
1068    }
1069
1070    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
1071        let span = snode.span;
1072        match &snode.node {
1073            Node::ReturnStmt { value: Some(val) } => {
1074                let inferred = self.infer_type(val, scope);
1075                if let Some(actual) = &inferred {
1076                    if !self.types_compatible(expected, actual, scope) {
1077                        self.error_at(
1078                            format!(
1079                                "Return type mismatch: expected {}, got {}",
1080                                format_type(expected),
1081                                format_type(actual)
1082                            ),
1083                            span,
1084                        );
1085                    }
1086                }
1087            }
1088            Node::IfElse {
1089                then_body,
1090                else_body,
1091                ..
1092            } => {
1093                for stmt in then_body {
1094                    self.check_return_type(stmt, expected, scope);
1095                }
1096                if let Some(else_body) = else_body {
1097                    for stmt in else_body {
1098                        self.check_return_type(stmt, expected, scope);
1099                    }
1100                }
1101            }
1102            _ => {}
1103        }
1104    }
1105
1106    /// Check if a match expression on an enum's `.variant` property covers all variants.
1107    /// Extract narrowing info from nil-check conditions like `x != nil`.
1108    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
1109    /// Check if a type satisfies an interface (Go-style implicit satisfaction).
1110    /// A type satisfies an interface if its impl block has all the required methods.
1111    fn satisfies_interface(
1112        &self,
1113        type_name: &str,
1114        interface_name: &str,
1115        scope: &TypeScope,
1116    ) -> bool {
1117        self.interface_mismatch_reason(type_name, interface_name, scope)
1118            .is_none()
1119    }
1120
1121    /// Return a detailed reason why a type does not satisfy an interface, or None
1122    /// if it does satisfy it.  Used for producing actionable warning messages.
1123    fn interface_mismatch_reason(
1124        &self,
1125        type_name: &str,
1126        interface_name: &str,
1127        scope: &TypeScope,
1128    ) -> Option<String> {
1129        let interface_methods = match scope.get_interface(interface_name) {
1130            Some(methods) => methods,
1131            None => return Some(format!("interface '{}' not found", interface_name)),
1132        };
1133        let impl_methods = match scope.get_impl_methods(type_name) {
1134            Some(methods) => methods,
1135            None => {
1136                if interface_methods.is_empty() {
1137                    return None;
1138                }
1139                let names: Vec<_> = interface_methods.iter().map(|m| m.name.as_str()).collect();
1140                return Some(format!("missing method(s): {}", names.join(", ")));
1141            }
1142        };
1143        for iface_method in interface_methods {
1144            let iface_params: Vec<_> = iface_method
1145                .params
1146                .iter()
1147                .filter(|p| p.name != "self")
1148                .collect();
1149            let iface_param_count = iface_params.len();
1150            let matching_impl = impl_methods.iter().find(|im| im.name == iface_method.name);
1151            let impl_method = match matching_impl {
1152                Some(m) => m,
1153                None => {
1154                    return Some(format!("missing method '{}'", iface_method.name));
1155                }
1156            };
1157            if impl_method.param_count != iface_param_count {
1158                return Some(format!(
1159                    "method '{}' has {} parameter(s), expected {}",
1160                    iface_method.name, impl_method.param_count, iface_param_count
1161                ));
1162            }
1163            // Check parameter types where both sides specify them
1164            for (i, iface_param) in iface_params.iter().enumerate() {
1165                if let (Some(expected), Some(actual)) = (
1166                    &iface_param.type_expr,
1167                    impl_method.param_types.get(i).and_then(|t| t.as_ref()),
1168                ) {
1169                    if !self.types_compatible(expected, actual, scope) {
1170                        return Some(format!(
1171                            "method '{}' parameter {} has type '{}', expected '{}'",
1172                            iface_method.name,
1173                            i + 1,
1174                            format_type(actual),
1175                            format_type(expected),
1176                        ));
1177                    }
1178                }
1179            }
1180            // Check return type where both sides specify it
1181            if let (Some(expected_ret), Some(actual_ret)) =
1182                (&iface_method.return_type, &impl_method.return_type)
1183            {
1184                if !self.types_compatible(expected_ret, actual_ret, scope) {
1185                    return Some(format!(
1186                        "method '{}' returns '{}', expected '{}'",
1187                        iface_method.name,
1188                        format_type(actual_ret),
1189                        format_type(expected_ret),
1190                    ));
1191                }
1192            }
1193        }
1194        None
1195    }
1196
1197    /// Recursively extract type parameter bindings from matching param/arg types.
1198    /// E.g., param_type=list<T> + arg_type=list<Dog> → binds T=Dog.
1199    fn extract_type_bindings(
1200        param_type: &TypeExpr,
1201        arg_type: &TypeExpr,
1202        type_params: &std::collections::BTreeSet<String>,
1203        bindings: &mut BTreeMap<String, String>,
1204    ) {
1205        match (param_type, arg_type) {
1206            // Direct type param match: T → concrete
1207            (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
1208                if type_params.contains(param_name) =>
1209            {
1210                bindings
1211                    .entry(param_name.clone())
1212                    .or_insert(concrete.clone());
1213            }
1214            // list<T> + list<Dog>
1215            (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
1216                Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
1217            }
1218            // dict<K, V> + dict<string, int>
1219            (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
1220                Self::extract_type_bindings(pk, ak, type_params, bindings);
1221                Self::extract_type_bindings(pv, av, type_params, bindings);
1222            }
1223            _ => {}
1224        }
1225    }
1226
1227    fn extract_nil_narrowing(
1228        condition: &SNode,
1229        scope: &TypeScope,
1230    ) -> Option<(String, InferredType)> {
1231        if let Node::BinaryOp { op, left, right } = &condition.node {
1232            if op == "!=" {
1233                // Check for `x != nil` or `nil != x`
1234                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1235                    (left, right)
1236                } else if matches!(left.node, Node::NilLiteral) {
1237                    (right, left)
1238                } else {
1239                    return None;
1240                };
1241                let _ = nil_node;
1242                if let Node::Identifier(name) = &var_node.node {
1243                    // Look up the variable's type and narrow it
1244                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1245                        let narrowed: Vec<TypeExpr> = members
1246                            .iter()
1247                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1248                            .cloned()
1249                            .collect();
1250                        return if narrowed.len() == 1 {
1251                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1252                        } else if narrowed.is_empty() {
1253                            None
1254                        } else {
1255                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1256                        };
1257                    }
1258                }
1259            }
1260        }
1261        None
1262    }
1263
1264    fn check_match_exhaustiveness(
1265        &mut self,
1266        value: &SNode,
1267        arms: &[MatchArm],
1268        scope: &TypeScope,
1269        span: Span,
1270    ) {
1271        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
1272        let enum_name = match &value.node {
1273            Node::PropertyAccess { object, property } if property == "variant" => {
1274                // Infer the type of the object
1275                match self.infer_type(object, scope) {
1276                    Some(TypeExpr::Named(name)) => {
1277                        if scope.get_enum(&name).is_some() {
1278                            Some(name)
1279                        } else {
1280                            None
1281                        }
1282                    }
1283                    _ => None,
1284                }
1285            }
1286            _ => {
1287                // Direct match on an enum value: match <expr> { ... }
1288                match self.infer_type(value, scope) {
1289                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1290                    _ => None,
1291                }
1292            }
1293        };
1294
1295        let Some(enum_name) = enum_name else {
1296            return;
1297        };
1298        let Some(variants) = scope.get_enum(&enum_name) else {
1299            return;
1300        };
1301
1302        // Collect variant names covered by match arms
1303        let mut covered: Vec<String> = Vec::new();
1304        let mut has_wildcard = false;
1305
1306        for arm in arms {
1307            match &arm.pattern.node {
1308                // String literal pattern (matching on .variant): "VariantA"
1309                Node::StringLiteral(s) => covered.push(s.clone()),
1310                // Identifier pattern acts as a wildcard/catch-all
1311                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1312                    has_wildcard = true;
1313                }
1314                // Direct enum construct pattern: EnumName.Variant
1315                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1316                // PropertyAccess pattern: EnumName.Variant (no args)
1317                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1318                _ => {
1319                    // Unknown pattern shape — conservatively treat as wildcard
1320                    has_wildcard = true;
1321                }
1322            }
1323        }
1324
1325        if has_wildcard {
1326            return;
1327        }
1328
1329        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1330        if !missing.is_empty() {
1331            let missing_str = missing
1332                .iter()
1333                .map(|s| format!("\"{}\"", s))
1334                .collect::<Vec<_>>()
1335                .join(", ");
1336            self.warning_at(
1337                format!(
1338                    "Non-exhaustive match on enum {}: missing variants {}",
1339                    enum_name, missing_str
1340                ),
1341                span,
1342            );
1343        }
1344    }
1345
1346    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1347        // Check against known function signatures
1348        let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1349        if let Some(sig) = scope.get_fn(name).cloned() {
1350            if !has_spread
1351                && !is_builtin(name)
1352                && (args.len() < sig.required_params || args.len() > sig.params.len())
1353            {
1354                let expected = if sig.required_params == sig.params.len() {
1355                    format!("{}", sig.params.len())
1356                } else {
1357                    format!("{}-{}", sig.required_params, sig.params.len())
1358                };
1359                self.warning_at(
1360                    format!(
1361                        "Function '{}' expects {} arguments, got {}",
1362                        name,
1363                        expected,
1364                        args.len()
1365                    ),
1366                    span,
1367                );
1368            }
1369            // Build a scope that includes the function's generic type params
1370            // so they are treated as compatible with any concrete type.
1371            let call_scope = if sig.type_param_names.is_empty() {
1372                scope.clone()
1373            } else {
1374                let mut s = scope.child();
1375                for tp_name in &sig.type_param_names {
1376                    s.generic_type_params.insert(tp_name.clone());
1377                }
1378                s
1379            };
1380            for (i, (arg, (param_name, param_type))) in
1381                args.iter().zip(sig.params.iter()).enumerate()
1382            {
1383                if let Some(expected) = param_type {
1384                    let actual = self.infer_type(arg, scope);
1385                    if let Some(actual) = &actual {
1386                        if !self.types_compatible(expected, actual, &call_scope) {
1387                            self.error_at(
1388                                format!(
1389                                    "Argument {} ('{}'): expected {}, got {}",
1390                                    i + 1,
1391                                    param_name,
1392                                    format_type(expected),
1393                                    format_type(actual)
1394                                ),
1395                                arg.span,
1396                            );
1397                        }
1398                    }
1399                }
1400            }
1401            // Enforce where-clause constraints at call site
1402            if !sig.where_clauses.is_empty() {
1403                // Build mapping: type_param → concrete type from inferred args.
1404                // Recursively walks Generic types so list<T> + list<Dog> binds T=Dog.
1405                let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1406                let type_param_set: std::collections::BTreeSet<String> =
1407                    sig.type_param_names.iter().cloned().collect();
1408                for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1409                    if let Some(param_ty) = param_type {
1410                        if let Some(arg_ty) = self.infer_type(arg, scope) {
1411                            Self::extract_type_bindings(
1412                                param_ty,
1413                                &arg_ty,
1414                                &type_param_set,
1415                                &mut type_bindings,
1416                            );
1417                        }
1418                    }
1419                }
1420                for (type_param, bound) in &sig.where_clauses {
1421                    if let Some(concrete_type) = type_bindings.get(type_param) {
1422                        if let Some(reason) =
1423                            self.interface_mismatch_reason(concrete_type, bound, scope)
1424                        {
1425                            self.warning_at(
1426                                format!(
1427                                    "Type '{}' does not satisfy interface '{}': {} \
1428                                     (required by constraint `where {}: {}`)",
1429                                    concrete_type, bound, reason, type_param, bound
1430                                ),
1431                                span,
1432                            );
1433                        }
1434                    }
1435                }
1436            }
1437        }
1438        // Check args recursively
1439        for arg in args {
1440            self.check_node(arg, scope);
1441        }
1442    }
1443
1444    /// Infer the type of an expression.
1445    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1446        match &snode.node {
1447            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1448            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1449            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1450                Some(TypeExpr::Named("string".into()))
1451            }
1452            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1453            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1454            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1455            Node::DictLiteral(entries) => {
1456                // Infer shape type when all keys are string literals
1457                let mut fields = Vec::new();
1458                let mut all_string_keys = true;
1459                for entry in entries {
1460                    if let Node::StringLiteral(key) = &entry.key.node {
1461                        let val_type = self
1462                            .infer_type(&entry.value, scope)
1463                            .unwrap_or(TypeExpr::Named("nil".into()));
1464                        fields.push(ShapeField {
1465                            name: key.clone(),
1466                            type_expr: val_type,
1467                            optional: false,
1468                        });
1469                    } else {
1470                        all_string_keys = false;
1471                        break;
1472                    }
1473                }
1474                if all_string_keys && !fields.is_empty() {
1475                    Some(TypeExpr::Shape(fields))
1476                } else {
1477                    Some(TypeExpr::Named("dict".into()))
1478                }
1479            }
1480            Node::Closure { params, body, .. } => {
1481                // If all params are typed and we can infer a return type, produce FnType
1482                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1483                if all_typed && !params.is_empty() {
1484                    let param_types: Vec<TypeExpr> =
1485                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1486                    // Try to infer return type from last expression in body
1487                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1488                    if let Some(ret_type) = ret {
1489                        return Some(TypeExpr::FnType {
1490                            params: param_types,
1491                            return_type: Box::new(ret_type),
1492                        });
1493                    }
1494                }
1495                Some(TypeExpr::Named("closure".into()))
1496            }
1497
1498            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1499
1500            Node::FunctionCall { name, .. } => {
1501                // Struct constructor calls return the struct type
1502                if scope.get_struct(name).is_some() {
1503                    return Some(TypeExpr::Named(name.clone()));
1504                }
1505                // Check user-defined function return types
1506                if let Some(sig) = scope.get_fn(name) {
1507                    return sig.return_type.clone();
1508                }
1509                // Check builtin return types
1510                builtin_return_type(name)
1511            }
1512
1513            Node::BinaryOp { op, left, right } => {
1514                let lt = self.infer_type(left, scope);
1515                let rt = self.infer_type(right, scope);
1516                infer_binary_op_type(op, &lt, &rt)
1517            }
1518
1519            Node::UnaryOp { op, operand } => {
1520                let t = self.infer_type(operand, scope);
1521                match op.as_str() {
1522                    "!" => Some(TypeExpr::Named("bool".into())),
1523                    "-" => t, // negation preserves type
1524                    _ => None,
1525                }
1526            }
1527
1528            Node::Ternary {
1529                true_expr,
1530                false_expr,
1531                ..
1532            } => {
1533                let tt = self.infer_type(true_expr, scope);
1534                let ft = self.infer_type(false_expr, scope);
1535                match (&tt, &ft) {
1536                    (Some(a), Some(b)) if a == b => tt,
1537                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1538                    (Some(_), None) => tt,
1539                    (None, Some(_)) => ft,
1540                    (None, None) => None,
1541                }
1542            }
1543
1544            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1545
1546            Node::PropertyAccess { object, property } => {
1547                // EnumName.Variant → infer as the enum type
1548                if let Node::Identifier(name) = &object.node {
1549                    if scope.get_enum(name).is_some() {
1550                        return Some(TypeExpr::Named(name.clone()));
1551                    }
1552                }
1553                // .variant on an enum value → string
1554                if property == "variant" {
1555                    let obj_type = self.infer_type(object, scope);
1556                    if let Some(TypeExpr::Named(name)) = &obj_type {
1557                        if scope.get_enum(name).is_some() {
1558                            return Some(TypeExpr::Named("string".into()));
1559                        }
1560                    }
1561                }
1562                // Shape field access: obj.field → field type
1563                let obj_type = self.infer_type(object, scope);
1564                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1565                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1566                        return Some(field.type_expr.clone());
1567                    }
1568                }
1569                None
1570            }
1571
1572            Node::SubscriptAccess { object, index } => {
1573                let obj_type = self.infer_type(object, scope);
1574                match &obj_type {
1575                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1576                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1577                    Some(TypeExpr::Shape(fields)) => {
1578                        // If index is a string literal, look up the field type
1579                        if let Node::StringLiteral(key) = &index.node {
1580                            fields
1581                                .iter()
1582                                .find(|f| &f.name == key)
1583                                .map(|f| f.type_expr.clone())
1584                        } else {
1585                            None
1586                        }
1587                    }
1588                    Some(TypeExpr::Named(n)) if n == "list" => None,
1589                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1590                    Some(TypeExpr::Named(n)) if n == "string" => {
1591                        Some(TypeExpr::Named("string".into()))
1592                    }
1593                    _ => None,
1594                }
1595            }
1596            Node::SliceAccess { object, .. } => {
1597                // Slicing a list returns the same list type; slicing a string returns string
1598                let obj_type = self.infer_type(object, scope);
1599                match &obj_type {
1600                    Some(TypeExpr::List(_)) => obj_type,
1601                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1602                    Some(TypeExpr::Named(n)) if n == "string" => {
1603                        Some(TypeExpr::Named("string".into()))
1604                    }
1605                    _ => None,
1606                }
1607            }
1608            Node::MethodCall { object, method, .. }
1609            | Node::OptionalMethodCall { object, method, .. } => {
1610                let obj_type = self.infer_type(object, scope);
1611                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1612                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1613                match method.as_str() {
1614                    // Shared: bool-returning methods
1615                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1616                        Some(TypeExpr::Named("bool".into()))
1617                    }
1618                    // Shared: int-returning methods
1619                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1620                    // String methods
1621                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1622                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1623                        Some(TypeExpr::Named("string".into()))
1624                    }
1625                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1626                    // filter returns dict for dicts, list for lists
1627                    "filter" => {
1628                        if is_dict {
1629                            Some(TypeExpr::Named("dict".into()))
1630                        } else {
1631                            Some(TypeExpr::Named("list".into()))
1632                        }
1633                    }
1634                    // List methods
1635                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1636                    "reduce" | "find" | "first" | "last" => None,
1637                    // Dict methods
1638                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1639                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1640                    // Conversions
1641                    "to_string" => Some(TypeExpr::Named("string".into())),
1642                    "to_int" => Some(TypeExpr::Named("int".into())),
1643                    "to_float" => Some(TypeExpr::Named("float".into())),
1644                    _ => None,
1645                }
1646            }
1647
1648            // TryOperator on Result<T, E> produces T
1649            Node::TryOperator { operand } => {
1650                match self.infer_type(operand, scope) {
1651                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1652                    _ => None,
1653                }
1654            }
1655
1656            _ => None,
1657        }
1658    }
1659
1660    /// Check if two types are compatible (actual can be assigned to expected).
1661    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1662        // Generic type parameters match anything.
1663        if let TypeExpr::Named(name) = expected {
1664            if scope.is_generic_type_param(name) {
1665                return true;
1666            }
1667        }
1668        if let TypeExpr::Named(name) = actual {
1669            if scope.is_generic_type_param(name) {
1670                return true;
1671            }
1672        }
1673        let expected = self.resolve_alias(expected, scope);
1674        let actual = self.resolve_alias(actual, scope);
1675
1676        // Interface satisfaction: if expected is an interface name, check if actual type
1677        // has all required methods (Go-style implicit satisfaction).
1678        if let TypeExpr::Named(iface_name) = &expected {
1679            if scope.get_interface(iface_name).is_some() {
1680                if let TypeExpr::Named(type_name) = &actual {
1681                    return self.satisfies_interface(type_name, iface_name, scope);
1682                }
1683                return false;
1684            }
1685        }
1686
1687        match (&expected, &actual) {
1688            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1689            (TypeExpr::Union(members), actual_type) => members
1690                .iter()
1691                .any(|m| self.types_compatible(m, actual_type, scope)),
1692            (expected_type, TypeExpr::Union(members)) => members
1693                .iter()
1694                .all(|m| self.types_compatible(expected_type, m, scope)),
1695            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1696            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1697            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1698                if expected_field.optional {
1699                    return true;
1700                }
1701                af.iter().any(|actual_field| {
1702                    actual_field.name == expected_field.name
1703                        && self.types_compatible(
1704                            &expected_field.type_expr,
1705                            &actual_field.type_expr,
1706                            scope,
1707                        )
1708                })
1709            }),
1710            // dict<K, V> expected, Shape actual → all field values must match V
1711            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1712                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1713                keys_ok
1714                    && af
1715                        .iter()
1716                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1717            }
1718            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1719            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1720            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1721                self.types_compatible(expected_inner, actual_inner, scope)
1722            }
1723            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1724            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1725            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1726                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1727            }
1728            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1729            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1730            // FnType compatibility: params match positionally and return types match
1731            (
1732                TypeExpr::FnType {
1733                    params: ep,
1734                    return_type: er,
1735                },
1736                TypeExpr::FnType {
1737                    params: ap,
1738                    return_type: ar,
1739                },
1740            ) => {
1741                ep.len() == ap.len()
1742                    && ep
1743                        .iter()
1744                        .zip(ap.iter())
1745                        .all(|(e, a)| self.types_compatible(e, a, scope))
1746                    && self.types_compatible(er, ar, scope)
1747            }
1748            // FnType is compatible with Named("closure") for backward compat
1749            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1750            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1751            _ => false,
1752        }
1753    }
1754
1755    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1756        if let TypeExpr::Named(name) = ty {
1757            if let Some(resolved) = scope.resolve_type(name) {
1758                return resolved.clone();
1759            }
1760        }
1761        ty.clone()
1762    }
1763
1764    fn error_at(&mut self, message: String, span: Span) {
1765        self.diagnostics.push(TypeDiagnostic {
1766            message,
1767            severity: DiagnosticSeverity::Error,
1768            span: Some(span),
1769            help: None,
1770        });
1771    }
1772
1773    #[allow(dead_code)]
1774    fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
1775        self.diagnostics.push(TypeDiagnostic {
1776            message,
1777            severity: DiagnosticSeverity::Error,
1778            span: Some(span),
1779            help: Some(help),
1780        });
1781    }
1782
1783    fn warning_at(&mut self, message: String, span: Span) {
1784        self.diagnostics.push(TypeDiagnostic {
1785            message,
1786            severity: DiagnosticSeverity::Warning,
1787            span: Some(span),
1788            help: None,
1789        });
1790    }
1791
1792    #[allow(dead_code)]
1793    fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
1794        self.diagnostics.push(TypeDiagnostic {
1795            message,
1796            severity: DiagnosticSeverity::Warning,
1797            span: Some(span),
1798            help: Some(help),
1799        });
1800    }
1801}
1802
1803impl Default for TypeChecker {
1804    fn default() -> Self {
1805        Self::new()
1806    }
1807}
1808
1809/// Infer the result type of a binary operation.
1810fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1811    match op {
1812        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
1813            Some(TypeExpr::Named("bool".into()))
1814        }
1815        "+" => match (left, right) {
1816            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1817                match (l.as_str(), r.as_str()) {
1818                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1819                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1820                    ("string", _) => Some(TypeExpr::Named("string".into())),
1821                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1822                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1823                    _ => Some(TypeExpr::Named("string".into())),
1824                }
1825            }
1826            _ => None,
1827        },
1828        "-" | "*" | "/" | "%" => match (left, right) {
1829            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1830                match (l.as_str(), r.as_str()) {
1831                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1832                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1833                    _ => None,
1834                }
1835            }
1836            _ => None,
1837        },
1838        "??" => match (left, right) {
1839            (Some(TypeExpr::Union(members)), _) => {
1840                let non_nil: Vec<_> = members
1841                    .iter()
1842                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1843                    .cloned()
1844                    .collect();
1845                if non_nil.len() == 1 {
1846                    Some(non_nil[0].clone())
1847                } else if non_nil.is_empty() {
1848                    right.clone()
1849                } else {
1850                    Some(TypeExpr::Union(non_nil))
1851                }
1852            }
1853            _ => right.clone(),
1854        },
1855        "|>" => None,
1856        _ => None,
1857    }
1858}
1859
1860/// Format a type expression for display in error messages.
1861/// Produce a detail string describing why a Shape type is incompatible with
1862/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1863/// has type int, expected string".  Returns `None` if both types are not shapes.
1864pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1865    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1866        let mut details = Vec::new();
1867        for field in ef {
1868            if field.optional {
1869                continue;
1870            }
1871            match af.iter().find(|f| f.name == field.name) {
1872                None => details.push(format!(
1873                    "missing field '{}' ({})",
1874                    field.name,
1875                    format_type(&field.type_expr)
1876                )),
1877                Some(actual_field) => {
1878                    let e_str = format_type(&field.type_expr);
1879                    let a_str = format_type(&actual_field.type_expr);
1880                    if e_str != a_str {
1881                        details.push(format!(
1882                            "field '{}' has type {}, expected {}",
1883                            field.name, a_str, e_str
1884                        ));
1885                    }
1886                }
1887            }
1888        }
1889        if details.is_empty() {
1890            None
1891        } else {
1892            Some(details.join("; "))
1893        }
1894    } else {
1895        None
1896    }
1897}
1898
1899pub fn format_type(ty: &TypeExpr) -> String {
1900    match ty {
1901        TypeExpr::Named(n) => n.clone(),
1902        TypeExpr::Union(types) => types
1903            .iter()
1904            .map(format_type)
1905            .collect::<Vec<_>>()
1906            .join(" | "),
1907        TypeExpr::Shape(fields) => {
1908            let inner: Vec<String> = fields
1909                .iter()
1910                .map(|f| {
1911                    let opt = if f.optional { "?" } else { "" };
1912                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1913                })
1914                .collect();
1915            format!("{{{}}}", inner.join(", "))
1916        }
1917        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1918        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1919        TypeExpr::FnType {
1920            params,
1921            return_type,
1922        } => {
1923            let params_str = params
1924                .iter()
1925                .map(format_type)
1926                .collect::<Vec<_>>()
1927                .join(", ");
1928            format!("fn({}) -> {}", params_str, format_type(return_type))
1929        }
1930    }
1931}
1932
1933#[cfg(test)]
1934mod tests {
1935    use super::*;
1936    use crate::Parser;
1937    use harn_lexer::Lexer;
1938
1939    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1940        let mut lexer = Lexer::new(source);
1941        let tokens = lexer.tokenize().unwrap();
1942        let mut parser = Parser::new(tokens);
1943        let program = parser.parse().unwrap();
1944        TypeChecker::new().check(&program)
1945    }
1946
1947    fn errors(source: &str) -> Vec<String> {
1948        check_source(source)
1949            .into_iter()
1950            .filter(|d| d.severity == DiagnosticSeverity::Error)
1951            .map(|d| d.message)
1952            .collect()
1953    }
1954
1955    #[test]
1956    fn test_no_errors_for_untyped_code() {
1957        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1958        assert!(errs.is_empty());
1959    }
1960
1961    #[test]
1962    fn test_correct_typed_let() {
1963        let errs = errors("pipeline t(task) { let x: int = 42 }");
1964        assert!(errs.is_empty());
1965    }
1966
1967    #[test]
1968    fn test_type_mismatch_let() {
1969        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1970        assert_eq!(errs.len(), 1);
1971        assert!(errs[0].contains("Type mismatch"));
1972        assert!(errs[0].contains("int"));
1973        assert!(errs[0].contains("string"));
1974    }
1975
1976    #[test]
1977    fn test_correct_typed_fn() {
1978        let errs = errors(
1979            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1980        );
1981        assert!(errs.is_empty());
1982    }
1983
1984    #[test]
1985    fn test_fn_arg_type_mismatch() {
1986        let errs = errors(
1987            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1988add("hello", 2) }"#,
1989        );
1990        assert_eq!(errs.len(), 1);
1991        assert!(errs[0].contains("Argument 1"));
1992        assert!(errs[0].contains("expected int"));
1993    }
1994
1995    #[test]
1996    fn test_return_type_mismatch() {
1997        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1998        assert_eq!(errs.len(), 1);
1999        assert!(errs[0].contains("Return type mismatch"));
2000    }
2001
2002    #[test]
2003    fn test_union_type_compatible() {
2004        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
2005        assert!(errs.is_empty());
2006    }
2007
2008    #[test]
2009    fn test_union_type_mismatch() {
2010        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
2011        assert_eq!(errs.len(), 1);
2012        assert!(errs[0].contains("Type mismatch"));
2013    }
2014
2015    #[test]
2016    fn test_type_inference_propagation() {
2017        let errs = errors(
2018            r#"pipeline t(task) {
2019  fn add(a: int, b: int) -> int { return a + b }
2020  let result: string = add(1, 2)
2021}"#,
2022        );
2023        assert_eq!(errs.len(), 1);
2024        assert!(errs[0].contains("Type mismatch"));
2025        assert!(errs[0].contains("string"));
2026        assert!(errs[0].contains("int"));
2027    }
2028
2029    #[test]
2030    fn test_builtin_return_type_inference() {
2031        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
2032        assert_eq!(errs.len(), 1);
2033        assert!(errs[0].contains("string"));
2034        assert!(errs[0].contains("int"));
2035    }
2036
2037    #[test]
2038    fn test_workflow_and_transcript_builtins_are_known() {
2039        let errs = errors(
2040            r#"pipeline t(task) {
2041  let flow = workflow_graph({name: "demo", entry: "act", nodes: {act: {kind: "stage"}}})
2042  let report: dict = workflow_policy_report(flow, {tools: ["read"], capabilities: {workspace: ["read_text"]}})
2043  let run: dict = workflow_execute("task", flow, [], {})
2044  let tree: dict = load_run_tree("run.json")
2045  let fixture: dict = run_record_fixture(run?.run)
2046  let suite: dict = run_record_eval_suite([{run: run?.run, fixture: fixture}])
2047  let diff: dict = run_record_diff(run?.run, run?.run)
2048  let manifest: dict = eval_suite_manifest({cases: [{run_path: "run.json"}]})
2049  let suite_report: dict = eval_suite_run(manifest)
2050  let wf: dict = artifact_workspace_file("src/main.rs", "fn main() {}", {source: "host"})
2051  let snap: dict = artifact_workspace_snapshot(["src/main.rs"], "snapshot")
2052  let selection: dict = artifact_editor_selection("src/main.rs", "main")
2053  let verify: dict = artifact_verification_result("verify", "ok")
2054  let test_result: dict = artifact_test_result("tests", "pass")
2055  let cmd: dict = artifact_command_result("cargo test", {status: 0})
2056  let patch: dict = artifact_diff("src/main.rs", "old", "new")
2057  let git: dict = artifact_git_diff("diff --git a b")
2058  let review: dict = artifact_diff_review(patch, "review me")
2059  let decision: dict = artifact_review_decision(review, "accepted")
2060  let proposal: dict = artifact_patch_proposal(review, "*** Begin Patch")
2061  let bundle: dict = artifact_verification_bundle("checks", [{name: "fmt", ok: true}])
2062  let apply: dict = artifact_apply_intent(review, "apply")
2063  let transcript = transcript_reset({metadata: {source: "test"}})
2064  let visible: string = transcript_render_visible(transcript_archive(transcript))
2065  let events: list = transcript_events(transcript)
2066  let context: string = artifact_context([], {max_artifacts: 1})
2067  println(report)
2068  println(run)
2069  println(tree)
2070  println(fixture)
2071  println(suite)
2072  println(diff)
2073  println(manifest)
2074  println(suite_report)
2075  println(wf)
2076  println(snap)
2077  println(selection)
2078  println(verify)
2079  println(test_result)
2080  println(cmd)
2081  println(patch)
2082  println(git)
2083  println(review)
2084  println(decision)
2085  println(proposal)
2086  println(bundle)
2087  println(apply)
2088  println(visible)
2089  println(events)
2090  println(context)
2091}"#,
2092        );
2093        assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
2094    }
2095
2096    #[test]
2097    fn test_binary_op_type_inference() {
2098        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
2099        assert_eq!(errs.len(), 1);
2100    }
2101
2102    #[test]
2103    fn test_comparison_returns_bool() {
2104        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
2105        assert!(errs.is_empty());
2106    }
2107
2108    #[test]
2109    fn test_int_float_promotion() {
2110        let errs = errors("pipeline t(task) { let x: float = 42 }");
2111        assert!(errs.is_empty());
2112    }
2113
2114    #[test]
2115    fn test_untyped_code_no_errors() {
2116        let errs = errors(
2117            r#"pipeline t(task) {
2118  fn process(data) {
2119    let result = data + " processed"
2120    return result
2121  }
2122  log(process("hello"))
2123}"#,
2124        );
2125        assert!(errs.is_empty());
2126    }
2127
2128    #[test]
2129    fn test_type_alias() {
2130        let errs = errors(
2131            r#"pipeline t(task) {
2132  type Name = string
2133  let x: Name = "hello"
2134}"#,
2135        );
2136        assert!(errs.is_empty());
2137    }
2138
2139    #[test]
2140    fn test_type_alias_mismatch() {
2141        let errs = errors(
2142            r#"pipeline t(task) {
2143  type Name = string
2144  let x: Name = 42
2145}"#,
2146        );
2147        assert_eq!(errs.len(), 1);
2148    }
2149
2150    #[test]
2151    fn test_assignment_type_check() {
2152        let errs = errors(
2153            r#"pipeline t(task) {
2154  var x: int = 0
2155  x = "hello"
2156}"#,
2157        );
2158        assert_eq!(errs.len(), 1);
2159        assert!(errs[0].contains("cannot assign string"));
2160    }
2161
2162    #[test]
2163    fn test_covariance_int_to_float_in_fn() {
2164        let errs = errors(
2165            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
2166        );
2167        assert!(errs.is_empty());
2168    }
2169
2170    #[test]
2171    fn test_covariance_return_type() {
2172        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
2173        assert!(errs.is_empty());
2174    }
2175
2176    #[test]
2177    fn test_no_contravariance_float_to_int() {
2178        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
2179        assert_eq!(errs.len(), 1);
2180    }
2181
2182    // --- Exhaustiveness checking tests ---
2183
2184    fn warnings(source: &str) -> Vec<String> {
2185        check_source(source)
2186            .into_iter()
2187            .filter(|d| d.severity == DiagnosticSeverity::Warning)
2188            .map(|d| d.message)
2189            .collect()
2190    }
2191
2192    #[test]
2193    fn test_exhaustive_match_no_warning() {
2194        let warns = warnings(
2195            r#"pipeline t(task) {
2196  enum Color { Red, Green, Blue }
2197  let c = Color.Red
2198  match c.variant {
2199    "Red" -> { log("r") }
2200    "Green" -> { log("g") }
2201    "Blue" -> { log("b") }
2202  }
2203}"#,
2204        );
2205        let exhaustive_warns: Vec<_> = warns
2206            .iter()
2207            .filter(|w| w.contains("Non-exhaustive"))
2208            .collect();
2209        assert!(exhaustive_warns.is_empty());
2210    }
2211
2212    #[test]
2213    fn test_non_exhaustive_match_warning() {
2214        let warns = warnings(
2215            r#"pipeline t(task) {
2216  enum Color { Red, Green, Blue }
2217  let c = Color.Red
2218  match c.variant {
2219    "Red" -> { log("r") }
2220    "Green" -> { log("g") }
2221  }
2222}"#,
2223        );
2224        let exhaustive_warns: Vec<_> = warns
2225            .iter()
2226            .filter(|w| w.contains("Non-exhaustive"))
2227            .collect();
2228        assert_eq!(exhaustive_warns.len(), 1);
2229        assert!(exhaustive_warns[0].contains("Blue"));
2230    }
2231
2232    #[test]
2233    fn test_non_exhaustive_multiple_missing() {
2234        let warns = warnings(
2235            r#"pipeline t(task) {
2236  enum Status { Active, Inactive, Pending }
2237  let s = Status.Active
2238  match s.variant {
2239    "Active" -> { log("a") }
2240  }
2241}"#,
2242        );
2243        let exhaustive_warns: Vec<_> = warns
2244            .iter()
2245            .filter(|w| w.contains("Non-exhaustive"))
2246            .collect();
2247        assert_eq!(exhaustive_warns.len(), 1);
2248        assert!(exhaustive_warns[0].contains("Inactive"));
2249        assert!(exhaustive_warns[0].contains("Pending"));
2250    }
2251
2252    #[test]
2253    fn test_enum_construct_type_inference() {
2254        let errs = errors(
2255            r#"pipeline t(task) {
2256  enum Color { Red, Green, Blue }
2257  let c: Color = Color.Red
2258}"#,
2259        );
2260        assert!(errs.is_empty());
2261    }
2262
2263    // --- Type narrowing tests ---
2264
2265    #[test]
2266    fn test_nil_coalescing_strips_nil() {
2267        // After ??, nil should be stripped from the type
2268        let errs = errors(
2269            r#"pipeline t(task) {
2270  let x: string | nil = nil
2271  let y: string = x ?? "default"
2272}"#,
2273        );
2274        assert!(errs.is_empty());
2275    }
2276
2277    #[test]
2278    fn test_shape_mismatch_detail_missing_field() {
2279        let errs = errors(
2280            r#"pipeline t(task) {
2281  let x: {name: string, age: int} = {name: "hello"}
2282}"#,
2283        );
2284        assert_eq!(errs.len(), 1);
2285        assert!(
2286            errs[0].contains("missing field 'age'"),
2287            "expected detail about missing field, got: {}",
2288            errs[0]
2289        );
2290    }
2291
2292    #[test]
2293    fn test_shape_mismatch_detail_wrong_type() {
2294        let errs = errors(
2295            r#"pipeline t(task) {
2296  let x: {name: string, age: int} = {name: 42, age: 10}
2297}"#,
2298        );
2299        assert_eq!(errs.len(), 1);
2300        assert!(
2301            errs[0].contains("field 'name' has type int, expected string"),
2302            "expected detail about wrong type, got: {}",
2303            errs[0]
2304        );
2305    }
2306
2307    // --- Match pattern type validation tests ---
2308
2309    #[test]
2310    fn test_match_pattern_string_against_int() {
2311        let warns = warnings(
2312            r#"pipeline t(task) {
2313  let x: int = 42
2314  match x {
2315    "hello" -> { log("bad") }
2316    42 -> { log("ok") }
2317  }
2318}"#,
2319        );
2320        let pattern_warns: Vec<_> = warns
2321            .iter()
2322            .filter(|w| w.contains("Match pattern type mismatch"))
2323            .collect();
2324        assert_eq!(pattern_warns.len(), 1);
2325        assert!(pattern_warns[0].contains("matching int against string literal"));
2326    }
2327
2328    #[test]
2329    fn test_match_pattern_int_against_string() {
2330        let warns = warnings(
2331            r#"pipeline t(task) {
2332  let x: string = "hello"
2333  match x {
2334    42 -> { log("bad") }
2335    "hello" -> { log("ok") }
2336  }
2337}"#,
2338        );
2339        let pattern_warns: Vec<_> = warns
2340            .iter()
2341            .filter(|w| w.contains("Match pattern type mismatch"))
2342            .collect();
2343        assert_eq!(pattern_warns.len(), 1);
2344        assert!(pattern_warns[0].contains("matching string against int literal"));
2345    }
2346
2347    #[test]
2348    fn test_match_pattern_bool_against_int() {
2349        let warns = warnings(
2350            r#"pipeline t(task) {
2351  let x: int = 42
2352  match x {
2353    true -> { log("bad") }
2354    42 -> { log("ok") }
2355  }
2356}"#,
2357        );
2358        let pattern_warns: Vec<_> = warns
2359            .iter()
2360            .filter(|w| w.contains("Match pattern type mismatch"))
2361            .collect();
2362        assert_eq!(pattern_warns.len(), 1);
2363        assert!(pattern_warns[0].contains("matching int against bool literal"));
2364    }
2365
2366    #[test]
2367    fn test_match_pattern_float_against_string() {
2368        let warns = warnings(
2369            r#"pipeline t(task) {
2370  let x: string = "hello"
2371  match x {
2372    3.14 -> { log("bad") }
2373    "hello" -> { log("ok") }
2374  }
2375}"#,
2376        );
2377        let pattern_warns: Vec<_> = warns
2378            .iter()
2379            .filter(|w| w.contains("Match pattern type mismatch"))
2380            .collect();
2381        assert_eq!(pattern_warns.len(), 1);
2382        assert!(pattern_warns[0].contains("matching string against float literal"));
2383    }
2384
2385    #[test]
2386    fn test_match_pattern_int_against_float_ok() {
2387        // int and float are compatible for match patterns
2388        let warns = warnings(
2389            r#"pipeline t(task) {
2390  let x: float = 3.14
2391  match x {
2392    42 -> { log("ok") }
2393    _ -> { log("default") }
2394  }
2395}"#,
2396        );
2397        let pattern_warns: Vec<_> = warns
2398            .iter()
2399            .filter(|w| w.contains("Match pattern type mismatch"))
2400            .collect();
2401        assert!(pattern_warns.is_empty());
2402    }
2403
2404    #[test]
2405    fn test_match_pattern_float_against_int_ok() {
2406        // float and int are compatible for match patterns
2407        let warns = warnings(
2408            r#"pipeline t(task) {
2409  let x: int = 42
2410  match x {
2411    3.14 -> { log("close") }
2412    _ -> { log("default") }
2413  }
2414}"#,
2415        );
2416        let pattern_warns: Vec<_> = warns
2417            .iter()
2418            .filter(|w| w.contains("Match pattern type mismatch"))
2419            .collect();
2420        assert!(pattern_warns.is_empty());
2421    }
2422
2423    #[test]
2424    fn test_match_pattern_correct_types_no_warning() {
2425        let warns = warnings(
2426            r#"pipeline t(task) {
2427  let x: int = 42
2428  match x {
2429    1 -> { log("one") }
2430    2 -> { log("two") }
2431    _ -> { log("other") }
2432  }
2433}"#,
2434        );
2435        let pattern_warns: Vec<_> = warns
2436            .iter()
2437            .filter(|w| w.contains("Match pattern type mismatch"))
2438            .collect();
2439        assert!(pattern_warns.is_empty());
2440    }
2441
2442    #[test]
2443    fn test_match_pattern_wildcard_no_warning() {
2444        let warns = warnings(
2445            r#"pipeline t(task) {
2446  let x: int = 42
2447  match x {
2448    _ -> { log("catch all") }
2449  }
2450}"#,
2451        );
2452        let pattern_warns: Vec<_> = warns
2453            .iter()
2454            .filter(|w| w.contains("Match pattern type mismatch"))
2455            .collect();
2456        assert!(pattern_warns.is_empty());
2457    }
2458
2459    #[test]
2460    fn test_match_pattern_untyped_no_warning() {
2461        // When value has no known type, no warning should be emitted
2462        let warns = warnings(
2463            r#"pipeline t(task) {
2464  let x = some_unknown_fn()
2465  match x {
2466    "hello" -> { log("string") }
2467    42 -> { log("int") }
2468  }
2469}"#,
2470        );
2471        let pattern_warns: Vec<_> = warns
2472            .iter()
2473            .filter(|w| w.contains("Match pattern type mismatch"))
2474            .collect();
2475        assert!(pattern_warns.is_empty());
2476    }
2477
2478    // --- Interface constraint type checking tests ---
2479
2480    fn iface_warns(source: &str) -> Vec<String> {
2481        warnings(source)
2482            .into_iter()
2483            .filter(|w| w.contains("does not satisfy interface"))
2484            .collect()
2485    }
2486
2487    #[test]
2488    fn test_interface_constraint_return_type_mismatch() {
2489        let warns = iface_warns(
2490            r#"pipeline t(task) {
2491  interface Sizable {
2492    fn size(self) -> int
2493  }
2494  struct Box { width: int }
2495  impl Box {
2496    fn size(self) -> string { return "nope" }
2497  }
2498  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2499  measure(Box({width: 3}))
2500}"#,
2501        );
2502        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2503        assert!(
2504            warns[0].contains("method 'size' returns 'string', expected 'int'"),
2505            "unexpected message: {}",
2506            warns[0]
2507        );
2508    }
2509
2510    #[test]
2511    fn test_interface_constraint_param_type_mismatch() {
2512        let warns = iface_warns(
2513            r#"pipeline t(task) {
2514  interface Processor {
2515    fn process(self, x: int) -> string
2516  }
2517  struct MyProc { name: string }
2518  impl MyProc {
2519    fn process(self, x: string) -> string { return x }
2520  }
2521  fn run_proc<T>(p: T) where T: Processor { log(p.process(42)) }
2522  run_proc(MyProc({name: "a"}))
2523}"#,
2524        );
2525        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2526        assert!(
2527            warns[0].contains("method 'process' parameter 1 has type 'string', expected 'int'"),
2528            "unexpected message: {}",
2529            warns[0]
2530        );
2531    }
2532
2533    #[test]
2534    fn test_interface_constraint_missing_method() {
2535        let warns = iface_warns(
2536            r#"pipeline t(task) {
2537  interface Sizable {
2538    fn size(self) -> int
2539  }
2540  struct Box { width: int }
2541  impl Box {
2542    fn area(self) -> int { return self.width }
2543  }
2544  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2545  measure(Box({width: 3}))
2546}"#,
2547        );
2548        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2549        assert!(
2550            warns[0].contains("missing method 'size'"),
2551            "unexpected message: {}",
2552            warns[0]
2553        );
2554    }
2555
2556    #[test]
2557    fn test_interface_constraint_param_count_mismatch() {
2558        let warns = iface_warns(
2559            r#"pipeline t(task) {
2560  interface Doubler {
2561    fn double(self, x: int) -> int
2562  }
2563  struct Bad { v: int }
2564  impl Bad {
2565    fn double(self) -> int { return self.v * 2 }
2566  }
2567  fn run_double<T>(d: T) where T: Doubler { log(d.double(3)) }
2568  run_double(Bad({v: 5}))
2569}"#,
2570        );
2571        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2572        assert!(
2573            warns[0].contains("method 'double' has 0 parameter(s), expected 1"),
2574            "unexpected message: {}",
2575            warns[0]
2576        );
2577    }
2578
2579    #[test]
2580    fn test_interface_constraint_satisfied() {
2581        let warns = iface_warns(
2582            r#"pipeline t(task) {
2583  interface Sizable {
2584    fn size(self) -> int
2585  }
2586  struct Box { width: int, height: int }
2587  impl Box {
2588    fn size(self) -> int { return self.width * self.height }
2589  }
2590  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2591  measure(Box({width: 3, height: 4}))
2592}"#,
2593        );
2594        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2595    }
2596
2597    #[test]
2598    fn test_interface_constraint_untyped_impl_compatible() {
2599        // Gradual typing: untyped impl return should not trigger warning
2600        let warns = iface_warns(
2601            r#"pipeline t(task) {
2602  interface Sizable {
2603    fn size(self) -> int
2604  }
2605  struct Box { width: int }
2606  impl Box {
2607    fn size(self) { return self.width }
2608  }
2609  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2610  measure(Box({width: 3}))
2611}"#,
2612        );
2613        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2614    }
2615
2616    #[test]
2617    fn test_interface_constraint_int_float_covariance() {
2618        // int is compatible with float (covariance)
2619        let warns = iface_warns(
2620            r#"pipeline t(task) {
2621  interface Measurable {
2622    fn value(self) -> float
2623  }
2624  struct Gauge { v: int }
2625  impl Gauge {
2626    fn value(self) -> int { return self.v }
2627  }
2628  fn read_val<T>(g: T) where T: Measurable { log(g.value()) }
2629  read_val(Gauge({v: 42}))
2630}"#,
2631        );
2632        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2633    }
2634}