lust/typechecker/
mod.rs

1mod expr_checker;
2mod item_checker;
3mod stmt_checker;
4mod type_env;
5use crate::modules::{LoadedModule, ModuleImports};
6use crate::{
7    ast::*,
8    config::LustConfig,
9    error::{LustError, Result},
10};
11pub(super) use alloc::{
12    boxed::Box,
13    format,
14    string::{String, ToString},
15    vec::Vec,
16};
17use core::mem;
18use hashbrown::{HashMap, HashSet};
19pub use type_env::FunctionSignature;
20pub use type_env::TypeEnv;
21pub struct TypeChecker {
22    env: TypeEnv,
23    current_function_return_type: Option<Type>,
24    in_loop: bool,
25    pending_generic_instances: Option<HashMap<String, Type>>,
26    expected_lambda_signature: Option<(Vec<Type>, Option<Type>)>,
27    current_trait_bounds: HashMap<String, Vec<String>>,
28    current_module: Option<String>,
29    imports_by_module: HashMap<String, ModuleImports>,
30    expr_types_by_module: HashMap<String, HashMap<Span, Type>>,
31    variable_types_by_module: HashMap<String, HashMap<Span, Type>>,
32    short_circuit_info: HashMap<String, HashMap<Span, ShortCircuitInfo>>,
33}
34
35pub struct TypeCollection {
36    pub expr_types: HashMap<String, HashMap<Span, Type>>,
37    pub variable_types: HashMap<String, HashMap<Span, Type>>,
38}
39
40#[derive(Clone, Debug)]
41struct ShortCircuitInfo {
42    truthy: Option<Type>,
43    falsy: Option<Type>,
44    option_inner: Option<Type>,
45}
46
47impl TypeChecker {
48    pub fn new() -> Self {
49        Self::with_config(&LustConfig::default())
50    }
51
52    pub fn with_config(config: &LustConfig) -> Self {
53        Self {
54            env: TypeEnv::with_config(config),
55            current_function_return_type: None,
56            in_loop: false,
57            pending_generic_instances: None,
58            expected_lambda_signature: None,
59            current_trait_bounds: HashMap::new(),
60            current_module: None,
61            imports_by_module: HashMap::new(),
62            expr_types_by_module: HashMap::new(),
63            variable_types_by_module: HashMap::new(),
64            short_circuit_info: HashMap::new(),
65        }
66    }
67
68    fn dummy_span() -> Span {
69        Span::new(0, 0, 0, 0)
70    }
71
72    pub fn check_module(&mut self, items: &[Item]) -> Result<()> {
73        for item in items {
74            self.register_type_definition(item)?;
75        }
76
77        self.validate_struct_cycles()?;
78        self.env.push_scope();
79        self.register_module_init_locals(items)?;
80        for item in items {
81            self.check_item(item)?;
82        }
83
84        self.env.pop_scope();
85        Ok(())
86    }
87
88    pub fn check_program(&mut self, modules: &[LoadedModule]) -> Result<()> {
89        for m in modules {
90            self.current_module = Some(m.path.clone());
91            for item in &m.items {
92                self.register_type_definition(item)?;
93            }
94        }
95
96        self.validate_struct_cycles()?;
97        for m in modules {
98            self.current_module = Some(m.path.clone());
99            self.env.push_scope();
100            self.register_module_init_locals(&m.items)?;
101            for item in &m.items {
102                self.check_item(item)?;
103            }
104
105            self.env.pop_scope();
106        }
107
108        self.current_module = None;
109        Ok(())
110    }
111
112    fn validate_struct_cycles(&self) -> Result<()> {
113        use hashbrown::{HashMap, HashSet};
114        let struct_defs = self.env.struct_definitions();
115        if struct_defs.is_empty() {
116            return Ok(());
117        }
118
119        let mut simple_to_full: HashMap<String, Vec<String>> = HashMap::new();
120        for name in struct_defs.keys() {
121            let simple = name.rsplit('.').next().unwrap_or(name).to_string();
122            simple_to_full.entry(simple).or_default().push(name.clone());
123        }
124
125        let mut struct_has_weak: HashMap<String, bool> = HashMap::new();
126        for (name, def) in &struct_defs {
127            let has_weak = def
128                .fields
129                .iter()
130                .any(|field| matches!(field.ownership, FieldOwnership::Weak));
131            struct_has_weak.insert(name.clone(), has_weak);
132        }
133
134        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
135        for (name, def) in &struct_defs {
136            let module_prefix = name.rsplit_once('.').map(|(module, _)| module.to_string());
137            let mut edges: HashSet<String> = HashSet::new();
138            for field in &def.fields {
139                if matches!(field.ownership, FieldOwnership::Weak) {
140                    let target = field.weak_target.as_ref().ok_or_else(|| {
141                        self.type_error(format!(
142                            "Field '{}.{}' is marked as 'ref' but has no target type",
143                            name, field.name
144                        ))
145                    })?;
146                    let target_name = if let TypeKind::Named(inner) = &target.kind {
147                        inner
148                    } else {
149                        return Err(self.type_error(format!(
150                            "Field '{}.{}' uses 'ref' but only struct types are supported",
151                            name, field.name
152                        )));
153                    };
154                    let resolved = self.resolve_struct_name_for_cycle(
155                        target_name.as_str(),
156                        module_prefix.as_deref(),
157                        &struct_defs,
158                        &simple_to_full,
159                    );
160                    if resolved.is_none() {
161                        return Err(self.type_error(format!(
162                            "Field '{}.{}' uses 'ref' but '{}' is not a known struct type",
163                            name, field.name, target_name
164                        )));
165                    }
166
167                    continue;
168                }
169
170                self.collect_strong_struct_targets(
171                    &field.ty,
172                    module_prefix.as_deref(),
173                    &struct_defs,
174                    &simple_to_full,
175                    &mut edges,
176                );
177            }
178
179            graph.insert(name.clone(), edges.into_iter().collect());
180        }
181
182        fn dfs(
183            node: &str,
184            graph: &HashMap<String, Vec<String>>,
185            visited: &mut HashSet<String>,
186            on_stack: &mut HashSet<String>,
187            stack: &mut Vec<String>,
188        ) -> Option<Vec<String>> {
189            visited.insert(node.to_string());
190            on_stack.insert(node.to_string());
191            stack.push(node.to_string());
192            if let Some(neighbors) = graph.get(node) {
193                for neighbor in neighbors {
194                    if !visited.contains(neighbor) {
195                        if let Some(cycle) = dfs(neighbor, graph, visited, on_stack, stack) {
196                            return Some(cycle);
197                        }
198                    } else if on_stack.contains(neighbor) {
199                        if let Some(pos) = stack.iter().position(|n| n == neighbor) {
200                            let mut cycle = stack[pos..].to_vec();
201                            cycle.push(neighbor.clone());
202                            return Some(cycle);
203                        }
204                    }
205                }
206            }
207
208            stack.pop();
209            on_stack.remove(node);
210            None
211        }
212
213        let mut visited: HashSet<String> = HashSet::new();
214        let mut on_stack: HashSet<String> = HashSet::new();
215        let mut stack: Vec<String> = Vec::new();
216        for name in struct_defs.keys() {
217            if !visited.contains(name) {
218                if let Some(cycle) = dfs(name, &graph, &mut visited, &mut on_stack, &mut stack) {
219                    let contains_weak = cycle
220                        .iter()
221                        .any(|node| struct_has_weak.get(node).copied().unwrap_or(false));
222                    if contains_weak {
223                        continue;
224                    }
225
226                    let description = cycle.join(" -> ");
227                    return Err(self.type_error(format!(
228                        "Strong ownership cycle detected: {}. Mark at least one field as 'ref' to break the cycle.",
229                        description
230                    )));
231                }
232            }
233        }
234
235        Ok(())
236    }
237
238    fn collect_strong_struct_targets(
239        &self,
240        ty: &Type,
241        parent_module: Option<&str>,
242        struct_defs: &HashMap<String, StructDef>,
243        simple_to_full: &HashMap<String, Vec<String>>,
244        out: &mut HashSet<String>,
245    ) {
246        match &ty.kind {
247            TypeKind::Named(name) => {
248                if let Some(resolved) = self.resolve_struct_name_for_cycle(
249                    name,
250                    parent_module,
251                    struct_defs,
252                    simple_to_full,
253                ) {
254                    out.insert(resolved);
255                }
256            }
257
258            TypeKind::Array(inner)
259            | TypeKind::Ref(inner)
260            | TypeKind::MutRef(inner)
261            | TypeKind::Option(inner) => {
262                self.collect_strong_struct_targets(
263                    inner,
264                    parent_module,
265                    struct_defs,
266                    simple_to_full,
267                    out,
268                );
269            }
270
271            TypeKind::Map(key, value) => {
272                self.collect_strong_struct_targets(
273                    key,
274                    parent_module,
275                    struct_defs,
276                    simple_to_full,
277                    out,
278                );
279                self.collect_strong_struct_targets(
280                    value,
281                    parent_module,
282                    struct_defs,
283                    simple_to_full,
284                    out,
285                );
286            }
287
288            TypeKind::Tuple(elements) | TypeKind::Union(elements) => {
289                for element in elements {
290                    self.collect_strong_struct_targets(
291                        element,
292                        parent_module,
293                        struct_defs,
294                        simple_to_full,
295                        out,
296                    );
297                }
298            }
299
300            TypeKind::Result(ok, err) => {
301                self.collect_strong_struct_targets(
302                    ok,
303                    parent_module,
304                    struct_defs,
305                    simple_to_full,
306                    out,
307                );
308                self.collect_strong_struct_targets(
309                    err,
310                    parent_module,
311                    struct_defs,
312                    simple_to_full,
313                    out,
314                );
315            }
316
317            TypeKind::GenericInstance { type_args, .. } => {
318                for arg in type_args {
319                    self.collect_strong_struct_targets(
320                        arg,
321                        parent_module,
322                        struct_defs,
323                        simple_to_full,
324                        out,
325                    );
326                }
327            }
328
329            _ => {}
330        }
331    }
332
333    fn resolve_struct_name_for_cycle(
334        &self,
335        name: &str,
336        parent_module: Option<&str>,
337        struct_defs: &HashMap<String, StructDef>,
338        simple_to_full: &HashMap<String, Vec<String>>,
339    ) -> Option<String> {
340        if struct_defs.contains_key(name) {
341            return Some(name.to_string());
342        }
343
344        if name.contains('.') {
345            return None;
346        }
347
348        if let Some(candidates) = simple_to_full.get(name) {
349            if candidates.len() == 1 {
350                return Some(candidates[0].clone());
351            }
352
353            if let Some(module) = parent_module {
354                for candidate in candidates {
355                    if let Some((candidate_module, _)) = candidate.rsplit_once('.') {
356                        if candidate_module == module {
357                            return Some(candidate.clone());
358                        }
359                    }
360                }
361            }
362        }
363
364        None
365    }
366
367    pub fn set_imports_by_module(&mut self, map: HashMap<String, ModuleImports>) {
368        self.imports_by_module = map;
369    }
370
371    pub fn take_type_info(&mut self) -> TypeCollection {
372        TypeCollection {
373            expr_types: mem::take(&mut self.expr_types_by_module),
374            variable_types: mem::take(&mut self.variable_types_by_module),
375        }
376    }
377
378    pub fn take_option_coercions(&mut self) -> HashMap<String, HashSet<Span>> {
379        let mut result: HashMap<String, HashSet<Span>> = HashMap::new();
380        let info = mem::take(&mut self.short_circuit_info);
381        for (module, entries) in info {
382            let mut spans: HashSet<Span> = HashSet::new();
383            for (span, entry) in entries {
384                if entry.option_inner.is_some() {
385                    spans.insert(span);
386                }
387            }
388            if !spans.is_empty() {
389                result.insert(module, spans);
390            }
391        }
392
393        result
394    }
395
396    pub fn function_signatures(&self) -> HashMap<String, type_env::FunctionSignature> {
397        self.env.function_signatures()
398    }
399
400    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
401        self.env.struct_definitions()
402    }
403
404    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
405        self.env.enum_definitions()
406    }
407
408    fn register_module_init_locals(&mut self, items: &[Item]) -> Result<()> {
409        let module = match &self.current_module {
410            Some(m) => m.clone(),
411            None => return Ok(()),
412        };
413        let init_name = format!("__init@{}", module);
414        for item in items {
415            if let ItemKind::Function(func) = &item.kind {
416                if func.name == init_name {
417                    for stmt in &func.body {
418                        if let StmtKind::Local {
419                            bindings,
420                            ref mutable,
421                            initializer,
422                        } = &stmt.kind
423                        {
424                            self.check_local_stmt(
425                                bindings.as_slice(),
426                                *mutable,
427                                initializer.as_ref().map(|values| values.as_slice()),
428                            )?;
429                        }
430                    }
431                }
432            }
433        }
434
435        Ok(())
436    }
437
438    pub fn resolve_function_key(&self, name: &str) -> String {
439        if name.contains('.') || name.contains(':') {
440            return name.to_string();
441        }
442
443        if let Some(module) = &self.current_module {
444            if let Some(imports) = self.imports_by_module.get(module) {
445                if let Some(fq) = imports.function_aliases.get(name) {
446                    return fq.clone();
447                }
448            }
449
450            let qualified = format!("{}.{}", module, name);
451            if self.env.lookup_function(&qualified).is_some() {
452                return qualified;
453            }
454
455            if self.env.lookup_function(name).is_some() {
456                return name.to_string();
457            }
458
459            return qualified;
460        }
461
462        name.to_string()
463    }
464
465    pub fn resolve_module_alias(&self, alias: &str) -> Option<String> {
466        if let Some(module) = &self.current_module {
467            if let Some(imports) = self.imports_by_module.get(module) {
468                if let Some(m) = imports.module_aliases.get(alias) {
469                    return Some(m.clone());
470                }
471            }
472        }
473
474        None
475    }
476
477    pub fn resolve_type_key(&self, name: &str) -> String {
478        if let Some((head, tail)) = name.split_once('.') {
479            if let Some(module) = &self.current_module {
480                if let Some(imports) = self.imports_by_module.get(module) {
481                    if let Some(real_module) = imports.module_aliases.get(head) {
482                        if tail.is_empty() {
483                            return real_module.clone();
484                        } else {
485                            return format!("{}.{}", real_module, tail);
486                        }
487                    }
488                }
489            }
490
491            return name.to_string();
492        }
493
494        if self.env.lookup_struct(name).is_some()
495            || self.env.lookup_enum(name).is_some()
496            || self.env.lookup_trait(name).is_some()
497        {
498            return name.to_string();
499        }
500
501        if self.env.is_builtin_type(name) {
502            return name.to_string();
503        }
504
505        if let Some(module) = &self.current_module {
506            if let Some(imports) = self.imports_by_module.get(module) {
507                if let Some(fq) = imports.type_aliases.get(name) {
508                    return fq.clone();
509                }
510            }
511
512            return format!("{}.{}", module, name);
513        }
514
515        name.to_string()
516    }
517
518    fn register_type_definition(&mut self, item: &Item) -> Result<()> {
519        match &item.kind {
520            ItemKind::Struct(s) => {
521                let mut s2 = s.clone();
522                if let Some(module) = &self.current_module {
523                    if !s2.name.contains('.') {
524                        s2.name = format!("{}.{}", module, s2.name);
525                    }
526                }
527
528                self.env.register_struct(&s2)?;
529            }
530
531            ItemKind::Enum(e) => {
532                let mut e2 = e.clone();
533                if let Some(module) = &self.current_module {
534                    if !e2.name.contains('.') {
535                        e2.name = format!("{}.{}", module, e2.name);
536                    }
537                }
538
539                self.env.register_enum(&e2)?;
540            }
541
542            ItemKind::Trait(t) => {
543                let mut t2 = t.clone();
544                if let Some(module) = &self.current_module {
545                    if !t2.name.contains('.') {
546                        t2.name = format!("{}.{}", module, t2.name);
547                    }
548                }
549
550                self.env.register_trait(&t2)?;
551            }
552
553            ItemKind::TypeAlias {
554                name,
555                type_params,
556                target,
557            } => {
558                let qname = if let Some(module) = &self.current_module {
559                    if name.contains('.') {
560                        name.clone()
561                    } else {
562                        format!("{}.{}", module, name)
563                    }
564                } else {
565                    name.clone()
566                };
567                self.env
568                    .register_type_alias(qname, type_params.clone(), target.clone())?;
569            }
570
571            _ => {}
572        }
573
574        Ok(())
575    }
576
577    fn type_error(&self, message: String) -> LustError {
578        LustError::TypeError { message }
579    }
580
581    fn type_error_at(&self, message: String, span: Span) -> LustError {
582        if span.start_line > 0 {
583            LustError::TypeErrorWithSpan {
584                message,
585                line: span.start_line,
586                column: span.start_col,
587                module: self.current_module.clone(),
588            }
589        } else {
590            LustError::TypeError { message }
591        }
592    }
593
594    fn types_equal(&self, t1: &Type, t2: &Type) -> bool {
595        t1.kind == t2.kind
596    }
597
598    pub fn canonicalize_type(&self, ty: &Type) -> Type {
599        use crate::ast::TypeKind as TK;
600        match &ty.kind {
601            TK::Named(name) => Type::new(TK::Named(self.resolve_type_key(name)), ty.span),
602            TK::Array(inner) => {
603                Type::new(TK::Array(Box::new(self.canonicalize_type(inner))), ty.span)
604            }
605
606            TK::Tuple(elements) => Type::new(
607                TK::Tuple(elements.iter().map(|t| self.canonicalize_type(t)).collect()),
608                ty.span,
609            ),
610            TK::Option(inner) => {
611                Type::new(TK::Option(Box::new(self.canonicalize_type(inner))), ty.span)
612            }
613
614            TK::Result(ok, err) => Type::new(
615                TK::Result(
616                    Box::new(self.canonicalize_type(ok)),
617                    Box::new(self.canonicalize_type(err)),
618                ),
619                ty.span,
620            ),
621            TK::Map(k, v) => Type::new(
622                TK::Map(
623                    Box::new(self.canonicalize_type(k)),
624                    Box::new(self.canonicalize_type(v)),
625                ),
626                ty.span,
627            ),
628            TK::Ref(inner) => Type::new(TK::Ref(Box::new(self.canonicalize_type(inner))), ty.span),
629            TK::MutRef(inner) => {
630                Type::new(TK::MutRef(Box::new(self.canonicalize_type(inner))), ty.span)
631            }
632
633            TK::Pointer { mutable, pointee } => Type::new(
634                TK::Pointer {
635                    mutable: *mutable,
636                    pointee: Box::new(self.canonicalize_type(pointee)),
637                },
638                ty.span,
639            ),
640            _ => ty.clone(),
641        }
642    }
643
644    fn unify(&self, expected: &Type, actual: &Type) -> Result<()> {
645        let span = if actual.span.start_line > 0 {
646            Some(actual.span)
647        } else if expected.span.start_line > 0 {
648            Some(expected.span)
649        } else {
650            None
651        };
652        self.unify_at(expected, actual, span)
653    }
654
655    fn unify_at(&self, expected: &Type, actual: &Type, span: Option<Span>) -> Result<()> {
656        if matches!(expected.kind, TypeKind::Unknown) || matches!(actual.kind, TypeKind::Unknown) {
657            return Ok(());
658        }
659
660        if matches!(expected.kind, TypeKind::Infer) || matches!(actual.kind, TypeKind::Infer) {
661            return Ok(());
662        }
663
664        match (&expected.kind, &actual.kind) {
665            (TypeKind::Union(expected_types), TypeKind::Union(actual_types)) => {
666                if expected_types.len() != actual_types.len() {
667                    return Err(self.type_error(format!(
668                        "Union types have different number of members: expected {}, got {}",
669                        expected_types.len(),
670                        actual_types.len()
671                    )));
672                }
673
674                for exp_type in expected_types {
675                    let mut found = false;
676                    for act_type in actual_types {
677                        if self.types_equal(exp_type, act_type) {
678                            found = true;
679                            break;
680                        }
681                    }
682
683                    if !found {
684                        return Err(match span {
685                            Some(s) => self.type_error_at(
686                                format!(
687                                    "Union type member '{}' not found in actual union",
688                                    exp_type
689                                ),
690                                s,
691                            ),
692                            None => self.type_error(format!(
693                                "Union type member '{}' not found in actual union",
694                                exp_type
695                            )),
696                        });
697                    }
698                }
699
700                return Ok(());
701            }
702
703            (TypeKind::Union(expected_types), _) => {
704                for union_member in expected_types {
705                    if self.unify(union_member, actual).is_ok() {
706                        return Ok(());
707                    }
708                }
709
710                return Err(match span {
711                    Some(s) => self.type_error_at(
712                        format!("Type '{}' is not compatible with union type", actual),
713                        s,
714                    ),
715                    None => self.type_error(format!(
716                        "Type '{}' is not compatible with union type",
717                        actual
718                    )),
719                });
720            }
721
722            (_, TypeKind::Union(actual_types)) => {
723                for union_member in actual_types {
724                    self.unify(expected, union_member)?;
725                }
726
727                return Ok(());
728            }
729
730            _ => {}
731        }
732
733        match (&expected.kind, &actual.kind) {
734            (TypeKind::Tuple(expected_elems), TypeKind::Tuple(actual_elems)) => {
735                if expected_elems.len() != actual_elems.len() {
736                    return Err(match span {
737                        Some(s) => self.type_error_at(
738                            format!(
739                                "Tuple length mismatch: expected {} element(s), got {}",
740                                expected_elems.len(),
741                                actual_elems.len()
742                            ),
743                            s,
744                        ),
745                        None => self.type_error(format!(
746                            "Tuple length mismatch: expected {} element(s), got {}",
747                            expected_elems.len(),
748                            actual_elems.len()
749                        )),
750                    });
751                }
752
753                for (exp_elem, act_elem) in expected_elems.iter().zip(actual_elems.iter()) {
754                    self.unify(exp_elem, act_elem)?;
755                }
756
757                return Ok(());
758            }
759
760            (TypeKind::Tuple(_), _) | (_, TypeKind::Tuple(_)) => {
761                return Err(match span {
762                    Some(s) => self.type_error_at(
763                        format!("Tuple type is not compatible with type '{}'", actual),
764                        s,
765                    ),
766                    None => self.type_error(format!(
767                        "Tuple type is not compatible with type '{}'",
768                        actual
769                    )),
770                })
771            }
772
773            (TypeKind::Named(name), TypeKind::Array(_))
774            | (TypeKind::Array(_), TypeKind::Named(name))
775                if name == "Array" =>
776            {
777                return Ok(());
778            }
779
780            (TypeKind::Array(exp_el), TypeKind::Array(act_el)) => {
781                if matches!(exp_el.kind, TypeKind::Unknown | TypeKind::Infer)
782                    || matches!(act_el.kind, TypeKind::Unknown | TypeKind::Infer)
783                {
784                    return Ok(());
785                } else {
786                    return self.unify(exp_el, act_el);
787                }
788            }
789
790            (TypeKind::Map(exp_key, exp_value), TypeKind::Map(act_key, act_value)) => {
791                self.unify(exp_key, act_key)?;
792                return self.unify(exp_value, act_value);
793            }
794
795            (TypeKind::Named(name), TypeKind::Option(_))
796            | (TypeKind::Option(_), TypeKind::Named(name))
797                if name == "Option" =>
798            {
799                return Ok(());
800            }
801
802            (TypeKind::Option(exp_inner), TypeKind::Option(act_inner)) => {
803                if matches!(exp_inner.kind, TypeKind::Unknown | TypeKind::Infer)
804                    || matches!(act_inner.kind, TypeKind::Unknown | TypeKind::Infer)
805                {
806                    return Ok(());
807                } else {
808                    return self.unify(exp_inner, act_inner);
809                }
810            }
811
812            (TypeKind::Named(name), TypeKind::Result(_, _))
813            | (TypeKind::Result(_, _), TypeKind::Named(name))
814                if name == "Result" =>
815            {
816                return Ok(());
817            }
818
819            (TypeKind::Result(exp_ok, exp_err), TypeKind::Result(act_ok, act_err)) => {
820                if matches!(exp_ok.kind, TypeKind::Unknown | TypeKind::Infer)
821                    || matches!(act_ok.kind, TypeKind::Unknown | TypeKind::Infer)
822                {
823                    if matches!(exp_err.kind, TypeKind::Unknown | TypeKind::Infer)
824                        || matches!(act_err.kind, TypeKind::Unknown | TypeKind::Infer)
825                    {
826                        return Ok(());
827                    } else {
828                        return self.unify(exp_err, act_err);
829                    }
830                } else {
831                    self.unify(exp_ok, act_ok)?;
832                    return self.unify(exp_err, act_err);
833                }
834            }
835
836            _ => {}
837        }
838
839        if self.types_equal(expected, actual) {
840            Ok(())
841        } else {
842            Err(match span {
843                Some(s) => self.type_error_at(
844                    format!("Type mismatch: expected '{}', got '{}'", expected, actual),
845                    s,
846                ),
847                None => self.type_error(format!(
848                    "Type mismatch: expected '{}', got '{}'",
849                    expected, actual
850                )),
851            })
852        }
853    }
854
855    fn types_compatible(&self, expected: &Type, actual: &Type) -> bool {
856        if matches!(expected.kind, TypeKind::Unknown) || matches!(actual.kind, TypeKind::Unknown) {
857            return true;
858        }
859
860        if matches!(expected.kind, TypeKind::Infer) || matches!(actual.kind, TypeKind::Infer) {
861            return true;
862        }
863
864        match (&expected.kind, &actual.kind) {
865            (TypeKind::Generic(_), TypeKind::Generic(_)) => return true,
866            (TypeKind::Generic(_), _) | (_, TypeKind::Generic(_)) => return true,
867            _ => {}
868        }
869
870        match (&expected.kind, &actual.kind) {
871            (TypeKind::Array(e1), TypeKind::Array(e2)) => {
872                return self.types_compatible(e1, e2);
873            }
874
875            (TypeKind::Named(name), TypeKind::Array(_))
876            | (TypeKind::Array(_), TypeKind::Named(name))
877                if name == "Array" =>
878            {
879                return true;
880            }
881
882            _ => {}
883        }
884
885        match (&expected.kind, &actual.kind) {
886            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
887                return self.types_compatible(k1, k2) && self.types_compatible(v1, v2);
888            }
889
890            _ => {}
891        }
892
893        match (&expected.kind, &actual.kind) {
894            (TypeKind::Option(t1), TypeKind::Option(t2)) => {
895                return self.types_compatible(t1, t2);
896            }
897
898            (TypeKind::Named(name), TypeKind::Option(_))
899            | (TypeKind::Option(_), TypeKind::Named(name))
900                if name == "Option" =>
901            {
902                return true;
903            }
904
905            _ => {}
906        }
907
908        match (&expected.kind, &actual.kind) {
909            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
910                return self.types_compatible(ok1, ok2) && self.types_compatible(err1, err2);
911            }
912
913            (TypeKind::Named(name), TypeKind::Result(_, _))
914            | (TypeKind::Result(_, _), TypeKind::Named(name))
915                if name == "Result" =>
916            {
917                return true;
918            }
919
920            _ => {}
921        }
922
923        match (&expected.kind, &actual.kind) {
924            (
925                TypeKind::Function {
926                    params: p1,
927                    return_type: r1,
928                },
929                TypeKind::Function {
930                    params: p2,
931                    return_type: r2,
932                },
933            ) => {
934                if p1.len() != p2.len() {
935                    return false;
936                }
937
938                for (t1, t2) in p1.iter().zip(p2.iter()) {
939                    if !self.types_compatible(t1, t2) {
940                        return false;
941                    }
942                }
943
944                return self.types_compatible(r1, r2);
945            }
946
947            _ => {}
948        }
949
950        self.types_equal(expected, actual)
951    }
952
953    fn unify_with_bounds(&self, expected: &Type, actual: &Type) -> Result<()> {
954        if let TypeKind::Generic(type_param) = &expected.kind {
955            if let Some(trait_names) = self.current_trait_bounds.get(type_param) {
956                for trait_name in trait_names {
957                    if !self.env.type_implements_trait(actual, trait_name) {
958                        return Err(self.type_error(format!(
959                            "Type '{}' does not implement required trait '{}'",
960                            actual, trait_name
961                        )));
962                    }
963                }
964
965                return Ok(());
966            }
967
968            return Ok(());
969        }
970
971        self.unify(expected, actual)
972    }
973
974    fn record_short_circuit_info(&mut self, span: Span, info: &ShortCircuitInfo) {
975        let truthy = info.truthy.as_ref().map(|ty| self.canonicalize_type(ty));
976        let falsy = info.falsy.as_ref().map(|ty| self.canonicalize_type(ty));
977        let option_inner = info
978            .option_inner
979            .as_ref()
980            .map(|ty| self.canonicalize_type(ty));
981        let module_key = self.current_module_key();
982        self.short_circuit_info
983            .entry(module_key)
984            .or_default()
985            .insert(
986                span,
987                ShortCircuitInfo {
988                    truthy,
989                    falsy,
990                    option_inner,
991                },
992            );
993    }
994
995    fn short_circuit_profile(&self, expr: &Expr, ty: &Type) -> ShortCircuitInfo {
996        let module_key = self
997            .current_module
998            .as_ref()
999            .map(String::as_str)
1000            .unwrap_or("");
1001        if let Some(module_map) = self.short_circuit_info.get(module_key) {
1002            if let Some(info) = module_map.get(&expr.span) {
1003                return info.clone();
1004            }
1005        }
1006
1007        ShortCircuitInfo {
1008            truthy: if self.type_can_be_truthy(ty) {
1009                Some(self.canonicalize_type(ty))
1010            } else {
1011                None
1012            },
1013            falsy: self.extract_falsy_type(ty),
1014            option_inner: None,
1015        }
1016    }
1017
1018    fn current_module_key(&self) -> String {
1019        self.current_module
1020            .as_ref()
1021            .cloned()
1022            .unwrap_or_else(|| "".to_string())
1023    }
1024
1025    fn clear_option_for_span(&mut self, span: Span) {
1026        let module_key = self.current_module_key();
1027        if let Some(module_map) = self.short_circuit_info.get_mut(&module_key) {
1028            if let Some(info) = module_map.get_mut(&span) {
1029                info.option_inner = None;
1030            }
1031        }
1032    }
1033
1034    fn type_can_be_truthy(&self, ty: &Type) -> bool {
1035        match &ty.kind {
1036            TypeKind::Union(members) => {
1037                members.iter().any(|member| self.type_can_be_truthy(member))
1038            }
1039            TypeKind::Bool => true,
1040            TypeKind::Unknown => true,
1041            _ => true,
1042        }
1043    }
1044
1045    fn type_can_be_falsy(&self, ty: &Type) -> bool {
1046        match &ty.kind {
1047            TypeKind::Union(members) => members.iter().any(|member| self.type_can_be_falsy(member)),
1048            TypeKind::Bool => true,
1049            TypeKind::Unknown => true,
1050            TypeKind::Option(_) => true,
1051            _ => false,
1052        }
1053    }
1054
1055    fn extract_falsy_type(&self, ty: &Type) -> Option<Type> {
1056        match &ty.kind {
1057            TypeKind::Bool => Some(Type::new(TypeKind::Bool, ty.span)),
1058            TypeKind::Unknown => Some(Type::new(TypeKind::Unknown, ty.span)),
1059            TypeKind::Option(inner) => Some(Type::new(
1060                TypeKind::Option(Box::new(self.canonicalize_type(inner))),
1061                ty.span,
1062            )),
1063            TypeKind::Union(members) => {
1064                let mut parts = Vec::new();
1065                for member in members {
1066                    if let Some(part) = self.extract_falsy_type(member) {
1067                        parts.push(part);
1068                    }
1069                }
1070                self.merge_optional_types(parts)
1071            }
1072            _ => None,
1073        }
1074    }
1075
1076    fn merge_optional_types(&self, types: Vec<Type>) -> Option<Type> {
1077        if types.is_empty() {
1078            return None;
1079        }
1080
1081        Some(self.make_union_from_types(types))
1082    }
1083
1084    fn make_union_from_types(&self, types: Vec<Type>) -> Type {
1085        let mut flat: Vec<Type> = Vec::new();
1086        for ty in types {
1087            let canonical = self.canonicalize_type(&ty);
1088            match &canonical.kind {
1089                TypeKind::Union(members) => {
1090                    for member in members {
1091                        self.push_unique_type(&mut flat, member.clone());
1092                    }
1093                }
1094                _ => self.push_unique_type(&mut flat, canonical),
1095            }
1096        }
1097
1098        match flat.len() {
1099            0 => Type::new(TypeKind::Unknown, Self::dummy_span()),
1100            1 => flat.into_iter().next().unwrap(),
1101            _ => Type::new(TypeKind::Union(flat), Self::dummy_span()),
1102        }
1103    }
1104
1105    fn push_unique_type(&self, list: &mut Vec<Type>, candidate: Type) {
1106        if !list
1107            .iter()
1108            .any(|existing| self.types_equal(existing, &candidate))
1109        {
1110            list.push(candidate);
1111        }
1112    }
1113
1114    fn combine_truthy_falsy(&self, truthy: Option<Type>, falsy: Option<Type>) -> Type {
1115        match (truthy, falsy) {
1116            (Some(t), Some(f)) => self.make_union_from_types(vec![t, f]),
1117            (Some(t), None) => t,
1118            (None, Some(f)) => f,
1119            (None, None) => Type::new(TypeKind::Unknown, Self::dummy_span()),
1120        }
1121    }
1122
1123    fn is_bool_like(&self, ty: &Type) -> bool {
1124        match &ty.kind {
1125            TypeKind::Bool => true,
1126            TypeKind::Union(members) => members.iter().all(|member| self.is_bool_like(member)),
1127            _ => false,
1128        }
1129    }
1130
1131    fn option_inner_type<'a>(&self, ty: &'a Type) -> Option<&'a Type> {
1132        match &ty.kind {
1133            TypeKind::Option(inner) => Some(inner.as_ref()),
1134            TypeKind::Union(members) => {
1135                for member in members {
1136                    if let Some(inner) = self.option_inner_type(member) {
1137                        return Some(inner);
1138                    }
1139                }
1140                None
1141            }
1142            _ => None,
1143        }
1144    }
1145
1146    fn should_optionize(&self, left: &Type, right: &Type) -> bool {
1147        self.is_bool_like(left)
1148            && !self.is_bool_like(right)
1149            && self.option_inner_type(right).is_none()
1150    }
1151}