Skip to main content

lust/typechecker/
mod.rs

1mod expr_checker;
2mod item_checker;
3mod stmt_checker;
4mod type_env;
5use crate::modules::{LoadedModule, ModuleImports};
6use crate::{
7    ast::*,
8    config::LustConfig,
9    error::{LustError, Result},
10};
11pub(super) use alloc::{
12    boxed::Box,
13    format,
14    string::{String, ToString},
15    vec,
16    vec::Vec,
17};
18use core::mem;
19use hashbrown::{HashMap, HashSet};
20pub use type_env::FunctionSignature;
21pub use type_env::TypeEnv;
22pub struct TypeChecker {
23    env: TypeEnv,
24    current_function_return_type: Option<Type>,
25    in_loop: bool,
26    pending_generic_instances: Option<HashMap<String, Type>>,
27    expected_lambda_signature: Option<(Vec<Type>, Option<Type>)>,
28    current_trait_bounds: HashMap<String, Vec<String>>,
29    current_module: Option<String>,
30    imports_by_module: HashMap<String, ModuleImports>,
31    expr_types_by_module: HashMap<String, HashMap<Span, Type>>,
32    variable_types_by_module: HashMap<String, HashMap<Span, Type>>,
33    short_circuit_info: HashMap<String, HashMap<Span, ShortCircuitInfo>>,
34}
35
36pub struct TypeCollection {
37    pub expr_types: HashMap<String, HashMap<Span, Type>>,
38    pub variable_types: HashMap<String, HashMap<Span, Type>>,
39}
40
41#[derive(Clone, Debug)]
42struct ShortCircuitInfo {
43    truthy: Option<Type>,
44    falsy: Option<Type>,
45    option_inner: Option<Type>,
46}
47
48impl TypeChecker {
49    pub fn new() -> Self {
50        Self::with_config(&LustConfig::default())
51    }
52
53    pub fn with_config(config: &LustConfig) -> Self {
54        Self {
55            env: TypeEnv::with_config(config),
56            current_function_return_type: None,
57            in_loop: false,
58            pending_generic_instances: None,
59            expected_lambda_signature: None,
60            current_trait_bounds: HashMap::new(),
61            current_module: None,
62            imports_by_module: HashMap::new(),
63            expr_types_by_module: HashMap::new(),
64            variable_types_by_module: HashMap::new(),
65            short_circuit_info: HashMap::new(),
66        }
67    }
68
69    fn dummy_span() -> Span {
70        Span::new(0, 0, 0, 0)
71    }
72
73    pub fn check_module(&mut self, items: &[Item]) -> Result<()> {
74        for item in items {
75            self.register_type_definition(item)?;
76        }
77
78        self.validate_struct_cycles()?;
79        self.env.push_scope();
80        self.register_module_init_locals(items)?;
81        for item in items {
82            self.check_item(item)?;
83        }
84
85        self.env.pop_scope();
86        Ok(())
87    }
88
89    pub fn check_program(&mut self, modules: &[LoadedModule]) -> Result<()> {
90        for m in modules {
91            self.current_module = Some(m.path.clone());
92            for item in &m.items {
93                self.register_type_definition(item)?;
94            }
95        }
96
97        self.validate_struct_cycles()?;
98        for m in modules {
99            self.current_module = Some(m.path.clone());
100            self.env.push_scope();
101            self.register_module_init_locals(&m.items)?;
102            for item in &m.items {
103                self.check_item(item)?;
104            }
105
106            self.env.pop_scope();
107        }
108
109        self.current_module = None;
110        Ok(())
111    }
112
113    fn validate_struct_cycles(&self) -> Result<()> {
114        use hashbrown::{HashMap, HashSet};
115        let struct_defs = self.env.struct_definitions();
116        if struct_defs.is_empty() {
117            return Ok(());
118        }
119
120        let mut simple_to_full: HashMap<String, Vec<String>> = HashMap::new();
121        for name in struct_defs.keys() {
122            let simple = name.rsplit('.').next().unwrap_or(name).to_string();
123            simple_to_full.entry(simple).or_default().push(name.clone());
124        }
125
126        let mut struct_has_weak: HashMap<String, bool> = HashMap::new();
127        for (name, def) in &struct_defs {
128            let has_weak = def
129                .fields
130                .iter()
131                .any(|field| matches!(field.ownership, FieldOwnership::Weak));
132            struct_has_weak.insert(name.clone(), has_weak);
133        }
134
135        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
136        for (name, def) in &struct_defs {
137            let module_prefix = name.rsplit_once('.').map(|(module, _)| module.to_string());
138            let mut edges: HashSet<String> = HashSet::new();
139            for field in &def.fields {
140                if matches!(field.ownership, FieldOwnership::Weak) {
141                    let target = field.weak_target.as_ref().ok_or_else(|| {
142                        self.type_error(format!(
143                            "Field '{}.{}' is marked as 'ref' but has no target type",
144                            name, field.name
145                        ))
146                    })?;
147                    let target_name = if let TypeKind::Named(inner) = &target.kind {
148                        inner
149                    } else {
150                        return Err(self.type_error(format!(
151                            "Field '{}.{}' uses 'ref' but only struct types are supported",
152                            name, field.name
153                        )));
154                    };
155                    let resolved = self.resolve_struct_name_for_cycle(
156                        target_name.as_str(),
157                        module_prefix.as_deref(),
158                        &struct_defs,
159                        &simple_to_full,
160                    );
161                    if resolved.is_none() {
162                        return Err(self.type_error(format!(
163                            "Field '{}.{}' uses 'ref' but '{}' is not a known struct type",
164                            name, field.name, target_name
165                        )));
166                    }
167
168                    continue;
169                }
170
171                self.collect_strong_struct_targets(
172                    &field.ty,
173                    module_prefix.as_deref(),
174                    &struct_defs,
175                    &simple_to_full,
176                    &mut edges,
177                );
178            }
179
180            graph.insert(name.clone(), edges.into_iter().collect());
181        }
182
183        fn dfs(
184            node: &str,
185            graph: &HashMap<String, Vec<String>>,
186            visited: &mut HashSet<String>,
187            on_stack: &mut HashSet<String>,
188            stack: &mut Vec<String>,
189        ) -> Option<Vec<String>> {
190            visited.insert(node.to_string());
191            on_stack.insert(node.to_string());
192            stack.push(node.to_string());
193            if let Some(neighbors) = graph.get(node) {
194                for neighbor in neighbors {
195                    if !visited.contains(neighbor) {
196                        if let Some(cycle) = dfs(neighbor, graph, visited, on_stack, stack) {
197                            return Some(cycle);
198                        }
199                    } else if on_stack.contains(neighbor) {
200                        if let Some(pos) = stack.iter().position(|n| n == neighbor) {
201                            let mut cycle = stack[pos..].to_vec();
202                            cycle.push(neighbor.clone());
203                            return Some(cycle);
204                        }
205                    }
206                }
207            }
208
209            stack.pop();
210            on_stack.remove(node);
211            None
212        }
213
214        let mut visited: HashSet<String> = HashSet::new();
215        let mut on_stack: HashSet<String> = HashSet::new();
216        let mut stack: Vec<String> = Vec::new();
217        for name in struct_defs.keys() {
218            if !visited.contains(name) {
219                if let Some(cycle) = dfs(name, &graph, &mut visited, &mut on_stack, &mut stack) {
220                    let contains_weak = cycle
221                        .iter()
222                        .any(|node| struct_has_weak.get(node).copied().unwrap_or(false));
223                    if contains_weak {
224                        continue;
225                    }
226
227                    // let description = cycle.join(" -> ");
228                    break;
229                    // return Err(self.type_error(format!(
230                    //     "Strong ownership cycle detected: {}. Mark at least one field as 'ref' to break the cycle.",
231                    //     description
232                    // )));
233                }
234            }
235        }
236
237        Ok(())
238    }
239
240    fn collect_strong_struct_targets(
241        &self,
242        ty: &Type,
243        parent_module: Option<&str>,
244        struct_defs: &HashMap<String, StructDef>,
245        simple_to_full: &HashMap<String, Vec<String>>,
246        out: &mut HashSet<String>,
247    ) {
248        match &ty.kind {
249            TypeKind::Named(name) => {
250                if let Some(resolved) = self.resolve_struct_name_for_cycle(
251                    name,
252                    parent_module,
253                    struct_defs,
254                    simple_to_full,
255                ) {
256                    out.insert(resolved);
257                }
258            }
259
260            TypeKind::Array(inner)
261            | TypeKind::Ref(inner)
262            | TypeKind::MutRef(inner)
263            | TypeKind::Option(inner) => {
264                self.collect_strong_struct_targets(
265                    inner,
266                    parent_module,
267                    struct_defs,
268                    simple_to_full,
269                    out,
270                );
271            }
272
273            TypeKind::Map(key, value) => {
274                self.collect_strong_struct_targets(
275                    key,
276                    parent_module,
277                    struct_defs,
278                    simple_to_full,
279                    out,
280                );
281                self.collect_strong_struct_targets(
282                    value,
283                    parent_module,
284                    struct_defs,
285                    simple_to_full,
286                    out,
287                );
288            }
289
290            TypeKind::Tuple(elements) | TypeKind::Union(elements) => {
291                for element in elements {
292                    self.collect_strong_struct_targets(
293                        element,
294                        parent_module,
295                        struct_defs,
296                        simple_to_full,
297                        out,
298                    );
299                }
300            }
301
302            TypeKind::Result(ok, err) => {
303                self.collect_strong_struct_targets(
304                    ok,
305                    parent_module,
306                    struct_defs,
307                    simple_to_full,
308                    out,
309                );
310                self.collect_strong_struct_targets(
311                    err,
312                    parent_module,
313                    struct_defs,
314                    simple_to_full,
315                    out,
316                );
317            }
318
319            TypeKind::GenericInstance { type_args, .. } => {
320                for arg in type_args {
321                    self.collect_strong_struct_targets(
322                        arg,
323                        parent_module,
324                        struct_defs,
325                        simple_to_full,
326                        out,
327                    );
328                }
329            }
330
331            _ => {}
332        }
333    }
334
335    fn resolve_struct_name_for_cycle(
336        &self,
337        name: &str,
338        parent_module: Option<&str>,
339        struct_defs: &HashMap<String, StructDef>,
340        simple_to_full: &HashMap<String, Vec<String>>,
341    ) -> Option<String> {
342        if struct_defs.contains_key(name) {
343            return Some(name.to_string());
344        }
345
346        if name.contains('.') {
347            return None;
348        }
349
350        if let Some(candidates) = simple_to_full.get(name) {
351            if candidates.len() == 1 {
352                return Some(candidates[0].clone());
353            }
354
355            if let Some(module) = parent_module {
356                for candidate in candidates {
357                    if let Some((candidate_module, _)) = candidate.rsplit_once('.') {
358                        if candidate_module == module {
359                            return Some(candidate.clone());
360                        }
361                    }
362                }
363            }
364        }
365
366        None
367    }
368
369    pub fn set_imports_by_module(&mut self, map: HashMap<String, ModuleImports>) {
370        self.imports_by_module = map;
371    }
372
373    pub fn take_type_info(&mut self) -> TypeCollection {
374        TypeCollection {
375            expr_types: mem::take(&mut self.expr_types_by_module),
376            variable_types: mem::take(&mut self.variable_types_by_module),
377        }
378    }
379
380    pub fn take_option_coercions(&mut self) -> HashMap<String, HashSet<Span>> {
381        let mut result: HashMap<String, HashSet<Span>> = HashMap::new();
382        let info = mem::take(&mut self.short_circuit_info);
383        for (module, entries) in info {
384            let mut spans: HashSet<Span> = HashSet::new();
385            for (span, entry) in entries {
386                if entry.option_inner.is_some() {
387                    spans.insert(span);
388                }
389            }
390            if !spans.is_empty() {
391                result.insert(module, spans);
392            }
393        }
394
395        result
396    }
397
398    pub fn function_signatures(&self) -> HashMap<String, type_env::FunctionSignature> {
399        self.env.function_signatures()
400    }
401
402    pub fn take_function_signatures(&mut self) -> HashMap<String, type_env::FunctionSignature> {
403        self.env.take_function_signatures()
404    }
405
406    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
407        self.env.struct_definitions()
408    }
409
410    pub fn take_struct_definitions(&mut self) -> HashMap<String, StructDef> {
411        self.env.take_struct_definitions()
412    }
413
414    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
415        self.env.enum_definitions()
416    }
417
418    pub fn take_enum_definitions(&mut self) -> HashMap<String, EnumDef> {
419        self.env.take_enum_definitions()
420    }
421
422    fn register_module_init_locals(&mut self, items: &[Item]) -> Result<()> {
423        let module = match &self.current_module {
424            Some(m) => m.clone(),
425            None => return Ok(()),
426        };
427        let init_name = format!("__init@{}", module);
428        for item in items {
429            if let ItemKind::Function(func) = &item.kind {
430                if func.name == init_name {
431                    for stmt in &func.body {
432                        if let StmtKind::Local {
433                            bindings,
434                            ref mutable,
435                            initializer,
436                        } = &stmt.kind
437                        {
438                            self.check_local_stmt(
439                                bindings.as_slice(),
440                                *mutable,
441                                initializer.as_ref().map(|values| values.as_slice()),
442                            )?;
443                        }
444                    }
445                }
446            }
447        }
448
449        Ok(())
450    }
451
452    pub fn resolve_function_key(&self, name: &str) -> String {
453        if name.contains('.') || name.contains(':') {
454            return name.to_string();
455        }
456
457        if let Some(module) = &self.current_module {
458            if let Some(imports) = self.imports_by_module.get(module) {
459                if let Some(fq) = imports.function_aliases.get(name) {
460                    return fq.clone();
461                }
462            }
463
464            let qualified = format!("{}.{}", module, name);
465            if self.env.lookup_function(&qualified).is_some() {
466                return qualified;
467            }
468
469            if self.env.lookup_function(name).is_some() {
470                return name.to_string();
471            }
472
473            return qualified;
474        }
475
476        name.to_string()
477    }
478
479    pub fn resolve_value_key(&self, name: &str) -> String {
480        if name.contains('.') || name.contains(':') {
481            return name.to_string();
482        }
483
484        if let Some(module) = &self.current_module {
485            if let Some(imports) = self.imports_by_module.get(module) {
486                if let Some(fq) = imports.function_aliases.get(name) {
487                    return fq.clone();
488                }
489            }
490
491            return format!("{}.{}", module, name);
492        }
493
494        name.to_string()
495    }
496
497    pub fn resolve_module_alias(&self, alias: &str) -> Option<String> {
498        if let Some(module) = &self.current_module {
499            if let Some(imports) = self.imports_by_module.get(module) {
500                if let Some(m) = imports.module_aliases.get(alias) {
501                    return Some(m.clone());
502                }
503            }
504        }
505
506        None
507    }
508
509    pub fn register_external_struct(&mut self, mut def: StructDef) -> Result<()> {
510        def.name = self.resolve_type_key(&def.name);
511        for field in &mut def.fields {
512            field.ty = self.canonicalize_type(&field.ty);
513            if let Some(target) = &field.weak_target {
514                field.weak_target = Some(self.canonicalize_type(target));
515            }
516        }
517        self.env.register_struct(&def)
518    }
519
520    pub fn register_external_enum(&mut self, mut def: EnumDef) -> Result<()> {
521        def.name = self.resolve_type_key(&def.name);
522        for variant in &mut def.variants {
523            if let Some(fields) = &mut variant.fields {
524                for field in fields {
525                    *field = self.canonicalize_type(field);
526                }
527            }
528        }
529        self.env.register_enum(&def)
530    }
531
532    pub fn register_external_trait(&mut self, mut def: TraitDef) -> Result<()> {
533        def.name = self.resolve_type_key(&def.name);
534        for method in &mut def.methods {
535            for param in &mut method.params {
536                param.ty = self.canonicalize_type(&param.ty);
537            }
538            if let Some(ret) = method.return_type.clone() {
539                method.return_type = Some(self.canonicalize_type(&ret));
540            }
541        }
542        self.env.register_trait(&def)
543    }
544
545    pub fn register_external_function(
546        &mut self,
547        (name, mut signature): (String, FunctionSignature),
548    ) -> Result<()> {
549        signature.params = signature
550            .params
551            .into_iter()
552            .map(|ty| self.canonicalize_type(&ty))
553            .collect();
554        signature.return_type = self.canonicalize_type(&signature.return_type);
555        let canonical = self.resolve_type_key(&name);
556        self.env.register_or_update_function(canonical, signature)
557    }
558
559    pub fn register_external_constant(&mut self, name: String, ty: Type) -> Result<()> {
560        let canonical_ty = self.canonicalize_type(&ty);
561        let canonical_name = self.resolve_value_key(&name);
562        self.env.register_constant(canonical_name, canonical_ty)
563    }
564
565    pub fn register_external_impl(&mut self, mut impl_block: ImplBlock) -> Result<()> {
566        impl_block.target_type = self.canonicalize_type(&impl_block.target_type);
567        if let Some(trait_name) = &impl_block.trait_name {
568            impl_block.trait_name = Some(self.resolve_type_key(trait_name));
569        }
570        for method in &mut impl_block.methods {
571            for param in &mut method.params {
572                param.ty = self.canonicalize_type(&param.ty);
573            }
574            if let Some(ret) = method.return_type.clone() {
575                method.return_type = Some(self.canonicalize_type(&ret));
576            }
577        }
578
579        let type_name = match &impl_block.target_type.kind {
580            TypeKind::Named(name) => self.resolve_type_key(name),
581            TypeKind::GenericInstance { name, .. } => self.resolve_type_key(name),
582            _ => {
583                return Err(self.type_error(
584                    "Impl target must be a named type when registering from Rust".to_string(),
585                ))
586            }
587        };
588
589        self.env.register_impl(&impl_block);
590        for method in &impl_block.methods {
591            let params: Vec<Type> = method.params.iter().map(|p| p.ty.clone()).collect();
592            let return_type = method
593                .return_type
594                .clone()
595                .unwrap_or(Type::new(TypeKind::Unit, Span::dummy()));
596            let has_self = method.params.iter().any(|p| p.is_self);
597            let canonical_name = if method.name.contains(':') || method.name.contains('.') {
598                self.resolve_type_key(&method.name)
599            } else if has_self {
600                format!("{}:{}", type_name, method.name)
601            } else {
602                format!("{}.{}", type_name, method.name)
603            };
604            #[cfg(all(debug_assertions, feature = "std"))]
605            eprintln!(
606                "register_external_impl canonical method {} (has_self={})",
607                canonical_name, has_self
608            );
609            let signature = FunctionSignature {
610                params,
611                return_type,
612                is_method: has_self,
613            };
614            self.env
615                .register_or_update_function(canonical_name, signature)?;
616        }
617
618        Ok(())
619    }
620
621    pub fn resolve_type_key(&self, name: &str) -> String {
622        if let Some((head, tail)) = name.split_once('.') {
623            if let Some(module) = &self.current_module {
624                if let Some(imports) = self.imports_by_module.get(module) {
625                    if let Some(real_module) = imports.module_aliases.get(head) {
626                        if tail.is_empty() {
627                            return real_module.clone();
628                        } else {
629                            return format!("{}.{}", real_module, tail);
630                        }
631                    }
632                }
633            }
634
635            return name.to_string();
636        }
637
638        if self.env.lookup_struct(name).is_some()
639            || self.env.lookup_enum(name).is_some()
640            || self.env.lookup_trait(name).is_some()
641        {
642            return name.to_string();
643        }
644
645        if self.env.is_builtin_type(name) {
646            return name.to_string();
647        }
648
649        if let Some(module) = &self.current_module {
650            if let Some(imports) = self.imports_by_module.get(module) {
651                if let Some(fq) = imports.type_aliases.get(name) {
652                    return fq.clone();
653                }
654            }
655
656            return format!("{}.{}", module, name);
657        }
658
659        name.to_string()
660    }
661
662    fn register_type_definition(&mut self, item: &Item) -> Result<()> {
663        match &item.kind {
664            ItemKind::Struct(s) => {
665                let mut s2 = s.clone();
666                if let Some(module) = &self.current_module {
667                    if !s2.name.contains('.') {
668                        s2.name = format!("{}.{}", module, s2.name);
669                    }
670                }
671
672                for field in &mut s2.fields {
673                    field.ty = self.canonicalize_type(&field.ty);
674                    if let Some(target) = &field.weak_target {
675                        field.weak_target = Some(self.canonicalize_type(target));
676                    }
677                }
678
679                self.env.register_struct(&s2)?;
680            }
681
682            ItemKind::Enum(e) => {
683                let mut e2 = e.clone();
684                if let Some(module) = &self.current_module {
685                    if !e2.name.contains('.') {
686                        e2.name = format!("{}.{}", module, e2.name);
687                    }
688                }
689
690                for variant in &mut e2.variants {
691                    if let Some(fields) = &mut variant.fields {
692                        for field in fields {
693                            *field = self.canonicalize_type(field);
694                        }
695                    }
696                }
697
698                self.env.register_enum(&e2)?;
699            }
700
701            ItemKind::Trait(t) => {
702                let mut t2 = t.clone();
703                if let Some(module) = &self.current_module {
704                    if !t2.name.contains('.') {
705                        t2.name = format!("{}.{}", module, t2.name);
706                    }
707                }
708
709                for method in &mut t2.methods {
710                    for param in &mut method.params {
711                        param.ty = self.canonicalize_type(&param.ty);
712                    }
713                    if let Some(ret) = method.return_type.clone() {
714                        method.return_type = Some(self.canonicalize_type(&ret));
715                    }
716                }
717
718                self.env.register_trait(&t2)?;
719            }
720
721            ItemKind::TypeAlias {
722                name,
723                type_params,
724                target,
725            } => {
726                let qname = if let Some(module) = &self.current_module {
727                    if name.contains('.') {
728                        name.clone()
729                    } else {
730                        format!("{}.{}", module, name)
731                    }
732                } else {
733                    name.clone()
734                };
735                self.env.register_type_alias(
736                    qname,
737                    type_params.clone(),
738                    self.canonicalize_type(target),
739                )?;
740            }
741
742            ItemKind::Extern { items, .. } => {
743                for ext in items {
744                    match ext {
745                        ExternItem::Struct(def) => {
746                            self.register_external_struct(def.clone())?;
747                        }
748                        ExternItem::Enum(def) => {
749                            self.register_external_enum(def.clone())?;
750                        }
751                        ExternItem::Const { name, ty } => {
752                            let key = self.resolve_value_key(name);
753                            self.env
754                                .register_constant(key, self.canonicalize_type(ty))?;
755                        }
756                        ExternItem::Function { .. } => {}
757                    }
758                }
759            }
760
761            _ => {}
762        }
763
764        Ok(())
765    }
766
767    fn type_error(&self, message: String) -> LustError {
768        LustError::TypeError { message }
769    }
770
771    fn type_error_at(&self, message: String, span: Span) -> LustError {
772        if span.start_line > 0 {
773            LustError::TypeErrorWithSpan {
774                message,
775                line: span.start_line,
776                column: span.start_col,
777                module: self.current_module.clone(),
778            }
779        } else {
780            LustError::TypeError { message }
781        }
782    }
783
784    fn types_equal(&self, t1: &Type, t2: &Type) -> bool {
785        t1.kind == t2.kind
786    }
787
788    pub fn canonicalize_type(&self, ty: &Type) -> Type {
789        use crate::ast::TypeKind as TK;
790        match &ty.kind {
791            TK::Named(name) => Type::new(TK::Named(self.resolve_type_key(name)), ty.span),
792            TK::Array(inner) => {
793                Type::new(TK::Array(Box::new(self.canonicalize_type(inner))), ty.span)
794            }
795
796            TK::Tuple(elements) => Type::new(
797                TK::Tuple(elements.iter().map(|t| self.canonicalize_type(t)).collect()),
798                ty.span,
799            ),
800            TK::Option(inner) => {
801                Type::new(TK::Option(Box::new(self.canonicalize_type(inner))), ty.span)
802            }
803
804            TK::Result(ok, err) => Type::new(
805                TK::Result(
806                    Box::new(self.canonicalize_type(ok)),
807                    Box::new(self.canonicalize_type(err)),
808                ),
809                ty.span,
810            ),
811            TK::Map(k, v) => Type::new(
812                TK::Map(
813                    Box::new(self.canonicalize_type(k)),
814                    Box::new(self.canonicalize_type(v)),
815                ),
816                ty.span,
817            ),
818            TK::Ref(inner) => Type::new(TK::Ref(Box::new(self.canonicalize_type(inner))), ty.span),
819            TK::MutRef(inner) => {
820                Type::new(TK::MutRef(Box::new(self.canonicalize_type(inner))), ty.span)
821            }
822
823            TK::Pointer { mutable, pointee } => Type::new(
824                TK::Pointer {
825                    mutable: *mutable,
826                    pointee: Box::new(self.canonicalize_type(pointee)),
827                },
828                ty.span,
829            ),
830            _ => ty.clone(),
831        }
832    }
833
834    fn unify(&self, expected: &Type, actual: &Type) -> Result<()> {
835        let span = if actual.span.start_line > 0 {
836            Some(actual.span)
837        } else if expected.span.start_line > 0 {
838            Some(expected.span)
839        } else {
840            None
841        };
842        self.unify_at(expected, actual, span)
843    }
844
845    fn unify_at(&self, expected: &Type, actual: &Type, span: Option<Span>) -> Result<()> {
846        if matches!(expected.kind, TypeKind::Unknown) || matches!(actual.kind, TypeKind::Unknown) {
847            return Ok(());
848        }
849
850        if matches!(expected.kind, TypeKind::Infer) || matches!(actual.kind, TypeKind::Infer) {
851            return Ok(());
852        }
853
854        if self.is_lua_multi_return(expected) || self.is_lua_multi_return(actual) {
855            return Ok(());
856        }
857
858        if matches!(&expected.kind, TypeKind::Named(name) if name == "LuaValue")
859            || matches!(&actual.kind, TypeKind::Named(name) if name == "LuaValue")
860        {
861            return Ok(());
862        }
863
864        match (&expected.kind, &actual.kind) {
865            (TypeKind::Union(expected_types), TypeKind::Union(actual_types)) => {
866                if expected_types.len() != actual_types.len() {
867                    return Err(self.type_error(format!(
868                        "Union types have different number of members: expected {}, got {}",
869                        expected_types.len(),
870                        actual_types.len()
871                    )));
872                }
873
874                for exp_type in expected_types {
875                    let mut found = false;
876                    for act_type in actual_types {
877                        if self.types_equal(exp_type, act_type) {
878                            found = true;
879                            break;
880                        }
881                    }
882
883                    if !found {
884                        return Err(match span {
885                            Some(s) => self.type_error_at(
886                                format!(
887                                    "Union type member '{}' not found in actual union",
888                                    exp_type
889                                ),
890                                s,
891                            ),
892                            None => self.type_error(format!(
893                                "Union type member '{}' not found in actual union",
894                                exp_type
895                            )),
896                        });
897                    }
898                }
899
900                return Ok(());
901            }
902
903            (TypeKind::Union(expected_types), _) => {
904                for union_member in expected_types {
905                    if self.unify(union_member, actual).is_ok() {
906                        return Ok(());
907                    }
908                }
909
910                return Err(match span {
911                    Some(s) => self.type_error_at(
912                        format!("Type '{}' is not compatible with union type", actual),
913                        s,
914                    ),
915                    None => self.type_error(format!(
916                        "Type '{}' is not compatible with union type",
917                        actual
918                    )),
919                });
920            }
921
922            (_, TypeKind::Union(actual_types)) => {
923                for union_member in actual_types {
924                    self.unify(expected, union_member)?;
925                }
926
927                return Ok(());
928            }
929
930            _ => {}
931        }
932
933        match (&expected.kind, &actual.kind) {
934            (TypeKind::Tuple(expected_elems), TypeKind::Tuple(actual_elems)) => {
935                if expected_elems.len() != actual_elems.len() {
936                    return Err(match span {
937                        Some(s) => self.type_error_at(
938                            format!(
939                                "Tuple length mismatch: expected {} element(s), got {}",
940                                expected_elems.len(),
941                                actual_elems.len()
942                            ),
943                            s,
944                        ),
945                        None => self.type_error(format!(
946                            "Tuple length mismatch: expected {} element(s), got {}",
947                            expected_elems.len(),
948                            actual_elems.len()
949                        )),
950                    });
951                }
952
953                for (exp_elem, act_elem) in expected_elems.iter().zip(actual_elems.iter()) {
954                    self.unify(exp_elem, act_elem)?;
955                }
956
957                return Ok(());
958            }
959
960            (TypeKind::Tuple(_), _) | (_, TypeKind::Tuple(_)) => {
961                return Err(match span {
962                    Some(s) => self.type_error_at(
963                        format!("Tuple type is not compatible with type '{}'", actual),
964                        s,
965                    ),
966                    None => self.type_error(format!(
967                        "Tuple type is not compatible with type '{}'",
968                        actual
969                    )),
970                })
971            }
972
973            (TypeKind::Named(name), TypeKind::Array(_))
974            | (TypeKind::Array(_), TypeKind::Named(name))
975                if name == "Array" =>
976            {
977                return Ok(());
978            }
979
980            (TypeKind::Array(exp_el), TypeKind::Array(act_el)) => {
981                if matches!(exp_el.kind, TypeKind::Unknown | TypeKind::Infer)
982                    || matches!(act_el.kind, TypeKind::Unknown | TypeKind::Infer)
983                {
984                    return Ok(());
985                } else {
986                    return self.unify(exp_el, act_el);
987                }
988            }
989
990            (TypeKind::Map(exp_key, exp_value), TypeKind::Map(act_key, act_value)) => {
991                self.unify(exp_key, act_key)?;
992                return self.unify(exp_value, act_value);
993            }
994
995            (TypeKind::Named(name), TypeKind::Option(_))
996            | (TypeKind::Option(_), TypeKind::Named(name))
997                if name == "Option" =>
998            {
999                return Ok(());
1000            }
1001
1002            (TypeKind::Option(exp_inner), TypeKind::Option(act_inner)) => {
1003                if matches!(exp_inner.kind, TypeKind::Unknown | TypeKind::Infer)
1004                    || matches!(act_inner.kind, TypeKind::Unknown | TypeKind::Infer)
1005                {
1006                    return Ok(());
1007                } else {
1008                    return self.unify(exp_inner, act_inner);
1009                }
1010            }
1011
1012            (TypeKind::Named(name), TypeKind::Result(_, _))
1013            | (TypeKind::Result(_, _), TypeKind::Named(name))
1014                if name == "Result" =>
1015            {
1016                return Ok(());
1017            }
1018
1019            (TypeKind::Result(exp_ok, exp_err), TypeKind::Result(act_ok, act_err)) => {
1020                if matches!(exp_ok.kind, TypeKind::Unknown | TypeKind::Infer)
1021                    || matches!(act_ok.kind, TypeKind::Unknown | TypeKind::Infer)
1022                {
1023                    if matches!(exp_err.kind, TypeKind::Unknown | TypeKind::Infer)
1024                        || matches!(act_err.kind, TypeKind::Unknown | TypeKind::Infer)
1025                    {
1026                        return Ok(());
1027                    } else {
1028                        return self.unify(exp_err, act_err);
1029                    }
1030                } else {
1031                    self.unify(exp_ok, act_ok)?;
1032                    return self.unify(exp_err, act_err);
1033                }
1034            }
1035
1036            _ => {}
1037        }
1038
1039        if self.types_equal(expected, actual) {
1040            Ok(())
1041        } else {
1042            Err(match span {
1043                Some(s) => self.type_error_at(
1044                    format!("Type mismatch: expected '{}', got '{}'", expected, actual),
1045                    s,
1046                ),
1047                None => self.type_error(format!(
1048                    "Type mismatch: expected '{}', got '{}'",
1049                    expected, actual
1050                )),
1051            })
1052        }
1053    }
1054
1055    fn types_compatible(&self, expected: &Type, actual: &Type) -> bool {
1056        if matches!(expected.kind, TypeKind::Unknown) || matches!(actual.kind, TypeKind::Unknown) {
1057            return true;
1058        }
1059
1060        if matches!(expected.kind, TypeKind::Infer) || matches!(actual.kind, TypeKind::Infer) {
1061            return true;
1062        }
1063
1064        match (&expected.kind, &actual.kind) {
1065            (TypeKind::Generic(_), TypeKind::Generic(_)) => return true,
1066            (TypeKind::Generic(_), _) | (_, TypeKind::Generic(_)) => return true,
1067            _ => {}
1068        }
1069
1070        match (&expected.kind, &actual.kind) {
1071            (TypeKind::Array(e1), TypeKind::Array(e2)) => {
1072                return self.types_compatible(e1, e2);
1073            }
1074
1075            (TypeKind::Named(name), TypeKind::Array(_))
1076            | (TypeKind::Array(_), TypeKind::Named(name))
1077                if name == "Array" =>
1078            {
1079                return true;
1080            }
1081
1082            _ => {}
1083        }
1084
1085        match (&expected.kind, &actual.kind) {
1086            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
1087                return self.types_compatible(k1, k2) && self.types_compatible(v1, v2);
1088            }
1089
1090            _ => {}
1091        }
1092
1093        match (&expected.kind, &actual.kind) {
1094            (TypeKind::Option(t1), TypeKind::Option(t2)) => {
1095                return self.types_compatible(t1, t2);
1096            }
1097
1098            (TypeKind::Named(name), TypeKind::Option(_))
1099            | (TypeKind::Option(_), TypeKind::Named(name))
1100                if name == "Option" =>
1101            {
1102                return true;
1103            }
1104
1105            _ => {}
1106        }
1107
1108        match (&expected.kind, &actual.kind) {
1109            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
1110                return self.types_compatible(ok1, ok2) && self.types_compatible(err1, err2);
1111            }
1112
1113            (TypeKind::Named(name), TypeKind::Result(_, _))
1114            | (TypeKind::Result(_, _), TypeKind::Named(name))
1115                if name == "Result" =>
1116            {
1117                return true;
1118            }
1119
1120            _ => {}
1121        }
1122
1123        match (&expected.kind, &actual.kind) {
1124            (
1125                TypeKind::Function {
1126                    params: p1,
1127                    return_type: r1,
1128                },
1129                TypeKind::Function {
1130                    params: p2,
1131                    return_type: r2,
1132                },
1133            ) => {
1134                if p1.len() != p2.len() {
1135                    return false;
1136                }
1137
1138                for (t1, t2) in p1.iter().zip(p2.iter()) {
1139                    if !self.types_compatible(t1, t2) {
1140                        return false;
1141                    }
1142                }
1143
1144                return self.types_compatible(r1, r2);
1145            }
1146
1147            _ => {}
1148        }
1149
1150        self.types_equal(expected, actual)
1151    }
1152
1153    fn unify_with_bounds(&self, expected: &Type, actual: &Type) -> Result<()> {
1154        if let TypeKind::Generic(type_param) = &expected.kind {
1155            if let Some(trait_names) = self.current_trait_bounds.get(type_param) {
1156                for trait_name in trait_names {
1157                    if !self.env.type_implements_trait(actual, trait_name) {
1158                        return Err(self.type_error(format!(
1159                            "Type '{}' does not implement required trait '{}'",
1160                            actual, trait_name
1161                        )));
1162                    }
1163                }
1164
1165                return Ok(());
1166            }
1167
1168            return Ok(());
1169        }
1170
1171        self.unify(expected, actual)
1172    }
1173
1174    fn is_lua_multi_return(&self, ty: &Type) -> bool {
1175        if let TypeKind::Array(inner) = &ty.kind {
1176            return matches!(inner.kind, TypeKind::Unknown)
1177                || matches!(&inner.kind, TypeKind::Named(name) if name == "LuaValue");
1178        }
1179        false
1180    }
1181
1182    fn record_short_circuit_info(&mut self, span: Span, info: &ShortCircuitInfo) {
1183        let truthy = info.truthy.as_ref().map(|ty| self.canonicalize_type(ty));
1184        let falsy = info.falsy.as_ref().map(|ty| self.canonicalize_type(ty));
1185        let option_inner = info
1186            .option_inner
1187            .as_ref()
1188            .map(|ty| self.canonicalize_type(ty));
1189        let module_key = self.current_module_key();
1190        self.short_circuit_info
1191            .entry(module_key)
1192            .or_default()
1193            .insert(
1194                span,
1195                ShortCircuitInfo {
1196                    truthy,
1197                    falsy,
1198                    option_inner,
1199                },
1200            );
1201    }
1202
1203    fn short_circuit_profile(&self, expr: &Expr, ty: &Type) -> ShortCircuitInfo {
1204        let module_key = self
1205            .current_module
1206            .as_ref()
1207            .map(String::as_str)
1208            .unwrap_or("");
1209        if let Some(module_map) = self.short_circuit_info.get(module_key) {
1210            if let Some(info) = module_map.get(&expr.span) {
1211                return info.clone();
1212            }
1213        }
1214
1215        ShortCircuitInfo {
1216            truthy: if self.type_can_be_truthy(ty) {
1217                Some(self.canonicalize_type(ty))
1218            } else {
1219                None
1220            },
1221            falsy: self.extract_falsy_type(ty),
1222            option_inner: None,
1223        }
1224    }
1225
1226    fn current_module_key(&self) -> String {
1227        self.current_module
1228            .as_ref()
1229            .cloned()
1230            .unwrap_or_else(|| "".to_string())
1231    }
1232
1233    fn clear_option_for_span(&mut self, span: Span) {
1234        let module_key = self.current_module_key();
1235        if let Some(module_map) = self.short_circuit_info.get_mut(&module_key) {
1236            if let Some(info) = module_map.get_mut(&span) {
1237                info.option_inner = None;
1238            }
1239        }
1240    }
1241
1242    fn type_can_be_truthy(&self, ty: &Type) -> bool {
1243        match &ty.kind {
1244            TypeKind::Union(members) => {
1245                members.iter().any(|member| self.type_can_be_truthy(member))
1246            }
1247            TypeKind::Bool => true,
1248            TypeKind::Unknown => true,
1249            _ => true,
1250        }
1251    }
1252
1253    fn type_can_be_falsy(&self, ty: &Type) -> bool {
1254        match &ty.kind {
1255            TypeKind::Union(members) => members.iter().any(|member| self.type_can_be_falsy(member)),
1256            TypeKind::Bool => true,
1257            TypeKind::Unknown => true,
1258            TypeKind::Option(_) => true,
1259            _ => false,
1260        }
1261    }
1262
1263    fn extract_falsy_type(&self, ty: &Type) -> Option<Type> {
1264        match &ty.kind {
1265            TypeKind::Bool => Some(Type::new(TypeKind::Bool, ty.span)),
1266            TypeKind::Unknown => Some(Type::new(TypeKind::Unknown, ty.span)),
1267            TypeKind::Option(inner) => Some(Type::new(
1268                TypeKind::Option(Box::new(self.canonicalize_type(inner))),
1269                ty.span,
1270            )),
1271            TypeKind::Union(members) => {
1272                let mut parts = Vec::new();
1273                for member in members {
1274                    if let Some(part) = self.extract_falsy_type(member) {
1275                        parts.push(part);
1276                    }
1277                }
1278                self.merge_optional_types(parts)
1279            }
1280            _ => None,
1281        }
1282    }
1283
1284    fn merge_optional_types(&self, types: Vec<Type>) -> Option<Type> {
1285        if types.is_empty() {
1286            return None;
1287        }
1288
1289        Some(self.make_union_from_types(types))
1290    }
1291
1292    fn make_union_from_types(&self, types: Vec<Type>) -> Type {
1293        let mut flat: Vec<Type> = Vec::new();
1294        for ty in types {
1295            let canonical = self.canonicalize_type(&ty);
1296            match &canonical.kind {
1297                TypeKind::Union(members) => {
1298                    for member in members {
1299                        self.push_unique_type(&mut flat, member.clone());
1300                    }
1301                }
1302                _ => self.push_unique_type(&mut flat, canonical),
1303            }
1304        }
1305
1306        match flat.len() {
1307            0 => Type::new(TypeKind::Unknown, Self::dummy_span()),
1308            1 => flat.into_iter().next().unwrap(),
1309            _ => Type::new(TypeKind::Union(flat), Self::dummy_span()),
1310        }
1311    }
1312
1313    fn push_unique_type(&self, list: &mut Vec<Type>, candidate: Type) {
1314        if !list
1315            .iter()
1316            .any(|existing| self.types_equal(existing, &candidate))
1317        {
1318            list.push(candidate);
1319        }
1320    }
1321
1322    fn combine_truthy_falsy(&self, truthy: Option<Type>, falsy: Option<Type>) -> Type {
1323        match (truthy, falsy) {
1324            (Some(t), Some(f)) => self.make_union_from_types(vec![t, f]),
1325            (Some(t), None) => t,
1326            (None, Some(f)) => f,
1327            (None, None) => Type::new(TypeKind::Unknown, Self::dummy_span()),
1328        }
1329    }
1330
1331    fn is_bool_like(&self, ty: &Type) -> bool {
1332        match &ty.kind {
1333            TypeKind::Bool => true,
1334            TypeKind::Union(members) => members.iter().all(|member| self.is_bool_like(member)),
1335            _ => false,
1336        }
1337    }
1338
1339    fn option_inner_type<'a>(&self, ty: &'a Type) -> Option<&'a Type> {
1340        match &ty.kind {
1341            TypeKind::Option(inner) => Some(inner.as_ref()),
1342            TypeKind::Union(members) => {
1343                for member in members {
1344                    if let Some(inner) = self.option_inner_type(member) {
1345                        return Some(inner);
1346                    }
1347                }
1348                None
1349            }
1350            _ => None,
1351        }
1352    }
1353
1354    fn should_optionize(&self, left: &Type, right: &Type) -> bool {
1355        self.is_bool_like(left)
1356            && !self.is_bool_like(right)
1357            && self.option_inner_type(right).is_none()
1358    }
1359}