1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::{OnceCell, RefCell};
3use std::rc::Rc;
4
5use ecow::EcoString;
6
7pub fn unqualified_name(id: &str) -> &str {
11 id.rsplit('.').next().unwrap_or(id)
12}
13
14pub type SubstitutionMap = HashMap<EcoString, Type>;
16
17pub fn substitute(ty: &Type, map: &HashMap<EcoString, Type>) -> Type {
18 if map.is_empty() {
19 return ty.clone();
20 }
21 match ty {
22 Type::Parameter(name) => map.get(name).cloned().unwrap_or_else(|| ty.clone()),
23 Type::Constructor {
24 id,
25 params,
26 underlying_ty: underlying,
27 } => Type::Constructor {
28 id: id.clone(),
29 params: params.iter().map(|p| substitute(p, map)).collect(),
30 underlying_ty: underlying.as_ref().map(|u| Box::new(substitute(u, map))),
31 },
32 Type::Function {
33 params,
34 param_mutability,
35 bounds,
36 return_type,
37 } => Type::Function {
38 params: params.iter().map(|p| substitute(p, map)).collect(),
39 param_mutability: param_mutability.clone(),
40 bounds: bounds
41 .iter()
42 .map(|b| Bound {
43 param_name: b.param_name.clone(),
44 generic: substitute(&b.generic, map),
45 ty: substitute(&b.ty, map),
46 })
47 .collect(),
48 return_type: Box::new(substitute(return_type, map)),
49 },
50 Type::Variable(_) | Type::Error => ty.clone(),
51 Type::Forall { vars, body } => {
52 let has_overlap = map.keys().any(|k| vars.contains(k));
53 let substituted_body = if has_overlap {
54 let filtered_map: HashMap<EcoString, Type> = map
55 .iter()
56 .filter(|(k, _)| !vars.contains(*k))
57 .map(|(k, v)| (k.clone(), v.clone()))
58 .collect();
59 substitute(body, &filtered_map)
60 } else {
61 substitute(body, map)
62 };
63 Type::Forall {
64 vars: vars.clone(),
65 body: Box::new(substituted_body),
66 }
67 }
68 Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| substitute(e, map)).collect()),
69 Type::Never => ty.clone(),
70 }
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct Bound {
75 pub param_name: EcoString,
76 pub generic: Type,
77 pub ty: Type,
78}
79
80#[derive(Clone)]
81pub enum TypeVariableState {
82 Unbound { id: i32, hint: Option<EcoString> },
83 Link(Type),
84}
85
86impl TypeVariableState {
87 pub fn is_unbound(&self) -> bool {
88 matches!(self, TypeVariableState::Unbound { .. })
89 }
90}
91
92impl std::fmt::Debug for TypeVariableState {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 TypeVariableState::Unbound { id, hint } => match hint {
96 Some(name) => write!(f, "{}", name),
97 None => write!(f, "{}", id),
98 },
99 TypeVariableState::Link(ty) => write!(f, "{:?}", ty),
100 }
101 }
102}
103
104impl PartialEq for TypeVariableState {
105 fn eq(&self, other: &Self) -> bool {
106 match (self, other) {
107 (
108 TypeVariableState::Unbound { id: id1, .. },
109 TypeVariableState::Unbound { id: id2, .. },
110 ) => id1 == id2,
111 (TypeVariableState::Link(ty1), TypeVariableState::Link(ty2)) => ty1 == ty2,
112 _ => false,
113 }
114 }
115}
116
117#[derive(Clone)]
118pub enum Type {
119 Constructor {
120 id: EcoString,
121 params: Vec<Type>,
122 underlying_ty: Option<Box<Type>>,
123 },
124
125 Function {
126 params: Vec<Type>,
127 param_mutability: Vec<bool>,
128 bounds: Vec<Bound>,
129 return_type: Box<Type>,
130 },
131
132 Variable(Rc<RefCell<TypeVariableState>>),
133
134 Forall {
135 vars: Vec<EcoString>,
136 body: Box<Type>,
137 },
138
139 Parameter(EcoString),
140
141 Never,
142
143 Tuple(Vec<Type>),
144
145 Error,
148}
149
150impl std::fmt::Debug for Type {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 Type::Constructor { id, params, .. } => f
154 .debug_struct("Constructor")
155 .field("id", id)
156 .field("params", params)
157 .finish(),
158 Type::Function {
159 params,
160 param_mutability,
161 bounds,
162 return_type,
163 } => {
164 let mut s = f.debug_struct("Function");
165 s.field("params", params);
166 if param_mutability.iter().any(|m| *m) {
167 s.field("param_mutability", param_mutability);
168 }
169 s.field("bounds", bounds)
170 .field("return_type", return_type)
171 .finish()
172 }
173 Type::Variable(type_var) => f
174 .debug_tuple("Variable")
175 .field(&*type_var.borrow())
176 .finish(),
177 Type::Forall { vars, body } => f
178 .debug_struct("Forall")
179 .field("vars", vars)
180 .field("body", body)
181 .finish(),
182 Type::Parameter(name) => f.debug_tuple("Parameter").field(name).finish(),
183 Type::Never => write!(f, "Never"),
184 Type::Tuple(elements) => f.debug_tuple("Tuple").field(elements).finish(),
185 Type::Error => write!(f, "Error"),
186 }
187 }
188}
189
190impl PartialEq for Type {
191 fn eq(&self, other: &Self) -> bool {
192 match (self, other) {
193 (
194 Type::Constructor {
195 id: id1,
196 params: params1,
197 ..
198 },
199 Type::Constructor {
200 id: id2,
201 params: params2,
202 ..
203 },
204 ) => id1 == id2 && params1 == params2,
205 (
206 Type::Function {
207 params: p1,
208 param_mutability: m1,
209 bounds: b1,
210 return_type: r1,
211 },
212 Type::Function {
213 params: p2,
214 param_mutability: m2,
215 bounds: b2,
216 return_type: r2,
217 },
218 ) => p1 == p2 && m1 == m2 && b1 == b2 && r1 == r2,
219 (Type::Variable(v1), Type::Variable(v2)) => {
220 Rc::ptr_eq(v1, v2) || *v1.borrow() == *v2.borrow()
221 }
222 (
223 Type::Forall {
224 vars: vars1,
225 body: body1,
226 },
227 Type::Forall {
228 vars: vars2,
229 body: body2,
230 },
231 ) => vars1 == vars2 && body1 == body2,
232 (Type::Parameter(name1), Type::Parameter(name2)) => name1 == name2,
233 (Type::Never, Type::Never) => true,
234 (Type::Tuple(elems1), Type::Tuple(elems2)) => elems1 == elems2,
235 _ => false,
236 }
237 }
238}
239
240thread_local! {
241 static INTERNED_INT: OnceCell<Type> = const { OnceCell::new() };
242 static INTERNED_STRING: OnceCell<Type> = const { OnceCell::new() };
243 static INTERNED_BOOL: OnceCell<Type> = const { OnceCell::new() };
244 static INTERNED_UNIT: OnceCell<Type> = const { OnceCell::new() };
245 static INTERNED_FLOAT64: OnceCell<Type> = const { OnceCell::new() };
246 static INTERNED_RUNE: OnceCell<Type> = const { OnceCell::new() };
247 static INTERNED_BYTE: OnceCell<Type> = const { OnceCell::new() };
248}
249
250impl Type {
251 pub fn int() -> Type {
252 INTERNED_INT.with(|cell| cell.get_or_init(|| Self::nominal("int")).clone())
253 }
254
255 pub fn string() -> Type {
256 INTERNED_STRING.with(|cell| cell.get_or_init(|| Self::nominal("string")).clone())
257 }
258
259 pub fn bool() -> Type {
260 INTERNED_BOOL.with(|cell| cell.get_or_init(|| Self::nominal("bool")).clone())
261 }
262
263 pub fn unit() -> Type {
264 INTERNED_UNIT.with(|cell| cell.get_or_init(|| Self::nominal("Unit")).clone())
265 }
266
267 pub fn float64() -> Type {
268 INTERNED_FLOAT64.with(|cell| cell.get_or_init(|| Self::nominal("float64")).clone())
269 }
270
271 pub fn rune() -> Type {
272 INTERNED_RUNE.with(|cell| cell.get_or_init(|| Self::nominal("rune")).clone())
273 }
274
275 pub fn byte() -> Type {
276 INTERNED_BYTE.with(|cell| cell.get_or_init(|| Self::nominal("byte")).clone())
277 }
278}
279
280impl Type {
281 const UNINFERRED_ID: i32 = -1;
282 const IGNORED_ID: i32 = -333;
283
284 pub fn nominal(name: &str) -> Self {
285 Self::Constructor {
286 id: format!("**nominal.{}", name).into(),
287 params: vec![],
288 underlying_ty: None,
289 }
290 }
291
292 pub fn uninferred() -> Self {
293 Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
294 id: Self::UNINFERRED_ID,
295 hint: None,
296 })))
297 }
298
299 pub fn ignored() -> Self {
300 Self::Variable(Rc::new(RefCell::new(TypeVariableState::Unbound {
301 id: Self::IGNORED_ID,
302 hint: None,
303 })))
304 }
305
306 pub fn get_type_params(&self) -> Option<&[Type]> {
307 match self {
308 Type::Constructor { params, .. } => Some(params),
309 _ => None,
310 }
311 }
312}
313
314const ARITHMETIC_TYPES: &[&str] = &[
315 "byte",
316 "complex128",
317 "complex64",
318 "float32",
319 "float64",
320 "int",
321 "int16",
322 "int32",
323 "int64",
324 "int8",
325 "rune",
326 "uint",
327 "uint16",
328 "uint32",
329 "uint64",
330 "uint8",
331];
332
333const ORDERED_TYPES: &[&str] = &[
334 "byte", "float32", "float64", "int", "int16", "int32", "int64", "int8", "rune", "uint",
335 "uint16", "uint32", "uint64", "uint8",
336];
337
338const UNSIGNED_INT_TYPES: &[&str] = &["byte", "uint", "uint8", "uint16", "uint32", "uint64"];
339
340impl Type {
341 pub fn get_function_ret(&self) -> Option<&Type> {
342 match self {
343 Type::Function { return_type, .. } => Some(return_type),
344 _ => None,
345 }
346 }
347
348 pub fn has_name(&self, name: &str) -> bool {
349 match self {
350 Type::Constructor { id, .. } => unqualified_name(id) == name,
351 _ => false,
352 }
353 }
354
355 pub fn get_qualified_id(&self) -> Option<&str> {
356 match self {
357 Type::Constructor { id, .. } => Some(id.as_str()),
358 _ => None,
359 }
360 }
361
362 pub fn get_underlying(&self) -> Option<&Type> {
363 match self {
364 Type::Constructor {
365 underlying_ty: underlying,
366 ..
367 } => underlying.as_deref(),
368 _ => None,
369 }
370 }
371
372 pub fn is_result(&self) -> bool {
373 self.has_qualified_id("prelude.Result")
374 }
375
376 pub fn is_option(&self) -> bool {
377 self.has_qualified_id("prelude.Option")
378 }
379
380 pub fn is_partial(&self) -> bool {
381 self.has_qualified_id("prelude.Partial")
382 }
383
384 fn has_qualified_id(&self, qualified_id: &str) -> bool {
385 matches!(self, Type::Constructor { id, .. } if id.as_str() == qualified_id)
386 }
387
388 pub fn is_unit(&self) -> bool {
389 matches!(self.resolve(), Type::Constructor { ref id, .. } if id.as_ref() == "**nominal.Unit")
390 }
391
392 pub fn tuple_arity(&self) -> Option<usize> {
393 match self {
394 Type::Tuple(elements) => Some(elements.len()),
395 _ => None,
396 }
397 }
398
399 pub fn is_tuple(&self) -> bool {
400 matches!(self, Type::Tuple(_))
401 }
402
403 pub fn is_ref(&self) -> bool {
404 self.has_name("Ref")
405 }
406
407 pub fn is_receiver_placeholder(&self) -> bool {
408 self.has_name("__receiver__")
409 }
410
411 pub fn is_unknown(&self) -> bool {
412 self.has_name("Unknown")
413 }
414
415 pub fn is_receiver(&self) -> bool {
416 self.has_name("Receiver")
417 }
418
419 pub fn is_ignored(&self) -> bool {
420 match self {
421 Type::Variable(var) => {
422 matches!(&*var.borrow(), TypeVariableState::Unbound { id, .. } if *id == Self::IGNORED_ID)
423 }
424 _ => false,
425 }
426 }
427
428 pub fn is_variadic(&self) -> Option<Type> {
429 let args = self.get_function_params()?;
430 let last = args.last()?;
431
432 if last.get_name()? == "VarArgs" {
433 return last.inner();
434 }
435
436 None
437 }
438
439 pub fn is_string(&self) -> bool {
440 self.has_name("string")
441 }
442
443 pub fn is_slice_of(&self, element_name: &str) -> bool {
444 match self {
445 Type::Constructor { id, params, .. } => {
446 if unqualified_name(id) != "Slice" || params.len() != 1 {
447 return false;
448 }
449 params[0].resolve().has_name(element_name)
450 }
451 _ => false,
452 }
453 }
454
455 pub fn is_byte_slice(&self) -> bool {
456 self.is_slice_of("byte") || self.is_slice_of("uint8")
457 }
458
459 pub fn is_rune_slice(&self) -> bool {
460 self.is_slice_of("rune")
461 }
462
463 pub fn is_byte_or_rune_slice(&self) -> bool {
464 self.is_byte_slice() || self.is_rune_slice()
465 }
466
467 pub fn has_byte_or_rune_slice_underlying(&self) -> bool {
468 if self.is_byte_or_rune_slice() {
469 return true;
470 }
471 match self {
472 Type::Constructor { underlying_ty, .. } => underlying_ty
473 .as_deref()
474 .is_some_and(|u| u.has_byte_or_rune_slice_underlying()),
475 _ => false,
476 }
477 }
478
479 pub fn is_boolean(&self) -> bool {
480 self.has_name("bool")
481 }
482
483 pub fn is_rune(&self) -> bool {
484 self.has_name("rune")
485 }
486
487 pub fn is_float64(&self) -> bool {
488 self.has_name("float64")
489 }
490
491 pub fn is_float32(&self) -> bool {
492 self.has_name("float32")
493 }
494
495 pub fn is_float(&self) -> bool {
496 self.is_float64() || self.is_float32()
497 }
498
499 pub fn is_variable(&self) -> bool {
500 matches!(self, Type::Variable(_))
501 }
502
503 pub fn is_unbound_variable(&self) -> bool {
504 matches!(self, Type::Variable(cell) if cell.borrow().is_unbound())
505 }
506
507 pub fn is_numeric(&self) -> bool {
508 match self {
509 Type::Constructor { id, .. } => ARITHMETIC_TYPES.contains(&unqualified_name(id)),
510 _ => false,
511 }
512 }
513
514 pub fn is_ordered(&self) -> bool {
515 match self {
516 Type::Constructor { id, .. } => ORDERED_TYPES.contains(&unqualified_name(id)),
517 _ => false,
518 }
519 }
520
521 pub fn is_complex(&self) -> bool {
522 match self {
523 Type::Constructor { id, .. } => {
524 matches!(unqualified_name(id), "complex128" | "complex64")
525 }
526 _ => false,
527 }
528 }
529
530 pub fn is_unsigned_int(&self) -> bool {
531 match self {
532 Type::Constructor { id, .. } => UNSIGNED_INT_TYPES.contains(&unqualified_name(id)),
533 _ => false,
534 }
535 }
536
537 pub fn is_never(&self) -> bool {
538 matches!(self.shallow_resolve(), Type::Never)
539 }
540
541 pub fn is_error(&self) -> bool {
542 matches!(self.shallow_resolve(), Type::Error)
543 }
544
545 pub fn has_unbound_variables(&self) -> bool {
546 match self {
547 Type::Variable(type_var) => match &*type_var.borrow() {
548 TypeVariableState::Unbound { hint, .. } => hint.is_some(),
549 TypeVariableState::Link(ty) => ty.has_unbound_variables(),
550 },
551 Type::Constructor { params, .. } => params.iter().any(|p| p.has_unbound_variables()),
552 Type::Function {
553 params,
554 return_type,
555 ..
556 } => {
557 params.iter().any(|p| p.has_unbound_variables())
558 || return_type.has_unbound_variables()
559 }
560 Type::Forall { body, .. } => body.has_unbound_variables(),
561 Type::Tuple(elements) => elements.iter().any(|e| e.has_unbound_variables()),
562 Type::Parameter(_) | Type::Never | Type::Error => false,
563 }
564 }
565
566 pub fn remove_found_type_names(&self, names: &mut HashSet<EcoString>) {
567 if names.is_empty() {
568 return;
569 }
570
571 match self {
572 Type::Constructor { id, params, .. } => {
573 names.remove(unqualified_name(id));
574 for param in params {
575 param.remove_found_type_names(names);
576 }
577 }
578 Type::Function {
579 params,
580 return_type,
581 bounds,
582 ..
583 } => {
584 for param in params {
585 param.remove_found_type_names(names);
586 }
587 return_type.remove_found_type_names(names);
588 for bound in bounds {
589 bound.generic.remove_found_type_names(names);
590 bound.ty.remove_found_type_names(names);
591 }
592 }
593 Type::Forall { body, .. } => {
594 body.remove_found_type_names(names);
595 }
596 Type::Variable(type_var) => {
597 if let TypeVariableState::Link(ty) = &*type_var.borrow() {
598 ty.remove_found_type_names(names);
599 }
600 }
601 Type::Parameter(name) => {
602 names.remove(name);
603 }
604 Type::Tuple(elements) => {
605 for element in elements {
606 element.remove_found_type_names(names);
607 }
608 }
609 Type::Never | Type::Error => {}
610 }
611 }
612}
613
614impl Type {
615 pub fn get_name(&self) -> Option<&str> {
616 match self {
617 Type::Constructor { id, params, .. } => {
618 let name = unqualified_name(id);
619 if name == "Ref" {
620 return params.first().and_then(|inner| inner.get_name());
621 }
622 if let Some(module_path) = id.strip_prefix("@import/") {
623 let path = module_path.strip_prefix("go:").unwrap_or(module_path);
624 return path.rsplit('/').next();
625 }
626 Some(name)
627 }
628 _ => None,
629 }
630 }
631
632 pub fn wraps(&self, name: &str, inner: &Type) -> bool {
633 self.get_name().is_some_and(|n| n == name)
634 && self
635 .get_type_params()
636 .and_then(|p| p.first())
637 .is_some_and(|first| *first == *inner)
638 }
639
640 pub fn get_function_params(&self) -> Option<&[Type]> {
641 match self {
642 Type::Function { params, .. } => Some(params),
643 Type::Constructor {
644 underlying_ty: Some(inner),
645 ..
646 } => inner.get_function_params(),
647 _ => None,
648 }
649 }
650
651 pub fn param_count(&self) -> usize {
652 match self {
653 Type::Function { params, .. } => params.len(),
654 _ => 0,
655 }
656 }
657
658 pub fn get_param_mutability(&self) -> &[bool] {
659 match self {
660 Type::Function {
661 param_mutability, ..
662 } => param_mutability,
663 _ => &[],
664 }
665 }
666
667 pub fn with_replaced_first_param(&self, new_first: &Type) -> Type {
668 match self {
669 Type::Function {
670 params,
671 param_mutability,
672 bounds,
673 return_type,
674 } => {
675 if params.is_empty() {
676 return self.clone();
677 }
678 let mut new_params = params.clone();
679 new_params[0] = new_first.clone();
680 Type::Function {
681 params: new_params,
682 param_mutability: param_mutability.clone(),
683 bounds: bounds.clone(),
684 return_type: return_type.clone(),
685 }
686 }
687 Type::Forall { vars, body } => Type::Forall {
688 vars: vars.clone(),
689 body: Box::new(body.with_replaced_first_param(new_first)),
690 },
691 _ => self.clone(),
692 }
693 }
694
695 pub fn get_bounds(&self) -> &[Bound] {
696 match self {
697 Type::Function { bounds, .. } => bounds,
698 Type::Forall { body, .. } => body.get_bounds(),
699 _ => &[],
700 }
701 }
702
703 pub fn get_qualified_name(&self) -> EcoString {
704 match self.strip_refs() {
705 Type::Constructor { id, .. } => id,
706 _ => panic!("called get_qualified_name on {:#?}", self),
707 }
708 }
709
710 pub fn inner(&self) -> Option<Type> {
711 self.get_type_params()
712 .and_then(|args| args.first().cloned())
713 }
714
715 pub fn ok_type(&self) -> Type {
716 debug_assert!(
717 self.is_result() || self.is_option() || self.is_partial(),
718 "ok_type called on non-Result/Option/Partial type"
719 );
720 self.inner()
721 .expect("Result/Option/Partial should have inner type")
722 }
723
724 pub fn err_type(&self) -> Type {
725 debug_assert!(
726 self.is_result() || self.is_partial(),
727 "err_type called on non-Result/Partial type"
728 );
729 self.get_type_params()
730 .and_then(|args| args.get(1).cloned())
731 .expect("Result/Partial should have error type")
732 }
733}
734
735impl Type {
736 pub fn unwrap_forall(&self) -> &Type {
737 match self {
738 Type::Forall { body, .. } => body.as_ref(),
739 other => other,
740 }
741 }
742
743 pub fn strip_refs(&self) -> Type {
744 if self.is_ref() {
745 return self.inner().expect("ref type must have inner").strip_refs();
746 }
747
748 self.clone()
749 }
750
751 pub fn with_receiver_placeholder(self) -> Type {
752 match self {
753 Type::Function {
754 params,
755 param_mutability,
756 bounds,
757 return_type,
758 } => {
759 let mut new_params = vec![Type::nominal("__receiver__")];
760 new_params.extend(params);
761
762 let mut new_mutability = vec![false];
763 new_mutability.extend(param_mutability);
764
765 Type::Function {
766 params: new_params,
767 param_mutability: new_mutability,
768 bounds,
769 return_type,
770 }
771 }
772 _ => unreachable!(
773 "with_receiver_placeholder called on non-function type: {:?}",
774 self
775 ),
776 }
777 }
778
779 pub fn remove_vars(types: &[&Type]) -> (Vec<Type>, Vec<EcoString>) {
780 let mut vars = HashMap::default();
781 let types = types
782 .iter()
783 .map(|v| Self::remove_vars_impl(v, &mut vars))
784 .collect();
785
786 (types, vars.into_values().collect())
787 }
788
789 fn remove_vars_impl(ty: &Type, vars: &mut HashMap<i32, EcoString>) -> Type {
790 match ty {
791 Type::Constructor {
792 id: name,
793 params: args,
794 underlying_ty: underlying,
795 } => Type::Constructor {
796 id: name.clone(),
797 params: args
798 .iter()
799 .map(|a| Self::remove_vars_impl(a, vars))
800 .collect(),
801 underlying_ty: underlying
802 .as_ref()
803 .map(|u| Box::new(Self::remove_vars_impl(u, vars))),
804 },
805
806 Type::Function {
807 params: args,
808 param_mutability,
809 bounds,
810 return_type,
811 } => Type::Function {
812 params: args
813 .iter()
814 .map(|a| Self::remove_vars_impl(a, vars))
815 .collect(),
816 param_mutability: param_mutability.clone(),
817 bounds: bounds
818 .iter()
819 .map(|b| Bound {
820 param_name: b.param_name.clone(),
821 generic: Self::remove_vars_impl(&b.generic, vars),
822 ty: Self::remove_vars_impl(&b.ty, vars),
823 })
824 .collect(),
825 return_type: Self::remove_vars_impl(return_type, vars).into(),
826 },
827
828 Type::Variable(type_var) => match &*type_var.borrow() {
829 TypeVariableState::Unbound { id, hint } => match vars.get(id) {
830 Some(g) => Self::nominal(g),
831 None => {
832 let name: EcoString = hint.clone().unwrap_or_else(|| {
833 char::from_digit(
834 (vars.len() + 10)
835 .try_into()
836 .expect("type var count fits in u32"),
837 16,
838 )
839 .expect("type var index is valid hex digit")
840 .to_uppercase()
841 .to_string()
842 .into()
843 });
844
845 vars.insert(*id, name.clone());
846 Self::nominal(&name)
847 }
848 },
849 TypeVariableState::Link(ty) => Self::remove_vars_impl(ty, vars),
850 },
851
852 Type::Forall { body, .. } => Self::remove_vars_impl(body, vars),
853 Type::Tuple(elements) => Type::Tuple(
854 elements
855 .iter()
856 .map(|e| Self::remove_vars_impl(e, vars))
857 .collect(),
858 ),
859 Type::Parameter(name) => Type::Parameter(name.clone()),
860 Type::Never | Type::Error => ty.clone(),
861 }
862 }
863
864 pub fn contains_type(&self, target: &Type) -> bool {
865 if *self == *target {
866 return true;
867 }
868 match self {
869 Type::Constructor { params, .. } => params.iter().any(|p| p.contains_type(target)),
870 Type::Function {
871 params,
872 return_type,
873 ..
874 } => {
875 params.iter().any(|p| p.contains_type(target)) || return_type.contains_type(target)
876 }
877 Type::Variable(var) => {
878 if let TypeVariableState::Link(linked) = &*var.borrow() {
879 linked.contains_type(target)
880 } else {
881 false
882 }
883 }
884 Type::Forall { body, .. } => body.contains_type(target),
885 Type::Tuple(elements) => elements.iter().any(|e| e.contains_type(target)),
886 Type::Parameter(_) | Type::Never | Type::Error => false,
887 }
888 }
889
890 pub fn shallow_resolve(&self) -> Type {
894 match self {
895 Type::Variable(type_var) => {
896 let state = type_var.borrow();
897 match &*state {
898 TypeVariableState::Unbound { .. } => self.clone(),
899 TypeVariableState::Link(linked) => linked.shallow_resolve(),
900 }
901 }
902 _ => self.clone(),
903 }
904 }
905
906 pub fn resolve(&self) -> Type {
907 match self {
908 Type::Variable(type_var) => {
909 let state = type_var.borrow();
910 match &*state {
911 TypeVariableState::Unbound { .. } => self.clone(),
912 TypeVariableState::Link(linked) => {
913 let resolved = linked.resolve();
914 drop(state);
915 *type_var.borrow_mut() = TypeVariableState::Link(resolved.clone());
916 resolved
917 }
918 }
919 }
920 Type::Constructor {
921 id,
922 params,
923 underlying_ty: underlying,
924 } => Type::Constructor {
925 id: id.clone(),
926 params: params.iter().map(|p| p.resolve()).collect(),
927 underlying_ty: underlying.as_ref().map(|u| Box::new(u.resolve())),
928 },
929 Type::Function {
930 params,
931 param_mutability,
932 bounds,
933 return_type,
934 } => Type::Function {
935 params: params.iter().map(|p| p.resolve()).collect(),
936 param_mutability: param_mutability.clone(),
937 bounds: bounds
938 .iter()
939 .map(|b| Bound {
940 param_name: b.param_name.clone(),
941 generic: b.generic.resolve(),
942 ty: b.ty.resolve(),
943 })
944 .collect(),
945 return_type: Box::new(return_type.resolve()),
946 },
947 Type::Forall { body, .. } => body.resolve(),
948 Type::Tuple(elements) => Type::Tuple(elements.iter().map(|e| e.resolve()).collect()),
949 Type::Parameter(_) | Type::Error => self.clone(),
950 Type::Never => Type::Never,
951 }
952 }
953}
954
955#[derive(Debug, Clone, Copy, PartialEq, Eq)]
956pub enum NumericFamily {
957 SignedInt,
958 UnsignedInt,
959 Float,
960}
961
962const SIGNED_INT_TYPES: &[&str] = &["int", "int8", "int16", "int32", "int64", "rune"];
963const FLOAT_TYPES: &[&str] = &["float32", "float64"];
964
965impl Type {
966 pub fn underlying_numeric_type(&self) -> Option<Type> {
967 self.underlying_numeric_type_recursive(&mut HashSet::default())
968 }
969
970 pub fn has_underlying_numeric_type(&self) -> bool {
971 self.underlying_numeric_type().is_some()
972 }
973
974 fn underlying_numeric_type_recursive(&self, visited: &mut HashSet<EcoString>) -> Option<Type> {
975 match self {
976 Type::Constructor {
977 id,
978 underlying_ty: underlying,
979 ..
980 } => {
981 if self.is_numeric() {
982 return Some(self.clone());
983 }
984
985 if !visited.insert(id.clone()) {
986 return None;
987 }
988
989 underlying
990 .as_ref()?
991 .underlying_numeric_type_recursive(visited)
992 }
993 _ => None,
994 }
995 }
996
997 pub fn numeric_family(&self) -> Option<NumericFamily> {
998 let name = match self {
999 Type::Constructor { id, .. } => unqualified_name(id),
1000 _ => return None,
1001 };
1002
1003 if SIGNED_INT_TYPES.contains(&name) {
1004 Some(NumericFamily::SignedInt)
1005 } else if UNSIGNED_INT_TYPES.contains(&name) {
1006 Some(NumericFamily::UnsignedInt)
1007 } else if FLOAT_TYPES.contains(&name) {
1008 Some(NumericFamily::Float)
1009 } else {
1010 None
1011 }
1012 }
1013
1014 pub fn is_numeric_compatible_with(&self, other: &Type) -> bool {
1015 let self_underlying_ty = self.underlying_numeric_type();
1016 let other_underlying_ty = other.underlying_numeric_type();
1017
1018 match (self_underlying_ty, other_underlying_ty) {
1019 (Some(s), Some(o)) => s.numeric_family() == o.numeric_family(),
1020 _ => false,
1021 }
1022 }
1023
1024 pub fn is_aliased_numeric_type(&self) -> bool {
1025 match self {
1026 Type::Constructor { underlying_ty, .. } => {
1027 underlying_ty.is_some() && !self.is_numeric() && self.has_underlying_numeric_type()
1028 }
1029 _ => false,
1030 }
1031 }
1032}