Skip to main content

harn_parser/
typechecker.rs

1use std::collections::BTreeMap;
2
3use crate::ast::*;
4use harn_lexer::Span;
5
6/// A diagnostic produced by the type checker.
7#[derive(Debug, Clone)]
8pub struct TypeDiagnostic {
9    pub message: String,
10    pub severity: DiagnosticSeverity,
11    pub span: Option<Span>,
12    pub help: Option<String>,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum DiagnosticSeverity {
17    Error,
18    Warning,
19}
20
21/// Inferred type of an expression. None means unknown/untyped (gradual typing).
22type InferredType = Option<TypeExpr>;
23
24/// Scope for tracking variable types.
25#[derive(Debug, Clone)]
26struct TypeScope {
27    /// Variable name → inferred type.
28    vars: BTreeMap<String, InferredType>,
29    /// Function name → (param types, return type).
30    functions: BTreeMap<String, FnSignature>,
31    /// Named type aliases.
32    type_aliases: BTreeMap<String, TypeExpr>,
33    /// Enum declarations: name → variant names.
34    enums: BTreeMap<String, Vec<String>>,
35    /// Interface declarations: name → method signatures.
36    interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
37    /// Struct declarations: name → field types.
38    structs: BTreeMap<String, Vec<(String, InferredType)>>,
39    /// Impl block methods: type_name → method signatures.
40    impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
41    /// Generic type parameter names in scope (treated as compatible with any type).
42    generic_type_params: std::collections::BTreeSet<String>,
43    /// Where-clause constraints: type_param → interface_bound.
44    /// Used for definition-site checking of generic function bodies.
45    where_constraints: BTreeMap<String, String>,
46    parent: Option<Box<TypeScope>>,
47}
48
49/// Method signature extracted from an impl block (for interface checking).
50#[derive(Debug, Clone)]
51struct ImplMethodSig {
52    name: String,
53    /// Number of parameters excluding `self`.
54    param_count: usize,
55    /// Parameter types (excluding `self`), None means untyped.
56    param_types: Vec<Option<TypeExpr>>,
57    /// Return type, None means untyped.
58    return_type: Option<TypeExpr>,
59}
60
61#[derive(Debug, Clone)]
62struct FnSignature {
63    params: Vec<(String, InferredType)>,
64    return_type: InferredType,
65    /// Generic type parameter names declared on the function.
66    type_param_names: Vec<String>,
67    /// Number of required parameters (those without defaults).
68    required_params: usize,
69    /// Where-clause constraints: (type_param_name, interface_bound).
70    where_clauses: Vec<(String, String)>,
71}
72
73impl TypeScope {
74    fn new() -> Self {
75        Self {
76            vars: BTreeMap::new(),
77            functions: BTreeMap::new(),
78            type_aliases: BTreeMap::new(),
79            enums: BTreeMap::new(),
80            interfaces: BTreeMap::new(),
81            structs: BTreeMap::new(),
82            impl_methods: BTreeMap::new(),
83            generic_type_params: std::collections::BTreeSet::new(),
84            where_constraints: BTreeMap::new(),
85            parent: None,
86        }
87    }
88
89    fn child(&self) -> Self {
90        Self {
91            vars: BTreeMap::new(),
92            functions: BTreeMap::new(),
93            type_aliases: BTreeMap::new(),
94            enums: BTreeMap::new(),
95            interfaces: BTreeMap::new(),
96            structs: BTreeMap::new(),
97            impl_methods: BTreeMap::new(),
98            generic_type_params: std::collections::BTreeSet::new(),
99            where_constraints: BTreeMap::new(),
100            parent: Some(Box::new(self.clone())),
101        }
102    }
103
104    fn get_var(&self, name: &str) -> Option<&InferredType> {
105        self.vars
106            .get(name)
107            .or_else(|| self.parent.as_ref()?.get_var(name))
108    }
109
110    fn get_fn(&self, name: &str) -> Option<&FnSignature> {
111        self.functions
112            .get(name)
113            .or_else(|| self.parent.as_ref()?.get_fn(name))
114    }
115
116    fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
117        self.type_aliases
118            .get(name)
119            .or_else(|| self.parent.as_ref()?.resolve_type(name))
120    }
121
122    fn is_generic_type_param(&self, name: &str) -> bool {
123        self.generic_type_params.contains(name)
124            || self
125                .parent
126                .as_ref()
127                .is_some_and(|p| p.is_generic_type_param(name))
128    }
129
130    fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
131        self.where_constraints
132            .get(type_param)
133            .map(|s| s.as_str())
134            .or_else(|| {
135                self.parent
136                    .as_ref()
137                    .and_then(|p| p.get_where_constraint(type_param))
138            })
139    }
140
141    fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
142        self.enums
143            .get(name)
144            .or_else(|| self.parent.as_ref()?.get_enum(name))
145    }
146
147    fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
148        self.interfaces
149            .get(name)
150            .or_else(|| self.parent.as_ref()?.get_interface(name))
151    }
152
153    fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
154        self.structs
155            .get(name)
156            .or_else(|| self.parent.as_ref()?.get_struct(name))
157    }
158
159    fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
160        self.impl_methods
161            .get(name)
162            .or_else(|| self.parent.as_ref()?.get_impl_methods(name))
163    }
164
165    fn define_var(&mut self, name: &str, ty: InferredType) {
166        self.vars.insert(name.to_string(), ty);
167    }
168
169    fn define_fn(&mut self, name: &str, sig: FnSignature) {
170        self.functions.insert(name.to_string(), sig);
171    }
172}
173
174/// Known return types for builtin functions.
175fn builtin_return_type(name: &str) -> InferredType {
176    match name {
177        "log" | "print" | "println" | "write_file" | "sleep" | "cancel" | "exit"
178        | "delete_file" | "mkdir" | "copy_file" | "append_file" => {
179            Some(TypeExpr::Named("nil".into()))
180        }
181        "type_of"
182        | "to_string"
183        | "json_stringify"
184        | "read_file"
185        | "http_get"
186        | "http_post"
187        | "regex_replace"
188        | "path_join"
189        | "temp_dir"
190        | "date_format"
191        | "format"
192        | "compute_content_hash" => Some(TypeExpr::Named("string".into())),
193        "to_int" | "timer_end" | "elapsed" | "sign" => Some(TypeExpr::Named("int".into())),
194        "to_float" | "timestamp" | "date_parse" | "sin" | "cos" | "tan" | "asin" | "acos"
195        | "atan" | "atan2" | "log2" | "log10" | "ln" | "exp" | "pi" | "e" => {
196            Some(TypeExpr::Named("float".into()))
197        }
198        "file_exists" | "json_validate" | "is_nan" | "is_infinite" | "set_contains" => {
199            Some(TypeExpr::Named("bool".into()))
200        }
201        "list_dir"
202        | "mcp_list_tools"
203        | "mcp_list_resources"
204        | "mcp_list_prompts"
205        | "to_list"
206        | "regex_captures"
207        | "artifact_select"
208        | "transcript_messages"
209        | "transcript_events" => Some(TypeExpr::Named("list".into())),
210        "stat"
211        | "exec"
212        | "shell"
213        | "date_now"
214        | "llm_call"
215        | "llm_completion"
216        | "agent_loop"
217        | "llm_info"
218        | "llm_usage"
219        | "timer_start"
220        | "metadata_get"
221        | "mcp_server_info"
222        | "mcp_get_prompt"
223        | "llm_pick_model"
224        | "transcript"
225        | "transcript_from_messages"
226        | "transcript_reset"
227        | "transcript_archive"
228        | "transcript_abandon"
229        | "transcript_resume"
230        | "workflow_graph"
231        | "workflow_validate"
232        | "workflow_inspect"
233        | "workflow_policy_report"
234        | "workflow_clone"
235        | "workflow_insert_node"
236        | "workflow_replace_node"
237        | "workflow_rewire"
238        | "workflow_set_model_policy"
239        | "workflow_set_context_policy"
240        | "workflow_set_transcript_policy"
241        | "workflow_diff"
242        | "workflow_commit"
243        | "artifact"
244        | "artifact_derive"
245        | "artifact_workspace_file"
246        | "artifact_workspace_snapshot"
247        | "artifact_editor_selection"
248        | "artifact_verification_result"
249        | "artifact_test_result"
250        | "artifact_command_result"
251        | "artifact_diff"
252        | "artifact_git_diff"
253        | "artifact_diff_review"
254        | "artifact_review_decision"
255        | "run_record"
256        | "run_record_save"
257        | "run_record_load"
258        | "run_record_fixture"
259        | "run_record_eval"
260        | "run_record_eval_suite"
261        | "run_record_diff"
262        | "eval_suite_manifest"
263        | "eval_suite_run"
264        | "workflow_execute"
265        | "transcript_compact"
266        | "transcript_summarize"
267        | "host_capabilities" => Some(TypeExpr::Named("dict".into())),
268        "transcript_render_visible"
269        | "transcript_render_full"
270        | "artifact_context"
271        | "transcript_export"
272        | "transcript_id" => Some(TypeExpr::Named("string".into())),
273        "transcript_summary" => Some(TypeExpr::Union(vec![
274            TypeExpr::Named("string".into()),
275            TypeExpr::Named("nil".into()),
276        ])),
277        "host_has" => Some(TypeExpr::Named("bool".into())),
278        "metadata_set"
279        | "metadata_save"
280        | "metadata_refresh_hashes"
281        | "invalidate_facts"
282        | "log_json"
283        | "mcp_disconnect" => Some(TypeExpr::Named("nil".into())),
284        "env" | "regex_match" => Some(TypeExpr::Union(vec![
285            TypeExpr::Named("string".into()),
286            TypeExpr::Named("nil".into()),
287        ])),
288        "json_parse" | "json_extract" => None, // could be any type
289        _ => None,
290    }
291}
292
293/// Check if a name is a known builtin.
294fn is_builtin(name: &str) -> bool {
295    matches!(
296        name,
297        "log"
298            | "print"
299            | "println"
300            | "type_of"
301            | "to_string"
302            | "to_int"
303            | "to_float"
304            | "json_stringify"
305            | "json_parse"
306            | "env"
307            | "timestamp"
308            | "sleep"
309            | "read_file"
310            | "write_file"
311            | "exit"
312            | "regex_match"
313            | "regex_replace"
314            | "regex_captures"
315            | "http_get"
316            | "http_post"
317            | "llm_call"
318            | "llm_completion"
319            | "agent_loop"
320            | "llm_pick_model"
321            | "await"
322            | "cancel"
323            | "file_exists"
324            | "delete_file"
325            | "list_dir"
326            | "mkdir"
327            | "path_join"
328            | "copy_file"
329            | "append_file"
330            | "temp_dir"
331            | "transcript"
332            | "transcript_from_messages"
333            | "transcript_messages"
334            | "transcript_events"
335            | "transcript_summary"
336            | "transcript_id"
337            | "transcript_export"
338            | "transcript_import"
339            | "transcript_fork"
340            | "transcript_reset"
341            | "transcript_archive"
342            | "transcript_abandon"
343            | "transcript_resume"
344            | "transcript_render_visible"
345            | "transcript_render_full"
346            | "transcript_compact"
347            | "transcript_summarize"
348            | "host_capabilities"
349            | "host_has"
350            | "host_invoke"
351            | "workflow_graph"
352            | "workflow_validate"
353            | "workflow_inspect"
354            | "workflow_policy_report"
355            | "workflow_clone"
356            | "workflow_insert_node"
357            | "workflow_replace_node"
358            | "workflow_rewire"
359            | "workflow_set_model_policy"
360            | "workflow_set_context_policy"
361            | "workflow_set_transcript_policy"
362            | "workflow_diff"
363            | "workflow_commit"
364            | "workflow_execute"
365            | "artifact"
366            | "artifact_derive"
367            | "artifact_workspace_file"
368            | "artifact_workspace_snapshot"
369            | "artifact_editor_selection"
370            | "artifact_verification_result"
371            | "artifact_test_result"
372            | "artifact_command_result"
373            | "artifact_diff"
374            | "artifact_git_diff"
375            | "artifact_diff_review"
376            | "artifact_review_decision"
377            | "artifact_select"
378            | "artifact_context"
379            | "run_record"
380            | "run_record_save"
381            | "run_record_load"
382            | "run_record_fixture"
383            | "run_record_eval"
384            | "run_record_eval_suite"
385            | "run_record_diff"
386            | "eval_suite_manifest"
387            | "eval_suite_run"
388            | "stat"
389            | "exec"
390            | "shell"
391            | "date_now"
392            | "date_format"
393            | "date_parse"
394            | "format"
395            | "json_validate"
396            | "json_extract"
397            | "trim"
398            | "lowercase"
399            | "uppercase"
400            | "split"
401            | "starts_with"
402            | "ends_with"
403            | "contains"
404            | "replace"
405            | "join"
406            | "len"
407            | "substring"
408            | "dirname"
409            | "basename"
410            | "extname"
411            | "sin"
412            | "cos"
413            | "tan"
414            | "asin"
415            | "acos"
416            | "atan"
417            | "atan2"
418            | "log2"
419            | "log10"
420            | "ln"
421            | "exp"
422            | "pi"
423            | "e"
424            | "sign"
425            | "is_nan"
426            | "is_infinite"
427            | "set"
428            | "set_add"
429            | "set_remove"
430            | "set_contains"
431            | "set_union"
432            | "set_intersect"
433            | "set_difference"
434            | "to_list"
435    )
436}
437
438/// The static type checker.
439pub struct TypeChecker {
440    diagnostics: Vec<TypeDiagnostic>,
441    scope: TypeScope,
442}
443
444impl TypeChecker {
445    pub fn new() -> Self {
446        Self {
447            diagnostics: Vec::new(),
448            scope: TypeScope::new(),
449        }
450    }
451
452    /// Check a program and return diagnostics.
453    pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
454        // First pass: register type and enum declarations into root scope
455        Self::register_declarations_into(&mut self.scope, program);
456
457        // Also scan pipeline bodies for declarations
458        for snode in program {
459            if let Node::Pipeline { body, .. } = &snode.node {
460                Self::register_declarations_into(&mut self.scope, body);
461            }
462        }
463
464        // Check each top-level node
465        for snode in program {
466            match &snode.node {
467                Node::Pipeline { params, body, .. } => {
468                    let mut child = self.scope.child();
469                    for p in params {
470                        child.define_var(p, None);
471                    }
472                    self.check_block(body, &mut child);
473                }
474                Node::FnDecl {
475                    name,
476                    type_params,
477                    params,
478                    return_type,
479                    where_clauses,
480                    body,
481                    ..
482                } => {
483                    let required_params =
484                        params.iter().filter(|p| p.default_value.is_none()).count();
485                    let sig = FnSignature {
486                        params: params
487                            .iter()
488                            .map(|p| (p.name.clone(), p.type_expr.clone()))
489                            .collect(),
490                        return_type: return_type.clone(),
491                        type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
492                        required_params,
493                        where_clauses: where_clauses
494                            .iter()
495                            .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
496                            .collect(),
497                    };
498                    self.scope.define_fn(name, sig);
499                    self.check_fn_body(type_params, params, return_type, body, where_clauses);
500                }
501                _ => {
502                    let mut scope = self.scope.clone();
503                    self.check_node(snode, &mut scope);
504                    // Merge any new definitions back into the top-level scope
505                    for (name, ty) in scope.vars {
506                        self.scope.vars.entry(name).or_insert(ty);
507                    }
508                }
509            }
510        }
511
512        self.diagnostics
513    }
514
515    /// Register type, enum, interface, and struct declarations from AST nodes into a scope.
516    fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
517        for snode in nodes {
518            match &snode.node {
519                Node::TypeDecl { name, type_expr } => {
520                    scope.type_aliases.insert(name.clone(), type_expr.clone());
521                }
522                Node::EnumDecl { name, variants } => {
523                    let variant_names: Vec<String> =
524                        variants.iter().map(|v| v.name.clone()).collect();
525                    scope.enums.insert(name.clone(), variant_names);
526                }
527                Node::InterfaceDecl { name, methods } => {
528                    scope.interfaces.insert(name.clone(), methods.clone());
529                }
530                Node::StructDecl { name, fields } => {
531                    let field_types: Vec<(String, InferredType)> = fields
532                        .iter()
533                        .map(|f| (f.name.clone(), f.type_expr.clone()))
534                        .collect();
535                    scope.structs.insert(name.clone(), field_types);
536                }
537                Node::ImplBlock {
538                    type_name, methods, ..
539                } => {
540                    let sigs: Vec<ImplMethodSig> = methods
541                        .iter()
542                        .filter_map(|m| {
543                            if let Node::FnDecl {
544                                name,
545                                params,
546                                return_type,
547                                ..
548                            } = &m.node
549                            {
550                                let non_self: Vec<_> =
551                                    params.iter().filter(|p| p.name != "self").collect();
552                                let param_count = non_self.len();
553                                let param_types: Vec<Option<TypeExpr>> =
554                                    non_self.iter().map(|p| p.type_expr.clone()).collect();
555                                Some(ImplMethodSig {
556                                    name: name.clone(),
557                                    param_count,
558                                    param_types,
559                                    return_type: return_type.clone(),
560                                })
561                            } else {
562                                None
563                            }
564                        })
565                        .collect();
566                    scope.impl_methods.insert(type_name.clone(), sigs);
567                }
568                _ => {}
569            }
570        }
571    }
572
573    fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
574        for stmt in stmts {
575            self.check_node(stmt, scope);
576        }
577    }
578
579    /// Define variables from a destructuring pattern in the given scope (as unknown type).
580    fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
581        match pattern {
582            BindingPattern::Identifier(name) => {
583                scope.define_var(name, None);
584            }
585            BindingPattern::Dict(fields) => {
586                for field in fields {
587                    let name = field.alias.as_deref().unwrap_or(&field.key);
588                    scope.define_var(name, None);
589                }
590            }
591            BindingPattern::List(elements) => {
592                for elem in elements {
593                    scope.define_var(&elem.name, None);
594                }
595            }
596        }
597    }
598
599    fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
600        let span = snode.span;
601        match &snode.node {
602            Node::LetBinding {
603                pattern,
604                type_ann,
605                value,
606            } => {
607                let inferred = self.infer_type(value, scope);
608                if let BindingPattern::Identifier(name) = pattern {
609                    if let Some(expected) = type_ann {
610                        if let Some(actual) = &inferred {
611                            if !self.types_compatible(expected, actual, scope) {
612                                let mut msg = format!(
613                                    "Type mismatch: '{}' declared as {}, but assigned {}",
614                                    name,
615                                    format_type(expected),
616                                    format_type(actual)
617                                );
618                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
619                                    msg.push_str(&format!(" ({})", detail));
620                                }
621                                self.error_at(msg, span);
622                            }
623                        }
624                    }
625                    let ty = type_ann.clone().or(inferred);
626                    scope.define_var(name, ty);
627                } else {
628                    Self::define_pattern_vars(pattern, scope);
629                }
630            }
631
632            Node::VarBinding {
633                pattern,
634                type_ann,
635                value,
636            } => {
637                let inferred = self.infer_type(value, scope);
638                if let BindingPattern::Identifier(name) = pattern {
639                    if let Some(expected) = type_ann {
640                        if let Some(actual) = &inferred {
641                            if !self.types_compatible(expected, actual, scope) {
642                                let mut msg = format!(
643                                    "Type mismatch: '{}' declared as {}, but assigned {}",
644                                    name,
645                                    format_type(expected),
646                                    format_type(actual)
647                                );
648                                if let Some(detail) = shape_mismatch_detail(expected, actual) {
649                                    msg.push_str(&format!(" ({})", detail));
650                                }
651                                self.error_at(msg, span);
652                            }
653                        }
654                    }
655                    let ty = type_ann.clone().or(inferred);
656                    scope.define_var(name, ty);
657                } else {
658                    Self::define_pattern_vars(pattern, scope);
659                }
660            }
661
662            Node::FnDecl {
663                name,
664                type_params,
665                params,
666                return_type,
667                where_clauses,
668                body,
669                ..
670            } => {
671                let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
672                let sig = FnSignature {
673                    params: params
674                        .iter()
675                        .map(|p| (p.name.clone(), p.type_expr.clone()))
676                        .collect(),
677                    return_type: return_type.clone(),
678                    type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
679                    required_params,
680                    where_clauses: where_clauses
681                        .iter()
682                        .map(|wc| (wc.type_name.clone(), wc.bound.clone()))
683                        .collect(),
684                };
685                scope.define_fn(name, sig.clone());
686                scope.define_var(name, None);
687                self.check_fn_body(type_params, params, return_type, body, where_clauses);
688            }
689
690            Node::FunctionCall { name, args } => {
691                self.check_call(name, args, scope, span);
692            }
693
694            Node::IfElse {
695                condition,
696                then_body,
697                else_body,
698            } => {
699                self.check_node(condition, scope);
700                let mut then_scope = scope.child();
701                // Narrow union types after nil checks: if x != nil, narrow x
702                if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
703                    then_scope.define_var(&var_name, narrowed);
704                }
705                self.check_block(then_body, &mut then_scope);
706                if let Some(else_body) = else_body {
707                    let mut else_scope = scope.child();
708                    self.check_block(else_body, &mut else_scope);
709                }
710            }
711
712            Node::ForIn {
713                pattern,
714                iterable,
715                body,
716            } => {
717                self.check_node(iterable, scope);
718                let mut loop_scope = scope.child();
719                if let BindingPattern::Identifier(variable) = pattern {
720                    // Infer loop variable type from iterable
721                    let elem_type = match self.infer_type(iterable, scope) {
722                        Some(TypeExpr::List(inner)) => Some(*inner),
723                        Some(TypeExpr::Named(n)) if n == "string" => {
724                            Some(TypeExpr::Named("string".into()))
725                        }
726                        _ => None,
727                    };
728                    loop_scope.define_var(variable, elem_type);
729                } else {
730                    Self::define_pattern_vars(pattern, &mut loop_scope);
731                }
732                self.check_block(body, &mut loop_scope);
733            }
734
735            Node::WhileLoop { condition, body } => {
736                self.check_node(condition, scope);
737                let mut loop_scope = scope.child();
738                self.check_block(body, &mut loop_scope);
739            }
740
741            Node::TryCatch {
742                body,
743                error_var,
744                catch_body,
745                finally_body,
746                ..
747            } => {
748                let mut try_scope = scope.child();
749                self.check_block(body, &mut try_scope);
750                let mut catch_scope = scope.child();
751                if let Some(var) = error_var {
752                    catch_scope.define_var(var, None);
753                }
754                self.check_block(catch_body, &mut catch_scope);
755                if let Some(fb) = finally_body {
756                    let mut finally_scope = scope.child();
757                    self.check_block(fb, &mut finally_scope);
758                }
759            }
760
761            Node::TryExpr { body } => {
762                let mut try_scope = scope.child();
763                self.check_block(body, &mut try_scope);
764            }
765
766            Node::ReturnStmt {
767                value: Some(val), ..
768            } => {
769                self.check_node(val, scope);
770            }
771
772            Node::Assignment {
773                target, value, op, ..
774            } => {
775                self.check_node(value, scope);
776                if let Node::Identifier(name) = &target.node {
777                    if let Some(Some(var_type)) = scope.get_var(name) {
778                        let value_type = self.infer_type(value, scope);
779                        let assigned = if let Some(op) = op {
780                            let var_inferred = scope.get_var(name).cloned().flatten();
781                            infer_binary_op_type(op, &var_inferred, &value_type)
782                        } else {
783                            value_type
784                        };
785                        if let Some(actual) = &assigned {
786                            if !self.types_compatible(var_type, actual, scope) {
787                                self.error_at(
788                                    format!(
789                                        "Type mismatch: cannot assign {} to '{}' (declared as {})",
790                                        format_type(actual),
791                                        name,
792                                        format_type(var_type)
793                                    ),
794                                    span,
795                                );
796                            }
797                        }
798                    }
799                }
800            }
801
802            Node::TypeDecl { name, type_expr } => {
803                scope.type_aliases.insert(name.clone(), type_expr.clone());
804            }
805
806            Node::EnumDecl { name, variants } => {
807                let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
808                scope.enums.insert(name.clone(), variant_names);
809            }
810
811            Node::StructDecl { name, fields } => {
812                let field_types: Vec<(String, InferredType)> = fields
813                    .iter()
814                    .map(|f| (f.name.clone(), f.type_expr.clone()))
815                    .collect();
816                scope.structs.insert(name.clone(), field_types);
817            }
818
819            Node::InterfaceDecl { name, methods } => {
820                scope.interfaces.insert(name.clone(), methods.clone());
821            }
822
823            Node::ImplBlock {
824                type_name, methods, ..
825            } => {
826                // Register impl methods for interface satisfaction checking
827                let sigs: Vec<ImplMethodSig> = methods
828                    .iter()
829                    .filter_map(|m| {
830                        if let Node::FnDecl {
831                            name,
832                            params,
833                            return_type,
834                            ..
835                        } = &m.node
836                        {
837                            let non_self: Vec<_> =
838                                params.iter().filter(|p| p.name != "self").collect();
839                            let param_count = non_self.len();
840                            let param_types: Vec<Option<TypeExpr>> =
841                                non_self.iter().map(|p| p.type_expr.clone()).collect();
842                            Some(ImplMethodSig {
843                                name: name.clone(),
844                                param_count,
845                                param_types,
846                                return_type: return_type.clone(),
847                            })
848                        } else {
849                            None
850                        }
851                    })
852                    .collect();
853                scope.impl_methods.insert(type_name.clone(), sigs);
854                for method_sn in methods {
855                    self.check_node(method_sn, scope);
856                }
857            }
858
859            Node::TryOperator { operand } => {
860                self.check_node(operand, scope);
861            }
862
863            Node::MatchExpr { value, arms } => {
864                self.check_node(value, scope);
865                let value_type = self.infer_type(value, scope);
866                for arm in arms {
867                    self.check_node(&arm.pattern, scope);
868                    // Check for incompatible literal pattern types
869                    if let Some(ref vt) = value_type {
870                        let value_type_name = format_type(vt);
871                        let mismatch = match &arm.pattern.node {
872                            Node::StringLiteral(_) => {
873                                !self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
874                            }
875                            Node::IntLiteral(_) => {
876                                !self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
877                                    && !self.types_compatible(
878                                        vt,
879                                        &TypeExpr::Named("float".into()),
880                                        scope,
881                                    )
882                            }
883                            Node::FloatLiteral(_) => {
884                                !self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
885                                    && !self.types_compatible(
886                                        vt,
887                                        &TypeExpr::Named("int".into()),
888                                        scope,
889                                    )
890                            }
891                            Node::BoolLiteral(_) => {
892                                !self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
893                            }
894                            _ => false,
895                        };
896                        if mismatch {
897                            let pattern_type = match &arm.pattern.node {
898                                Node::StringLiteral(_) => "string",
899                                Node::IntLiteral(_) => "int",
900                                Node::FloatLiteral(_) => "float",
901                                Node::BoolLiteral(_) => "bool",
902                                _ => unreachable!(),
903                            };
904                            self.warning_at(
905                                format!(
906                                    "Match pattern type mismatch: matching {} against {} literal",
907                                    value_type_name, pattern_type
908                                ),
909                                arm.pattern.span,
910                            );
911                        }
912                    }
913                    let mut arm_scope = scope.child();
914                    self.check_block(&arm.body, &mut arm_scope);
915                }
916                self.check_match_exhaustiveness(value, arms, scope, span);
917            }
918
919            // Recurse into nested expressions + validate binary op types
920            Node::BinaryOp { op, left, right } => {
921                self.check_node(left, scope);
922                self.check_node(right, scope);
923                // Validate operator/type compatibility
924                let lt = self.infer_type(left, scope);
925                let rt = self.infer_type(right, scope);
926                if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (&lt, &rt) {
927                    match op.as_str() {
928                        "-" | "*" | "/" | "%" => {
929                            let numeric = ["int", "float"];
930                            if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
931                                self.warning_at(
932                                    format!(
933                                        "Operator '{op}' may not be valid for types {} and {}",
934                                        l, r
935                                    ),
936                                    span,
937                                );
938                            }
939                        }
940                        "+" => {
941                            // + is valid for int, float, string, list, dict
942                            let valid = ["int", "float", "string", "list", "dict"];
943                            if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
944                                self.warning_at(
945                                    format!(
946                                        "Operator '+' may not be valid for types {} and {}",
947                                        l, r
948                                    ),
949                                    span,
950                                );
951                            }
952                        }
953                        _ => {}
954                    }
955                }
956            }
957            Node::UnaryOp { operand, .. } => {
958                self.check_node(operand, scope);
959            }
960            Node::MethodCall {
961                object,
962                method,
963                args,
964                ..
965            }
966            | Node::OptionalMethodCall {
967                object,
968                method,
969                args,
970                ..
971            } => {
972                self.check_node(object, scope);
973                for arg in args {
974                    self.check_node(arg, scope);
975                }
976                // Definition-site generic checking: if the object's type is a
977                // constrained generic param (where T: Interface), verify the
978                // method exists in the bound interface.
979                if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
980                    if scope.is_generic_type_param(&type_name) {
981                        if let Some(iface_name) = scope.get_where_constraint(&type_name) {
982                            if let Some(iface_methods) = scope.get_interface(iface_name) {
983                                let has_method = iface_methods.iter().any(|m| m.name == *method);
984                                if !has_method {
985                                    self.warning_at(
986                                        format!(
987                                            "Method '{}' not found in interface '{}' (constraint on '{}')",
988                                            method, iface_name, type_name
989                                        ),
990                                        span,
991                                    );
992                                }
993                            }
994                        }
995                    }
996                }
997            }
998            Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
999                self.check_node(object, scope);
1000            }
1001            Node::SubscriptAccess { object, index } => {
1002                self.check_node(object, scope);
1003                self.check_node(index, scope);
1004            }
1005            Node::SliceAccess { object, start, end } => {
1006                self.check_node(object, scope);
1007                if let Some(s) = start {
1008                    self.check_node(s, scope);
1009                }
1010                if let Some(e) = end {
1011                    self.check_node(e, scope);
1012                }
1013            }
1014
1015            // Terminals — nothing to check
1016            _ => {}
1017        }
1018    }
1019
1020    fn check_fn_body(
1021        &mut self,
1022        type_params: &[TypeParam],
1023        params: &[TypedParam],
1024        return_type: &Option<TypeExpr>,
1025        body: &[SNode],
1026        where_clauses: &[WhereClause],
1027    ) {
1028        let mut fn_scope = self.scope.child();
1029        // Register generic type parameters so they are treated as compatible
1030        // with any concrete type during type checking.
1031        for tp in type_params {
1032            fn_scope.generic_type_params.insert(tp.name.clone());
1033        }
1034        // Store where-clause constraints for definition-site checking
1035        for wc in where_clauses {
1036            fn_scope
1037                .where_constraints
1038                .insert(wc.type_name.clone(), wc.bound.clone());
1039        }
1040        for param in params {
1041            fn_scope.define_var(&param.name, param.type_expr.clone());
1042            if let Some(default) = &param.default_value {
1043                self.check_node(default, &mut fn_scope);
1044            }
1045        }
1046        self.check_block(body, &mut fn_scope);
1047
1048        // Check return statements against declared return type
1049        if let Some(ret_type) = return_type {
1050            for stmt in body {
1051                self.check_return_type(stmt, ret_type, &fn_scope);
1052            }
1053        }
1054    }
1055
1056    fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
1057        let span = snode.span;
1058        match &snode.node {
1059            Node::ReturnStmt { value: Some(val) } => {
1060                let inferred = self.infer_type(val, scope);
1061                if let Some(actual) = &inferred {
1062                    if !self.types_compatible(expected, actual, scope) {
1063                        self.error_at(
1064                            format!(
1065                                "Return type mismatch: expected {}, got {}",
1066                                format_type(expected),
1067                                format_type(actual)
1068                            ),
1069                            span,
1070                        );
1071                    }
1072                }
1073            }
1074            Node::IfElse {
1075                then_body,
1076                else_body,
1077                ..
1078            } => {
1079                for stmt in then_body {
1080                    self.check_return_type(stmt, expected, scope);
1081                }
1082                if let Some(else_body) = else_body {
1083                    for stmt in else_body {
1084                        self.check_return_type(stmt, expected, scope);
1085                    }
1086                }
1087            }
1088            _ => {}
1089        }
1090    }
1091
1092    /// Check if a match expression on an enum's `.variant` property covers all variants.
1093    /// Extract narrowing info from nil-check conditions like `x != nil`.
1094    /// Returns (var_name, narrowed_type) where narrowed_type removes nil from a union.
1095    /// Check if a type satisfies an interface (Go-style implicit satisfaction).
1096    /// A type satisfies an interface if its impl block has all the required methods.
1097    fn satisfies_interface(
1098        &self,
1099        type_name: &str,
1100        interface_name: &str,
1101        scope: &TypeScope,
1102    ) -> bool {
1103        self.interface_mismatch_reason(type_name, interface_name, scope)
1104            .is_none()
1105    }
1106
1107    /// Return a detailed reason why a type does not satisfy an interface, or None
1108    /// if it does satisfy it.  Used for producing actionable warning messages.
1109    fn interface_mismatch_reason(
1110        &self,
1111        type_name: &str,
1112        interface_name: &str,
1113        scope: &TypeScope,
1114    ) -> Option<String> {
1115        let interface_methods = match scope.get_interface(interface_name) {
1116            Some(methods) => methods,
1117            None => return Some(format!("interface '{}' not found", interface_name)),
1118        };
1119        let impl_methods = match scope.get_impl_methods(type_name) {
1120            Some(methods) => methods,
1121            None => {
1122                if interface_methods.is_empty() {
1123                    return None;
1124                }
1125                let names: Vec<_> = interface_methods.iter().map(|m| m.name.as_str()).collect();
1126                return Some(format!("missing method(s): {}", names.join(", ")));
1127            }
1128        };
1129        for iface_method in interface_methods {
1130            let iface_params: Vec<_> = iface_method
1131                .params
1132                .iter()
1133                .filter(|p| p.name != "self")
1134                .collect();
1135            let iface_param_count = iface_params.len();
1136            let matching_impl = impl_methods.iter().find(|im| im.name == iface_method.name);
1137            let impl_method = match matching_impl {
1138                Some(m) => m,
1139                None => {
1140                    return Some(format!("missing method '{}'", iface_method.name));
1141                }
1142            };
1143            if impl_method.param_count != iface_param_count {
1144                return Some(format!(
1145                    "method '{}' has {} parameter(s), expected {}",
1146                    iface_method.name, impl_method.param_count, iface_param_count
1147                ));
1148            }
1149            // Check parameter types where both sides specify them
1150            for (i, iface_param) in iface_params.iter().enumerate() {
1151                if let (Some(expected), Some(actual)) = (
1152                    &iface_param.type_expr,
1153                    impl_method.param_types.get(i).and_then(|t| t.as_ref()),
1154                ) {
1155                    if !self.types_compatible(expected, actual, scope) {
1156                        return Some(format!(
1157                            "method '{}' parameter {} has type '{}', expected '{}'",
1158                            iface_method.name,
1159                            i + 1,
1160                            format_type(actual),
1161                            format_type(expected),
1162                        ));
1163                    }
1164                }
1165            }
1166            // Check return type where both sides specify it
1167            if let (Some(expected_ret), Some(actual_ret)) =
1168                (&iface_method.return_type, &impl_method.return_type)
1169            {
1170                if !self.types_compatible(expected_ret, actual_ret, scope) {
1171                    return Some(format!(
1172                        "method '{}' returns '{}', expected '{}'",
1173                        iface_method.name,
1174                        format_type(actual_ret),
1175                        format_type(expected_ret),
1176                    ));
1177                }
1178            }
1179        }
1180        None
1181    }
1182
1183    /// Recursively extract type parameter bindings from matching param/arg types.
1184    /// E.g., param_type=list<T> + arg_type=list<Dog> → binds T=Dog.
1185    fn extract_type_bindings(
1186        param_type: &TypeExpr,
1187        arg_type: &TypeExpr,
1188        type_params: &std::collections::BTreeSet<String>,
1189        bindings: &mut BTreeMap<String, String>,
1190    ) {
1191        match (param_type, arg_type) {
1192            // Direct type param match: T → concrete
1193            (TypeExpr::Named(param_name), TypeExpr::Named(concrete))
1194                if type_params.contains(param_name) =>
1195            {
1196                bindings
1197                    .entry(param_name.clone())
1198                    .or_insert(concrete.clone());
1199            }
1200            // list<T> + list<Dog>
1201            (TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
1202                Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
1203            }
1204            // dict<K, V> + dict<string, int>
1205            (TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
1206                Self::extract_type_bindings(pk, ak, type_params, bindings);
1207                Self::extract_type_bindings(pv, av, type_params, bindings);
1208            }
1209            _ => {}
1210        }
1211    }
1212
1213    fn extract_nil_narrowing(
1214        condition: &SNode,
1215        scope: &TypeScope,
1216    ) -> Option<(String, InferredType)> {
1217        if let Node::BinaryOp { op, left, right } = &condition.node {
1218            if op == "!=" {
1219                // Check for `x != nil` or `nil != x`
1220                let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
1221                    (left, right)
1222                } else if matches!(left.node, Node::NilLiteral) {
1223                    (right, left)
1224                } else {
1225                    return None;
1226                };
1227                let _ = nil_node;
1228                if let Node::Identifier(name) = &var_node.node {
1229                    // Look up the variable's type and narrow it
1230                    if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
1231                        let narrowed: Vec<TypeExpr> = members
1232                            .iter()
1233                            .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1234                            .cloned()
1235                            .collect();
1236                        return if narrowed.len() == 1 {
1237                            Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
1238                        } else if narrowed.is_empty() {
1239                            None
1240                        } else {
1241                            Some((name.clone(), Some(TypeExpr::Union(narrowed))))
1242                        };
1243                    }
1244                }
1245            }
1246        }
1247        None
1248    }
1249
1250    fn check_match_exhaustiveness(
1251        &mut self,
1252        value: &SNode,
1253        arms: &[MatchArm],
1254        scope: &TypeScope,
1255        span: Span,
1256    ) {
1257        // Detect pattern: match <expr>.variant { "VariantA" -> ... }
1258        let enum_name = match &value.node {
1259            Node::PropertyAccess { object, property } if property == "variant" => {
1260                // Infer the type of the object
1261                match self.infer_type(object, scope) {
1262                    Some(TypeExpr::Named(name)) => {
1263                        if scope.get_enum(&name).is_some() {
1264                            Some(name)
1265                        } else {
1266                            None
1267                        }
1268                    }
1269                    _ => None,
1270                }
1271            }
1272            _ => {
1273                // Direct match on an enum value: match <expr> { ... }
1274                match self.infer_type(value, scope) {
1275                    Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
1276                    _ => None,
1277                }
1278            }
1279        };
1280
1281        let Some(enum_name) = enum_name else {
1282            return;
1283        };
1284        let Some(variants) = scope.get_enum(&enum_name) else {
1285            return;
1286        };
1287
1288        // Collect variant names covered by match arms
1289        let mut covered: Vec<String> = Vec::new();
1290        let mut has_wildcard = false;
1291
1292        for arm in arms {
1293            match &arm.pattern.node {
1294                // String literal pattern (matching on .variant): "VariantA"
1295                Node::StringLiteral(s) => covered.push(s.clone()),
1296                // Identifier pattern acts as a wildcard/catch-all
1297                Node::Identifier(name) if name == "_" || !variants.contains(name) => {
1298                    has_wildcard = true;
1299                }
1300                // Direct enum construct pattern: EnumName.Variant
1301                Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
1302                // PropertyAccess pattern: EnumName.Variant (no args)
1303                Node::PropertyAccess { property, .. } => covered.push(property.clone()),
1304                _ => {
1305                    // Unknown pattern shape — conservatively treat as wildcard
1306                    has_wildcard = true;
1307                }
1308            }
1309        }
1310
1311        if has_wildcard {
1312            return;
1313        }
1314
1315        let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
1316        if !missing.is_empty() {
1317            let missing_str = missing
1318                .iter()
1319                .map(|s| format!("\"{}\"", s))
1320                .collect::<Vec<_>>()
1321                .join(", ");
1322            self.warning_at(
1323                format!(
1324                    "Non-exhaustive match on enum {}: missing variants {}",
1325                    enum_name, missing_str
1326                ),
1327                span,
1328            );
1329        }
1330    }
1331
1332    fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
1333        // Check against known function signatures
1334        let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
1335        if let Some(sig) = scope.get_fn(name).cloned() {
1336            if !has_spread
1337                && !is_builtin(name)
1338                && (args.len() < sig.required_params || args.len() > sig.params.len())
1339            {
1340                let expected = if sig.required_params == sig.params.len() {
1341                    format!("{}", sig.params.len())
1342                } else {
1343                    format!("{}-{}", sig.required_params, sig.params.len())
1344                };
1345                self.warning_at(
1346                    format!(
1347                        "Function '{}' expects {} arguments, got {}",
1348                        name,
1349                        expected,
1350                        args.len()
1351                    ),
1352                    span,
1353                );
1354            }
1355            // Build a scope that includes the function's generic type params
1356            // so they are treated as compatible with any concrete type.
1357            let call_scope = if sig.type_param_names.is_empty() {
1358                scope.clone()
1359            } else {
1360                let mut s = scope.child();
1361                for tp_name in &sig.type_param_names {
1362                    s.generic_type_params.insert(tp_name.clone());
1363                }
1364                s
1365            };
1366            for (i, (arg, (param_name, param_type))) in
1367                args.iter().zip(sig.params.iter()).enumerate()
1368            {
1369                if let Some(expected) = param_type {
1370                    let actual = self.infer_type(arg, scope);
1371                    if let Some(actual) = &actual {
1372                        if !self.types_compatible(expected, actual, &call_scope) {
1373                            self.error_at(
1374                                format!(
1375                                    "Argument {} ('{}'): expected {}, got {}",
1376                                    i + 1,
1377                                    param_name,
1378                                    format_type(expected),
1379                                    format_type(actual)
1380                                ),
1381                                arg.span,
1382                            );
1383                        }
1384                    }
1385                }
1386            }
1387            // Enforce where-clause constraints at call site
1388            if !sig.where_clauses.is_empty() {
1389                // Build mapping: type_param → concrete type from inferred args.
1390                // Recursively walks Generic types so list<T> + list<Dog> binds T=Dog.
1391                let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
1392                let type_param_set: std::collections::BTreeSet<String> =
1393                    sig.type_param_names.iter().cloned().collect();
1394                for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
1395                    if let Some(param_ty) = param_type {
1396                        if let Some(arg_ty) = self.infer_type(arg, scope) {
1397                            Self::extract_type_bindings(
1398                                param_ty,
1399                                &arg_ty,
1400                                &type_param_set,
1401                                &mut type_bindings,
1402                            );
1403                        }
1404                    }
1405                }
1406                for (type_param, bound) in &sig.where_clauses {
1407                    if let Some(concrete_type) = type_bindings.get(type_param) {
1408                        if let Some(reason) =
1409                            self.interface_mismatch_reason(concrete_type, bound, scope)
1410                        {
1411                            self.warning_at(
1412                                format!(
1413                                    "Type '{}' does not satisfy interface '{}': {} \
1414                                     (required by constraint `where {}: {}`)",
1415                                    concrete_type, bound, reason, type_param, bound
1416                                ),
1417                                span,
1418                            );
1419                        }
1420                    }
1421                }
1422            }
1423        }
1424        // Check args recursively
1425        for arg in args {
1426            self.check_node(arg, scope);
1427        }
1428    }
1429
1430    /// Infer the type of an expression.
1431    fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
1432        match &snode.node {
1433            Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
1434            Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
1435            Node::StringLiteral(_) | Node::InterpolatedString(_) => {
1436                Some(TypeExpr::Named("string".into()))
1437            }
1438            Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
1439            Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
1440            Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
1441            Node::DictLiteral(entries) => {
1442                // Infer shape type when all keys are string literals
1443                let mut fields = Vec::new();
1444                let mut all_string_keys = true;
1445                for entry in entries {
1446                    if let Node::StringLiteral(key) = &entry.key.node {
1447                        let val_type = self
1448                            .infer_type(&entry.value, scope)
1449                            .unwrap_or(TypeExpr::Named("nil".into()));
1450                        fields.push(ShapeField {
1451                            name: key.clone(),
1452                            type_expr: val_type,
1453                            optional: false,
1454                        });
1455                    } else {
1456                        all_string_keys = false;
1457                        break;
1458                    }
1459                }
1460                if all_string_keys && !fields.is_empty() {
1461                    Some(TypeExpr::Shape(fields))
1462                } else {
1463                    Some(TypeExpr::Named("dict".into()))
1464                }
1465            }
1466            Node::Closure { params, body, .. } => {
1467                // If all params are typed and we can infer a return type, produce FnType
1468                let all_typed = params.iter().all(|p| p.type_expr.is_some());
1469                if all_typed && !params.is_empty() {
1470                    let param_types: Vec<TypeExpr> =
1471                        params.iter().filter_map(|p| p.type_expr.clone()).collect();
1472                    // Try to infer return type from last expression in body
1473                    let ret = body.last().and_then(|last| self.infer_type(last, scope));
1474                    if let Some(ret_type) = ret {
1475                        return Some(TypeExpr::FnType {
1476                            params: param_types,
1477                            return_type: Box::new(ret_type),
1478                        });
1479                    }
1480                }
1481                Some(TypeExpr::Named("closure".into()))
1482            }
1483
1484            Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
1485
1486            Node::FunctionCall { name, .. } => {
1487                // Struct constructor calls return the struct type
1488                if scope.get_struct(name).is_some() {
1489                    return Some(TypeExpr::Named(name.clone()));
1490                }
1491                // Check user-defined function return types
1492                if let Some(sig) = scope.get_fn(name) {
1493                    return sig.return_type.clone();
1494                }
1495                // Check builtin return types
1496                builtin_return_type(name)
1497            }
1498
1499            Node::BinaryOp { op, left, right } => {
1500                let lt = self.infer_type(left, scope);
1501                let rt = self.infer_type(right, scope);
1502                infer_binary_op_type(op, &lt, &rt)
1503            }
1504
1505            Node::UnaryOp { op, operand } => {
1506                let t = self.infer_type(operand, scope);
1507                match op.as_str() {
1508                    "!" => Some(TypeExpr::Named("bool".into())),
1509                    "-" => t, // negation preserves type
1510                    _ => None,
1511                }
1512            }
1513
1514            Node::Ternary {
1515                true_expr,
1516                false_expr,
1517                ..
1518            } => {
1519                let tt = self.infer_type(true_expr, scope);
1520                let ft = self.infer_type(false_expr, scope);
1521                match (&tt, &ft) {
1522                    (Some(a), Some(b)) if a == b => tt,
1523                    (Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
1524                    (Some(_), None) => tt,
1525                    (None, Some(_)) => ft,
1526                    (None, None) => None,
1527                }
1528            }
1529
1530            Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
1531
1532            Node::PropertyAccess { object, property } => {
1533                // EnumName.Variant → infer as the enum type
1534                if let Node::Identifier(name) = &object.node {
1535                    if scope.get_enum(name).is_some() {
1536                        return Some(TypeExpr::Named(name.clone()));
1537                    }
1538                }
1539                // .variant on an enum value → string
1540                if property == "variant" {
1541                    let obj_type = self.infer_type(object, scope);
1542                    if let Some(TypeExpr::Named(name)) = &obj_type {
1543                        if scope.get_enum(name).is_some() {
1544                            return Some(TypeExpr::Named("string".into()));
1545                        }
1546                    }
1547                }
1548                // Shape field access: obj.field → field type
1549                let obj_type = self.infer_type(object, scope);
1550                if let Some(TypeExpr::Shape(fields)) = &obj_type {
1551                    if let Some(field) = fields.iter().find(|f| f.name == *property) {
1552                        return Some(field.type_expr.clone());
1553                    }
1554                }
1555                None
1556            }
1557
1558            Node::SubscriptAccess { object, index } => {
1559                let obj_type = self.infer_type(object, scope);
1560                match &obj_type {
1561                    Some(TypeExpr::List(inner)) => Some(*inner.clone()),
1562                    Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
1563                    Some(TypeExpr::Shape(fields)) => {
1564                        // If index is a string literal, look up the field type
1565                        if let Node::StringLiteral(key) = &index.node {
1566                            fields
1567                                .iter()
1568                                .find(|f| &f.name == key)
1569                                .map(|f| f.type_expr.clone())
1570                        } else {
1571                            None
1572                        }
1573                    }
1574                    Some(TypeExpr::Named(n)) if n == "list" => None,
1575                    Some(TypeExpr::Named(n)) if n == "dict" => None,
1576                    Some(TypeExpr::Named(n)) if n == "string" => {
1577                        Some(TypeExpr::Named("string".into()))
1578                    }
1579                    _ => None,
1580                }
1581            }
1582            Node::SliceAccess { object, .. } => {
1583                // Slicing a list returns the same list type; slicing a string returns string
1584                let obj_type = self.infer_type(object, scope);
1585                match &obj_type {
1586                    Some(TypeExpr::List(_)) => obj_type,
1587                    Some(TypeExpr::Named(n)) if n == "list" => obj_type,
1588                    Some(TypeExpr::Named(n)) if n == "string" => {
1589                        Some(TypeExpr::Named("string".into()))
1590                    }
1591                    _ => None,
1592                }
1593            }
1594            Node::MethodCall { object, method, .. }
1595            | Node::OptionalMethodCall { object, method, .. } => {
1596                let obj_type = self.infer_type(object, scope);
1597                let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
1598                    || matches!(&obj_type, Some(TypeExpr::DictType(..)));
1599                match method.as_str() {
1600                    // Shared: bool-returning methods
1601                    "contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
1602                        Some(TypeExpr::Named("bool".into()))
1603                    }
1604                    // Shared: int-returning methods
1605                    "count" | "index_of" => Some(TypeExpr::Named("int".into())),
1606                    // String methods
1607                    "trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
1608                    | "pad_left" | "pad_right" | "repeat" | "join" => {
1609                        Some(TypeExpr::Named("string".into()))
1610                    }
1611                    "split" | "chars" => Some(TypeExpr::Named("list".into())),
1612                    // filter returns dict for dicts, list for lists
1613                    "filter" => {
1614                        if is_dict {
1615                            Some(TypeExpr::Named("dict".into()))
1616                        } else {
1617                            Some(TypeExpr::Named("list".into()))
1618                        }
1619                    }
1620                    // List methods
1621                    "map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
1622                    "reduce" | "find" | "first" | "last" => None,
1623                    // Dict methods
1624                    "keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
1625                    "merge" | "map_values" => Some(TypeExpr::Named("dict".into())),
1626                    // Conversions
1627                    "to_string" => Some(TypeExpr::Named("string".into())),
1628                    "to_int" => Some(TypeExpr::Named("int".into())),
1629                    "to_float" => Some(TypeExpr::Named("float".into())),
1630                    _ => None,
1631                }
1632            }
1633
1634            // TryOperator on Result<T, E> produces T
1635            Node::TryOperator { operand } => {
1636                match self.infer_type(operand, scope) {
1637                    Some(TypeExpr::Named(name)) if name == "Result" => None, // unknown inner type
1638                    _ => None,
1639                }
1640            }
1641
1642            _ => None,
1643        }
1644    }
1645
1646    /// Check if two types are compatible (actual can be assigned to expected).
1647    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
1648        // Generic type parameters match anything.
1649        if let TypeExpr::Named(name) = expected {
1650            if scope.is_generic_type_param(name) {
1651                return true;
1652            }
1653        }
1654        if let TypeExpr::Named(name) = actual {
1655            if scope.is_generic_type_param(name) {
1656                return true;
1657            }
1658        }
1659        let expected = self.resolve_alias(expected, scope);
1660        let actual = self.resolve_alias(actual, scope);
1661
1662        // Interface satisfaction: if expected is an interface name, check if actual type
1663        // has all required methods (Go-style implicit satisfaction).
1664        if let TypeExpr::Named(iface_name) = &expected {
1665            if scope.get_interface(iface_name).is_some() {
1666                if let TypeExpr::Named(type_name) = &actual {
1667                    return self.satisfies_interface(type_name, iface_name, scope);
1668                }
1669                return false;
1670            }
1671        }
1672
1673        match (&expected, &actual) {
1674            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
1675            (TypeExpr::Union(members), actual_type) => members
1676                .iter()
1677                .any(|m| self.types_compatible(m, actual_type, scope)),
1678            (expected_type, TypeExpr::Union(members)) => members
1679                .iter()
1680                .all(|m| self.types_compatible(expected_type, m, scope)),
1681            (TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
1682            (TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
1683            (TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
1684                if expected_field.optional {
1685                    return true;
1686                }
1687                af.iter().any(|actual_field| {
1688                    actual_field.name == expected_field.name
1689                        && self.types_compatible(
1690                            &expected_field.type_expr,
1691                            &actual_field.type_expr,
1692                            scope,
1693                        )
1694                })
1695            }),
1696            // dict<K, V> expected, Shape actual → all field values must match V
1697            (TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
1698                let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
1699                keys_ok
1700                    && af
1701                        .iter()
1702                        .all(|f| self.types_compatible(ev, &f.type_expr, scope))
1703            }
1704            // Shape expected, dict<K, V> actual → gradual: allow since dict may have the fields
1705            (TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
1706            (TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
1707                self.types_compatible(expected_inner, actual_inner, scope)
1708            }
1709            (TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
1710            (TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
1711            (TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
1712                self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
1713            }
1714            (TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
1715            (TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
1716            // FnType compatibility: params match positionally and return types match
1717            (
1718                TypeExpr::FnType {
1719                    params: ep,
1720                    return_type: er,
1721                },
1722                TypeExpr::FnType {
1723                    params: ap,
1724                    return_type: ar,
1725                },
1726            ) => {
1727                ep.len() == ap.len()
1728                    && ep
1729                        .iter()
1730                        .zip(ap.iter())
1731                        .all(|(e, a)| self.types_compatible(e, a, scope))
1732                    && self.types_compatible(er, ar, scope)
1733            }
1734            // FnType is compatible with Named("closure") for backward compat
1735            (TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
1736            (TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
1737            _ => false,
1738        }
1739    }
1740
1741    fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
1742        if let TypeExpr::Named(name) = ty {
1743            if let Some(resolved) = scope.resolve_type(name) {
1744                return resolved.clone();
1745            }
1746        }
1747        ty.clone()
1748    }
1749
1750    fn error_at(&mut self, message: String, span: Span) {
1751        self.diagnostics.push(TypeDiagnostic {
1752            message,
1753            severity: DiagnosticSeverity::Error,
1754            span: Some(span),
1755            help: None,
1756        });
1757    }
1758
1759    #[allow(dead_code)]
1760    fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
1761        self.diagnostics.push(TypeDiagnostic {
1762            message,
1763            severity: DiagnosticSeverity::Error,
1764            span: Some(span),
1765            help: Some(help),
1766        });
1767    }
1768
1769    fn warning_at(&mut self, message: String, span: Span) {
1770        self.diagnostics.push(TypeDiagnostic {
1771            message,
1772            severity: DiagnosticSeverity::Warning,
1773            span: Some(span),
1774            help: None,
1775        });
1776    }
1777
1778    #[allow(dead_code)]
1779    fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
1780        self.diagnostics.push(TypeDiagnostic {
1781            message,
1782            severity: DiagnosticSeverity::Warning,
1783            span: Some(span),
1784            help: Some(help),
1785        });
1786    }
1787}
1788
1789impl Default for TypeChecker {
1790    fn default() -> Self {
1791        Self::new()
1792    }
1793}
1794
1795/// Infer the result type of a binary operation.
1796fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
1797    match op {
1798        "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
1799            Some(TypeExpr::Named("bool".into()))
1800        }
1801        "+" => match (left, right) {
1802            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1803                match (l.as_str(), r.as_str()) {
1804                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1805                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1806                    ("string", _) => Some(TypeExpr::Named("string".into())),
1807                    ("list", "list") => Some(TypeExpr::Named("list".into())),
1808                    ("dict", "dict") => Some(TypeExpr::Named("dict".into())),
1809                    _ => Some(TypeExpr::Named("string".into())),
1810                }
1811            }
1812            _ => None,
1813        },
1814        "-" | "*" | "/" | "%" => match (left, right) {
1815            (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
1816                match (l.as_str(), r.as_str()) {
1817                    ("int", "int") => Some(TypeExpr::Named("int".into())),
1818                    ("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
1819                    _ => None,
1820                }
1821            }
1822            _ => None,
1823        },
1824        "??" => match (left, right) {
1825            (Some(TypeExpr::Union(members)), _) => {
1826                let non_nil: Vec<_> = members
1827                    .iter()
1828                    .filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
1829                    .cloned()
1830                    .collect();
1831                if non_nil.len() == 1 {
1832                    Some(non_nil[0].clone())
1833                } else if non_nil.is_empty() {
1834                    right.clone()
1835                } else {
1836                    Some(TypeExpr::Union(non_nil))
1837                }
1838            }
1839            _ => right.clone(),
1840        },
1841        "|>" => None,
1842        _ => None,
1843    }
1844}
1845
1846/// Format a type expression for display in error messages.
1847/// Produce a detail string describing why a Shape type is incompatible with
1848/// another Shape type — e.g. "missing field 'age' (int)" or "field 'name'
1849/// has type int, expected string".  Returns `None` if both types are not shapes.
1850pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
1851    if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
1852        let mut details = Vec::new();
1853        for field in ef {
1854            if field.optional {
1855                continue;
1856            }
1857            match af.iter().find(|f| f.name == field.name) {
1858                None => details.push(format!(
1859                    "missing field '{}' ({})",
1860                    field.name,
1861                    format_type(&field.type_expr)
1862                )),
1863                Some(actual_field) => {
1864                    let e_str = format_type(&field.type_expr);
1865                    let a_str = format_type(&actual_field.type_expr);
1866                    if e_str != a_str {
1867                        details.push(format!(
1868                            "field '{}' has type {}, expected {}",
1869                            field.name, a_str, e_str
1870                        ));
1871                    }
1872                }
1873            }
1874        }
1875        if details.is_empty() {
1876            None
1877        } else {
1878            Some(details.join("; "))
1879        }
1880    } else {
1881        None
1882    }
1883}
1884
1885pub fn format_type(ty: &TypeExpr) -> String {
1886    match ty {
1887        TypeExpr::Named(n) => n.clone(),
1888        TypeExpr::Union(types) => types
1889            .iter()
1890            .map(format_type)
1891            .collect::<Vec<_>>()
1892            .join(" | "),
1893        TypeExpr::Shape(fields) => {
1894            let inner: Vec<String> = fields
1895                .iter()
1896                .map(|f| {
1897                    let opt = if f.optional { "?" } else { "" };
1898                    format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
1899                })
1900                .collect();
1901            format!("{{{}}}", inner.join(", "))
1902        }
1903        TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
1904        TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
1905        TypeExpr::FnType {
1906            params,
1907            return_type,
1908        } => {
1909            let params_str = params
1910                .iter()
1911                .map(format_type)
1912                .collect::<Vec<_>>()
1913                .join(", ");
1914            format!("fn({}) -> {}", params_str, format_type(return_type))
1915        }
1916    }
1917}
1918
1919#[cfg(test)]
1920mod tests {
1921    use super::*;
1922    use crate::Parser;
1923    use harn_lexer::Lexer;
1924
1925    fn check_source(source: &str) -> Vec<TypeDiagnostic> {
1926        let mut lexer = Lexer::new(source);
1927        let tokens = lexer.tokenize().unwrap();
1928        let mut parser = Parser::new(tokens);
1929        let program = parser.parse().unwrap();
1930        TypeChecker::new().check(&program)
1931    }
1932
1933    fn errors(source: &str) -> Vec<String> {
1934        check_source(source)
1935            .into_iter()
1936            .filter(|d| d.severity == DiagnosticSeverity::Error)
1937            .map(|d| d.message)
1938            .collect()
1939    }
1940
1941    #[test]
1942    fn test_no_errors_for_untyped_code() {
1943        let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
1944        assert!(errs.is_empty());
1945    }
1946
1947    #[test]
1948    fn test_correct_typed_let() {
1949        let errs = errors("pipeline t(task) { let x: int = 42 }");
1950        assert!(errs.is_empty());
1951    }
1952
1953    #[test]
1954    fn test_type_mismatch_let() {
1955        let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
1956        assert_eq!(errs.len(), 1);
1957        assert!(errs[0].contains("Type mismatch"));
1958        assert!(errs[0].contains("int"));
1959        assert!(errs[0].contains("string"));
1960    }
1961
1962    #[test]
1963    fn test_correct_typed_fn() {
1964        let errs = errors(
1965            "pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
1966        );
1967        assert!(errs.is_empty());
1968    }
1969
1970    #[test]
1971    fn test_fn_arg_type_mismatch() {
1972        let errs = errors(
1973            r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
1974add("hello", 2) }"#,
1975        );
1976        assert_eq!(errs.len(), 1);
1977        assert!(errs[0].contains("Argument 1"));
1978        assert!(errs[0].contains("expected int"));
1979    }
1980
1981    #[test]
1982    fn test_return_type_mismatch() {
1983        let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
1984        assert_eq!(errs.len(), 1);
1985        assert!(errs[0].contains("Return type mismatch"));
1986    }
1987
1988    #[test]
1989    fn test_union_type_compatible() {
1990        let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
1991        assert!(errs.is_empty());
1992    }
1993
1994    #[test]
1995    fn test_union_type_mismatch() {
1996        let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
1997        assert_eq!(errs.len(), 1);
1998        assert!(errs[0].contains("Type mismatch"));
1999    }
2000
2001    #[test]
2002    fn test_type_inference_propagation() {
2003        let errs = errors(
2004            r#"pipeline t(task) {
2005  fn add(a: int, b: int) -> int { return a + b }
2006  let result: string = add(1, 2)
2007}"#,
2008        );
2009        assert_eq!(errs.len(), 1);
2010        assert!(errs[0].contains("Type mismatch"));
2011        assert!(errs[0].contains("string"));
2012        assert!(errs[0].contains("int"));
2013    }
2014
2015    #[test]
2016    fn test_builtin_return_type_inference() {
2017        let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
2018        assert_eq!(errs.len(), 1);
2019        assert!(errs[0].contains("string"));
2020        assert!(errs[0].contains("int"));
2021    }
2022
2023    #[test]
2024    fn test_workflow_and_transcript_builtins_are_known() {
2025        let errs = errors(
2026            r#"pipeline t(task) {
2027  let flow = workflow_graph({name: "demo", entry: "act", nodes: {act: {kind: "stage"}}})
2028  let report: dict = workflow_policy_report(flow, {tools: ["read"], capabilities: {workspace: ["read_text"]}})
2029  let run: dict = workflow_execute("task", flow, [], {})
2030  let fixture: dict = run_record_fixture(run?.run)
2031  let suite: dict = run_record_eval_suite([{run: run?.run, fixture: fixture}])
2032  let diff: dict = run_record_diff(run?.run, run?.run)
2033  let manifest: dict = eval_suite_manifest({cases: [{run_path: "run.json"}]})
2034  let suite_report: dict = eval_suite_run(manifest)
2035  let wf: dict = artifact_workspace_file("src/main.rs", "fn main() {}", {source: "host"})
2036  let snap: dict = artifact_workspace_snapshot(["src/main.rs"], "snapshot")
2037  let selection: dict = artifact_editor_selection("src/main.rs", "main")
2038  let verify: dict = artifact_verification_result("verify", "ok")
2039  let test_result: dict = artifact_test_result("tests", "pass")
2040  let cmd: dict = artifact_command_result("cargo test", {status: 0})
2041  let patch: dict = artifact_diff("src/main.rs", "old", "new")
2042  let git: dict = artifact_git_diff("diff --git a b")
2043  let review: dict = artifact_diff_review(patch, "review me")
2044  let decision: dict = artifact_review_decision(review, "accepted")
2045  let transcript = transcript_reset({metadata: {source: "test"}})
2046  let visible: string = transcript_render_visible(transcript_archive(transcript))
2047  let events: list = transcript_events(transcript)
2048  let context: string = artifact_context([], {max_artifacts: 1})
2049  println(report)
2050  println(run)
2051  println(fixture)
2052  println(suite)
2053  println(diff)
2054  println(manifest)
2055  println(suite_report)
2056  println(wf)
2057  println(snap)
2058  println(selection)
2059  println(verify)
2060  println(test_result)
2061  println(cmd)
2062  println(patch)
2063  println(git)
2064  println(review)
2065  println(decision)
2066  println(visible)
2067  println(events)
2068  println(context)
2069}"#,
2070        );
2071        assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
2072    }
2073
2074    #[test]
2075    fn test_binary_op_type_inference() {
2076        let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
2077        assert_eq!(errs.len(), 1);
2078    }
2079
2080    #[test]
2081    fn test_comparison_returns_bool() {
2082        let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
2083        assert!(errs.is_empty());
2084    }
2085
2086    #[test]
2087    fn test_int_float_promotion() {
2088        let errs = errors("pipeline t(task) { let x: float = 42 }");
2089        assert!(errs.is_empty());
2090    }
2091
2092    #[test]
2093    fn test_untyped_code_no_errors() {
2094        let errs = errors(
2095            r#"pipeline t(task) {
2096  fn process(data) {
2097    let result = data + " processed"
2098    return result
2099  }
2100  log(process("hello"))
2101}"#,
2102        );
2103        assert!(errs.is_empty());
2104    }
2105
2106    #[test]
2107    fn test_type_alias() {
2108        let errs = errors(
2109            r#"pipeline t(task) {
2110  type Name = string
2111  let x: Name = "hello"
2112}"#,
2113        );
2114        assert!(errs.is_empty());
2115    }
2116
2117    #[test]
2118    fn test_type_alias_mismatch() {
2119        let errs = errors(
2120            r#"pipeline t(task) {
2121  type Name = string
2122  let x: Name = 42
2123}"#,
2124        );
2125        assert_eq!(errs.len(), 1);
2126    }
2127
2128    #[test]
2129    fn test_assignment_type_check() {
2130        let errs = errors(
2131            r#"pipeline t(task) {
2132  var x: int = 0
2133  x = "hello"
2134}"#,
2135        );
2136        assert_eq!(errs.len(), 1);
2137        assert!(errs[0].contains("cannot assign string"));
2138    }
2139
2140    #[test]
2141    fn test_covariance_int_to_float_in_fn() {
2142        let errs = errors(
2143            "pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
2144        );
2145        assert!(errs.is_empty());
2146    }
2147
2148    #[test]
2149    fn test_covariance_return_type() {
2150        let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
2151        assert!(errs.is_empty());
2152    }
2153
2154    #[test]
2155    fn test_no_contravariance_float_to_int() {
2156        let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
2157        assert_eq!(errs.len(), 1);
2158    }
2159
2160    // --- Exhaustiveness checking tests ---
2161
2162    fn warnings(source: &str) -> Vec<String> {
2163        check_source(source)
2164            .into_iter()
2165            .filter(|d| d.severity == DiagnosticSeverity::Warning)
2166            .map(|d| d.message)
2167            .collect()
2168    }
2169
2170    #[test]
2171    fn test_exhaustive_match_no_warning() {
2172        let warns = warnings(
2173            r#"pipeline t(task) {
2174  enum Color { Red, Green, Blue }
2175  let c = Color.Red
2176  match c.variant {
2177    "Red" -> { log("r") }
2178    "Green" -> { log("g") }
2179    "Blue" -> { log("b") }
2180  }
2181}"#,
2182        );
2183        let exhaustive_warns: Vec<_> = warns
2184            .iter()
2185            .filter(|w| w.contains("Non-exhaustive"))
2186            .collect();
2187        assert!(exhaustive_warns.is_empty());
2188    }
2189
2190    #[test]
2191    fn test_non_exhaustive_match_warning() {
2192        let warns = warnings(
2193            r#"pipeline t(task) {
2194  enum Color { Red, Green, Blue }
2195  let c = Color.Red
2196  match c.variant {
2197    "Red" -> { log("r") }
2198    "Green" -> { log("g") }
2199  }
2200}"#,
2201        );
2202        let exhaustive_warns: Vec<_> = warns
2203            .iter()
2204            .filter(|w| w.contains("Non-exhaustive"))
2205            .collect();
2206        assert_eq!(exhaustive_warns.len(), 1);
2207        assert!(exhaustive_warns[0].contains("Blue"));
2208    }
2209
2210    #[test]
2211    fn test_non_exhaustive_multiple_missing() {
2212        let warns = warnings(
2213            r#"pipeline t(task) {
2214  enum Status { Active, Inactive, Pending }
2215  let s = Status.Active
2216  match s.variant {
2217    "Active" -> { log("a") }
2218  }
2219}"#,
2220        );
2221        let exhaustive_warns: Vec<_> = warns
2222            .iter()
2223            .filter(|w| w.contains("Non-exhaustive"))
2224            .collect();
2225        assert_eq!(exhaustive_warns.len(), 1);
2226        assert!(exhaustive_warns[0].contains("Inactive"));
2227        assert!(exhaustive_warns[0].contains("Pending"));
2228    }
2229
2230    #[test]
2231    fn test_enum_construct_type_inference() {
2232        let errs = errors(
2233            r#"pipeline t(task) {
2234  enum Color { Red, Green, Blue }
2235  let c: Color = Color.Red
2236}"#,
2237        );
2238        assert!(errs.is_empty());
2239    }
2240
2241    // --- Type narrowing tests ---
2242
2243    #[test]
2244    fn test_nil_coalescing_strips_nil() {
2245        // After ??, nil should be stripped from the type
2246        let errs = errors(
2247            r#"pipeline t(task) {
2248  let x: string | nil = nil
2249  let y: string = x ?? "default"
2250}"#,
2251        );
2252        assert!(errs.is_empty());
2253    }
2254
2255    #[test]
2256    fn test_shape_mismatch_detail_missing_field() {
2257        let errs = errors(
2258            r#"pipeline t(task) {
2259  let x: {name: string, age: int} = {name: "hello"}
2260}"#,
2261        );
2262        assert_eq!(errs.len(), 1);
2263        assert!(
2264            errs[0].contains("missing field 'age'"),
2265            "expected detail about missing field, got: {}",
2266            errs[0]
2267        );
2268    }
2269
2270    #[test]
2271    fn test_shape_mismatch_detail_wrong_type() {
2272        let errs = errors(
2273            r#"pipeline t(task) {
2274  let x: {name: string, age: int} = {name: 42, age: 10}
2275}"#,
2276        );
2277        assert_eq!(errs.len(), 1);
2278        assert!(
2279            errs[0].contains("field 'name' has type int, expected string"),
2280            "expected detail about wrong type, got: {}",
2281            errs[0]
2282        );
2283    }
2284
2285    // --- Match pattern type validation tests ---
2286
2287    #[test]
2288    fn test_match_pattern_string_against_int() {
2289        let warns = warnings(
2290            r#"pipeline t(task) {
2291  let x: int = 42
2292  match x {
2293    "hello" -> { log("bad") }
2294    42 -> { log("ok") }
2295  }
2296}"#,
2297        );
2298        let pattern_warns: Vec<_> = warns
2299            .iter()
2300            .filter(|w| w.contains("Match pattern type mismatch"))
2301            .collect();
2302        assert_eq!(pattern_warns.len(), 1);
2303        assert!(pattern_warns[0].contains("matching int against string literal"));
2304    }
2305
2306    #[test]
2307    fn test_match_pattern_int_against_string() {
2308        let warns = warnings(
2309            r#"pipeline t(task) {
2310  let x: string = "hello"
2311  match x {
2312    42 -> { log("bad") }
2313    "hello" -> { log("ok") }
2314  }
2315}"#,
2316        );
2317        let pattern_warns: Vec<_> = warns
2318            .iter()
2319            .filter(|w| w.contains("Match pattern type mismatch"))
2320            .collect();
2321        assert_eq!(pattern_warns.len(), 1);
2322        assert!(pattern_warns[0].contains("matching string against int literal"));
2323    }
2324
2325    #[test]
2326    fn test_match_pattern_bool_against_int() {
2327        let warns = warnings(
2328            r#"pipeline t(task) {
2329  let x: int = 42
2330  match x {
2331    true -> { log("bad") }
2332    42 -> { log("ok") }
2333  }
2334}"#,
2335        );
2336        let pattern_warns: Vec<_> = warns
2337            .iter()
2338            .filter(|w| w.contains("Match pattern type mismatch"))
2339            .collect();
2340        assert_eq!(pattern_warns.len(), 1);
2341        assert!(pattern_warns[0].contains("matching int against bool literal"));
2342    }
2343
2344    #[test]
2345    fn test_match_pattern_float_against_string() {
2346        let warns = warnings(
2347            r#"pipeline t(task) {
2348  let x: string = "hello"
2349  match x {
2350    3.14 -> { log("bad") }
2351    "hello" -> { log("ok") }
2352  }
2353}"#,
2354        );
2355        let pattern_warns: Vec<_> = warns
2356            .iter()
2357            .filter(|w| w.contains("Match pattern type mismatch"))
2358            .collect();
2359        assert_eq!(pattern_warns.len(), 1);
2360        assert!(pattern_warns[0].contains("matching string against float literal"));
2361    }
2362
2363    #[test]
2364    fn test_match_pattern_int_against_float_ok() {
2365        // int and float are compatible for match patterns
2366        let warns = warnings(
2367            r#"pipeline t(task) {
2368  let x: float = 3.14
2369  match x {
2370    42 -> { log("ok") }
2371    _ -> { log("default") }
2372  }
2373}"#,
2374        );
2375        let pattern_warns: Vec<_> = warns
2376            .iter()
2377            .filter(|w| w.contains("Match pattern type mismatch"))
2378            .collect();
2379        assert!(pattern_warns.is_empty());
2380    }
2381
2382    #[test]
2383    fn test_match_pattern_float_against_int_ok() {
2384        // float and int are compatible for match patterns
2385        let warns = warnings(
2386            r#"pipeline t(task) {
2387  let x: int = 42
2388  match x {
2389    3.14 -> { log("close") }
2390    _ -> { log("default") }
2391  }
2392}"#,
2393        );
2394        let pattern_warns: Vec<_> = warns
2395            .iter()
2396            .filter(|w| w.contains("Match pattern type mismatch"))
2397            .collect();
2398        assert!(pattern_warns.is_empty());
2399    }
2400
2401    #[test]
2402    fn test_match_pattern_correct_types_no_warning() {
2403        let warns = warnings(
2404            r#"pipeline t(task) {
2405  let x: int = 42
2406  match x {
2407    1 -> { log("one") }
2408    2 -> { log("two") }
2409    _ -> { log("other") }
2410  }
2411}"#,
2412        );
2413        let pattern_warns: Vec<_> = warns
2414            .iter()
2415            .filter(|w| w.contains("Match pattern type mismatch"))
2416            .collect();
2417        assert!(pattern_warns.is_empty());
2418    }
2419
2420    #[test]
2421    fn test_match_pattern_wildcard_no_warning() {
2422        let warns = warnings(
2423            r#"pipeline t(task) {
2424  let x: int = 42
2425  match x {
2426    _ -> { log("catch all") }
2427  }
2428}"#,
2429        );
2430        let pattern_warns: Vec<_> = warns
2431            .iter()
2432            .filter(|w| w.contains("Match pattern type mismatch"))
2433            .collect();
2434        assert!(pattern_warns.is_empty());
2435    }
2436
2437    #[test]
2438    fn test_match_pattern_untyped_no_warning() {
2439        // When value has no known type, no warning should be emitted
2440        let warns = warnings(
2441            r#"pipeline t(task) {
2442  let x = some_unknown_fn()
2443  match x {
2444    "hello" -> { log("string") }
2445    42 -> { log("int") }
2446  }
2447}"#,
2448        );
2449        let pattern_warns: Vec<_> = warns
2450            .iter()
2451            .filter(|w| w.contains("Match pattern type mismatch"))
2452            .collect();
2453        assert!(pattern_warns.is_empty());
2454    }
2455
2456    // --- Interface constraint type checking tests ---
2457
2458    fn iface_warns(source: &str) -> Vec<String> {
2459        warnings(source)
2460            .into_iter()
2461            .filter(|w| w.contains("does not satisfy interface"))
2462            .collect()
2463    }
2464
2465    #[test]
2466    fn test_interface_constraint_return_type_mismatch() {
2467        let warns = iface_warns(
2468            r#"pipeline t(task) {
2469  interface Sizable {
2470    fn size(self) -> int
2471  }
2472  struct Box { width: int }
2473  impl Box {
2474    fn size(self) -> string { return "nope" }
2475  }
2476  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2477  measure(Box({width: 3}))
2478}"#,
2479        );
2480        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2481        assert!(
2482            warns[0].contains("method 'size' returns 'string', expected 'int'"),
2483            "unexpected message: {}",
2484            warns[0]
2485        );
2486    }
2487
2488    #[test]
2489    fn test_interface_constraint_param_type_mismatch() {
2490        let warns = iface_warns(
2491            r#"pipeline t(task) {
2492  interface Processor {
2493    fn process(self, x: int) -> string
2494  }
2495  struct MyProc { name: string }
2496  impl MyProc {
2497    fn process(self, x: string) -> string { return x }
2498  }
2499  fn run_proc<T>(p: T) where T: Processor { log(p.process(42)) }
2500  run_proc(MyProc({name: "a"}))
2501}"#,
2502        );
2503        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2504        assert!(
2505            warns[0].contains("method 'process' parameter 1 has type 'string', expected 'int'"),
2506            "unexpected message: {}",
2507            warns[0]
2508        );
2509    }
2510
2511    #[test]
2512    fn test_interface_constraint_missing_method() {
2513        let warns = iface_warns(
2514            r#"pipeline t(task) {
2515  interface Sizable {
2516    fn size(self) -> int
2517  }
2518  struct Box { width: int }
2519  impl Box {
2520    fn area(self) -> int { return self.width }
2521  }
2522  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2523  measure(Box({width: 3}))
2524}"#,
2525        );
2526        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2527        assert!(
2528            warns[0].contains("missing method 'size'"),
2529            "unexpected message: {}",
2530            warns[0]
2531        );
2532    }
2533
2534    #[test]
2535    fn test_interface_constraint_param_count_mismatch() {
2536        let warns = iface_warns(
2537            r#"pipeline t(task) {
2538  interface Doubler {
2539    fn double(self, x: int) -> int
2540  }
2541  struct Bad { v: int }
2542  impl Bad {
2543    fn double(self) -> int { return self.v * 2 }
2544  }
2545  fn run_double<T>(d: T) where T: Doubler { log(d.double(3)) }
2546  run_double(Bad({v: 5}))
2547}"#,
2548        );
2549        assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
2550        assert!(
2551            warns[0].contains("method 'double' has 0 parameter(s), expected 1"),
2552            "unexpected message: {}",
2553            warns[0]
2554        );
2555    }
2556
2557    #[test]
2558    fn test_interface_constraint_satisfied() {
2559        let warns = iface_warns(
2560            r#"pipeline t(task) {
2561  interface Sizable {
2562    fn size(self) -> int
2563  }
2564  struct Box { width: int, height: int }
2565  impl Box {
2566    fn size(self) -> int { return self.width * self.height }
2567  }
2568  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2569  measure(Box({width: 3, height: 4}))
2570}"#,
2571        );
2572        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2573    }
2574
2575    #[test]
2576    fn test_interface_constraint_untyped_impl_compatible() {
2577        // Gradual typing: untyped impl return should not trigger warning
2578        let warns = iface_warns(
2579            r#"pipeline t(task) {
2580  interface Sizable {
2581    fn size(self) -> int
2582  }
2583  struct Box { width: int }
2584  impl Box {
2585    fn size(self) { return self.width }
2586  }
2587  fn measure<T>(item: T) where T: Sizable { log(item.size()) }
2588  measure(Box({width: 3}))
2589}"#,
2590        );
2591        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2592    }
2593
2594    #[test]
2595    fn test_interface_constraint_int_float_covariance() {
2596        // int is compatible with float (covariance)
2597        let warns = iface_warns(
2598            r#"pipeline t(task) {
2599  interface Measurable {
2600    fn value(self) -> float
2601  }
2602  struct Gauge { v: int }
2603  impl Gauge {
2604    fn value(self) -> int { return self.v }
2605  }
2606  fn read_val<T>(g: T) where T: Measurable { log(g.value()) }
2607  read_val(Gauge({v: 42}))
2608}"#,
2609        );
2610        assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
2611    }
2612}