Skip to main content

lisette_syntax/
types.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::{OnceCell, RefCell};
3use std::rc::Rc;
4
5use ecow::EcoString;
6
7/// Extract the unqualified name from a dot-qualified identifier.
8///
9/// `"prelude.Option"` → `"Option"`, `"**nominal.int"` → `"int"`, `"foo"` → `"foo"`
10pub fn unqualified_name(id: &str) -> &str {
11    id.rsplit('.').next().unwrap_or(id)
12}
13
14/// type param name -> type variable
15pub type SubstitutionMap = HashMap<EcoString, Type>;
16
17pub fn substitute(ty: &Type, map: &HashMap<EcoString, Type>) -> Type {
18    if map.is_empty() {
19        return ty.clone();
20    }
21    match ty {
22        Type::Parameter(name) => map.get(name).cloned().unwrap_or_else(|| ty.clone()),
23        Type::Constructor {
24            id,
25            params,
26            underlying_ty: underlying,
27        } => Type::Constructor {
28            id: id.clone(),
29            params: params.iter().map(|p| substitute(p, map)).collect(),
30            underlying_ty: underlying.as_ref().map(|u| Box::new(substitute(u, map))),
31        },
32        Type::Function {
33            params,
34            param_mutability,
35            bounds,
36            return_type,
37        } => Type::Function {
38            params: params.iter().map(|p| substitute(p, map)).collect(),
39            param_mutability: param_mutability.clone(),
40            bounds: bounds
41                .iter()
42                .map(|b| Bound {
43                    param_name: b.param_name.clone(),
44                    generic: substitute(&b.generic, map),
45                    ty: substitute(&b.ty, map),
46                })
47                .collect(),
48            return_type: Box::new(substitute(return_type, map)),
49        },
50        Type::Variable(_) | Type::Error => ty.clone(),
51        Type::Forall { vars, body } => {
52            let has_overlap = map.keys().any(|k| vars.contains(k));
53            let substituted_body = if has_overlap {
54                let filtered_map: HashMap<EcoString, Type> = map
55                    .iter()
56                    .filter(|(k, _)| !vars.contains(*k))
57                    .map(|(k, v)| (k.clone(), v.clone()))
58                    .collect();
59                substitute(body, &filtered_map)
60            } else {
61                substitute(body, map)
62            };
63            Type::Forall {
64                vars: vars.clone(),
65                body: Box::new(substituted_body),
66            }
67        }
68        Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| substitute(e, map)).collect()),
69        Type::Never => ty.clone(),
70    }
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct Bound {
75    pub param_name: EcoString,
76    pub generic: Type,
77    pub ty: Type,
78}
79
80#[derive(Clone)]
81pub enum TypeVariableState {
82    Unbound { id: i32, hint: Option<EcoString> },
83    Link(Type),
84}
85
86impl TypeVariableState {
87    pub fn is_unbound(&self) -> bool {
88        matches!(self, TypeVariableState::Unbound { .. })
89    }
90}
91
92impl std::fmt::Debug for TypeVariableState {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            TypeVariableState::Unbound { id, hint } => match hint {
96                Some(name) => write!(f, "{}", name),
97                None => write!(f, "{}", id),
98            },
99            TypeVariableState::Link(ty) => write!(f, "{:?}", ty),
100        }
101    }
102}
103
104impl PartialEq for TypeVariableState {
105    fn eq(&self, other: &Self) -> bool {
106        match (self, other) {
107            (
108                TypeVariableState::Unbound { id: id1, .. },
109                TypeVariableState::Unbound { id: id2, .. },
110            ) => id1 == id2,
111            (TypeVariableState::Link(ty1), TypeVariableState::Link(ty2)) => ty1 == ty2,
112            _ => false,
113        }
114    }
115}
116
117#[derive(Clone)]
118pub enum Type {
119    Constructor {
120        id: EcoString,
121        params: Vec<Type>,
122        underlying_ty: Option<Box<Type>>,
123    },
124
125    Function {
126        params: Vec<Type>,
127        param_mutability: Vec<bool>,
128        bounds: Vec<Bound>,
129        return_type: Box<Type>,
130    },
131
132    Variable(Rc<RefCell<TypeVariableState>>),
133
134    Forall {
135        vars: Vec<EcoString>,
136        body: Box<Type>,
137    },
138
139    Parameter(EcoString),
140
141    Never,
142
143    Tuple(Vec<Type>),
144
145    /// Poison type returned after an error has been reported.
146    /// Unifies with everything silently, preventing cascading diagnostics.
147    Error,
148}
149
150impl std::fmt::Debug for Type {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            Type::Constructor { id, params, .. } => f
154                .debug_struct("Constructor")
155                .field("id", id)
156                .field("params", params)
157                .finish(),
158            Type::Function {
159                params,
160                param_mutability,
161                bounds,
162                return_type,
163            } => {
164                let mut s = f.debug_struct("Function");
165                s.field("params", params);
166                if param_mutability.iter().any(|m| *m) {
167                    s.field("param_mutability", param_mutability);
168                }
169                s.field("bounds", bounds)
170                    .field("return_type", return_type)
171                    .finish()
172            }
173            Type::Variable(type_var) => f
174                .debug_tuple("Variable")
175                .field(&*type_var.borrow())
176                .finish(),
177            Type::Forall { vars, body } => f
178                .debug_struct("Forall")
179                .field("vars", vars)
180                .field("body", body)
181                .finish(),
182            Type::Parameter(name) => f.debug_tuple("Parameter").field(name).finish(),
183            Type::Never => write!(f, "Never"),
184            Type::Tuple(elements) => f.debug_tuple("Tuple").field(elements).finish(),
185            Type::Error => write!(f, "Error"),
186        }
187    }
188}
189
190impl PartialEq for Type {
191    fn eq(&self, other: &Self) -> bool {
192        match (self, other) {
193            (
194                Type::Constructor {
195                    id: id1,
196                    params: params1,
197                    ..
198                },
199                Type::Constructor {
200                    id: id2,
201                    params: params2,
202                    ..
203                },
204            ) => id1 == id2 && params1 == params2,
205            (
206                Type::Function {
207                    params: p1,
208                    param_mutability: m1,
209                    bounds: b1,
210                    return_type: r1,
211                },
212                Type::Function {
213                    params: p2,
214                    param_mutability: m2,
215                    bounds: b2,
216                    return_type: r2,
217                },
218            ) => p1 == p2 && m1 == m2 && b1 == b2 && r1 == r2,
219            (Type::Variable(v1), Type::Variable(v2)) => {
220                Rc::ptr_eq(v1, v2) || *v1.borrow() == *v2.borrow()
221            }
222            (
223                Type::Forall {
224                    vars: vars1,
225                    body: body1,
226                },
227                Type::Forall {
228                    vars: vars2,
229                    body: body2,
230                },
231            ) => vars1 == vars2 && body1 == body2,
232            (Type::Parameter(name1), Type::Parameter(name2)) => name1 == name2,
233            (Type::Never, Type::Never) => true,
234            (Type::Tuple(elems1), Type::Tuple(elems2)) => elems1 == elems2,
235            _ => false,
236        }
237    }
238}
239
240thread_local! {
241    static INTERNED_INT: OnceCell<Type> = const { OnceCell::new() };
242    static INTERNED_STRING: OnceCell<Type> = const { OnceCell::new() };
243    static INTERNED_BOOL: OnceCell<Type> = const { OnceCell::new() };
244    static INTERNED_UNIT: OnceCell<Type> = const { OnceCell::new() };
245    static INTERNED_FLOAT64: OnceCell<Type> = const { OnceCell::new() };
246    static INTERNED_RUNE: OnceCell<Type> = const { OnceCell::new() };
247}
248
249impl Type {
250    pub fn int() -> Type {
251        INTERNED_INT.with(|cell| cell.get_or_init(|| Self::nominal("int")).clone())
252    }
253
254    pub fn string() -> Type {
255        INTERNED_STRING.with(|cell| cell.get_or_init(|| Self::nominal("string")).clone())
256    }
257
258    pub fn bool() -> Type {
259        INTERNED_BOOL.with(|cell| cell.get_or_init(|| Self::nominal("bool")).clone())
260    }
261
262    pub fn unit() -> Type {
263        INTERNED_UNIT.with(|cell| cell.get_or_init(|| Self::nominal("Unit")).clone())
264    }
265
266    pub fn float64() -> Type {
267        INTERNED_FLOAT64.with(|cell| cell.get_or_init(|| Self::nominal("float64")).clone())
268    }
269
270    pub fn rune() -> Type {
271        INTERNED_RUNE.with(|cell| cell.get_or_init(|| Self::nominal("rune")).clone())
272    }
273}
274
275impl Type {
276    const UNINFERRED_ID: i32 = -1;
277    const IGNORED_ID: i32 = -333;
278
279    pub fn nominal(name: &str) -> Self {
280        Self::Constructor {
281            id: format!("**nominal.{}", name).into(),
282            params: vec![],
283            underlying_ty: None,
284        }
285    }
286
287    pub fn uninferred() -> Self {
288        Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
289            id: Self::UNINFERRED_ID,
290            hint: None,
291        })))
292    }
293
294    pub fn ignored() -> Self {
295        Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
296            id: Self::IGNORED_ID,
297            hint: None,
298        })))
299    }
300
301    pub fn get_type_params(&self) -> Option<&[Type]> {
302        match self {
303            Type::Constructor { params, .. } => Some(params),
304            _ => None,
305        }
306    }
307}
308
309const ARITHMETIC_TYPES: &[&str] = &[
310    "byte",
311    "complex128",
312    "complex64",
313    "float32",
314    "float64",
315    "int",
316    "int16",
317    "int32",
318    "int64",
319    "int8",
320    "rune",
321    "uint",
322    "uint16",
323    "uint32",
324    "uint64",
325    "uint8",
326];
327
328const ORDERED_TYPES: &[&str] = &[
329    "byte", "float32", "float64", "int", "int16", "int32", "int64", "int8", "rune", "uint",
330    "uint16", "uint32", "uint64", "uint8",
331];
332
333const UNSIGNED_INT_TYPES: &[&str] = &["byte", "uint", "uint8", "uint16", "uint32", "uint64"];
334
335impl Type {
336    pub fn get_function_ret(&self) -> Option<&Type> {
337        match self {
338            Type::Function { return_type, .. } => Some(return_type),
339            _ => None,
340        }
341    }
342
343    pub fn has_name(&self, name: &str) -> bool {
344        match self {
345            Type::Constructor { id, .. } => unqualified_name(id) == name,
346            _ => false,
347        }
348    }
349
350    pub fn get_qualified_id(&self) -> Option<&str> {
351        match self {
352            Type::Constructor { id, .. } => Some(id.as_str()),
353            _ => None,
354        }
355    }
356
357    pub fn get_underlying(&self) -> Option<&Type> {
358        match self {
359            Type::Constructor {
360                underlying_ty: underlying,
361                ..
362            } => underlying.as_deref(),
363            _ => None,
364        }
365    }
366
367    pub fn is_result(&self) -> bool {
368        self.has_qualified_id("prelude.Result")
369    }
370
371    pub fn is_option(&self) -> bool {
372        self.has_qualified_id("prelude.Option")
373    }
374
375    pub fn is_partial(&self) -> bool {
376        self.has_qualified_id("prelude.Partial")
377    }
378
379    fn has_qualified_id(&self, qualified_id: &str) -> bool {
380        matches!(self, Type::Constructor { id, .. } if id.as_str() == qualified_id)
381    }
382
383    pub fn is_unit(&self) -> bool {
384        matches!(self.resolve(), Type::Constructor { ref id, .. } if id.as_ref() == "**nominal.Unit")
385    }
386
387    pub fn tuple_arity(&self) -> Option<usize> {
388        match self {
389            Type::Tuple(elements) => Some(elements.len()),
390            _ => None,
391        }
392    }
393
394    pub fn is_tuple(&self) -> bool {
395        matches!(self, Type::Tuple(_))
396    }
397
398    pub fn is_ref(&self) -> bool {
399        self.has_name("Ref")
400    }
401
402    pub fn is_receiver_placeholder(&self) -> bool {
403        self.has_name("__receiver__")
404    }
405
406    pub fn is_unknown(&self) -> bool {
407        self.has_name("Unknown")
408    }
409
410    pub fn is_receiver(&self) -> bool {
411        self.has_name("Receiver")
412    }
413
414    pub fn is_ignored(&self) -> bool {
415        match self {
416            Type::Variable(var) => {
417                matches!(&*var.borrow(), TypeVariableState::Unbound { id, .. } if *id == Self::IGNORED_ID)
418            }
419            _ => false,
420        }
421    }
422
423    pub fn is_variadic(&self) -> Option<Type> {
424        let args = self.get_function_params()?;
425        let last = args.last()?;
426
427        if last.get_name()? == "VarArgs" {
428            return last.inner();
429        }
430
431        None
432    }
433
434    pub fn is_string(&self) -> bool {
435        self.has_name("string")
436    }
437
438    pub fn is_slice_of(&self, element_name: &str) -> bool {
439        match self {
440            Type::Constructor { id, params, .. } => {
441                if unqualified_name(id) != "Slice" || params.len() != 1 {
442                    return false;
443                }
444                params[0].resolve().has_name(element_name)
445            }
446            _ => false,
447        }
448    }
449
450    pub fn is_byte_slice(&self) -> bool {
451        self.is_slice_of("byte") || self.is_slice_of("uint8")
452    }
453
454    pub fn is_rune_slice(&self) -> bool {
455        self.is_slice_of("rune")
456    }
457
458    pub fn is_byte_or_rune_slice(&self) -> bool {
459        self.is_byte_slice() || self.is_rune_slice()
460    }
461
462    pub fn has_byte_or_rune_slice_underlying(&self) -> bool {
463        if self.is_byte_or_rune_slice() {
464            return true;
465        }
466        match self {
467            Type::Constructor { underlying_ty, .. } => underlying_ty
468                .as_deref()
469                .is_some_and(|u| u.has_byte_or_rune_slice_underlying()),
470            _ => false,
471        }
472    }
473
474    pub fn is_boolean(&self) -> bool {
475        self.has_name("bool")
476    }
477
478    pub fn is_rune(&self) -> bool {
479        self.has_name("rune")
480    }
481
482    pub fn is_float64(&self) -> bool {
483        self.has_name("float64")
484    }
485
486    pub fn is_float32(&self) -> bool {
487        self.has_name("float32")
488    }
489
490    pub fn is_float(&self) -> bool {
491        self.is_float64() || self.is_float32()
492    }
493
494    pub fn is_variable(&self) -> bool {
495        matches!(self, Type::Variable(_))
496    }
497
498    pub fn is_unbound_variable(&self) -> bool {
499        matches!(self, Type::Variable(cell) if cell.borrow().is_unbound())
500    }
501
502    pub fn is_numeric(&self) -> bool {
503        match self {
504            Type::Constructor { id, .. } => ARITHMETIC_TYPES.contains(&unqualified_name(id)),
505            _ => false,
506        }
507    }
508
509    pub fn is_ordered(&self) -> bool {
510        match self {
511            Type::Constructor { id, .. } => ORDERED_TYPES.contains(&unqualified_name(id)),
512            _ => false,
513        }
514    }
515
516    pub fn is_complex(&self) -> bool {
517        match self {
518            Type::Constructor { id, .. } => {
519                matches!(unqualified_name(id), "complex128" | "complex64")
520            }
521            _ => false,
522        }
523    }
524
525    pub fn is_unsigned_int(&self) -> bool {
526        match self {
527            Type::Constructor { id, .. } => UNSIGNED_INT_TYPES.contains(&unqualified_name(id)),
528            _ => false,
529        }
530    }
531
532    pub fn is_never(&self) -> bool {
533        matches!(self.shallow_resolve(), Type::Never)
534    }
535
536    pub fn is_error(&self) -> bool {
537        matches!(self.shallow_resolve(), Type::Error)
538    }
539
540    pub fn has_unbound_variables(&self) -> bool {
541        match self {
542            Type::Variable(type_var) => match &*type_var.borrow() {
543                TypeVariableState::Unbound { hint, .. } => hint.is_some(),
544                TypeVariableState::Link(ty) => ty.has_unbound_variables(),
545            },
546            Type::Constructor { params, .. } => params.iter().any(|p| p.has_unbound_variables()),
547            Type::Function {
548                params,
549                return_type,
550                ..
551            } => {
552                params.iter().any(|p| p.has_unbound_variables())
553                    || return_type.has_unbound_variables()
554            }
555            Type::Forall { body, .. } => body.has_unbound_variables(),
556            Type::Tuple(elements) => elements.iter().any(|e| e.has_unbound_variables()),
557            Type::Parameter(_) | Type::Never | Type::Error => false,
558        }
559    }
560
561    pub fn remove_found_type_names(&self, names: &mut HashSet<EcoString>) {
562        if names.is_empty() {
563            return;
564        }
565
566        match self {
567            Type::Constructor { id, params, .. } => {
568                names.remove(unqualified_name(id));
569                for param in params {
570                    param.remove_found_type_names(names);
571                }
572            }
573            Type::Function {
574                params,
575                return_type,
576                bounds,
577                ..
578            } => {
579                for param in params {
580                    param.remove_found_type_names(names);
581                }
582                return_type.remove_found_type_names(names);
583                for bound in bounds {
584                    bound.generic.remove_found_type_names(names);
585                    bound.ty.remove_found_type_names(names);
586                }
587            }
588            Type::Forall { body, .. } => {
589                body.remove_found_type_names(names);
590            }
591            Type::Variable(type_var) => {
592                if let TypeVariableState::Link(ty) = &*type_var.borrow() {
593                    ty.remove_found_type_names(names);
594                }
595            }
596            Type::Parameter(name) => {
597                names.remove(name);
598            }
599            Type::Tuple(elements) => {
600                for element in elements {
601                    element.remove_found_type_names(names);
602                }
603            }
604            Type::Never | Type::Error => {}
605        }
606    }
607}
608
609impl Type {
610    pub fn get_name(&self) -> Option<&str> {
611        match self {
612            Type::Constructor { id, params, .. } => {
613                let name = unqualified_name(id);
614                if name == "Ref" {
615                    return params.first().and_then(|inner| inner.get_name());
616                }
617                if let Some(module_path) = id.strip_prefix("@import/") {
618                    let path = module_path.strip_prefix("go:").unwrap_or(module_path);
619                    return path.rsplit('/').next();
620                }
621                Some(name)
622            }
623            _ => None,
624        }
625    }
626
627    pub fn wraps(&self, name: &str, inner: &Type) -> bool {
628        self.get_name().is_some_and(|n| n == name)
629            && self
630                .get_type_params()
631                .and_then(|p| p.first())
632                .is_some_and(|first| *first == *inner)
633    }
634
635    pub fn get_function_params(&self) -> Option<&[Type]> {
636        match self {
637            Type::Function { params, .. } => Some(params),
638            Type::Constructor {
639                underlying_ty: Some(inner),
640                ..
641            } => inner.get_function_params(),
642            _ => None,
643        }
644    }
645
646    pub fn param_count(&self) -> usize {
647        match self {
648            Type::Function { params, .. } => params.len(),
649            _ => 0,
650        }
651    }
652
653    pub fn get_param_mutability(&self) -> &[bool] {
654        match self {
655            Type::Function {
656                param_mutability, ..
657            } => param_mutability,
658            _ => &[],
659        }
660    }
661
662    pub fn with_replaced_first_param(&self, new_first: &Type) -> Type {
663        match self {
664            Type::Function {
665                params,
666                param_mutability,
667                bounds,
668                return_type,
669            } => {
670                if params.is_empty() {
671                    return self.clone();
672                }
673                let mut new_params = params.clone();
674                new_params[0] = new_first.clone();
675                Type::Function {
676                    params: new_params,
677                    param_mutability: param_mutability.clone(),
678                    bounds: bounds.clone(),
679                    return_type: return_type.clone(),
680                }
681            }
682            Type::Forall { vars, body } => Type::Forall {
683                vars: vars.clone(),
684                body: Box::new(body.with_replaced_first_param(new_first)),
685            },
686            _ => self.clone(),
687        }
688    }
689
690    pub fn get_bounds(&self) -> &[Bound] {
691        match self {
692            Type::Function { bounds, .. } => bounds,
693            Type::Forall { body, .. } => body.get_bounds(),
694            _ => &[],
695        }
696    }
697
698    pub fn get_qualified_name(&self) -> EcoString {
699        match self.strip_refs() {
700            Type::Constructor { id, .. } => id,
701            _ => panic!("called get_qualified_name on {:#?}", self),
702        }
703    }
704
705    pub fn inner(&self) -> Option<Type> {
706        self.get_type_params()
707            .and_then(|args| args.first().cloned())
708    }
709
710    pub fn ok_type(&self) -> Type {
711        debug_assert!(
712            self.is_result() || self.is_option() || self.is_partial(),
713            "ok_type called on non-Result/Option/Partial type"
714        );
715        self.inner()
716            .expect("Result/Option/Partial should have inner type")
717    }
718
719    pub fn err_type(&self) -> Type {
720        debug_assert!(
721            self.is_result() || self.is_partial(),
722            "err_type called on non-Result/Partial type"
723        );
724        self.get_type_params()
725            .and_then(|args| args.get(1).cloned())
726            .expect("Result/Partial should have error type")
727    }
728}
729
730impl Type {
731    pub fn unwrap_forall(&self) -> &Type {
732        match self {
733            Type::Forall { body, .. } => body.as_ref(),
734            other => other,
735        }
736    }
737
738    pub fn strip_refs(&self) -> Type {
739        if self.is_ref() {
740            return self.inner().expect("ref type must have inner").strip_refs();
741        }
742
743        self.clone()
744    }
745
746    pub fn with_receiver_placeholder(self) -> Type {
747        match self {
748            Type::Function {
749                params,
750                param_mutability,
751                bounds,
752                return_type,
753            } => {
754                let mut new_params = vec![Type::nominal("__receiver__")];
755                new_params.extend(params);
756
757                let mut new_mutability = vec![false];
758                new_mutability.extend(param_mutability);
759
760                Type::Function {
761                    params: new_params,
762                    param_mutability: new_mutability,
763                    bounds,
764                    return_type,
765                }
766            }
767            _ => unreachable!(
768                "with_receiver_placeholder called on non-function type: {:?}",
769                self
770            ),
771        }
772    }
773
774    pub fn remove_vars(types: &[&Type]) -> (Vec<Type>, Vec<EcoString>) {
775        let mut vars = HashMap::default();
776        let types = types
777            .iter()
778            .map(|v| Self::remove_vars_impl(v, &mut vars))
779            .collect();
780
781        (types, vars.into_values().collect())
782    }
783
784    fn remove_vars_impl(ty: &Type, vars: &mut HashMap<i32, EcoString>) -> Type {
785        match ty {
786            Type::Constructor {
787                id: name,
788                params: args,
789                underlying_ty: underlying,
790            } => Type::Constructor {
791                id: name.clone(),
792                params: args
793                    .iter()
794                    .map(|a| Self::remove_vars_impl(a, vars))
795                    .collect(),
796                underlying_ty: underlying
797                    .as_ref()
798                    .map(|u| Box::new(Self::remove_vars_impl(u, vars))),
799            },
800
801            Type::Function {
802                params: args,
803                param_mutability,
804                bounds,
805                return_type,
806            } => Type::Function {
807                params: args
808                    .iter()
809                    .map(|a| Self::remove_vars_impl(a, vars))
810                    .collect(),
811                param_mutability: param_mutability.clone(),
812                bounds: bounds
813                    .iter()
814                    .map(|b| Bound {
815                        param_name: b.param_name.clone(),
816                        generic: Self::remove_vars_impl(&b.generic, vars),
817                        ty: Self::remove_vars_impl(&b.ty, vars),
818                    })
819                    .collect(),
820                return_type: Self::remove_vars_impl(return_type, vars).into(),
821            },
822
823            Type::Variable(type_var) => match &*type_var.borrow() {
824                TypeVariableState::Unbound { id, hint } => match vars.get(id) {
825                    Some(g) => Self::nominal(g),
826                    None => {
827                        let name: EcoString = hint.clone().unwrap_or_else(|| {
828                            char::from_digit(
829                                (vars.len() + 10)
830                                    .try_into()
831                                    .expect("type var count fits in u32"),
832                                16,
833                            )
834                            .expect("type var index is valid hex digit")
835                            .to_uppercase()
836                            .to_string()
837                            .into()
838                        });
839
840                        vars.insert(*id, name.clone());
841                        Self::nominal(&name)
842                    }
843                },
844                TypeVariableState::Link(ty) => Self::remove_vars_impl(ty, vars),
845            },
846
847            Type::Forall { body, .. } => Self::remove_vars_impl(body, vars),
848            Type::Tuple(elements) => Type::Tuple(
849                elements
850                    .iter()
851                    .map(|e| Self::remove_vars_impl(e, vars))
852                    .collect(),
853            ),
854            Type::Parameter(name) => Type::Parameter(name.clone()),
855            Type::Never | Type::Error => ty.clone(),
856        }
857    }
858
859    pub fn contains_type(&self, target: &Type) -> bool {
860        if *self == *target {
861            return true;
862        }
863        match self {
864            Type::Constructor { params, .. } => params.iter().any(|p| p.contains_type(target)),
865            Type::Function {
866                params,
867                return_type,
868                ..
869            } => {
870                params.iter().any(|p| p.contains_type(target)) || return_type.contains_type(target)
871            }
872            Type::Variable(var) => {
873                if let TypeVariableState::Link(linked) = &*var.borrow() {
874                    linked.contains_type(target)
875                } else {
876                    false
877                }
878            }
879            Type::Forall { body, .. } => body.contains_type(target),
880            Type::Tuple(elements) => elements.iter().any(|e| e.contains_type(target)),
881            Type::Parameter(_) | Type::Never | Type::Error => false,
882        }
883    }
884
885    /// Follow Variable::Link chains to the outermost non-variable type.
886    /// Does NOT recurse into Constructor params, Function params, etc.
887    /// Use this when you only need the outermost type (e.g. is_never, is_unknown, has_name).
888    pub fn shallow_resolve(&self) -> Type {
889        match self {
890            Type::Variable(type_var) => {
891                let state = type_var.borrow();
892                match &*state {
893                    TypeVariableState::Unbound { .. } => self.clone(),
894                    TypeVariableState::Link(linked) => linked.shallow_resolve(),
895                }
896            }
897            _ => self.clone(),
898        }
899    }
900
901    pub fn resolve(&self) -> Type {
902        match self {
903            Type::Variable(type_var) => {
904                let state = type_var.borrow();
905                match &*state {
906                    TypeVariableState::Unbound { .. } => self.clone(),
907                    TypeVariableState::Link(linked) => {
908                        let resolved = linked.resolve();
909                        drop(state);
910                        *type_var.borrow_mut() = TypeVariableState::Link(resolved.clone());
911                        resolved
912                    }
913                }
914            }
915            Type::Constructor {
916                id,
917                params,
918                underlying_ty: underlying,
919            } => Type::Constructor {
920                id: id.clone(),
921                params: params.iter().map(|p| p.resolve()).collect(),
922                underlying_ty: underlying.as_ref().map(|u| Box::new(u.resolve())),
923            },
924            Type::Function {
925                params,
926                param_mutability,
927                bounds,
928                return_type,
929            } => Type::Function {
930                params: params.iter().map(|p| p.resolve()).collect(),
931                param_mutability: param_mutability.clone(),
932                bounds: bounds
933                    .iter()
934                    .map(|b| Bound {
935                        param_name: b.param_name.clone(),
936                        generic: b.generic.resolve(),
937                        ty: b.ty.resolve(),
938                    })
939                    .collect(),
940                return_type: Box::new(return_type.resolve()),
941            },
942            Type::Forall { body, .. } => body.resolve(),
943            Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| e.resolve()).collect()),
944            Type::Parameter(_) | Type::Error => self.clone(),
945            Type::Never => Type::Never,
946        }
947    }
948}
949
950#[derive(Debug, Clone, Copy, PartialEq, Eq)]
951pub enum NumericFamily {
952    SignedInt,
953    UnsignedInt,
954    Float,
955}
956
957const SIGNED_INT_TYPES: &[&str] = &["int", "int8", "int16", "int32", "int64", "rune"];
958const FLOAT_TYPES: &[&str] = &["float32", "float64"];
959
960impl Type {
961    pub fn underlying_numeric_type(&self) -> Option<Type> {
962        self.underlying_numeric_type_recursive(&mut HashSet::default())
963    }
964
965    pub fn has_underlying_numeric_type(&self) -> bool {
966        self.underlying_numeric_type().is_some()
967    }
968
969    fn underlying_numeric_type_recursive(&self, visited: &mut HashSet<EcoString>) -> Option<Type> {
970        match self {
971            Type::Constructor {
972                id,
973                underlying_ty: underlying,
974                ..
975            } => {
976                if self.is_numeric() {
977                    return Some(self.clone());
978                }
979
980                if !visited.insert(id.clone()) {
981                    return None;
982                }
983
984                underlying
985                    .as_ref()?
986                    .underlying_numeric_type_recursive(visited)
987            }
988            _ => None,
989        }
990    }
991
992    pub fn numeric_family(&self) -> Option<NumericFamily> {
993        let name = match self {
994            Type::Constructor { id, .. } => unqualified_name(id),
995            _ => return None,
996        };
997
998        if SIGNED_INT_TYPES.contains(&name) {
999            Some(NumericFamily::SignedInt)
1000        } else if UNSIGNED_INT_TYPES.contains(&name) {
1001            Some(NumericFamily::UnsignedInt)
1002        } else if FLOAT_TYPES.contains(&name) {
1003            Some(NumericFamily::Float)
1004        } else {
1005            None
1006        }
1007    }
1008
1009    pub fn is_numeric_compatible_with(&self, other: &Type) -> bool {
1010        let self_underlying_ty = self.underlying_numeric_type();
1011        let other_underlying_ty = other.underlying_numeric_type();
1012
1013        match (self_underlying_ty, other_underlying_ty) {
1014            (Some(s), Some(o)) => s.numeric_family() == o.numeric_family(),
1015            _ => false,
1016        }
1017    }
1018
1019    pub fn is_aliased_numeric_type(&self) -> bool {
1020        match self {
1021            Type::Constructor { underlying_ty, .. } => {
1022                underlying_ty.is_some() && !self.is_numeric()
1023            }
1024            _ => false,
1025        }
1026    }
1027}