1use crate::ast::{BinOp, Expr, Literal, MatchArm, Pattern};
40use crate::error::{TypeError, TypeErrorKind};
41use crate::types::{Substitution, Type, TypeEnv, TypeScheme, TypeVar};
42use std::collections::HashMap;
43
44#[derive(Debug, Clone, PartialEq)]
49pub enum Constraint {
50 Equal(Type, Type),
52}
53
54pub struct TypeInference {
58 next_var_id: usize,
60 constraints: Vec<Constraint>,
62}
63
64#[allow(clippy::result_large_err)]
65impl TypeInference {
66 pub fn new() -> Self {
68 TypeInference {
69 next_var_id: 0,
70 constraints: Vec::new(),
71 }
72 }
73
74 pub fn fresh_var(&mut self) -> TypeVar {
78 let id = self.next_var_id;
79 self.next_var_id += 1;
80 TypeVar::fresh(id)
81 }
82
83 fn add_constraint(&mut self, constraint: Constraint) {
85 self.constraints.push(constraint);
86 }
87
88 fn expr_references_var(expr: &Expr, name: &str) -> bool {
94 match expr {
95 Expr::Var(var_name) => var_name == name,
96 Expr::Lambda { param, body } => {
97 if param == name {
99 false
100 } else {
101 Self::expr_references_var(body, name)
102 }
103 }
104 Expr::App { func, arg } => {
105 Self::expr_references_var(func, name) || Self::expr_references_var(arg, name)
106 }
107 Expr::Let {
108 name: let_name,
109 value,
110 body,
111 } => {
112 Self::expr_references_var(value, name)
114 || (let_name != name && Self::expr_references_var(body, name))
115 }
116 Expr::LetRec {
117 name: rec_name,
118 value,
119 body,
120 } => {
121 Self::expr_references_var(value, name)
123 || (rec_name != name && Self::expr_references_var(body, name))
124 }
125 Expr::LetRecMutual { bindings, body } => {
126 bindings
128 .iter()
129 .any(|(_, expr)| Self::expr_references_var(expr, name))
130 || (!bindings.iter().any(|(n, _)| n == name)
132 && Self::expr_references_var(body, name))
133 }
134 Expr::If {
135 cond,
136 then_branch,
137 else_branch,
138 } => {
139 Self::expr_references_var(cond, name)
140 || Self::expr_references_var(then_branch, name)
141 || Self::expr_references_var(else_branch, name)
142 }
143 Expr::BinOp { left, right, .. } => {
144 Self::expr_references_var(left, name) || Self::expr_references_var(right, name)
145 }
146 Expr::Tuple(elements) | Expr::List(elements) | Expr::Array(elements) => {
147 elements.iter().any(|e| Self::expr_references_var(e, name))
148 }
149 Expr::Cons { head, tail } => {
150 Self::expr_references_var(head, name) || Self::expr_references_var(tail, name)
151 }
152 Expr::ArrayIndex { array, index } => {
153 Self::expr_references_var(array, name) || Self::expr_references_var(index, name)
154 }
155 Expr::ArrayUpdate {
156 array,
157 index,
158 value,
159 } => {
160 Self::expr_references_var(array, name)
161 || Self::expr_references_var(index, name)
162 || Self::expr_references_var(value, name)
163 }
164 Expr::ArrayLength(array) => Self::expr_references_var(array, name),
165 Expr::RecordLiteral { fields, .. } => fields
166 .iter()
167 .any(|(_, expr)| Self::expr_references_var(expr, name)),
168 Expr::RecordAccess { record, .. } => Self::expr_references_var(record, name),
169 Expr::RecordUpdate { record, fields } => {
170 Self::expr_references_var(record, name)
171 || fields
172 .iter()
173 .any(|(_, expr)| Self::expr_references_var(expr, name))
174 }
175 Expr::VariantConstruct { fields, .. } => fields
176 .iter()
177 .any(|expr| Self::expr_references_var(expr, name)),
178 Expr::Match { scrutinee, arms } => {
179 Self::expr_references_var(scrutinee, name)
180 || arms.iter().any(|arm| {
181 let pattern_binds = Self::pattern_binds(&arm.pattern, name);
183 !pattern_binds && Self::expr_references_var(&arm.body, name)
185 })
186 }
187 Expr::MethodCall { receiver, args, .. } => {
188 Self::expr_references_var(receiver, name)
189 || args.iter().any(|e| Self::expr_references_var(e, name))
190 }
191 Expr::While { cond, body } => {
192 Self::expr_references_var(cond, name) || Self::expr_references_var(body, name)
193 }
194 Expr::ComputationExpr { body, .. } => {
195 body.iter().any(|stmt| {
197 use crate::ast::CEStatement;
198 match stmt {
199 CEStatement::Let { value, .. }
200 | CEStatement::LetBang { value, .. }
201 | CEStatement::DoBang { value }
202 | CEStatement::Return { value }
203 | CEStatement::ReturnBang { value }
204 | CEStatement::Yield { value }
205 | CEStatement::YieldBang { value }
206 | CEStatement::Expr { value } => Self::expr_references_var(value, name),
207 }
208 })
209 }
210 Expr::Lit(_) | Expr::Break | Expr::Continue => false,
212 }
213 }
214
215 fn pattern_binds(pattern: &Pattern, name: &str) -> bool {
217 match pattern {
218 Pattern::Var(var_name) => var_name == name,
219 Pattern::Tuple(patterns) | Pattern::Variant { patterns, .. } => {
220 patterns.iter().any(|p| Self::pattern_binds(p, name))
221 }
222 Pattern::Wildcard | Pattern::Literal(_) => false,
223 }
224 }
225
226 pub fn infer(&mut self, expr: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
240 match expr {
241 Expr::Lit(lit) => Ok(self.infer_literal(lit)),
243
244 Expr::Var(name) => self.infer_var(name, env),
246
247 Expr::Lambda { param, body } => self.infer_lambda(param, body, env),
249
250 Expr::App { func, arg } => self.infer_app(func, arg, env),
252
253 Expr::Let { name, value, body } => self.infer_let(name, value, body, env, false),
255
256 Expr::LetRec { name, value, body } => self.infer_let(name, value, body, env, true),
258
259 Expr::LetRecMutual { bindings, body } => self.infer_let_rec_mutual(bindings, body, env),
261
262 Expr::If {
264 cond,
265 then_branch,
266 else_branch,
267 } => self.infer_if(cond, then_branch, else_branch, env),
268
269 Expr::BinOp { op, left, right } => self.infer_binop(*op, left, right, env),
271
272 Expr::Tuple(elements) => self.infer_tuple(elements, env),
274
275 Expr::List(elements) => self.infer_list(elements, env),
277
278 Expr::Cons { head, tail } => self.infer_cons(head, tail, env),
280
281 Expr::Array(elements) => self.infer_array(elements, env),
283
284 Expr::ArrayIndex { array, index } => self.infer_array_index(array, index, env),
286
287 Expr::ArrayUpdate {
289 array,
290 index,
291 value,
292 } => self.infer_array_update(array, index, value, env),
293
294 Expr::ArrayLength(array) => self.infer_array_length(array, env),
296
297 Expr::RecordLiteral { type_name, fields } => {
299 self.infer_record_literal(type_name, fields, env)
300 }
301
302 Expr::RecordAccess { record, field } => self.infer_record_access(record, field, env),
304
305 Expr::RecordUpdate { record, fields } => self.infer_record_update(record, fields, env),
307
308 Expr::VariantConstruct {
310 type_name,
311 variant,
312 fields,
313 } => self.infer_variant_construct(type_name, variant, fields, env),
314
315 Expr::Match { scrutinee, arms } => self.infer_match(scrutinee, arms, env),
317
318 Expr::MethodCall {
320 receiver,
321 method_name: _,
322 args: _,
323 } => {
324 self.infer(receiver, env)?;
327 Ok(Type::Var(self.fresh_var()))
329 }
330
331 Expr::While { cond, body } => {
333 let cond_ty = self.infer(cond, env)?;
335 self.unify(&cond_ty, &Type::Bool)?;
336 self.infer(body, env)?;
338 Ok(Type::Unit)
340 }
341
342 Expr::Break => {
344 Ok(Type::Unit)
347 }
348
349 Expr::Continue => {
351 Ok(Type::Unit)
354 }
355
356 Expr::ComputationExpr {
358 builder: _,
359 body: _,
360 } => {
361 Ok(Type::Var(self.fresh_var()))
364 }
365 }
366 }
367
368 fn infer_literal(&self, lit: &Literal) -> Type {
370 match lit {
371 Literal::Int(_) => Type::Int,
372 Literal::Float(_) => Type::Float,
373 Literal::Bool(_) => Type::Bool,
374 Literal::Str(_) => Type::String,
375 Literal::Unit => Type::Unit,
376 }
377 }
378
379 fn infer_var(&mut self, name: &str, env: &TypeEnv) -> Result<Type, TypeError> {
381 match env.lookup(name) {
382 Some(scheme) => {
383 Ok(env.instantiate(scheme, &mut || self.fresh_var()))
385 }
386 None => Err(TypeError::new(TypeErrorKind::UnboundVariable {
387 name: name.to_string(),
388 })),
389 }
390 }
391
392 fn infer_lambda(&mut self, param: &str, body: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
400 let param_type = Type::Var(self.fresh_var());
401 let param_scheme = TypeScheme::mono(param_type.clone());
402 let extended_env = env.extend(param.to_string(), param_scheme);
403
404 let body_type = self.infer(body, &extended_env)?;
405
406 Ok(Type::Function(Box::new(param_type), Box::new(body_type)))
407 }
408
409 fn infer_app(&mut self, func: &Expr, arg: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
418 let func_type = self.infer(func, env)?;
419 let arg_type = self.infer(arg, env)?;
420 let result_type = Type::Var(self.fresh_var());
421
422 let expected_func_type = Type::Function(Box::new(arg_type), Box::new(result_type.clone()));
424 self.add_constraint(Constraint::Equal(func_type, expected_func_type));
425
426 Ok(result_type)
427 }
428
429 fn infer_let(
440 &mut self,
441 name: &str,
442 value: &Expr,
443 body: &Expr,
444 env: &TypeEnv,
445 is_recursive: bool,
446 ) -> Result<Type, TypeError> {
447 let auto_recursive = !is_recursive && Self::expr_references_var(value, name);
449 let treat_as_recursive = is_recursive || auto_recursive;
450
451 let value_type = if treat_as_recursive {
452 let rec_var = Type::Var(self.fresh_var());
454 let rec_scheme = TypeScheme::mono(rec_var.clone());
455 let rec_env = env.extend(name.to_string(), rec_scheme);
456
457 let inferred = self.infer(value, &rec_env)?;
459
460 self.add_constraint(Constraint::Equal(rec_var, inferred.clone()));
462 inferred
463 } else {
464 self.infer(value, env)?
466 };
467
468 let value_scheme = env.generalize(&value_type);
470
471 let extended_env = env.extend(name.to_string(), value_scheme);
473 self.infer(body, &extended_env)
474 }
475
476 fn infer_let_rec_mutual(
478 &mut self,
479 bindings: &[(String, Expr)],
480 body: &Expr,
481 env: &TypeEnv,
482 ) -> Result<Type, TypeError> {
483 let mut rec_env = env.clone();
485 let mut binding_vars = Vec::new();
486
487 for (name, _) in bindings {
488 let var = Type::Var(self.fresh_var());
489 rec_env.insert(name.clone(), TypeScheme::mono(var.clone()));
490 binding_vars.push((name.clone(), var));
491 }
492
493 for ((_, expr), (_name, var)) in bindings.iter().zip(binding_vars.iter()) {
495 let inferred = self.infer(expr, &rec_env)?;
496 self.add_constraint(Constraint::Equal(var.clone(), inferred));
497 }
498
499 self.infer(body, &rec_env)
501 }
502
503 fn infer_if(
511 &mut self,
512 cond: &Expr,
513 then_branch: &Expr,
514 else_branch: &Expr,
515 env: &TypeEnv,
516 ) -> Result<Type, TypeError> {
517 let cond_type = self.infer(cond, env)?;
518 self.add_constraint(Constraint::Equal(cond_type, Type::Bool));
519
520 let then_type = self.infer(then_branch, env)?;
521 let else_type = self.infer(else_branch, env)?;
522
523 self.add_constraint(Constraint::Equal(then_type.clone(), else_type));
525
526 Ok(then_type)
527 }
528
529 fn infer_binop(
531 &mut self,
532 op: BinOp,
533 left: &Expr,
534 right: &Expr,
535 env: &TypeEnv,
536 ) -> Result<Type, TypeError> {
537 let left_type = self.infer(left, env)?;
538 let right_type = self.infer(right, env)?;
539
540 if op.is_arithmetic() {
541 self.add_constraint(Constraint::Equal(left_type.clone(), Type::Int));
544 self.add_constraint(Constraint::Equal(right_type, Type::Int));
545 Ok(Type::Int)
546 } else if op.is_comparison() {
547 self.add_constraint(Constraint::Equal(left_type, right_type));
549 Ok(Type::Bool)
550 } else if op.is_logical() {
551 self.add_constraint(Constraint::Equal(left_type, Type::Bool));
553 self.add_constraint(Constraint::Equal(right_type, Type::Bool));
554 Ok(Type::Bool)
555 } else {
556 unreachable!("Unknown binary operator")
557 }
558 }
559
560 fn infer_tuple(&mut self, elements: &[Expr], env: &TypeEnv) -> Result<Type, TypeError> {
562 let mut element_types = Vec::new();
563 for element in elements {
564 element_types.push(self.infer(element, env)?);
565 }
566 Ok(Type::Tuple(element_types))
567 }
568
569 fn infer_list(&mut self, elements: &[Expr], env: &TypeEnv) -> Result<Type, TypeError> {
573 if elements.is_empty() {
574 Ok(Type::List(Box::new(Type::Var(self.fresh_var()))))
576 } else {
577 let first_type = self.infer(&elements[0], env)?;
578 for element in &elements[1..] {
579 let element_type = self.infer(element, env)?;
580 self.add_constraint(Constraint::Equal(first_type.clone(), element_type));
581 }
582 Ok(Type::List(Box::new(first_type)))
583 }
584 }
585
586 fn infer_cons(&mut self, head: &Expr, tail: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
588 let head_type = self.infer(head, env)?;
589 let tail_type = self.infer(tail, env)?;
590
591 let expected_tail = Type::List(Box::new(head_type.clone()));
592 self.add_constraint(Constraint::Equal(tail_type, expected_tail));
593
594 Ok(Type::List(Box::new(head_type)))
595 }
596
597 fn infer_array(&mut self, elements: &[Expr], env: &TypeEnv) -> Result<Type, TypeError> {
599 if elements.is_empty() {
600 Ok(Type::Array(Box::new(Type::Var(self.fresh_var()))))
601 } else {
602 let first_type = self.infer(&elements[0], env)?;
603 for element in &elements[1..] {
604 let element_type = self.infer(element, env)?;
605 self.add_constraint(Constraint::Equal(first_type.clone(), element_type));
606 }
607 Ok(Type::Array(Box::new(first_type)))
608 }
609 }
610
611 fn infer_array_index(
613 &mut self,
614 array: &Expr,
615 index: &Expr,
616 env: &TypeEnv,
617 ) -> Result<Type, TypeError> {
618 let array_type = self.infer(array, env)?;
619 let index_type = self.infer(index, env)?;
620
621 self.add_constraint(Constraint::Equal(index_type, Type::Int));
623
624 let element_type = Type::Var(self.fresh_var());
626 let expected_array_type = Type::Array(Box::new(element_type.clone()));
627 self.add_constraint(Constraint::Equal(array_type, expected_array_type));
628
629 Ok(element_type)
630 }
631
632 fn infer_array_update(
634 &mut self,
635 array: &Expr,
636 index: &Expr,
637 value: &Expr,
638 env: &TypeEnv,
639 ) -> Result<Type, TypeError> {
640 let array_type = self.infer(array, env)?;
641 let index_type = self.infer(index, env)?;
642 let value_type = self.infer(value, env)?;
643
644 self.add_constraint(Constraint::Equal(index_type, Type::Int));
646
647 let expected_array_type = Type::Array(Box::new(value_type));
649 self.add_constraint(Constraint::Equal(array_type.clone(), expected_array_type));
650
651 Ok(array_type)
652 }
653
654 fn infer_array_length(&mut self, array: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
656 let array_type = self.infer(array, env)?;
657
658 let element_type = Type::Var(self.fresh_var());
660 let expected_array_type = Type::Array(Box::new(element_type));
661 self.add_constraint(Constraint::Equal(array_type, expected_array_type));
662
663 Ok(Type::Int)
664 }
665
666 fn infer_record_literal(
668 &mut self,
669 _type_name: &str,
670 fields: &[(String, Box<Expr>)],
671 env: &TypeEnv,
672 ) -> Result<Type, TypeError> {
673 let mut field_types = HashMap::new();
674
675 for (field_name, field_expr) in fields {
676 let field_type = self.infer(field_expr, env)?;
677 field_types.insert(field_name.clone(), field_type);
678 }
679
680 Ok(Type::Record(field_types))
681 }
682
683 fn infer_record_access(
685 &mut self,
686 record: &Expr,
687 field: &str,
688 env: &TypeEnv,
689 ) -> Result<Type, TypeError> {
690 let record_type = self.infer(record, env)?;
691
692 let field_type = Type::Var(self.fresh_var());
694
695 let mut expected_fields = HashMap::new();
697 expected_fields.insert(field.to_string(), field_type.clone());
698 let expected_record = Type::Record(expected_fields);
699
700 self.add_constraint(Constraint::Equal(record_type, expected_record));
703
704 Ok(field_type)
705 }
706
707 fn infer_record_update(
709 &mut self,
710 record: &Expr,
711 fields: &[(String, Box<Expr>)],
712 env: &TypeEnv,
713 ) -> Result<Type, TypeError> {
714 let record_type = self.infer(record, env)?;
715
716 let mut update_field_types = HashMap::new();
718 for (field_name, field_expr) in fields {
719 let field_type = self.infer(field_expr, env)?;
720 update_field_types.insert(field_name.clone(), field_type);
721 }
722
723 Ok(record_type)
726 }
727
728 fn infer_variant_construct(
730 &mut self,
731 _type_name: &str,
732 variant: &str,
733 fields: &[Box<Expr>],
734 env: &TypeEnv,
735 ) -> Result<Type, TypeError> {
736 let mut field_types = Vec::new();
738 for field in fields {
739 field_types.push(self.infer(field, env)?);
740 }
741
742 Ok(Type::Variant(variant.to_string(), field_types))
744 }
745
746 fn infer_match(
755 &mut self,
756 scrutinee: &Expr,
757 arms: &[MatchArm],
758 env: &TypeEnv,
759 ) -> Result<Type, TypeError> {
760 if arms.is_empty() {
761 return Err(TypeError::new(TypeErrorKind::Custom {
762 message: "Match expression must have at least one arm".to_string(),
763 }));
764 }
765
766 let scrutinee_type = self.infer(scrutinee, env)?;
767
768 let (_first_pattern_env, first_result_type) =
770 self.infer_match_arm(&arms[0], &scrutinee_type, env)?;
771
772 for arm in &arms[1..] {
774 let (_, arm_type) = self.infer_match_arm(arm, &scrutinee_type, env)?;
775 self.add_constraint(Constraint::Equal(first_result_type.clone(), arm_type));
776 }
777
778 Ok(first_result_type)
779 }
780
781 fn infer_match_arm(
785 &mut self,
786 arm: &MatchArm,
787 scrutinee_type: &Type,
788 env: &TypeEnv,
789 ) -> Result<(TypeEnv, Type), TypeError> {
790 let pattern_env = self.infer_pattern(&arm.pattern, scrutinee_type, env)?;
792
793 let body_type = self.infer(&arm.body, &pattern_env)?;
795
796 Ok((pattern_env, body_type))
797 }
798
799 pub fn infer_pattern(
803 &mut self,
804 pattern: &Pattern,
805 scrutinee_ty: &Type,
806 env: &TypeEnv,
807 ) -> Result<TypeEnv, TypeError> {
808 match pattern {
809 Pattern::Wildcard => Ok(env.clone()),
811
812 Pattern::Var(name) => {
814 let scheme = TypeScheme::mono(scrutinee_ty.clone());
815 Ok(env.extend(name.clone(), scheme))
816 }
817
818 Pattern::Literal(lit) => {
820 let lit_type = self.infer_literal(lit);
821 self.add_constraint(Constraint::Equal(scrutinee_ty.clone(), lit_type));
822 Ok(env.clone())
823 }
824
825 Pattern::Tuple(patterns) => {
827 let mut pattern_types = Vec::new();
829 for _ in patterns {
830 pattern_types.push(Type::Var(self.fresh_var()));
831 }
832
833 let expected_tuple = Type::Tuple(pattern_types.clone());
834 self.add_constraint(Constraint::Equal(scrutinee_ty.clone(), expected_tuple));
835
836 let mut extended_env = env.clone();
838 for (pattern, pattern_type) in patterns.iter().zip(pattern_types.iter()) {
839 extended_env = self.infer_pattern(pattern, pattern_type, &extended_env)?;
840 }
841
842 Ok(extended_env)
843 }
844
845 Pattern::Variant { variant, patterns } => {
847 let mut field_types = Vec::new();
849 for _ in patterns {
850 field_types.push(Type::Var(self.fresh_var()));
851 }
852
853 let expected_variant = Type::Variant(variant.clone(), field_types.clone());
854 self.add_constraint(Constraint::Equal(scrutinee_ty.clone(), expected_variant));
855
856 let mut extended_env = env.clone();
858 for (pattern, field_type) in patterns.iter().zip(field_types.iter()) {
859 extended_env = self.infer_pattern(pattern, field_type, &extended_env)?;
860 }
861
862 Ok(extended_env)
863 }
864 }
865 }
866
867 pub fn solve_constraints(&mut self) -> Result<Substitution, TypeError> {
871 let mut subst = Substitution::empty();
872
873 for constraint in &self.constraints {
874 match constraint {
875 Constraint::Equal(t1, t2) => {
876 let t1_subst = t1.apply(&subst);
878 let t2_subst = t2.apply(&subst);
879
880 let new_subst = self.unify(&t1_subst, &t2_subst)?;
882 subst = Substitution::compose(&new_subst, &subst);
883 }
884 }
885 }
886
887 Ok(subst)
888 }
889
890 #[allow(clippy::only_used_in_recursion)]
894 pub fn unify(&self, t1: &Type, t2: &Type) -> Result<Substitution, TypeError> {
895 match (t1, t2) {
896 (Type::Int, Type::Int)
898 | (Type::Bool, Type::Bool)
899 | (Type::String, Type::String)
900 | (Type::Unit, Type::Unit)
901 | (Type::Float, Type::Float) => Ok(Substitution::empty()),
902
903 (Type::Var(v1), Type::Var(v2)) if v1 == v2 => Ok(Substitution::empty()),
905
906 (Type::Var(v), t) | (t, Type::Var(v)) => {
908 if t.occurs_check(v) {
909 Err(TypeError::new(TypeErrorKind::OccursCheck {
910 var: v.clone(),
911 in_type: t.clone(),
912 }))
913 } else {
914 Ok(Substitution::singleton(v.clone(), t.clone()))
915 }
916 }
917
918 (Type::Function(a1, r1), Type::Function(a2, r2)) => {
920 let subst1 = self.unify(a1, a2)?;
921 let r1_subst = r1.apply(&subst1);
922 let r2_subst = r2.apply(&subst1);
923 let subst2 = self.unify(&r1_subst, &r2_subst)?;
924 Ok(Substitution::compose(&subst2, &subst1))
925 }
926
927 (Type::Tuple(ts1), Type::Tuple(ts2)) => {
929 if ts1.len() != ts2.len() {
930 return Err(TypeError::new(TypeErrorKind::Mismatch {
931 expected: t1.clone(),
932 got: t2.clone(),
933 }));
934 }
935
936 let mut subst = Substitution::empty();
937 for (ty1, ty2) in ts1.iter().zip(ts2.iter()) {
938 let ty1_subst = ty1.apply(&subst);
939 let ty2_subst = ty2.apply(&subst);
940 let new_subst = self.unify(&ty1_subst, &ty2_subst)?;
941 subst = Substitution::compose(&new_subst, &subst);
942 }
943 Ok(subst)
944 }
945
946 (Type::List(t1), Type::List(t2)) => self.unify(t1, t2),
948
949 (Type::Array(t1), Type::Array(t2)) => self.unify(t1, t2),
951
952 (Type::Record(fields1), Type::Record(fields2)) => {
954 if fields1.len() != fields2.len() {
955 return Err(TypeError::new(TypeErrorKind::Mismatch {
956 expected: t1.clone(),
957 got: t2.clone(),
958 }));
959 }
960
961 let mut subst = Substitution::empty();
962 for (name, ty1) in fields1 {
963 match fields2.get(name) {
964 Some(ty2) => {
965 let ty1_subst = ty1.apply(&subst);
966 let ty2_subst = ty2.apply(&subst);
967 let new_subst = self.unify(&ty1_subst, &ty2_subst)?;
968 subst = Substitution::compose(&new_subst, &subst);
969 }
970 None => {
971 return Err(TypeError::new(TypeErrorKind::FieldNotFound {
972 record_type: t1.clone(),
973 field: name.clone(),
974 }));
975 }
976 }
977 }
978 Ok(subst)
979 }
980
981 (Type::Variant(name1, fields1), Type::Variant(name2, fields2)) => {
983 if name1 != name2 {
984 return Err(TypeError::new(TypeErrorKind::Mismatch {
985 expected: t1.clone(),
986 got: t2.clone(),
987 }));
988 }
989
990 if fields1.len() != fields2.len() {
991 return Err(TypeError::new(TypeErrorKind::Mismatch {
992 expected: t1.clone(),
993 got: t2.clone(),
994 }));
995 }
996
997 let mut subst = Substitution::empty();
998 for (ty1, ty2) in fields1.iter().zip(fields2.iter()) {
999 let ty1_subst = ty1.apply(&subst);
1000 let ty2_subst = ty2.apply(&subst);
1001 let new_subst = self.unify(&ty1_subst, &ty2_subst)?;
1002 subst = Substitution::compose(&new_subst, &subst);
1003 }
1004 Ok(subst)
1005 }
1006
1007 _ => Err(TypeError::new(TypeErrorKind::Mismatch {
1009 expected: t1.clone(),
1010 got: t2.clone(),
1011 })),
1012 }
1013 }
1014
1015 pub fn infer_and_solve(&mut self, expr: &Expr, env: &TypeEnv) -> Result<Type, TypeError> {
1033 self.constraints.clear();
1035
1036 let ty = self.infer(expr, env)?;
1038
1039 let subst = self.solve_constraints()?;
1041
1042 Ok(ty.apply(&subst))
1044 }
1045}
1046
1047impl Default for TypeInference {
1048 fn default() -> Self {
1049 Self::new()
1050 }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use super::*;
1056
1057 fn lit_int(n: i64) -> Expr {
1059 Expr::Lit(Literal::Int(n))
1060 }
1061
1062 fn var(name: &str) -> Expr {
1063 Expr::Var(name.to_string())
1064 }
1065
1066 fn lambda(param: &str, body: Expr) -> Expr {
1067 Expr::Lambda {
1068 param: param.to_string(),
1069 body: Box::new(body),
1070 }
1071 }
1072
1073 fn app(func: Expr, arg: Expr) -> Expr {
1074 Expr::App {
1075 func: Box::new(func),
1076 arg: Box::new(arg),
1077 }
1078 }
1079
1080 fn let_expr(name: &str, value: Expr, body: Expr) -> Expr {
1081 Expr::Let {
1082 name: name.to_string(),
1083 value: Box::new(value),
1084 body: Box::new(body),
1085 }
1086 }
1087
1088 #[test]
1093 fn test_infer_literal_int() {
1094 let mut inf = TypeInference::new();
1095 let env = TypeEnv::new();
1096 let expr = lit_int(42);
1097
1098 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1099 assert_eq!(ty, Type::Int);
1100 }
1101
1102 #[test]
1103 fn test_infer_literal_bool() {
1104 let mut inf = TypeInference::new();
1105 let env = TypeEnv::new();
1106 let expr = Expr::Lit(Literal::Bool(true));
1107
1108 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1109 assert_eq!(ty, Type::Bool);
1110 }
1111
1112 #[test]
1113 fn test_infer_identity_function() {
1114 let mut inf = TypeInference::new();
1115 let env = TypeEnv::new();
1116 let expr = lambda("x", var("x"));
1118
1119 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1120 match ty {
1122 Type::Function(arg, ret) => match (*arg, *ret) {
1123 (Type::Var(v1), Type::Var(v2)) => assert_eq!(v1, v2),
1124 _ => panic!("Expected function with type variables"),
1125 },
1126 _ => panic!("Expected function type"),
1127 }
1128 }
1129
1130 #[test]
1131 fn test_infer_const_function() {
1132 let mut inf = TypeInference::new();
1133 let env = TypeEnv::new();
1134 let expr = lambda("x", lit_int(42));
1136
1137 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1138 match ty {
1140 Type::Function(_, ret) => assert_eq!(*ret, Type::Int),
1141 _ => panic!("Expected function type"),
1142 }
1143 }
1144
1145 #[test]
1146 fn test_infer_application() {
1147 let mut inf = TypeInference::new();
1148 let env = TypeEnv::new();
1149 let expr = app(lambda("x", var("x")), lit_int(42));
1151
1152 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1153 assert_eq!(ty, Type::Int);
1154 }
1155
1156 #[test]
1157 fn test_infer_unbound_variable() {
1158 let mut inf = TypeInference::new();
1159 let env = TypeEnv::new();
1160 let expr = var("x");
1161
1162 let result = inf.infer_and_solve(&expr, &env);
1163 assert!(result.is_err());
1164 match result.unwrap_err().kind {
1165 TypeErrorKind::UnboundVariable { name } => assert_eq!(name, "x"),
1166 _ => panic!("Expected UnboundVariable error"),
1167 }
1168 }
1169
1170 #[test]
1175 fn test_auto_recursive_lambda_factorial() {
1176 let mut inf = TypeInference::new();
1177 let env = TypeEnv::new();
1178
1179 let cond = Expr::BinOp {
1184 op: BinOp::Lte,
1185 left: Box::new(var("n")),
1186 right: Box::new(lit_int(1)),
1187 };
1188 let then_branch = lit_int(1);
1189 let else_branch = Expr::BinOp {
1190 op: BinOp::Mul,
1191 left: Box::new(var("n")),
1192 right: Box::new(app(
1193 var("factorial"),
1194 Expr::BinOp {
1195 op: BinOp::Sub,
1196 left: Box::new(var("n")),
1197 right: Box::new(lit_int(1)),
1198 },
1199 )),
1200 };
1201 let factorial_body = Expr::If {
1202 cond: Box::new(cond),
1203 then_branch: Box::new(then_branch),
1204 else_branch: Box::new(else_branch),
1205 };
1206 let factorial_lambda = lambda("n", factorial_body);
1207 let expr = let_expr(
1208 "factorial",
1209 factorial_lambda,
1210 app(var("factorial"), lit_int(5)),
1211 );
1212
1213 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1215 assert_eq!(ty, Type::Int);
1216 }
1217
1218 #[test]
1219 fn test_auto_recursive_simple() {
1220 let mut inf = TypeInference::new();
1221 let env = TypeEnv::new();
1222
1223 let f_body = app(var("f"), var("x"));
1225 let f_lambda = lambda("x", f_body);
1226 let expr = let_expr("f", f_lambda, app(var("f"), lit_int(42)));
1227
1228 let result = inf.infer_and_solve(&expr, &env);
1230 assert!(result.is_ok());
1231 }
1232
1233 #[test]
1234 fn test_non_recursive_lambda_still_works() {
1235 let mut inf = TypeInference::new();
1236 let env = TypeEnv::new();
1237
1238 let double_body = Expr::BinOp {
1240 op: BinOp::Mul,
1241 left: Box::new(var("x")),
1242 right: Box::new(lit_int(2)),
1243 };
1244 let double_lambda = lambda("x", double_body);
1245 let expr = let_expr("double", double_lambda, app(var("double"), lit_int(21)));
1246
1247 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1248 assert_eq!(ty, Type::Int);
1249 }
1250
1251 #[test]
1252 fn test_shadowing_prevents_auto_recursion() {
1253 let mut inf = TypeInference::new();
1254 let env = TypeEnv::new();
1255
1256 let f_lambda = lambda("f", var("f"));
1261 let expr = let_expr("f", f_lambda, app(var("f"), lit_int(42)));
1262
1263 let ty = inf.infer_and_solve(&expr, &env).unwrap();
1264 assert_eq!(ty, Type::Int);
1266 }
1267}