1use std::collections::HashMap;
19
20use crate::intern::{Interner, Symbol};
21use crate::analysis::{FieldType, LogosType};
22
23#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct TyVar(pub u32);
30
31#[derive(Clone, PartialEq, Debug)]
38pub enum InferType {
39 Int,
41 Float,
42 Bool,
43 Char,
44 Byte,
45 String,
46 Unit,
47 Nat,
48 Duration,
49 Date,
50 Moment,
51 Time,
52 Span,
53 Seq(Box<InferType>),
54 Map(Box<InferType>, Box<InferType>),
55 Set(Box<InferType>),
56 Option(Box<InferType>),
57 UserDefined(Symbol),
58 Var(TyVar),
60 Function(Vec<InferType>, Box<InferType>),
61 Unknown,
62}
63
64#[derive(Debug, Clone)]
66pub enum TypeError {
67 Mismatch { expected: InferType, found: InferType },
68 InfiniteType { var: TyVar, ty: InferType },
69 ArityMismatch { expected: usize, found: usize },
70 FieldNotFound { type_name: Symbol, field_name: Symbol },
71 NotAFunction { found: InferType },
72}
73
74impl TypeError {
75 pub fn expected_str(&self) -> std::string::String {
76 match self {
77 TypeError::Mismatch { expected, .. } => expected.to_logos_name(),
78 TypeError::ArityMismatch { expected, .. } => format!("{} arguments", expected),
79 TypeError::FieldNotFound { .. } => "a known field".to_string(),
80 TypeError::NotAFunction { .. } => "a function".to_string(),
81 TypeError::InfiniteType { .. } => "a finite type".to_string(),
82 }
83 }
84
85 pub fn found_str(&self) -> std::string::String {
86 match self {
87 TypeError::Mismatch { found, .. } => found.to_logos_name(),
88 TypeError::ArityMismatch { found, .. } => format!("{} arguments", found),
89 TypeError::FieldNotFound { field_name, .. } => format!("{:?}", field_name),
90 TypeError::NotAFunction { found } => found.to_logos_name(),
91 TypeError::InfiniteType { ty, .. } => ty.to_logos_name(),
92 }
93 }
94
95 pub fn to_parse_error_kind(
99 &self,
100 interner: &crate::intern::Interner,
101 ) -> crate::error::ParseErrorKind {
102 use crate::error::ParseErrorKind;
103 match self {
104 TypeError::Mismatch { expected, found } => ParseErrorKind::TypeMismatchDetailed {
105 expected: expected.to_logos_name(),
106 found: found.to_logos_name(),
107 context: String::new(),
108 },
109 TypeError::InfiniteType { var, ty } => ParseErrorKind::InfiniteType {
110 var_description: format!("type variable α{}", var.0),
111 type_description: ty.to_logos_name(),
112 },
113 TypeError::ArityMismatch { expected, found } => ParseErrorKind::ArityMismatch {
114 function: String::from("function"),
115 expected: *expected,
116 found: *found,
117 },
118 TypeError::FieldNotFound { type_name, field_name } => ParseErrorKind::FieldNotFound {
119 type_name: interner.resolve(*type_name).to_string(),
120 field_name: interner.resolve(*field_name).to_string(),
121 available: vec![],
122 },
123 TypeError::NotAFunction { found } => ParseErrorKind::NotAFunction {
124 found_type: found.to_logos_name(),
125 },
126 }
127 }
128}
129
130#[derive(Clone, Debug)]
139pub struct TypeScheme {
140 pub vars: Vec<TyVar>,
142 pub body: InferType,
144}
145
146pub struct UnificationTable {
152 bindings: Vec<Option<InferType>>,
153 next_id: u32,
154}
155
156impl UnificationTable {
157 pub fn new() -> Self {
158 Self {
159 bindings: Vec::new(),
160 next_id: 0,
161 }
162 }
163
164 pub fn fresh(&mut self) -> InferType {
166 let id = self.next_id;
167 self.next_id += 1;
168 self.bindings.push(None);
169 InferType::Var(TyVar(id))
170 }
171
172 pub fn fresh_var(&mut self) -> TyVar {
174 let id = self.next_id;
175 self.next_id += 1;
176 self.bindings.push(None);
177 TyVar(id)
178 }
179
180 pub fn instantiate(&mut self, scheme: &TypeScheme) -> InferType {
185 if scheme.vars.is_empty() {
186 return scheme.body.clone();
187 }
188 let subst: HashMap<TyVar, TyVar> = scheme.vars.iter()
189 .map(|&old_tv| (old_tv, self.fresh_var()))
190 .collect();
191 self.substitute_vars(&scheme.body, &subst)
192 }
193
194 fn substitute_vars(&self, ty: &InferType, subst: &HashMap<TyVar, TyVar>) -> InferType {
196 match ty {
197 InferType::Var(tv) => {
198 let resolved = self.find(*tv);
199 match &resolved {
200 InferType::Var(rtv) => {
201 if let Some(&new_tv) = subst.get(rtv) {
202 InferType::Var(new_tv)
203 } else {
204 InferType::Var(*rtv)
205 }
206 }
207 other => self.substitute_vars(&other.clone(), subst),
208 }
209 }
210 InferType::Seq(inner) => InferType::Seq(Box::new(self.substitute_vars(inner, subst))),
211 InferType::Map(k, v) => InferType::Map(
212 Box::new(self.substitute_vars(k, subst)),
213 Box::new(self.substitute_vars(v, subst)),
214 ),
215 InferType::Set(inner) => InferType::Set(Box::new(self.substitute_vars(inner, subst))),
216 InferType::Option(inner) => InferType::Option(Box::new(self.substitute_vars(inner, subst))),
217 InferType::Function(params, ret) => InferType::Function(
218 params.iter().map(|p| self.substitute_vars(p, subst)).collect(),
219 Box::new(self.substitute_vars(ret, subst)),
220 ),
221 other => other.clone(),
222 }
223 }
224
225 pub fn find(&self, tv: TyVar) -> InferType {
227 let mut current = tv;
228 loop {
229 match &self.bindings[current.0 as usize] {
230 None => return InferType::Var(current),
231 Some(InferType::Var(tv2)) => current = *tv2,
232 Some(ty) => return ty.clone(),
233 }
234 }
235 }
236
237 fn walk(&self, ty: &InferType) -> InferType {
239 match ty {
240 InferType::Var(tv) => self.find(*tv),
241 other => other.clone(),
242 }
243 }
244
245 pub fn resolve(&self, ty: &InferType) -> InferType {
251 match ty {
252 InferType::Var(tv) => {
253 let resolved = self.find(*tv);
254 match &resolved {
255 InferType::Var(_) => resolved, other => self.resolve(&other.clone()),
257 }
258 }
259 InferType::Seq(inner) => InferType::Seq(Box::new(self.resolve(inner))),
260 InferType::Map(k, v) => {
261 InferType::Map(Box::new(self.resolve(k)), Box::new(self.resolve(v)))
262 }
263 InferType::Set(inner) => InferType::Set(Box::new(self.resolve(inner))),
264 InferType::Option(inner) => InferType::Option(Box::new(self.resolve(inner))),
265 InferType::Function(params, ret) => {
266 let params = params.iter().map(|p| self.resolve(p)).collect();
267 InferType::Function(params, Box::new(self.resolve(ret)))
268 }
269 other => other.clone(),
270 }
271 }
272
273 pub fn zonk(&self, ty: &InferType) -> InferType {
278 match ty {
279 InferType::Var(tv) => {
280 let resolved = self.find(*tv);
281 match &resolved {
282 InferType::Var(_) => InferType::Unknown,
283 other => self.zonk(other),
284 }
285 }
286 InferType::Seq(inner) => InferType::Seq(Box::new(self.zonk(inner))),
287 InferType::Map(k, v) => {
288 InferType::Map(Box::new(self.zonk(k)), Box::new(self.zonk(v)))
289 }
290 InferType::Set(inner) => InferType::Set(Box::new(self.zonk(inner))),
291 InferType::Option(inner) => InferType::Option(Box::new(self.zonk(inner))),
292 InferType::Function(params, ret) => {
293 let params = params.iter().map(|p| self.zonk(p)).collect();
294 InferType::Function(params, Box::new(self.zonk(ret)))
295 }
296 other => other.clone(),
297 }
298 }
299
300 pub fn to_logos_type(&self, ty: &InferType) -> LogosType {
302 let zonked = self.zonk(ty);
303 infer_to_logos(&zonked)
304 }
305
306 pub fn unify(&mut self, a: &InferType, b: &InferType) -> Result<(), TypeError> {
311 let a = self.walk(a);
312 let b = self.walk(b);
313 self.unify_walked(&a, &b)
314 }
315
316 fn unify_walked(&mut self, a: &InferType, b: &InferType) -> Result<(), TypeError> {
317 match (a, b) {
318 (InferType::Var(va), InferType::Var(vb)) if va == vb => Ok(()),
320
321 (InferType::Var(tv), ty) => {
323 let tv = *tv;
324 let ty = ty.clone();
325 self.occurs_check(tv, &ty)?;
326 self.bindings[tv.0 as usize] = Some(ty);
327 Ok(())
328 }
329 (ty, InferType::Var(tv)) => {
330 let tv = *tv;
331 let ty = ty.clone();
332 self.occurs_check(tv, &ty)?;
333 self.bindings[tv.0 as usize] = Some(ty);
334 Ok(())
335 }
336
337 (InferType::Unknown, _) | (_, InferType::Unknown) => Ok(()),
339
340 (InferType::Int, InferType::Int) => Ok(()),
342 (InferType::Float, InferType::Float) => Ok(()),
343 (InferType::Bool, InferType::Bool) => Ok(()),
344 (InferType::Char, InferType::Char) => Ok(()),
345 (InferType::Byte, InferType::Byte) => Ok(()),
346 (InferType::String, InferType::String) => Ok(()),
347 (InferType::Unit, InferType::Unit) => Ok(()),
348 (InferType::Nat, InferType::Nat) => Ok(()),
349 (InferType::Duration, InferType::Duration) => Ok(()),
350 (InferType::Date, InferType::Date) => Ok(()),
351 (InferType::Moment, InferType::Moment) => Ok(()),
352 (InferType::Time, InferType::Time) => Ok(()),
353 (InferType::Span, InferType::Span) => Ok(()),
354
355 (InferType::Nat, InferType::Int) | (InferType::Int, InferType::Nat) => Ok(()),
357
358 (InferType::UserDefined(a), InferType::UserDefined(b)) if a == b => Ok(()),
360
361 (InferType::Seq(a_inner), InferType::Seq(b_inner)) => {
363 let a_inner = (**a_inner).clone();
364 let b_inner = (**b_inner).clone();
365 self.unify(&a_inner, &b_inner)
366 }
367 (InferType::Set(a_inner), InferType::Set(b_inner)) => {
368 let a_inner = (**a_inner).clone();
369 let b_inner = (**b_inner).clone();
370 self.unify(&a_inner, &b_inner)
371 }
372 (InferType::Option(a_inner), InferType::Option(b_inner)) => {
373 let a_inner = (**a_inner).clone();
374 let b_inner = (**b_inner).clone();
375 self.unify(&a_inner, &b_inner)
376 }
377 (InferType::Map(ak, av), InferType::Map(bk, bv)) => {
378 let ak = (**ak).clone();
379 let bk = (**bk).clone();
380 let av = (**av).clone();
381 let bv = (**bv).clone();
382 self.unify(&ak, &bk)?;
383 self.unify(&av, &bv)
384 }
385 (InferType::Function(a_params, a_ret), InferType::Function(b_params, b_ret)) => {
386 if a_params.len() != b_params.len() {
387 return Err(TypeError::ArityMismatch {
388 expected: a_params.len(),
389 found: b_params.len(),
390 });
391 }
392 let a_params = a_params.clone();
393 let b_params = b_params.clone();
394 let a_ret = (**a_ret).clone();
395 let b_ret = (**b_ret).clone();
396 for (ap, bp) in a_params.iter().zip(b_params.iter()) {
397 self.unify(ap, bp)?;
398 }
399 self.unify(&a_ret, &b_ret)
400 }
401
402 (a, b) => Err(TypeError::Mismatch {
404 expected: a.clone(),
405 found: b.clone(),
406 }),
407 }
408 }
409
410 fn occurs_check(&self, tv: TyVar, ty: &InferType) -> Result<(), TypeError> {
414 match ty {
415 InferType::Var(tv2) => {
416 let resolved = self.find(*tv2);
417 match &resolved {
418 InferType::Var(rtv) => {
419 if *rtv == tv {
420 Err(TypeError::InfiniteType { var: tv, ty: ty.clone() })
421 } else {
422 Ok(())
423 }
424 }
425 other => self.occurs_check(tv, &other.clone()),
426 }
427 }
428 InferType::Seq(inner) | InferType::Set(inner) | InferType::Option(inner) => {
429 self.occurs_check(tv, inner)
430 }
431 InferType::Map(k, v) => {
432 self.occurs_check(tv, k)?;
433 self.occurs_check(tv, v)
434 }
435 InferType::Function(params, ret) => {
436 for p in params {
437 self.occurs_check(tv, p)?;
438 }
439 self.occurs_check(tv, ret)
440 }
441 _ => Ok(()),
442 }
443 }
444}
445
446impl InferType {
451 pub fn from_type_expr(ty: &crate::ast::stmt::TypeExpr, interner: &Interner) -> InferType {
457 Self::from_type_expr_with_params(ty, interner, &HashMap::new())
458 }
459
460 pub fn from_type_expr_with_params(
467 ty: &crate::ast::stmt::TypeExpr,
468 interner: &Interner,
469 type_params: &HashMap<Symbol, TyVar>,
470 ) -> InferType {
471 use crate::ast::stmt::TypeExpr;
472 match ty {
473 TypeExpr::Primitive(sym) | TypeExpr::Named(sym) => {
474 if let Some(&tv) = type_params.get(sym) {
476 return InferType::Var(tv);
477 }
478 Self::from_type_name(interner.resolve(*sym))
479 }
480 TypeExpr::Generic { base, params } => {
481 let base_name = interner.resolve(*base);
482 match base_name {
483 "Seq" | "List" | "Vec" => {
484 let elem = params
485 .first()
486 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
487 .unwrap_or(InferType::Unit);
488 InferType::Seq(Box::new(elem))
489 }
490 "Map" | "HashMap" => {
491 let key = params
492 .first()
493 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
494 .unwrap_or(InferType::String);
495 let val = params
496 .get(1)
497 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
498 .unwrap_or(InferType::String);
499 InferType::Map(Box::new(key), Box::new(val))
500 }
501 "Set" | "HashSet" => {
502 let elem = params
503 .first()
504 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
505 .unwrap_or(InferType::Unit);
506 InferType::Set(Box::new(elem))
507 }
508 "Option" | "Maybe" => {
509 let inner = params
510 .first()
511 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
512 .unwrap_or(InferType::Unit);
513 InferType::Option(Box::new(inner))
514 }
515 _ => InferType::Unknown,
516 }
517 }
518 TypeExpr::Refinement { base, .. } => {
519 InferType::from_type_expr_with_params(base, interner, type_params)
520 }
521 TypeExpr::Persistent { inner } => {
522 InferType::from_type_expr_with_params(inner, interner, type_params)
523 }
524 TypeExpr::Function { inputs, output } => {
525 let param_types: Vec<InferType> = inputs
526 .iter()
527 .map(|p| InferType::from_type_expr_with_params(p, interner, type_params))
528 .collect();
529 let ret_type = InferType::from_type_expr_with_params(output, interner, type_params);
530 InferType::Function(param_types, Box::new(ret_type))
531 }
532 }
533 }
534
535 pub fn from_field_type(
540 ty: &FieldType,
541 interner: &Interner,
542 type_params: &HashMap<Symbol, TyVar>,
543 ) -> InferType {
544 match ty {
545 FieldType::Primitive(sym) => InferType::from_type_name(interner.resolve(*sym)),
546 FieldType::Named(sym) => {
547 let name = interner.resolve(*sym);
548 let primitive = InferType::from_type_name(name);
549 if primitive == InferType::Unknown {
550 InferType::UserDefined(*sym)
551 } else {
552 primitive
553 }
554 }
555 FieldType::Generic { base, params } => {
556 let base_name = interner.resolve(*base);
557 let converted: Vec<InferType> = params
558 .iter()
559 .map(|p| InferType::from_field_type(p, interner, type_params))
560 .collect();
561 match base_name {
562 "Seq" | "List" | "Vec" => {
563 InferType::Seq(Box::new(
564 converted.into_iter().next().unwrap_or(InferType::Unit),
565 ))
566 }
567 "Map" | "HashMap" => {
568 let mut it = converted.into_iter();
569 let k = it.next().unwrap_or(InferType::String);
570 let v = it.next().unwrap_or(InferType::String);
571 InferType::Map(Box::new(k), Box::new(v))
572 }
573 "Set" | "HashSet" => {
574 InferType::Set(Box::new(
575 converted.into_iter().next().unwrap_or(InferType::Unit),
576 ))
577 }
578 "Option" | "Maybe" => {
579 InferType::Option(Box::new(
580 converted.into_iter().next().unwrap_or(InferType::Unit),
581 ))
582 }
583 _ => InferType::Unknown,
584 }
585 }
586 FieldType::TypeParam(sym) => {
587 if let Some(tv) = type_params.get(sym) {
588 InferType::Var(*tv)
589 } else {
590 InferType::Unknown
591 }
592 }
593 }
594 }
595
596 pub fn from_literal(lit: &crate::ast::stmt::Literal) -> InferType {
598 use crate::ast::stmt::Literal;
599 match lit {
600 Literal::Number(_) => InferType::Int,
601 Literal::Float(_) => InferType::Float,
602 Literal::Text(_) => InferType::String,
603 Literal::Boolean(_) => InferType::Bool,
604 Literal::Char(_) => InferType::Char,
605 Literal::Nothing => InferType::Unit,
606 Literal::Duration(_) => InferType::Duration,
607 Literal::Date(_) => InferType::Date,
608 Literal::Moment(_) => InferType::Moment,
609 Literal::Span { .. } => InferType::Span,
610 Literal::Time(_) => InferType::Time,
611 }
612 }
613
614 pub fn from_type_name(name: &str) -> InferType {
616 match name {
617 "Int" => InferType::Int,
618 "Nat" => InferType::Nat,
619 "Real" | "Float" => InferType::Float,
620 "Bool" | "Boolean" => InferType::Bool,
621 "Text" | "String" => InferType::String,
622 "Char" => InferType::Char,
623 "Byte" => InferType::Byte,
624 "Unit" | "()" => InferType::Unit,
625 "Duration" => InferType::Duration,
626 "Date" => InferType::Date,
627 "Moment" => InferType::Moment,
628 "Time" => InferType::Time,
629 "Span" => InferType::Span,
630 _ => InferType::Unknown,
631 }
632 }
633
634 pub fn to_logos_name(&self) -> std::string::String {
636 match self {
637 InferType::Int => "Int".into(),
638 InferType::Float => "Real".into(),
639 InferType::Bool => "Bool".into(),
640 InferType::Char => "Char".into(),
641 InferType::Byte => "Byte".into(),
642 InferType::String => "Text".into(),
643 InferType::Unit => "Unit".into(),
644 InferType::Nat => "Nat".into(),
645 InferType::Duration => "Duration".into(),
646 InferType::Date => "Date".into(),
647 InferType::Moment => "Moment".into(),
648 InferType::Time => "Time".into(),
649 InferType::Span => "Span".into(),
650 InferType::Seq(inner) => format!("Seq of {}", inner.to_logos_name()),
651 InferType::Map(k, v) => {
652 format!("Map of {} and {}", k.to_logos_name(), v.to_logos_name())
653 }
654 InferType::Set(inner) => format!("Set of {}", inner.to_logos_name()),
655 InferType::Option(inner) => format!("Option of {}", inner.to_logos_name()),
656 InferType::UserDefined(_) => "a user-defined type".into(),
657 InferType::Var(_) => "an unknown type".into(),
658 InferType::Function(params, ret) => {
659 let params_str = params
660 .iter()
661 .map(|p| p.to_logos_name())
662 .collect::<Vec<_>>()
663 .join(", ");
664 format!("fn({}) -> {}", params_str, ret.to_logos_name())
665 }
666 InferType::Unknown => "unknown".into(),
667 }
668 }
669
670 pub fn to_logos_type_ground(&self) -> LogosType {
676 match self {
677 InferType::Int => LogosType::Int,
678 InferType::Float => LogosType::Float,
679 InferType::Bool => LogosType::Bool,
680 InferType::Char => LogosType::Char,
681 InferType::Byte => LogosType::Byte,
682 InferType::String => LogosType::String,
683 InferType::Unit => LogosType::Unit,
684 InferType::Nat => LogosType::Nat,
685 InferType::Duration => LogosType::Duration,
686 InferType::Date => LogosType::Date,
687 InferType::Moment => LogosType::Moment,
688 InferType::Time => LogosType::Time,
689 InferType::Span => LogosType::Span,
690 InferType::Seq(inner) => LogosType::Seq(Box::new(inner.to_logos_type_ground())),
691 InferType::Map(k, v) => LogosType::Map(
692 Box::new(k.to_logos_type_ground()),
693 Box::new(v.to_logos_type_ground()),
694 ),
695 InferType::Set(inner) => LogosType::Set(Box::new(inner.to_logos_type_ground())),
696 InferType::Option(inner) => LogosType::Option(Box::new(inner.to_logos_type_ground())),
697 InferType::UserDefined(sym) => LogosType::UserDefined(*sym),
698 InferType::Function(params, ret) => LogosType::Function(
699 params.iter().map(|p| p.to_logos_type_ground()).collect(),
700 Box::new(ret.to_logos_type_ground()),
701 ),
702 InferType::Unknown => LogosType::Unknown,
703 InferType::Var(_) => panic!("to_logos_type_ground called on unresolved Var"),
704 }
705 }
706}
707
708pub fn unify_numeric(a: &InferType, b: &InferType) -> Result<InferType, TypeError> {
714 match (a, b) {
715 (InferType::Float, _) | (_, InferType::Float) => Ok(InferType::Float),
716 (InferType::Int, InferType::Int) => Ok(InferType::Int),
717 (InferType::Nat, InferType::Int) | (InferType::Int, InferType::Nat) => Ok(InferType::Int),
718 (InferType::Nat, InferType::Nat) => Ok(InferType::Nat),
719 (InferType::Byte, InferType::Byte) => Ok(InferType::Byte),
720 _ => Err(TypeError::Mismatch {
721 expected: InferType::Int,
722 found: a.clone(),
723 }),
724 }
725}
726
727pub fn infer_to_logos(ty: &InferType) -> LogosType {
729 match ty {
730 InferType::Int => LogosType::Int,
731 InferType::Float => LogosType::Float,
732 InferType::Bool => LogosType::Bool,
733 InferType::Char => LogosType::Char,
734 InferType::Byte => LogosType::Byte,
735 InferType::String => LogosType::String,
736 InferType::Unit => LogosType::Unit,
737 InferType::Nat => LogosType::Nat,
738 InferType::Duration => LogosType::Duration,
739 InferType::Date => LogosType::Date,
740 InferType::Moment => LogosType::Moment,
741 InferType::Time => LogosType::Time,
742 InferType::Span => LogosType::Span,
743 InferType::Seq(inner) => LogosType::Seq(Box::new(infer_to_logos(inner))),
744 InferType::Map(k, v) => {
745 LogosType::Map(Box::new(infer_to_logos(k)), Box::new(infer_to_logos(v)))
746 }
747 InferType::Set(inner) => LogosType::Set(Box::new(infer_to_logos(inner))),
748 InferType::Option(inner) => LogosType::Option(Box::new(infer_to_logos(inner))),
749 InferType::UserDefined(sym) => LogosType::UserDefined(*sym),
750 InferType::Function(params, ret) => LogosType::Function(
751 params.iter().map(infer_to_logos).collect(),
752 Box::new(infer_to_logos(ret)),
753 ),
754 InferType::Unknown | InferType::Var(_) => LogosType::Unknown,
755 }
756}
757
758#[cfg(test)]
763mod tests {
764 use super::*;
765 use crate::analysis::{FieldDef, TypeDef};
766
767 #[test]
772 fn fresh_produces_distinct_vars() {
773 let mut table = UnificationTable::new();
774 let a = table.fresh();
775 let b = table.fresh();
776 assert_ne!(a, b);
777 }
778
779 #[test]
780 fn unbound_var_finds_itself() {
781 let mut table = UnificationTable::new();
782 let v = table.fresh();
783 if let InferType::Var(tv) = v {
784 assert_eq!(table.find(tv), InferType::Var(tv));
785 } else {
786 panic!("expected Var");
787 }
788 }
789
790 #[test]
795 fn unify_identical_ground_types() {
796 let mut table = UnificationTable::new();
797 assert!(table.unify(&InferType::Int, &InferType::Int).is_ok());
798 assert!(table.unify(&InferType::Float, &InferType::Float).is_ok());
799 assert!(table.unify(&InferType::Bool, &InferType::Bool).is_ok());
800 assert!(table.unify(&InferType::String, &InferType::String).is_ok());
801 assert!(table.unify(&InferType::Unit, &InferType::Unit).is_ok());
802 }
803
804 #[test]
805 fn unify_different_ground_types_fails() {
806 let mut table = UnificationTable::new();
807 let result = table.unify(&InferType::Int, &InferType::String);
808 assert!(result.is_err());
809 assert!(matches!(result, Err(TypeError::Mismatch { .. })));
810 }
811
812 #[test]
813 fn unify_int_float_fails() {
814 let mut table = UnificationTable::new();
815 let result = table.unify(&InferType::Int, &InferType::Float);
816 assert!(result.is_err());
817 }
818
819 #[test]
820 fn unify_nat_int_succeeds() {
821 let mut table = UnificationTable::new();
822 assert!(table.unify(&InferType::Nat, &InferType::Int).is_ok());
823 assert!(table.unify(&InferType::Int, &InferType::Nat).is_ok());
824 }
825
826 #[test]
827 fn unify_unknown_with_any_succeeds() {
828 let mut table = UnificationTable::new();
829 assert!(table.unify(&InferType::Unknown, &InferType::Int).is_ok());
830 assert!(table.unify(&InferType::String, &InferType::Unknown).is_ok());
831 assert!(table.unify(&InferType::Unknown, &InferType::Unknown).is_ok());
832 }
833
834 #[test]
839 fn var_unifies_with_int() {
840 let mut table = UnificationTable::new();
841 let v = table.fresh();
842 if let InferType::Var(tv) = v {
843 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
844 assert_eq!(table.find(tv), InferType::Int);
845 }
846 }
847
848 #[test]
849 fn int_unifies_with_var() {
850 let mut table = UnificationTable::new();
851 let v = table.fresh();
852 if let InferType::Var(tv) = v {
853 table.unify(&InferType::Int, &InferType::Var(tv)).unwrap();
854 assert_eq!(table.find(tv), InferType::Int);
855 }
856 }
857
858 #[test]
859 fn two_vars_unify_chain() {
860 let mut table = UnificationTable::new();
861 let va = table.fresh();
862 let vb = table.fresh();
863 let tva = if let InferType::Var(tv) = va { tv } else { panic!() };
864 let tvb = if let InferType::Var(tv) = vb { tv } else { panic!() };
865 table.unify(&InferType::Var(tva), &InferType::Var(tvb)).unwrap();
866 table.unify(&InferType::Var(tvb), &InferType::Int).unwrap();
868 let zonked = table.zonk(&InferType::Var(tva));
869 assert_eq!(zonked, InferType::Int);
870 }
871
872 #[test]
873 fn var_conflicting_types_fails() {
874 let mut table = UnificationTable::new();
875 let v = table.fresh();
876 if let InferType::Var(tv) = v {
877 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
878 let result = table.unify(&InferType::Var(tv), &InferType::String);
879 assert!(result.is_err());
881 }
882 }
883
884 #[test]
889 fn occurs_check_detects_infinite_type() {
890 let mut table = UnificationTable::new();
891 let v = table.fresh();
892 if let InferType::Var(tv) = v {
893 let circular = InferType::Seq(Box::new(InferType::Var(tv)));
894 let result = table.unify(&InferType::Var(tv), &circular);
895 assert!(result.is_err());
896 assert!(matches!(result, Err(TypeError::InfiniteType { .. })));
897 }
898 }
899
900 #[test]
905 fn zonk_resolves_bound_var() {
906 let mut table = UnificationTable::new();
907 let v = table.fresh();
908 if let InferType::Var(tv) = v {
909 table.unify(&InferType::Var(tv), &InferType::Bool).unwrap();
910 let zonked = table.zonk(&InferType::Var(tv));
911 assert_eq!(zonked, InferType::Bool);
912 }
913 }
914
915 #[test]
916 fn zonk_unbound_var_becomes_unknown() {
917 let mut table = UnificationTable::new();
918 let v = table.fresh();
919 if let InferType::Var(tv) = v {
920 let zonked = table.zonk(&InferType::Var(tv));
921 assert_eq!(zonked, InferType::Unknown);
922 }
923 }
924
925 #[test]
926 fn zonk_nested_resolves_inner_var() {
927 let mut table = UnificationTable::new();
928 let v = table.fresh();
929 if let InferType::Var(tv) = v {
930 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
931 let ty = InferType::Seq(Box::new(InferType::Var(tv)));
932 let zonked = table.zonk(&ty);
933 assert_eq!(zonked, InferType::Seq(Box::new(InferType::Int)));
934 }
935 }
936
937 #[test]
938 fn zonk_chain_of_vars() {
939 let mut table = UnificationTable::new();
940 let tva = if let InferType::Var(tv) = table.fresh() { tv } else { panic!() };
941 let tvb = if let InferType::Var(tv) = table.fresh() { tv } else { panic!() };
942 let tvc = if let InferType::Var(tv) = table.fresh() { tv } else { panic!() };
943 table.unify(&InferType::Var(tva), &InferType::Var(tvb)).unwrap();
945 table.unify(&InferType::Var(tvb), &InferType::Var(tvc)).unwrap();
946 table.unify(&InferType::Var(tvc), &InferType::Float).unwrap();
947 assert_eq!(table.zonk(&InferType::Var(tva)), InferType::Float);
948 }
949
950 #[test]
955 fn unify_seq_of_same_type() {
956 let mut table = UnificationTable::new();
957 let a = InferType::Seq(Box::new(InferType::Int));
958 let b = InferType::Seq(Box::new(InferType::Int));
959 assert!(table.unify(&a, &b).is_ok());
960 }
961
962 #[test]
963 fn unify_seq_of_different_types_fails() {
964 let mut table = UnificationTable::new();
965 let a = InferType::Seq(Box::new(InferType::Int));
966 let b = InferType::Seq(Box::new(InferType::String));
967 assert!(table.unify(&a, &b).is_err());
968 }
969
970 #[test]
971 fn unify_seq_with_var_element() {
972 let mut table = UnificationTable::new();
973 let v = table.fresh();
974 if let InferType::Var(tv) = v {
975 let a = InferType::Seq(Box::new(InferType::Var(tv)));
976 let b = InferType::Seq(Box::new(InferType::Float));
977 table.unify(&a, &b).unwrap();
978 assert_eq!(table.find(tv), InferType::Float);
979 }
980 }
981
982 #[test]
983 fn unify_map_types() {
984 let mut table = UnificationTable::new();
985 let a = InferType::Map(Box::new(InferType::String), Box::new(InferType::Int));
986 let b = InferType::Map(Box::new(InferType::String), Box::new(InferType::Int));
987 assert!(table.unify(&a, &b).is_ok());
988 }
989
990 #[test]
991 fn unify_function_types_same_arity() {
992 let mut table = UnificationTable::new();
993 let a = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
994 let b = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
995 assert!(table.unify(&a, &b).is_ok());
996 }
997
998 #[test]
999 fn unify_function_arity_mismatch_fails() {
1000 let mut table = UnificationTable::new();
1001 let a = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1002 let b = InferType::Function(
1003 vec![InferType::Int, InferType::Int],
1004 Box::new(InferType::Bool),
1005 );
1006 let result = table.unify(&a, &b);
1007 assert!(matches!(result, Err(TypeError::ArityMismatch { expected: 1, found: 2 })));
1008 }
1009
1010 #[test]
1015 fn to_logos_type_ground_primitives() {
1016 let table = UnificationTable::new();
1017 assert_eq!(table.to_logos_type(&InferType::Int), LogosType::Int);
1018 assert_eq!(table.to_logos_type(&InferType::Float), LogosType::Float);
1019 assert_eq!(table.to_logos_type(&InferType::Bool), LogosType::Bool);
1020 assert_eq!(table.to_logos_type(&InferType::String), LogosType::String);
1021 assert_eq!(table.to_logos_type(&InferType::Unit), LogosType::Unit);
1022 assert_eq!(table.to_logos_type(&InferType::Nat), LogosType::Nat);
1023 }
1024
1025 #[test]
1026 fn to_logos_type_unbound_var_becomes_unknown() {
1027 let mut table = UnificationTable::new();
1028 let v = table.fresh();
1029 assert_eq!(table.to_logos_type(&v), LogosType::Unknown);
1030 }
1031
1032 #[test]
1033 fn to_logos_type_bound_var_resolves() {
1034 let mut table = UnificationTable::new();
1035 let v = table.fresh();
1036 if let InferType::Var(tv) = v {
1037 table.unify(&InferType::Var(tv), &InferType::Int).unwrap();
1038 assert_eq!(table.to_logos_type(&InferType::Var(tv)), LogosType::Int);
1039 }
1040 }
1041
1042 #[test]
1043 fn to_logos_type_seq_resolves_inner() {
1044 let mut table = UnificationTable::new();
1045 let v = table.fresh();
1046 if let InferType::Var(tv) = v {
1047 table.unify(&InferType::Var(tv), &InferType::String).unwrap();
1048 let ty = InferType::Seq(Box::new(InferType::Var(tv)));
1049 assert_eq!(
1050 table.to_logos_type(&ty),
1051 LogosType::Seq(Box::new(LogosType::String))
1052 );
1053 }
1054 }
1055
1056 #[test]
1057 fn to_logos_type_function_converts_to_logos_function() {
1058 let table = UnificationTable::new();
1059 let ty = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1060 assert_eq!(
1061 table.to_logos_type(&ty),
1062 LogosType::Function(vec![LogosType::Int], Box::new(LogosType::Bool))
1063 );
1064 }
1065
1066 #[test]
1067 fn to_logos_type_function_two_params_converts() {
1068 let table = UnificationTable::new();
1069 let ty = InferType::Function(
1070 vec![InferType::Int, InferType::String],
1071 Box::new(InferType::Bool),
1072 );
1073 assert_eq!(
1074 table.to_logos_type(&ty),
1075 LogosType::Function(
1076 vec![LogosType::Int, LogosType::String],
1077 Box::new(LogosType::Bool)
1078 )
1079 );
1080 }
1081
1082 #[test]
1083 fn to_logos_type_function_zero_params_converts() {
1084 let table = UnificationTable::new();
1085 let ty = InferType::Function(vec![], Box::new(InferType::Unit));
1086 assert_eq!(
1087 table.to_logos_type(&ty),
1088 LogosType::Function(vec![], Box::new(LogosType::Unit))
1089 );
1090 }
1091
1092 #[test]
1093 fn to_logos_type_function_nested_converts() {
1094 let table = UnificationTable::new();
1096 let inner = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1097 let outer = InferType::Function(vec![inner], Box::new(InferType::String));
1098 assert_eq!(
1099 table.to_logos_type(&outer),
1100 LogosType::Function(
1101 vec![LogosType::Function(
1102 vec![LogosType::Int],
1103 Box::new(LogosType::Bool)
1104 )],
1105 Box::new(LogosType::String)
1106 )
1107 );
1108 }
1109
1110 #[test]
1115 fn from_type_expr_function_produces_function_type() {
1116 use crate::ast::stmt::TypeExpr;
1117 let mut interner = Interner::new();
1118 let int_sym = interner.intern("Int");
1119 let bool_sym = interner.intern("Bool");
1120 let int_ty = TypeExpr::Primitive(int_sym);
1121 let bool_ty = TypeExpr::Primitive(bool_sym);
1122 let fn_ty = TypeExpr::Function {
1123 inputs: std::slice::from_ref(&int_ty),
1124 output: &bool_ty,
1125 };
1126 let result = InferType::from_type_expr(&fn_ty, &interner);
1127 assert_eq!(
1128 result,
1129 InferType::Function(vec![InferType::Int], Box::new(InferType::Bool))
1130 );
1131 }
1132
1133 #[test]
1134 fn from_type_expr_seq_of_int() {
1135 use crate::ast::stmt::TypeExpr;
1136 let mut interner = Interner::new();
1137 let seq_sym = interner.intern("Seq");
1138 let int_sym = interner.intern("Int");
1139 let int_ty = TypeExpr::Primitive(int_sym);
1140 let ty = TypeExpr::Generic {
1141 base: seq_sym,
1142 params: std::slice::from_ref(&int_ty),
1143 };
1144 let result = InferType::from_type_expr(&ty, &interner);
1145 assert_eq!(result, InferType::Seq(Box::new(InferType::Int)));
1146 }
1147
1148 #[test]
1153 fn from_field_type_type_param_resolves_to_var() {
1154 let mut interner = Interner::new();
1155 let t_sym = interner.intern("T");
1156 let tv = TyVar(0);
1157 let mut type_params = HashMap::new();
1158 type_params.insert(t_sym, tv);
1159
1160 let field_ty = FieldType::TypeParam(t_sym);
1161 let result = InferType::from_field_type(&field_ty, &interner, &type_params);
1162 assert_eq!(result, InferType::Var(tv));
1163 }
1164
1165 #[test]
1166 fn from_field_type_missing_type_param_becomes_unknown() {
1167 let mut interner = Interner::new();
1168 let t_sym = interner.intern("T");
1169 let type_params = HashMap::new();
1170 let field_ty = FieldType::TypeParam(t_sym);
1171 let result = InferType::from_field_type(&field_ty, &interner, &type_params);
1172 assert_eq!(result, InferType::Unknown);
1173 }
1174
1175 #[test]
1176 fn from_field_type_primitive() {
1177 let mut interner = Interner::new();
1178 let int_sym = interner.intern("Int");
1179 let field_ty = FieldType::Primitive(int_sym);
1180 let result = InferType::from_field_type(&field_ty, &interner, &HashMap::new());
1181 assert_eq!(result, InferType::Int);
1182 }
1183
1184 #[test]
1185 fn from_field_type_generic_seq_of_type_param() {
1186 let mut interner = Interner::new();
1187 let seq_sym = interner.intern("Seq");
1188 let t_sym = interner.intern("T");
1189 let tv = TyVar(0);
1190 let mut type_params = HashMap::new();
1191 type_params.insert(t_sym, tv);
1192
1193 let field_ty = FieldType::Generic {
1194 base: seq_sym,
1195 params: vec![FieldType::TypeParam(t_sym)],
1196 };
1197 let result = InferType::from_field_type(&field_ty, &interner, &type_params);
1198 assert_eq!(result, InferType::Seq(Box::new(InferType::Var(tv))));
1199 }
1200
1201 #[test]
1206 fn numeric_float_wins() {
1207 assert_eq!(
1208 unify_numeric(&InferType::Int, &InferType::Float).unwrap(),
1209 InferType::Float
1210 );
1211 assert_eq!(
1212 unify_numeric(&InferType::Float, &InferType::Int).unwrap(),
1213 InferType::Float
1214 );
1215 }
1216
1217 #[test]
1218 fn numeric_int_plus_int_is_int() {
1219 assert_eq!(
1220 unify_numeric(&InferType::Int, &InferType::Int).unwrap(),
1221 InferType::Int
1222 );
1223 }
1224
1225 #[test]
1226 fn numeric_nat_plus_int_is_int() {
1227 assert_eq!(
1228 unify_numeric(&InferType::Nat, &InferType::Int).unwrap(),
1229 InferType::Int
1230 );
1231 }
1232
1233 #[test]
1234 fn numeric_nat_plus_nat_is_nat() {
1235 assert_eq!(
1236 unify_numeric(&InferType::Nat, &InferType::Nat).unwrap(),
1237 InferType::Nat
1238 );
1239 }
1240
1241 #[test]
1242 fn numeric_string_fails() {
1243 let result = unify_numeric(&InferType::String, &InferType::Int);
1244 assert!(result.is_err());
1245 }
1246
1247 #[test]
1252 fn logos_name_primitives() {
1253 assert_eq!(InferType::Int.to_logos_name(), "Int");
1254 assert_eq!(InferType::Float.to_logos_name(), "Real");
1255 assert_eq!(InferType::String.to_logos_name(), "Text");
1256 assert_eq!(InferType::Bool.to_logos_name(), "Bool");
1257 }
1258
1259 #[test]
1260 fn logos_name_seq() {
1261 let ty = InferType::Seq(Box::new(InferType::Int));
1262 assert_eq!(ty.to_logos_name(), "Seq of Int");
1263 }
1264
1265 #[test]
1266 fn logos_name_function() {
1267 let ty = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1268 assert_eq!(ty.to_logos_name(), "fn(Int) -> Bool");
1269 }
1270
1271 #[test]
1276 fn type_error_mismatch_strings() {
1277 let err = TypeError::Mismatch {
1278 expected: InferType::Int,
1279 found: InferType::String,
1280 };
1281 assert_eq!(err.expected_str(), "Int");
1282 assert_eq!(err.found_str(), "Text");
1283 }
1284
1285 #[test]
1286 fn type_error_arity_mismatch_strings() {
1287 let err = TypeError::ArityMismatch { expected: 2, found: 3 };
1288 assert_eq!(err.expected_str(), "2 arguments");
1289 assert_eq!(err.found_str(), "3 arguments");
1290 }
1291
1292 #[test]
1297 fn infer_to_logos_function_single_param() {
1298 let ty = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1299 assert_eq!(
1300 super::infer_to_logos(&ty),
1301 LogosType::Function(vec![LogosType::Int], Box::new(LogosType::Bool))
1302 );
1303 }
1304
1305 #[test]
1306 fn infer_to_logos_function_zero_params() {
1307 let ty = InferType::Function(vec![], Box::new(InferType::Unit));
1308 assert_eq!(
1309 super::infer_to_logos(&ty),
1310 LogosType::Function(vec![], Box::new(LogosType::Unit))
1311 );
1312 }
1313
1314 #[test]
1315 fn infer_to_logos_function_two_params() {
1316 let ty = InferType::Function(
1317 vec![InferType::String, InferType::Float],
1318 Box::new(InferType::Bool),
1319 );
1320 assert_eq!(
1321 super::infer_to_logos(&ty),
1322 LogosType::Function(
1323 vec![LogosType::String, LogosType::Float],
1324 Box::new(LogosType::Bool)
1325 )
1326 );
1327 }
1328
1329 #[test]
1330 fn infer_to_logos_function_nested() {
1331 let inner = InferType::Function(vec![InferType::Int], Box::new(InferType::Bool));
1333 let outer = InferType::Function(vec![inner], Box::new(InferType::String));
1334 assert_eq!(
1335 super::infer_to_logos(&outer),
1336 LogosType::Function(
1337 vec![LogosType::Function(
1338 vec![LogosType::Int],
1339 Box::new(LogosType::Bool)
1340 )],
1341 Box::new(LogosType::String)
1342 )
1343 );
1344 }
1345}