1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::{OnceCell, RefCell};
3use std::rc::Rc;
4
5use ecow::EcoString;
6
7pub fn unqualified_name(id: &str) -> &str {
11 id.rsplit('.').next().unwrap_or(id)
12}
13
14pub type SubstitutionMap = HashMap<EcoString, Type>;
16
17pub fn substitute(ty: &Type, map: &HashMap<EcoString, Type>) -> Type {
18 if map.is_empty() {
19 return ty.clone();
20 }
21 match ty {
22 Type::Parameter(name) => map.get(name).cloned().unwrap_or_else(|| ty.clone()),
23 Type::Constructor {
24 id,
25 params,
26 underlying_ty: underlying,
27 } => Type::Constructor {
28 id: id.clone(),
29 params: params.iter().map(|p| substitute(p, map)).collect(),
30 underlying_ty: underlying.as_ref().map(|u| Box::new(substitute(u, map))),
31 },
32 Type::Function {
33 params,
34 param_mutability,
35 bounds,
36 return_type,
37 } => Type::Function {
38 params: params.iter().map(|p| substitute(p, map)).collect(),
39 param_mutability: param_mutability.clone(),
40 bounds: bounds
41 .iter()
42 .map(|b| Bound {
43 param_name: b.param_name.clone(),
44 generic: substitute(&b.generic, map),
45 ty: substitute(&b.ty, map),
46 })
47 .collect(),
48 return_type: Box::new(substitute(return_type, map)),
49 },
50 Type::Variable(_) | Type::Error => ty.clone(),
51 Type::Forall { vars, body } => {
52 let has_overlap = map.keys().any(|k| vars.contains(k));
53 let substituted_body = if has_overlap {
54 let filtered_map: HashMap<EcoString, Type> = map
55 .iter()
56 .filter(|(k, _)| !vars.contains(*k))
57 .map(|(k, v)| (k.clone(), v.clone()))
58 .collect();
59 substitute(body, &filtered_map)
60 } else {
61 substitute(body, map)
62 };
63 Type::Forall {
64 vars: vars.clone(),
65 body: Box::new(substituted_body),
66 }
67 }
68 Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| substitute(e, map)).collect()),
69 Type::Never => ty.clone(),
70 }
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct Bound {
75 pub param_name: EcoString,
76 pub generic: Type,
77 pub ty: Type,
78}
79
80#[derive(Clone)]
81pub enum TypeVariableState {
82 Unbound { id: i32, hint: Option<EcoString> },
83 Link(Type),
84}
85
86impl TypeVariableState {
87 pub fn is_unbound(&self) -> bool {
88 matches!(self, TypeVariableState::Unbound { .. })
89 }
90}
91
92impl std::fmt::Debug for TypeVariableState {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 TypeVariableState::Unbound { id, hint } => match hint {
96 Some(name) => write!(f, "{}", name),
97 None => write!(f, "{}", id),
98 },
99 TypeVariableState::Link(ty) => write!(f, "{:?}", ty),
100 }
101 }
102}
103
104impl PartialEq for TypeVariableState {
105 fn eq(&self, other: &Self) -> bool {
106 match (self, other) {
107 (
108 TypeVariableState::Unbound { id: id1, .. },
109 TypeVariableState::Unbound { id: id2, .. },
110 ) => id1 == id2,
111 (TypeVariableState::Link(ty1), TypeVariableState::Link(ty2)) => ty1 == ty2,
112 _ => false,
113 }
114 }
115}
116
117#[derive(Clone)]
118pub enum Type {
119 Constructor {
120 id: EcoString,
121 params: Vec<Type>,
122 underlying_ty: Option<Box<Type>>,
123 },
124
125 Function {
126 params: Vec<Type>,
127 param_mutability: Vec<bool>,
128 bounds: Vec<Bound>,
129 return_type: Box<Type>,
130 },
131
132 Variable(Rc<RefCell<TypeVariableState>>),
133
134 Forall {
135 vars: Vec<EcoString>,
136 body: Box<Type>,
137 },
138
139 Parameter(EcoString),
140
141 Never,
142
143 Tuple(Vec<Type>),
144
145 Error,
148}
149
150impl std::fmt::Debug for Type {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 Type::Constructor { id, params, .. } => f
154 .debug_struct("Constructor")
155 .field("id", id)
156 .field("params", params)
157 .finish(),
158 Type::Function {
159 params,
160 param_mutability,
161 bounds,
162 return_type,
163 } => {
164 let mut s = f.debug_struct("Function");
165 s.field("params", params);
166 if param_mutability.iter().any(|m| *m) {
167 s.field("param_mutability", param_mutability);
168 }
169 s.field("bounds", bounds)
170 .field("return_type", return_type)
171 .finish()
172 }
173 Type::Variable(type_var) => f
174 .debug_tuple("Variable")
175 .field(&*type_var.borrow())
176 .finish(),
177 Type::Forall { vars, body } => f
178 .debug_struct("Forall")
179 .field("vars", vars)
180 .field("body", body)
181 .finish(),
182 Type::Parameter(name) => f.debug_tuple("Parameter").field(name).finish(),
183 Type::Never => write!(f, "Never"),
184 Type::Tuple(elements) => f.debug_tuple("Tuple").field(elements).finish(),
185 Type::Error => write!(f, "Error"),
186 }
187 }
188}
189
190impl PartialEq for Type {
191 fn eq(&self, other: &Self) -> bool {
192 match (self, other) {
193 (
194 Type::Constructor {
195 id: id1,
196 params: params1,
197 ..
198 },
199 Type::Constructor {
200 id: id2,
201 params: params2,
202 ..
203 },
204 ) => id1 == id2 && params1 == params2,
205 (
206 Type::Function {
207 params: p1,
208 param_mutability: m1,
209 bounds: b1,
210 return_type: r1,
211 },
212 Type::Function {
213 params: p2,
214 param_mutability: m2,
215 bounds: b2,
216 return_type: r2,
217 },
218 ) => p1 == p2 && m1 == m2 && b1 == b2 && r1 == r2,
219 (Type::Variable(v1), Type::Variable(v2)) => {
220 Rc::ptr_eq(v1, v2) || *v1.borrow() == *v2.borrow()
221 }
222 (
223 Type::Forall {
224 vars: vars1,
225 body: body1,
226 },
227 Type::Forall {
228 vars: vars2,
229 body: body2,
230 },
231 ) => vars1 == vars2 && body1 == body2,
232 (Type::Parameter(name1), Type::Parameter(name2)) => name1 == name2,
233 (Type::Never, Type::Never) => true,
234 (Type::Tuple(elems1), Type::Tuple(elems2)) => elems1 == elems2,
235 _ => false,
236 }
237 }
238}
239
240thread_local! {
241 static INTERNED_INT: OnceCell<Type> = const { OnceCell::new() };
242 static INTERNED_STRING: OnceCell<Type> = const { OnceCell::new() };
243 static INTERNED_BOOL: OnceCell<Type> = const { OnceCell::new() };
244 static INTERNED_UNIT: OnceCell<Type> = const { OnceCell::new() };
245 static INTERNED_FLOAT64: OnceCell<Type> = const { OnceCell::new() };
246 static INTERNED_RUNE: OnceCell<Type> = const { OnceCell::new() };
247}
248
249impl Type {
250 pub fn int() -> Type {
251 INTERNED_INT.with(|cell| cell.get_or_init(|| Self::nominal("int")).clone())
252 }
253
254 pub fn string() -> Type {
255 INTERNED_STRING.with(|cell| cell.get_or_init(|| Self::nominal("string")).clone())
256 }
257
258 pub fn bool() -> Type {
259 INTERNED_BOOL.with(|cell| cell.get_or_init(|| Self::nominal("bool")).clone())
260 }
261
262 pub fn unit() -> Type {
263 INTERNED_UNIT.with(|cell| cell.get_or_init(|| Self::nominal("Unit")).clone())
264 }
265
266 pub fn float64() -> Type {
267 INTERNED_FLOAT64.with(|cell| cell.get_or_init(|| Self::nominal("float64")).clone())
268 }
269
270 pub fn rune() -> Type {
271 INTERNED_RUNE.with(|cell| cell.get_or_init(|| Self::nominal("rune")).clone())
272 }
273}
274
275impl Type {
276 const UNINFERRED_ID: i32 = -1;
277 const IGNORED_ID: i32 = -333;
278
279 pub fn nominal(name: &str) -> Self {
280 Self::Constructor {
281 id: format!("**nominal.{}", name).into(),
282 params: vec![],
283 underlying_ty: None,
284 }
285 }
286
287 pub fn uninferred() -> Self {
288 Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
289 id: Self::UNINFERRED_ID,
290 hint: None,
291 })))
292 }
293
294 pub fn ignored() -> Self {
295 Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
296 id: Self::IGNORED_ID,
297 hint: None,
298 })))
299 }
300
301 pub fn get_type_params(&self) -> Option<&[Type]> {
302 match self {
303 Type::Constructor { params, .. } => Some(params),
304 _ => None,
305 }
306 }
307}
308
309const ARITHMETIC_TYPES: &[&str] = &[
310 "byte",
311 "complex128",
312 "complex64",
313 "float32",
314 "float64",
315 "int",
316 "int16",
317 "int32",
318 "int64",
319 "int8",
320 "rune",
321 "uint",
322 "uint16",
323 "uint32",
324 "uint64",
325 "uint8",
326];
327
328const ORDERED_TYPES: &[&str] = &[
329 "byte", "float32", "float64", "int", "int16", "int32", "int64", "int8", "rune", "uint",
330 "uint16", "uint32", "uint64", "uint8",
331];
332
333const UNSIGNED_INT_TYPES: &[&str] = &["byte", "uint", "uint8", "uint16", "uint32", "uint64"];
334
335impl Type {
336 pub fn get_function_ret(&self) -> Option<&Type> {
337 match self {
338 Type::Function { return_type, .. } => Some(return_type),
339 _ => None,
340 }
341 }
342
343 pub fn has_name(&self, name: &str) -> bool {
344 match self {
345 Type::Constructor { id, .. } => unqualified_name(id) == name,
346 _ => false,
347 }
348 }
349
350 pub fn get_qualified_id(&self) -> Option<&str> {
351 match self {
352 Type::Constructor { id, .. } => Some(id.as_str()),
353 _ => None,
354 }
355 }
356
357 pub fn get_underlying(&self) -> Option<&Type> {
358 match self {
359 Type::Constructor {
360 underlying_ty: underlying,
361 ..
362 } => underlying.as_deref(),
363 _ => None,
364 }
365 }
366
367 pub fn is_result(&self) -> bool {
368 self.has_qualified_id("prelude.Result")
369 }
370
371 pub fn is_option(&self) -> bool {
372 self.has_qualified_id("prelude.Option")
373 }
374
375 pub fn is_partial(&self) -> bool {
376 self.has_qualified_id("prelude.Partial")
377 }
378
379 fn has_qualified_id(&self, qualified_id: &str) -> bool {
380 matches!(self, Type::Constructor { id, .. } if id.as_str() == qualified_id)
381 }
382
383 pub fn is_unit(&self) -> bool {
384 matches!(self.resolve(), Type::Constructor { ref id, .. } if id.as_ref() == "**nominal.Unit")
385 }
386
387 pub fn tuple_arity(&self) -> Option<usize> {
388 match self {
389 Type::Tuple(elements) => Some(elements.len()),
390 _ => None,
391 }
392 }
393
394 pub fn is_tuple(&self) -> bool {
395 matches!(self, Type::Tuple(_))
396 }
397
398 pub fn is_ref(&self) -> bool {
399 self.has_name("Ref")
400 }
401
402 pub fn is_receiver_placeholder(&self) -> bool {
403 self.has_name("__receiver__")
404 }
405
406 pub fn is_unknown(&self) -> bool {
407 self.has_name("Unknown")
408 }
409
410 pub fn is_receiver(&self) -> bool {
411 self.has_name("Receiver")
412 }
413
414 pub fn is_ignored(&self) -> bool {
415 match self {
416 Type::Variable(var) => {
417 matches!(&*var.borrow(), TypeVariableState::Unbound { id, .. } if *id == Self::IGNORED_ID)
418 }
419 _ => false,
420 }
421 }
422
423 pub fn is_variadic(&self) -> Option<Type> {
424 let args = self.get_function_params()?;
425 let last = args.last()?;
426
427 if last.get_name()? == "VarArgs" {
428 return last.inner();
429 }
430
431 None
432 }
433
434 pub fn is_string(&self) -> bool {
435 self.has_name("string")
436 }
437
438 pub fn is_slice_of(&self, element_name: &str) -> bool {
439 match self {
440 Type::Constructor { id, params, .. } => {
441 if unqualified_name(id) != "Slice" || params.len() != 1 {
442 return false;
443 }
444 params[0].resolve().has_name(element_name)
445 }
446 _ => false,
447 }
448 }
449
450 pub fn is_byte_slice(&self) -> bool {
451 self.is_slice_of("byte") || self.is_slice_of("uint8")
452 }
453
454 pub fn is_rune_slice(&self) -> bool {
455 self.is_slice_of("rune")
456 }
457
458 pub fn is_byte_or_rune_slice(&self) -> bool {
459 self.is_byte_slice() || self.is_rune_slice()
460 }
461
462 pub fn has_byte_or_rune_slice_underlying(&self) -> bool {
463 if self.is_byte_or_rune_slice() {
464 return true;
465 }
466 match self {
467 Type::Constructor { underlying_ty, .. } => underlying_ty
468 .as_deref()
469 .is_some_and(|u| u.has_byte_or_rune_slice_underlying()),
470 _ => false,
471 }
472 }
473
474 pub fn is_boolean(&self) -> bool {
475 self.has_name("bool")
476 }
477
478 pub fn is_rune(&self) -> bool {
479 self.has_name("rune")
480 }
481
482 pub fn is_float64(&self) -> bool {
483 self.has_name("float64")
484 }
485
486 pub fn is_float32(&self) -> bool {
487 self.has_name("float32")
488 }
489
490 pub fn is_float(&self) -> bool {
491 self.is_float64() || self.is_float32()
492 }
493
494 pub fn is_variable(&self) -> bool {
495 matches!(self, Type::Variable(_))
496 }
497
498 pub fn is_unbound_variable(&self) -> bool {
499 matches!(self, Type::Variable(cell) if cell.borrow().is_unbound())
500 }
501
502 pub fn is_numeric(&self) -> bool {
503 match self {
504 Type::Constructor { id, .. } => ARITHMETIC_TYPES.contains(&unqualified_name(id)),
505 _ => false,
506 }
507 }
508
509 pub fn is_ordered(&self) -> bool {
510 match self {
511 Type::Constructor { id, .. } => ORDERED_TYPES.contains(&unqualified_name(id)),
512 _ => false,
513 }
514 }
515
516 pub fn is_complex(&self) -> bool {
517 match self {
518 Type::Constructor { id, .. } => {
519 matches!(unqualified_name(id), "complex128" | "complex64")
520 }
521 _ => false,
522 }
523 }
524
525 pub fn is_unsigned_int(&self) -> bool {
526 match self {
527 Type::Constructor { id, .. } => UNSIGNED_INT_TYPES.contains(&unqualified_name(id)),
528 _ => false,
529 }
530 }
531
532 pub fn is_never(&self) -> bool {
533 matches!(self.shallow_resolve(), Type::Never)
534 }
535
536 pub fn is_error(&self) -> bool {
537 matches!(self.shallow_resolve(), Type::Error)
538 }
539
540 pub fn has_unbound_variables(&self) -> bool {
541 match self {
542 Type::Variable(type_var) => match &*type_var.borrow() {
543 TypeVariableState::Unbound { hint, .. } => hint.is_some(),
544 TypeVariableState::Link(ty) => ty.has_unbound_variables(),
545 },
546 Type::Constructor { params, .. } => params.iter().any(|p| p.has_unbound_variables()),
547 Type::Function {
548 params,
549 return_type,
550 ..
551 } => {
552 params.iter().any(|p| p.has_unbound_variables())
553 || return_type.has_unbound_variables()
554 }
555 Type::Forall { body, .. } => body.has_unbound_variables(),
556 Type::Tuple(elements) => elements.iter().any(|e| e.has_unbound_variables()),
557 Type::Parameter(_) | Type::Never | Type::Error => false,
558 }
559 }
560
561 pub fn remove_found_type_names(&self, names: &mut HashSet<EcoString>) {
562 if names.is_empty() {
563 return;
564 }
565
566 match self {
567 Type::Constructor { id, params, .. } => {
568 names.remove(unqualified_name(id));
569 for param in params {
570 param.remove_found_type_names(names);
571 }
572 }
573 Type::Function {
574 params,
575 return_type,
576 bounds,
577 ..
578 } => {
579 for param in params {
580 param.remove_found_type_names(names);
581 }
582 return_type.remove_found_type_names(names);
583 for bound in bounds {
584 bound.generic.remove_found_type_names(names);
585 bound.ty.remove_found_type_names(names);
586 }
587 }
588 Type::Forall { body, .. } => {
589 body.remove_found_type_names(names);
590 }
591 Type::Variable(type_var) => {
592 if let TypeVariableState::Link(ty) = &*type_var.borrow() {
593 ty.remove_found_type_names(names);
594 }
595 }
596 Type::Parameter(name) => {
597 names.remove(name);
598 }
599 Type::Tuple(elements) => {
600 for element in elements {
601 element.remove_found_type_names(names);
602 }
603 }
604 Type::Never | Type::Error => {}
605 }
606 }
607}
608
609impl Type {
610 pub fn get_name(&self) -> Option<&str> {
611 match self {
612 Type::Constructor { id, params, .. } => {
613 let name = unqualified_name(id);
614 if name == "Ref" {
615 return params.first().and_then(|inner| inner.get_name());
616 }
617 if let Some(module_path) = id.strip_prefix("@import/") {
618 let path = module_path.strip_prefix("go:").unwrap_or(module_path);
619 return path.rsplit('/').next();
620 }
621 Some(name)
622 }
623 _ => None,
624 }
625 }
626
627 pub fn wraps(&self, name: &str, inner: &Type) -> bool {
628 self.get_name().is_some_and(|n| n == name)
629 && self
630 .get_type_params()
631 .and_then(|p| p.first())
632 .is_some_and(|first| *first == *inner)
633 }
634
635 pub fn get_function_params(&self) -> Option<&[Type]> {
636 match self {
637 Type::Function { params, .. } => Some(params),
638 Type::Constructor {
639 underlying_ty: Some(inner),
640 ..
641 } => inner.get_function_params(),
642 _ => None,
643 }
644 }
645
646 pub fn param_count(&self) -> usize {
647 match self {
648 Type::Function { params, .. } => params.len(),
649 _ => 0,
650 }
651 }
652
653 pub fn get_param_mutability(&self) -> &[bool] {
654 match self {
655 Type::Function {
656 param_mutability, ..
657 } => param_mutability,
658 _ => &[],
659 }
660 }
661
662 pub fn with_replaced_first_param(&self, new_first: &Type) -> Type {
663 match self {
664 Type::Function {
665 params,
666 param_mutability,
667 bounds,
668 return_type,
669 } => {
670 if params.is_empty() {
671 return self.clone();
672 }
673 let mut new_params = params.clone();
674 new_params[0] = new_first.clone();
675 Type::Function {
676 params: new_params,
677 param_mutability: param_mutability.clone(),
678 bounds: bounds.clone(),
679 return_type: return_type.clone(),
680 }
681 }
682 Type::Forall { vars, body } => Type::Forall {
683 vars: vars.clone(),
684 body: Box::new(body.with_replaced_first_param(new_first)),
685 },
686 _ => self.clone(),
687 }
688 }
689
690 pub fn get_bounds(&self) -> &[Bound] {
691 match self {
692 Type::Function { bounds, .. } => bounds,
693 Type::Forall { body, .. } => body.get_bounds(),
694 _ => &[],
695 }
696 }
697
698 pub fn get_qualified_name(&self) -> EcoString {
699 match self.strip_refs() {
700 Type::Constructor { id, .. } => id,
701 _ => panic!("called get_qualified_name on {:#?}", self),
702 }
703 }
704
705 pub fn inner(&self) -> Option<Type> {
706 self.get_type_params()
707 .and_then(|args| args.first().cloned())
708 }
709
710 pub fn ok_type(&self) -> Type {
711 debug_assert!(
712 self.is_result() || self.is_option() || self.is_partial(),
713 "ok_type called on non-Result/Option/Partial type"
714 );
715 self.inner()
716 .expect("Result/Option/Partial should have inner type")
717 }
718
719 pub fn err_type(&self) -> Type {
720 debug_assert!(
721 self.is_result() || self.is_partial(),
722 "err_type called on non-Result/Partial type"
723 );
724 self.get_type_params()
725 .and_then(|args| args.get(1).cloned())
726 .expect("Result/Partial should have error type")
727 }
728}
729
730impl Type {
731 pub fn unwrap_forall(&self) -> &Type {
732 match self {
733 Type::Forall { body, .. } => body.as_ref(),
734 other => other,
735 }
736 }
737
738 pub fn strip_refs(&self) -> Type {
739 if self.is_ref() {
740 return self.inner().expect("ref type must have inner").strip_refs();
741 }
742
743 self.clone()
744 }
745
746 pub fn with_receiver_placeholder(self) -> Type {
747 match self {
748 Type::Function {
749 params,
750 param_mutability,
751 bounds,
752 return_type,
753 } => {
754 let mut new_params = vec![Type::nominal("__receiver__")];
755 new_params.extend(params);
756
757 let mut new_mutability = vec![false];
758 new_mutability.extend(param_mutability);
759
760 Type::Function {
761 params: new_params,
762 param_mutability: new_mutability,
763 bounds,
764 return_type,
765 }
766 }
767 _ => unreachable!(
768 "with_receiver_placeholder called on non-function type: {:?}",
769 self
770 ),
771 }
772 }
773
774 pub fn remove_vars(types: &[&Type]) -> (Vec<Type>, Vec<EcoString>) {
775 let mut vars = HashMap::default();
776 let types = types
777 .iter()
778 .map(|v| Self::remove_vars_impl(v, &mut vars))
779 .collect();
780
781 (types, vars.into_values().collect())
782 }
783
784 fn remove_vars_impl(ty: &Type, vars: &mut HashMap<i32, EcoString>) -> Type {
785 match ty {
786 Type::Constructor {
787 id: name,
788 params: args,
789 underlying_ty: underlying,
790 } => Type::Constructor {
791 id: name.clone(),
792 params: args
793 .iter()
794 .map(|a| Self::remove_vars_impl(a, vars))
795 .collect(),
796 underlying_ty: underlying
797 .as_ref()
798 .map(|u| Box::new(Self::remove_vars_impl(u, vars))),
799 },
800
801 Type::Function {
802 params: args,
803 param_mutability,
804 bounds,
805 return_type,
806 } => Type::Function {
807 params: args
808 .iter()
809 .map(|a| Self::remove_vars_impl(a, vars))
810 .collect(),
811 param_mutability: param_mutability.clone(),
812 bounds: bounds
813 .iter()
814 .map(|b| Bound {
815 param_name: b.param_name.clone(),
816 generic: Self::remove_vars_impl(&b.generic, vars),
817 ty: Self::remove_vars_impl(&b.ty, vars),
818 })
819 .collect(),
820 return_type: Self::remove_vars_impl(return_type, vars).into(),
821 },
822
823 Type::Variable(type_var) => match &*type_var.borrow() {
824 TypeVariableState::Unbound { id, hint } => match vars.get(id) {
825 Some(g) => Self::nominal(g),
826 None => {
827 let name: EcoString = hint.clone().unwrap_or_else(|| {
828 char::from_digit(
829 (vars.len() + 10)
830 .try_into()
831 .expect("type var count fits in u32"),
832 16,
833 )
834 .expect("type var index is valid hex digit")
835 .to_uppercase()
836 .to_string()
837 .into()
838 });
839
840 vars.insert(*id, name.clone());
841 Self::nominal(&name)
842 }
843 },
844 TypeVariableState::Link(ty) => Self::remove_vars_impl(ty, vars),
845 },
846
847 Type::Forall { body, .. } => Self::remove_vars_impl(body, vars),
848 Type::Tuple(elements) => Type::Tuple(
849 elements
850 .iter()
851 .map(|e| Self::remove_vars_impl(e, vars))
852 .collect(),
853 ),
854 Type::Parameter(name) => Type::Parameter(name.clone()),
855 Type::Never | Type::Error => ty.clone(),
856 }
857 }
858
859 pub fn contains_type(&self, target: &Type) -> bool {
860 if *self == *target {
861 return true;
862 }
863 match self {
864 Type::Constructor { params, .. } => params.iter().any(|p| p.contains_type(target)),
865 Type::Function {
866 params,
867 return_type,
868 ..
869 } => {
870 params.iter().any(|p| p.contains_type(target)) || return_type.contains_type(target)
871 }
872 Type::Variable(var) => {
873 if let TypeVariableState::Link(linked) = &*var.borrow() {
874 linked.contains_type(target)
875 } else {
876 false
877 }
878 }
879 Type::Forall { body, .. } => body.contains_type(target),
880 Type::Tuple(elements) => elements.iter().any(|e| e.contains_type(target)),
881 Type::Parameter(_) | Type::Never | Type::Error => false,
882 }
883 }
884
885 pub fn shallow_resolve(&self) -> Type {
889 match self {
890 Type::Variable(type_var) => {
891 let state = type_var.borrow();
892 match &*state {
893 TypeVariableState::Unbound { .. } => self.clone(),
894 TypeVariableState::Link(linked) => linked.shallow_resolve(),
895 }
896 }
897 _ => self.clone(),
898 }
899 }
900
901 pub fn resolve(&self) -> Type {
902 match self {
903 Type::Variable(type_var) => {
904 let state = type_var.borrow();
905 match &*state {
906 TypeVariableState::Unbound { .. } => self.clone(),
907 TypeVariableState::Link(linked) => {
908 let resolved = linked.resolve();
909 drop(state);
910 *type_var.borrow_mut() = TypeVariableState::Link(resolved.clone());
911 resolved
912 }
913 }
914 }
915 Type::Constructor {
916 id,
917 params,
918 underlying_ty: underlying,
919 } => Type::Constructor {
920 id: id.clone(),
921 params: params.iter().map(|p| p.resolve()).collect(),
922 underlying_ty: underlying.as_ref().map(|u| Box::new(u.resolve())),
923 },
924 Type::Function {
925 params,
926 param_mutability,
927 bounds,
928 return_type,
929 } => Type::Function {
930 params: params.iter().map(|p| p.resolve()).collect(),
931 param_mutability: param_mutability.clone(),
932 bounds: bounds
933 .iter()
934 .map(|b| Bound {
935 param_name: b.param_name.clone(),
936 generic: b.generic.resolve(),
937 ty: b.ty.resolve(),
938 })
939 .collect(),
940 return_type: Box::new(return_type.resolve()),
941 },
942 Type::Forall { body, .. } => body.resolve(),
943 Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| e.resolve()).collect()),
944 Type::Parameter(_) | Type::Error => self.clone(),
945 Type::Never => Type::Never,
946 }
947 }
948}
949
950#[derive(Debug, Clone, Copy, PartialEq, Eq)]
951pub enum NumericFamily {
952 SignedInt,
953 UnsignedInt,
954 Float,
955}
956
957const SIGNED_INT_TYPES: &[&str] = &["int", "int8", "int16", "int32", "int64", "rune"];
958const FLOAT_TYPES: &[&str] = &["float32", "float64"];
959
960impl Type {
961 pub fn underlying_numeric_type(&self) -> Option<Type> {
962 self.underlying_numeric_type_recursive(&mut HashSet::default())
963 }
964
965 pub fn has_underlying_numeric_type(&self) -> bool {
966 self.underlying_numeric_type().is_some()
967 }
968
969 fn underlying_numeric_type_recursive(&self, visited: &mut HashSet<EcoString>) -> Option<Type> {
970 match self {
971 Type::Constructor {
972 id,
973 underlying_ty: underlying,
974 ..
975 } => {
976 if self.is_numeric() {
977 return Some(self.clone());
978 }
979
980 if !visited.insert(id.clone()) {
981 return None;
982 }
983
984 underlying
985 .as_ref()?
986 .underlying_numeric_type_recursive(visited)
987 }
988 _ => None,
989 }
990 }
991
992 pub fn numeric_family(&self) -> Option<NumericFamily> {
993 let name = match self {
994 Type::Constructor { id, .. } => unqualified_name(id),
995 _ => return None,
996 };
997
998 if SIGNED_INT_TYPES.contains(&name) {
999 Some(NumericFamily::SignedInt)
1000 } else if UNSIGNED_INT_TYPES.contains(&name) {
1001 Some(NumericFamily::UnsignedInt)
1002 } else if FLOAT_TYPES.contains(&name) {
1003 Some(NumericFamily::Float)
1004 } else {
1005 None
1006 }
1007 }
1008
1009 pub fn is_numeric_compatible_with(&self, other: &Type) -> bool {
1010 let self_underlying_ty = self.underlying_numeric_type();
1011 let other_underlying_ty = other.underlying_numeric_type();
1012
1013 match (self_underlying_ty, other_underlying_ty) {
1014 (Some(s), Some(o)) => s.numeric_family() == o.numeric_family(),
1015 _ => false,
1016 }
1017 }
1018
1019 pub fn is_aliased_numeric_type(&self) -> bool {
1020 match self {
1021 Type::Constructor { underlying_ty, .. } => {
1022 underlying_ty.is_some() && !self.is_numeric()
1023 }
1024 _ => false,
1025 }
1026 }
1027}