Skip to main content

lisette_syntax/
types.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::{OnceCell, RefCell};
3use std::rc::Rc;
4
5use ecow::EcoString;
6
7/// Extract the unqualified name from a dot-qualified identifier.
8///
9/// `"prelude.Option"` → `"Option"`, `"**nominal.int"` → `"int"`, `"foo"` → `"foo"`
10pub fn unqualified_name(id: &str) -> &str {
11    id.rsplit('.').next().unwrap_or(id)
12}
13
14/// type param name -> type variable
15pub type SubstitutionMap = HashMap<EcoString, Type>;
16
17pub fn substitute(ty: &Type, map: &HashMap<EcoString, Type>) -> Type {
18    if map.is_empty() {
19        return ty.clone();
20    }
21    match ty {
22        Type::Parameter(name) => map.get(name).cloned().unwrap_or_else(|| ty.clone()),
23        Type::Constructor {
24            id,
25            params,
26            underlying_ty: underlying,
27        } => Type::Constructor {
28            id: id.clone(),
29            params: params.iter().map(|p| substitute(p, map)).collect(),
30            underlying_ty: underlying.as_ref().map(|u| Box::new(substitute(u, map))),
31        },
32        Type::Function {
33            params,
34            param_mutability,
35            bounds,
36            return_type,
37        } => Type::Function {
38            params: params.iter().map(|p| substitute(p, map)).collect(),
39            param_mutability: param_mutability.clone(),
40            bounds: bounds
41                .iter()
42                .map(|b| Bound {
43                    param_name: b.param_name.clone(),
44                    generic: substitute(&b.generic, map),
45                    ty: substitute(&b.ty, map),
46                })
47                .collect(),
48            return_type: Box::new(substitute(return_type, map)),
49        },
50        Type::Variable(_) | Type::Error => ty.clone(),
51        Type::Forall { vars, body } => {
52            let has_overlap = map.keys().any(|k| vars.contains(k));
53            let substituted_body = if has_overlap {
54                let filtered_map: HashMap<EcoString, Type> = map
55                    .iter()
56                    .filter(|(k, _)| !vars.contains(*k))
57                    .map(|(k, v)| (k.clone(), v.clone()))
58                    .collect();
59                substitute(body, &filtered_map)
60            } else {
61                substitute(body, map)
62            };
63            Type::Forall {
64                vars: vars.clone(),
65                body: Box::new(substituted_body),
66            }
67        }
68        Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| substitute(e, map)).collect()),
69        Type::Never => ty.clone(),
70    }
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct Bound {
75    pub param_name: EcoString,
76    pub generic: Type,
77    pub ty: Type,
78}
79
80#[derive(Clone)]
81pub enum TypeVariableState {
82    Unbound { id: i32, hint: Option<EcoString> },
83    Link(Type),
84}
85
86impl TypeVariableState {
87    pub fn is_unbound(&self) -> bool {
88        matches!(self, TypeVariableState::Unbound { .. })
89    }
90}
91
92impl std::fmt::Debug for TypeVariableState {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            TypeVariableState::Unbound { id, hint } => match hint {
96                Some(name) => write!(f, "{}", name),
97                None => write!(f, "{}", id),
98            },
99            TypeVariableState::Link(ty) => write!(f, "{:?}", ty),
100        }
101    }
102}
103
104impl PartialEq for TypeVariableState {
105    fn eq(&self, other: &Self) -> bool {
106        match (self, other) {
107            (
108                TypeVariableState::Unbound { id: id1, .. },
109                TypeVariableState::Unbound { id: id2, .. },
110            ) => id1 == id2,
111            (TypeVariableState::Link(ty1), TypeVariableState::Link(ty2)) => ty1 == ty2,
112            _ => false,
113        }
114    }
115}
116
117#[derive(Clone)]
118pub enum Type {
119    Constructor {
120        id: EcoString,
121        params: Vec<Type>,
122        underlying_ty: Option<Box<Type>>,
123    },
124
125    Function {
126        params: Vec<Type>,
127        param_mutability: Vec<bool>,
128        bounds: Vec<Bound>,
129        return_type: Box<Type>,
130    },
131
132    Variable(Rc<RefCell<TypeVariableState>>),
133
134    Forall {
135        vars: Vec<EcoString>,
136        body: Box<Type>,
137    },
138
139    Parameter(EcoString),
140
141    Never,
142
143    Tuple(Vec<Type>),
144
145    /// Poison type returned after an error has been reported.
146    /// Unifies with everything silently, preventing cascading diagnostics.
147    Error,
148}
149
150impl std::fmt::Debug for Type {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            Type::Constructor { id, params, .. } => f
154                .debug_struct("Constructor")
155                .field("id", id)
156                .field("params", params)
157                .finish(),
158            Type::Function {
159                params,
160                param_mutability,
161                bounds,
162                return_type,
163            } => {
164                let mut s = f.debug_struct("Function");
165                s.field("params", params);
166                if param_mutability.iter().any(|m| *m) {
167                    s.field("param_mutability", param_mutability);
168                }
169                s.field("bounds", bounds)
170                    .field("return_type", return_type)
171                    .finish()
172            }
173            Type::Variable(type_var) => f
174                .debug_tuple("Variable")
175                .field(&*type_var.borrow())
176                .finish(),
177            Type::Forall { vars, body } => f
178                .debug_struct("Forall")
179                .field("vars", vars)
180                .field("body", body)
181                .finish(),
182            Type::Parameter(name) => f.debug_tuple("Parameter").field(name).finish(),
183            Type::Never => write!(f, "Never"),
184            Type::Tuple(elements) => f.debug_tuple("Tuple").field(elements).finish(),
185            Type::Error => write!(f, "Error"),
186        }
187    }
188}
189
190impl PartialEq for Type {
191    fn eq(&self, other: &Self) -> bool {
192        match (self, other) {
193            (
194                Type::Constructor {
195                    id: id1,
196                    params: params1,
197                    ..
198                },
199                Type::Constructor {
200                    id: id2,
201                    params: params2,
202                    ..
203                },
204            ) => id1 == id2 && params1 == params2,
205            (
206                Type::Function {
207                    params: p1,
208                    param_mutability: m1,
209                    bounds: b1,
210                    return_type: r1,
211                },
212                Type::Function {
213                    params: p2,
214                    param_mutability: m2,
215                    bounds: b2,
216                    return_type: r2,
217                },
218            ) => p1 == p2 && m1 == m2 && b1 == b2 && r1 == r2,
219            (Type::Variable(v1), Type::Variable(v2)) => {
220                Rc::ptr_eq(v1, v2) || *v1.borrow() == *v2.borrow()
221            }
222            (
223                Type::Forall {
224                    vars: vars1,
225                    body: body1,
226                },
227                Type::Forall {
228                    vars: vars2,
229                    body: body2,
230                },
231            ) => vars1 == vars2 && body1 == body2,
232            (Type::Parameter(name1), Type::Parameter(name2)) => name1 == name2,
233            (Type::Never, Type::Never) => true,
234            (Type::Tuple(elems1), Type::Tuple(elems2)) => elems1 == elems2,
235            _ => false,
236        }
237    }
238}
239
240thread_local! {
241    static INTERNED_INT: OnceCell<Type> = const { OnceCell::new() };
242    static INTERNED_STRING: OnceCell<Type> = const { OnceCell::new() };
243    static INTERNED_BOOL: OnceCell<Type> = const { OnceCell::new() };
244    static INTERNED_UNIT: OnceCell<Type> = const { OnceCell::new() };
245    static INTERNED_FLOAT64: OnceCell<Type> = const { OnceCell::new() };
246    static INTERNED_RUNE: OnceCell<Type> = const { OnceCell::new() };
247}
248
249impl Type {
250    pub fn int() -> Type {
251        INTERNED_INT.with(|cell| cell.get_or_init(|| Self::nominal("int")).clone())
252    }
253
254    pub fn string() -> Type {
255        INTERNED_STRING.with(|cell| cell.get_or_init(|| Self::nominal("string")).clone())
256    }
257
258    pub fn bool() -> Type {
259        INTERNED_BOOL.with(|cell| cell.get_or_init(|| Self::nominal("bool")).clone())
260    }
261
262    pub fn unit() -> Type {
263        INTERNED_UNIT.with(|cell| cell.get_or_init(|| Self::nominal("Unit")).clone())
264    }
265
266    pub fn float64() -> Type {
267        INTERNED_FLOAT64.with(|cell| cell.get_or_init(|| Self::nominal("float64")).clone())
268    }
269
270    pub fn rune() -> Type {
271        INTERNED_RUNE.with(|cell| cell.get_or_init(|| Self::nominal("rune")).clone())
272    }
273}
274
275impl Type {
276    const UNINFERRED_ID: i32 = -1;
277    const IGNORED_ID: i32 = -333;
278
279    pub fn nominal(name: &str) -> Self {
280        Self::Constructor {
281            id: format!("**nominal.{}", name).into(),
282            params: vec![],
283            underlying_ty: None,
284        }
285    }
286
287    pub fn uninferred() -> Self {
288        Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
289            id: Self::UNINFERRED_ID,
290            hint: None,
291        })))
292    }
293
294    pub fn ignored() -> Self {
295        Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
296            id: Self::IGNORED_ID,
297            hint: None,
298        })))
299    }
300
301    pub fn get_type_params(&self) -> Option<&[Type]> {
302        match self {
303            Type::Constructor { params, .. } => Some(params),
304            _ => None,
305        }
306    }
307}
308
309const ARITHMETIC_TYPES: &[&str] = &[
310    "byte",
311    "complex128",
312    "complex64",
313    "float32",
314    "float64",
315    "int",
316    "int16",
317    "int32",
318    "int64",
319    "int8",
320    "rune",
321    "uint",
322    "uint16",
323    "uint32",
324    "uint64",
325    "uint8",
326];
327
328const ORDERED_TYPES: &[&str] = &[
329    "byte", "float32", "float64", "int", "int16", "int32", "int64", "int8", "rune", "uint",
330    "uint16", "uint32", "uint64", "uint8",
331];
332
333const UNSIGNED_INT_TYPES: &[&str] = &["byte", "uint", "uint8", "uint16", "uint32", "uint64"];
334
335impl Type {
336    pub fn get_function_ret(&self) -> Option<&Type> {
337        match self {
338            Type::Function { return_type, .. } => Some(return_type),
339            _ => None,
340        }
341    }
342
343    pub fn has_name(&self, name: &str) -> bool {
344        match self {
345            Type::Constructor { id, .. } => unqualified_name(id) == name,
346            _ => false,
347        }
348    }
349
350    pub fn get_qualified_id(&self) -> Option<&str> {
351        match self {
352            Type::Constructor { id, .. } => Some(id.as_str()),
353            _ => None,
354        }
355    }
356
357    pub fn get_underlying(&self) -> Option<&Type> {
358        match self {
359            Type::Constructor {
360                underlying_ty: underlying,
361                ..
362            } => underlying.as_deref(),
363            _ => None,
364        }
365    }
366
367    pub fn is_result(&self) -> bool {
368        self.has_name("Result")
369    }
370
371    pub fn is_option(&self) -> bool {
372        self.has_name("Option")
373    }
374
375    pub fn is_partial(&self) -> bool {
376        self.has_name("Partial")
377    }
378
379    pub fn is_unit(&self) -> bool {
380        matches!(self.resolve(), Type::Constructor { ref id, .. } if id.as_ref() == "**nominal.Unit")
381    }
382
383    pub fn tuple_arity(&self) -> Option<usize> {
384        match self {
385            Type::Tuple(elements) => Some(elements.len()),
386            _ => None,
387        }
388    }
389
390    pub fn is_tuple(&self) -> bool {
391        matches!(self, Type::Tuple(_))
392    }
393
394    pub fn is_ref(&self) -> bool {
395        self.has_name("Ref")
396    }
397
398    pub fn is_receiver_placeholder(&self) -> bool {
399        self.has_name("__receiver__")
400    }
401
402    pub fn is_unknown(&self) -> bool {
403        self.has_name("Unknown")
404    }
405
406    pub fn is_receiver(&self) -> bool {
407        self.has_name("Receiver")
408    }
409
410    pub fn is_ignored(&self) -> bool {
411        match self {
412            Type::Variable(var) => {
413                matches!(&*var.borrow(), TypeVariableState::Unbound { id, .. } if *id == Self::IGNORED_ID)
414            }
415            _ => false,
416        }
417    }
418
419    pub fn is_variadic(&self) -> Option<Type> {
420        let args = self.get_function_params()?;
421        let last = args.last()?;
422
423        if last.get_name()? == "VarArgs" {
424            return last.inner();
425        }
426
427        None
428    }
429
430    pub fn is_string(&self) -> bool {
431        self.has_name("string")
432    }
433
434    pub fn is_slice_of(&self, element_name: &str) -> bool {
435        match self {
436            Type::Constructor { id, params, .. } => {
437                if unqualified_name(id) != "Slice" || params.len() != 1 {
438                    return false;
439                }
440                params[0].resolve().has_name(element_name)
441            }
442            _ => false,
443        }
444    }
445
446    pub fn is_byte_slice(&self) -> bool {
447        self.is_slice_of("byte") || self.is_slice_of("uint8")
448    }
449
450    pub fn is_rune_slice(&self) -> bool {
451        self.is_slice_of("rune")
452    }
453
454    pub fn is_byte_or_rune_slice(&self) -> bool {
455        self.is_byte_slice() || self.is_rune_slice()
456    }
457
458    pub fn has_byte_or_rune_slice_underlying(&self) -> bool {
459        if self.is_byte_or_rune_slice() {
460            return true;
461        }
462        match self {
463            Type::Constructor { underlying_ty, .. } => underlying_ty
464                .as_deref()
465                .is_some_and(|u| u.has_byte_or_rune_slice_underlying()),
466            _ => false,
467        }
468    }
469
470    pub fn is_boolean(&self) -> bool {
471        self.has_name("bool")
472    }
473
474    pub fn is_rune(&self) -> bool {
475        self.has_name("rune")
476    }
477
478    pub fn is_float64(&self) -> bool {
479        self.has_name("float64")
480    }
481
482    pub fn is_float32(&self) -> bool {
483        self.has_name("float32")
484    }
485
486    pub fn is_float(&self) -> bool {
487        self.is_float64() || self.is_float32()
488    }
489
490    pub fn is_variable(&self) -> bool {
491        matches!(self, Type::Variable(_))
492    }
493
494    pub fn is_unbound_variable(&self) -> bool {
495        matches!(self, Type::Variable(cell) if cell.borrow().is_unbound())
496    }
497
498    pub fn is_numeric(&self) -> bool {
499        match self {
500            Type::Constructor { id, .. } => ARITHMETIC_TYPES.contains(&unqualified_name(id)),
501            _ => false,
502        }
503    }
504
505    pub fn is_ordered(&self) -> bool {
506        match self {
507            Type::Constructor { id, .. } => ORDERED_TYPES.contains(&unqualified_name(id)),
508            _ => false,
509        }
510    }
511
512    pub fn is_complex(&self) -> bool {
513        match self {
514            Type::Constructor { id, .. } => {
515                matches!(unqualified_name(id), "complex128" | "complex64")
516            }
517            _ => false,
518        }
519    }
520
521    pub fn is_unsigned_int(&self) -> bool {
522        match self {
523            Type::Constructor { id, .. } => UNSIGNED_INT_TYPES.contains(&unqualified_name(id)),
524            _ => false,
525        }
526    }
527
528    pub fn is_never(&self) -> bool {
529        matches!(self.shallow_resolve(), Type::Never)
530    }
531
532    pub fn is_error(&self) -> bool {
533        matches!(self.shallow_resolve(), Type::Error)
534    }
535
536    pub fn has_unbound_variables(&self) -> bool {
537        match self {
538            Type::Variable(type_var) => match &*type_var.borrow() {
539                TypeVariableState::Unbound { hint, .. } => hint.is_some(),
540                TypeVariableState::Link(ty) => ty.has_unbound_variables(),
541            },
542            Type::Constructor { params, .. } => params.iter().any(|p| p.has_unbound_variables()),
543            Type::Function {
544                params,
545                return_type,
546                ..
547            } => {
548                params.iter().any(|p| p.has_unbound_variables())
549                    || return_type.has_unbound_variables()
550            }
551            Type::Forall { body, .. } => body.has_unbound_variables(),
552            Type::Tuple(elements) => elements.iter().any(|e| e.has_unbound_variables()),
553            Type::Parameter(_) | Type::Never | Type::Error => false,
554        }
555    }
556
557    pub fn remove_found_type_names(&self, names: &mut HashSet<EcoString>) {
558        if names.is_empty() {
559            return;
560        }
561
562        match self {
563            Type::Constructor { id, params, .. } => {
564                names.remove(unqualified_name(id));
565                for param in params {
566                    param.remove_found_type_names(names);
567                }
568            }
569            Type::Function {
570                params,
571                return_type,
572                bounds,
573                ..
574            } => {
575                for param in params {
576                    param.remove_found_type_names(names);
577                }
578                return_type.remove_found_type_names(names);
579                for bound in bounds {
580                    bound.generic.remove_found_type_names(names);
581                    bound.ty.remove_found_type_names(names);
582                }
583            }
584            Type::Forall { body, .. } => {
585                body.remove_found_type_names(names);
586            }
587            Type::Variable(type_var) => {
588                if let TypeVariableState::Link(ty) = &*type_var.borrow() {
589                    ty.remove_found_type_names(names);
590                }
591            }
592            Type::Parameter(name) => {
593                names.remove(name);
594            }
595            Type::Tuple(elements) => {
596                for element in elements {
597                    element.remove_found_type_names(names);
598                }
599            }
600            Type::Never | Type::Error => {}
601        }
602    }
603}
604
605impl Type {
606    pub fn get_name(&self) -> Option<&str> {
607        match self {
608            Type::Constructor { id, params, .. } => {
609                let name = unqualified_name(id);
610                if name == "Ref" {
611                    return params.first().and_then(|inner| inner.get_name());
612                }
613                if let Some(module_path) = id.strip_prefix("@import/") {
614                    let path = module_path.strip_prefix("go:").unwrap_or(module_path);
615                    return path.rsplit('/').next();
616                }
617                Some(name)
618            }
619            _ => None,
620        }
621    }
622
623    pub fn wraps(&self, name: &str, inner: &Type) -> bool {
624        self.get_name().is_some_and(|n| n == name)
625            && self
626                .get_type_params()
627                .and_then(|p| p.first())
628                .is_some_and(|first| *first == *inner)
629    }
630
631    pub fn get_function_params(&self) -> Option<&[Type]> {
632        match self {
633            Type::Function { params, .. } => Some(params),
634            Type::Constructor {
635                underlying_ty: Some(inner),
636                ..
637            } => inner.get_function_params(),
638            _ => None,
639        }
640    }
641
642    pub fn param_count(&self) -> usize {
643        match self {
644            Type::Function { params, .. } => params.len(),
645            _ => 0,
646        }
647    }
648
649    pub fn get_param_mutability(&self) -> &[bool] {
650        match self {
651            Type::Function {
652                param_mutability, ..
653            } => param_mutability,
654            _ => &[],
655        }
656    }
657
658    pub fn with_replaced_first_param(&self, new_first: &Type) -> Type {
659        match self {
660            Type::Function {
661                params,
662                param_mutability,
663                bounds,
664                return_type,
665            } => {
666                if params.is_empty() {
667                    return self.clone();
668                }
669                let mut new_params = params.clone();
670                new_params[0] = new_first.clone();
671                Type::Function {
672                    params: new_params,
673                    param_mutability: param_mutability.clone(),
674                    bounds: bounds.clone(),
675                    return_type: return_type.clone(),
676                }
677            }
678            Type::Forall { vars, body } => Type::Forall {
679                vars: vars.clone(),
680                body: Box::new(body.with_replaced_first_param(new_first)),
681            },
682            _ => self.clone(),
683        }
684    }
685
686    pub fn get_bounds(&self) -> &[Bound] {
687        match self {
688            Type::Function { bounds, .. } => bounds,
689            Type::Forall { body, .. } => body.get_bounds(),
690            _ => &[],
691        }
692    }
693
694    pub fn get_qualified_name(&self) -> EcoString {
695        match self.strip_refs() {
696            Type::Constructor { id, .. } => id,
697            _ => panic!("called get_qualified_name on {:#?}", self),
698        }
699    }
700
701    pub fn inner(&self) -> Option<Type> {
702        self.get_type_params()
703            .and_then(|args| args.first().cloned())
704    }
705
706    pub fn ok_type(&self) -> Type {
707        debug_assert!(
708            self.is_result() || self.is_option() || self.is_partial(),
709            "ok_type called on non-Result/Option/Partial type"
710        );
711        self.inner()
712            .expect("Result/Option/Partial should have inner type")
713    }
714
715    pub fn err_type(&self) -> Type {
716        debug_assert!(
717            self.is_result() || self.is_partial(),
718            "err_type called on non-Result/Partial type"
719        );
720        self.get_type_params()
721            .and_then(|args| args.get(1).cloned())
722            .expect("Result/Partial should have error type")
723    }
724}
725
726impl Type {
727    pub fn unwrap_forall(&self) -> &Type {
728        match self {
729            Type::Forall { body, .. } => body.as_ref(),
730            other => other,
731        }
732    }
733
734    pub fn strip_refs(&self) -> Type {
735        if self.is_ref() {
736            return self.inner().expect("ref type must have inner").strip_refs();
737        }
738
739        self.clone()
740    }
741
742    pub fn with_receiver_placeholder(self) -> Type {
743        match self {
744            Type::Function {
745                params,
746                param_mutability,
747                bounds,
748                return_type,
749            } => {
750                let mut new_params = vec![Type::nominal("__receiver__")];
751                new_params.extend(params);
752
753                let mut new_mutability = vec![false];
754                new_mutability.extend(param_mutability);
755
756                Type::Function {
757                    params: new_params,
758                    param_mutability: new_mutability,
759                    bounds,
760                    return_type,
761                }
762            }
763            _ => unreachable!(
764                "with_receiver_placeholder called on non-function type: {:?}",
765                self
766            ),
767        }
768    }
769
770    pub fn remove_vars(types: &[&Type]) -> (Vec<Type>, Vec<EcoString>) {
771        let mut vars = HashMap::default();
772        let types = types
773            .iter()
774            .map(|v| Self::remove_vars_impl(v, &mut vars))
775            .collect();
776
777        (types, vars.into_values().collect())
778    }
779
780    fn remove_vars_impl(ty: &Type, vars: &mut HashMap<i32, EcoString>) -> Type {
781        match ty {
782            Type::Constructor {
783                id: name,
784                params: args,
785                underlying_ty: underlying,
786            } => Type::Constructor {
787                id: name.clone(),
788                params: args
789                    .iter()
790                    .map(|a| Self::remove_vars_impl(a, vars))
791                    .collect(),
792                underlying_ty: underlying
793                    .as_ref()
794                    .map(|u| Box::new(Self::remove_vars_impl(u, vars))),
795            },
796
797            Type::Function {
798                params: args,
799                param_mutability,
800                bounds,
801                return_type,
802            } => Type::Function {
803                params: args
804                    .iter()
805                    .map(|a| Self::remove_vars_impl(a, vars))
806                    .collect(),
807                param_mutability: param_mutability.clone(),
808                bounds: bounds
809                    .iter()
810                    .map(|b| Bound {
811                        param_name: b.param_name.clone(),
812                        generic: Self::remove_vars_impl(&b.generic, vars),
813                        ty: Self::remove_vars_impl(&b.ty, vars),
814                    })
815                    .collect(),
816                return_type: Self::remove_vars_impl(return_type, vars).into(),
817            },
818
819            Type::Variable(type_var) => match &*type_var.borrow() {
820                TypeVariableState::Unbound { id, hint } => match vars.get(id) {
821                    Some(g) => Self::nominal(g),
822                    None => {
823                        let name: EcoString = hint.clone().unwrap_or_else(|| {
824                            char::from_digit(
825                                (vars.len() + 10)
826                                    .try_into()
827                                    .expect("type var count fits in u32"),
828                                16,
829                            )
830                            .expect("type var index is valid hex digit")
831                            .to_uppercase()
832                            .to_string()
833                            .into()
834                        });
835
836                        vars.insert(*id, name.clone());
837                        Self::nominal(&name)
838                    }
839                },
840                TypeVariableState::Link(ty) => Self::remove_vars_impl(ty, vars),
841            },
842
843            Type::Forall { body, .. } => Self::remove_vars_impl(body, vars),
844            Type::Tuple(elements) => Type::Tuple(
845                elements
846                    .iter()
847                    .map(|e| Self::remove_vars_impl(e, vars))
848                    .collect(),
849            ),
850            Type::Parameter(name) => Type::Parameter(name.clone()),
851            Type::Never | Type::Error => ty.clone(),
852        }
853    }
854
855    pub fn contains_type(&self, target: &Type) -> bool {
856        if *self == *target {
857            return true;
858        }
859        match self {
860            Type::Constructor { params, .. } => params.iter().any(|p| p.contains_type(target)),
861            Type::Function {
862                params,
863                return_type,
864                ..
865            } => {
866                params.iter().any(|p| p.contains_type(target)) || return_type.contains_type(target)
867            }
868            Type::Variable(var) => {
869                if let TypeVariableState::Link(linked) = &*var.borrow() {
870                    linked.contains_type(target)
871                } else {
872                    false
873                }
874            }
875            Type::Forall { body, .. } => body.contains_type(target),
876            Type::Tuple(elements) => elements.iter().any(|e| e.contains_type(target)),
877            Type::Parameter(_) | Type::Never | Type::Error => false,
878        }
879    }
880
881    /// Follow Variable::Link chains to the outermost non-variable type.
882    /// Does NOT recurse into Constructor params, Function params, etc.
883    /// Use this when you only need the outermost type (e.g. is_never, is_unknown, has_name).
884    pub fn shallow_resolve(&self) -> Type {
885        match self {
886            Type::Variable(type_var) => {
887                let state = type_var.borrow();
888                match &*state {
889                    TypeVariableState::Unbound { .. } => self.clone(),
890                    TypeVariableState::Link(linked) => linked.shallow_resolve(),
891                }
892            }
893            _ => self.clone(),
894        }
895    }
896
897    pub fn resolve(&self) -> Type {
898        match self {
899            Type::Variable(type_var) => {
900                let state = type_var.borrow();
901                match &*state {
902                    TypeVariableState::Unbound { .. } => self.clone(),
903                    TypeVariableState::Link(linked) => {
904                        let resolved = linked.resolve();
905                        drop(state);
906                        *type_var.borrow_mut() = TypeVariableState::Link(resolved.clone());
907                        resolved
908                    }
909                }
910            }
911            Type::Constructor {
912                id,
913                params,
914                underlying_ty: underlying,
915            } => Type::Constructor {
916                id: id.clone(),
917                params: params.iter().map(|p| p.resolve()).collect(),
918                underlying_ty: underlying.as_ref().map(|u| Box::new(u.resolve())),
919            },
920            Type::Function {
921                params,
922                param_mutability,
923                bounds,
924                return_type,
925            } => Type::Function {
926                params: params.iter().map(|p| p.resolve()).collect(),
927                param_mutability: param_mutability.clone(),
928                bounds: bounds
929                    .iter()
930                    .map(|b| Bound {
931                        param_name: b.param_name.clone(),
932                        generic: b.generic.resolve(),
933                        ty: b.ty.resolve(),
934                    })
935                    .collect(),
936                return_type: Box::new(return_type.resolve()),
937            },
938            Type::Forall { .. } => {
939                unreachable!("Forall types are always instantiated before resolve")
940            }
941            Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| e.resolve()).collect()),
942            Type::Parameter(_) | Type::Error => self.clone(),
943            Type::Never => Type::Never,
944        }
945    }
946}
947
948#[derive(Debug, Clone, Copy, PartialEq, Eq)]
949pub enum NumericFamily {
950    SignedInt,
951    UnsignedInt,
952    Float,
953}
954
955const SIGNED_INT_TYPES: &[&str] = &["int", "int8", "int16", "int32", "int64", "rune"];
956const FLOAT_TYPES: &[&str] = &["float32", "float64"];
957
958impl Type {
959    pub fn underlying_numeric_type(&self) -> Option<Type> {
960        self.underlying_numeric_type_recursive(&mut HashSet::default())
961    }
962
963    pub fn has_underlying_numeric_type(&self) -> bool {
964        self.underlying_numeric_type().is_some()
965    }
966
967    fn underlying_numeric_type_recursive(&self, visited: &mut HashSet<EcoString>) -> Option<Type> {
968        match self {
969            Type::Constructor {
970                id,
971                underlying_ty: underlying,
972                ..
973            } => {
974                if self.is_numeric() {
975                    return Some(self.clone());
976                }
977
978                if !visited.insert(id.clone()) {
979                    return None;
980                }
981
982                underlying
983                    .as_ref()?
984                    .underlying_numeric_type_recursive(visited)
985            }
986            _ => None,
987        }
988    }
989
990    pub fn numeric_family(&self) -> Option<NumericFamily> {
991        let name = match self {
992            Type::Constructor { id, .. } => unqualified_name(id),
993            _ => return None,
994        };
995
996        if SIGNED_INT_TYPES.contains(&name) {
997            Some(NumericFamily::SignedInt)
998        } else if UNSIGNED_INT_TYPES.contains(&name) {
999            Some(NumericFamily::UnsignedInt)
1000        } else if FLOAT_TYPES.contains(&name) {
1001            Some(NumericFamily::Float)
1002        } else {
1003            None
1004        }
1005    }
1006
1007    pub fn is_numeric_compatible_with(&self, other: &Type) -> bool {
1008        let self_underlying_ty = self.underlying_numeric_type();
1009        let other_underlying_ty = other.underlying_numeric_type();
1010
1011        match (self_underlying_ty, other_underlying_ty) {
1012            (Some(s), Some(o)) => s.numeric_family() == o.numeric_family(),
1013            _ => false,
1014        }
1015    }
1016
1017    pub fn is_aliased_numeric_type(&self) -> bool {
1018        match self {
1019            Type::Constructor { underlying_ty, .. } => {
1020                underlying_ty.is_some() && !self.is_numeric()
1021            }
1022            _ => false,
1023        }
1024    }
1025}