lust/typechecker/
mod.rs

1mod expr_checker;
2mod item_checker;
3mod stmt_checker;
4mod type_env;
5use crate::modules::{LoadedModule, ModuleImports};
6use crate::{
7    ast::*,
8    config::LustConfig,
9    error::{LustError, Result},
10};
11pub(super) use alloc::{
12    boxed::Box,
13    format,
14    string::{String, ToString},
15    vec,
16    vec::Vec,
17};
18use core::mem;
19use hashbrown::{HashMap, HashSet};
20pub use type_env::FunctionSignature;
21pub use type_env::TypeEnv;
22pub struct TypeChecker {
23    env: TypeEnv,
24    current_function_return_type: Option<Type>,
25    in_loop: bool,
26    pending_generic_instances: Option<HashMap<String, Type>>,
27    expected_lambda_signature: Option<(Vec<Type>, Option<Type>)>,
28    current_trait_bounds: HashMap<String, Vec<String>>,
29    current_module: Option<String>,
30    imports_by_module: HashMap<String, ModuleImports>,
31    expr_types_by_module: HashMap<String, HashMap<Span, Type>>,
32    variable_types_by_module: HashMap<String, HashMap<Span, Type>>,
33    short_circuit_info: HashMap<String, HashMap<Span, ShortCircuitInfo>>,
34}
35
36pub struct TypeCollection {
37    pub expr_types: HashMap<String, HashMap<Span, Type>>,
38    pub variable_types: HashMap<String, HashMap<Span, Type>>,
39}
40
41#[derive(Clone, Debug)]
42struct ShortCircuitInfo {
43    truthy: Option<Type>,
44    falsy: Option<Type>,
45    option_inner: Option<Type>,
46}
47
48impl TypeChecker {
49    pub fn new() -> Self {
50        Self::with_config(&LustConfig::default())
51    }
52
53    pub fn with_config(config: &LustConfig) -> Self {
54        Self {
55            env: TypeEnv::with_config(config),
56            current_function_return_type: None,
57            in_loop: false,
58            pending_generic_instances: None,
59            expected_lambda_signature: None,
60            current_trait_bounds: HashMap::new(),
61            current_module: None,
62            imports_by_module: HashMap::new(),
63            expr_types_by_module: HashMap::new(),
64            variable_types_by_module: HashMap::new(),
65            short_circuit_info: HashMap::new(),
66        }
67    }
68
69    fn dummy_span() -> Span {
70        Span::new(0, 0, 0, 0)
71    }
72
73    pub fn check_module(&mut self, items: &[Item]) -> Result<()> {
74        for item in items {
75            self.register_type_definition(item)?;
76        }
77
78        self.validate_struct_cycles()?;
79        self.env.push_scope();
80        self.register_module_init_locals(items)?;
81        for item in items {
82            self.check_item(item)?;
83        }
84
85        self.env.pop_scope();
86        Ok(())
87    }
88
89    pub fn check_program(&mut self, modules: &[LoadedModule]) -> Result<()> {
90        for m in modules {
91            self.current_module = Some(m.path.clone());
92            for item in &m.items {
93                self.register_type_definition(item)?;
94            }
95        }
96
97        self.validate_struct_cycles()?;
98        for m in modules {
99            self.current_module = Some(m.path.clone());
100            self.env.push_scope();
101            self.register_module_init_locals(&m.items)?;
102            for item in &m.items {
103                self.check_item(item)?;
104            }
105
106            self.env.pop_scope();
107        }
108
109        self.current_module = None;
110        Ok(())
111    }
112
113    fn validate_struct_cycles(&self) -> Result<()> {
114        use hashbrown::{HashMap, HashSet};
115        let struct_defs = self.env.struct_definitions();
116        if struct_defs.is_empty() {
117            return Ok(());
118        }
119
120        let mut simple_to_full: HashMap<String, Vec<String>> = HashMap::new();
121        for name in struct_defs.keys() {
122            let simple = name.rsplit('.').next().unwrap_or(name).to_string();
123            simple_to_full.entry(simple).or_default().push(name.clone());
124        }
125
126        let mut struct_has_weak: HashMap<String, bool> = HashMap::new();
127        for (name, def) in &struct_defs {
128            let has_weak = def
129                .fields
130                .iter()
131                .any(|field| matches!(field.ownership, FieldOwnership::Weak));
132            struct_has_weak.insert(name.clone(), has_weak);
133        }
134
135        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
136        for (name, def) in &struct_defs {
137            let module_prefix = name.rsplit_once('.').map(|(module, _)| module.to_string());
138            let mut edges: HashSet<String> = HashSet::new();
139            for field in &def.fields {
140                if matches!(field.ownership, FieldOwnership::Weak) {
141                    let target = field.weak_target.as_ref().ok_or_else(|| {
142                        self.type_error(format!(
143                            "Field '{}.{}' is marked as 'ref' but has no target type",
144                            name, field.name
145                        ))
146                    })?;
147                    let target_name = if let TypeKind::Named(inner) = &target.kind {
148                        inner
149                    } else {
150                        return Err(self.type_error(format!(
151                            "Field '{}.{}' uses 'ref' but only struct types are supported",
152                            name, field.name
153                        )));
154                    };
155                    let resolved = self.resolve_struct_name_for_cycle(
156                        target_name.as_str(),
157                        module_prefix.as_deref(),
158                        &struct_defs,
159                        &simple_to_full,
160                    );
161                    if resolved.is_none() {
162                        return Err(self.type_error(format!(
163                            "Field '{}.{}' uses 'ref' but '{}' is not a known struct type",
164                            name, field.name, target_name
165                        )));
166                    }
167
168                    continue;
169                }
170
171                self.collect_strong_struct_targets(
172                    &field.ty,
173                    module_prefix.as_deref(),
174                    &struct_defs,
175                    &simple_to_full,
176                    &mut edges,
177                );
178            }
179
180            graph.insert(name.clone(), edges.into_iter().collect());
181        }
182
183        fn dfs(
184            node: &str,
185            graph: &HashMap<String, Vec<String>>,
186            visited: &mut HashSet<String>,
187            on_stack: &mut HashSet<String>,
188            stack: &mut Vec<String>,
189        ) -> Option<Vec<String>> {
190            visited.insert(node.to_string());
191            on_stack.insert(node.to_string());
192            stack.push(node.to_string());
193            if let Some(neighbors) = graph.get(node) {
194                for neighbor in neighbors {
195                    if !visited.contains(neighbor) {
196                        if let Some(cycle) = dfs(neighbor, graph, visited, on_stack, stack) {
197                            return Some(cycle);
198                        }
199                    } else if on_stack.contains(neighbor) {
200                        if let Some(pos) = stack.iter().position(|n| n == neighbor) {
201                            let mut cycle = stack[pos..].to_vec();
202                            cycle.push(neighbor.clone());
203                            return Some(cycle);
204                        }
205                    }
206                }
207            }
208
209            stack.pop();
210            on_stack.remove(node);
211            None
212        }
213
214        let mut visited: HashSet<String> = HashSet::new();
215        let mut on_stack: HashSet<String> = HashSet::new();
216        let mut stack: Vec<String> = Vec::new();
217        for name in struct_defs.keys() {
218            if !visited.contains(name) {
219                if let Some(cycle) = dfs(name, &graph, &mut visited, &mut on_stack, &mut stack) {
220                    let contains_weak = cycle
221                        .iter()
222                        .any(|node| struct_has_weak.get(node).copied().unwrap_or(false));
223                    if contains_weak {
224                        continue;
225                    }
226
227                    let description = cycle.join(" -> ");
228                    return Err(self.type_error(format!(
229                        "Strong ownership cycle detected: {}. Mark at least one field as 'ref' to break the cycle.",
230                        description
231                    )));
232                }
233            }
234        }
235
236        Ok(())
237    }
238
239    fn collect_strong_struct_targets(
240        &self,
241        ty: &Type,
242        parent_module: Option<&str>,
243        struct_defs: &HashMap<String, StructDef>,
244        simple_to_full: &HashMap<String, Vec<String>>,
245        out: &mut HashSet<String>,
246    ) {
247        match &ty.kind {
248            TypeKind::Named(name) => {
249                if let Some(resolved) = self.resolve_struct_name_for_cycle(
250                    name,
251                    parent_module,
252                    struct_defs,
253                    simple_to_full,
254                ) {
255                    out.insert(resolved);
256                }
257            }
258
259            TypeKind::Array(inner)
260            | TypeKind::Ref(inner)
261            | TypeKind::MutRef(inner)
262            | TypeKind::Option(inner) => {
263                self.collect_strong_struct_targets(
264                    inner,
265                    parent_module,
266                    struct_defs,
267                    simple_to_full,
268                    out,
269                );
270            }
271
272            TypeKind::Map(key, value) => {
273                self.collect_strong_struct_targets(
274                    key,
275                    parent_module,
276                    struct_defs,
277                    simple_to_full,
278                    out,
279                );
280                self.collect_strong_struct_targets(
281                    value,
282                    parent_module,
283                    struct_defs,
284                    simple_to_full,
285                    out,
286                );
287            }
288
289            TypeKind::Tuple(elements) | TypeKind::Union(elements) => {
290                for element in elements {
291                    self.collect_strong_struct_targets(
292                        element,
293                        parent_module,
294                        struct_defs,
295                        simple_to_full,
296                        out,
297                    );
298                }
299            }
300
301            TypeKind::Result(ok, err) => {
302                self.collect_strong_struct_targets(
303                    ok,
304                    parent_module,
305                    struct_defs,
306                    simple_to_full,
307                    out,
308                );
309                self.collect_strong_struct_targets(
310                    err,
311                    parent_module,
312                    struct_defs,
313                    simple_to_full,
314                    out,
315                );
316            }
317
318            TypeKind::GenericInstance { type_args, .. } => {
319                for arg in type_args {
320                    self.collect_strong_struct_targets(
321                        arg,
322                        parent_module,
323                        struct_defs,
324                        simple_to_full,
325                        out,
326                    );
327                }
328            }
329
330            _ => {}
331        }
332    }
333
334    fn resolve_struct_name_for_cycle(
335        &self,
336        name: &str,
337        parent_module: Option<&str>,
338        struct_defs: &HashMap<String, StructDef>,
339        simple_to_full: &HashMap<String, Vec<String>>,
340    ) -> Option<String> {
341        if struct_defs.contains_key(name) {
342            return Some(name.to_string());
343        }
344
345        if name.contains('.') {
346            return None;
347        }
348
349        if let Some(candidates) = simple_to_full.get(name) {
350            if candidates.len() == 1 {
351                return Some(candidates[0].clone());
352            }
353
354            if let Some(module) = parent_module {
355                for candidate in candidates {
356                    if let Some((candidate_module, _)) = candidate.rsplit_once('.') {
357                        if candidate_module == module {
358                            return Some(candidate.clone());
359                        }
360                    }
361                }
362            }
363        }
364
365        None
366    }
367
368    pub fn set_imports_by_module(&mut self, map: HashMap<String, ModuleImports>) {
369        self.imports_by_module = map;
370    }
371
372    pub fn take_type_info(&mut self) -> TypeCollection {
373        TypeCollection {
374            expr_types: mem::take(&mut self.expr_types_by_module),
375            variable_types: mem::take(&mut self.variable_types_by_module),
376        }
377    }
378
379    pub fn take_option_coercions(&mut self) -> HashMap<String, HashSet<Span>> {
380        let mut result: HashMap<String, HashSet<Span>> = HashMap::new();
381        let info = mem::take(&mut self.short_circuit_info);
382        for (module, entries) in info {
383            let mut spans: HashSet<Span> = HashSet::new();
384            for (span, entry) in entries {
385                if entry.option_inner.is_some() {
386                    spans.insert(span);
387                }
388            }
389            if !spans.is_empty() {
390                result.insert(module, spans);
391            }
392        }
393
394        result
395    }
396
397    pub fn function_signatures(&self) -> HashMap<String, type_env::FunctionSignature> {
398        self.env.function_signatures()
399    }
400
401    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
402        self.env.struct_definitions()
403    }
404
405    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
406        self.env.enum_definitions()
407    }
408
409    fn register_module_init_locals(&mut self, items: &[Item]) -> Result<()> {
410        let module = match &self.current_module {
411            Some(m) => m.clone(),
412            None => return Ok(()),
413        };
414        let init_name = format!("__init@{}", module);
415        for item in items {
416            if let ItemKind::Function(func) = &item.kind {
417                if func.name == init_name {
418                    for stmt in &func.body {
419                        if let StmtKind::Local {
420                            bindings,
421                            ref mutable,
422                            initializer,
423                        } = &stmt.kind
424                        {
425                            self.check_local_stmt(
426                                bindings.as_slice(),
427                                *mutable,
428                                initializer.as_ref().map(|values| values.as_slice()),
429                            )?;
430                        }
431                    }
432                }
433            }
434        }
435
436        Ok(())
437    }
438
439    pub fn resolve_function_key(&self, name: &str) -> String {
440        if name.contains('.') || name.contains(':') {
441            return name.to_string();
442        }
443
444        if let Some(module) = &self.current_module {
445            if let Some(imports) = self.imports_by_module.get(module) {
446                if let Some(fq) = imports.function_aliases.get(name) {
447                    return fq.clone();
448                }
449            }
450
451            let qualified = format!("{}.{}", module, name);
452            if self.env.lookup_function(&qualified).is_some() {
453                return qualified;
454            }
455
456            if self.env.lookup_function(name).is_some() {
457                return name.to_string();
458            }
459
460            return qualified;
461        }
462
463        name.to_string()
464    }
465
466    pub fn resolve_module_alias(&self, alias: &str) -> Option<String> {
467        if let Some(module) = &self.current_module {
468            if let Some(imports) = self.imports_by_module.get(module) {
469                if let Some(m) = imports.module_aliases.get(alias) {
470                    return Some(m.clone());
471                }
472            }
473        }
474
475        None
476    }
477
478    pub fn register_external_struct(&mut self, mut def: StructDef) -> Result<()> {
479        def.name = self.resolve_type_key(&def.name);
480        for field in &mut def.fields {
481            field.ty = self.canonicalize_type(&field.ty);
482            if let Some(target) = &field.weak_target {
483                field.weak_target = Some(self.canonicalize_type(target));
484            }
485        }
486        self.env.register_struct(&def)
487    }
488
489    pub fn register_external_enum(&mut self, mut def: EnumDef) -> Result<()> {
490        def.name = self.resolve_type_key(&def.name);
491        for variant in &mut def.variants {
492            if let Some(fields) = &mut variant.fields {
493                for field in fields {
494                    *field = self.canonicalize_type(field);
495                }
496            }
497        }
498        self.env.register_enum(&def)
499    }
500
501    pub fn register_external_trait(&mut self, mut def: TraitDef) -> Result<()> {
502        def.name = self.resolve_type_key(&def.name);
503        for method in &mut def.methods {
504            for param in &mut method.params {
505                param.ty = self.canonicalize_type(&param.ty);
506            }
507            if let Some(ret) = method.return_type.clone() {
508                method.return_type = Some(self.canonicalize_type(&ret));
509            }
510        }
511        self.env.register_trait(&def)
512    }
513
514    pub fn register_external_function(
515        &mut self,
516        (name, mut signature): (String, FunctionSignature),
517    ) -> Result<()> {
518        signature.params = signature
519            .params
520            .into_iter()
521            .map(|ty| self.canonicalize_type(&ty))
522            .collect();
523        signature.return_type = self.canonicalize_type(&signature.return_type);
524        let canonical = self.resolve_type_key(&name);
525        self.env.register_or_update_function(canonical, signature)
526    }
527
528    pub fn register_external_impl(&mut self, mut impl_block: ImplBlock) -> Result<()> {
529        impl_block.target_type = self.canonicalize_type(&impl_block.target_type);
530        if let Some(trait_name) = &impl_block.trait_name {
531            impl_block.trait_name = Some(self.resolve_type_key(trait_name));
532        }
533        for method in &mut impl_block.methods {
534            for param in &mut method.params {
535                param.ty = self.canonicalize_type(&param.ty);
536            }
537            if let Some(ret) = method.return_type.clone() {
538                method.return_type = Some(self.canonicalize_type(&ret));
539            }
540        }
541
542        let type_name = match &impl_block.target_type.kind {
543            TypeKind::Named(name) => self.resolve_type_key(name),
544            TypeKind::GenericInstance { name, .. } => self.resolve_type_key(name),
545            _ => {
546                return Err(self.type_error(
547                    "Impl target must be a named type when registering from Rust".to_string(),
548                ))
549            }
550        };
551
552        self.env.register_impl(&impl_block);
553        for method in &impl_block.methods {
554            let params: Vec<Type> = method.params.iter().map(|p| p.ty.clone()).collect();
555            let return_type = method
556                .return_type
557                .clone()
558                .unwrap_or(Type::new(TypeKind::Unit, Span::dummy()));
559            let has_self = method.params.iter().any(|p| p.is_self);
560            let canonical_name = if method.name.contains(':') || method.name.contains('.') {
561                self.resolve_type_key(&method.name)
562            } else if has_self {
563                format!("{}:{}", type_name, method.name)
564            } else {
565                format!("{}.{}", type_name, method.name)
566            };
567            #[cfg(debug_assertions)]
568            eprintln!(
569                "register_external_impl canonical method {} (has_self={})",
570                canonical_name, has_self
571            );
572            let signature = FunctionSignature {
573                params,
574                return_type,
575                is_method: has_self,
576            };
577            self.env
578                .register_or_update_function(canonical_name, signature)?;
579        }
580
581        Ok(())
582    }
583
584    pub fn resolve_type_key(&self, name: &str) -> String {
585        if let Some((head, tail)) = name.split_once('.') {
586            if let Some(module) = &self.current_module {
587                if let Some(imports) = self.imports_by_module.get(module) {
588                    if let Some(real_module) = imports.module_aliases.get(head) {
589                        if tail.is_empty() {
590                            return real_module.clone();
591                        } else {
592                            return format!("{}.{}", real_module, tail);
593                        }
594                    }
595                }
596            }
597
598            return name.to_string();
599        }
600
601        if self.env.lookup_struct(name).is_some()
602            || self.env.lookup_enum(name).is_some()
603            || self.env.lookup_trait(name).is_some()
604        {
605            return name.to_string();
606        }
607
608        if self.env.is_builtin_type(name) {
609            return name.to_string();
610        }
611
612        if let Some(module) = &self.current_module {
613            if let Some(imports) = self.imports_by_module.get(module) {
614                if let Some(fq) = imports.type_aliases.get(name) {
615                    return fq.clone();
616                }
617            }
618
619            return format!("{}.{}", module, name);
620        }
621
622        name.to_string()
623    }
624
625    fn register_type_definition(&mut self, item: &Item) -> Result<()> {
626        match &item.kind {
627            ItemKind::Struct(s) => {
628                let mut s2 = s.clone();
629                if let Some(module) = &self.current_module {
630                    if !s2.name.contains('.') {
631                        s2.name = format!("{}.{}", module, s2.name);
632                    }
633                }
634
635                self.env.register_struct(&s2)?;
636            }
637
638            ItemKind::Enum(e) => {
639                let mut e2 = e.clone();
640                if let Some(module) = &self.current_module {
641                    if !e2.name.contains('.') {
642                        e2.name = format!("{}.{}", module, e2.name);
643                    }
644                }
645
646                self.env.register_enum(&e2)?;
647            }
648
649            ItemKind::Trait(t) => {
650                let mut t2 = t.clone();
651                if let Some(module) = &self.current_module {
652                    if !t2.name.contains('.') {
653                        t2.name = format!("{}.{}", module, t2.name);
654                    }
655                }
656
657                self.env.register_trait(&t2)?;
658            }
659
660            ItemKind::TypeAlias {
661                name,
662                type_params,
663                target,
664            } => {
665                let qname = if let Some(module) = &self.current_module {
666                    if name.contains('.') {
667                        name.clone()
668                    } else {
669                        format!("{}.{}", module, name)
670                    }
671                } else {
672                    name.clone()
673                };
674                self.env
675                    .register_type_alias(qname, type_params.clone(), target.clone())?;
676            }
677
678            _ => {}
679        }
680
681        Ok(())
682    }
683
684    fn type_error(&self, message: String) -> LustError {
685        LustError::TypeError { message }
686    }
687
688    fn type_error_at(&self, message: String, span: Span) -> LustError {
689        if span.start_line > 0 {
690            LustError::TypeErrorWithSpan {
691                message,
692                line: span.start_line,
693                column: span.start_col,
694                module: self.current_module.clone(),
695            }
696        } else {
697            LustError::TypeError { message }
698        }
699    }
700
701    fn types_equal(&self, t1: &Type, t2: &Type) -> bool {
702        t1.kind == t2.kind
703    }
704
705    pub fn canonicalize_type(&self, ty: &Type) -> Type {
706        use crate::ast::TypeKind as TK;
707        match &ty.kind {
708            TK::Named(name) => Type::new(TK::Named(self.resolve_type_key(name)), ty.span),
709            TK::Array(inner) => {
710                Type::new(TK::Array(Box::new(self.canonicalize_type(inner))), ty.span)
711            }
712
713            TK::Tuple(elements) => Type::new(
714                TK::Tuple(elements.iter().map(|t| self.canonicalize_type(t)).collect()),
715                ty.span,
716            ),
717            TK::Option(inner) => {
718                Type::new(TK::Option(Box::new(self.canonicalize_type(inner))), ty.span)
719            }
720
721            TK::Result(ok, err) => Type::new(
722                TK::Result(
723                    Box::new(self.canonicalize_type(ok)),
724                    Box::new(self.canonicalize_type(err)),
725                ),
726                ty.span,
727            ),
728            TK::Map(k, v) => Type::new(
729                TK::Map(
730                    Box::new(self.canonicalize_type(k)),
731                    Box::new(self.canonicalize_type(v)),
732                ),
733                ty.span,
734            ),
735            TK::Ref(inner) => Type::new(TK::Ref(Box::new(self.canonicalize_type(inner))), ty.span),
736            TK::MutRef(inner) => {
737                Type::new(TK::MutRef(Box::new(self.canonicalize_type(inner))), ty.span)
738            }
739
740            TK::Pointer { mutable, pointee } => Type::new(
741                TK::Pointer {
742                    mutable: *mutable,
743                    pointee: Box::new(self.canonicalize_type(pointee)),
744                },
745                ty.span,
746            ),
747            _ => ty.clone(),
748        }
749    }
750
751    fn unify(&self, expected: &Type, actual: &Type) -> Result<()> {
752        let span = if actual.span.start_line > 0 {
753            Some(actual.span)
754        } else if expected.span.start_line > 0 {
755            Some(expected.span)
756        } else {
757            None
758        };
759        self.unify_at(expected, actual, span)
760    }
761
762    fn unify_at(&self, expected: &Type, actual: &Type, span: Option<Span>) -> Result<()> {
763        if matches!(expected.kind, TypeKind::Unknown) || matches!(actual.kind, TypeKind::Unknown) {
764            return Ok(());
765        }
766
767        if matches!(expected.kind, TypeKind::Infer) || matches!(actual.kind, TypeKind::Infer) {
768            return Ok(());
769        }
770
771        match (&expected.kind, &actual.kind) {
772            (TypeKind::Union(expected_types), TypeKind::Union(actual_types)) => {
773                if expected_types.len() != actual_types.len() {
774                    return Err(self.type_error(format!(
775                        "Union types have different number of members: expected {}, got {}",
776                        expected_types.len(),
777                        actual_types.len()
778                    )));
779                }
780
781                for exp_type in expected_types {
782                    let mut found = false;
783                    for act_type in actual_types {
784                        if self.types_equal(exp_type, act_type) {
785                            found = true;
786                            break;
787                        }
788                    }
789
790                    if !found {
791                        return Err(match span {
792                            Some(s) => self.type_error_at(
793                                format!(
794                                    "Union type member '{}' not found in actual union",
795                                    exp_type
796                                ),
797                                s,
798                            ),
799                            None => self.type_error(format!(
800                                "Union type member '{}' not found in actual union",
801                                exp_type
802                            )),
803                        });
804                    }
805                }
806
807                return Ok(());
808            }
809
810            (TypeKind::Union(expected_types), _) => {
811                for union_member in expected_types {
812                    if self.unify(union_member, actual).is_ok() {
813                        return Ok(());
814                    }
815                }
816
817                return Err(match span {
818                    Some(s) => self.type_error_at(
819                        format!("Type '{}' is not compatible with union type", actual),
820                        s,
821                    ),
822                    None => self.type_error(format!(
823                        "Type '{}' is not compatible with union type",
824                        actual
825                    )),
826                });
827            }
828
829            (_, TypeKind::Union(actual_types)) => {
830                for union_member in actual_types {
831                    self.unify(expected, union_member)?;
832                }
833
834                return Ok(());
835            }
836
837            _ => {}
838        }
839
840        match (&expected.kind, &actual.kind) {
841            (TypeKind::Tuple(expected_elems), TypeKind::Tuple(actual_elems)) => {
842                if expected_elems.len() != actual_elems.len() {
843                    return Err(match span {
844                        Some(s) => self.type_error_at(
845                            format!(
846                                "Tuple length mismatch: expected {} element(s), got {}",
847                                expected_elems.len(),
848                                actual_elems.len()
849                            ),
850                            s,
851                        ),
852                        None => self.type_error(format!(
853                            "Tuple length mismatch: expected {} element(s), got {}",
854                            expected_elems.len(),
855                            actual_elems.len()
856                        )),
857                    });
858                }
859
860                for (exp_elem, act_elem) in expected_elems.iter().zip(actual_elems.iter()) {
861                    self.unify(exp_elem, act_elem)?;
862                }
863
864                return Ok(());
865            }
866
867            (TypeKind::Tuple(_), _) | (_, TypeKind::Tuple(_)) => {
868                return Err(match span {
869                    Some(s) => self.type_error_at(
870                        format!("Tuple type is not compatible with type '{}'", actual),
871                        s,
872                    ),
873                    None => self.type_error(format!(
874                        "Tuple type is not compatible with type '{}'",
875                        actual
876                    )),
877                })
878            }
879
880            (TypeKind::Named(name), TypeKind::Array(_))
881            | (TypeKind::Array(_), TypeKind::Named(name))
882                if name == "Array" =>
883            {
884                return Ok(());
885            }
886
887            (TypeKind::Array(exp_el), TypeKind::Array(act_el)) => {
888                if matches!(exp_el.kind, TypeKind::Unknown | TypeKind::Infer)
889                    || matches!(act_el.kind, TypeKind::Unknown | TypeKind::Infer)
890                {
891                    return Ok(());
892                } else {
893                    return self.unify(exp_el, act_el);
894                }
895            }
896
897            (TypeKind::Map(exp_key, exp_value), TypeKind::Map(act_key, act_value)) => {
898                self.unify(exp_key, act_key)?;
899                return self.unify(exp_value, act_value);
900            }
901
902            (TypeKind::Named(name), TypeKind::Option(_))
903            | (TypeKind::Option(_), TypeKind::Named(name))
904                if name == "Option" =>
905            {
906                return Ok(());
907            }
908
909            (TypeKind::Option(exp_inner), TypeKind::Option(act_inner)) => {
910                if matches!(exp_inner.kind, TypeKind::Unknown | TypeKind::Infer)
911                    || matches!(act_inner.kind, TypeKind::Unknown | TypeKind::Infer)
912                {
913                    return Ok(());
914                } else {
915                    return self.unify(exp_inner, act_inner);
916                }
917            }
918
919            (TypeKind::Named(name), TypeKind::Result(_, _))
920            | (TypeKind::Result(_, _), TypeKind::Named(name))
921                if name == "Result" =>
922            {
923                return Ok(());
924            }
925
926            (TypeKind::Result(exp_ok, exp_err), TypeKind::Result(act_ok, act_err)) => {
927                if matches!(exp_ok.kind, TypeKind::Unknown | TypeKind::Infer)
928                    || matches!(act_ok.kind, TypeKind::Unknown | TypeKind::Infer)
929                {
930                    if matches!(exp_err.kind, TypeKind::Unknown | TypeKind::Infer)
931                        || matches!(act_err.kind, TypeKind::Unknown | TypeKind::Infer)
932                    {
933                        return Ok(());
934                    } else {
935                        return self.unify(exp_err, act_err);
936                    }
937                } else {
938                    self.unify(exp_ok, act_ok)?;
939                    return self.unify(exp_err, act_err);
940                }
941            }
942
943            _ => {}
944        }
945
946        if self.types_equal(expected, actual) {
947            Ok(())
948        } else {
949            Err(match span {
950                Some(s) => self.type_error_at(
951                    format!("Type mismatch: expected '{}', got '{}'", expected, actual),
952                    s,
953                ),
954                None => self.type_error(format!(
955                    "Type mismatch: expected '{}', got '{}'",
956                    expected, actual
957                )),
958            })
959        }
960    }
961
962    fn types_compatible(&self, expected: &Type, actual: &Type) -> bool {
963        if matches!(expected.kind, TypeKind::Unknown) || matches!(actual.kind, TypeKind::Unknown) {
964            return true;
965        }
966
967        if matches!(expected.kind, TypeKind::Infer) || matches!(actual.kind, TypeKind::Infer) {
968            return true;
969        }
970
971        match (&expected.kind, &actual.kind) {
972            (TypeKind::Generic(_), TypeKind::Generic(_)) => return true,
973            (TypeKind::Generic(_), _) | (_, TypeKind::Generic(_)) => return true,
974            _ => {}
975        }
976
977        match (&expected.kind, &actual.kind) {
978            (TypeKind::Array(e1), TypeKind::Array(e2)) => {
979                return self.types_compatible(e1, e2);
980            }
981
982            (TypeKind::Named(name), TypeKind::Array(_))
983            | (TypeKind::Array(_), TypeKind::Named(name))
984                if name == "Array" =>
985            {
986                return true;
987            }
988
989            _ => {}
990        }
991
992        match (&expected.kind, &actual.kind) {
993            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
994                return self.types_compatible(k1, k2) && self.types_compatible(v1, v2);
995            }
996
997            _ => {}
998        }
999
1000        match (&expected.kind, &actual.kind) {
1001            (TypeKind::Option(t1), TypeKind::Option(t2)) => {
1002                return self.types_compatible(t1, t2);
1003            }
1004
1005            (TypeKind::Named(name), TypeKind::Option(_))
1006            | (TypeKind::Option(_), TypeKind::Named(name))
1007                if name == "Option" =>
1008            {
1009                return true;
1010            }
1011
1012            _ => {}
1013        }
1014
1015        match (&expected.kind, &actual.kind) {
1016            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
1017                return self.types_compatible(ok1, ok2) && self.types_compatible(err1, err2);
1018            }
1019
1020            (TypeKind::Named(name), TypeKind::Result(_, _))
1021            | (TypeKind::Result(_, _), TypeKind::Named(name))
1022                if name == "Result" =>
1023            {
1024                return true;
1025            }
1026
1027            _ => {}
1028        }
1029
1030        match (&expected.kind, &actual.kind) {
1031            (
1032                TypeKind::Function {
1033                    params: p1,
1034                    return_type: r1,
1035                },
1036                TypeKind::Function {
1037                    params: p2,
1038                    return_type: r2,
1039                },
1040            ) => {
1041                if p1.len() != p2.len() {
1042                    return false;
1043                }
1044
1045                for (t1, t2) in p1.iter().zip(p2.iter()) {
1046                    if !self.types_compatible(t1, t2) {
1047                        return false;
1048                    }
1049                }
1050
1051                return self.types_compatible(r1, r2);
1052            }
1053
1054            _ => {}
1055        }
1056
1057        self.types_equal(expected, actual)
1058    }
1059
1060    fn unify_with_bounds(&self, expected: &Type, actual: &Type) -> Result<()> {
1061        if let TypeKind::Generic(type_param) = &expected.kind {
1062            if let Some(trait_names) = self.current_trait_bounds.get(type_param) {
1063                for trait_name in trait_names {
1064                    if !self.env.type_implements_trait(actual, trait_name) {
1065                        return Err(self.type_error(format!(
1066                            "Type '{}' does not implement required trait '{}'",
1067                            actual, trait_name
1068                        )));
1069                    }
1070                }
1071
1072                return Ok(());
1073            }
1074
1075            return Ok(());
1076        }
1077
1078        self.unify(expected, actual)
1079    }
1080
1081    fn record_short_circuit_info(&mut self, span: Span, info: &ShortCircuitInfo) {
1082        let truthy = info.truthy.as_ref().map(|ty| self.canonicalize_type(ty));
1083        let falsy = info.falsy.as_ref().map(|ty| self.canonicalize_type(ty));
1084        let option_inner = info
1085            .option_inner
1086            .as_ref()
1087            .map(|ty| self.canonicalize_type(ty));
1088        let module_key = self.current_module_key();
1089        self.short_circuit_info
1090            .entry(module_key)
1091            .or_default()
1092            .insert(
1093                span,
1094                ShortCircuitInfo {
1095                    truthy,
1096                    falsy,
1097                    option_inner,
1098                },
1099            );
1100    }
1101
1102    fn short_circuit_profile(&self, expr: &Expr, ty: &Type) -> ShortCircuitInfo {
1103        let module_key = self
1104            .current_module
1105            .as_ref()
1106            .map(String::as_str)
1107            .unwrap_or("");
1108        if let Some(module_map) = self.short_circuit_info.get(module_key) {
1109            if let Some(info) = module_map.get(&expr.span) {
1110                return info.clone();
1111            }
1112        }
1113
1114        ShortCircuitInfo {
1115            truthy: if self.type_can_be_truthy(ty) {
1116                Some(self.canonicalize_type(ty))
1117            } else {
1118                None
1119            },
1120            falsy: self.extract_falsy_type(ty),
1121            option_inner: None,
1122        }
1123    }
1124
1125    fn current_module_key(&self) -> String {
1126        self.current_module
1127            .as_ref()
1128            .cloned()
1129            .unwrap_or_else(|| "".to_string())
1130    }
1131
1132    fn clear_option_for_span(&mut self, span: Span) {
1133        let module_key = self.current_module_key();
1134        if let Some(module_map) = self.short_circuit_info.get_mut(&module_key) {
1135            if let Some(info) = module_map.get_mut(&span) {
1136                info.option_inner = None;
1137            }
1138        }
1139    }
1140
1141    fn type_can_be_truthy(&self, ty: &Type) -> bool {
1142        match &ty.kind {
1143            TypeKind::Union(members) => {
1144                members.iter().any(|member| self.type_can_be_truthy(member))
1145            }
1146            TypeKind::Bool => true,
1147            TypeKind::Unknown => true,
1148            _ => true,
1149        }
1150    }
1151
1152    fn type_can_be_falsy(&self, ty: &Type) -> bool {
1153        match &ty.kind {
1154            TypeKind::Union(members) => members.iter().any(|member| self.type_can_be_falsy(member)),
1155            TypeKind::Bool => true,
1156            TypeKind::Unknown => true,
1157            TypeKind::Option(_) => true,
1158            _ => false,
1159        }
1160    }
1161
1162    fn extract_falsy_type(&self, ty: &Type) -> Option<Type> {
1163        match &ty.kind {
1164            TypeKind::Bool => Some(Type::new(TypeKind::Bool, ty.span)),
1165            TypeKind::Unknown => Some(Type::new(TypeKind::Unknown, ty.span)),
1166            TypeKind::Option(inner) => Some(Type::new(
1167                TypeKind::Option(Box::new(self.canonicalize_type(inner))),
1168                ty.span,
1169            )),
1170            TypeKind::Union(members) => {
1171                let mut parts = Vec::new();
1172                for member in members {
1173                    if let Some(part) = self.extract_falsy_type(member) {
1174                        parts.push(part);
1175                    }
1176                }
1177                self.merge_optional_types(parts)
1178            }
1179            _ => None,
1180        }
1181    }
1182
1183    fn merge_optional_types(&self, types: Vec<Type>) -> Option<Type> {
1184        if types.is_empty() {
1185            return None;
1186        }
1187
1188        Some(self.make_union_from_types(types))
1189    }
1190
1191    fn make_union_from_types(&self, types: Vec<Type>) -> Type {
1192        let mut flat: Vec<Type> = Vec::new();
1193        for ty in types {
1194            let canonical = self.canonicalize_type(&ty);
1195            match &canonical.kind {
1196                TypeKind::Union(members) => {
1197                    for member in members {
1198                        self.push_unique_type(&mut flat, member.clone());
1199                    }
1200                }
1201                _ => self.push_unique_type(&mut flat, canonical),
1202            }
1203        }
1204
1205        match flat.len() {
1206            0 => Type::new(TypeKind::Unknown, Self::dummy_span()),
1207            1 => flat.into_iter().next().unwrap(),
1208            _ => Type::new(TypeKind::Union(flat), Self::dummy_span()),
1209        }
1210    }
1211
1212    fn push_unique_type(&self, list: &mut Vec<Type>, candidate: Type) {
1213        if !list
1214            .iter()
1215            .any(|existing| self.types_equal(existing, &candidate))
1216        {
1217            list.push(candidate);
1218        }
1219    }
1220
1221    fn combine_truthy_falsy(&self, truthy: Option<Type>, falsy: Option<Type>) -> Type {
1222        match (truthy, falsy) {
1223            (Some(t), Some(f)) => self.make_union_from_types(vec![t, f]),
1224            (Some(t), None) => t,
1225            (None, Some(f)) => f,
1226            (None, None) => Type::new(TypeKind::Unknown, Self::dummy_span()),
1227        }
1228    }
1229
1230    fn is_bool_like(&self, ty: &Type) -> bool {
1231        match &ty.kind {
1232            TypeKind::Bool => true,
1233            TypeKind::Union(members) => members.iter().all(|member| self.is_bool_like(member)),
1234            _ => false,
1235        }
1236    }
1237
1238    fn option_inner_type<'a>(&self, ty: &'a Type) -> Option<&'a Type> {
1239        match &ty.kind {
1240            TypeKind::Option(inner) => Some(inner.as_ref()),
1241            TypeKind::Union(members) => {
1242                for member in members {
1243                    if let Some(inner) = self.option_inner_type(member) {
1244                        return Some(inner);
1245                    }
1246                }
1247                None
1248            }
1249            _ => None,
1250        }
1251    }
1252
1253    fn should_optionize(&self, left: &Type, right: &Type) -> bool {
1254        self.is_bool_like(left)
1255            && !self.is_bool_like(right)
1256            && self.option_inner_type(right).is_none()
1257    }
1258}