Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12    pub help: Option<String>,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum DiagnosticSeverity {
17    Error,
18    Warning,
19}
20
21/// Inferred type of an expression. None means unknown/untyped (gradual typing).
22type InferredType = Option<TypeExpr>;
23
24/// Scope for tracking variable types.
25#[derive(Debug, Clone)]
26struct TypeScope {
27    /// Variable name → inferred type.
28    vars: BTreeMap<String, InferredType>,
29    /// Function name → (param types, return type).
30    functions: BTreeMap<String, FnSignature>,
31    /// Named type aliases.
32    type_aliases: BTreeMap<String, TypeExpr>,
33    /// Enum declarations: name → variant names.
34    enums: BTreeMap<String, Vec<String>>,
35    /// Interface declarations: name → method signatures.
36    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
37    /// Struct declarations: name → field types.
38    structs: BTreeMap<String, Vec<(String, InferredType)>>,
39    /// Impl block methods: type_name → method signatures.
40    impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
41    /// Generic type parameter names in scope (treated as compatible with any type).
42    generic_type_params: std::collections::BTreeSet<String>,
43    /// Where-clause constraints: type_param → interface_bound.
44    /// Used for definition-site checking of generic function bodies.
45    where_constraints: BTreeMap<String, String>,
46    parent: Option<Box<TypeScope>>,
47}
48
49/// Method signature extracted from an impl block (for interface checking).
50#[derive(Debug, Clone)]
51struct ImplMethodSig {
52    name: String,
53    /// Number of parameters excluding `self`.
54    param_count: usize,
55    /// Parameter types (excluding `self`), None means untyped.
56    param_types: Vec<Option<TypeExpr>>,
57    /// Return type, None means untyped.
58    return_type: Option<TypeExpr>,
59}
60
61#[derive(Debug, Clone)]
62struct FnSignature {
63    params: Vec<(String, InferredType)>,
64    return_type: InferredType,
65    /// Generic type parameter names declared on the function.
66    type_param_names: Vec<String>,
67    /// Number of required parameters (those without defaults).
68    required_params: usize,
69    /// Where-clause constraints: (type_param_name, interface_bound).
70    where_clauses: Vec<(String, String)>,
71}
72
73impl TypeScope {
74    fn new() -> Self {
75        Self {
76            vars: BTreeMap::new(),
77            functions: BTreeMap::new(),
78            type_aliases: BTreeMap::new(),
79            enums: BTreeMap::new(),
80            interfaces: BTreeMap::new(),
81            structs: BTreeMap::new(),
82            impl_methods: BTreeMap::new(),
83            generic_type_params: std::collections::BTreeSet::new(),
84            where_constraints: BTreeMap::new(),
85            parent: None,
86        }
87    }
88
89    fn child(&self) -> Self {
90        Self {
91            vars: BTreeMap::new(),
92            functions: BTreeMap::new(),
93            type_aliases: BTreeMap::new(),
94            enums: BTreeMap::new(),
95            interfaces: BTreeMap::new(),
96            structs: BTreeMap::new(),
97            impl_methods: BTreeMap::new(),
98            generic_type_params: std::collections::BTreeSet::new(),
99            where_constraints: BTreeMap::new(),
100            parent: Some(Box::new(self.clone())),
101        }
102    }
103
104    fn get_var(&self, name: &str) -> Option<&InferredType> {
105        self.vars
106            .get(name)
107            .or_else(|| self.parent.as_ref()?.get_var(name))
108    }
109
110    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
111        self.functions
112            .get(name)
113            .or_else(|| self.parent.as_ref()?.get_fn(name))
114    }
115
116    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
117        self.type_aliases
118            .get(name)
119            .or_else(|| self.parent.as_ref()?.resolve_type(name))
120    }
121
122    fn is_generic_type_param(&self, name: &str) -> bool {
123        self.generic_type_params.contains(name)
124            || self
125                .parent
126                .as_ref()
127                .is_some_and(|p| p.is_generic_type_param(name))
128    }
129
130    fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
131        self.where_constraints
132            .get(type_param)
133            .map(|s| s.as_str())
134            .or_else(|| {
135                self.parent
136                    .as_ref()
137                    .and_then(|p| p.get_where_constraint(type_param))
138            })
139    }
140
141    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
142        self.enums
143            .get(name)
144            .or_else(|| self.parent.as_ref()?.get_enum(name))
145    }
146
147    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
148        self.interfaces
149            .get(name)
150            .or_else(|| self.parent.as_ref()?.get_interface(name))
151    }
152
153    fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
154        self.structs
155            .get(name)
156            .or_else(|| self.parent.as_ref()?.get_struct(name))
157    }
158
159    fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
160        self.impl_methods
161            .get(name)
162            .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
163    }
164
165    fn define_var(&mut self, name: &str, ty: InferredType) {
166        self.vars.insert(name.to_string(), ty);
167    }
168
169    fn define_fn(&mut self, name: &str, sig: FnSignature) {
170        self.functions.insert(name.to_string(), sig);
171    }
172}
173
174/// Known return types for builtin functions.
175fn builtin_return_type(name: &str) -> InferredType {
176    match name {
177        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
178        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
179            Some(TypeExpr::Named("nil".into()))
180        }
181        "type_of"
182        | "to_string"
183        | "json_stringify"
184        | "read_file"
185        | "http_get"
186        | "http_post"
187        | "regex_replace"
188        | "path_join"
189        | "temp_dir"
190        | "date_format"
191        | "format"
192        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
193        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
194        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
195        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
196            Some(TypeExpr::Named("float".into()))
197        }
198        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
199            Some(TypeExpr::Named("bool".into()))
200        }
201        "list_dir"
202        | "mcp_list_tools"
203        | "mcp_list_resources"
204        | "mcp_list_prompts"
205        | "to_list"
206        | "regex_captures"
207        | "artifact_select"
208        | "transcript_messages"
209        | "transcript_events" => Some(TypeExpr::Named("list".into())),
210        "stat"
211        | "exec"
212        | "exec_at"
213        | "shell"
214        | "shell_at"
215        | "date_now"
216        | "llm_call"
217        | "llm_completion"
218        | "agent_loop"
219        | "llm_info"
220        | "llm_usage"
221        | "timer_start"
222        | "metadata_get"
223        | "mcp_server_info"
224        | "mcp_get_prompt"
225        | "llm_pick_model"
226        | "transcript"
227        | "transcript_from_messages"
228        | "transcript_reset"
229        | "transcript_archive"
230        | "transcript_abandon"
231        | "transcript_resume"
232        | "workflow_graph"
233        | "workflow_validate"
234        | "workflow_inspect"
235        | "workflow_policy_report"
236        | "workflow_clone"
237        | "workflow_insert_node"
238        | "workflow_replace_node"
239        | "workflow_rewire"
240        | "workflow_set_model_policy"
241        | "workflow_set_context_policy"
242        | "workflow_set_transcript_policy"
243        | "workflow_diff"
244        | "workflow_commit"
245        | "artifact"
246        | "artifact_derive"
247        | "artifact_workspace_file"
248        | "artifact_workspace_snapshot"
249        | "artifact_editor_selection"
250        | "artifact_verification_result"
251        | "artifact_test_result"
252        | "artifact_command_result"
253        | "artifact_diff"
254        | "artifact_git_diff"
255        | "artifact_diff_review"
256        | "artifact_review_decision"
257        | "run_record"
258        | "run_record_save"
259        | "run_record_load"
260        | "run_record_fixture"
261        | "run_record_eval"
262        | "run_record_eval_suite"
263        | "run_record_diff"
264        | "eval_suite_manifest"
265        | "eval_suite_run"
266        | "workflow_execute"
267        | "resume_agent"
268        | "transcript_compact"
269        | "transcript_summarize"
270        | "host_capabilities" => Some(TypeExpr::Named("dict".into())),
271        "transcript_render_visible"
272        | "transcript_render_full"
273        | "artifact_context"
274        | "transcript_export"
275        | "transcript_id" => Some(TypeExpr::Named("string".into())),
276        "transcript_summary" => Some(TypeExpr::Union(vec![
277            TypeExpr::Named("string".into()),
278            TypeExpr::Named("nil".into()),
279        ])),
280        "host_has" => Some(TypeExpr::Named("bool".into())),
281        "metadata_set"
282        | "metadata_save"
283        | "metadata_refresh_hashes"
284        | "invalidate_facts"
285        | "log_json"
286        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
287        "env" | "regex_match" => Some(TypeExpr::Union(vec![
288            TypeExpr::Named("string".into()),
289            TypeExpr::Named("nil".into()),
290        ])),
291        "json_parse" | "json_extract" => None, // could be any type
292        _ => None,
293    }
294}
295
296/// Check if a name is a known builtin.
297fn is_builtin(name: &str) -> bool {
298    matches!(
299        name,
300        "log"
301            | "print"
302            | "println"
303            | "type_of"
304            | "to_string"
305            | "to_int"
306            | "to_float"
307            | "json_stringify"
308            | "json_parse"
309            | "env"
310            | "timestamp"
311            | "sleep"
312            | "read_file"
313            | "write_file"
314            | "exit"
315            | "regex_match"
316            | "regex_replace"
317            | "regex_captures"
318            | "http_get"
319            | "http_post"
320            | "llm_call"
321            | "llm_completion"
322            | "agent_loop"
323            | "llm_pick_model"
324            | "await"
325            | "cancel"
326            | "file_exists"
327            | "delete_file"
328            | "list_dir"
329            | "mkdir"
330            | "path_join"
331            | "copy_file"
332            | "append_file"
333            | "temp_dir"
334            | "transcript"
335            | "transcript_from_messages"
336            | "transcript_messages"
337            | "transcript_events"
338            | "transcript_summary"
339            | "transcript_id"
340            | "transcript_export"
341            | "transcript_import"
342            | "transcript_fork"
343            | "transcript_reset"
344            | "transcript_archive"
345            | "transcript_abandon"
346            | "transcript_resume"
347            | "transcript_render_visible"
348            | "transcript_render_full"
349            | "transcript_compact"
350            | "transcript_summarize"
351            | "host_capabilities"
352            | "host_has"
353            | "host_invoke"
354            | "workflow_graph"
355            | "workflow_validate"
356            | "workflow_inspect"
357            | "workflow_policy_report"
358            | "workflow_clone"
359            | "workflow_insert_node"
360            | "workflow_replace_node"
361            | "workflow_rewire"
362            | "workflow_set_model_policy"
363            | "workflow_set_context_policy"
364            | "workflow_set_transcript_policy"
365            | "workflow_diff"
366            | "workflow_commit"
367            | "workflow_execute"
368            | "resume_agent"
369            | "artifact"
370            | "artifact_derive"
371            | "artifact_workspace_file"
372            | "artifact_workspace_snapshot"
373            | "artifact_editor_selection"
374            | "artifact_verification_result"
375            | "artifact_test_result"
376            | "artifact_command_result"
377            | "artifact_diff"
378            | "artifact_git_diff"
379            | "artifact_diff_review"
380            | "artifact_review_decision"
381            | "artifact_select"
382            | "artifact_context"
383            | "run_record"
384            | "run_record_save"
385            | "run_record_load"
386            | "run_record_fixture"
387            | "run_record_eval"
388            | "run_record_eval_suite"
389            | "run_record_diff"
390            | "eval_suite_manifest"
391            | "eval_suite_run"
392            | "stat"
393            | "exec"
394            | "exec_at"
395            | "shell"
396            | "shell_at"
397            | "date_now"
398            | "date_format"
399            | "date_parse"
400            | "format"
401            | "json_validate"
402            | "json_extract"
403            | "trim"
404            | "lowercase"
405            | "uppercase"
406            | "split"
407            | "starts_with"
408            | "ends_with"
409            | "contains"
410            | "replace"
411            | "join"
412            | "len"
413            | "substring"
414            | "dirname"
415            | "basename"
416            | "extname"
417            | "sin"
418            | "cos"
419            | "tan"
420            | "asin"
421            | "acos"
422            | "atan"
423            | "atan2"
424            | "log2"
425            | "log10"
426            | "ln"
427            | "exp"
428            | "pi"
429            | "e"
430            | "sign"
431            | "is_nan"
432            | "is_infinite"
433            | "set"
434            | "set_add"
435            | "set_remove"
436            | "set_contains"
437            | "set_union"
438            | "set_intersect"
439            | "set_difference"
440            | "to_list"
441    )
442}
443
444/// The static type checker.
445pub struct TypeChecker {
446    diagnostics: Vec<TypeDiagnostic>,
447    scope: TypeScope,
448}
449
450impl TypeChecker {
451    pub fn new() -> Self {
452        Self {
453            diagnostics: Vec::new(),
454            scope: TypeScope::new(),
455        }
456    }
457
458    /// Check a program and return diagnostics.
459    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
460        // First pass: register type and enum declarations into root scope
461        Self::register_declarations_into(&mut self.scope, program);
462
463        // Also scan pipeline bodies for declarations
464        for snode in program {
465            if let Node::Pipeline { body, .. } = &snode.node {
466                Self::register_declarations_into(&mut self.scope, body);
467            }
468        }
469
470        // Check each top-level node
471        for snode in program {
472            match &snode.node {
473                Node::Pipeline { params, body, .. } => {
474                    let mut child = self.scope.child();
475                    for p in params {
476                        child.define_var(p, None);
477                    }
478                    self.check_block(body, &mut child);
479                }
480                Node::FnDecl {
481                    name,
482                    type_params,
483                    params,
484                    return_type,
485                    where_clauses,
486                    body,
487                    ..
488                } => {
489                    let required_params =
490                        params.iter().filter(|p| p.default_value.is_none()).count();
491                    let sig = FnSignature {
492                        params: params
493                            .iter()
494                            .map(|p| (p.name.clone(), p.type_expr.clone()))
495                            .collect(),
496                        return_type: return_type.clone(),
497                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
498                        required_params,
499                        where_clauses: where_clauses
500                            .iter()
501                            .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
502                            .collect(),
503                    };
504                    self.scope.define_fn(name, sig);
505                    self.check_fn_body(type_params, params, return_type, body, where_clauses);
506                }
507                _ => {
508                    let mut scope = self.scope.clone();
509                    self.check_node(snode, &mut scope);
510                    // Merge any new definitions back into the top-level scope
511                    for (name, ty) in scope.vars {
512                        self.scope.vars.entry(name).or_insert(ty);
513                    }
514                }
515            }
516        }
517
518        self.diagnostics
519    }
520
521    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
522    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
523        for snode in nodes {
524            match &snode.node {
525                Node::TypeDecl { name, type_expr } => {
526                    scope.type_aliases.insert(name.clone(), type_expr.clone());
527                }
528                Node::EnumDecl { name, variants } => {
529                    let variant_names: Vec<String> =
530                        variants.iter().map(|v| v.name.clone()).collect();
531                    scope.enums.insert(name.clone(), variant_names);
532                }
533                Node::InterfaceDecl { name, methods } => {
534                    scope.interfaces.insert(name.clone(), methods.clone());
535                }
536                Node::StructDecl { name, fields } => {
537                    let field_types: Vec<(String, InferredType)> = fields
538                        .iter()
539                        .map(|f| (f.name.clone(), f.type_expr.clone()))
540                        .collect();
541                    scope.structs.insert(name.clone(), field_types);
542                }
543                Node::ImplBlock {
544                    type_name, methods, ..
545                } => {
546                    let sigs: Vec<ImplMethodSig> = methods
547                        .iter()
548                        .filter_map(|m| {
549                            if let Node::FnDecl {
550                                name,
551                                params,
552                                return_type,
553                                ..
554                            } = &m.node
555                            {
556                                let non_self: Vec<_> =
557                                    params.iter().filter(|p| p.name != "self").collect();
558                                let param_count = non_self.len();
559                                let param_types: Vec<Option<TypeExpr>> =
560                                    non_self.iter().map(|p| p.type_expr.clone()).collect();
561                                Some(ImplMethodSig {
562                                    name: name.clone(),
563                                    param_count,
564                                    param_types,
565                                    return_type: return_type.clone(),
566                                })
567                            } else {
568                                None
569                            }
570                        })
571                        .collect();
572                    scope.impl_methods.insert(type_name.clone(), sigs);
573                }
574                _ => {}
575            }
576        }
577    }
578
579    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
580        for stmt in stmts {
581            self.check_node(stmt, scope);
582        }
583    }
584
585    /// Define variables from a destructuring pattern in the given scope (as unknown type).
586    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
587        match pattern {
588            BindingPattern::Identifier(name) => {
589                scope.define_var(name, None);
590            }
591            BindingPattern::Dict(fields) => {
592                for field in fields {
593                    let name = field.alias.as_deref().unwrap_or(&field.key);
594                    scope.define_var(name, None);
595                }
596            }
597            BindingPattern::List(elements) => {
598                for elem in elements {
599                    scope.define_var(&elem.name, None);
600                }
601            }
602        }
603    }
604
605    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
606        let span = snode.span;
607        match &snode.node {
608            Node::LetBinding {
609                pattern,
610                type_ann,
611                value,
612            } => {
613                let inferred = self.infer_type(value, scope);
614                if let BindingPattern::Identifier(name) = pattern {
615                    if let Some(expected) = type_ann {
616                        if let Some(actual) = &inferred {
617                            if !self.types_compatible(expected, actual, scope) {
618                                let mut msg = format!(
619                                    "Type mismatch: '{}' declared as {}, but assigned {}",
620                                    name,
621                                    format_type(expected),
622                                    format_type(actual)
623                                );
624                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
625                                    msg.push_str(&format!(" ({})", detail));
626                                }
627                                self.error_at(msg, span);
628                            }
629                        }
630                    }
631                    let ty = type_ann.clone().or(inferred);
632                    scope.define_var(name, ty);
633                } else {
634                    Self::define_pattern_vars(pattern, scope);
635                }
636            }
637
638            Node::VarBinding {
639                pattern,
640                type_ann,
641                value,
642            } => {
643                let inferred = self.infer_type(value, scope);
644                if let BindingPattern::Identifier(name) = pattern {
645                    if let Some(expected) = type_ann {
646                        if let Some(actual) = &inferred {
647                            if !self.types_compatible(expected, actual, scope) {
648                                let mut msg = format!(
649                                    "Type mismatch: '{}' declared as {}, but assigned {}",
650                                    name,
651                                    format_type(expected),
652                                    format_type(actual)
653                                );
654                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
655                                    msg.push_str(&format!(" ({})", detail));
656                                }
657                                self.error_at(msg, span);
658                            }
659                        }
660                    }
661                    let ty = type_ann.clone().or(inferred);
662                    scope.define_var(name, ty);
663                } else {
664                    Self::define_pattern_vars(pattern, scope);
665                }
666            }
667
668            Node::FnDecl {
669                name,
670                type_params,
671                params,
672                return_type,
673                where_clauses,
674                body,
675                ..
676            } => {
677                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
678                let sig = FnSignature {
679                    params: params
680                        .iter()
681                        .map(|p| (p.name.clone(), p.type_expr.clone()))
682                        .collect(),
683                    return_type: return_type.clone(),
684                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
685                    required_params,
686                    where_clauses: where_clauses
687                        .iter()
688                        .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
689                        .collect(),
690                };
691                scope.define_fn(name, sig.clone());
692                scope.define_var(name, None);
693                self.check_fn_body(type_params, params, return_type, body, where_clauses);
694            }
695
696            Node::FunctionCall { name, args } => {
697                self.check_call(name, args, scope, span);
698            }
699
700            Node::IfElse {
701                condition,
702                then_body,
703                else_body,
704            } => {
705                self.check_node(condition, scope);
706                let mut then_scope = scope.child();
707                // Narrow union types after nil checks: if x != nil, narrow x
708                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
709                    then_scope.define_var(&var_name, narrowed);
710                }
711                self.check_block(then_body, &mut then_scope);
712                if let Some(else_body) = else_body {
713                    let mut else_scope = scope.child();
714                    self.check_block(else_body, &mut else_scope);
715                }
716            }
717
718            Node::ForIn {
719                pattern,
720                iterable,
721                body,
722            } => {
723                self.check_node(iterable, scope);
724                let mut loop_scope = scope.child();
725                if let BindingPattern::Identifier(variable) = pattern {
726                    // Infer loop variable type from iterable
727                    let elem_type = match self.infer_type(iterable, scope) {
728                        Some(TypeExpr::List(inner)) => Some(*inner),
729                        Some(TypeExpr::Named(n)) if n == "string" => {
730                            Some(TypeExpr::Named("string".into()))
731                        }
732                        _ => None,
733                    };
734                    loop_scope.define_var(variable, elem_type);
735                } else {
736                    Self::define_pattern_vars(pattern, &mut loop_scope);
737                }
738                self.check_block(body, &mut loop_scope);
739            }
740
741            Node::WhileLoop { condition, body } => {
742                self.check_node(condition, scope);
743                let mut loop_scope = scope.child();
744                self.check_block(body, &mut loop_scope);
745            }
746
747            Node::TryCatch {
748                body,
749                error_var,
750                catch_body,
751                finally_body,
752                ..
753            } => {
754                let mut try_scope = scope.child();
755                self.check_block(body, &mut try_scope);
756                let mut catch_scope = scope.child();
757                if let Some(var) = error_var {
758                    catch_scope.define_var(var, None);
759                }
760                self.check_block(catch_body, &mut catch_scope);
761                if let Some(fb) = finally_body {
762                    let mut finally_scope = scope.child();
763                    self.check_block(fb, &mut finally_scope);
764                }
765            }
766
767            Node::TryExpr { body } => {
768                let mut try_scope = scope.child();
769                self.check_block(body, &mut try_scope);
770            }
771
772            Node::ReturnStmt {
773                value: Some(val), ..
774            } => {
775                self.check_node(val, scope);
776            }
777
778            Node::Assignment {
779                target, value, op, ..
780            } => {
781                self.check_node(value, scope);
782                if let Node::Identifier(name) = &target.node {
783                    if let Some(Some(var_type)) = scope.get_var(name) {
784                        let value_type = self.infer_type(value, scope);
785                        let assigned = if let Some(op) = op {
786                            let var_inferred = scope.get_var(name).cloned().flatten();
787                            infer_binary_op_type(op, &var_inferred, &value_type)
788                        } else {
789                            value_type
790                        };
791                        if let Some(actual) = &assigned {
792                            if !self.types_compatible(var_type, actual, scope) {
793                                self.error_at(
794                                    format!(
795                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
796                                        format_type(actual),
797                                        name,
798                                        format_type(var_type)
799                                    ),
800                                    span,
801                                );
802                            }
803                        }
804                    }
805                }
806            }
807
808            Node::TypeDecl { name, type_expr } => {
809                scope.type_aliases.insert(name.clone(), type_expr.clone());
810            }
811
812            Node::EnumDecl { name, variants } => {
813                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
814                scope.enums.insert(name.clone(), variant_names);
815            }
816
817            Node::StructDecl { name, fields } => {
818                let field_types: Vec<(String, InferredType)> = fields
819                    .iter()
820                    .map(|f| (f.name.clone(), f.type_expr.clone()))
821                    .collect();
822                scope.structs.insert(name.clone(), field_types);
823            }
824
825            Node::InterfaceDecl { name, methods } => {
826                scope.interfaces.insert(name.clone(), methods.clone());
827            }
828
829            Node::ImplBlock {
830                type_name, methods, ..
831            } => {
832                // Register impl methods for interface satisfaction checking
833                let sigs: Vec<ImplMethodSig> = methods
834                    .iter()
835                    .filter_map(|m| {
836                        if let Node::FnDecl {
837                            name,
838                            params,
839                            return_type,
840                            ..
841                        } = &m.node
842                        {
843                            let non_self: Vec<_> =
844                                params.iter().filter(|p| p.name != "self").collect();
845                            let param_count = non_self.len();
846                            let param_types: Vec<Option<TypeExpr>> =
847                                non_self.iter().map(|p| p.type_expr.clone()).collect();
848                            Some(ImplMethodSig {
849                                name: name.clone(),
850                                param_count,
851                                param_types,
852                                return_type: return_type.clone(),
853                            })
854                        } else {
855                            None
856                        }
857                    })
858                    .collect();
859                scope.impl_methods.insert(type_name.clone(), sigs);
860                for method_sn in methods {
861                    self.check_node(method_sn, scope);
862                }
863            }
864
865            Node::TryOperator { operand } => {
866                self.check_node(operand, scope);
867            }
868
869            Node::MatchExpr { value, arms } => {
870                self.check_node(value, scope);
871                let value_type = self.infer_type(value, scope);
872                for arm in arms {
873                    self.check_node(&arm.pattern, scope);
874                    // Check for incompatible literal pattern types
875                    if let Some(ref vt) = value_type {
876                        let value_type_name = format_type(vt);
877                        let mismatch = match &arm.pattern.node {
878                            Node::StringLiteral(_) => {
879                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
880                            }
881                            Node::IntLiteral(_) => {
882                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
883                                    && !self.types_compatible(
884                                        vt,
885                                        &TypeExpr::Named("float".into()),
886                                        scope,
887                                    )
888                            }
889                            Node::FloatLiteral(_) => {
890                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
891                                    && !self.types_compatible(
892                                        vt,
893                                        &TypeExpr::Named("int".into()),
894                                        scope,
895                                    )
896                            }
897                            Node::BoolLiteral(_) => {
898                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
899                            }
900                            _ => false,
901                        };
902                        if mismatch {
903                            let pattern_type = match &arm.pattern.node {
904                                Node::StringLiteral(_) => "string",
905                                Node::IntLiteral(_) => "int",
906                                Node::FloatLiteral(_) => "float",
907                                Node::BoolLiteral(_) => "bool",
908                                _ => unreachable!(),
909                            };
910                            self.warning_at(
911                                format!(
912                                    "Match pattern type mismatch: matching {} against {} literal",
913                                    value_type_name, pattern_type
914                                ),
915                                arm.pattern.span,
916                            );
917                        }
918                    }
919                    let mut arm_scope = scope.child();
920                    self.check_block(&arm.body, &mut arm_scope);
921                }
922                self.check_match_exhaustiveness(value, arms, scope, span);
923            }
924
925            // Recurse into nested expressions + validate binary op types
926            Node::BinaryOp { op, left, right } => {
927                self.check_node(left, scope);
928                self.check_node(right, scope);
929                // Validate operator/type compatibility
930                let lt = self.infer_type(left, scope);
931                let rt = self.infer_type(right, scope);
932                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
933                    match op.as_str() {
934                        "-" | "*" | "/" | "%" => {
935                            let numeric = ["int", "float"];
936                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
937                                self.warning_at(
938                                    format!(
939                                        "Operator '{op}' may not be valid for types {} and {}",
940                                        l, r
941                                    ),
942                                    span,
943                                );
944                            }
945                        }
946                        "+" => {
947                            // + is valid for int, float, string, list, dict
948                            let valid = ["int", "float", "string", "list", "dict"];
949                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
950                                self.warning_at(
951                                    format!(
952                                        "Operator '+' may not be valid for types {} and {}",
953                                        l, r
954                                    ),
955                                    span,
956                                );
957                            }
958                        }
959                        _ => {}
960                    }
961                }
962            }
963            Node::UnaryOp { operand, .. } => {
964                self.check_node(operand, scope);
965            }
966            Node::MethodCall {
967                object,
968                method,
969                args,
970                ..
971            }
972            | Node::OptionalMethodCall {
973                object,
974                method,
975                args,
976                ..
977            } => {
978                self.check_node(object, scope);
979                for arg in args {
980                    self.check_node(arg, scope);
981                }
982                // Definition-site generic checking: if the object's type is a
983                // constrained generic param (where T: Interface), verify the
984                // method exists in the bound interface.
985                if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
986                    if scope.is_generic_type_param(&type_name) {
987                        if let Some(iface_name) = scope.get_where_constraint(&type_name) {
988                            if let Some(iface_methods) = scope.get_interface(iface_name) {
989                                let has_method = iface_methods.iter().any(|m| m.name == *method);
990                                if !has_method {
991                                    self.warning_at(
992                                        format!(
993                                            "Method '{}' not found in interface '{}' (constraint on '{}')",
994                                            method, iface_name, type_name
995                                        ),
996                                        span,
997                                    );
998                                }
999                            }
1000                        }
1001                    }
1002                }
1003            }
1004            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
1005                self.check_node(object, scope);
1006            }
1007            Node::SubscriptAccess { object, index } => {
1008                self.check_node(object, scope);
1009                self.check_node(index, scope);
1010            }
1011            Node::SliceAccess { object, start, end } => {
1012                self.check_node(object, scope);
1013                if let Some(s) = start {
1014                    self.check_node(s, scope);
1015                }
1016                if let Some(e) = end {
1017                    self.check_node(e, scope);
1018                }
1019            }
1020
1021            // Terminals — nothing to check
1022            _ => {}
1023        }
1024    }
1025
1026    fn check_fn_body(
1027        &mut self,
1028        type_params: &[TypeParam],
1029        params: &[TypedParam],
1030        return_type: &Option<TypeExpr>,
1031        body: &[SNode],
1032        where_clauses: &[WhereClause],
1033    ) {
1034        let mut fn_scope = self.scope.child();
1035        // Register generic type parameters so they are treated as compatible
1036        // with any concrete type during type checking.
1037        for tp in type_params {
1038            fn_scope.generic_type_params.insert(tp.name.clone());
1039        }
1040        // Store where-clause constraints for definition-site checking
1041        for wc in where_clauses {
1042            fn_scope
1043                .where_constraints
1044                .insert(wc.type_name.clone(), wc.bound.clone());
1045        }
1046        for param in params {
1047            fn_scope.define_var(&param.name, param.type_expr.clone());
1048            if let Some(default) = &param.default_value {
1049                self.check_node(default, &mut fn_scope);
1050            }
1051        }
1052        self.check_block(body, &mut fn_scope);
1053
1054        // Check return statements against declared return type
1055        if let Some(ret_type) = return_type {
1056            for stmt in body {
1057                self.check_return_type(stmt, ret_type, &fn_scope);
1058            }
1059        }
1060    }
1061
1062    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
1063        let span = snode.span;
1064        match &snode.node {
1065            Node::ReturnStmt { value: Some(val) } => {
1066                let inferred = self.infer_type(val, scope);
1067                if let Some(actual) = &inferred {
1068                    if !self.types_compatible(expected, actual, scope) {
1069                        self.error_at(
1070                            format!(
1071                                "Return type mismatch: expected {}, got {}",
1072                                format_type(expected),
1073                                format_type(actual)
1074                            ),
1075                            span,
1076                        );
1077                    }
1078                }
1079            }
1080            Node::IfElse {
1081                then_body,
1082                else_body,
1083                ..
1084            } => {
1085                for stmt in then_body {
1086                    self.check_return_type(stmt, expected, scope);
1087                }
1088                if let Some(else_body) = else_body {
1089                    for stmt in else_body {
1090                        self.check_return_type(stmt, expected, scope);
1091                    }
1092                }
1093            }
1094            _ => {}
1095        }
1096    }
1097
1098    /// Check if a match expression on an enum's `.variant` property covers all variants.
1099    /// Extract narrowing info from nil-check conditions like `x != nil`.
1100    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
1101    /// Check if a type satisfies an interface (Go-style implicit satisfaction).
1102    /// A type satisfies an interface if its impl block has all the required methods.
1103    fn satisfies_interface(
1104        &self,
1105        type_name: &str,
1106        interface_name: &str,
1107        scope: &TypeScope,
1108    ) -> bool {
1109        self.interface_mismatch_reason(type_name, interface_name, scope)
1110            .is_none()
1111    }
1112
1113    /// Return a detailed reason why a type does not satisfy an interface, or None
1114    /// if it does satisfy it.  Used for producing actionable warning messages.
1115    fn interface_mismatch_reason(
1116        &self,
1117        type_name: &str,
1118        interface_name: &str,
1119        scope: &TypeScope,
1120    ) -> Option<String> {
1121        let interface_methods = match scope.get_interface(interface_name) {
1122            Some(methods) => methods,
1123            None => return Some(format!("interface '{}' not found", interface_name)),
1124        };
1125        let impl_methods = match scope.get_impl_methods(type_name) {
1126            Some(methods) => methods,
1127            None => {
1128                if interface_methods.is_empty() {
1129                    return None;
1130                }
1131                let names: Vec<_> = interface_methods.iter().map(|m| m.name.as_str()).collect();
1132                return Some(format!("missing method(s): {}", names.join(", ")));
1133            }
1134        };
1135        for iface_method in interface_methods {
1136            let iface_params: Vec<_> = iface_method
1137                .params
1138                .iter()
1139                .filter(|p| p.name != "self")
1140                .collect();
1141            let iface_param_count = iface_params.len();
1142            let matching_impl = impl_methods.iter().find(|im| im.name == iface_method.name);
1143            let impl_method = match matching_impl {
1144                Some(m) => m,
1145                None => {
1146                    return Some(format!("missing method '{}'", iface_method.name));
1147                }
1148            };
1149            if impl_method.param_count != iface_param_count {
1150                return Some(format!(
1151                    "method '{}' has {} parameter(s), expected {}",
1152                    iface_method.name, impl_method.param_count, iface_param_count
1153                ));
1154            }
1155            // Check parameter types where both sides specify them
1156            for (i, iface_param) in iface_params.iter().enumerate() {
1157                if let (Some(expected), Some(actual)) = (
1158                    &iface_param.type_expr,
1159                    impl_method.param_types.get(i).and_then(|t| t.as_ref()),
1160                ) {
1161                    if !self.types_compatible(expected, actual, scope) {
1162                        return Some(format!(
1163                            "method '{}' parameter {} has type '{}', expected '{}'",
1164                            iface_method.name,
1165                            i + 1,
1166                            format_type(actual),
1167                            format_type(expected),
1168                        ));
1169                    }
1170                }
1171            }
1172            // Check return type where both sides specify it
1173            if let (Some(expected_ret), Some(actual_ret)) =
1174                (&iface_method.return_type, &impl_method.return_type)
1175            {
1176                if !self.types_compatible(expected_ret, actual_ret, scope) {
1177                    return Some(format!(
1178                        "method '{}' returns '{}', expected '{}'",
1179                        iface_method.name,
1180                        format_type(actual_ret),
1181                        format_type(expected_ret),
1182                    ));
1183                }
1184            }
1185        }
1186        None
1187    }
1188
1189    /// Recursively extract type parameter bindings from matching param/arg types.
1190    /// E.g., param_type=list<T> + arg_type=list<Dog> → binds T=Dog.
1191    fn extract_type_bindings(
1192        param_type: &TypeExpr,
1193        arg_type: &TypeExpr,
1194        type_params: &std::collections::BTreeSet<String>,
1195        bindings: &mut BTreeMap<String, String>,
1196    ) {
1197        match (param_type, arg_type) {
1198            // Direct type param match: T → concrete
1199            (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
1200                if type_params.contains(param_name) =>
1201            {
1202                bindings
1203                    .entry(param_name.clone())
1204                    .or_insert(concrete.clone());
1205            }
1206            // list<T> + list<Dog>
1207            (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
1208                Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
1209            }
1210            // dict<K, V> + dict<string, int>
1211            (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
1212                Self::extract_type_bindings(pk, ak, type_params, bindings);
1213                Self::extract_type_bindings(pv, av, type_params, bindings);
1214            }
1215            _ => {}
1216        }
1217    }
1218
1219    fn extract_nil_narrowing(
1220        condition: &SNode,
1221        scope: &TypeScope,
1222    ) -> Option<(String, InferredType)> {
1223        if let Node::BinaryOp { op, left, right } = &condition.node {
1224            if op == "!=" {
1225                // Check for `x != nil` or `nil != x`
1226                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1227                    (left, right)
1228                } else if matches!(left.node, Node::NilLiteral) {
1229                    (right, left)
1230                } else {
1231                    return None;
1232                };
1233                let _ = nil_node;
1234                if let Node::Identifier(name) = &var_node.node {
1235                    // Look up the variable's type and narrow it
1236                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1237                        let narrowed: Vec<TypeExpr> = members
1238                            .iter()
1239                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1240                            .cloned()
1241                            .collect();
1242                        return if narrowed.len() == 1 {
1243                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1244                        } else if narrowed.is_empty() {
1245                            None
1246                        } else {
1247                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1248                        };
1249                    }
1250                }
1251            }
1252        }
1253        None
1254    }
1255
1256    fn check_match_exhaustiveness(
1257        &mut self,
1258        value: &SNode,
1259        arms: &[MatchArm],
1260        scope: &TypeScope,
1261        span: Span,
1262    ) {
1263        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
1264        let enum_name = match &value.node {
1265            Node::PropertyAccess { object, property } if property == "variant" => {
1266                // Infer the type of the object
1267                match self.infer_type(object, scope) {
1268                    Some(TypeExpr::Named(name)) => {
1269                        if scope.get_enum(&name).is_some() {
1270                            Some(name)
1271                        } else {
1272                            None
1273                        }
1274                    }
1275                    _ => None,
1276                }
1277            }
1278            _ => {
1279                // Direct match on an enum value: match <expr> { ... }
1280                match self.infer_type(value, scope) {
1281                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1282                    _ => None,
1283                }
1284            }
1285        };
1286
1287        let Some(enum_name) = enum_name else {
1288            return;
1289        };
1290        let Some(variants) = scope.get_enum(&enum_name) else {
1291            return;
1292        };
1293
1294        // Collect variant names covered by match arms
1295        let mut covered: Vec<String> = Vec::new();
1296        let mut has_wildcard = false;
1297
1298        for arm in arms {
1299            match &arm.pattern.node {
1300                // String literal pattern (matching on .variant): "VariantA"
1301                Node::StringLiteral(s) => covered.push(s.clone()),
1302                // Identifier pattern acts as a wildcard/catch-all
1303                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1304                    has_wildcard = true;
1305                }
1306                // Direct enum construct pattern: EnumName.Variant
1307                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1308                // PropertyAccess pattern: EnumName.Variant (no args)
1309                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1310                _ => {
1311                    // Unknown pattern shape — conservatively treat as wildcard
1312                    has_wildcard = true;
1313                }
1314            }
1315        }
1316
1317        if has_wildcard {
1318            return;
1319        }
1320
1321        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1322        if !missing.is_empty() {
1323            let missing_str = missing
1324                .iter()
1325                .map(|s| format!("\"{}\"", s))
1326                .collect::<Vec<_>>()
1327                .join(", ");
1328            self.warning_at(
1329                format!(
1330                    "Non-exhaustive match on enum {}: missing variants {}",
1331                    enum_name, missing_str
1332                ),
1333                span,
1334            );
1335        }
1336    }
1337
1338    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1339        // Check against known function signatures
1340        let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1341        if let Some(sig) = scope.get_fn(name).cloned() {
1342            if !has_spread
1343                && !is_builtin(name)
1344                && (args.len() < sig.required_params || args.len() > sig.params.len())
1345            {
1346                let expected = if sig.required_params == sig.params.len() {
1347                    format!("{}", sig.params.len())
1348                } else {
1349                    format!("{}-{}", sig.required_params, sig.params.len())
1350                };
1351                self.warning_at(
1352                    format!(
1353                        "Function '{}' expects {} arguments, got {}",
1354                        name,
1355                        expected,
1356                        args.len()
1357                    ),
1358                    span,
1359                );
1360            }
1361            // Build a scope that includes the function's generic type params
1362            // so they are treated as compatible with any concrete type.
1363            let call_scope = if sig.type_param_names.is_empty() {
1364                scope.clone()
1365            } else {
1366                let mut s = scope.child();
1367                for tp_name in &sig.type_param_names {
1368                    s.generic_type_params.insert(tp_name.clone());
1369                }
1370                s
1371            };
1372            for (i, (arg, (param_name, param_type))) in
1373                args.iter().zip(sig.params.iter()).enumerate()
1374            {
1375                if let Some(expected) = param_type {
1376                    let actual = self.infer_type(arg, scope);
1377                    if let Some(actual) = &actual {
1378                        if !self.types_compatible(expected, actual, &call_scope) {
1379                            self.error_at(
1380                                format!(
1381                                    "Argument {} ('{}'): expected {}, got {}",
1382                                    i + 1,
1383                                    param_name,
1384                                    format_type(expected),
1385                                    format_type(actual)
1386                                ),
1387                                arg.span,
1388                            );
1389                        }
1390                    }
1391                }
1392            }
1393            // Enforce where-clause constraints at call site
1394            if !sig.where_clauses.is_empty() {
1395                // Build mapping: type_param → concrete type from inferred args.
1396                // Recursively walks Generic types so list<T> + list<Dog> binds T=Dog.
1397                let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1398                let type_param_set: std::collections::BTreeSet<String> =
1399                    sig.type_param_names.iter().cloned().collect();
1400                for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1401                    if let Some(param_ty) = param_type {
1402                        if let Some(arg_ty) = self.infer_type(arg, scope) {
1403                            Self::extract_type_bindings(
1404                                param_ty,
1405                                &arg_ty,
1406                                &type_param_set,
1407                                &mut type_bindings,
1408                            );
1409                        }
1410                    }
1411                }
1412                for (type_param, bound) in &sig.where_clauses {
1413                    if let Some(concrete_type) = type_bindings.get(type_param) {
1414                        if let Some(reason) =
1415                            self.interface_mismatch_reason(concrete_type, bound, scope)
1416                        {
1417                            self.warning_at(
1418                                format!(
1419                                    "Type '{}' does not satisfy interface '{}': {} \
1420                                     (required by constraint `where {}: {}`)",
1421                                    concrete_type, bound, reason, type_param, bound
1422                                ),
1423                                span,
1424                            );
1425                        }
1426                    }
1427                }
1428            }
1429        }
1430        // Check args recursively
1431        for arg in args {
1432            self.check_node(arg, scope);
1433        }
1434    }
1435
1436    /// Infer the type of an expression.
1437    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1438        match &snode.node {
1439            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1440            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1441            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1442                Some(TypeExpr::Named("string".into()))
1443            }
1444            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1445            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1446            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1447            Node::DictLiteral(entries) => {
1448                // Infer shape type when all keys are string literals
1449                let mut fields = Vec::new();
1450                let mut all_string_keys = true;
1451                for entry in entries {
1452                    if let Node::StringLiteral(key) = &entry.key.node {
1453                        let val_type = self
1454                            .infer_type(&entry.value, scope)
1455                            .unwrap_or(TypeExpr::Named("nil".into()));
1456                        fields.push(ShapeField {
1457                            name: key.clone(),
1458                            type_expr: val_type,
1459                            optional: false,
1460                        });
1461                    } else {
1462                        all_string_keys = false;
1463                        break;
1464                    }
1465                }
1466                if all_string_keys && !fields.is_empty() {
1467                    Some(TypeExpr::Shape(fields))
1468                } else {
1469                    Some(TypeExpr::Named("dict".into()))
1470                }
1471            }
1472            Node::Closure { params, body, .. } => {
1473                // If all params are typed and we can infer a return type, produce FnType
1474                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1475                if all_typed && !params.is_empty() {
1476                    let param_types: Vec<TypeExpr> =
1477                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1478                    // Try to infer return type from last expression in body
1479                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1480                    if let Some(ret_type) = ret {
1481                        return Some(TypeExpr::FnType {
1482                            params: param_types,
1483                            return_type: Box::new(ret_type),
1484                        });
1485                    }
1486                }
1487                Some(TypeExpr::Named("closure".into()))
1488            }
1489
1490            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1491
1492            Node::FunctionCall { name, .. } => {
1493                // Struct constructor calls return the struct type
1494                if scope.get_struct(name).is_some() {
1495                    return Some(TypeExpr::Named(name.clone()));
1496                }
1497                // Check user-defined function return types
1498                if let Some(sig) = scope.get_fn(name) {
1499                    return sig.return_type.clone();
1500                }
1501                // Check builtin return types
1502                builtin_return_type(name)
1503            }
1504
1505            Node::BinaryOp { op, left, right } => {
1506                let lt = self.infer_type(left, scope);
1507                let rt = self.infer_type(right, scope);
1508                infer_binary_op_type(op, &lt, &rt)
1509            }
1510
1511            Node::UnaryOp { op, operand } => {
1512                let t = self.infer_type(operand, scope);
1513                match op.as_str() {
1514                    "!" => Some(TypeExpr::Named("bool".into())),
1515                    "-" => t, // negation preserves type
1516                    _ => None,
1517                }
1518            }
1519
1520            Node::Ternary {
1521                true_expr,
1522                false_expr,
1523                ..
1524            } => {
1525                let tt = self.infer_type(true_expr, scope);
1526                let ft = self.infer_type(false_expr, scope);
1527                match (&tt, &ft) {
1528                    (Some(a), Some(b)) if a == b => tt,
1529                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1530                    (Some(_), None) => tt,
1531                    (None, Some(_)) => ft,
1532                    (None, None) => None,
1533                }
1534            }
1535
1536            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1537
1538            Node::PropertyAccess { object, property } => {
1539                // EnumName.Variant → infer as the enum type
1540                if let Node::Identifier(name) = &object.node {
1541                    if scope.get_enum(name).is_some() {
1542                        return Some(TypeExpr::Named(name.clone()));
1543                    }
1544                }
1545                // .variant on an enum value → string
1546                if property == "variant" {
1547                    let obj_type = self.infer_type(object, scope);
1548                    if let Some(TypeExpr::Named(name)) = &obj_type {
1549                        if scope.get_enum(name).is_some() {
1550                            return Some(TypeExpr::Named("string".into()));
1551                        }
1552                    }
1553                }
1554                // Shape field access: obj.field → field type
1555                let obj_type = self.infer_type(object, scope);
1556                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1557                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1558                        return Some(field.type_expr.clone());
1559                    }
1560                }
1561                None
1562            }
1563
1564            Node::SubscriptAccess { object, index } => {
1565                let obj_type = self.infer_type(object, scope);
1566                match &obj_type {
1567                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1568                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1569                    Some(TypeExpr::Shape(fields)) => {
1570                        // If index is a string literal, look up the field type
1571                        if let Node::StringLiteral(key) = &index.node {
1572                            fields
1573                                .iter()
1574                                .find(|f| &f.name == key)
1575                                .map(|f| f.type_expr.clone())
1576                        } else {
1577                            None
1578                        }
1579                    }
1580                    Some(TypeExpr::Named(n)) if n == "list" => None,
1581                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1582                    Some(TypeExpr::Named(n)) if n == "string" => {
1583                        Some(TypeExpr::Named("string".into()))
1584                    }
1585                    _ => None,
1586                }
1587            }
1588            Node::SliceAccess { object, .. } => {
1589                // Slicing a list returns the same list type; slicing a string returns string
1590                let obj_type = self.infer_type(object, scope);
1591                match &obj_type {
1592                    Some(TypeExpr::List(_)) => obj_type,
1593                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1594                    Some(TypeExpr::Named(n)) if n == "string" => {
1595                        Some(TypeExpr::Named("string".into()))
1596                    }
1597                    _ => None,
1598                }
1599            }
1600            Node::MethodCall { object, method, .. }
1601            | Node::OptionalMethodCall { object, method, .. } => {
1602                let obj_type = self.infer_type(object, scope);
1603                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1604                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1605                match method.as_str() {
1606                    // Shared: bool-returning methods
1607                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1608                        Some(TypeExpr::Named("bool".into()))
1609                    }
1610                    // Shared: int-returning methods
1611                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1612                    // String methods
1613                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1614                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1615                        Some(TypeExpr::Named("string".into()))
1616                    }
1617                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1618                    // filter returns dict for dicts, list for lists
1619                    "filter" => {
1620                        if is_dict {
1621                            Some(TypeExpr::Named("dict".into()))
1622                        } else {
1623                            Some(TypeExpr::Named("list".into()))
1624                        }
1625                    }
1626                    // List methods
1627                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1628                    "reduce" | "find" | "first" | "last" => None,
1629                    // Dict methods
1630                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1631                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1632                    // Conversions
1633                    "to_string" => Some(TypeExpr::Named("string".into())),
1634                    "to_int" => Some(TypeExpr::Named("int".into())),
1635                    "to_float" => Some(TypeExpr::Named("float".into())),
1636                    _ => None,
1637                }
1638            }
1639
1640            // TryOperator on Result<T, E> produces T
1641            Node::TryOperator { operand } => {
1642                match self.infer_type(operand, scope) {
1643                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1644                    _ => None,
1645                }
1646            }
1647
1648            _ => None,
1649        }
1650    }
1651
1652    /// Check if two types are compatible (actual can be assigned to expected).
1653    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1654        // Generic type parameters match anything.
1655        if let TypeExpr::Named(name) = expected {
1656            if scope.is_generic_type_param(name) {
1657                return true;
1658            }
1659        }
1660        if let TypeExpr::Named(name) = actual {
1661            if scope.is_generic_type_param(name) {
1662                return true;
1663            }
1664        }
1665        let expected = self.resolve_alias(expected, scope);
1666        let actual = self.resolve_alias(actual, scope);
1667
1668        // Interface satisfaction: if expected is an interface name, check if actual type
1669        // has all required methods (Go-style implicit satisfaction).
1670        if let TypeExpr::Named(iface_name) = &expected {
1671            if scope.get_interface(iface_name).is_some() {
1672                if let TypeExpr::Named(type_name) = &actual {
1673                    return self.satisfies_interface(type_name, iface_name, scope);
1674                }
1675                return false;
1676            }
1677        }
1678
1679        match (&expected, &actual) {
1680            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1681            (TypeExpr::Union(members), actual_type) => members
1682                .iter()
1683                .any(|m| self.types_compatible(m, actual_type, scope)),
1684            (expected_type, TypeExpr::Union(members)) => members
1685                .iter()
1686                .all(|m| self.types_compatible(expected_type, m, scope)),
1687            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1688            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1689            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1690                if expected_field.optional {
1691                    return true;
1692                }
1693                af.iter().any(|actual_field| {
1694                    actual_field.name == expected_field.name
1695                        && self.types_compatible(
1696                            &expected_field.type_expr,
1697                            &actual_field.type_expr,
1698                            scope,
1699                        )
1700                })
1701            }),
1702            // dict<K, V> expected, Shape actual → all field values must match V
1703            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1704                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1705                keys_ok
1706                    && af
1707                        .iter()
1708                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1709            }
1710            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1711            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1712            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1713                self.types_compatible(expected_inner, actual_inner, scope)
1714            }
1715            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1716            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1717            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1718                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1719            }
1720            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1721            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1722            // FnType compatibility: params match positionally and return types match
1723            (
1724                TypeExpr::FnType {
1725                    params: ep,
1726                    return_type: er,
1727                },
1728                TypeExpr::FnType {
1729                    params: ap,
1730                    return_type: ar,
1731                },
1732            ) => {
1733                ep.len() == ap.len()
1734                    && ep
1735                        .iter()
1736                        .zip(ap.iter())
1737                        .all(|(e, a)| self.types_compatible(e, a, scope))
1738                    && self.types_compatible(er, ar, scope)
1739            }
1740            // FnType is compatible with Named("closure") for backward compat
1741            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1742            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1743            _ => false,
1744        }
1745    }
1746
1747    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1748        if let TypeExpr::Named(name) = ty {
1749            if let Some(resolved) = scope.resolve_type(name) {
1750                return resolved.clone();
1751            }
1752        }
1753        ty.clone()
1754    }
1755
1756    fn error_at(&mut self, message: String, span: Span) {
1757        self.diagnostics.push(TypeDiagnostic {
1758            message,
1759            severity: DiagnosticSeverity::Error,
1760            span: Some(span),
1761            help: None,
1762        });
1763    }
1764
1765    #[allow(dead_code)]
1766    fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
1767        self.diagnostics.push(TypeDiagnostic {
1768            message,
1769            severity: DiagnosticSeverity::Error,
1770            span: Some(span),
1771            help: Some(help),
1772        });
1773    }
1774
1775    fn warning_at(&mut self, message: String, span: Span) {
1776        self.diagnostics.push(TypeDiagnostic {
1777            message,
1778            severity: DiagnosticSeverity::Warning,
1779            span: Some(span),
1780            help: None,
1781        });
1782    }
1783
1784    #[allow(dead_code)]
1785    fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
1786        self.diagnostics.push(TypeDiagnostic {
1787            message,
1788            severity: DiagnosticSeverity::Warning,
1789            span: Some(span),
1790            help: Some(help),
1791        });
1792    }
1793}
1794
1795impl Default for TypeChecker {
1796    fn default() -> Self {
1797        Self::new()
1798    }
1799}
1800
1801/// Infer the result type of a binary operation.
1802fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1803    match op {
1804        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
1805            Some(TypeExpr::Named("bool".into()))
1806        }
1807        "+" => match (left, right) {
1808            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1809                match (l.as_str(), r.as_str()) {
1810                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1811                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1812                    ("string", _) => Some(TypeExpr::Named("string".into())),
1813                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1814                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1815                    _ => Some(TypeExpr::Named("string".into())),
1816                }
1817            }
1818            _ => None,
1819        },
1820        "-" | "*" | "/" | "%" => match (left, right) {
1821            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1822                match (l.as_str(), r.as_str()) {
1823                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1824                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1825                    _ => None,
1826                }
1827            }
1828            _ => None,
1829        },
1830        "??" => match (left, right) {
1831            (Some(TypeExpr::Union(members)), _) => {
1832                let non_nil: Vec<_> = members
1833                    .iter()
1834                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1835                    .cloned()
1836                    .collect();
1837                if non_nil.len() == 1 {
1838                    Some(non_nil[0].clone())
1839                } else if non_nil.is_empty() {
1840                    right.clone()
1841                } else {
1842                    Some(TypeExpr::Union(non_nil))
1843                }
1844            }
1845            _ => right.clone(),
1846        },
1847        "|>" => None,
1848        _ => None,
1849    }
1850}
1851
1852/// Format a type expression for display in error messages.
1853/// Produce a detail string describing why a Shape type is incompatible with
1854/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1855/// has type int, expected string".  Returns `None` if both types are not shapes.
1856pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1857    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1858        let mut details = Vec::new();
1859        for field in ef {
1860            if field.optional {
1861                continue;
1862            }
1863            match af.iter().find(|f| f.name == field.name) {
1864                None => details.push(format!(
1865                    "missing field '{}' ({})",
1866                    field.name,
1867                    format_type(&field.type_expr)
1868                )),
1869                Some(actual_field) => {
1870                    let e_str = format_type(&field.type_expr);
1871                    let a_str = format_type(&actual_field.type_expr);
1872                    if e_str != a_str {
1873                        details.push(format!(
1874                            "field '{}' has type {}, expected {}",
1875                            field.name, a_str, e_str
1876                        ));
1877                    }
1878                }
1879            }
1880        }
1881        if details.is_empty() {
1882            None
1883        } else {
1884            Some(details.join("; "))
1885        }
1886    } else {
1887        None
1888    }
1889}
1890
1891pub fn format_type(ty: &TypeExpr) -> String {
1892    match ty {
1893        TypeExpr::Named(n) => n.clone(),
1894        TypeExpr::Union(types) => types
1895            .iter()
1896            .map(format_type)
1897            .collect::<Vec<_>>()
1898            .join(" | "),
1899        TypeExpr::Shape(fields) => {
1900            let inner: Vec<String> = fields
1901                .iter()
1902                .map(|f| {
1903                    let opt = if f.optional { "?" } else { "" };
1904                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1905                })
1906                .collect();
1907            format!("{{{}}}", inner.join(", "))
1908        }
1909        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1910        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1911        TypeExpr::FnType {
1912            params,
1913            return_type,
1914        } => {
1915            let params_str = params
1916                .iter()
1917                .map(format_type)
1918                .collect::<Vec<_>>()
1919                .join(", ");
1920            format!("fn({}) -> {}", params_str, format_type(return_type))
1921        }
1922    }
1923}
1924
1925#[cfg(test)]
1926mod tests {
1927    use super::*;
1928    use crate::Parser;
1929    use harn_lexer::Lexer;
1930
1931    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1932        let mut lexer = Lexer::new(source);
1933        let tokens = lexer.tokenize().unwrap();
1934        let mut parser = Parser::new(tokens);
1935        let program = parser.parse().unwrap();
1936        TypeChecker::new().check(&program)
1937    }
1938
1939    fn errors(source: &str) -> Vec<String> {
1940        check_source(source)
1941            .into_iter()
1942            .filter(|d| d.severity == DiagnosticSeverity::Error)
1943            .map(|d| d.message)
1944            .collect()
1945    }
1946
1947    #[test]
1948    fn test_no_errors_for_untyped_code() {
1949        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1950        assert!(errs.is_empty());
1951    }
1952
1953    #[test]
1954    fn test_correct_typed_let() {
1955        let errs = errors("pipeline t(task) { let x: int = 42 }");
1956        assert!(errs.is_empty());
1957    }
1958
1959    #[test]
1960    fn test_type_mismatch_let() {
1961        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1962        assert_eq!(errs.len(), 1);
1963        assert!(errs[0].contains("Type mismatch"));
1964        assert!(errs[0].contains("int"));
1965        assert!(errs[0].contains("string"));
1966    }
1967
1968    #[test]
1969    fn test_correct_typed_fn() {
1970        let errs = errors(
1971            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1972        );
1973        assert!(errs.is_empty());
1974    }
1975
1976    #[test]
1977    fn test_fn_arg_type_mismatch() {
1978        let errs = errors(
1979            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1980add("hello", 2) }"#,
1981        );
1982        assert_eq!(errs.len(), 1);
1983        assert!(errs[0].contains("Argument 1"));
1984        assert!(errs[0].contains("expected int"));
1985    }
1986
1987    #[test]
1988    fn test_return_type_mismatch() {
1989        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1990        assert_eq!(errs.len(), 1);
1991        assert!(errs[0].contains("Return type mismatch"));
1992    }
1993
1994    #[test]
1995    fn test_union_type_compatible() {
1996        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1997        assert!(errs.is_empty());
1998    }
1999
2000    #[test]
2001    fn test_union_type_mismatch() {
2002        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
2003        assert_eq!(errs.len(), 1);
2004        assert!(errs[0].contains("Type mismatch"));
2005    }
2006
2007    #[test]
2008    fn test_type_inference_propagation() {
2009        let errs = errors(
2010            r#"pipeline t(task) {
2011  fn add(a: int, b: int) -> int { return a + b }
2012  let result: string = add(1, 2)
2013}"#,
2014        );
2015        assert_eq!(errs.len(), 1);
2016        assert!(errs[0].contains("Type mismatch"));
2017        assert!(errs[0].contains("string"));
2018        assert!(errs[0].contains("int"));
2019    }
2020
2021    #[test]
2022    fn test_builtin_return_type_inference() {
2023        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
2024        assert_eq!(errs.len(), 1);
2025        assert!(errs[0].contains("string"));
2026        assert!(errs[0].contains("int"));
2027    }
2028
2029    #[test]
2030    fn test_workflow_and_transcript_builtins_are_known() {
2031        let errs = errors(
2032            r#"pipeline t(task) {
2033  let flow = workflow_graph({name: "demo", entry: "act", nodes: {act: {kind: "stage"}}})
2034  let report: dict = workflow_policy_report(flow, {tools: ["read"], capabilities: {workspace: ["read_text"]}})
2035  let run: dict = workflow_execute("task", flow, [], {})
2036  let fixture: dict = run_record_fixture(run?.run)
2037  let suite: dict = run_record_eval_suite([{run: run?.run, fixture: fixture}])
2038  let diff: dict = run_record_diff(run?.run, run?.run)
2039  let manifest: dict = eval_suite_manifest({cases: [{run_path: "run.json"}]})
2040  let suite_report: dict = eval_suite_run(manifest)
2041  let wf: dict = artifact_workspace_file("src/main.rs", "fn main() {}", {source: "host"})
2042  let snap: dict = artifact_workspace_snapshot(["src/main.rs"], "snapshot")
2043  let selection: dict = artifact_editor_selection("src/main.rs", "main")
2044  let verify: dict = artifact_verification_result("verify", "ok")
2045  let test_result: dict = artifact_test_result("tests", "pass")
2046  let cmd: dict = artifact_command_result("cargo test", {status: 0})
2047  let patch: dict = artifact_diff("src/main.rs", "old", "new")
2048  let git: dict = artifact_git_diff("diff --git a b")
2049  let review: dict = artifact_diff_review(patch, "review me")
2050  let decision: dict = artifact_review_decision(review, "accepted")
2051  let transcript = transcript_reset({metadata: {source: "test"}})
2052  let visible: string = transcript_render_visible(transcript_archive(transcript))
2053  let events: list = transcript_events(transcript)
2054  let context: string = artifact_context([], {max_artifacts: 1})
2055  println(report)
2056  println(run)
2057  println(fixture)
2058  println(suite)
2059  println(diff)
2060  println(manifest)
2061  println(suite_report)
2062  println(wf)
2063  println(snap)
2064  println(selection)
2065  println(verify)
2066  println(test_result)
2067  println(cmd)
2068  println(patch)
2069  println(git)
2070  println(review)
2071  println(decision)
2072  println(visible)
2073  println(events)
2074  println(context)
2075}"#,
2076        );
2077        assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
2078    }
2079
2080    #[test]
2081    fn test_binary_op_type_inference() {
2082        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
2083        assert_eq!(errs.len(), 1);
2084    }
2085
2086    #[test]
2087    fn test_comparison_returns_bool() {
2088        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
2089        assert!(errs.is_empty());
2090    }
2091
2092    #[test]
2093    fn test_int_float_promotion() {
2094        let errs = errors("pipeline t(task) { let x: float = 42 }");
2095        assert!(errs.is_empty());
2096    }
2097
2098    #[test]
2099    fn test_untyped_code_no_errors() {
2100        let errs = errors(
2101            r#"pipeline t(task) {
2102  fn process(data) {
2103    let result = data + " processed"
2104    return result
2105  }
2106  log(process("hello"))
2107}"#,
2108        );
2109        assert!(errs.is_empty());
2110    }
2111
2112    #[test]
2113    fn test_type_alias() {
2114        let errs = errors(
2115            r#"pipeline t(task) {
2116  type Name = string
2117  let x: Name = "hello"
2118}"#,
2119        );
2120        assert!(errs.is_empty());
2121    }
2122
2123    #[test]
2124    fn test_type_alias_mismatch() {
2125        let errs = errors(
2126            r#"pipeline t(task) {
2127  type Name = string
2128  let x: Name = 42
2129}"#,
2130        );
2131        assert_eq!(errs.len(), 1);
2132    }
2133
2134    #[test]
2135    fn test_assignment_type_check() {
2136        let errs = errors(
2137            r#"pipeline t(task) {
2138  var x: int = 0
2139  x = "hello"
2140}"#,
2141        );
2142        assert_eq!(errs.len(), 1);
2143        assert!(errs[0].contains("cannot assign string"));
2144    }
2145
2146    #[test]
2147    fn test_covariance_int_to_float_in_fn() {
2148        let errs = errors(
2149            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
2150        );
2151        assert!(errs.is_empty());
2152    }
2153
2154    #[test]
2155    fn test_covariance_return_type() {
2156        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
2157        assert!(errs.is_empty());
2158    }
2159
2160    #[test]
2161    fn test_no_contravariance_float_to_int() {
2162        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
2163        assert_eq!(errs.len(), 1);
2164    }
2165
2166    // --- Exhaustiveness checking tests ---
2167
2168    fn warnings(source: &str) -> Vec<String> {
2169        check_source(source)
2170            .into_iter()
2171            .filter(|d| d.severity == DiagnosticSeverity::Warning)
2172            .map(|d| d.message)
2173            .collect()
2174    }
2175
2176    #[test]
2177    fn test_exhaustive_match_no_warning() {
2178        let warns = warnings(
2179            r#"pipeline t(task) {
2180  enum Color { Red, Green, Blue }
2181  let c = Color.Red
2182  match c.variant {
2183    "Red" -> { log("r") }
2184    "Green" -> { log("g") }
2185    "Blue" -> { log("b") }
2186  }
2187}"#,
2188        );
2189        let exhaustive_warns: Vec<_> = warns
2190            .iter()
2191            .filter(|w| w.contains("Non-exhaustive"))
2192            .collect();
2193        assert!(exhaustive_warns.is_empty());
2194    }
2195
2196    #[test]
2197    fn test_non_exhaustive_match_warning() {
2198        let warns = warnings(
2199            r#"pipeline t(task) {
2200  enum Color { Red, Green, Blue }
2201  let c = Color.Red
2202  match c.variant {
2203    "Red" -> { log("r") }
2204    "Green" -> { log("g") }
2205  }
2206}"#,
2207        );
2208        let exhaustive_warns: Vec<_> = warns
2209            .iter()
2210            .filter(|w| w.contains("Non-exhaustive"))
2211            .collect();
2212        assert_eq!(exhaustive_warns.len(), 1);
2213        assert!(exhaustive_warns[0].contains("Blue"));
2214    }
2215
2216    #[test]
2217    fn test_non_exhaustive_multiple_missing() {
2218        let warns = warnings(
2219            r#"pipeline t(task) {
2220  enum Status { Active, Inactive, Pending }
2221  let s = Status.Active
2222  match s.variant {
2223    "Active" -> { log("a") }
2224  }
2225}"#,
2226        );
2227        let exhaustive_warns: Vec<_> = warns
2228            .iter()
2229            .filter(|w| w.contains("Non-exhaustive"))
2230            .collect();
2231        assert_eq!(exhaustive_warns.len(), 1);
2232        assert!(exhaustive_warns[0].contains("Inactive"));
2233        assert!(exhaustive_warns[0].contains("Pending"));
2234    }
2235
2236    #[test]
2237    fn test_enum_construct_type_inference() {
2238        let errs = errors(
2239            r#"pipeline t(task) {
2240  enum Color { Red, Green, Blue }
2241  let c: Color = Color.Red
2242}"#,
2243        );
2244        assert!(errs.is_empty());
2245    }
2246
2247    // --- Type narrowing tests ---
2248
2249    #[test]
2250    fn test_nil_coalescing_strips_nil() {
2251        // After ??, nil should be stripped from the type
2252        let errs = errors(
2253            r#"pipeline t(task) {
2254  let x: string | nil = nil
2255  let y: string = x ?? "default"
2256}"#,
2257        );
2258        assert!(errs.is_empty());
2259    }
2260
2261    #[test]
2262    fn test_shape_mismatch_detail_missing_field() {
2263        let errs = errors(
2264            r#"pipeline t(task) {
2265  let x: {name: string, age: int} = {name: "hello"}
2266}"#,
2267        );
2268        assert_eq!(errs.len(), 1);
2269        assert!(
2270            errs[0].contains("missing field 'age'"),
2271            "expected detail about missing field, got: {}",
2272            errs[0]
2273        );
2274    }
2275
2276    #[test]
2277    fn test_shape_mismatch_detail_wrong_type() {
2278        let errs = errors(
2279            r#"pipeline t(task) {
2280  let x: {name: string, age: int} = {name: 42, age: 10}
2281}"#,
2282        );
2283        assert_eq!(errs.len(), 1);
2284        assert!(
2285            errs[0].contains("field 'name' has type int, expected string"),
2286            "expected detail about wrong type, got: {}",
2287            errs[0]
2288        );
2289    }
2290
2291    // --- Match pattern type validation tests ---
2292
2293    #[test]
2294    fn test_match_pattern_string_against_int() {
2295        let warns = warnings(
2296            r#"pipeline t(task) {
2297  let x: int = 42
2298  match x {
2299    "hello" -> { log("bad") }
2300    42 -> { log("ok") }
2301  }
2302}"#,
2303        );
2304        let pattern_warns: Vec<_> = warns
2305            .iter()
2306            .filter(|w| w.contains("Match pattern type mismatch"))
2307            .collect();
2308        assert_eq!(pattern_warns.len(), 1);
2309        assert!(pattern_warns[0].contains("matching int against string literal"));
2310    }
2311
2312    #[test]
2313    fn test_match_pattern_int_against_string() {
2314        let warns = warnings(
2315            r#"pipeline t(task) {
2316  let x: string = "hello"
2317  match x {
2318    42 -> { log("bad") }
2319    "hello" -> { log("ok") }
2320  }
2321}"#,
2322        );
2323        let pattern_warns: Vec<_> = warns
2324            .iter()
2325            .filter(|w| w.contains("Match pattern type mismatch"))
2326            .collect();
2327        assert_eq!(pattern_warns.len(), 1);
2328        assert!(pattern_warns[0].contains("matching string against int literal"));
2329    }
2330
2331    #[test]
2332    fn test_match_pattern_bool_against_int() {
2333        let warns = warnings(
2334            r#"pipeline t(task) {
2335  let x: int = 42
2336  match x {
2337    true -> { log("bad") }
2338    42 -> { log("ok") }
2339  }
2340}"#,
2341        );
2342        let pattern_warns: Vec<_> = warns
2343            .iter()
2344            .filter(|w| w.contains("Match pattern type mismatch"))
2345            .collect();
2346        assert_eq!(pattern_warns.len(), 1);
2347        assert!(pattern_warns[0].contains("matching int against bool literal"));
2348    }
2349
2350    #[test]
2351    fn test_match_pattern_float_against_string() {
2352        let warns = warnings(
2353            r#"pipeline t(task) {
2354  let x: string = "hello"
2355  match x {
2356    3.14 -> { log("bad") }
2357    "hello" -> { log("ok") }
2358  }
2359}"#,
2360        );
2361        let pattern_warns: Vec<_> = warns
2362            .iter()
2363            .filter(|w| w.contains("Match pattern type mismatch"))
2364            .collect();
2365        assert_eq!(pattern_warns.len(), 1);
2366        assert!(pattern_warns[0].contains("matching string against float literal"));
2367    }
2368
2369    #[test]
2370    fn test_match_pattern_int_against_float_ok() {
2371        // int and float are compatible for match patterns
2372        let warns = warnings(
2373            r#"pipeline t(task) {
2374  let x: float = 3.14
2375  match x {
2376    42 -> { log("ok") }
2377    _ -> { log("default") }
2378  }
2379}"#,
2380        );
2381        let pattern_warns: Vec<_> = warns
2382            .iter()
2383            .filter(|w| w.contains("Match pattern type mismatch"))
2384            .collect();
2385        assert!(pattern_warns.is_empty());
2386    }
2387
2388    #[test]
2389    fn test_match_pattern_float_against_int_ok() {
2390        // float and int are compatible for match patterns
2391        let warns = warnings(
2392            r#"pipeline t(task) {
2393  let x: int = 42
2394  match x {
2395    3.14 -> { log("close") }
2396    _ -> { log("default") }
2397  }
2398}"#,
2399        );
2400        let pattern_warns: Vec<_> = warns
2401            .iter()
2402            .filter(|w| w.contains("Match pattern type mismatch"))
2403            .collect();
2404        assert!(pattern_warns.is_empty());
2405    }
2406
2407    #[test]
2408    fn test_match_pattern_correct_types_no_warning() {
2409        let warns = warnings(
2410            r#"pipeline t(task) {
2411  let x: int = 42
2412  match x {
2413    1 -> { log("one") }
2414    2 -> { log("two") }
2415    _ -> { log("other") }
2416  }
2417}"#,
2418        );
2419        let pattern_warns: Vec<_> = warns
2420            .iter()
2421            .filter(|w| w.contains("Match pattern type mismatch"))
2422            .collect();
2423        assert!(pattern_warns.is_empty());
2424    }
2425
2426    #[test]
2427    fn test_match_pattern_wildcard_no_warning() {
2428        let warns = warnings(
2429            r#"pipeline t(task) {
2430  let x: int = 42
2431  match x {
2432    _ -> { log("catch all") }
2433  }
2434}"#,
2435        );
2436        let pattern_warns: Vec<_> = warns
2437            .iter()
2438            .filter(|w| w.contains("Match pattern type mismatch"))
2439            .collect();
2440        assert!(pattern_warns.is_empty());
2441    }
2442
2443    #[test]
2444    fn test_match_pattern_untyped_no_warning() {
2445        // When value has no known type, no warning should be emitted
2446        let warns = warnings(
2447            r#"pipeline t(task) {
2448  let x = some_unknown_fn()
2449  match x {
2450    "hello" -> { log("string") }
2451    42 -> { log("int") }
2452  }
2453}"#,
2454        );
2455        let pattern_warns: Vec<_> = warns
2456            .iter()
2457            .filter(|w| w.contains("Match pattern type mismatch"))
2458            .collect();
2459        assert!(pattern_warns.is_empty());
2460    }
2461
2462    // --- Interface constraint type checking tests ---
2463
2464    fn iface_warns(source: &str) -> Vec<String> {
2465        warnings(source)
2466            .into_iter()
2467            .filter(|w| w.contains("does not satisfy interface"))
2468            .collect()
2469    }
2470
2471    #[test]
2472    fn test_interface_constraint_return_type_mismatch() {
2473        let warns = iface_warns(
2474            r#"pipeline t(task) {
2475  interface Sizable {
2476    fn size(self) -> int
2477  }
2478  struct Box { width: int }
2479  impl Box {
2480    fn size(self) -> string { return "nope" }
2481  }
2482  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2483  measure(Box({width: 3}))
2484}"#,
2485        );
2486        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2487        assert!(
2488            warns[0].contains("method 'size' returns 'string', expected 'int'"),
2489            "unexpected message: {}",
2490            warns[0]
2491        );
2492    }
2493
2494    #[test]
2495    fn test_interface_constraint_param_type_mismatch() {
2496        let warns = iface_warns(
2497            r#"pipeline t(task) {
2498  interface Processor {
2499    fn process(self, x: int) -> string
2500  }
2501  struct MyProc { name: string }
2502  impl MyProc {
2503    fn process(self, x: string) -> string { return x }
2504  }
2505  fn run_proc<T>(p: T) where T: Processor { log(p.process(42)) }
2506  run_proc(MyProc({name: "a"}))
2507}"#,
2508        );
2509        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2510        assert!(
2511            warns[0].contains("method 'process' parameter 1 has type 'string', expected 'int'"),
2512            "unexpected message: {}",
2513            warns[0]
2514        );
2515    }
2516
2517    #[test]
2518    fn test_interface_constraint_missing_method() {
2519        let warns = iface_warns(
2520            r#"pipeline t(task) {
2521  interface Sizable {
2522    fn size(self) -> int
2523  }
2524  struct Box { width: int }
2525  impl Box {
2526    fn area(self) -> int { return self.width }
2527  }
2528  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2529  measure(Box({width: 3}))
2530}"#,
2531        );
2532        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2533        assert!(
2534            warns[0].contains("missing method 'size'"),
2535            "unexpected message: {}",
2536            warns[0]
2537        );
2538    }
2539
2540    #[test]
2541    fn test_interface_constraint_param_count_mismatch() {
2542        let warns = iface_warns(
2543            r#"pipeline t(task) {
2544  interface Doubler {
2545    fn double(self, x: int) -> int
2546  }
2547  struct Bad { v: int }
2548  impl Bad {
2549    fn double(self) -> int { return self.v * 2 }
2550  }
2551  fn run_double<T>(d: T) where T: Doubler { log(d.double(3)) }
2552  run_double(Bad({v: 5}))
2553}"#,
2554        );
2555        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2556        assert!(
2557            warns[0].contains("method 'double' has 0 parameter(s), expected 1"),
2558            "unexpected message: {}",
2559            warns[0]
2560        );
2561    }
2562
2563    #[test]
2564    fn test_interface_constraint_satisfied() {
2565        let warns = iface_warns(
2566            r#"pipeline t(task) {
2567  interface Sizable {
2568    fn size(self) -> int
2569  }
2570  struct Box { width: int, height: int }
2571  impl Box {
2572    fn size(self) -> int { return self.width * self.height }
2573  }
2574  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2575  measure(Box({width: 3, height: 4}))
2576}"#,
2577        );
2578        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2579    }
2580
2581    #[test]
2582    fn test_interface_constraint_untyped_impl_compatible() {
2583        // Gradual typing: untyped impl return should not trigger warning
2584        let warns = iface_warns(
2585            r#"pipeline t(task) {
2586  interface Sizable {
2587    fn size(self) -> int
2588  }
2589  struct Box { width: int }
2590  impl Box {
2591    fn size(self) { return self.width }
2592  }
2593  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2594  measure(Box({width: 3}))
2595}"#,
2596        );
2597        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2598    }
2599
2600    #[test]
2601    fn test_interface_constraint_int_float_covariance() {
2602        // int is compatible with float (covariance)
2603        let warns = iface_warns(
2604            r#"pipeline t(task) {
2605  interface Measurable {
2606    fn value(self) -> float
2607  }
2608  struct Gauge { v: int }
2609  impl Gauge {
2610    fn value(self) -> int { return self.v }
2611  }
2612  fn read_val<T>(g: T) where T: Measurable { log(g.value()) }
2613  read_val(Gauge({v: 42}))
2614}"#,
2615        );
2616        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2617    }
2618}