petr_typecheck/
lib.rs

1mod error;
2
3use std::{collections::BTreeMap, rc::Rc};
4
5use error::TypeConstraintError;
6pub use petr_bind::FunctionId;
7use petr_resolve::{Expr, ExprKind, QueryableResolvedItems};
8pub use petr_resolve::{Intrinsic as ResolvedIntrinsic, IntrinsicName, Literal};
9use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, TypeId};
10
11pub type TypeError = SpannedItem<TypeConstraintError>;
12pub type TResult<T> = Result<T, TypeError>;
13
14// TODO return QueryableTypeChecked instead of type checker
15// Clean up API so this is the only function exposed
16pub fn type_check(resolved: QueryableResolvedItems) -> (Vec<TypeError>, TypeChecker) {
17    let mut type_checker = TypeChecker::new(resolved);
18    type_checker.fully_type_check();
19    (type_checker.errors.clone(), type_checker)
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
23pub enum TypeOrFunctionId {
24    TypeId(TypeId),
25    FunctionId(FunctionId),
26}
27
28impl From<TypeId> for TypeOrFunctionId {
29    fn from(type_id: TypeId) -> Self {
30        TypeOrFunctionId::TypeId(type_id)
31    }
32}
33
34impl From<FunctionId> for TypeOrFunctionId {
35    fn from(function_id: FunctionId) -> Self {
36        TypeOrFunctionId::FunctionId(function_id)
37    }
38}
39
40impl From<&TypeId> for TypeOrFunctionId {
41    fn from(type_id: &TypeId) -> Self {
42        TypeOrFunctionId::TypeId(*type_id)
43    }
44}
45
46impl From<&FunctionId> for TypeOrFunctionId {
47    fn from(function_id: &FunctionId) -> Self {
48        TypeOrFunctionId::FunctionId(*function_id)
49    }
50}
51
52idx_map_key!(TypeVariable);
53
54#[derive(Clone, Copy, Debug)]
55pub struct TypeConstraint {
56    kind: TypeConstraintKind,
57    /// The span from which this type constraint originated
58    span: Span,
59}
60impl TypeConstraint {
61    fn unify(
62        t1: TypeVariable,
63        t2: TypeVariable,
64        span: Span,
65    ) -> Self {
66        Self {
67            kind: TypeConstraintKind::Unify(t1, t2),
68            span,
69        }
70    }
71
72    fn satisfies(
73        t1: TypeVariable,
74        t2: TypeVariable,
75        span: Span,
76    ) -> Self {
77        Self {
78            kind: TypeConstraintKind::Satisfies(t1, t2),
79            span,
80        }
81    }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub enum TypeConstraintKind {
86    Unify(TypeVariable, TypeVariable),
87    // constraint that lhs is a "subtype" or satisfies the typeclass constraints of "rhs"
88    Satisfies(TypeVariable, TypeVariable),
89}
90
91pub struct TypeContext {
92    types:          IndexMap<TypeVariable, PetrType>,
93    constraints:    Vec<TypeConstraint>,
94    // known primitive type IDs
95    unit_ty:        TypeVariable,
96    string_ty:      TypeVariable,
97    int_ty:         TypeVariable,
98    error_recovery: TypeVariable,
99}
100
101impl Default for TypeContext {
102    fn default() -> Self {
103        let mut types = IndexMap::default();
104        // instantiate basic primitive types
105        let unit_ty = types.insert(PetrType::Unit);
106        let string_ty = types.insert(PetrType::String);
107        let int_ty = types.insert(PetrType::Integer);
108        let error_recovery = types.insert(PetrType::ErrorRecovery);
109        // insert primitive types
110        TypeContext {
111            types,
112            constraints: Default::default(),
113            unit_ty,
114            string_ty,
115            int_ty,
116            error_recovery,
117        }
118    }
119}
120
121impl TypeContext {
122    fn unify(
123        &mut self,
124        ty1: TypeVariable,
125        ty2: TypeVariable,
126        span: Span,
127    ) {
128        self.constraints.push(TypeConstraint::unify(ty1, ty2, span));
129    }
130
131    fn satisfies(
132        &mut self,
133        ty1: TypeVariable,
134        ty2: TypeVariable,
135        span: Span,
136    ) {
137        self.constraints.push(TypeConstraint::satisfies(ty1, ty2, span));
138    }
139
140    fn new_variable(&mut self) -> TypeVariable {
141        // infer is special -- it knows its own id, mostly for printing
142        let infer_id = self.types.len();
143        self.types.insert(PetrType::Infer(infer_id))
144    }
145
146    /// Update a type variable with a new PetrType
147    fn update_type(
148        &mut self,
149        t1: TypeVariable,
150        known: PetrType,
151    ) {
152        *self.types.get_mut(t1) = known;
153    }
154}
155
156pub struct TypeChecker {
157    ctx: TypeContext,
158    type_map: BTreeMap<TypeOrFunctionId, TypeVariable>,
159    typed_functions: BTreeMap<FunctionId, Function>,
160    errors: Vec<TypeError>,
161    resolved: QueryableResolvedItems,
162    variable_scope: Vec<BTreeMap<Identifier, TypeVariable>>,
163}
164
165#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)]
166pub enum PetrType {
167    Unit,
168    Integer,
169    Boolean,
170    /// a static length string known at compile time
171    String,
172    /// A reference to another type
173    Ref(TypeVariable),
174    /// A user-defined type
175    UserDefined {
176        name:     Identifier,
177        variants: Vec<TypeVariant>,
178    },
179    Arrow(Vec<TypeVariable>),
180    ErrorRecovery,
181    List(TypeVariable),
182    /// the usize is just an identifier for use in rendering the type
183    Infer(usize),
184}
185
186#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)]
187pub struct TypeVariant {
188    pub fields: Box<[TypeVariable]>,
189}
190
191impl TypeChecker {
192    pub fn insert_type(
193        &mut self,
194        ty: PetrType,
195    ) -> TypeVariable {
196        // TODO: check if type already exists and return that ID instead
197        self.ctx.types.insert(ty)
198    }
199
200    pub fn look_up_variable(
201        &self,
202        ty: TypeVariable,
203    ) -> &PetrType {
204        self.ctx.types.get(ty)
205    }
206
207    pub fn get_symbol(
208        &self,
209        id: SymbolId,
210    ) -> Rc<str> {
211        self.resolved.interner.get(id).clone()
212    }
213
214    fn with_type_scope<T>(
215        &mut self,
216        f: impl FnOnce(&mut Self) -> T,
217    ) -> T {
218        self.variable_scope.push(Default::default());
219        let res = f(self);
220        self.variable_scope.pop();
221        res
222    }
223
224    fn generic_type(
225        &mut self,
226        id: &Identifier,
227    ) -> TypeVariable {
228        for scope in self.variable_scope.iter().rev() {
229            if let Some(ty) = scope.get(id) {
230                return *ty;
231            }
232        }
233        let fresh_ty = self.fresh_ty_var();
234        match self.variable_scope.last_mut() {
235            Some(entry) => {
236                entry.insert(*id, fresh_ty);
237            },
238            None => {
239                self.errors.push(id.span.with_item(TypeConstraintError::Internal(
240                    "attempted to insert generic type into variable scope when no variable scope existed".into(),
241                )));
242                self.ctx.update_type(fresh_ty, PetrType::ErrorRecovery);
243            },
244        };
245        fresh_ty
246    }
247
248    fn find_variable(
249        &self,
250        id: Identifier,
251    ) -> Option<TypeVariable> {
252        for scope in self.variable_scope.iter().rev() {
253            if let Some(ty) = scope.get(&id) {
254                return Some(*ty);
255            }
256        }
257        None
258    }
259
260    fn fully_type_check(&mut self) {
261        for (id, decl) in self.resolved.types() {
262            let ty = self.fresh_ty_var();
263            let variants = decl
264                .variants
265                .iter()
266                .map(|variant| {
267                    self.with_type_scope(|ctx| {
268                        let fields = variant.fields.iter().map(|field| ctx.to_type_var(&field.ty)).collect::<Vec<_>>();
269                        TypeVariant {
270                            fields: fields.into_boxed_slice(),
271                        }
272                    })
273                })
274                .collect::<Vec<_>>();
275            self.ctx.update_type(ty, PetrType::UserDefined { name: decl.name, variants });
276            self.type_map.insert(id.into(), ty);
277        }
278
279        for (id, func) in self.resolved.functions() {
280            let typed_function = func.type_check(self);
281
282            let ty = self.arrow_type([typed_function.params.iter().map(|(_, b)| *b).collect(), vec![typed_function.return_ty]].concat());
283            self.type_map.insert(id.into(), ty);
284            self.typed_functions.insert(id, typed_function);
285        }
286
287        // we have now collected our constraints and can solve for them
288        self.apply_constraints();
289    }
290
291    /// iterate through each constraint and transform the underlying types to satisfy them
292    /// - unification tries to collapse two types into one
293    /// - satisfaction tries to make one type satisfy the constraints of another, although type
294    ///   constraints don't exist in the language yet
295    fn apply_constraints(&mut self) {
296        let constraints = self.ctx.constraints.clone();
297        for constraint in constraints {
298            match &constraint.kind {
299                TypeConstraintKind::Unify(t1, t2) => {
300                    self.apply_unify_constraint(*t1, *t2, constraint.span);
301                },
302                TypeConstraintKind::Satisfies(t1, t2) => {
303                    self.apply_satisfies_constraint(*t1, *t2, constraint.span);
304                },
305            }
306        }
307    }
308
309    /// Attempt to unify two types, returning an error if they cannot be unified
310    /// The more specific of the two types will instantiate the more general of the two types.
311    fn apply_unify_constraint(
312        &mut self,
313        t1: TypeVariable,
314        t2: TypeVariable,
315        span: Span,
316    ) {
317        let ty1 = self.ctx.types.get(t1).clone();
318        let ty2 = self.ctx.types.get(t2).clone();
319        use PetrType::*;
320        match (ty1, ty2) {
321            (a, b) if a == b => (),
322            (ErrorRecovery, _) | (_, ErrorRecovery) => (),
323            (Ref(a), _) => self.apply_unify_constraint(a, t2, span),
324            (_, Ref(b)) => self.apply_unify_constraint(t1, b, span),
325            (Infer(id), Infer(id2)) if id != id2 => {
326                // if two different inferred types are unified, replace the second with a reference
327                // to the first
328                self.ctx.update_type(t2, Ref(t1));
329            },
330            // instantiate the infer type with the known type
331            (Infer(_), known) => {
332                self.ctx.update_type(t1, known);
333            },
334            (known, Infer(_)) => {
335                self.ctx.update_type(t2, known);
336            },
337            // lastly, if no unification rule exists for these two types, it is a mismatch
338            (a, b) => {
339                self.push_error(span.with_item(TypeConstraintError::UnificationFailure(a, b)));
340            },
341        }
342    }
343
344    // This function will need to be rewritten when type constraints and bounded polymorphism are
345    // implemented.
346    fn apply_satisfies_constraint(
347        &mut self,
348        t1: TypeVariable,
349        t2: TypeVariable,
350        span: Span,
351    ) {
352        let ty1 = self.ctx.types.get(t1);
353        let ty2 = self.ctx.types.get(t2);
354        use PetrType::*;
355        match (ty1, ty2) {
356            (a, b) if a == b => (),
357            (ErrorRecovery, _) | (_, ErrorRecovery) => (),
358            (Ref(a), _) => self.apply_satisfies_constraint(*a, t2, span),
359            (_, Ref(b)) => self.apply_satisfies_constraint(t1, *b, span),
360            // if t1 is a fully instantiated type, then t2 can be updated to be a reference to t1
361            (_known, Infer(_)) => {
362                self.ctx.update_type(t2, Ref(t1));
363            },
364            //            (Infer(_), _) | (_, Infer(_)) => Ok(()),
365            (a, b) => {
366                self.push_error(span.with_item(TypeConstraintError::FailedToSatisfy(a.clone(), b.clone())));
367            },
368        }
369    }
370
371    pub fn new(resolved: QueryableResolvedItems) -> Self {
372        let ctx = TypeContext::default();
373        let mut type_checker = TypeChecker {
374            ctx,
375            type_map: Default::default(),
376            errors: Default::default(),
377            typed_functions: Default::default(),
378            resolved,
379            variable_scope: Default::default(),
380        };
381
382        type_checker.fully_type_check();
383        type_checker
384    }
385
386    pub fn insert_variable(
387        &mut self,
388        id: Identifier,
389        ty: TypeVariable,
390    ) {
391        self.variable_scope
392            .last_mut()
393            .expect("inserted variable when no scope existed")
394            .insert(id, ty);
395    }
396
397    pub fn fresh_ty_var(&mut self) -> TypeVariable {
398        self.ctx.new_variable()
399    }
400
401    fn arrow_type(
402        &mut self,
403        tys: Vec<TypeVariable>,
404    ) -> TypeVariable {
405        assert!(!tys.is_empty(), "arrow_type: tys is empty");
406
407        if tys.len() == 1 {
408            return tys[0];
409        }
410
411        let ty = PetrType::Arrow(tys);
412        self.ctx.types.insert(ty)
413    }
414
415    pub fn to_type_var(
416        &mut self,
417        ty: &petr_resolve::Type,
418    ) -> TypeVariable {
419        let ty = match ty {
420            petr_resolve::Type::Integer => PetrType::Integer,
421            petr_resolve::Type::Bool => PetrType::Boolean,
422            petr_resolve::Type::Unit => PetrType::Unit,
423            petr_resolve::Type::String => PetrType::String,
424            petr_resolve::Type::ErrorRecovery => {
425                // unifies to anything, fresh var
426                return self.fresh_ty_var();
427            },
428            petr_resolve::Type::Named(ty_id) => PetrType::Ref(*self.type_map.get(&ty_id.into()).expect("type did not exist in type map")),
429            petr_resolve::Type::Generic(generic_name) => {
430                return self.generic_type(generic_name);
431            },
432        };
433        self.ctx.types.insert(ty)
434    }
435
436    pub fn get_type(
437        &self,
438        key: impl Into<TypeOrFunctionId>,
439    ) -> &TypeVariable {
440        self.type_map.get(&key.into()).expect("type did not exist in type map")
441    }
442
443    fn convert_literal_to_type(
444        &mut self,
445        literal: &petr_resolve::Literal,
446    ) -> TypeVariable {
447        use petr_resolve::Literal::*;
448        let ty = match literal {
449            Integer(_) => PetrType::Integer,
450            Boolean(_) => PetrType::Boolean,
451            String(_) => PetrType::String,
452        };
453        self.ctx.types.insert(ty)
454    }
455
456    fn push_error(
457        &mut self,
458        e: TypeError,
459    ) {
460        self.errors.push(e);
461    }
462
463    pub fn unify(
464        &mut self,
465        ty1: TypeVariable,
466        ty2: TypeVariable,
467        span: Span,
468    ) {
469        self.ctx.unify(ty1, ty2, span);
470    }
471
472    pub fn satisfies(
473        &mut self,
474        ty1: TypeVariable,
475        ty2: TypeVariable,
476        span: Span,
477    ) {
478        self.ctx.satisfies(ty1, ty2, span);
479    }
480
481    fn get_untyped_function(
482        &self,
483        function: FunctionId,
484    ) -> &petr_resolve::Function {
485        self.resolved.get_function(function)
486    }
487
488    /// Given a symbol ID, look it up in the interner and realize it as a
489    /// string.
490    fn realize_symbol(
491        &self,
492        id: petr_utils::SymbolId,
493    ) -> Rc<str> {
494        self.resolved.interner.get(id)
495    }
496
497    pub fn get_function(
498        &self,
499        id: &FunctionId,
500    ) -> &Function {
501        self.typed_functions.get(id).expect("invariant: should exist")
502    }
503
504    // TODO unideal clone
505    pub fn functions(&self) -> impl Iterator<Item = (FunctionId, Function)> {
506        self.typed_functions.iter().map(|(a, b)| (*a, b.clone())).collect::<Vec<_>>().into_iter()
507    }
508
509    pub fn expr_ty(
510        &self,
511        expr: &TypedExpr,
512    ) -> TypeVariable {
513        use TypedExprKind::*;
514        match &expr.kind {
515            FunctionCall { ty, .. } => *ty,
516            Literal { ty, .. } => *ty,
517            List { ty, .. } => *ty,
518            Unit => self.unit(),
519            Variable { ty, .. } => *ty,
520            Intrinsic { ty, .. } => *ty,
521            ErrorRecovery(..) => self.ctx.error_recovery,
522            ExprWithBindings { expression, .. } => self.expr_ty(expression),
523            TypeConstructor { ty, .. } => *ty,
524        }
525    }
526
527    /// Given a concrete [`PetrType`], unify it with the return type of the given expression.
528    pub fn unify_expr_return(
529        &mut self,
530        ty: TypeVariable,
531        expr: &TypedExpr,
532    ) {
533        let expr_ty = self.expr_ty(expr);
534        self.unify(ty, expr_ty, expr.span());
535    }
536
537    pub fn string(&self) -> TypeVariable {
538        self.ctx.string_ty
539    }
540
541    pub fn unit(&self) -> TypeVariable {
542        self.ctx.unit_ty
543    }
544
545    pub fn int(&self) -> TypeVariable {
546        self.ctx.int_ty
547    }
548
549    /// To reference an error recovery type, you must provide an error.
550    /// This holds the invariant that error recovery types are only generated when
551    /// an error occurs.
552    pub fn error_recovery(
553        &mut self,
554        err: TypeError,
555    ) -> TypeVariable {
556        self.push_error(err);
557        self.ctx.error_recovery
558    }
559
560    pub fn errors(&self) -> &[TypeError] {
561        &self.errors
562    }
563}
564
565#[derive(Clone)]
566pub enum Intrinsic {
567    Puts(Box<TypedExpr>),
568    Add(Box<TypedExpr>, Box<TypedExpr>),
569    Multiply(Box<TypedExpr>, Box<TypedExpr>),
570    Divide(Box<TypedExpr>, Box<TypedExpr>),
571    Subtract(Box<TypedExpr>, Box<TypedExpr>),
572    Malloc(Box<TypedExpr>),
573}
574
575impl std::fmt::Debug for Intrinsic {
576    fn fmt(
577        &self,
578        f: &mut std::fmt::Formatter<'_>,
579    ) -> std::fmt::Result {
580        match self {
581            Intrinsic::Puts(expr) => write!(f, "@puts({:?})", expr),
582            Intrinsic::Add(lhs, rhs) => write!(f, "@add({:?}, {:?})", lhs, rhs),
583            Intrinsic::Multiply(lhs, rhs) => write!(f, "@multiply({:?}, {:?})", lhs, rhs),
584            Intrinsic::Divide(lhs, rhs) => write!(f, "@divide({:?}, {:?})", lhs, rhs),
585            Intrinsic::Subtract(lhs, rhs) => write!(f, "@subtract({:?}, {:?})", lhs, rhs),
586            Intrinsic::Malloc(size) => write!(f, "@malloc({:?})", size),
587        }
588    }
589}
590
591#[derive(Clone)]
592pub struct TypedExpr {
593    pub kind: TypedExprKind,
594    span:     Span,
595}
596
597impl TypedExpr {
598    pub fn span(&self) -> Span {
599        self.span
600    }
601}
602
603#[derive(Clone, Debug)]
604pub enum TypedExprKind {
605    FunctionCall {
606        func: FunctionId,
607        args: Vec<(Identifier, TypedExpr)>,
608        ty:   TypeVariable,
609    },
610    Literal {
611        value: Literal,
612        ty:    TypeVariable,
613    },
614    List {
615        elements: Vec<TypedExpr>,
616        ty:       TypeVariable,
617    },
618    Unit,
619    Variable {
620        ty:   TypeVariable,
621        name: Identifier,
622    },
623    Intrinsic {
624        ty:        TypeVariable,
625        intrinsic: Intrinsic,
626    },
627    ErrorRecovery(Span),
628    ExprWithBindings {
629        bindings:   Vec<(Identifier, TypedExpr)>,
630        expression: Box<TypedExpr>,
631    },
632    TypeConstructor {
633        ty:   TypeVariable,
634        args: Box<[TypedExpr]>,
635    },
636}
637
638impl std::fmt::Debug for TypedExpr {
639    fn fmt(
640        &self,
641        f: &mut std::fmt::Formatter<'_>,
642    ) -> std::fmt::Result {
643        use TypedExprKind::*;
644        match &self.kind {
645            FunctionCall { func, args, .. } => {
646                write!(f, "function call to {} with args: ", func)?;
647                for (name, arg) in args {
648                    write!(f, "{}: {:?}, ", name.id, arg)?;
649                }
650                Ok(())
651            },
652            Literal { value, .. } => write!(f, "literal: {}", value),
653            List { elements, .. } => {
654                write!(f, "list: [")?;
655                for elem in elements {
656                    write!(f, "{:?}, ", elem)?;
657                }
658                write!(f, "]")
659            },
660            Unit => write!(f, "unit"),
661            Variable { name, .. } => write!(f, "variable: {}", name.id),
662            Intrinsic { intrinsic, .. } => write!(f, "intrinsic: {:?}", intrinsic),
663            ErrorRecovery(..) => write!(f, "error recovery"),
664            ExprWithBindings { bindings, expression } => {
665                write!(f, "bindings: ")?;
666                for (name, expr) in bindings {
667                    write!(f, "{}: {:?}, ", name.id, expr)?;
668                }
669                write!(f, "expression: {:?}", expression)
670            },
671            TypeConstructor { ty, .. } => write!(f, "type constructor: {:?}", ty),
672        }
673    }
674}
675
676impl TypeCheck for Expr {
677    type Output = TypedExpr;
678
679    fn type_check(
680        &self,
681        ctx: &mut TypeChecker,
682    ) -> Self::Output {
683        let kind = match &self.kind {
684            ExprKind::Literal(lit) => {
685                let ty = ctx.convert_literal_to_type(lit);
686                TypedExprKind::Literal { value: lit.clone(), ty }
687            },
688            ExprKind::List(exprs) => {
689                if exprs.is_empty() {
690                    let ty = ctx.unit();
691                    TypedExprKind::List { elements: vec![], ty }
692                } else {
693                    let type_checked_exprs = exprs.iter().map(|expr| expr.type_check(ctx)).collect::<Vec<_>>();
694                    // unify the type of the first expr against everything else in the list
695                    let first_ty = ctx.expr_ty(&type_checked_exprs[0]);
696                    for expr in type_checked_exprs.iter().skip(1) {
697                        let second_ty = ctx.expr_ty(expr);
698                        ctx.unify(first_ty, second_ty, expr.span());
699                    }
700                    TypedExprKind::List {
701                        elements: type_checked_exprs,
702                        ty:       ctx.insert_type(PetrType::List(first_ty)),
703                    }
704                }
705            },
706            ExprKind::FunctionCall(call) => {
707                // unify args with params
708                // return the func return type
709                let func_decl = ctx.get_untyped_function(call.function).clone();
710                if call.args.len() != func_decl.params.len() {
711                    ctx.push_error(call.span().with_item(TypeConstraintError::ArgumentCountMismatch {
712                        expected: func_decl.params.len(),
713                        got:      call.args.len(),
714                        function: ctx.realize_symbol(func_decl.name.id).to_string(),
715                    }));
716                    return TypedExpr {
717                        kind: TypedExprKind::ErrorRecovery(self.span),
718                        span: self.span,
719                    };
720                }
721                let mut args = Vec::with_capacity(call.args.len());
722                let mut arg_types = Vec::with_capacity(call.args.len());
723
724                for (arg, (param_name, param)) in call.args.iter().zip(func_decl.params.iter()) {
725                    let arg_expr = arg.type_check(ctx);
726                    let param_ty = ctx.to_type_var(param);
727                    let arg_ty = ctx.expr_ty(&arg_expr);
728                    ctx.satisfies(arg_ty, param_ty, arg_expr.span());
729                    arg_types.push(arg_ty);
730                    args.push((*param_name, arg_expr));
731                }
732                TypedExprKind::FunctionCall {
733                    func: call.function,
734                    args,
735                    ty: ctx.to_type_var(&func_decl.return_type),
736                }
737            },
738            ExprKind::Unit => TypedExprKind::Unit,
739            ExprKind::ErrorRecovery => TypedExprKind::ErrorRecovery(self.span),
740            ExprKind::Variable { name, ty } => {
741                // look up variable in scope
742                // find its expr return type
743                let var_ty = ctx.find_variable(*name).expect("variable not found in scope");
744                let ty = ctx.to_type_var(ty);
745
746                ctx.unify(var_ty, ty, name.span());
747
748                TypedExprKind::Variable { ty, name: *name }
749            },
750            ExprKind::Intrinsic(intrinsic) => return self.span.with_item(intrinsic.clone()).type_check(ctx),
751            ExprKind::TypeConstructor(parent_type_id, args) => {
752                // This ExprKind only shows up in the body of type constructor functions, and
753                // is basically a noop. The surrounding function decl will handle type checking for
754                // the type constructor.
755                let args = args.iter().map(|arg| arg.type_check(ctx)).collect::<Vec<_>>();
756                let ty = ctx.get_type(*parent_type_id);
757                TypedExprKind::TypeConstructor {
758                    ty:   *ty,
759                    args: args.into_boxed_slice(),
760                }
761            },
762            ExprKind::ExpressionWithBindings { bindings, expression } => {
763                // for each binding, type check the rhs
764                ctx.with_type_scope(|ctx| {
765                    let mut type_checked_bindings = Vec::with_capacity(bindings.len());
766                    for binding in bindings {
767                        let binding_ty = binding.expression.type_check(ctx);
768                        let binding_expr_return_ty = ctx.expr_ty(&binding_ty);
769                        ctx.insert_variable(binding.name, binding_expr_return_ty);
770                        type_checked_bindings.push((binding.name, binding_ty));
771                    }
772
773                    TypedExprKind::ExprWithBindings {
774                        bindings:   type_checked_bindings,
775                        expression: Box::new(expression.type_check(ctx)),
776                    }
777                })
778            },
779        };
780
781        TypedExpr { kind, span: self.span }
782    }
783}
784
785fn unify_basic_math_op(
786    lhs: &Expr,
787    rhs: &Expr,
788    ctx: &mut TypeChecker,
789) -> (TypedExpr, TypedExpr) {
790    let lhs = lhs.type_check(ctx);
791    let rhs = rhs.type_check(ctx);
792    let lhs_ty = ctx.expr_ty(&lhs);
793    let rhs_ty = ctx.expr_ty(&rhs);
794    let int_ty = ctx.int();
795    ctx.unify(lhs_ty, int_ty, lhs.span());
796    ctx.unify(rhs_ty, int_ty, rhs.span());
797    (lhs, rhs)
798}
799
800impl TypeCheck for SpannedItem<ResolvedIntrinsic> {
801    type Output = TypedExpr;
802
803    fn type_check(
804        &self,
805        ctx: &mut TypeChecker,
806    ) -> Self::Output {
807        use petr_resolve::IntrinsicName::*;
808        let string_ty = ctx.string();
809        let kind = match self.item().intrinsic {
810            Puts => {
811                if self.item().args.len() != 1 {
812                    todo!("puts arg len check");
813                }
814                // puts takes a single string and returns unit
815                let arg = self.item().args[0].type_check(ctx);
816                ctx.unify_expr_return(string_ty, &arg);
817                TypedExprKind::Intrinsic {
818                    intrinsic: Intrinsic::Puts(Box::new(arg)),
819                    ty:        ctx.unit(),
820                }
821            },
822            Add => {
823                if self.item().args.len() != 2 {
824                    todo!("add arg len check");
825                }
826                let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
827                TypedExprKind::Intrinsic {
828                    intrinsic: Intrinsic::Add(Box::new(lhs), Box::new(rhs)),
829                    ty:        ctx.int(),
830                }
831            },
832            Subtract => {
833                if self.item().args.len() != 2 {
834                    todo!("sub arg len check");
835                }
836                let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
837                TypedExprKind::Intrinsic {
838                    intrinsic: Intrinsic::Subtract(Box::new(lhs), Box::new(rhs)),
839                    ty:        ctx.int(),
840                }
841            },
842            Multiply => {
843                if self.item().args.len() != 2 {
844                    todo!("mult arg len check");
845                }
846
847                let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
848                TypedExprKind::Intrinsic {
849                    intrinsic: Intrinsic::Multiply(Box::new(lhs), Box::new(rhs)),
850                    ty:        ctx.int(),
851                }
852            },
853
854            Divide => {
855                if self.item().args.len() != 2 {
856                    todo!("Divide arg len check");
857                }
858
859                let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx);
860                TypedExprKind::Intrinsic {
861                    intrinsic: Intrinsic::Divide(Box::new(lhs), Box::new(rhs)),
862                    ty:        ctx.int(),
863                }
864            },
865            Malloc => {
866                // malloc takes one integer (the number of bytes to allocate)
867                // and returns a pointer to the allocated memory
868                // will return `0` if the allocation fails
869                // in the future, this might change to _words_ of allocation,
870                // depending on the compilation target
871                if self.item().args.len() != 1 {
872                    todo!("malloc arg len check");
873                }
874                let arg = self.item().args[0].type_check(ctx);
875                let arg_ty = ctx.expr_ty(&arg);
876                let int_ty = ctx.int();
877                ctx.unify(arg_ty, int_ty, arg.span());
878                TypedExprKind::Intrinsic {
879                    intrinsic: Intrinsic::Malloc(Box::new(arg)),
880                    ty:        int_ty,
881                }
882            },
883        };
884
885        TypedExpr { kind, span: self.span() }
886    }
887}
888
889trait TypeCheck {
890    type Output;
891    fn type_check(
892        &self,
893        ctx: &mut TypeChecker,
894    ) -> Self::Output;
895}
896
897#[derive(Clone, Debug)]
898pub struct Function {
899    pub name:      Identifier,
900    pub params:    Vec<(Identifier, TypeVariable)>,
901    pub body:      TypedExpr,
902    pub return_ty: TypeVariable,
903}
904
905impl TypeCheck for petr_resolve::Function {
906    type Output = Function;
907
908    fn type_check(
909        &self,
910        ctx: &mut TypeChecker,
911    ) -> Self::Output {
912        ctx.with_type_scope(|ctx| {
913            let params = self.params.iter().map(|(name, ty)| (*name, ctx.to_type_var(ty))).collect::<Vec<_>>();
914
915            for (name, ty) in &params {
916                ctx.insert_variable(*name, *ty);
917            }
918
919            // unify types within the body with the parameter
920            let body = self.body.type_check(ctx);
921
922            let declared_return_type = ctx.to_type_var(&self.return_type);
923
924            Function {
925                name: self.name,
926                params,
927                return_ty: declared_return_type,
928                body,
929            }
930        })
931        // in a scope that contains the above names to type variables, check the body
932        // TODO: introduce scopes here, like in the binder, except with type variables
933    }
934}
935
936impl TypeCheck for petr_resolve::FunctionCall {
937    type Output = ();
938
939    fn type_check(
940        &self,
941        ctx: &mut TypeChecker,
942    ) -> Self::Output {
943        let func_type = *ctx.get_type(self.function);
944        let args = self.args.iter().map(|arg| arg.type_check(ctx)).collect::<Vec<_>>();
945
946        let mut arg_types = Vec::with_capacity(args.len());
947
948        for arg in args.iter() {
949            arg_types.push(ctx.expr_ty(arg));
950        }
951
952        let arg_type = ctx.arrow_type(arg_types);
953
954        ctx.unify(func_type, arg_type, self.span());
955    }
956}
957
958#[cfg(test)]
959mod tests {
960    use expect_test::{expect, Expect};
961    use petr_resolve::resolve_symbols;
962    use petr_utils::{render_error, SourceId};
963
964    use super::*;
965    fn check(
966        input: impl Into<String>,
967        expect: Expect,
968    ) {
969        let input = input.into();
970        let parser = petr_parse::Parser::new(vec![("test", input)]);
971        let (ast, errs, interner, source_map) = parser.into_result();
972        if !errs.is_empty() {
973            errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err)));
974            panic!("fmt failed: code didn't parse");
975        }
976        let (errs, resolved) = resolve_symbols(ast, interner, Default::default());
977        assert!(errs.is_empty(), "can't typecheck: unresolved symbols");
978        let type_checker = TypeChecker::new(resolved);
979        let res = pretty_print_type_checker(type_checker, &source_map);
980
981        expect.assert_eq(&res);
982    }
983
984    fn pretty_print_type_checker(
985        type_checker: TypeChecker,
986        source_map: &IndexMap<SourceId, (&'static str, &'static str)>,
987    ) -> String {
988        let mut s = String::new();
989        for (id, ty) in &type_checker.type_map {
990            let text = match id {
991                TypeOrFunctionId::TypeId(id) => {
992                    let ty = type_checker.resolved.get_type(*id);
993
994                    let name = type_checker.resolved.interner.get(ty.name.id);
995                    format!("type {}", name)
996                },
997                TypeOrFunctionId::FunctionId(id) => {
998                    let func = type_checker.resolved.get_function(*id);
999
1000                    let name = type_checker.resolved.interner.get(func.name.id);
1001
1002                    format!("fn {}", name)
1003                },
1004            };
1005            s.push_str(&text);
1006            s.push_str(": ");
1007            s.push_str(&pretty_print_ty(ty, &type_checker));
1008
1009            s.push('\n');
1010            match id {
1011                TypeOrFunctionId::TypeId(_) => (),
1012                TypeOrFunctionId::FunctionId(func) => {
1013                    let func = type_checker.typed_functions.get(func).unwrap();
1014                    let body = &func.body;
1015                    s.push_str(&pretty_print_typed_expr(body, &type_checker));
1016                    s.push('\n');
1017                },
1018            }
1019
1020            s.push('\n');
1021        }
1022
1023        if !type_checker.errors.is_empty() {
1024            s.push_str("\nErrors:\n");
1025            for error in type_checker.errors {
1026                let rendered = render_error(source_map, error);
1027                s.push_str(&format!("{:?}\n", rendered));
1028            }
1029        }
1030        s
1031    }
1032
1033    fn pretty_print_ty(
1034        ty: &TypeVariable,
1035        type_checker: &TypeChecker,
1036    ) -> String {
1037        let mut ty = type_checker.look_up_variable(*ty);
1038        while let PetrType::Ref(t) = ty {
1039            ty = type_checker.look_up_variable(*t);
1040        }
1041        match ty {
1042            PetrType::Unit => "unit".to_string(),
1043            PetrType::Integer => "int".to_string(),
1044            PetrType::Boolean => "bool".to_string(),
1045            PetrType::String => "string".to_string(),
1046            PetrType::Ref(ty) => pretty_print_ty(ty, type_checker),
1047            PetrType::UserDefined { name, variants: _ } => {
1048                let name = type_checker.resolved.interner.get(name.id);
1049                name.to_string()
1050            },
1051            PetrType::Arrow(tys) => {
1052                let mut s = String::new();
1053                s.push('(');
1054                for (ix, ty) in tys.iter().enumerate() {
1055                    let is_last = ix == tys.len() - 1;
1056
1057                    s.push_str(&pretty_print_ty(ty, type_checker));
1058                    if !is_last {
1059                        s.push_str(" → ");
1060                    }
1061                }
1062                s.push(')');
1063                s
1064            },
1065            PetrType::ErrorRecovery => "error recovery".to_string(),
1066            PetrType::List(ty) => format!("[{}]", pretty_print_ty(ty, type_checker)),
1067            PetrType::Infer(id) => format!("t{id}"),
1068        }
1069    }
1070
1071    fn pretty_print_typed_expr(
1072        typed_expr: &TypedExpr,
1073        type_checker: &TypeChecker,
1074    ) -> String {
1075        let interner = &type_checker.resolved.interner;
1076        match &typed_expr.kind {
1077            TypedExprKind::ExprWithBindings { bindings, expression } => {
1078                let mut s = String::new();
1079                for (name, expr) in bindings {
1080                    let ident = interner.get(name.id);
1081                    let ty = type_checker.expr_ty(expr);
1082                    let ty = pretty_print_ty(&ty, type_checker);
1083                    s.push_str(&format!("{ident}: {:?} ({}),\n", expr, ty));
1084                }
1085                let expr_ty = type_checker.expr_ty(expression);
1086                let expr_ty = pretty_print_ty(&expr_ty, type_checker);
1087                s.push_str(&format!("{:?} ({})", pretty_print_typed_expr(expression, type_checker), expr_ty));
1088                s
1089            },
1090            TypedExprKind::Variable { name, ty } => {
1091                let name = interner.get(name.id);
1092                let ty = pretty_print_ty(ty, type_checker);
1093                format!("variable {name}: {ty}")
1094            },
1095
1096            TypedExprKind::FunctionCall { func, args, ty } => {
1097                let mut s = String::new();
1098                s.push_str(&format!("function call to {} with args: ", func));
1099                for (name, arg) in args {
1100                    let name = interner.get(name.id);
1101                    let arg_ty = type_checker.expr_ty(arg);
1102                    let arg_ty = pretty_print_ty(&arg_ty, type_checker);
1103                    s.push_str(&format!("{name}: {}, ", arg_ty));
1104                }
1105                let ty = pretty_print_ty(ty, type_checker);
1106                s.push_str(&format!("returns {ty}"));
1107                s
1108            },
1109            TypedExprKind::TypeConstructor { ty, .. } => format!("type constructor: {}", pretty_print_ty(ty, type_checker)),
1110            _otherwise => format!("{:?}", typed_expr),
1111        }
1112    }
1113
1114    #[test]
1115    fn identity_resolution_concrete_type() {
1116        check(
1117            r#"
1118            fn foo(x in 'int) returns 'int x
1119            "#,
1120            expect![[r#"
1121                fn foo: (int → int)
1122                variable x: int
1123
1124            "#]],
1125        );
1126    }
1127
1128    #[test]
1129    fn identity_resolution_generic() {
1130        check(
1131            r#"
1132            fn foo(x in 'A) returns 'A x
1133            "#,
1134            expect![[r#"
1135                fn foo: (t4 → t4)
1136                variable x: t4
1137
1138            "#]],
1139        );
1140    }
1141
1142    #[test]
1143    fn identity_resolution_custom_type() {
1144        check(
1145            r#"
1146            type MyType = A | B
1147            fn foo(x in 'MyType) returns 'MyType x
1148            "#,
1149            expect![[r#"
1150                type MyType: MyType
1151
1152                fn A: MyType
1153                type constructor: MyType
1154
1155                fn B: MyType
1156                type constructor: MyType
1157
1158                fn foo: (MyType → MyType)
1159                variable x: MyType
1160
1161            "#]],
1162        );
1163    }
1164
1165    #[test]
1166    fn identity_resolution_two_custom_types() {
1167        check(
1168            r#"
1169            type MyType = A | B
1170            type MyComposedType = firstVariant someField 'MyType | secondVariant someField 'int someField2 'MyType someField3 'GenericType
1171            fn foo(x in 'MyType) returns 'MyComposedType ~firstVariant(x)
1172            "#,
1173            expect![[r#"
1174                type MyType: MyType
1175
1176                type MyComposedType: MyComposedType
1177
1178                fn A: MyType
1179                type constructor: MyType
1180
1181                fn B: MyType
1182                type constructor: MyType
1183
1184                fn firstVariant: (MyType → MyComposedType)
1185                type constructor: MyComposedType
1186
1187                fn secondVariant: (int → MyType → t18 → MyComposedType)
1188                type constructor: MyComposedType
1189
1190                fn foo: (MyType → MyComposedType)
1191                function call to functionid2 with args: someField: MyType, returns MyComposedType
1192
1193            "#]],
1194        );
1195    }
1196
1197    #[test]
1198    fn literal_unification_fail() {
1199        check(
1200            r#"
1201            fn foo() returns 'int 5
1202            fn bar() returns 'bool 5
1203            "#,
1204            expect![[r#"
1205                fn foo: int
1206                literal: 5
1207
1208                fn bar: bool
1209                literal: 5
1210
1211            "#]],
1212        );
1213    }
1214
1215    #[test]
1216    fn literal_unification_success() {
1217        check(
1218            r#"
1219            fn foo() returns 'int 5
1220            fn bar() returns 'bool true
1221            "#,
1222            expect![[r#"
1223                fn foo: int
1224                literal: 5
1225
1226                fn bar: bool
1227                literal: true
1228
1229            "#]],
1230        );
1231    }
1232
1233    #[test]
1234    fn pass_zero_arity_func_to_intrinsic() {
1235        check(
1236            r#"
1237        fn string_literal() returns 'string
1238          "This is a string literal."
1239
1240        fn my_func() returns 'unit
1241          @puts(~string_literal)"#,
1242            expect![[r#"
1243                fn string_literal: string
1244                literal: "This is a string literal."
1245
1246                fn my_func: unit
1247                intrinsic: @puts(function call to functionid0 with args: )
1248
1249            "#]],
1250        );
1251    }
1252
1253    #[test]
1254    fn pass_literal_string_to_intrinsic() {
1255        check(
1256            r#"
1257        fn my_func() returns 'unit
1258          @puts("test")"#,
1259            expect![[r#"
1260                fn my_func: unit
1261                intrinsic: @puts(literal: "test")
1262
1263            "#]],
1264        );
1265    }
1266
1267    #[test]
1268    fn pass_wrong_type_literal_to_intrinsic() {
1269        check(
1270            r#"
1271        fn my_func() returns 'unit
1272          @puts(true)"#,
1273            expect![[r#"
1274                fn my_func: unit
1275                intrinsic: @puts(literal: true)
1276
1277
1278                Errors:
1279                  × Failed to unify types: String, Boolean
1280                   ╭─[test:2:1]
1281                 2 │         fn my_func() returns 'unit
1282                 3 │           @puts(true)
1283                   ·                 ──┬─
1284                   ·                   ╰── Failed to unify types: String, Boolean
1285                   ╰────
1286
1287            "#]],
1288        );
1289    }
1290
1291    #[test]
1292    fn intrinsic_and_return_ty_dont_match() {
1293        check(
1294            r#"
1295        fn my_func() returns 'bool
1296          @puts("test")"#,
1297            expect![[r#"
1298                fn my_func: bool
1299                intrinsic: @puts(literal: "test")
1300
1301            "#]],
1302        );
1303    }
1304
1305    #[test]
1306    fn pass_wrong_type_fn_call_to_intrinsic() {
1307        check(
1308            r#"
1309        fn bool_literal() returns 'bool
1310            true
1311
1312        fn my_func() returns 'unit
1313          @puts(~bool_literal)"#,
1314            expect![[r#"
1315                fn bool_literal: bool
1316                literal: true
1317
1318                fn my_func: unit
1319                intrinsic: @puts(function call to functionid0 with args: )
1320
1321
1322                Errors:
1323                  × Failed to unify types: String, Boolean
1324                   ╭─[test:5:1]
1325                 5 │         fn my_func() returns 'unit
1326                 6 │           @puts(~bool_literal)
1327                   ·                 ───────┬──────
1328                   ·                        ╰── Failed to unify types: String, Boolean
1329                   ╰────
1330
1331            "#]],
1332        );
1333    }
1334
1335    #[test]
1336    fn multiple_calls_to_fn_dont_unify_params_themselves() {
1337        check(
1338            r#"
1339        fn bool_literal(a in 'A, b in 'B) returns 'bool
1340            true
1341
1342        fn my_func() returns 'bool
1343            ~bool_literal(1, 2)
1344
1345        {- should not unify the parameter types of bool_literal -}
1346        fn my_second_func() returns 'bool
1347            ~bool_literal(true, false)
1348        "#,
1349            expect![[r#"
1350                fn bool_literal: (t4 → t5 → bool)
1351                literal: true
1352
1353                fn my_func: bool
1354                function call to functionid0 with args: a: int, b: int, returns bool
1355
1356                fn my_second_func: bool
1357                function call to functionid0 with args: a: bool, b: bool, returns bool
1358
1359            "#]],
1360        );
1361    }
1362    #[test]
1363    fn list_different_types_type_err() {
1364        check(
1365            r#"
1366                fn my_list() returns 'list [ 1, true ]
1367            "#,
1368            expect![[r#"
1369                fn my_list: t7
1370                list: [literal: 1, literal: true, ]
1371
1372
1373                Errors:
1374                  × Failed to unify types: Integer, Boolean
1375                   ╭─[test:1:1]
1376                 1 │ 
1377                 2 │                 fn my_list() returns 'list [ 1, true ]
1378                   ·                                                ──┬──
1379                   ·                                                  ╰── Failed to unify types: Integer, Boolean
1380                 3 │             
1381                   ╰────
1382
1383            "#]],
1384        );
1385    }
1386
1387    #[test]
1388    fn incorrect_number_of_args() {
1389        check(
1390            r#"
1391                fn add(a in 'int, b in 'int) returns 'int a
1392
1393                fn add_five(a in 'int) returns 'int ~add(5)
1394            "#,
1395            expect![[r#"
1396                fn add: (int → int → int)
1397                variable a: int
1398
1399                fn add_five: (int → int)
1400                error recovery
1401
1402
1403                Errors:
1404                  × Function add takes 2 arguments, but got 1 arguments.
1405                   ╭─[test:3:1]
1406                 3 │ 
1407                 4 │                 fn add_five(a in 'int) returns 'int ~add(5)
1408                   ·                                                    ────┬───
1409                   ·                                                        ╰── Function add takes 2 arguments, but got 1 arguments.
1410                 5 │             
1411                   ╰────
1412
1413            "#]],
1414        );
1415    }
1416
1417    #[test]
1418    fn infer_let_bindings() {
1419        check(
1420            r#"
1421            fn hi(x in 'int, y in 'int) returns 'int
1422    let a = x;
1423        b = y;
1424        c = 20;
1425        d = 30;
1426        e = 42;
1427    a
1428fn main() returns 'int ~hi(1, 2)"#,
1429            expect![[r#"
1430                fn hi: (int → int → int)
1431                a: variable: symbolid2 (int),
1432                b: variable: symbolid4 (int),
1433                c: literal: 20 (int),
1434                d: literal: 30 (int),
1435                e: literal: 42 (int),
1436                "variable a: int" (int)
1437
1438                fn main: int
1439                function call to functionid0 with args: x: int, y: int, returns int
1440
1441            "#]],
1442        )
1443    }
1444}