1mod error;
2
3use std::{collections::BTreeMap, rc::Rc};
4
5use error::TypeConstraintError;
6pub use petr_bind::FunctionId;
7use petr_resolve::{Expr, ExprKind, QueryableResolvedItems};
8pub use petr_resolve::{Intrinsic as ResolvedIntrinsic, IntrinsicName, Literal};
9use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, TypeId};
10
11pub type TypeError = SpannedItem<TypeConstraintError>;
12pub type TResult<T> = Result<T, TypeError>;
13
14pub fn type_check(resolved: QueryableResolvedItems) -> (Vec<TypeError>, TypeChecker) {
17 let mut type_checker = TypeChecker::new(resolved);
18 type_checker.fully_type_check();
19 (type_checker.errors.clone(), type_checker)
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
23pub enum TypeOrFunctionId {
24 TypeId(TypeId),
25 FunctionId(FunctionId),
26}
27
28impl From<TypeId> for TypeOrFunctionId {
29 fn from(type_id: TypeId) -> Self {
30 TypeOrFunctionId::TypeId(type_id)
31 }
32}
33
34impl From<FunctionId> for TypeOrFunctionId {
35 fn from(function_id: FunctionId) -> Self {
36 TypeOrFunctionId::FunctionId(function_id)
37 }
38}
39
40impl From<&TypeId> for TypeOrFunctionId {
41 fn from(type_id: &TypeId) -> Self {
42 TypeOrFunctionId::TypeId(*type_id)
43 }
44}
45
46impl From<&FunctionId> for TypeOrFunctionId {
47 fn from(function_id: &FunctionId) -> Self {
48 TypeOrFunctionId::FunctionId(*function_id)
49 }
50}
51
52idx_map_key!(TypeVariable);
53
54#[derive(Clone, Copy, Debug)]
55pub struct TypeConstraint {
56 kind: TypeConstraintKind,
57 span: Span,
59}
60impl TypeConstraint {
61 fn unify(
62 t1: TypeVariable,
63 t2: TypeVariable,
64 span: Span,
65 ) -> Self {
66 Self {
67 kind: TypeConstraintKind::Unify(t1, t2),
68 span,
69 }
70 }
71
72 fn satisfies(
73 t1: TypeVariable,
74 t2: TypeVariable,
75 span: Span,
76 ) -> Self {
77 Self {
78 kind: TypeConstraintKind::Satisfies(t1, t2),
79 span,
80 }
81 }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub enum TypeConstraintKind {
86 Unify(TypeVariable, TypeVariable),
87 Satisfies(TypeVariable, TypeVariable),
89}
90
91pub struct TypeContext {
92 types: IndexMap<TypeVariable, PetrType>,
93 constraints: Vec<TypeConstraint>,
94 unit_ty: TypeVariable,
96 string_ty: TypeVariable,
97 int_ty: TypeVariable,
98 error_recovery: TypeVariable,
99}
100
101impl Default for TypeContext {
102 fn default() -> Self {
103 let mut types = IndexMap::default();
104 let unit_ty = types.insert(PetrType::Unit);
106 let string_ty = types.insert(PetrType::String);
107 let int_ty = types.insert(PetrType::Integer);
108 let error_recovery = types.insert(PetrType::ErrorRecovery);
109 TypeContext {
111 types,
112 constraints: Default::default(),
113 unit_ty,
114 string_ty,
115 int_ty,
116 error_recovery,
117 }
118 }
119}
120
121impl TypeContext {
122 fn unify(
123 &mut self,
124 ty1: TypeVariable,
125 ty2: TypeVariable,
126 span: Span,
127 ) {
128 self.constraints.push(TypeConstraint::unify(ty1, ty2, span));
129 }
130
131 fn satisfies(
132 &mut self,
133 ty1: TypeVariable,
134 ty2: TypeVariable,
135 span: Span,
136 ) {
137 self.constraints.push(TypeConstraint::satisfies(ty1, ty2, span));
138 }
139
140 fn new_variable(&mut self) -> TypeVariable {
141 let infer_id = self.types.len();
143 self.types.insert(PetrType::Infer(infer_id))
144 }
145
146 fn update_type(
148 &mut self,
149 t1: TypeVariable,
150 known: PetrType,
151 ) {
152 *self.types.get_mut(t1) = known;
153 }
154}
155
156pub struct TypeChecker {
157 ctx: TypeContext,
158 type_map: BTreeMap<TypeOrFunctionId, TypeVariable>,
159 typed_functions: BTreeMap<FunctionId, Function>,
160 errors: Vec<TypeError>,
161 resolved: QueryableResolvedItems,
162 variable_scope: Vec<BTreeMap<Identifier, TypeVariable>>,
163}
164
165#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)]
166pub enum PetrType {
167 Unit,
168 Integer,
169 Boolean,
170 String,
172 Ref(TypeVariable),
174 UserDefined {
176 name: Identifier,
177 variants: Vec<TypeVariant>,
178 },
179 Arrow(Vec<TypeVariable>),
180 ErrorRecovery,
181 List(TypeVariable),
182 Infer(usize),
184}
185
186#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)]
187pub struct TypeVariant {
188 pub fields: Box<[TypeVariable]>,
189}
190
191impl TypeChecker {
192 pub fn insert_type(
193 &mut self,
194 ty: PetrType,
195 ) -> TypeVariable {
196 self.ctx.types.insert(ty)
198 }
199
200 pub fn look_up_variable(
201 &self,
202 ty: TypeVariable,
203 ) -> &PetrType {
204 self.ctx.types.get(ty)
205 }
206
207 pub fn get_symbol(
208 &self,
209 id: SymbolId,
210 ) -> Rc<str> {
211 self.resolved.interner.get(id).clone()
212 }
213
214 fn with_type_scope<T>(
215 &mut self,
216 f: impl FnOnce(&mut Self) -> T,
217 ) -> T {
218 self.variable_scope.push(Default::default());
219 let res = f(self);
220 self.variable_scope.pop();
221 res
222 }
223
224 fn generic_type(
225 &mut self,
226 id: &Identifier,
227 ) -> TypeVariable {
228 for scope in self.variable_scope.iter().rev() {
229 if let Some(ty) = scope.get(id) {
230 return *ty;
231 }
232 }
233 let fresh_ty = self.fresh_ty_var();
234 match self.variable_scope.last_mut() {
235 Some(entry) => {
236 entry.insert(*id, fresh_ty);
237 },
238 None => {
239 self.errors.push(id.span.with_item(TypeConstraintError::Internal(
240 "attempted to insert generic type into variable scope when no variable scope existed".into(),
241 )));
242 self.ctx.update_type(fresh_ty, PetrType::ErrorRecovery);
243 },
244 };
245 fresh_ty
246 }
247
248 fn find_variable(
249 &self,
250 id: Identifier,
251 ) -> Option<TypeVariable> {
252 for scope in self.variable_scope.iter().rev() {
253 if let Some(ty) = scope.get(&id) {
254 return Some(*ty);
255 }
256 }
257 None
258 }
259
260 fn fully_type_check(&mut self) {
261 for (id, decl) in self.resolved.types() {
262 let ty = self.fresh_ty_var();
263 let variants = decl
264 .variants
265 .iter()
266 .map(|variant| {
267 self.with_type_scope(|ctx| {
268 let fields = variant.fields.iter().map(|field| ctx.to_type_var(&field.ty)).collect::<Vec<_>>();
269 TypeVariant {
270 fields: fields.into_boxed_slice(),
271 }
272 })
273 })
274 .collect::<Vec<_>>();
275 self.ctx.update_type(ty, PetrType::UserDefined { name: decl.name, variants });
276 self.type_map.insert(id.into(), ty);
277 }
278
279 for (id, func) in self.resolved.functions() {
280 let typed_function = func.type_check(self);
281
282 let ty = self.arrow_type([typed_function.params.iter().map(|(_, b)| *b).collect(), vec![typed_function.return_ty]].concat());
283 self.type_map.insert(id.into(), ty);
284 self.typed_functions.insert(id, typed_function);
285 }
286
287 self.apply_constraints();
289 }
290
291 fn apply_constraints(&mut self) {
296 let constraints = self.ctx.constraints.clone();
297 for constraint in constraints {
298 match &constraint.kind {
299 TypeConstraintKind::Unify(t1, t2) => {
300 self.apply_unify_constraint(*t1, *t2, constraint.span);
301 },
302 TypeConstraintKind::Satisfies(t1, t2) => {
303 self.apply_satisfies_constraint(*t1, *t2, constraint.span);
304 },
305 }
306 }
307 }
308
309 fn apply_unify_constraint(
312 &mut self,
313 t1: TypeVariable,
314 t2: TypeVariable,
315 span: Span,
316 ) {
317 let ty1 = self.ctx.types.get(t1).clone();
318 let ty2 = self.ctx.types.get(t2).clone();
319 use PetrType::*;
320 match (ty1, ty2) {
321 (a, b) if a == b => (),
322 (ErrorRecovery, _) | (_, ErrorRecovery) => (),
323 (Ref(a), _) => self.apply_unify_constraint(a, t2, span),
324 (_, Ref(b)) => self.apply_unify_constraint(t1, b, span),
325 (Infer(id), Infer(id2)) if id != id2 => {
326 self.ctx.update_type(t2, Ref(t1));
329 },
330 (Infer(_), known) => {
332 self.ctx.update_type(t1, known);
333 },
334 (known, Infer(_)) => {
335 self.ctx.update_type(t2, known);
336 },
337 (a, b) => {
339 self.push_error(span.with_item(TypeConstraintError::UnificationFailure(a, b)));
340 },
341 }
342 }
343
344 fn apply_satisfies_constraint(
347 &mut self,
348 t1: TypeVariable,
349 t2: TypeVariable,
350 span: Span,
351 ) {
352 let ty1 = self.ctx.types.get(t1);
353 let ty2 = self.ctx.types.get(t2);
354 use PetrType::*;
355 match (ty1, ty2) {
356 (a, b) if a == b => (),
357 (ErrorRecovery, _) | (_, ErrorRecovery) => (),
358 (Ref(a), _) => self.apply_satisfies_constraint(*a, t2, span),
359 (_, Ref(b)) => self.apply_satisfies_constraint(t1, *b, span),
360 (_known, Infer(_)) => {
362 self.ctx.update_type(t2, Ref(t1));
363 },
364 (a, b) => {
366 self.push_error(span.with_item(TypeConstraintError::FailedToSatisfy(a.clone(), b.clone())));
367 },
368 }
369 }
370
371 pub fn new(resolved: QueryableResolvedItems) -> Self {
372 let ctx = TypeContext::default();
373 let mut type_checker = TypeChecker {
374 ctx,
375 type_map: Default::default(),
376 errors: Default::default(),
377 typed_functions: Default::default(),
378 resolved,
379 variable_scope: Default::default(),
380 };
381
382 type_checker.fully_type_check();
383 type_checker
384 }
385
386 pub fn insert_variable(
387 &mut self,
388 id: Identifier,
389 ty: TypeVariable,
390 ) {
391 self.variable_scope
392 .last_mut()
393 .expect("inserted variable when no scope existed")
394 .insert(id, ty);
395 }
396
397 pub fn fresh_ty_var(&mut self) -> TypeVariable {
398 self.ctx.new_variable()
399 }
400
401 fn arrow_type(
402 &mut self,
403 tys: Vec<TypeVariable>,
404 ) -> TypeVariable {
405 assert!(!tys.is_empty(), "arrow_type: tys is empty");
406
407 if tys.len() == 1 {
408 return tys[0];
409 }
410
411 let ty = PetrType::Arrow(tys);
412 self.ctx.types.insert(ty)
413 }
414
415 pub fn to_type_var(
416 &mut self,
417 ty: &petr_resolve::Type,
418 ) -> TypeVariable {
419 let ty = match ty {
420 petr_resolve::Type::Integer => PetrType::Integer,
421 petr_resolve::Type::Bool => PetrType::Boolean,
422 petr_resolve::Type::Unit => PetrType::Unit,
423 petr_resolve::Type::String => PetrType::String,
424 petr_resolve::Type::ErrorRecovery => {
425 return self.fresh_ty_var();
427 },
428 petr_resolve::Type::Named(ty_id) => PetrType::Ref(*self.type_map.get(&ty_id.into()).expect("type did not exist in type map")),
429 petr_resolve::Type::Generic(generic_name) => {
430 return self.generic_type(generic_name);
431 },
432 };
433 self.ctx.types.insert(ty)
434 }
435
436 pub fn get_type(
437 &self,
438 key: impl Into<TypeOrFunctionId>,
439 ) -> &TypeVariable {
440 self.type_map.get(&key.into()).expect("type did not exist in type map")
441 }
442
443 fn convert_literal_to_type(
444 &mut self,
445 literal: &petr_resolve::Literal,
446 ) -> TypeVariable {
447 use petr_resolve::Literal::*;
448 let ty = match literal {
449 Integer(_) => PetrType::Integer,
450 Boolean(_) => PetrType::Boolean,
451 String(_) => PetrType::String,
452 };
453 self.ctx.types.insert(ty)
454 }
455
456 fn push_error(
457 &mut self,
458 e: TypeError,
459 ) {
460 self.errors.push(e);
461 }
462
463 pub fn unify(
464 &mut self,
465 ty1: TypeVariable,
466 ty2: TypeVariable,
467 span: Span,
468 ) {
469 self.ctx.unify(ty1, ty2, span);
470 }
471
472 pub fn satisfies(
473 &mut self,
474 ty1: TypeVariable,
475 ty2: TypeVariable,
476 span: Span,
477 ) {
478 self.ctx.satisfies(ty1, ty2, span);
479 }
480
481 fn get_untyped_function(
482 &self,
483 function: FunctionId,
484 ) -> &petr_resolve::Function {
485 self.resolved.get_function(function)
486 }
487
488 fn realize_symbol(
491 &self,
492 id: petr_utils::SymbolId,
493 ) -> Rc<str> {
494 self.resolved.interner.get(id)
495 }
496
497 pub fn get_function(
498 &self,
499 id: &FunctionId,
500 ) -> &Function {
501 self.typed_functions.get(id).expect("invariant: should exist")
502 }
503
504 pub fn functions(&self) -> impl Iterator<Item = (FunctionId, Function)> {
506 self.typed_functions.iter().map(|(a, b)| (*a, b.clone())).collect::<Vec<_>>().into_iter()
507 }
508
509 pub fn expr_ty(
510 &self,
511 expr: &TypedExpr,
512 ) -> TypeVariable {
513 use TypedExprKind::*;
514 match &expr.kind {
515 FunctionCall { ty, .. } => *ty,
516 Literal { ty, .. } => *ty,
517 List { ty, .. } => *ty,
518 Unit => self.unit(),
519 Variable { ty, .. } => *ty,
520 Intrinsic { ty, .. } => *ty,
521 ErrorRecovery(..) => self.ctx.error_recovery,
522 ExprWithBindings { expression, .. } => self.expr_ty(expression),
523 TypeConstructor { ty, .. } => *ty,
524 }
525 }
526
527 pub fn unify_expr_return(
529 &mut self,
530 ty: TypeVariable,
531 expr: &TypedExpr,
532 ) {
533 let expr_ty = self.expr_ty(expr);
534 self.unify(ty, expr_ty, expr.span());
535 }
536
537 pub fn string(&self) -> TypeVariable {
538 self.ctx.string_ty
539 }
540
541 pub fn unit(&self) -> TypeVariable {
542 self.ctx.unit_ty
543 }
544
545 pub fn int(&self) -> TypeVariable {
546 self.ctx.int_ty
547 }
548
549 pub fn error_recovery(
553 &mut self,
554 err: TypeError,
555 ) -> TypeVariable {
556 self.push_error(err);
557 self.ctx.error_recovery
558 }
559
560 pub fn errors(&self) -> &[TypeError] {
561 &self.errors
562 }
563}
564
565#[derive(Clone)]
566pub enum Intrinsic {
567 Puts(Box<TypedExpr>),
568 Add(Box<TypedExpr>, Box<TypedExpr>),
569 Multiply(Box<TypedExpr>, Box<TypedExpr>),
570 Divide(Box<TypedExpr>, Box<TypedExpr>),
571 Subtract(Box<TypedExpr>, Box<TypedExpr>),
572 Malloc(Box<TypedExpr>),
573}
574
575impl std::fmt::Debug for Intrinsic {
576 fn fmt(
577 &self,
578 f: &mut std::fmt::Formatter<'_>,
579 ) -> std::fmt::Result {
580 match self {
581 Intrinsic::Puts(expr) => write!(f, "@puts({:?})", expr),
582 Intrinsic::Add(lhs, rhs) => write!(f, "@add({:?}, {:?})", lhs, rhs),
583 Intrinsic::Multiply(lhs, rhs) => write!(f, "@multiply({:?}, {:?})", lhs, rhs),
584 Intrinsic::Divide(lhs, rhs) => write!(f, "@divide({:?}, {:?})", lhs, rhs),
585 Intrinsic::Subtract(lhs, rhs) => write!(f, "@subtract({:?}, {:?})", lhs, rhs),
586 Intrinsic::Malloc(size) => write!(f, "@malloc({:?})", size),
587 }
588 }
589}
590
591#[derive(Clone)]
592pub struct TypedExpr {
593 pub kind: TypedExprKind,
594 span: Span,
595}
596
597impl TypedExpr {
598 pub fn span(&self) -> Span {
599 self.span
600 }
601}
602
603#[derive(Clone, Debug)]
604pub enum TypedExprKind {
605 FunctionCall {
606 func: FunctionId,
607 args: Vec<(Identifier, TypedExpr)>,
608 ty: TypeVariable,
609 },
610 Literal {
611 value: Literal,
612 ty: TypeVariable,
613 },
614 List {
615 elements: Vec<TypedExpr>,
616 ty: TypeVariable,
617 },
618 Unit,
619 Variable {
620 ty: TypeVariable,
621 name: Identifier,
622 },
623 Intrinsic {
624 ty: TypeVariable,
625 intrinsic: Intrinsic,
626 },
627 ErrorRecovery(Span),
628 ExprWithBindings {
629 bindings: Vec<(Identifier, TypedExpr)>,
630 expression: Box<TypedExpr>,
631 },
632 TypeConstructor {
633 ty: TypeVariable,
634 args: Box<[TypedExpr]>,
635 },
636}
637
638impl std::fmt::Debug for TypedExpr {
639 fn fmt(
640 &self,
641 f: &mut std::fmt::Formatter<'_>,
642 ) -> std::fmt::Result {
643 use TypedExprKind::*;
644 match &self.kind {
645 FunctionCall { func, args, .. } => {
646 write!(f, "function call to {} with args: ", func)?;
647 for (name, arg) in args {
648 write!(f, "{}: {:?}, ", name.id, arg)?;
649 }
650 Ok(())
651 },
652 Literal { value, .. } => write!(f, "literal: {}", value),
653 List { elements, .. } => {
654 write!(f, "list: [")?;
655 for elem in elements {
656 write!(f, "{:?}, ", elem)?;
657 }
658 write!(f, "]")
659 },
660 Unit => write!(f, "unit"),
661 Variable { name, .. } => write!(f, "variable: {}", name.id),
662 Intrinsic { intrinsic, .. } => write!(f, "intrinsic: {:?}", intrinsic),
663 ErrorRecovery(..) => write!(f, "error recovery"),
664 ExprWithBindings { bindings, expression } => {
665 write!(f, "bindings: ")?;
666 for (name, expr) in bindings {
667 write!(f, "{}: {:?}, ", name.id, expr)?;
668 }
669 write!(f, "expression: {:?}", expression)
670 },
671 TypeConstructor { ty, .. } => write!(f, "type constructor: {:?}", ty),
672 }
673 }
674}
675
676impl TypeCheck for Expr {
677 type Output = TypedExpr;
678
679 fn type_check(
680 &self,
681 ctx: &mut TypeChecker,
682 ) -> Self::Output {
683 let kind = match &self.kind {
684 ExprKind::Literal(lit) => {
685 let ty = ctx.convert_literal_to_type(lit);
686 TypedExprKind::Literal { value: lit.clone(), ty }
687 },
688 ExprKind::List(exprs) => {
689 if exprs.is_empty() {
690 let ty = ctx.unit();
691 TypedExprKind::List { elements: vec![], ty }
692 } else {
693 let type_checked_exprs = exprs.iter().map(|expr| expr.type_check(ctx)).collect::<Vec<_>>();
694 let first_ty = ctx.expr_ty(&type_checked_exprs[0]);
696 for expr in type_checked_exprs.iter().skip(1) {
697 let second_ty = ctx.expr_ty(expr);
698 ctx.unify(first_ty, second_ty, expr.span());
699 }
700 TypedExprKind::List {
701 elements: type_checked_exprs,
702 ty: ctx.insert_type(PetrType::List(first_ty)),
703 }
704 }
705 },
706 ExprKind::FunctionCall(call) => {
707 let func_decl = ctx.get_untyped_function(call.function).clone();
710 if call.args.len() != func_decl.params.len() {
711 ctx.push_error(call.span().with_item(TypeConstraintError::ArgumentCountMismatch {
712 expected: func_decl.params.len(),
713 got: call.args.len(),
714 function: ctx.realize_symbol(func_decl.name.id).to_string(),
715 }));
716 return TypedExpr {
717 kind: TypedExprKind::ErrorRecovery(self.span),
718 span: self.span,
719 };
720 }
721 let mut args = Vec::with_capacity(call.args.len());
722 let mut arg_types = Vec::with_capacity(call.args.len());
723
724 for (arg, (param_name, param)) in call.args.iter().zip(func_decl.params.iter()) {
725 let arg_expr = arg.type_check(ctx);
726 let param_ty = ctx.to_type_var(param);
727 let arg_ty = ctx.expr_ty(&arg_expr);
728 ctx.satisfies(arg_ty, param_ty, arg_expr.span());
729 arg_types.push(arg_ty);
730 args.push((*param_name, arg_expr));
731 }
732 TypedExprKind::FunctionCall {
733 func: call.function,
734 args,
735 ty: ctx.to_type_var(&func_decl.return_type),
736 }
737 },
738 ExprKind::Unit => TypedExprKind::Unit,
739 ExprKind::ErrorRecovery => TypedExprKind::ErrorRecovery(self.span),
740 ExprKind::Variable { name, ty } => {
741 let var_ty = ctx.find_variable(*name).expect("variable not found in scope");
744 let ty = ctx.to_type_var(ty);
745
746 ctx.unify(var_ty, ty, name.span());
747
748 TypedExprKind::Variable { ty, name: *name }
749 },
750 ExprKind::Intrinsic(intrinsic) => return self.span.with_item(intrinsic.clone()).type_check(ctx),
751 ExprKind::TypeConstructor(parent_type_id, args) => {
752 let args = args.iter().map(|arg| arg.type_check(ctx)).collect::<Vec<_>>();
756 let ty = ctx.get_type(*parent_type_id);
757 TypedExprKind::TypeConstructor {
758 ty: *ty,
759 args: args.into_boxed_slice(),
760 }
761 },
762 ExprKind::ExpressionWithBindings { bindings, expression } => {
763 ctx.with_type_scope(|ctx| {
765 let mut type_checked_bindings = Vec::with_capacity(bindings.len());
766 for binding in bindings {
767 let binding_ty = binding.expression.type_check(ctx);
768 let binding_expr_return_ty = ctx.expr_ty(&binding_ty);
769 ctx.insert_variable(binding.name, binding_expr_return_ty);
770 type_checked_bindings.push((binding.name, binding_ty));
771 }
772
773 TypedExprKind::ExprWithBindings {
774 bindings: type_checked_bindings,
775 expression: Box::new(expression.type_check(ctx)),
776 }
777 })
778 },
779 };
780
781 TypedExpr { kind, span: self.span }
782 }
783}
784
785fn unify_basic_math_op(
786 lhs: &Expr,
787 rhs: &Expr,
788 ctx: &mut TypeChecker,
789) -> (TypedExpr, TypedExpr) {
790 let lhs = lhs.type_check(ctx);
791 let rhs = rhs.type_check(ctx);
792 let lhs_ty = ctx.expr_ty(&lhs);
793 let rhs_ty = ctx.expr_ty(&rhs);
794 let int_ty = ctx.int();
795 ctx.unify(lhs_ty, int_ty, lhs.span());
796 ctx.unify(rhs_ty, int_ty, rhs.span());
797 (lhs, rhs)
798}
799
800impl TypeCheck for SpannedItem<ResolvedIntrinsic> {
801 type Output = TypedExpr;
802
803 fn type_check(
804 &self,
805 ctx: &mut TypeChecker,
806 ) -> Self::Output {
807 use petr_resolve::IntrinsicName::*;
808 let string_ty = ctx.string();
809 let kind = match self.item().intrinsic {
810 Puts => {
811 if self.item().args.len() != 1 {
812 todo!("puts arg len check");
813 }
814 let arg = self.item().args[0].type_check(ctx);
816 ctx.unify_expr_return(string_ty, &arg);
817 TypedExprKind::Intrinsic {
818 intrinsic: Intrinsic::Puts(Box::new(arg)),
819 ty: ctx.unit(),
820 }
821 },
822 Add => {
823 if self.item().args.len() != 2 {
824 todo!("add arg len check");
825 }
826 let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
827 TypedExprKind::Intrinsic {
828 intrinsic: Intrinsic::Add(Box::new(lhs), Box::new(rhs)),
829 ty: ctx.int(),
830 }
831 },
832 Subtract => {
833 if self.item().args.len() != 2 {
834 todo!("sub arg len check");
835 }
836 let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
837 TypedExprKind::Intrinsic {
838 intrinsic: Intrinsic::Subtract(Box::new(lhs), Box::new(rhs)),
839 ty: ctx.int(),
840 }
841 },
842 Multiply => {
843 if self.item().args.len() != 2 {
844 todo!("mult arg len check");
845 }
846
847 let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
848 TypedExprKind::Intrinsic {
849 intrinsic: Intrinsic::Multiply(Box::new(lhs), Box::new(rhs)),
850 ty: ctx.int(),
851 }
852 },
853
854 Divide => {
855 if self.item().args.len() != 2 {
856 todo!("Divide arg len check");
857 }
858
859 let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
860 TypedExprKind::Intrinsic {
861 intrinsic: Intrinsic::Divide(Box::new(lhs), Box::new(rhs)),
862 ty: ctx.int(),
863 }
864 },
865 Malloc => {
866 if self.item().args.len() != 1 {
872 todo!("malloc arg len check");
873 }
874 let arg = self.item().args[0].type_check(ctx);
875 let arg_ty = ctx.expr_ty(&arg);
876 let int_ty = ctx.int();
877 ctx.unify(arg_ty, int_ty, arg.span());
878 TypedExprKind::Intrinsic {
879 intrinsic: Intrinsic::Malloc(Box::new(arg)),
880 ty: int_ty,
881 }
882 },
883 };
884
885 TypedExpr { kind, span: self.span() }
886 }
887}
888
889trait TypeCheck {
890 type Output;
891 fn type_check(
892 &self,
893 ctx: &mut TypeChecker,
894 ) -> Self::Output;
895}
896
897#[derive(Clone, Debug)]
898pub struct Function {
899 pub name: Identifier,
900 pub params: Vec<(Identifier, TypeVariable)>,
901 pub body: TypedExpr,
902 pub return_ty: TypeVariable,
903}
904
905impl TypeCheck for petr_resolve::Function {
906 type Output = Function;
907
908 fn type_check(
909 &self,
910 ctx: &mut TypeChecker,
911 ) -> Self::Output {
912 ctx.with_type_scope(|ctx| {
913 let params = self.params.iter().map(|(name, ty)| (*name, ctx.to_type_var(ty))).collect::<Vec<_>>();
914
915 for (name, ty) in ¶ms {
916 ctx.insert_variable(*name, *ty);
917 }
918
919 let body = self.body.type_check(ctx);
921
922 let declared_return_type = ctx.to_type_var(&self.return_type);
923
924 Function {
925 name: self.name,
926 params,
927 return_ty: declared_return_type,
928 body,
929 }
930 })
931 }
934}
935
936impl TypeCheck for petr_resolve::FunctionCall {
937 type Output = ();
938
939 fn type_check(
940 &self,
941 ctx: &mut TypeChecker,
942 ) -> Self::Output {
943 let func_type = *ctx.get_type(self.function);
944 let args = self.args.iter().map(|arg| arg.type_check(ctx)).collect::<Vec<_>>();
945
946 let mut arg_types = Vec::with_capacity(args.len());
947
948 for arg in args.iter() {
949 arg_types.push(ctx.expr_ty(arg));
950 }
951
952 let arg_type = ctx.arrow_type(arg_types);
953
954 ctx.unify(func_type, arg_type, self.span());
955 }
956}
957
958#[cfg(test)]
959mod tests {
960 use expect_test::{expect, Expect};
961 use petr_resolve::resolve_symbols;
962 use petr_utils::{render_error, SourceId};
963
964 use super::*;
965 fn check(
966 input: impl Into<String>,
967 expect: Expect,
968 ) {
969 let input = input.into();
970 let parser = petr_parse::Parser::new(vec![("test", input)]);
971 let (ast, errs, interner, source_map) = parser.into_result();
972 if !errs.is_empty() {
973 errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err)));
974 panic!("fmt failed: code didn't parse");
975 }
976 let (errs, resolved) = resolve_symbols(ast, interner, Default::default());
977 assert!(errs.is_empty(), "can't typecheck: unresolved symbols");
978 let type_checker = TypeChecker::new(resolved);
979 let res = pretty_print_type_checker(type_checker, &source_map);
980
981 expect.assert_eq(&res);
982 }
983
984 fn pretty_print_type_checker(
985 type_checker: TypeChecker,
986 source_map: &IndexMap<SourceId, (&'static str, &'static str)>,
987 ) -> String {
988 let mut s = String::new();
989 for (id, ty) in &type_checker.type_map {
990 let text = match id {
991 TypeOrFunctionId::TypeId(id) => {
992 let ty = type_checker.resolved.get_type(*id);
993
994 let name = type_checker.resolved.interner.get(ty.name.id);
995 format!("type {}", name)
996 },
997 TypeOrFunctionId::FunctionId(id) => {
998 let func = type_checker.resolved.get_function(*id);
999
1000 let name = type_checker.resolved.interner.get(func.name.id);
1001
1002 format!("fn {}", name)
1003 },
1004 };
1005 s.push_str(&text);
1006 s.push_str(": ");
1007 s.push_str(&pretty_print_ty(ty, &type_checker));
1008
1009 s.push('\n');
1010 match id {
1011 TypeOrFunctionId::TypeId(_) => (),
1012 TypeOrFunctionId::FunctionId(func) => {
1013 let func = type_checker.typed_functions.get(func).unwrap();
1014 let body = &func.body;
1015 s.push_str(&pretty_print_typed_expr(body, &type_checker));
1016 s.push('\n');
1017 },
1018 }
1019
1020 s.push('\n');
1021 }
1022
1023 if !type_checker.errors.is_empty() {
1024 s.push_str("\nErrors:\n");
1025 for error in type_checker.errors {
1026 let rendered = render_error(source_map, error);
1027 s.push_str(&format!("{:?}\n", rendered));
1028 }
1029 }
1030 s
1031 }
1032
1033 fn pretty_print_ty(
1034 ty: &TypeVariable,
1035 type_checker: &TypeChecker,
1036 ) -> String {
1037 let mut ty = type_checker.look_up_variable(*ty);
1038 while let PetrType::Ref(t) = ty {
1039 ty = type_checker.look_up_variable(*t);
1040 }
1041 match ty {
1042 PetrType::Unit => "unit".to_string(),
1043 PetrType::Integer => "int".to_string(),
1044 PetrType::Boolean => "bool".to_string(),
1045 PetrType::String => "string".to_string(),
1046 PetrType::Ref(ty) => pretty_print_ty(ty, type_checker),
1047 PetrType::UserDefined { name, variants: _ } => {
1048 let name = type_checker.resolved.interner.get(name.id);
1049 name.to_string()
1050 },
1051 PetrType::Arrow(tys) => {
1052 let mut s = String::new();
1053 s.push('(');
1054 for (ix, ty) in tys.iter().enumerate() {
1055 let is_last = ix == tys.len() - 1;
1056
1057 s.push_str(&pretty_print_ty(ty, type_checker));
1058 if !is_last {
1059 s.push_str(" → ");
1060 }
1061 }
1062 s.push(')');
1063 s
1064 },
1065 PetrType::ErrorRecovery => "error recovery".to_string(),
1066 PetrType::List(ty) => format!("[{}]", pretty_print_ty(ty, type_checker)),
1067 PetrType::Infer(id) => format!("t{id}"),
1068 }
1069 }
1070
1071 fn pretty_print_typed_expr(
1072 typed_expr: &TypedExpr,
1073 type_checker: &TypeChecker,
1074 ) -> String {
1075 let interner = &type_checker.resolved.interner;
1076 match &typed_expr.kind {
1077 TypedExprKind::ExprWithBindings { bindings, expression } => {
1078 let mut s = String::new();
1079 for (name, expr) in bindings {
1080 let ident = interner.get(name.id);
1081 let ty = type_checker.expr_ty(expr);
1082 let ty = pretty_print_ty(&ty, type_checker);
1083 s.push_str(&format!("{ident}: {:?} ({}),\n", expr, ty));
1084 }
1085 let expr_ty = type_checker.expr_ty(expression);
1086 let expr_ty = pretty_print_ty(&expr_ty, type_checker);
1087 s.push_str(&format!("{:?} ({})", pretty_print_typed_expr(expression, type_checker), expr_ty));
1088 s
1089 },
1090 TypedExprKind::Variable { name, ty } => {
1091 let name = interner.get(name.id);
1092 let ty = pretty_print_ty(ty, type_checker);
1093 format!("variable {name}: {ty}")
1094 },
1095
1096 TypedExprKind::FunctionCall { func, args, ty } => {
1097 let mut s = String::new();
1098 s.push_str(&format!("function call to {} with args: ", func));
1099 for (name, arg) in args {
1100 let name = interner.get(name.id);
1101 let arg_ty = type_checker.expr_ty(arg);
1102 let arg_ty = pretty_print_ty(&arg_ty, type_checker);
1103 s.push_str(&format!("{name}: {}, ", arg_ty));
1104 }
1105 let ty = pretty_print_ty(ty, type_checker);
1106 s.push_str(&format!("returns {ty}"));
1107 s
1108 },
1109 TypedExprKind::TypeConstructor { ty, .. } => format!("type constructor: {}", pretty_print_ty(ty, type_checker)),
1110 _otherwise => format!("{:?}", typed_expr),
1111 }
1112 }
1113
1114 #[test]
1115 fn identity_resolution_concrete_type() {
1116 check(
1117 r#"
1118 fn foo(x in 'int) returns 'int x
1119 "#,
1120 expect![[r#"
1121 fn foo: (int → int)
1122 variable x: int
1123
1124 "#]],
1125 );
1126 }
1127
1128 #[test]
1129 fn identity_resolution_generic() {
1130 check(
1131 r#"
1132 fn foo(x in 'A) returns 'A x
1133 "#,
1134 expect![[r#"
1135 fn foo: (t4 → t4)
1136 variable x: t4
1137
1138 "#]],
1139 );
1140 }
1141
1142 #[test]
1143 fn identity_resolution_custom_type() {
1144 check(
1145 r#"
1146 type MyType = A | B
1147 fn foo(x in 'MyType) returns 'MyType x
1148 "#,
1149 expect![[r#"
1150 type MyType: MyType
1151
1152 fn A: MyType
1153 type constructor: MyType
1154
1155 fn B: MyType
1156 type constructor: MyType
1157
1158 fn foo: (MyType → MyType)
1159 variable x: MyType
1160
1161 "#]],
1162 );
1163 }
1164
1165 #[test]
1166 fn identity_resolution_two_custom_types() {
1167 check(
1168 r#"
1169 type MyType = A | B
1170 type MyComposedType = firstVariant someField 'MyType | secondVariant someField 'int someField2 'MyType someField3 'GenericType
1171 fn foo(x in 'MyType) returns 'MyComposedType ~firstVariant(x)
1172 "#,
1173 expect![[r#"
1174 type MyType: MyType
1175
1176 type MyComposedType: MyComposedType
1177
1178 fn A: MyType
1179 type constructor: MyType
1180
1181 fn B: MyType
1182 type constructor: MyType
1183
1184 fn firstVariant: (MyType → MyComposedType)
1185 type constructor: MyComposedType
1186
1187 fn secondVariant: (int → MyType → t18 → MyComposedType)
1188 type constructor: MyComposedType
1189
1190 fn foo: (MyType → MyComposedType)
1191 function call to functionid2 with args: someField: MyType, returns MyComposedType
1192
1193 "#]],
1194 );
1195 }
1196
1197 #[test]
1198 fn literal_unification_fail() {
1199 check(
1200 r#"
1201 fn foo() returns 'int 5
1202 fn bar() returns 'bool 5
1203 "#,
1204 expect![[r#"
1205 fn foo: int
1206 literal: 5
1207
1208 fn bar: bool
1209 literal: 5
1210
1211 "#]],
1212 );
1213 }
1214
1215 #[test]
1216 fn literal_unification_success() {
1217 check(
1218 r#"
1219 fn foo() returns 'int 5
1220 fn bar() returns 'bool true
1221 "#,
1222 expect![[r#"
1223 fn foo: int
1224 literal: 5
1225
1226 fn bar: bool
1227 literal: true
1228
1229 "#]],
1230 );
1231 }
1232
1233 #[test]
1234 fn pass_zero_arity_func_to_intrinsic() {
1235 check(
1236 r#"
1237 fn string_literal() returns 'string
1238 "This is a string literal."
1239
1240 fn my_func() returns 'unit
1241 @puts(~string_literal)"#,
1242 expect![[r#"
1243 fn string_literal: string
1244 literal: "This is a string literal."
1245
1246 fn my_func: unit
1247 intrinsic: @puts(function call to functionid0 with args: )
1248
1249 "#]],
1250 );
1251 }
1252
1253 #[test]
1254 fn pass_literal_string_to_intrinsic() {
1255 check(
1256 r#"
1257 fn my_func() returns 'unit
1258 @puts("test")"#,
1259 expect![[r#"
1260 fn my_func: unit
1261 intrinsic: @puts(literal: "test")
1262
1263 "#]],
1264 );
1265 }
1266
1267 #[test]
1268 fn pass_wrong_type_literal_to_intrinsic() {
1269 check(
1270 r#"
1271 fn my_func() returns 'unit
1272 @puts(true)"#,
1273 expect![[r#"
1274 fn my_func: unit
1275 intrinsic: @puts(literal: true)
1276
1277
1278 Errors:
1279 × Failed to unify types: String, Boolean
1280 ╭─[test:2:1]
1281 2 │ fn my_func() returns 'unit
1282 3 │ @puts(true)
1283 · ──┬─
1284 · ╰── Failed to unify types: String, Boolean
1285 ╰────
1286
1287 "#]],
1288 );
1289 }
1290
1291 #[test]
1292 fn intrinsic_and_return_ty_dont_match() {
1293 check(
1294 r#"
1295 fn my_func() returns 'bool
1296 @puts("test")"#,
1297 expect![[r#"
1298 fn my_func: bool
1299 intrinsic: @puts(literal: "test")
1300
1301 "#]],
1302 );
1303 }
1304
1305 #[test]
1306 fn pass_wrong_type_fn_call_to_intrinsic() {
1307 check(
1308 r#"
1309 fn bool_literal() returns 'bool
1310 true
1311
1312 fn my_func() returns 'unit
1313 @puts(~bool_literal)"#,
1314 expect![[r#"
1315 fn bool_literal: bool
1316 literal: true
1317
1318 fn my_func: unit
1319 intrinsic: @puts(function call to functionid0 with args: )
1320
1321
1322 Errors:
1323 × Failed to unify types: String, Boolean
1324 ╭─[test:5:1]
1325 5 │ fn my_func() returns 'unit
1326 6 │ @puts(~bool_literal)
1327 · ───────┬──────
1328 · ╰── Failed to unify types: String, Boolean
1329 ╰────
1330
1331 "#]],
1332 );
1333 }
1334
1335 #[test]
1336 fn multiple_calls_to_fn_dont_unify_params_themselves() {
1337 check(
1338 r#"
1339 fn bool_literal(a in 'A, b in 'B) returns 'bool
1340 true
1341
1342 fn my_func() returns 'bool
1343 ~bool_literal(1, 2)
1344
1345 {- should not unify the parameter types of bool_literal -}
1346 fn my_second_func() returns 'bool
1347 ~bool_literal(true, false)
1348 "#,
1349 expect![[r#"
1350 fn bool_literal: (t4 → t5 → bool)
1351 literal: true
1352
1353 fn my_func: bool
1354 function call to functionid0 with args: a: int, b: int, returns bool
1355
1356 fn my_second_func: bool
1357 function call to functionid0 with args: a: bool, b: bool, returns bool
1358
1359 "#]],
1360 );
1361 }
1362 #[test]
1363 fn list_different_types_type_err() {
1364 check(
1365 r#"
1366 fn my_list() returns 'list [ 1, true ]
1367 "#,
1368 expect![[r#"
1369 fn my_list: t7
1370 list: [literal: 1, literal: true, ]
1371
1372
1373 Errors:
1374 × Failed to unify types: Integer, Boolean
1375 ╭─[test:1:1]
1376 1 │
1377 2 │ fn my_list() returns 'list [ 1, true ]
1378 · ──┬──
1379 · ╰── Failed to unify types: Integer, Boolean
1380 3 │
1381 ╰────
1382
1383 "#]],
1384 );
1385 }
1386
1387 #[test]
1388 fn incorrect_number_of_args() {
1389 check(
1390 r#"
1391 fn add(a in 'int, b in 'int) returns 'int a
1392
1393 fn add_five(a in 'int) returns 'int ~add(5)
1394 "#,
1395 expect![[r#"
1396 fn add: (int → int → int)
1397 variable a: int
1398
1399 fn add_five: (int → int)
1400 error recovery
1401
1402
1403 Errors:
1404 × Function add takes 2 arguments, but got 1 arguments.
1405 ╭─[test:3:1]
1406 3 │
1407 4 │ fn add_five(a in 'int) returns 'int ~add(5)
1408 · ────┬───
1409 · ╰── Function add takes 2 arguments, but got 1 arguments.
1410 5 │
1411 ╰────
1412
1413 "#]],
1414 );
1415 }
1416
1417 #[test]
1418 fn infer_let_bindings() {
1419 check(
1420 r#"
1421 fn hi(x in 'int, y in 'int) returns 'int
1422 let a = x;
1423 b = y;
1424 c = 20;
1425 d = 30;
1426 e = 42;
1427 a
1428fn main() returns 'int ~hi(1, 2)"#,
1429 expect![[r#"
1430 fn hi: (int → int → int)
1431 a: variable: symbolid2 (int),
1432 b: variable: symbolid4 (int),
1433 c: literal: 20 (int),
1434 d: literal: 30 (int),
1435 e: literal: 42 (int),
1436 "variable a: int" (int)
1437
1438 fn main: int
1439 function call to functionid0 with args: x: int, y: int, returns int
1440
1441 "#]],
1442 )
1443 }
1444}