1use std::collections::HashMap;
19
20use crate::intern::{Interner, Symbol};
21use crate::analysis::{FieldType, LogosType};
22
23#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct TyVar(pub u32);
30
31#[derive(Clone, PartialEq, Debug)]
38pub enum InferType {
39 Int,
41 Float,
42 Bool,
43 Char,
44 Byte,
45 String,
46 Unit,
47 Nat,
48 Duration,
49 Date,
50 Moment,
51 Time,
52 Span,
53 Seq(Box<InferType>),
54 Map(Box<InferType>, Box<InferType>),
55 Set(Box<InferType>),
56 Option(Box<InferType>),
57 UserDefined(Symbol),
58 Var(TyVar),
60 Function(Vec<InferType>, Box<InferType>),
61 Unknown,
62}
63
64#[derive(Debug, Clone)]
66pub enum TypeError {
67 Mismatch { expected: InferType, found: InferType },
68 InfiniteType { var: TyVar, ty: InferType },
69 ArityMismatch { expected: usize, found: usize },
70 FieldNotFound { type_name: Symbol, field_name: Symbol },
71 NotAFunction { found: InferType },
72}
73
74impl TypeError {
75 pub fn expected_str(&self) -> std::string::String {
76 match self {
77 TypeError::Mismatch { expected, .. } => expected.to_logos_name(),
78 TypeError::ArityMismatch { expected, .. } => format!("{} arguments", expected),
79 TypeError::FieldNotFound { .. } => "a known field".to_string(),
80 TypeError::NotAFunction { .. } => "a function".to_string(),
81 TypeError::InfiniteType { .. } => "a finite type".to_string(),
82 }
83 }
84
85 pub fn found_str(&self) -> std::string::String {
86 match self {
87 TypeError::Mismatch { found, .. } => found.to_logos_name(),
88 TypeError::ArityMismatch { found, .. } => format!("{} arguments", found),
89 TypeError::FieldNotFound { field_name, .. } => format!("{:?}", field_name),
90 TypeError::NotAFunction { found } => found.to_logos_name(),
91 TypeError::InfiniteType { ty, .. } => ty.to_logos_name(),
92 }
93 }
94
95 pub fn to_parse_error_kind(
99 &self,
100 interner: &crate::intern::Interner,
101 ) -> crate::error::ParseErrorKind {
102 use crate::error::ParseErrorKind;
103 match self {
104 TypeError::Mismatch { expected, found } => ParseErrorKind::TypeMismatchDetailed {
105 expected: expected.to_logos_name(),
106 found: found.to_logos_name(),
107 context: String::new(),
108 },
109 TypeError::InfiniteType { var, ty } => ParseErrorKind::InfiniteType {
110 var_description: format!("type variable α{}", var.0),
111 type_description: ty.to_logos_name(),
112 },
113 TypeError::ArityMismatch { expected, found } => ParseErrorKind::ArityMismatch {
114 function: String::from("function"),
115 expected: *expected,
116 found: *found,
117 },
118 TypeError::FieldNotFound { type_name, field_name } => ParseErrorKind::FieldNotFound {
119 type_name: interner.resolve(*type_name).to_string(),
120 field_name: interner.resolve(*field_name).to_string(),
121 available: vec![],
122 },
123 TypeError::NotAFunction { found } => ParseErrorKind::NotAFunction {
124 found_type: found.to_logos_name(),
125 },
126 }
127 }
128}
129
130#[derive(Clone, Debug)]
139pub struct TypeScheme {
140 pub vars: Vec<TyVar>,
142 pub body: InferType,
144}
145
146pub struct UnificationTable {
152 bindings: Vec<Option<InferType>>,
153 next_id: u32,
154}
155
156impl UnificationTable {
157 pub fn new() -> Self {
158 Self {
159 bindings: Vec::new(),
160 next_id: 0,
161 }
162 }
163
164 pub fn fresh(&mut self) -> InferType {
166 let id = self.next_id;
167 self.next_id += 1;
168 self.bindings.push(None);
169 InferType::Var(TyVar(id))
170 }
171
172 pub fn fresh_var(&mut self) -> TyVar {
174 let id = self.next_id;
175 self.next_id += 1;
176 self.bindings.push(None);
177 TyVar(id)
178 }
179
180 pub fn instantiate(&mut self, scheme: &TypeScheme) -> InferType {
185 if scheme.vars.is_empty() {
186 return scheme.body.clone();
187 }
188 let subst: HashMap<TyVar, TyVar> = scheme.vars.iter()
189 .map(|&old_tv| (old_tv, self.fresh_var()))
190 .collect();
191 self.substitute_vars(&scheme.body, &subst)
192 }
193
194 fn substitute_vars(&self, ty: &InferType, subst: &HashMap<TyVar, TyVar>) -> InferType {
196 match ty {
197 InferType::Var(tv) => {
198 let resolved = self.find(*tv);
199 match &resolved {
200 InferType::Var(rtv) => {
201 if let Some(&new_tv) = subst.get(rtv) {
202 InferType::Var(new_tv)
203 } else {
204 InferType::Var(*rtv)
205 }
206 }
207 other => self.substitute_vars(&other.clone(), subst),
208 }
209 }
210 InferType::Seq(inner) => InferType::Seq(Box::new(self.substitute_vars(inner, subst))),
211 InferType::Map(k, v) => InferType::Map(
212 Box::new(self.substitute_vars(k, subst)),
213 Box::new(self.substitute_vars(v, subst)),
214 ),
215 InferType::Set(inner) => InferType::Set(Box::new(self.substitute_vars(inner, subst))),
216 InferType::Option(inner) => InferType::Option(Box::new(self.substitute_vars(inner, subst))),
217 InferType::Function(params, ret) => InferType::Function(
218 params.iter().map(|p| self.substitute_vars(p, subst)).collect(),
219 Box::new(self.substitute_vars(ret, subst)),
220 ),
221 other => other.clone(),
222 }
223 }
224
225 pub fn find(&self, tv: TyVar) -> InferType {
227 let mut current = tv;
228 loop {
229 match &self.bindings[current.0 as usize] {
230 None => return InferType::Var(current),
231 Some(InferType::Var(tv2)) => current = *tv2,
232 Some(ty) => return ty.clone(),
233 }
234 }
235 }
236
237 fn walk(&self, ty: &InferType) -> InferType {
239 match ty {
240 InferType::Var(tv) => self.find(*tv),
241 other => other.clone(),
242 }
243 }
244
245 pub fn resolve(&self, ty: &InferType) -> InferType {
251 match ty {
252 InferType::Var(tv) => {
253 let resolved = self.find(*tv);
254 match &resolved {
255 InferType::Var(_) => resolved, other => self.resolve(&other.clone()),
257 }
258 }
259 InferType::Seq(inner) => InferType::Seq(Box::new(self.resolve(inner))),
260 InferType::Map(k, v) => {
261 InferType::Map(Box::new(self.resolve(k)), Box::new(self.resolve(v)))
262 }
263 InferType::Set(inner) => InferType::Set(Box::new(self.resolve(inner))),
264 InferType::Option(inner) => InferType::Option(Box::new(self.resolve(inner))),
265 InferType::Function(params, ret) => {
266 let params = params.iter().map(|p| self.resolve(p)).collect();
267 InferType::Function(params, Box::new(self.resolve(ret)))
268 }
269 other => other.clone(),
270 }
271 }
272
273 pub fn zonk(&self, ty: &InferType) -> InferType {
278 match ty {
279 InferType::Var(tv) => {
280 let resolved = self.find(*tv);
281 match &resolved {
282 InferType::Var(_) => InferType::Unknown,
283 other => self.zonk(other),
284 }
285 }
286 InferType::Seq(inner) => InferType::Seq(Box::new(self.zonk(inner))),
287 InferType::Map(k, v) => {
288 InferType::Map(Box::new(self.zonk(k)), Box::new(self.zonk(v)))
289 }
290 InferType::Set(inner) => InferType::Set(Box::new(self.zonk(inner))),
291 InferType::Option(inner) => InferType::Option(Box::new(self.zonk(inner))),
292 InferType::Function(params, ret) => {
293 let params = params.iter().map(|p| self.zonk(p)).collect();
294 InferType::Function(params, Box::new(self.zonk(ret)))
295 }
296 other => other.clone(),
297 }
298 }
299
300 pub fn to_logos_type(&self, ty: &InferType) -> LogosType {
302 let zonked = self.zonk(ty);
303 infer_to_logos(&zonked)
304 }
305
306 pub fn unify(&mut self, a: &InferType, b: &InferType) -> Result<(), TypeError> {
311 let a = self.walk(a);
312 let b = self.walk(b);
313 self.unify_walked(&a, &b)
314 }
315
316 fn unify_walked(&mut self, a: &InferType, b: &InferType) -> Result<(), TypeError> {
317 match (a, b) {
318 (InferType::Var(va), InferType::Var(vb)) if va == vb => Ok(()),
320
321 (InferType::Var(tv), ty) => {
323 let tv = *tv;
324 let ty = ty.clone();
325 self.occurs_check(tv, &ty)?;
326 self.bindings[tv.0 as usize] = Some(ty);
327 Ok(())
328 }
329 (ty, InferType::Var(tv)) => {
330 let tv = *tv;
331 let ty = ty.clone();
332 self.occurs_check(tv, &ty)?;
333 self.bindings[tv.0 as usize] = Some(ty);
334 Ok(())
335 }
336
337 (InferType::Unknown, _) | (_, InferType::Unknown) => Ok(()),
339
340 (InferType::Int, InferType::Int) => Ok(()),
342 (InferType::Float, InferType::Float) => Ok(()),
343 (InferType::Bool, InferType::Bool) => Ok(()),
344 (InferType::Char, InferType::Char) => Ok(()),
345 (InferType::Byte, InferType::Byte) => Ok(()),
346 (InferType::String, InferType::String) => Ok(()),
347 (InferType::Unit, InferType::Unit) => Ok(()),
348 (InferType::Nat, InferType::Nat) => Ok(()),
349 (InferType::Duration, InferType::Duration) => Ok(()),
350 (InferType::Date, InferType::Date) => Ok(()),
351 (InferType::Moment, InferType::Moment) => Ok(()),
352 (InferType::Time, InferType::Time) => Ok(()),
353 (InferType::Span, InferType::Span) => Ok(()),
354
355 (InferType::Nat, InferType::Int) | (InferType::Int, InferType::Nat) => Ok(()),
357
358 (InferType::Byte, InferType::Int) | (InferType::Int, InferType::Byte) => Ok(()),
360
361 (InferType::UserDefined(a), InferType::UserDefined(b)) if a == b => Ok(()),
363
364 (InferType::Seq(a_inner), InferType::Seq(b_inner)) => {
366 let a_inner = (**a_inner).clone();
367 let b_inner = (**b_inner).clone();
368 self.unify(&a_inner, &b_inner)
369 }
370 (InferType::Set(a_inner), InferType::Set(b_inner)) => {
371 let a_inner = (**a_inner).clone();
372 let b_inner = (**b_inner).clone();
373 self.unify(&a_inner, &b_inner)
374 }
375 (InferType::Option(a_inner), InferType::Option(b_inner)) => {
376 let a_inner = (**a_inner).clone();
377 let b_inner = (**b_inner).clone();
378 self.unify(&a_inner, &b_inner)
379 }
380 (InferType::Map(ak, av), InferType::Map(bk, bv)) => {
381 let ak = (**ak).clone();
382 let bk = (**bk).clone();
383 let av = (**av).clone();
384 let bv = (**bv).clone();
385 self.unify(&ak, &bk)?;
386 self.unify(&av, &bv)
387 }
388 (InferType::Function(a_params, a_ret), InferType::Function(b_params, b_ret)) => {
389 if a_params.len() != b_params.len() {
390 return Err(TypeError::ArityMismatch {
391 expected: a_params.len(),
392 found: b_params.len(),
393 });
394 }
395 let a_params = a_params.clone();
396 let b_params = b_params.clone();
397 let a_ret = (**a_ret).clone();
398 let b_ret = (**b_ret).clone();
399 for (ap, bp) in a_params.iter().zip(b_params.iter()) {
400 self.unify(ap, bp)?;
401 }
402 self.unify(&a_ret, &b_ret)
403 }
404
405 (a, b) => Err(TypeError::Mismatch {
407 expected: a.clone(),
408 found: b.clone(),
409 }),
410 }
411 }
412
413 fn occurs_check(&self, tv: TyVar, ty: &InferType) -> Result<(), TypeError> {
417 match ty {
418 InferType::Var(tv2) => {
419 let resolved = self.find(*tv2);
420 match &resolved {
421 InferType::Var(rtv) => {
422 if *rtv == tv {
423 Err(TypeError::InfiniteType { var: tv, ty: ty.clone() })
424 } else {
425 Ok(())
426 }
427 }
428 other => self.occurs_check(tv, &other.clone()),
429 }
430 }
431 InferType::Seq(inner) | InferType::Set(inner) | InferType::Option(inner) => {
432 self.occurs_check(tv, inner)
433 }
434 InferType::Map(k, v) => {
435 self.occurs_check(tv, k)?;
436 self.occurs_check(tv, v)
437 }
438 InferType::Function(params, ret) => {
439 for p in params {
440 self.occurs_check(tv, p)?;
441 }
442 self.occurs_check(tv, ret)
443 }
444 _ => Ok(()),
445 }
446 }
447}
448
449impl InferType {
454 pub fn from_type_expr(ty: &crate::ast::stmt::TypeExpr, interner: &Interner) -> InferType {
460 Self::from_type_expr_with_params(ty, interner, &HashMap::new())
461 }
462
463 pub fn from_type_expr_with_params(
470 ty: &crate::ast::stmt::TypeExpr,
471 interner: &Interner,
472 type_params: &HashMap<Symbol, TyVar>,
473 ) -> InferType {
474 use crate::ast::stmt::TypeExpr;
475 match ty {
476 TypeExpr::Primitive(sym) | TypeExpr::Named(sym) => {
477 if let Some(&tv) = type_params.get(sym) {
479 return InferType::Var(tv);
480 }
481 Self::from_type_name(interner.resolve(*sym))
482 }
483 TypeExpr::Generic { base, params } => {
484 let base_name = interner.resolve(*base);
485 match base_name {
486 "Seq" | "List" | "Vec" => {
487 let elem = params
488 .first()
489 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
490 .unwrap_or(InferType::Unit);
491 InferType::Seq(Box::new(elem))
492 }
493 "Map" | "HashMap" => {
494 let key = params
495 .first()
496 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
497 .unwrap_or(InferType::String);
498 let val = params
499 .get(1)
500 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
501 .unwrap_or(InferType::String);
502 InferType::Map(Box::new(key), Box::new(val))
503 }
504 "Set" | "HashSet" => {
505 let elem = params
506 .first()
507 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
508 .unwrap_or(InferType::Unit);
509 InferType::Set(Box::new(elem))
510 }
511 "Option" | "Maybe" => {
512 let inner = params
513 .first()
514 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
515 .unwrap_or(InferType::Unit);
516 InferType::Option(Box::new(inner))
517 }
518 _ => InferType::Unknown,
519 }
520 }
521 TypeExpr::Refinement { base, .. } => {
522 InferType::from_type_expr_with_params(base, interner, type_params)
523 }
524 TypeExpr::Persistent { inner } => {
525 InferType::from_type_expr_with_params(inner, interner, type_params)
526 }
527 TypeExpr::Function { inputs, output } => {
528 let param_types: Vec<InferType> = inputs
529 .iter()
530 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
531 .collect();
532 let ret_type = InferType::from_type_expr_with_params(output, interner, type_params);
533 InferType::Function(param_types, Box::new(ret_type))
534 }
535 }
536 }
537
538 pub fn from_field_type(
543 ty: &FieldType,
544 interner: &Interner,
545 type_params: &HashMap<Symbol, TyVar>,
546 ) -> InferType {
547 match ty {
548 FieldType::Primitive(sym) => InferType::from_type_name(interner.resolve(*sym)),
549 FieldType::Named(sym) => {
550 let name = interner.resolve(*sym);
551 let primitive = InferType::from_type_name(name);
552 if primitive == InferType::Unknown {
553 InferType::UserDefined(*sym)
554 } else {
555 primitive
556 }
557 }
558 FieldType::Generic { base, params } => {
559 let base_name = interner.resolve(*base);
560 let converted: Vec<InferType> = params
561 .iter()
562 .map(|p| InferType::from_field_type(p, interner, type_params))
563 .collect();
564 match base_name {
565 "Seq" | "List" | "Vec" => {
566 InferType::Seq(Box::new(
567 converted.into_iter().next().unwrap_or(InferType::Unit),
568 ))
569 }
570 "Map" | "HashMap" => {
571 let mut it = converted.into_iter();
572 let k = it.next().unwrap_or(InferType::String);
573 let v = it.next().unwrap_or(InferType::String);
574 InferType::Map(Box::new(k), Box::new(v))
575 }
576 "Set" | "HashSet" => {
577 InferType::Set(Box::new(
578 converted.into_iter().next().unwrap_or(InferType::Unit),
579 ))
580 }
581 "Option" | "Maybe" => {
582 InferType::Option(Box::new(
583 converted.into_iter().next().unwrap_or(InferType::Unit),
584 ))
585 }
586 _ => InferType::Unknown,
587 }
588 }
589 FieldType::TypeParam(sym) => {
590 if let Some(tv) = type_params.get(sym) {
591 InferType::Var(*tv)
592 } else {
593 InferType::Unknown
594 }
595 }
596 }
597 }
598
599 pub fn from_literal(lit: &crate::ast::stmt::Literal) -> InferType {
601 use crate::ast::stmt::Literal;
602 match lit {
603 Literal::Number(_) => InferType::Int,
604 Literal::Float(_) => InferType::Float,
605 Literal::Text(_) => InferType::String,
606 Literal::Boolean(_) => InferType::Bool,
607 Literal::Char(_) => InferType::Char,
608 Literal::Nothing => InferType::Unit,
609 Literal::Duration(_) => InferType::Duration,
610 Literal::Date(_) => InferType::Date,
611 Literal::Moment(_) => InferType::Moment,
612 Literal::Span { .. } => InferType::Span,
613 Literal::Time(_) => InferType::Time,
614 }
615 }
616
617 pub fn from_type_name(name: &str) -> InferType {
619 match name {
620 "Int" => InferType::Int,
621 "Nat" => InferType::Nat,
622 "Real" | "Float" => InferType::Float,
623 "Bool" | "Boolean" => InferType::Bool,
624 "Text" | "String" => InferType::String,
625 "Char" => InferType::Char,
626 "Byte" => InferType::Byte,
627 "Unit" | "()" => InferType::Unit,
628 "Duration" => InferType::Duration,
629 "Date" => InferType::Date,
630 "Moment" => InferType::Moment,
631 "Time" => InferType::Time,
632 "Span" => InferType::Span,
633 _ => InferType::Unknown,
634 }
635 }
636
637 pub fn to_logos_name(&self) -> std::string::String {
639 match self {
640 InferType::Int => "Int".into(),
641 InferType::Float => "Real".into(),
642 InferType::Bool => "Bool".into(),
643 InferType::Char => "Char".into(),
644 InferType::Byte => "Byte".into(),
645 InferType::String => "Text".into(),
646 InferType::Unit => "Unit".into(),
647 InferType::Nat => "Nat".into(),
648 InferType::Duration => "Duration".into(),
649 InferType::Date => "Date".into(),
650 InferType::Moment => "Moment".into(),
651 InferType::Time => "Time".into(),
652 InferType::Span => "Span".into(),
653 InferType::Seq(inner) => format!("Seq of {}", inner.to_logos_name()),
654 InferType::Map(k, v) => {
655 format!("Map of {} and {}", k.to_logos_name(), v.to_logos_name())
656 }
657 InferType::Set(inner) => format!("Set of {}", inner.to_logos_name()),
658 InferType::Option(inner) => format!("Option of {}", inner.to_logos_name()),
659 InferType::UserDefined(_) => "a user-defined type".into(),
660 InferType::Var(_) => "an unknown type".into(),
661 InferType::Function(params, ret) => {
662 let params_str = params
663 .iter()
664 .map(|p| p.to_logos_name())
665 .collect::<Vec<_>>()
666 .join(", ");
667 format!("fn({}) -> {}", params_str, ret.to_logos_name())
668 }
669 InferType::Unknown => "unknown".into(),
670 }
671 }
672
673 pub fn to_logos_type_ground(&self) -> LogosType {
679 match self {
680 InferType::Int => LogosType::Int,
681 InferType::Float => LogosType::Float,
682 InferType::Bool => LogosType::Bool,
683 InferType::Char => LogosType::Char,
684 InferType::Byte => LogosType::Byte,
685 InferType::String => LogosType::String,
686 InferType::Unit => LogosType::Unit,
687 InferType::Nat => LogosType::Nat,
688 InferType::Duration => LogosType::Duration,
689 InferType::Date => LogosType::Date,
690 InferType::Moment => LogosType::Moment,
691 InferType::Time => LogosType::Time,
692 InferType::Span => LogosType::Span,
693 InferType::Seq(inner) => LogosType::Seq(Box::new(inner.to_logos_type_ground())),
694 InferType::Map(k, v) => LogosType::Map(
695 Box::new(k.to_logos_type_ground()),
696 Box::new(v.to_logos_type_ground()),
697 ),
698 InferType::Set(inner) => LogosType::Set(Box::new(inner.to_logos_type_ground())),
699 InferType::Option(inner) => LogosType::Option(Box::new(inner.to_logos_type_ground())),
700 InferType::UserDefined(sym) => LogosType::UserDefined(*sym),
701 InferType::Function(params, ret) => LogosType::Function(
702 params.iter().map(|p| p.to_logos_type_ground()).collect(),
703 Box::new(ret.to_logos_type_ground()),
704 ),
705 InferType::Unknown => LogosType::Unknown,
706 InferType::Var(_) => panic!("to_logos_type_ground called on unresolved Var"),
707 }
708 }
709}
710
711pub fn unify_numeric(a: &InferType, b: &InferType) -> Result<InferType, TypeError> {
718 match (a, b) {
719 (InferType::Float, _) | (_, InferType::Float) => Ok(InferType::Float),
720 (InferType::Int, InferType::Int) => Ok(InferType::Int),
721 (InferType::Nat, InferType::Int) | (InferType::Int, InferType::Nat) => Ok(InferType::Int),
722 (InferType::Nat, InferType::Nat) => Ok(InferType::Nat),
723 (InferType::Byte, InferType::Byte) => Ok(InferType::Byte),
724 (InferType::Byte, InferType::Int) | (InferType::Int, InferType::Byte) => Ok(InferType::Byte),
725 _ => Err(TypeError::Mismatch {
726 expected: InferType::Int,
727 found: a.clone(),
728 }),
729 }
730}
731
732pub fn infer_to_logos(ty: &InferType) -> LogosType {
734 match ty {
735 InferType::Int => LogosType::Int,
736 InferType::Float => LogosType::Float,
737 InferType::Bool => LogosType::Bool,
738 InferType::Char => LogosType::Char,
739 InferType::Byte => LogosType::Byte,
740 InferType::String => LogosType::String,
741 InferType::Unit => LogosType::Unit,
742 InferType::Nat => LogosType::Nat,
743 InferType::Duration => LogosType::Duration,
744 InferType::Date => LogosType::Date,
745 InferType::Moment => LogosType::Moment,
746 InferType::Time => LogosType::Time,
747 InferType::Span => LogosType::Span,
748 InferType::Seq(inner) => LogosType::Seq(Box::new(infer_to_logos(inner))),
749 InferType::Map(k, v) => {
750 LogosType::Map(Box::new(infer_to_logos(k)), Box::new(infer_to_logos(v)))
751 }
752 InferType::Set(inner) => LogosType::Set(Box::new(infer_to_logos(inner))),
753 InferType::Option(inner) => LogosType::Option(Box::new(infer_to_logos(inner))),
754 InferType::UserDefined(sym) => LogosType::UserDefined(*sym),
755 InferType::Function(params, ret) => LogosType::Function(
756 params.iter().map(infer_to_logos).collect(),
757 Box::new(infer_to_logos(ret)),
758 ),
759 InferType::Unknown | InferType::Var(_) => LogosType::Unknown,
760 }
761}
762
763#[cfg(test)]
768mod tests {
769 use super::*;
770 use crate::analysis::{FieldDef, TypeDef};
771
772 #[test]
777 fn fresh_produces_distinct_vars() {
778 let mut table = UnificationTable::new();
779 let a = table.fresh();
780 let b = table.fresh();
781 assert_ne!(a, b);
782 }
783
784 #[test]
785 fn unbound_var_finds_itself() {
786 let mut table = UnificationTable::new();
787 let v = table.fresh();
788 if let InferType::Var(tv) = v {
789 assert_eq!(table.find(tv), InferType::Var(tv));
790 } else {
791 panic!("expected Var");
792 }
793 }
794
795 #[test]
800 fn unify_identical_ground_types() {
801 let mut table = UnificationTable::new();
802 assert!(table.unify(&InferType::Int, &InferType::Int).is_ok());
803 assert!(table.unify(&InferType::Float, &InferType::Float).is_ok());
804 assert!(table.unify(&InferType::Bool, &InferType::Bool).is_ok());
805 assert!(table.unify(&InferType::String, &InferType::String).is_ok());
806 assert!(table.unify(&InferType::Unit, &InferType::Unit).is_ok());
807 }
808
809 #[test]
810 fn unify_different_ground_types_fails() {
811 let mut table = UnificationTable::new();
812 let result = table.unify(&InferType::Int, &InferType::String);
813 assert!(result.is_err());
814 assert!(matches!(result, Err(TypeError::Mismatch { .. })));
815 }
816
817 #[test]
818 fn unify_int_float_fails() {
819 let mut table = UnificationTable::new();
820 let result = table.unify(&InferType::Int, &InferType::Float);
821 assert!(result.is_err());
822 }
823
824 #[test]
825 fn unify_nat_int_succeeds() {
826 let mut table = UnificationTable::new();
827 assert!(table.unify(&InferType::Nat, &InferType::Int).is_ok());
828 assert!(table.unify(&InferType::Int, &InferType::Nat).is_ok());
829 }
830
831 #[test]
832 fn unify_unknown_with_any_succeeds() {
833 let mut table = UnificationTable::new();
834 assert!(table.unify(&InferType::Unknown, &InferType::Int).is_ok());
835 assert!(table.unify(&InferType::String, &InferType::Unknown).is_ok());
836 assert!(table.unify(&InferType::Unknown, &InferType::Unknown).is_ok());
837 }
838
839 #[test]
844 fn var_unifies_with_int() {
845 let mut table = UnificationTable::new();
846 let v = table.fresh();
847 if let InferType::Var(tv) = v {
848 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
849 assert_eq!(table.find(tv), InferType::Int);
850 }
851 }
852
853 #[test]
854 fn int_unifies_with_var() {
855 let mut table = UnificationTable::new();
856 let v = table.fresh();
857 if let InferType::Var(tv) = v {
858 table.unify(&InferType::Int, &InferType::Var(tv)).unwrap();
859 assert_eq!(table.find(tv), InferType::Int);
860 }
861 }
862
863 #[test]
864 fn two_vars_unify_chain() {
865 let mut table = UnificationTable::new();
866 let va = table.fresh();
867 let vb = table.fresh();
868 let tva = if let InferType::Var(tv) = va { tv } else { panic!() };
869 let tvb = if let InferType::Var(tv) = vb { tv } else { panic!() };
870 table.unify(&InferType::Var(tva), &InferType::Var(tvb)).unwrap();
871 table.unify(&InferType::Var(tvb), &InferType::Int).unwrap();
873 let zonked = table.zonk(&InferType::Var(tva));
874 assert_eq!(zonked, InferType::Int);
875 }
876
877 #[test]
878 fn var_conflicting_types_fails() {
879 let mut table = UnificationTable::new();
880 let v = table.fresh();
881 if let InferType::Var(tv) = v {
882 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
883 let result = table.unify(&InferType::Var(tv), &InferType::String);
884 assert!(result.is_err());
886 }
887 }
888
889 #[test]
894 fn occurs_check_detects_infinite_type() {
895 let mut table = UnificationTable::new();
896 let v = table.fresh();
897 if let InferType::Var(tv) = v {
898 let circular = InferType::Seq(Box::new(InferType::Var(tv)));
899 let result = table.unify(&InferType::Var(tv), &circular);
900 assert!(result.is_err());
901 assert!(matches!(result, Err(TypeError::InfiniteType { .. })));
902 }
903 }
904
905 #[test]
910 fn zonk_resolves_bound_var() {
911 let mut table = UnificationTable::new();
912 let v = table.fresh();
913 if let InferType::Var(tv) = v {
914 table.unify(&InferType::Var(tv), &InferType::Bool).unwrap();
915 let zonked = table.zonk(&InferType::Var(tv));
916 assert_eq!(zonked, InferType::Bool);
917 }
918 }
919
920 #[test]
921 fn zonk_unbound_var_becomes_unknown() {
922 let mut table = UnificationTable::new();
923 let v = table.fresh();
924 if let InferType::Var(tv) = v {
925 let zonked = table.zonk(&InferType::Var(tv));
926 assert_eq!(zonked, InferType::Unknown);
927 }
928 }
929
930 #[test]
931 fn zonk_nested_resolves_inner_var() {
932 let mut table = UnificationTable::new();
933 let v = table.fresh();
934 if let InferType::Var(tv) = v {
935 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
936 let ty = InferType::Seq(Box::new(InferType::Var(tv)));
937 let zonked = table.zonk(&ty);
938 assert_eq!(zonked, InferType::Seq(Box::new(InferType::Int)));
939 }
940 }
941
942 #[test]
943 fn zonk_chain_of_vars() {
944 let mut table = UnificationTable::new();
945 let tva = if let InferType::Var(tv) = table.fresh() { tv } else { panic!() };
946 let tvb = if let InferType::Var(tv) = table.fresh() { tv } else { panic!() };
947 let tvc = if let InferType::Var(tv) = table.fresh() { tv } else { panic!() };
948 table.unify(&InferType::Var(tva), &InferType::Var(tvb)).unwrap();
950 table.unify(&InferType::Var(tvb), &InferType::Var(tvc)).unwrap();
951 table.unify(&InferType::Var(tvc), &InferType::Float).unwrap();
952 assert_eq!(table.zonk(&InferType::Var(tva)), InferType::Float);
953 }
954
955 #[test]
960 fn unify_seq_of_same_type() {
961 let mut table = UnificationTable::new();
962 let a = InferType::Seq(Box::new(InferType::Int));
963 let b = InferType::Seq(Box::new(InferType::Int));
964 assert!(table.unify(&a, &b).is_ok());
965 }
966
967 #[test]
968 fn unify_seq_of_different_types_fails() {
969 let mut table = UnificationTable::new();
970 let a = InferType::Seq(Box::new(InferType::Int));
971 let b = InferType::Seq(Box::new(InferType::String));
972 assert!(table.unify(&a, &b).is_err());
973 }
974
975 #[test]
976 fn unify_seq_with_var_element() {
977 let mut table = UnificationTable::new();
978 let v = table.fresh();
979 if let InferType::Var(tv) = v {
980 let a = InferType::Seq(Box::new(InferType::Var(tv)));
981 let b = InferType::Seq(Box::new(InferType::Float));
982 table.unify(&a, &b).unwrap();
983 assert_eq!(table.find(tv), InferType::Float);
984 }
985 }
986
987 #[test]
988 fn unify_map_types() {
989 let mut table = UnificationTable::new();
990 let a = InferType::Map(Box::new(InferType::String), Box::new(InferType::Int));
991 let b = InferType::Map(Box::new(InferType::String), Box::new(InferType::Int));
992 assert!(table.unify(&a, &b).is_ok());
993 }
994
995 #[test]
996 fn unify_function_types_same_arity() {
997 let mut table = UnificationTable::new();
998 let a = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
999 let b = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1000 assert!(table.unify(&a, &b).is_ok());
1001 }
1002
1003 #[test]
1004 fn unify_function_arity_mismatch_fails() {
1005 let mut table = UnificationTable::new();
1006 let a = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1007 let b = InferType::Function(
1008 vec![InferType::Int, InferType::Int],
1009 Box::new(InferType::Bool),
1010 );
1011 let result = table.unify(&a, &b);
1012 assert!(matches!(result, Err(TypeError::ArityMismatch { expected: 1, found: 2 })));
1013 }
1014
1015 #[test]
1020 fn to_logos_type_ground_primitives() {
1021 let table = UnificationTable::new();
1022 assert_eq!(table.to_logos_type(&InferType::Int), LogosType::Int);
1023 assert_eq!(table.to_logos_type(&InferType::Float), LogosType::Float);
1024 assert_eq!(table.to_logos_type(&InferType::Bool), LogosType::Bool);
1025 assert_eq!(table.to_logos_type(&InferType::String), LogosType::String);
1026 assert_eq!(table.to_logos_type(&InferType::Unit), LogosType::Unit);
1027 assert_eq!(table.to_logos_type(&InferType::Nat), LogosType::Nat);
1028 }
1029
1030 #[test]
1031 fn to_logos_type_unbound_var_becomes_unknown() {
1032 let mut table = UnificationTable::new();
1033 let v = table.fresh();
1034 assert_eq!(table.to_logos_type(&v), LogosType::Unknown);
1035 }
1036
1037 #[test]
1038 fn to_logos_type_bound_var_resolves() {
1039 let mut table = UnificationTable::new();
1040 let v = table.fresh();
1041 if let InferType::Var(tv) = v {
1042 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
1043 assert_eq!(table.to_logos_type(&InferType::Var(tv)), LogosType::Int);
1044 }
1045 }
1046
1047 #[test]
1048 fn to_logos_type_seq_resolves_inner() {
1049 let mut table = UnificationTable::new();
1050 let v = table.fresh();
1051 if let InferType::Var(tv) = v {
1052 table.unify(&InferType::Var(tv), &InferType::String).unwrap();
1053 let ty = InferType::Seq(Box::new(InferType::Var(tv)));
1054 assert_eq!(
1055 table.to_logos_type(&ty),
1056 LogosType::Seq(Box::new(LogosType::String))
1057 );
1058 }
1059 }
1060
1061 #[test]
1062 fn to_logos_type_function_converts_to_logos_function() {
1063 let table = UnificationTable::new();
1064 let ty = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1065 assert_eq!(
1066 table.to_logos_type(&ty),
1067 LogosType::Function(vec![LogosType::Int], Box::new(LogosType::Bool))
1068 );
1069 }
1070
1071 #[test]
1072 fn to_logos_type_function_two_params_converts() {
1073 let table = UnificationTable::new();
1074 let ty = InferType::Function(
1075 vec![InferType::Int, InferType::String],
1076 Box::new(InferType::Bool),
1077 );
1078 assert_eq!(
1079 table.to_logos_type(&ty),
1080 LogosType::Function(
1081 vec![LogosType::Int, LogosType::String],
1082 Box::new(LogosType::Bool)
1083 )
1084 );
1085 }
1086
1087 #[test]
1088 fn to_logos_type_function_zero_params_converts() {
1089 let table = UnificationTable::new();
1090 let ty = InferType::Function(vec![], Box::new(InferType::Unit));
1091 assert_eq!(
1092 table.to_logos_type(&ty),
1093 LogosType::Function(vec![], Box::new(LogosType::Unit))
1094 );
1095 }
1096
1097 #[test]
1098 fn to_logos_type_function_nested_converts() {
1099 let table = UnificationTable::new();
1101 let inner = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1102 let outer = InferType::Function(vec![inner], Box::new(InferType::String));
1103 assert_eq!(
1104 table.to_logos_type(&outer),
1105 LogosType::Function(
1106 vec![LogosType::Function(
1107 vec![LogosType::Int],
1108 Box::new(LogosType::Bool)
1109 )],
1110 Box::new(LogosType::String)
1111 )
1112 );
1113 }
1114
1115 #[test]
1120 fn from_type_expr_function_produces_function_type() {
1121 use crate::ast::stmt::TypeExpr;
1122 let mut interner = Interner::new();
1123 let int_sym = interner.intern("Int");
1124 let bool_sym = interner.intern("Bool");
1125 let int_ty = TypeExpr::Primitive(int_sym);
1126 let bool_ty = TypeExpr::Primitive(bool_sym);
1127 let fn_ty = TypeExpr::Function {
1128 inputs: std::slice::from_ref(&int_ty),
1129 output: &bool_ty,
1130 };
1131 let result = InferType::from_type_expr(&fn_ty, &interner);
1132 assert_eq!(
1133 result,
1134 InferType::Function(vec![InferType::Int], Box::new(InferType::Bool))
1135 );
1136 }
1137
1138 #[test]
1139 fn from_type_expr_seq_of_int() {
1140 use crate::ast::stmt::TypeExpr;
1141 let mut interner = Interner::new();
1142 let seq_sym = interner.intern("Seq");
1143 let int_sym = interner.intern("Int");
1144 let int_ty = TypeExpr::Primitive(int_sym);
1145 let ty = TypeExpr::Generic {
1146 base: seq_sym,
1147 params: std::slice::from_ref(&int_ty),
1148 };
1149 let result = InferType::from_type_expr(&ty, &interner);
1150 assert_eq!(result, InferType::Seq(Box::new(InferType::Int)));
1151 }
1152
1153 #[test]
1158 fn from_field_type_type_param_resolves_to_var() {
1159 let mut interner = Interner::new();
1160 let t_sym = interner.intern("T");
1161 let tv = TyVar(0);
1162 let mut type_params = HashMap::new();
1163 type_params.insert(t_sym, tv);
1164
1165 let field_ty = FieldType::TypeParam(t_sym);
1166 let result = InferType::from_field_type(&field_ty, &interner, &type_params);
1167 assert_eq!(result, InferType::Var(tv));
1168 }
1169
1170 #[test]
1171 fn from_field_type_missing_type_param_becomes_unknown() {
1172 let mut interner = Interner::new();
1173 let t_sym = interner.intern("T");
1174 let type_params = HashMap::new();
1175 let field_ty = FieldType::TypeParam(t_sym);
1176 let result = InferType::from_field_type(&field_ty, &interner, &type_params);
1177 assert_eq!(result, InferType::Unknown);
1178 }
1179
1180 #[test]
1181 fn from_field_type_primitive() {
1182 let mut interner = Interner::new();
1183 let int_sym = interner.intern("Int");
1184 let field_ty = FieldType::Primitive(int_sym);
1185 let result = InferType::from_field_type(&field_ty, &interner, &HashMap::new());
1186 assert_eq!(result, InferType::Int);
1187 }
1188
1189 #[test]
1190 fn from_field_type_generic_seq_of_type_param() {
1191 let mut interner = Interner::new();
1192 let seq_sym = interner.intern("Seq");
1193 let t_sym = interner.intern("T");
1194 let tv = TyVar(0);
1195 let mut type_params = HashMap::new();
1196 type_params.insert(t_sym, tv);
1197
1198 let field_ty = FieldType::Generic {
1199 base: seq_sym,
1200 params: vec![FieldType::TypeParam(t_sym)],
1201 };
1202 let result = InferType::from_field_type(&field_ty, &interner, &type_params);
1203 assert_eq!(result, InferType::Seq(Box::new(InferType::Var(tv))));
1204 }
1205
1206 #[test]
1211 fn numeric_float_wins() {
1212 assert_eq!(
1213 unify_numeric(&InferType::Int, &InferType::Float).unwrap(),
1214 InferType::Float
1215 );
1216 assert_eq!(
1217 unify_numeric(&InferType::Float, &InferType::Int).unwrap(),
1218 InferType::Float
1219 );
1220 }
1221
1222 #[test]
1223 fn numeric_int_plus_int_is_int() {
1224 assert_eq!(
1225 unify_numeric(&InferType::Int, &InferType::Int).unwrap(),
1226 InferType::Int
1227 );
1228 }
1229
1230 #[test]
1231 fn numeric_nat_plus_int_is_int() {
1232 assert_eq!(
1233 unify_numeric(&InferType::Nat, &InferType::Int).unwrap(),
1234 InferType::Int
1235 );
1236 }
1237
1238 #[test]
1239 fn numeric_nat_plus_nat_is_nat() {
1240 assert_eq!(
1241 unify_numeric(&InferType::Nat, &InferType::Nat).unwrap(),
1242 InferType::Nat
1243 );
1244 }
1245
1246 #[test]
1247 fn numeric_string_fails() {
1248 let result = unify_numeric(&InferType::String, &InferType::Int);
1249 assert!(result.is_err());
1250 }
1251
1252 #[test]
1257 fn logos_name_primitives() {
1258 assert_eq!(InferType::Int.to_logos_name(), "Int");
1259 assert_eq!(InferType::Float.to_logos_name(), "Real");
1260 assert_eq!(InferType::String.to_logos_name(), "Text");
1261 assert_eq!(InferType::Bool.to_logos_name(), "Bool");
1262 }
1263
1264 #[test]
1265 fn logos_name_seq() {
1266 let ty = InferType::Seq(Box::new(InferType::Int));
1267 assert_eq!(ty.to_logos_name(), "Seq of Int");
1268 }
1269
1270 #[test]
1271 fn logos_name_function() {
1272 let ty = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1273 assert_eq!(ty.to_logos_name(), "fn(Int) -> Bool");
1274 }
1275
1276 #[test]
1281 fn type_error_mismatch_strings() {
1282 let err = TypeError::Mismatch {
1283 expected: InferType::Int,
1284 found: InferType::String,
1285 };
1286 assert_eq!(err.expected_str(), "Int");
1287 assert_eq!(err.found_str(), "Text");
1288 }
1289
1290 #[test]
1291 fn type_error_arity_mismatch_strings() {
1292 let err = TypeError::ArityMismatch { expected: 2, found: 3 };
1293 assert_eq!(err.expected_str(), "2 arguments");
1294 assert_eq!(err.found_str(), "3 arguments");
1295 }
1296
1297 #[test]
1302 fn infer_to_logos_function_single_param() {
1303 let ty = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1304 assert_eq!(
1305 super::infer_to_logos(&ty),
1306 LogosType::Function(vec![LogosType::Int], Box::new(LogosType::Bool))
1307 );
1308 }
1309
1310 #[test]
1311 fn infer_to_logos_function_zero_params() {
1312 let ty = InferType::Function(vec![], Box::new(InferType::Unit));
1313 assert_eq!(
1314 super::infer_to_logos(&ty),
1315 LogosType::Function(vec![], Box::new(LogosType::Unit))
1316 );
1317 }
1318
1319 #[test]
1320 fn infer_to_logos_function_two_params() {
1321 let ty = InferType::Function(
1322 vec![InferType::String, InferType::Float],
1323 Box::new(InferType::Bool),
1324 );
1325 assert_eq!(
1326 super::infer_to_logos(&ty),
1327 LogosType::Function(
1328 vec![LogosType::String, LogosType::Float],
1329 Box::new(LogosType::Bool)
1330 )
1331 );
1332 }
1333
1334 #[test]
1335 fn infer_to_logos_function_nested() {
1336 let inner = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1338 let outer = InferType::Function(vec![inner], Box::new(InferType::String));
1339 assert_eq!(
1340 super::infer_to_logos(&outer),
1341 LogosType::Function(
1342 vec![LogosType::Function(
1343 vec![LogosType::Int],
1344 Box::new(LogosType::Bool)
1345 )],
1346 Box::new(LogosType::String)
1347 )
1348 );
1349 }
1350}