1use rustc_hash::FxHashMap;
4use slotmap::SlotMap;
5use std::collections::BTreeMap;
6use std::fmt;
7
8slotmap::new_key_type! {
9 pub struct TypeVarId;
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum Type {
16 Int,
18 Float,
20 Number,
22 String,
24 Bool,
26 Symbol,
28 None,
30 Markdown,
32 Array(Box<Type>),
34 Tuple(Vec<Type>),
39 Dict(Box<Type>, Box<Type>),
41 Function(Vec<Type>, Box<Type>),
43 Union(Vec<Type>),
46 Record(BTreeMap<String, Type>, Box<Type>),
56 RowEmpty,
58 Var(TypeVarId),
60 Never,
66}
67
68impl Type {
69 pub fn function(params: Vec<Type>, ret: Type) -> Self {
71 Type::Function(params, Box::new(ret))
72 }
73
74 pub fn array(elem: Type) -> Self {
76 Type::Array(Box::new(elem))
77 }
78
79 pub fn tuple(elems: Vec<Type>) -> Self {
81 Type::Tuple(elems)
82 }
83
84 pub fn dict(key: Type, value: Type) -> Self {
86 Type::Dict(Box::new(key), Box::new(value))
87 }
88
89 pub fn record(fields: BTreeMap<String, Type>, rest: Type) -> Self {
91 Type::Record(fields, Box::new(rest))
92 }
93
94 pub fn union(types: Vec<Type>) -> Self {
97 let mut normalized = Vec::with_capacity(types.len());
98 for ty in types {
99 match ty {
100 Type::Union(inner) => normalized.extend(inner),
102 _ => normalized.push(ty),
103 }
104 }
105
106 if normalized.len() <= 4 {
108 let mut i = 0;
110 while i < normalized.len() {
111 if normalized[..i].iter().any(|t| t == &normalized[i]) {
112 normalized.swap_remove(i);
113 } else {
114 i += 1;
115 }
116 }
117 normalized.sort_by_key(|t| t.discriminant());
118 } else {
119 let mut seen = rustc_hash::FxHashSet::default();
121 normalized.retain(|t| seen.insert(t.clone()));
122 normalized.sort_by_key(|t| t.discriminant());
123 }
124
125 if normalized.len() == 1 {
127 normalized.into_iter().next().unwrap()
128 } else {
129 Type::Union(normalized)
130 }
131 }
132
133 pub fn subtract(&self, exclude: &Type) -> Type {
140 match self {
141 Type::Union(members) => {
142 let remaining: Vec<Type> = members
143 .iter()
144 .filter(|t| std::mem::discriminant(*t) != std::mem::discriminant(exclude))
145 .cloned()
146 .collect();
147 if remaining.is_empty() {
148 Type::Never
149 } else {
150 Type::union(remaining)
151 }
152 }
153 _ => self.clone(),
154 }
155 }
156
157 pub fn is_nullable(&self) -> bool {
159 match self {
160 Type::None => true,
161 Type::Union(members) => members.iter().any(|m| matches!(m, Type::None)),
162 _ => false,
163 }
164 }
165
166 pub fn is_never(&self) -> bool {
168 matches!(self, Type::Never)
169 }
170
171 fn discriminant(&self) -> u8 {
174 match self {
175 Type::Int => 0,
176 Type::Float => 1,
177 Type::Number => 2,
178 Type::String => 3,
179 Type::Bool => 4,
180 Type::Symbol => 5,
181 Type::None => 6,
182 Type::Markdown => 7,
183 Type::Array(_) => 8,
184 Type::Tuple(_) => 9,
185 Type::Dict(_, _) => 10,
186 Type::Function(_, _) => 11,
187 Type::Union(_) => 12,
188 Type::Record(_, _) => 13,
189 Type::RowEmpty => 14,
190 Type::Var(_) => 15,
191 Type::Never => 16,
192 }
193 }
194
195 pub fn is_var(&self) -> bool {
197 matches!(self, Type::Var(_))
198 }
199
200 pub fn is_concrete(&self) -> bool {
202 self.free_vars().is_empty()
203 }
204
205 pub fn is_union(&self) -> bool {
207 matches!(self, Type::Union(_))
208 }
209
210 pub fn as_var(&self) -> Option<TypeVarId> {
212 match self {
213 Type::Var(id) => Some(*id),
214 _ => None,
215 }
216 }
217
218 pub fn apply_subst(&self, subst: &Substitution) -> Type {
220 if subst.is_empty() {
221 return self.clone();
222 }
223 match self {
224 Type::Var(id) => subst.lookup(*id).map_or_else(|| self.clone(), |t| t.apply_subst(subst)),
225 Type::Array(elem) => Type::Array(Box::new(elem.apply_subst(subst))),
226 Type::Tuple(elems) => Type::Tuple(elems.iter().map(|e| e.apply_subst(subst)).collect()),
227 Type::Dict(key, value) => Type::Dict(Box::new(key.apply_subst(subst)), Box::new(value.apply_subst(subst))),
228 Type::Function(params, ret) => {
229 let new_params = params.iter().map(|p| p.apply_subst(subst)).collect();
230 Type::Function(new_params, Box::new(ret.apply_subst(subst)))
231 }
232 Type::Union(types) => {
233 let new_types = types.iter().map(|t| t.apply_subst(subst)).collect();
234 Type::union(new_types)
235 }
236 Type::Record(fields, rest) => {
237 let new_fields = fields.iter().map(|(k, v)| (k.clone(), v.apply_subst(subst))).collect();
238 Type::Record(new_fields, Box::new(rest.apply_subst(subst)))
239 }
240 _ => self.clone(),
242 }
243 }
244
245 pub fn free_vars(&self) -> Vec<TypeVarId> {
247 match self {
248 Type::Var(id) => vec![*id],
249 Type::Array(elem) => elem.free_vars(),
250 Type::Tuple(elems) => elems.iter().flat_map(|e| e.free_vars()).collect(),
251 Type::Dict(key, value) => {
252 let mut vars = key.free_vars();
253 vars.extend(value.free_vars());
254 vars
255 }
256 Type::Function(params, ret) => {
257 let mut vars: Vec<TypeVarId> = params.iter().flat_map(|p| p.free_vars()).collect();
258 vars.extend(ret.free_vars());
259 vars
260 }
261 Type::Union(types) => types.iter().flat_map(|t| t.free_vars()).collect(),
262 Type::Record(fields, rest) => {
263 let mut vars: Vec<TypeVarId> = fields.values().flat_map(|v| v.free_vars()).collect();
264 vars.extend(rest.free_vars());
265 vars
266 }
267 _ => Vec::new(),
268 }
269 }
270
271 pub fn can_match(&self, other: &Type) -> bool {
277 match (self, other) {
278 (Type::Never, _) | (_, Type::Never) => true,
280
281 (Type::Var(_), _) | (_, Type::Var(_)) => true,
283
284 (Type::Union(types), other) => types.iter().any(|t| t.can_match(other)),
286 (other, Type::Union(types)) => types.iter().any(|t| other.can_match(t)),
287
288 (Type::Int, Type::Int)
290 | (Type::Float, Type::Float)
291 | (Type::Number, Type::Number)
292 | (Type::String, Type::String)
293 | (Type::Bool, Type::Bool)
294 | (Type::Symbol, Type::Symbol)
295 | (Type::None, Type::None)
296 | (Type::Markdown, Type::Markdown) => true,
297
298 (Type::Array(elem1), Type::Array(elem2)) => elem1.can_match(elem2),
300
301 (Type::Tuple(elems1), Type::Tuple(elems2)) => {
303 elems1.len() == elems2.len() && elems1.iter().zip(elems2.iter()).all(|(e1, e2)| e1.can_match(e2))
304 }
305
306 (Type::Tuple(_), Type::Array(_)) | (Type::Array(_), Type::Tuple(_)) => true,
308
309 (Type::Dict(k1, v1), Type::Dict(k2, v2)) => k1.can_match(k2) && v1.can_match(v2),
311
312 (Type::Function(params1, ret1), Type::Function(params2, ret2)) => {
314 params1.len() == params2.len()
315 && params1.iter().zip(params2.iter()).all(|(p1, p2)| p1.can_match(p2))
316 && ret1.can_match(ret2)
317 }
318
319 (Type::Record(f1, r1), Type::Record(f2, r2)) => {
321 for (k, v1) in f1 {
323 if let Some(v2) = f2.get(k)
324 && !v1.can_match(v2)
325 {
326 return false;
327 }
328 }
329 r1.can_match(r2)
330 }
331
332 (Type::Record(_, _), Type::Dict(_, _)) | (Type::Dict(_, _), Type::Record(_, _)) => true,
334
335 (Type::RowEmpty, Type::RowEmpty) => true,
337
338 _ => false,
340 }
341 }
342
343 pub fn can_branch_unify_with(&self, other: &Type) -> bool {
356 match (self, other) {
357 (Type::Never, _) | (_, Type::Never) => true,
359
360 (Type::Var(_), Type::Var(_)) => true,
362
363 (Type::Var(_), _) | (_, Type::Var(_)) => false,
365
366 (Type::Union(types), other) => types.iter().any(|t| t.can_branch_unify_with(other)),
368 (other, Type::Union(types)) => types.iter().any(|t| other.can_branch_unify_with(t)),
369
370 (Type::Int, Type::Int)
372 | (Type::Float, Type::Float)
373 | (Type::Number, Type::Number)
374 | (Type::String, Type::String)
375 | (Type::Bool, Type::Bool)
376 | (Type::Symbol, Type::Symbol)
377 | (Type::None, Type::None)
378 | (Type::Markdown, Type::Markdown) => true,
379
380 (Type::Array(elem1), Type::Array(elem2)) => elem1.can_branch_unify_with(elem2),
382
383 (Type::Tuple(elems1), Type::Tuple(elems2)) => {
385 elems1.len() == elems2.len()
386 && elems1
387 .iter()
388 .zip(elems2.iter())
389 .all(|(e1, e2)| e1.can_branch_unify_with(e2))
390 }
391
392 (Type::Dict(k1, v1), Type::Dict(k2, v2)) => k1.can_branch_unify_with(k2) && v1.can_branch_unify_with(v2),
394
395 (Type::Function(params1, ret1), Type::Function(params2, ret2)) => {
397 params1.len() == params2.len()
398 && params1
399 .iter()
400 .zip(params2.iter())
401 .all(|(p1, p2)| p1.can_branch_unify_with(p2))
402 && ret1.can_branch_unify_with(ret2)
403 }
404
405 _ => false,
407 }
408 }
409
410 pub fn match_score(&self, other: &Type) -> Option<u32> {
419 if !self.can_match(other) {
420 return None;
421 }
422
423 match (self, other) {
424 (Type::Never, _) | (_, Type::Never) => Some(1),
426
427 (Type::Int, Type::Int)
429 | (Type::Float, Type::Float)
430 | (Type::Number, Type::Number)
431 | (Type::String, Type::String)
432 | (Type::Bool, Type::Bool)
433 | (Type::Symbol, Type::Symbol)
434 | (Type::None, Type::None)
435 | (Type::Markdown, Type::Markdown) => Some(100),
436
437 (Type::Var(_), _) | (_, Type::Var(_)) => Some(10),
442
443 (Type::Union(types), other) => types
445 .iter()
446 .filter_map(|t| t.match_score(other))
447 .max()
448 .map(|s| s.saturating_sub(15)),
449 (other, Type::Union(types)) => types
450 .iter()
451 .filter_map(|t| other.match_score(t))
452 .max()
453 .map(|s| s.saturating_sub(15)),
454
455 (Type::Array(elem1), Type::Array(elem2)) => elem1.match_score(elem2).map(|s| s + 20),
457
458 (Type::Tuple(elems1), Type::Tuple(elems2)) if elems1.len() == elems2.len() => {
460 let total: u32 = elems1
461 .iter()
462 .zip(elems2.iter())
463 .map(|(e1, e2)| e1.match_score(e2).unwrap_or(0))
464 .sum();
465 Some(total / elems1.len() as u32 + 20)
466 }
467
468 (Type::Tuple(_), Type::Array(_)) | (Type::Array(_), Type::Tuple(_)) => Some(15),
470
471 (Type::Dict(k1, v1), Type::Dict(k2, v2)) => {
473 let key_score = k1.match_score(k2)?;
474 let val_score = v1.match_score(v2)?;
475 Some((key_score + val_score) / 2 + 20)
476 }
477
478 (Type::Record(f1, r1), Type::Record(f2, r2)) => {
480 let mut total = 0u32;
481 let mut count = 0u32;
482 for (k, v1) in f1 {
483 if let Some(v2) = f2.get(k) {
484 total += v1.match_score(v2)?;
485 count += 1;
486 }
487 }
488 let field_score = if count > 0 { total / count } else { 10 };
489 let rest_score = r1.match_score(r2).unwrap_or(10);
490 Some(field_score + rest_score + 20)
491 }
492
493 (Type::Record(_, _), Type::Dict(_, _)) | (Type::Dict(_, _), Type::Record(_, _)) => Some(15),
495
496 (Type::RowEmpty, Type::RowEmpty) => Some(100),
497
498 (Type::Function(params1, ret1), Type::Function(params2, ret2)) => {
500 let param_score: u32 = params1
501 .iter()
502 .zip(params2.iter())
503 .map(|(p1, p2)| p1.match_score(p2).unwrap_or(0))
504 .sum();
505 let ret_score = ret1.match_score(ret2)?;
506 Some(param_score + ret_score)
507 }
508
509 _ => None,
510 }
511 }
512}
513
514impl Type {
515 pub fn display_resolved(&self) -> String {
518 match self {
519 Type::Int => "int".to_string(),
520 Type::Float => "float".to_string(),
521 Type::Number => "number".to_string(),
522 Type::String => "string".to_string(),
523 Type::Bool => "bool".to_string(),
524 Type::Symbol => "symbol".to_string(),
525 Type::None => "none".to_string(),
526 Type::Markdown => "markdown".to_string(),
527 Type::Array(elem) => format!("[{}]", elem.display_resolved()),
528 Type::Tuple(elems) => {
529 let elems_str = elems
530 .iter()
531 .map(|e| e.display_resolved())
532 .collect::<Vec<_>>()
533 .join(", ");
534 format!("({})", elems_str)
535 }
536 Type::Dict(key, value) => format!("{{{}: {}}}", key.display_resolved(), value.display_resolved()),
537 Type::Record(fields, rest) => {
538 let fields_str = fields
539 .iter()
540 .map(|(k, v)| format!("{}: {}", k, v.display_resolved()))
541 .collect::<Vec<_>>()
542 .join(", ");
543 match rest.as_ref() {
544 Type::RowEmpty => format!("{{{}}}", fields_str),
545 _ => {
546 if fields_str.is_empty() {
547 format!("{{| {}}}", rest.display_resolved())
548 } else {
549 format!("{{{} | {}}}", fields_str, rest.display_resolved())
550 }
551 }
552 }
553 }
554 Type::RowEmpty => "{}".to_string(),
555 Type::Function(params, ret) => {
556 let params_str = params
557 .iter()
558 .map(|p| p.display_resolved())
559 .collect::<Vec<_>>()
560 .join(", ");
561 format!("({}) -> {}", params_str, ret.display_resolved())
562 }
563 Type::Union(types) => {
564 let types_str = types
565 .iter()
566 .map(|t| t.display_resolved())
567 .collect::<Vec<_>>()
568 .join(" | ");
569 format!("({})", types_str)
570 }
571 Type::Var(id) => {
572 type_var_name(*id)
574 }
575 Type::Never => "never".to_string(),
576 }
577 }
578
579 pub fn display_renumbered(&self) -> String {
585 let mut var_map = FxHashMap::default();
586 let mut counter = 0usize;
587 self.fmt_renumbered(&mut var_map, &mut counter)
588 }
589
590 pub(crate) fn fmt_renumbered(&self, var_map: &mut FxHashMap<TypeVarId, usize>, counter: &mut usize) -> String {
592 match self {
593 Type::Int => "int".to_string(),
594 Type::Float => "float".to_string(),
595 Type::Number => "number".to_string(),
596 Type::String => "string".to_string(),
597 Type::Bool => "bool".to_string(),
598 Type::Symbol => "symbol".to_string(),
599 Type::None => "none".to_string(),
600 Type::Markdown => "markdown".to_string(),
601 Type::Array(elem) => format!("[{}]", elem.fmt_renumbered(var_map, counter)),
602 Type::Tuple(elems) => {
603 let elems_str = elems
604 .iter()
605 .map(|e| e.fmt_renumbered(var_map, counter))
606 .collect::<Vec<_>>()
607 .join(", ");
608 format!("({})", elems_str)
609 }
610 Type::Dict(key, value) => {
611 format!(
612 "{{{}: {}}}",
613 key.fmt_renumbered(var_map, counter),
614 value.fmt_renumbered(var_map, counter)
615 )
616 }
617 Type::Function(params, ret) => {
618 let params_str = params
619 .iter()
620 .map(|p| p.fmt_renumbered(var_map, counter))
621 .collect::<Vec<_>>()
622 .join(", ");
623 format!("({}) -> {}", params_str, ret.fmt_renumbered(var_map, counter))
624 }
625 Type::Union(types) => {
626 let types_str = types
627 .iter()
628 .map(|t| t.fmt_renumbered(var_map, counter))
629 .collect::<Vec<_>>()
630 .join(" | ");
631 format!("({})", types_str)
632 }
633 Type::Record(fields, rest) => {
634 let fields_str = fields
635 .iter()
636 .map(|(k, v)| format!("{}: {}", k, v.fmt_renumbered(var_map, counter)))
637 .collect::<Vec<_>>()
638 .join(", ");
639 match rest.as_ref() {
640 Type::RowEmpty => format!("{{{}}}", fields_str),
641 _ => {
642 let rest_str = rest.fmt_renumbered(var_map, counter);
643 if fields_str.is_empty() {
644 format!("{{| {}}}", rest_str)
645 } else {
646 format!("{{{} | {}}}", fields_str, rest_str)
647 }
648 }
649 }
650 }
651 Type::RowEmpty => "{}".to_string(),
652 Type::Never => "never".to_string(),
653 Type::Var(id) => {
654 let index = *var_map.entry(*id).or_insert_with(|| {
655 let i = *counter;
656 *counter += 1;
657 i
658 });
659 format_var_name(index)
660 }
661 }
662 }
663}
664
665fn type_var_name(id: TypeVarId) -> String {
670 use slotmap::Key;
671 let index = id.data().as_ffi() as u32 as usize;
672 format_var_name(index)
673}
674
675pub fn format_var_name(index: usize) -> String {
679 let letter = (b'a' + (index % 26) as u8) as char;
680 let suffix = index / 26;
681 if suffix == 0 {
682 format!("'{}", letter)
683 } else {
684 format!("'{}{}", letter, suffix)
685 }
686}
687
688pub(crate) fn format_type_list(types: &[Type]) -> String {
692 types
693 .iter()
694 .map(|t| t.display_renumbered())
695 .collect::<Vec<_>>()
696 .join(", ")
697}
698
699impl fmt::Display for Type {
700 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701 write!(f, "{}", self.display_renumbered())
704 }
705}
706
707#[derive(Debug, Clone, PartialEq, Eq)]
712pub struct TypeScheme {
713 pub quantified: Vec<TypeVarId>,
715 pub ty: Type,
717}
718
719impl TypeScheme {
720 pub fn mono(ty: Type) -> Self {
722 Self {
723 quantified: Vec::new(),
724 ty,
725 }
726 }
727
728 pub fn poly(quantified: Vec<TypeVarId>, ty: Type) -> Self {
730 Self { quantified, ty }
731 }
732
733 pub fn instantiate(&self, ctx: &mut TypeVarContext) -> Type {
735 if self.quantified.is_empty() {
736 return self.ty.clone();
737 }
738
739 let mut subst = Substitution::empty();
741 for var_id in &self.quantified {
742 let fresh = ctx.fresh();
743 subst.insert(*var_id, Type::Var(fresh));
744 }
745
746 self.ty.apply_subst(&subst)
747 }
748
749 pub fn generalize(ty: Type, env_vars: &[TypeVarId]) -> Self {
751 let ty_vars = ty.free_vars();
752 let quantified: Vec<TypeVarId> = ty_vars.into_iter().filter(|v| !env_vars.contains(v)).collect();
753 Self::poly(quantified, ty)
754 }
755}
756
757impl fmt::Display for TypeScheme {
758 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
759 if self.quantified.is_empty() {
760 write!(f, "{}", self.ty.display_renumbered())
762 } else {
763 let mut var_map: FxHashMap<TypeVarId, usize> = FxHashMap::default();
765 for (i, var) in self.quantified.iter().enumerate() {
766 var_map.insert(*var, i);
767 }
768 let mut counter = self.quantified.len();
769
770 write!(f, "forall ")?;
771 for (i, var) in self.quantified.iter().enumerate() {
772 if i > 0 {
773 write!(f, " ")?;
774 }
775 write!(f, "{}", format_var_name(var_map[var]))?;
776 }
777 write!(f, ". {}", self.ty.fmt_renumbered(&mut var_map, &mut counter))
778 }
779 }
780}
781
782pub struct TypeVarContext {
784 vars: SlotMap<TypeVarId, Option<Type>>,
785}
786
787impl TypeVarContext {
788 pub fn new() -> Self {
790 Self {
791 vars: SlotMap::with_key(),
792 }
793 }
794
795 pub fn fresh(&mut self) -> TypeVarId {
797 self.vars.insert(None)
798 }
799
800 pub fn get(&self, var: TypeVarId) -> Option<&Type> {
802 self.vars.get(var).and_then(|opt| opt.as_ref())
803 }
804
805 pub fn set(&mut self, var: TypeVarId, ty: Type) {
807 if let Some(slot) = self.vars.get_mut(var) {
808 *slot = Some(ty);
809 }
810 }
811
812 pub fn is_resolved(&self, var: TypeVarId) -> bool {
814 self.vars.get(var).and_then(|opt| opt.as_ref()).is_some()
815 }
816}
817
818impl Default for TypeVarContext {
819 fn default() -> Self {
820 Self::new()
821 }
822}
823
824#[derive(Debug, Clone, Default)]
826pub struct Substitution {
827 map: FxHashMap<TypeVarId, Type>,
828}
829
830impl Substitution {
831 pub fn empty() -> Self {
833 Self {
834 map: FxHashMap::default(),
835 }
836 }
837
838 pub fn insert(&mut self, var: TypeVarId, ty: Type) {
840 self.map.insert(var, ty);
841 }
842
843 pub fn lookup(&self, var: TypeVarId) -> Option<&Type> {
845 self.map.get(&var)
846 }
847
848 pub fn is_empty(&self) -> bool {
850 self.map.is_empty()
851 }
852
853 pub fn compose(&self, other: &Substitution) -> Substitution {
855 let mut result = Substitution::empty();
856
857 for (var, ty) in &self.map {
859 result.insert(*var, ty.apply_subst(other));
860 }
861
862 for (var, ty) in &other.map {
864 if !self.map.contains_key(var) {
865 result.insert(*var, ty.clone());
866 }
867 }
868
869 result
870 }
871}
872
873#[cfg(test)]
874mod tests {
875 use super::*;
876 use rstest::rstest;
877
878 #[test]
879 fn test_type_display() {
880 assert_eq!(Type::Number.to_string(), "number");
881 assert_eq!(Type::String.to_string(), "string");
882 assert_eq!(Type::array(Type::Number).to_string(), "[number]");
883 assert_eq!(
884 Type::function(vec![Type::Number, Type::String], Type::Bool).to_string(),
885 "(number, string) -> bool"
886 );
887 }
888
889 #[test]
890 fn test_type_var_context() {
891 let mut ctx = TypeVarContext::new();
892 let var1 = ctx.fresh();
893 let var2 = ctx.fresh();
894 assert_ne!(var1, var2);
895 }
896
897 #[test]
898 fn test_substitution() {
899 let mut ctx = TypeVarContext::new();
900 let var = ctx.fresh();
901 let ty = Type::Var(var);
902
903 let mut subst = Substitution::empty();
904 subst.insert(var, Type::Number);
905
906 let result = ty.apply_subst(&subst);
907 assert_eq!(result, Type::Number);
908 }
909
910 #[test]
911 fn test_type_scheme_instantiate() {
912 let mut ctx = TypeVarContext::new();
913 let var = ctx.fresh();
914
915 let scheme = TypeScheme::poly(vec![var], Type::Var(var));
916 let inst1 = scheme.instantiate(&mut ctx);
917 let inst2 = scheme.instantiate(&mut ctx);
918
919 assert_ne!(inst1, inst2);
921 }
922
923 #[test]
924 fn test_can_match_concrete_types() {
925 assert!(Type::Number.can_match(&Type::Number));
926 assert!(Type::String.can_match(&Type::String));
927 assert!(!Type::Number.can_match(&Type::String));
928 }
929
930 #[test]
931 fn test_can_match_type_variables() {
932 let mut ctx = TypeVarContext::new();
933 let var = ctx.fresh();
934
935 assert!(Type::Var(var).can_match(&Type::Number));
937 assert!(Type::Number.can_match(&Type::Var(var)));
938 assert!(Type::Var(var).can_match(&Type::String));
939 }
940
941 #[test]
942 fn test_can_match_arrays() {
943 let arr_num = Type::array(Type::Number);
944 let arr_str = Type::array(Type::String);
945
946 assert!(arr_num.can_match(&arr_num));
947 assert!(!arr_num.can_match(&arr_str));
948 }
949
950 #[test]
951 fn test_can_match_functions() {
952 let func1 = Type::function(vec![Type::Number], Type::String);
953 let func2 = Type::function(vec![Type::Number], Type::String);
954 let func3 = Type::function(vec![Type::String], Type::String);
955
956 assert!(func1.can_match(&func2));
957 assert!(!func1.can_match(&func3));
958 }
959
960 #[test]
961 fn test_match_score() {
962 assert_eq!(Type::Number.match_score(&Type::Number), Some(100));
964 assert_eq!(Type::String.match_score(&Type::String), Some(100));
965
966 let mut ctx = TypeVarContext::new();
968 let var = ctx.fresh();
969 assert_eq!(Type::Var(var).match_score(&Type::Number), Some(10));
970
971 assert_eq!(Type::Number.match_score(&Type::String), None);
973 }
974
975 #[rstest]
976 #[case(vec![Type::Number, Type::Number], Type::Number)]
977 #[case(vec![Type::Number, Type::String], Type::union(vec![Type::Number, Type::String]))]
978 #[case(vec![Type::union(vec![Type::Number, Type::String]), Type::Bool], Type::union(vec![Type::Number, Type::String, Type::Bool]))]
979 #[case(vec![Type::Number, Type::String, Type::Number], Type::union(vec![Type::Number, Type::String]))]
980 fn test_type_union(#[case] types: Vec<Type>, #[case] expected: Type) {
981 assert_eq!(Type::union(types), expected);
982 }
983
984 #[rstest]
985 #[case(Type::union(vec![Type::Number, Type::String]), &Type::Number, Type::String)]
986 #[case(Type::union(vec![Type::Number, Type::String, Type::Bool]), &Type::String, Type::union(vec![Type::Number, Type::Bool]))]
987 #[case(Type::Number, &Type::Number, Type::Number)]
988 fn test_type_subtract(#[case] ty: Type, #[case] exclude: &Type, #[case] expected: Type) {
989 assert_eq!(ty.subtract(exclude), expected);
990 }
991
992 #[rstest]
993 #[case(Type::Number, Type::Number, true)]
994 #[case(Type::Number, Type::String, false)]
995 #[case(Type::array(Type::Number), Type::array(Type::Number), true)]
996 #[case(Type::array(Type::Number), Type::array(Type::String), false)]
997 #[case(Type::tuple(vec![Type::Number]), Type::array(Type::Number), true)]
998 #[case(Type::record(BTreeMap::from([("a".to_string(), Type::Number)]), Type::RowEmpty), Type::dict(Type::String, Type::Number), true)]
999 fn test_can_match_complex(#[case] t1: Type, #[case] t2: Type, #[case] expected: bool) {
1000 assert_eq!(t1.can_match(&t2), expected);
1001 }
1002
1003 #[rstest]
1004 #[case(Type::Number, Type::Number, true)]
1005 #[case(Type::Number, Type::String, false)]
1006 fn test_can_branch_unify_with(#[case] t1: Type, #[case] t2: Type, #[case] expected: bool) {
1007 assert_eq!(t1.can_branch_unify_with(&t2), expected);
1008 }
1009
1010 #[test]
1011 fn test_can_branch_unify_with_vars() {
1012 let mut ctx = TypeVarContext::new();
1013 let v1 = ctx.fresh();
1014 let v2 = ctx.fresh();
1015 assert!(Type::Var(v1).can_branch_unify_with(&Type::Var(v2)));
1016 assert!(!Type::Var(v1).can_branch_unify_with(&Type::Number));
1017 }
1018
1019 #[test]
1020 fn test_substitution_compose() {
1021 let mut ctx = TypeVarContext::new();
1022 let v1 = ctx.fresh();
1023 let v2 = ctx.fresh();
1024
1025 let mut s1 = Substitution::empty();
1026 s1.insert(v1, Type::Var(v2));
1027
1028 let mut s2 = Substitution::empty();
1029 s2.insert(v2, Type::Number);
1030
1031 let s3 = s1.compose(&s2);
1032 assert_eq!(s3.lookup(v1), Some(&Type::Number));
1033 assert_eq!(s3.lookup(v2), Some(&Type::Number));
1034 }
1035
1036 #[test]
1037 fn test_type_scheme_generalize() {
1038 let mut ctx = TypeVarContext::new();
1039 let v1 = ctx.fresh();
1040 let v2 = ctx.fresh();
1041
1042 let ty = Type::function(vec![Type::Var(v1)], Type::Var(v2));
1043 let env_vars = vec![v1];
1044
1045 let scheme = TypeScheme::generalize(ty, &env_vars);
1046 assert_eq!(scheme.quantified, vec![v2]);
1047 }
1048
1049 #[test]
1050 fn test_display_renumbered() {
1051 let mut ctx = TypeVarContext::new();
1052 let v1 = ctx.fresh();
1053 let v2 = ctx.fresh();
1054 let ty = Type::function(vec![Type::Var(v1)], Type::Var(v2));
1055
1056 assert_eq!(ty.display_renumbered(), "('a) -> 'b");
1058 }
1059}