use super::*;
use crate::{BlockToFunctionRewriter, VariableType};
use leo_ast::{
Type::{Future, Tuple},
*,
};
use leo_errors::{TypeCheckerError, TypeCheckerWarning};
use leo_span::{Span, Symbol, sym};
use itertools::Itertools as _;
#[derive(Clone, Debug)]
pub struct AssignTargetInfo {
pub ty: Type,
pub kind: AssignTargetKind,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AssignTargetKind {
Local,
Storage,
ExternalStorage,
Err,
}
impl TypeCheckingVisitor<'_> {
pub fn visit_expression_assign(&mut self, input: &Expression) -> AssignTargetInfo {
let assign_target_info = match input {
Expression::ArrayAccess(array_access) => AssignTargetInfo {
ty: self.visit_array_access_general(array_access, true, &None),
kind: AssignTargetKind::Local,
},
Expression::Path(path) if path.qualifier().is_empty() => self.visit_path_assign(path),
Expression::MemberAccess(member_access) => AssignTargetInfo {
ty: self.visit_member_access_general(member_access, true, &None),
kind: AssignTargetKind::Local,
},
Expression::TupleAccess(tuple_access) => AssignTargetInfo {
ty: self.visit_tuple_access_general(tuple_access, true, &None),
kind: AssignTargetKind::Local,
},
_ => {
self.emit_err(TypeCheckerError::invalid_assignment_target(input, input.span()));
AssignTargetInfo { ty: Type::Err, kind: AssignTargetKind::Err }
}
};
if let Expression::Path(path) = input
&& !self.symbol_in_conditional_scope(path.identifier().name)
&& (self.is_external_record(&assign_target_info.ty)
|| matches!(&assign_target_info.ty, Type::Tuple(tuple) if tuple.elements().iter().any(|ty| self.is_external_record(ty))))
{
if self.is_external_record(&assign_target_info.ty) {
self.emit_err(TypeCheckerError::assignment_to_external_record_cond(
&assign_target_info.ty,
input.span(),
));
} else {
self.emit_err(TypeCheckerError::assignment_to_external_record_tuple_cond(
&assign_target_info.ty,
input.span(),
));
}
}
match &assign_target_info.ty {
Type::Future(..) => self.emit_err(TypeCheckerError::cannot_reassign_future_variable(input, input.span())),
Type::Mapping(_) => self.emit_err(TypeCheckerError::cannot_reassign_mapping(input, input.span())),
_ => {}
}
self.state.type_table.insert(input.id(), assign_target_info.ty.clone());
assign_target_info
}
pub fn visit_array_access_general(&mut self, input: &ArrayAccess, assign: bool, expected: &Option<Type>) -> Type {
let this_type = if assign {
self.visit_expression_assign(&input.array).ty
} else {
self.visit_expression(&input.array, &None)
};
self.assert_array_type(&this_type, input.array.span());
let mut index_type = self.visit_expression(&input.index, &None);
if index_type == Type::Numeric {
index_type = Type::Integer(IntegerType::U32);
if let Expression::Literal(literal) = &input.index {
self.check_numeric_literal(literal, &index_type);
}
}
self.assert_int_type(&index_type, input.index.span());
self.state.type_table.insert(input.index.id(), index_type.clone());
let Type::Array(array_type) = this_type else {
return Type::Err;
};
let element_type = array_type.element_type();
self.maybe_assert_type(element_type, expected, input.span());
element_type.clone()
}
pub fn visit_member_access_general(&mut self, input: &MemberAccess, assign: bool, expected: &Option<Type>) -> Type {
let ty = if assign {
self.visit_expression_assign(&input.inner).ty
} else {
self.visit_expression(&input.inner, &None)
};
if assign && self.is_external_record(&ty) {
self.emit_err(TypeCheckerError::assignment_to_external_record_member(&ty, input.span));
}
match ty {
Type::Err => Type::Err,
Type::Composite(ref composite) => {
let Some(composite) = self.lookup_composite(composite.path.expect_global_location()) else {
self.emit_err(TypeCheckerError::undefined_type(ty, input.inner.span()));
return Type::Err;
};
match composite.members.iter().find(|member| member.name() == input.name.name) {
Some(Member { type_, .. }) => {
self.maybe_assert_type(type_, expected, input.span());
type_.clone()
}
None => {
self.emit_err(TypeCheckerError::invalid_composite_variable(
input.name,
&composite,
input.name.span(),
));
Type::Err
}
}
}
type_ => {
self.emit_err(TypeCheckerError::type_should_be2(type_, "a struct or record", input.inner.span()));
Type::Err
}
}
}
pub fn visit_tuple_access_general(&mut self, input: &TupleAccess, assign: bool, expected: &Option<Type>) -> Type {
let this_type = if assign {
self.visit_expression_assign(&input.tuple).ty
} else {
self.visit_expression(&input.tuple, &None)
};
match this_type {
Type::Err => Type::Err,
Type::Tuple(tuple) => {
let index = input.index.value();
let Some(actual) = tuple.elements().get(index) else {
self.emit_err(TypeCheckerError::tuple_out_of_range(index, tuple.length(), input.span()));
return Type::Err;
};
self.maybe_assert_type(actual, expected, input.span());
actual.clone()
}
Type::Future(_) => {
let Some(Type::Future(inferred_f)) = self.state.type_table.get(&input.tuple.id()) else {
return Type::Err;
};
if inferred_f.location.is_none() {
self.emit_err(TypeCheckerError::invalid_async_block_future_access(input.span()));
return Type::Err;
}
let Some(actual) = inferred_f.inputs().get(input.index.value()) else {
self.emit_err(TypeCheckerError::invalid_future_access(
input.index.value(),
inferred_f.inputs().len(),
input.span(),
));
return Type::Err;
};
if let Type::Err = actual {
self.emit_err(TypeCheckerError::future_error_member(input.index.value(), input.span()));
return Type::Err;
}
self.maybe_assert_type(actual, expected, input.span());
actual.clone()
}
type_ => {
self.emit_err(TypeCheckerError::type_should_be2(type_, "a tuple or future", input.span()));
Type::Err
}
}
}
pub fn visit_path_assign(&mut self, input: &Path) -> AssignTargetInfo {
let current_program = self.scope_state.program_name.unwrap();
let Some(var) = self.state.symbol_table.lookup_path(current_program, input) else {
self.emit_err(TypeCheckerError::unknown_sym("variable", input, input.span));
return AssignTargetInfo { ty: Type::Err, kind: AssignTargetKind::Err };
};
let ty = var.type_.expect("must be known by now").clone();
if ty.is_vector() {
self.emit_err(TypeCheckerError::invalid_assignment_target(input, input.span()));
return AssignTargetInfo { ty: Type::Err, kind: AssignTargetKind::Err };
}
match &var.declaration {
VariableType::Const => self.emit_err(TypeCheckerError::cannot_assign_to_const_var(input, var.span)),
VariableType::ConstParameter => {
self.emit_err(TypeCheckerError::cannot_assign_to_generic_const_function_parameter(input, input.span))
}
VariableType::Input(Mode::Constant) => {
self.emit_err(TypeCheckerError::cannot_assign_to_const_input(input, var.span))
}
VariableType::Storage => {
let kind = if input.user_program().is_some() {
AssignTargetKind::ExternalStorage
} else {
AssignTargetKind::Storage
};
return AssignTargetInfo { ty, kind };
}
VariableType::Mut | VariableType::Input(_) => {}
}
if self.scope_state.variant.unwrap().is_async_function()
&& !self.symbol_in_conditional_scope(input.identifier().name)
{
self.emit_err(TypeCheckerError::async_cannot_assign_outside_conditional(input, "function", var.span));
}
if self.async_block_id.is_some() && !self.symbol_in_conditional_scope(input.identifier().name) {
self.emit_err(TypeCheckerError::async_cannot_assign_outside_conditional(input, "block", var.span));
}
if let Some(async_block_id) = self.async_block_id
&& !self.state.symbol_table.is_defined_in_scope_or_ancestor_until(async_block_id, input.identifier().name)
{
self.emit_err(TypeCheckerError::cannot_assign_to_vars_outside_async_block(
input.identifier().name,
input.span,
));
}
AssignTargetInfo { ty, kind: AssignTargetKind::Local }
}
pub(crate) fn visit_expression_reject_numeric(&mut self, expr: &Expression, expected: &Option<Type>) -> Type {
let mut inferred = self.visit_expression(expr, expected);
match inferred {
Type::Numeric => {
self.emit_inference_failure_error(&mut inferred, expr);
Type::Err
}
_ => inferred,
}
}
pub(crate) fn visit_expression_infer_default_u32(&mut self, expr: &Expression) -> Type {
let mut inferred = self.visit_expression(expr, &None);
if inferred == Type::Numeric {
inferred = Type::Integer(IntegerType::U32);
if let Expression::Literal(literal) = expr
&& !self.check_numeric_literal(literal, &inferred)
{
inferred = Type::Err;
}
self.state.type_table.insert(expr.id(), inferred.clone());
}
inferred
}
}
impl AstVisitor for TypeCheckingVisitor<'_> {
type AdditionalInput = Option<Type>;
type Output = Type;
fn visit_array_type(&mut self, input: &ArrayType) {
self.visit_type(&input.element_type);
self.visit_expression_infer_default_u32(&input.length);
}
fn visit_composite_type(&mut self, input: &CompositeType) {
let composite = self.lookup_composite(input.path.expect_global_location()).clone();
if let Some(composite) = composite {
if composite.const_parameters.len() != input.const_arguments.len() {
self.emit_err(TypeCheckerError::incorrect_num_const_args(
"Composite type",
composite.const_parameters.len(),
input.const_arguments.len(),
input.path.span,
));
}
for (expected, argument) in composite.const_parameters.iter().zip(input.const_arguments.iter()) {
self.visit_expression(argument, &Some(expected.type_().clone()));
}
} else if !input.const_arguments.is_empty() {
self.emit_err(TypeCheckerError::unexpected_const_args(input, input.path.span));
}
}
fn visit_expression(&mut self, input: &Expression, additional: &Self::AdditionalInput) -> Self::Output {
let output = match input {
Expression::Array(array) => self.visit_array(array, additional),
Expression::ArrayAccess(access) => self.visit_array_access_general(access, false, additional),
Expression::Intrinsic(intr) => self.visit_intrinsic(intr, additional),
Expression::Async(async_) => self.visit_async(async_, additional),
Expression::Binary(binary) => self.visit_binary(binary, additional),
Expression::Call(call) => self.visit_call(call, additional),
Expression::Cast(cast) => self.visit_cast(cast, additional),
Expression::Composite(composite) => self.visit_composite_init(composite, additional),
Expression::Err(err) => self.visit_err(err, additional),
Expression::Path(path) => self.visit_path(path, additional),
Expression::Literal(literal) => self.visit_literal(literal, additional),
Expression::MemberAccess(access) => self.visit_member_access_general(access, false, additional),
Expression::Repeat(repeat) => self.visit_repeat(repeat, additional),
Expression::Ternary(ternary) => self.visit_ternary(ternary, additional),
Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
Expression::TupleAccess(access) => self.visit_tuple_access_general(access, false, additional),
Expression::Unary(unary) => self.visit_unary(unary, additional),
Expression::Unit(unit) => self.visit_unit(unit, additional),
};
self.state.type_table.insert(input.id(), output.clone());
output
}
fn visit_array_access(&mut self, _input: &ArrayAccess, _additional: &Self::AdditionalInput) -> Self::Output {
panic!("Should not be called.");
}
fn visit_member_access(&mut self, _input: &MemberAccess, _additional: &Self::AdditionalInput) -> Self::Output {
panic!("Should not be called.");
}
fn visit_tuple_access(&mut self, _input: &TupleAccess, _additional: &Self::AdditionalInput) -> Self::Output {
panic!("Should not be called.");
}
fn visit_array(&mut self, input: &ArrayExpression, additional: &Self::AdditionalInput) -> Self::Output {
let element_type = match additional {
Some(Type::Array(array_ty)) => Some(array_ty.element_type().clone()),
Some(Type::Optional(opt)) => match &*opt.inner {
Type::Array(array_ty) => Some(array_ty.element_type().clone()),
_ => None,
},
_ => None,
};
let inferred_type = if input.elements.is_empty() {
if let Some(ty) = element_type.clone() {
ty
} else {
self.emit_err(TypeCheckerError::could_not_determine_type(input, input.span()));
Type::Err
}
} else {
self.visit_expression_reject_numeric(&input.elements[0], &element_type)
};
if input.elements.len() > self.limits.max_array_elements {
self.emit_err(TypeCheckerError::array_too_large(
input.elements.len(),
self.limits.max_array_elements,
input.span(),
));
}
for expression in input.elements.iter().skip(1) {
let next_type = self.visit_expression_reject_numeric(expression, &element_type);
if next_type == Type::Err {
return Type::Err;
}
if let Some(ref element_type) = element_type {
self.assert_type(&next_type, element_type, expression.span());
} else {
self.assert_type(&next_type, &inferred_type, expression.span());
}
}
if inferred_type == Type::Err {
return Type::Err;
}
let type_ = Type::Array(ArrayType::new(
inferred_type,
Expression::Literal(Literal {
variant: LiteralVariant::Integer(IntegerType::U32, input.elements.len().to_string()),
id: self.state.node_builder.next_id(),
span: Span::default(),
}),
));
self.maybe_assert_type(&type_, additional, input.span());
type_
}
fn visit_repeat(&mut self, input: &RepeatExpression, additional: &Self::AdditionalInput) -> Self::Output {
let expected_element_type = match additional {
Some(Type::Array(array_ty)) => Some(array_ty.element_type().clone()),
Some(Type::Optional(opt)) => match &*opt.inner {
Type::Array(array_ty) => Some(array_ty.element_type().clone()),
_ => None,
},
_ => None,
};
let inferred_element_type = self.visit_expression_reject_numeric(&input.expr, &expected_element_type);
self.visit_expression_infer_default_u32(&input.count);
if let Some(count) = input.count.as_u32()
&& count > self.limits.max_array_elements as u32
{
self.emit_err(TypeCheckerError::array_too_large(count, self.limits.max_array_elements, input.span()));
}
let type_ = Type::Array(ArrayType::new(inferred_element_type, input.count.clone()));
self.maybe_assert_type(&type_, additional, input.span());
type_
}
fn visit_intrinsic(&mut self, input: &IntrinsicExpression, expected: &Self::AdditionalInput) -> Self::Output {
let Some(intrinsic) = self.get_intrinsic(input) else {
return Type::Err;
};
if !matches!(self.scope_state.variant, Some(Variant::AsyncFunction) | Some(Variant::Script))
&& self.async_block_id.is_none()
&& intrinsic.is_finalize_command()
{
self.emit_err(TypeCheckerError::operation_must_be_in_async_block_or_function(input.span()));
}
let return_type = self.check_intrinsic(intrinsic.clone(), &input.arguments, expected, input.span());
self.maybe_assert_type(&return_type, expected, input.span());
if intrinsic == Intrinsic::FutureAwait && input.arguments.len() != 1 {
self.emit_err(TypeCheckerError::can_only_await_one_future_at_a_time(input.span));
}
return_type
}
fn visit_async(&mut self, input: &AsyncExpression, _additional: &Self::AdditionalInput) -> Self::Output {
self.async_block_id = Some(input.block.id);
if self.scope_state.is_conditional {
self.emit_err(TypeCheckerError::async_block_in_conditional(input.span));
}
if !matches!(self.scope_state.variant, Some(Variant::AsyncTransition) | Some(Variant::Script)) {
self.emit_err(TypeCheckerError::illegal_async_block_location(input.span));
}
if self.scope_state.already_contains_an_async_block {
self.emit_err(TypeCheckerError::multiple_async_blocks_not_allowed(input.span));
}
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::conflicting_async_call_and_block(input.span));
}
self.visit_block(&input.block);
self.scope_state.already_contains_an_async_block = true;
let mut block_to_function_rewriter =
BlockToFunctionRewriter::new(self.state, self.scope_state.program_name.unwrap());
let (new_function, _) =
block_to_function_rewriter.rewrite_block(&input.block, Symbol::intern("unused"), Variant::AsyncFunction);
let input_types = new_function.input.iter().map(|Input { type_, .. }| type_.clone()).collect();
self.async_function_input_types.insert(
Location::new(self.scope_state.program_name.unwrap(), vec![Symbol::intern(&format!(
"finalize/{}",
self.scope_state.function.unwrap(),
))]),
input_types,
);
self.async_block_id = None;
Type::Future(FutureType::new(Vec::new(), None, false))
}
fn visit_binary(&mut self, input: &BinaryExpression, destination: &Self::AdditionalInput) -> Self::Output {
let assert_same_type = |slf: &Self, t1: &Type, t2: &Type| -> Type {
if t1 == &Type::Err || t2 == &Type::Err {
Type::Err
} else if !t1.eq_user(t2) {
slf.emit_err(TypeCheckerError::operation_types_mismatch(input.op, t1, t2, input.span()));
Type::Err
} else {
t1.clone()
}
};
let infer_numeric_types = |slf: &Self, left_type: &mut Type, right_type: &mut Type| {
use Type::*;
match (&*left_type, &*right_type) {
(Numeric, Numeric) => {
slf.emit_inference_failure_error(left_type, &input.left);
slf.emit_inference_failure_error(right_type, &input.right);
}
(Numeric, Err) => slf.emit_inference_failure_error(left_type, &input.left),
(Err, Numeric) => slf.emit_inference_failure_error(right_type, &input.right),
(Integer(_) | Field | Group | Scalar, Numeric) => {
*right_type = left_type.clone();
slf.state.type_table.insert(input.right.id(), right_type.clone());
if let Expression::Literal(literal) = &input.right {
slf.check_numeric_literal(literal, right_type);
}
}
(Numeric, Integer(_) | Field | Group | Scalar) => {
*left_type = right_type.clone();
slf.state.type_table.insert(input.left.id(), left_type.clone());
if let Expression::Literal(literal) = &input.left {
slf.check_numeric_literal(literal, left_type);
}
}
(Numeric, _) => slf.emit_inference_failure_error(left_type, &input.left),
(_, Numeric) => slf.emit_inference_failure_error(right_type, &input.right),
_ => {}
}
};
match input.op {
BinaryOperation::And | BinaryOperation::Or | BinaryOperation::Nand | BinaryOperation::Nor => {
self.maybe_assert_type(&Type::Boolean, destination, input.span());
self.visit_expression(&input.left, &Some(Type::Boolean));
self.visit_expression(&input.right, &Some(Type::Boolean));
Type::Boolean
}
BinaryOperation::BitwiseAnd | BinaryOperation::BitwiseOr | BinaryOperation::Xor => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
self.assert_bool_int_type(&t1, input.left.span());
self.assert_bool_int_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Add => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
let assert_add_type = |type_: &Type, span: Span| {
if !matches!(type_, Type::Err | Type::Field | Type::Group | Type::Scalar | Type::Integer(_)) {
self.emit_err(TypeCheckerError::type_should_be2(
type_,
"a field, group, scalar, or integer",
span,
));
}
};
assert_add_type(&t1, input.left.span());
assert_add_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Sub => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
self.assert_field_group_int_type(&t1, input.left.span());
self.assert_field_group_int_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Mul => {
let unwrapped_dest = self.unwrap_optional_type(destination);
let expected = if matches!(unwrapped_dest, Some(Type::Group)) { &None } else { &unwrapped_dest };
let mut t1 = self.visit_expression(&input.left, expected);
let mut t2 = self.visit_expression(&input.right, expected);
match (&t1, &t2) {
(Type::Group, Type::Numeric) => infer_numeric_types(self, &mut Type::Scalar, &mut t2),
(Type::Numeric, Type::Group) => infer_numeric_types(self, &mut t1, &mut Type::Scalar),
(Type::Scalar, Type::Numeric) => infer_numeric_types(self, &mut Type::Group, &mut t2),
(Type::Numeric, Type::Scalar) => infer_numeric_types(self, &mut t1, &mut Type::Group),
(_, _) => infer_numeric_types(self, &mut t1, &mut t2),
}
let result_t = match (&t1, &t2) {
(Type::Err, _) | (_, Type::Err) => Type::Err,
(Type::Group, Type::Scalar) | (Type::Scalar, Type::Group) => Type::Group,
(Type::Field, Type::Field) => Type::Field,
(Type::Integer(integer_type1), Type::Integer(integer_type2)) if integer_type1 == integer_type2 => {
t1.clone()
}
_ => {
self.emit_err(TypeCheckerError::mul_types_mismatch(t1, t2, input.span()));
Type::Err
}
};
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Div => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
self.assert_field_int_type(&t1, input.left.span());
self.assert_field_int_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Rem | BinaryOperation::RemWrapped => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
self.assert_int_type(&t1, input.left.span());
self.assert_int_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Mod => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
self.assert_unsigned_type(&t1, input.left.span());
self.assert_unsigned_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Pow => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &None);
if matches!((&t1, &t2), (Type::Field, Type::Numeric) | (Type::Numeric, Type::Field)) {
infer_numeric_types(self, &mut t1, &mut t2);
} else {
if matches!(t1, Type::Numeric) {
self.emit_inference_failure_error(&mut t1, &input.left);
}
if matches!(t2, Type::Numeric) {
self.emit_inference_failure_error(&mut t2, &input.right);
}
}
let ty = match (&t1, &t2) {
(Type::Err, _) | (_, Type::Err) => Type::Err,
(Type::Field, Type::Field) => Type::Field,
(base @ Type::Integer(_), t2) => {
if !matches!(
t2,
Type::Integer(IntegerType::U8)
| Type::Integer(IntegerType::U16)
| Type::Integer(IntegerType::U32)
) {
self.emit_err(TypeCheckerError::pow_types_mismatch(base, t2, input.span()));
}
base.clone()
}
_ => {
self.emit_err(TypeCheckerError::pow_types_mismatch(t1, t2, input.span()));
Type::Err
}
};
self.maybe_assert_type(&ty, destination, input.span());
ty
}
BinaryOperation::Eq | BinaryOperation::Neq => {
let (mut t1, mut t2) =
if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input.right {
let t1 = self.visit_expression(&input.left, &None);
(t1.clone(), self.visit_expression(&input.right, &Some(t1.clone())))
} else if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input.left {
let t2 = self.visit_expression(&input.right, &None);
(self.visit_expression(&input.left, &Some(t2.clone())), t2)
} else {
(self.visit_expression(&input.left, &None), self.visit_expression(&input.right, &None))
};
infer_numeric_types(self, &mut t1, &mut t2);
let _ = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&Type::Boolean, destination, input.span());
Type::Boolean
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Lte | BinaryOperation::Gte => {
let mut t1 = self.visit_expression(&input.left, &None);
let mut t2 = self.visit_expression(&input.right, &None);
infer_numeric_types(self, &mut t1, &mut t2);
let assert_compare_type = |type_: &Type, span: Span| {
if !matches!(type_, Type::Err | Type::Field | Type::Scalar | Type::Integer(_)) {
self.emit_err(TypeCheckerError::type_should_be2(type_, "a field, scalar, or integer", span));
}
};
assert_compare_type(&t1, input.left.span());
assert_compare_type(&t2, input.right.span());
let _ = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&Type::Boolean, destination, input.span());
Type::Boolean
}
BinaryOperation::AddWrapped
| BinaryOperation::SubWrapped
| BinaryOperation::DivWrapped
| BinaryOperation::MulWrapped => {
let operand_expected = self.unwrap_optional_type(destination);
let mut t1 = self.visit_expression(&input.left, &operand_expected);
let mut t2 = self.visit_expression(&input.right, &operand_expected);
infer_numeric_types(self, &mut t1, &mut t2);
self.assert_int_type(&t1, input.left.span());
self.assert_int_type(&t2, input.right.span());
let result_t = assert_same_type(self, &t1, &t2);
self.maybe_assert_type(&result_t, destination, input.span());
result_t
}
BinaryOperation::Shl
| BinaryOperation::ShlWrapped
| BinaryOperation::Shr
| BinaryOperation::ShrWrapped
| BinaryOperation::PowWrapped => {
let operand_expected = self.unwrap_optional_type(destination);
let t1 = self.visit_expression_reject_numeric(&input.left, &operand_expected);
let t2 = self.visit_expression_reject_numeric(&input.right, &None);
self.assert_int_type(&t1, input.left.span());
if !matches!(
&t2,
Type::Err
| Type::Integer(IntegerType::U8)
| Type::Integer(IntegerType::U16)
| Type::Integer(IntegerType::U32)
) {
self.emit_err(TypeCheckerError::shift_type_magnitude(input.op, t2, input.right.span()));
}
t1
}
}
}
fn visit_call(&mut self, input: &CallExpression, expected: &Self::AdditionalInput) -> Self::Output {
let current_program = self.scope_state.program_name.unwrap();
let callee_location = input.function.expect_global_location();
let callee_program = callee_location.program;
let callee_path = callee_location.path.clone();
let Some(func_symbol) = self.state.symbol_table.lookup_function(current_program, callee_location) else {
self.emit_err(TypeCheckerError::unknown_sym("function", input.function.clone(), input.function.span()));
return Type::Err;
};
let func = func_symbol.function.clone();
match self.scope_state.variant.unwrap() {
Variant::AsyncFunction | Variant::Function if !matches!(func.variant, Variant::Inline) => self.emit_err(
TypeCheckerError::can_only_call_inline_function("a `function`, `inline`, or `constructor`", input.span),
),
Variant::Transition | Variant::AsyncTransition
if matches!(func.variant, Variant::Transition)
&& callee_program == self.scope_state.program_name.unwrap() =>
{
self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function(input.span))
}
_ => {}
}
if func.variant == Variant::Inline && callee_program != self.scope_state.program_name.unwrap() {
self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span));
}
if self.async_block_id.is_some() && !matches!(func.variant, Variant::Inline) {
self.emit_err(TypeCheckerError::can_only_call_inline_function("an async block", input.span));
}
let mut ret = if func.variant == Variant::AsyncFunction {
Type::Future(FutureType::new(Vec::new(), Some(callee_location.clone()), false))
} else if func.variant == Variant::AsyncTransition {
let Some(inputs) =
self.async_function_input_types.get(&Location::new(callee_program, vec![Symbol::intern(&format!(
"finalize/{}",
input.function.identifier().name
))]))
else {
self.emit_err(TypeCheckerError::async_function_not_found(input.function.clone(), input.span));
return Type::Future(FutureType::new(Vec::new(), Some(callee_location.clone()), false));
};
let future_type = Type::Future(FutureType::new(inputs.clone(), Some(callee_location.clone()), true));
let fully_inferred_type = match &func.output_type {
Type::Tuple(tup) => Type::Tuple(TupleType::new(
tup.elements()
.iter()
.map(|t| if matches!(t, Type::Future(_)) { future_type.clone() } else { t.clone() })
.collect::<Vec<Type>>(),
)),
Type::Future(_) => future_type,
_ => panic!("Invalid output type for async transition."),
};
self.assert_and_return_type(fully_inferred_type, expected, input.span())
} else {
self.assert_and_return_type(func.output_type, expected, input.span())
};
if func.input.len() != input.arguments.len() {
self.emit_err(TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
input.arguments.len(),
input.span(),
));
}
if func.const_parameters.len() != input.const_arguments.len() {
self.emit_err(TypeCheckerError::incorrect_num_const_args(
"Call",
func.const_parameters.len(),
input.const_arguments.len(),
input.span(),
));
}
for (expected, argument) in func.const_parameters.iter().zip(input.const_arguments.iter()) {
self.visit_expression(argument, &Some(expected.type_().clone()));
}
let (mut input_futures, mut inferred_finalize_inputs) = (Vec::new(), Vec::new());
for (expected, argument) in func.input.iter().zip(input.arguments.iter()) {
let ty = self.visit_expression(argument, &Some(expected.type_().clone()));
if ty == Type::Err {
return Type::Err;
}
if func.variant == Variant::AsyncFunction && matches!(expected.type_(), Type::Future(_)) {
let option_name = match argument {
Expression::Path(path) => Some(path.identifier().name),
Expression::TupleAccess(tuple_access) => {
if let Expression::Path(path) = &tuple_access.tuple {
Some(path.identifier().name)
} else {
None
}
}
_ => None,
};
if let Some(name) = option_name {
match self.scope_state.futures.shift_remove(&name) {
Some(future) => {
self.scope_state.call_location = Some(future);
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(name, argument.span()));
}
}
}
match argument {
Expression::Path(_) | Expression::Call(_) | Expression::TupleAccess(_) => {
match &self.scope_state.call_location {
Some(location) => {
input_futures.push(location.clone());
inferred_finalize_inputs.push(ty);
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(argument, argument.span()));
}
}
}
_ => {
self.emit_err(TypeCheckerError::unknown_future_consumed("unknown", argument.span()));
}
}
} else {
inferred_finalize_inputs.push(ty);
}
}
let caller_program =
self.scope_state.program_name.expect("`program_name` is always set before traversing a program scope");
let caller_function = if self.scope_state.is_constructor {
sym::constructor
} else {
self.scope_state.function.expect("`function` is always set before traversing a function scope")
};
let caller_path = self
.scope_state
.module_name
.iter()
.cloned()
.chain(std::iter::once(caller_function))
.collect::<Vec<Symbol>>();
let caller = Location::new(caller_program, caller_path.clone());
let callee = Location::new(callee_program, callee_path.clone());
self.state.call_graph.add_edge(caller, callee);
if func.variant.is_transition() && self.scope_state.variant == Some(Variant::AsyncTransition) {
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::external_call_after_async("function call", input.span));
}
if self.scope_state.already_contains_an_async_block {
self.emit_err(TypeCheckerError::external_call_after_async("block", input.span));
}
}
if func.variant.is_async_function() {
if self.scope_state.is_conditional {
self.emit_err(TypeCheckerError::async_call_in_conditional(input.span));
}
if !matches!(self.scope_state.variant, Some(Variant::AsyncTransition) | Some(Variant::Script)) {
self.emit_err(TypeCheckerError::async_call_can_only_be_done_from_async_transition(input.span));
}
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::must_call_async_function_once(input.span));
}
if self.scope_state.already_contains_an_async_block {
self.emit_err(TypeCheckerError::conflicting_async_call_and_block(input.span));
}
if !self.scope_state.futures.is_empty() {
self.emit_err(TypeCheckerError::not_all_futures_consumed(
self.scope_state.futures.iter().map(|(f, _)| f).join(", "),
input.span,
));
}
self.state
.symbol_table
.attach_finalizer(
Location::new(callee_program, caller_path.clone()),
Location::new(callee_program, callee_path.clone()),
input_futures,
inferred_finalize_inputs.clone(),
)
.expect("Failed to attach finalizer");
self.async_function_callers
.entry(Location::new(self.scope_state.program_name.unwrap(), callee_path.clone()))
.or_default()
.insert(self.scope_state.location());
self.scope_state.has_called_finalize = true;
ret = Type::Future(FutureType::new(
inferred_finalize_inputs.clone(),
Some(Location::new(callee_program, callee_path.clone())),
true,
));
self.async_function_input_types.insert(
Location::new(callee_program, vec![Symbol::intern(&format!(
"finalize/{}",
caller_path.last().unwrap()
))]),
inferred_finalize_inputs.clone(),
);
self.assert_and_return_type(ret.clone(), expected, input.span());
}
self.scope_state.call_location = Some(Location::new(callee_program, callee_path.clone()));
ret
}
fn visit_cast(&mut self, input: &CastExpression, expected: &Self::AdditionalInput) -> Self::Output {
let expression_type = self.visit_expression_reject_numeric(&input.expression, &None);
let assert_castable_type = |actual: &Type, span: Span| {
if !matches!(
actual,
Type::Integer(_) | Type::Boolean | Type::Field | Type::Group | Type::Scalar | Type::Address | Type::Err,
) {
self.emit_err(TypeCheckerError::type_should_be2(
actual,
"an integer, bool, field, group, scalar, or address",
span,
));
}
};
assert_castable_type(&input.type_, input.span());
assert_castable_type(&expression_type, input.expression.span());
self.maybe_assert_type(&input.type_, expected, input.span());
input.type_.clone()
}
fn visit_composite_init(
&mut self,
input: &CompositeExpression,
additional: &Self::AdditionalInput,
) -> Self::Output {
let composite_location = input.path.expect_global_location();
let composite = self.lookup_composite(composite_location).clone();
let Some(composite) = composite else {
self.emit_err(TypeCheckerError::unknown_sym("struct or record", input.path.clone(), input.path.span()));
return Type::Err;
};
if composite.const_parameters.len() != input.const_arguments.len() {
self.emit_err(TypeCheckerError::incorrect_num_const_args(
"Composite expression",
composite.const_parameters.len(),
input.const_arguments.len(),
input.span(),
));
}
for (expected, argument) in composite.const_parameters.iter().zip(input.const_arguments.iter()) {
self.visit_expression(argument, &Some(expected.type_().clone()));
}
let type_ =
Type::Composite(CompositeType { path: input.path.clone(), const_arguments: input.const_arguments.clone() });
self.maybe_assert_type(&type_, additional, input.path.span());
if composite.members.len() != input.members.len() {
self.emit_err(TypeCheckerError::incorrect_num_composite_members(
composite.members.len(),
input.members.len(),
input.span(),
));
}
for Member { identifier, type_, .. } in composite.members.iter() {
if let Some(actual) = input.members.iter().find(|member| member.identifier.name == identifier.name) {
match &actual.expression {
None => {
let current_program = self.scope_state.program_name.unwrap();
let var = self.state.symbol_table.lookup_local(actual.identifier.name).or_else(|| {
self.state
.symbol_table
.lookup_global(
current_program,
&Location::new(current_program, vec![actual.identifier.name]),
)
.cloned()
});
if let Some(var) = var {
let ty = var.type_.expect("must be known by now");
if var.declaration == VariableType::Storage && !ty.is_vector() && !ty.is_mapping() {
self.check_access_allowed("storage access", true, input.span());
}
self.maybe_assert_type(&ty, &Some(type_.clone()), input.span());
ty.clone()
} else {
self.emit_err(TypeCheckerError::unknown_sym("variable", input, input.span()));
Type::Err
};
}
Some(expr) => {
self.visit_expression(expr, &Some(type_.clone()));
}
};
} else {
self.emit_err(TypeCheckerError::missing_composite_member(
composite.identifier,
identifier,
input.span(),
));
};
}
if composite.is_record {
if composite_location.program != self.scope_state.program_name.unwrap() {
self.state
.handler
.emit_err(TypeCheckerError::cannot_instantiate_external_record(composite_location, input.span()));
}
if self.scope_state.variant == Some(Variant::AsyncFunction) {
self.state
.handler
.emit_err(TypeCheckerError::records_not_allowed_inside_async("function", input.span()));
}
if self.async_block_id.is_some() {
self.state.handler.emit_err(TypeCheckerError::records_not_allowed_inside_async("block", input.span()));
}
input.members.iter().filter(|init| init.identifier.name == sym::owner).for_each(|init| {
if let Some(Expression::Intrinsic(intr)) = &init.expression
&& let IntrinsicExpression { name: sym::_self_caller, .. } = &**intr
{
self.emit_warning(TypeCheckerWarning::caller_as_record_owner(input.path.clone(), intr.span()));
}
});
}
type_
}
fn visit_err(&mut self, _input: &ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output {
Type::Err
}
fn visit_path(&mut self, input: &Path, expected: &Self::AdditionalInput) -> Self::Output {
let current_program = self.scope_state.program_name.unwrap();
let var = self.state.symbol_table.lookup_path(current_program, input);
if let Some(var) = var {
let ty = var.type_.expect("must be known at this point");
if var.declaration == VariableType::Storage && !ty.is_vector() && !ty.is_mapping() {
self.check_access_allowed("storage access", true, input.span());
}
self.maybe_assert_type(&ty, expected, input.span());
ty.clone()
} else {
self.emit_err(TypeCheckerError::unknown_sym("variable", input, input.span()));
Type::Err
}
}
fn visit_literal(&mut self, input: &Literal, expected: &Self::AdditionalInput) -> Self::Output {
let span = input.span();
macro_rules! parse_and_return {
($ty:ty, $variant:expr, $str:expr, $label:expr) => {{
self.parse_integer_literal::<$ty>($str, span, $label);
Type::Integer($variant)
}};
}
let type_ = match &input.variant {
LiteralVariant::Address(..) => Type::Address,
LiteralVariant::Boolean(..) => Type::Boolean,
LiteralVariant::Field(..) => Type::Field,
LiteralVariant::Group(s) => {
let trimmed = s.trim_start_matches('-').trim_start_matches('0');
if !trimmed.is_empty()
&& format!("{trimmed}group")
.parse::<snarkvm::prelude::Group<snarkvm::prelude::TestnetV0>>()
.is_err()
{
self.emit_err(TypeCheckerError::invalid_int_value(trimmed, "group", span));
}
Type::Group
}
LiteralVariant::Integer(kind, string) => match kind {
IntegerType::U8 => parse_and_return!(u8, IntegerType::U8, string, "u8"),
IntegerType::U16 => parse_and_return!(u16, IntegerType::U16, string, "u16"),
IntegerType::U32 => parse_and_return!(u32, IntegerType::U32, string, "u32"),
IntegerType::U64 => parse_and_return!(u64, IntegerType::U64, string, "u64"),
IntegerType::U128 => parse_and_return!(u128, IntegerType::U128, string, "u128"),
IntegerType::I8 => parse_and_return!(i8, IntegerType::I8, string, "i8"),
IntegerType::I16 => parse_and_return!(i16, IntegerType::I16, string, "i16"),
IntegerType::I32 => parse_and_return!(i32, IntegerType::I32, string, "i32"),
IntegerType::I64 => parse_and_return!(i64, IntegerType::I64, string, "i64"),
IntegerType::I128 => parse_and_return!(i128, IntegerType::I128, string, "i128"),
},
LiteralVariant::None => {
if let Some(ty @ Type::Optional(_)) = expected {
ty.clone()
} else if let Some(ty) = expected {
self.emit_err(TypeCheckerError::none_found_non_optional(format!("{ty}"), span));
Type::Err
} else {
self.emit_err(TypeCheckerError::could_not_determine_type(format!("{input}"), span));
Type::Err
}
}
LiteralVariant::Scalar(..) => Type::Scalar,
LiteralVariant::Signature(..) => Type::Signature,
LiteralVariant::String(..) => Type::String,
LiteralVariant::Unsuffixed(_) => match expected {
Some(ty @ Type::Integer(_) | ty @ Type::Field | ty @ Type::Group | ty @ Type::Scalar) => {
self.check_numeric_literal(input, ty);
ty.clone()
}
Some(ty @ Type::Optional(opt)) => {
let inner = &opt.inner;
match &**inner {
Type::Integer(_) | Type::Field | Type::Group | Type::Scalar => {
self.check_numeric_literal(input, inner);
*inner.clone()
}
_ => {
self.emit_err(TypeCheckerError::unexpected_unsuffixed_numeral(
format!("type `{ty}`"),
span,
));
Type::Err
}
}
}
Some(ty) => {
self.emit_err(TypeCheckerError::unexpected_unsuffixed_numeral(format!("type `{ty}`"), span));
Type::Err
}
None => Type::Numeric,
},
};
self.maybe_assert_type(&type_, expected, span);
type_
}
fn visit_ternary(&mut self, input: &TernaryExpression, expected: &Self::AdditionalInput) -> Self::Output {
self.visit_expression(&input.condition, &Some(Type::Boolean));
let (t1, t2) = if expected.is_some() {
(
self.visit_expression_reject_numeric(&input.if_true, expected),
self.visit_expression_reject_numeric(&input.if_false, expected),
)
} else if input.if_false.is_none_expr() {
let t1 = self.visit_expression(&input.if_true, &None);
if matches!(t1, Type::Optional(_)) {
(t1.clone(), self.visit_expression(&input.if_false, &Some(t1.clone())))
} else {
(
t1.clone(),
self.visit_expression(
&input.if_false,
&Some(Type::Optional(OptionalType { inner: Box::new(t1.clone()) })),
),
)
}
} else if input.if_true.is_none_expr() {
let t2 = self.visit_expression(&input.if_false, &None);
if matches!(t2, Type::Optional(_)) {
(t2.clone(), self.visit_expression(&input.if_true, &Some(t2.clone())))
} else {
(
t2.clone(),
self.visit_expression(
&input.if_true,
&Some(Type::Optional(OptionalType { inner: Box::new(t2.clone()) })),
),
)
}
} else {
(
self.visit_expression_reject_numeric(&input.if_true, &None),
self.visit_expression_reject_numeric(&input.if_false, &None),
)
};
let typ = if t1 == Type::Err || t2 == Type::Err {
Type::Err
} else if !t1.can_coerce_to(&t2) && !t2.can_coerce_to(&t1) {
self.emit_err(TypeCheckerError::ternary_branch_mismatch(t1, t2, input.span()));
Type::Err
} else if let Some(expected) = expected {
expected.clone()
} else if t1.can_coerce_to(&t2) {
t2
} else {
t1
};
if self.is_external_record(&typ) {
self.emit_err(TypeCheckerError::ternary_over_external_records(&typ, input.span));
}
if let Type::Tuple(tuple) = &typ
&& tuple.elements().iter().any(|ty| self.is_external_record(ty))
{
self.emit_err(TypeCheckerError::ternary_over_external_records(&typ, input.span));
}
typ
}
fn visit_tuple(&mut self, input: &TupleExpression, expected: &Self::AdditionalInput) -> Self::Output {
if let Some(expected) = expected {
if let Type::Tuple(expected_types) = expected {
if expected_types.length() != input.elements.len() {
self.emit_err(TypeCheckerError::incorrect_tuple_length(
expected_types.length(),
input.elements.len(),
input.span(),
));
}
input.elements.iter().zip(expected_types.elements()).for_each(|(expr, expected_el_ty)| {
if matches!(expr, Expression::Tuple(_)) {
self.emit_err(TypeCheckerError::nested_tuple_expression(expr.span()));
}
self.visit_expression(expr, &Some(expected_el_ty.clone()));
});
expected.clone()
} else {
let field_types = input
.elements
.iter()
.map(|field| {
let ty = self.visit_expression(field, &None);
if ty == Type::Numeric {
self.emit_err(TypeCheckerError::could_not_determine_type(field.clone(), field.span()));
Type::Err
} else {
ty
}
})
.collect::<Vec<_>>();
if field_types.iter().all(|f| *f != Type::Err) {
let tuple_type = Type::Tuple(TupleType::new(field_types));
self.emit_err(TypeCheckerError::type_should_be2(tuple_type, expected, input.span()));
}
expected.clone()
}
} else {
input.elements.iter().for_each(|expr| {
if matches!(expr, Expression::Tuple(_)) {
self.emit_err(TypeCheckerError::nested_tuple_expression(expr.span()));
}
});
Type::Tuple(TupleType::new(
input
.elements
.iter()
.map(|field| {
let ty = self.visit_expression(field, &None);
if ty == Type::Numeric {
self.emit_err(TypeCheckerError::could_not_determine_type(field.clone(), field.span()));
Type::Err
} else {
ty
}
})
.collect::<Vec<_>>(),
))
}
}
fn visit_unary(&mut self, input: &UnaryExpression, destination: &Self::AdditionalInput) -> Self::Output {
let operand_expected = self.unwrap_optional_type(destination);
let assert_signed_int = |slf: &mut Self, type_: &Type| {
if !matches!(
type_,
Type::Err
| Type::Integer(IntegerType::I8)
| Type::Integer(IntegerType::I16)
| Type::Integer(IntegerType::I32)
| Type::Integer(IntegerType::I64)
| Type::Integer(IntegerType::I128)
) {
slf.emit_err(TypeCheckerError::type_should_be2(type_, "a signed integer", input.span()));
}
};
let ty = match input.op {
UnaryOperation::Abs => {
let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
assert_signed_int(self, &type_);
type_
}
UnaryOperation::AbsWrapped => {
let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
assert_signed_int(self, &type_);
type_
}
UnaryOperation::Double => {
let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
if !matches!(&type_, Type::Err | Type::Field | Type::Group) {
self.emit_err(TypeCheckerError::type_should_be2(&type_, "a field or group", input.span()));
}
type_
}
UnaryOperation::Inverse => {
let mut type_ = self.visit_expression(&input.receiver, &operand_expected);
if type_ == Type::Numeric {
type_ = Type::Field;
self.state.type_table.insert(input.receiver.id(), Type::Field);
} else {
self.assert_type(&type_, &Type::Field, input.span());
}
type_
}
UnaryOperation::Negate => {
let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
if !matches!(
&type_,
Type::Err
| Type::Integer(IntegerType::I8)
| Type::Integer(IntegerType::I16)
| Type::Integer(IntegerType::I32)
| Type::Integer(IntegerType::I64)
| Type::Integer(IntegerType::I128)
| Type::Group
| Type::Field
) {
self.emit_err(TypeCheckerError::type_should_be2(
&type_,
"a signed integer, group, or field",
input.receiver.span(),
));
}
type_
}
UnaryOperation::Not => {
let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
if !matches!(&type_, Type::Err | Type::Boolean | Type::Integer(_)) {
self.emit_err(TypeCheckerError::type_should_be2(&type_, "a bool or integer", input.span()));
}
type_
}
UnaryOperation::Square => {
let mut type_ = self.visit_expression(&input.receiver, &operand_expected);
if type_ == Type::Numeric {
type_ = Type::Field;
self.state.type_table.insert(input.receiver.id(), Type::Field);
} else {
self.assert_type(&type_, &Type::Field, input.span());
}
type_
}
UnaryOperation::SquareRoot => {
let mut type_ = self.visit_expression(&input.receiver, &operand_expected);
if type_ == Type::Numeric {
type_ = Type::Field;
self.state.type_table.insert(input.receiver.id(), Type::Field);
} else {
self.assert_type(&type_, &Type::Field, input.span());
}
type_
}
UnaryOperation::ToXCoordinate | UnaryOperation::ToYCoordinate => {
let _operand_type = self.visit_expression(&input.receiver, &Some(Type::Group));
self.maybe_assert_type(&Type::Field, destination, input.span());
Type::Field
}
};
self.maybe_assert_type(&ty, destination, input.span());
ty
}
fn visit_unit(&mut self, _input: &UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
Type::Unit
}
fn visit_statement(&mut self, input: &Statement) {
if self.scope_state.has_return {
self.emit_err(TypeCheckerError::unreachable_code_after_return(input.span()));
return;
}
match input {
Statement::Assert(stmt) => self.visit_assert(stmt),
Statement::Assign(stmt) => self.visit_assign(stmt),
Statement::Block(stmt) => self.visit_block(stmt),
Statement::Conditional(stmt) => self.visit_conditional(stmt),
Statement::Const(stmt) => self.visit_const(stmt),
Statement::Definition(stmt) => self.visit_definition(stmt),
Statement::Expression(stmt) => self.visit_expression_statement(stmt),
Statement::Iteration(stmt) => self.visit_iteration(stmt),
Statement::Return(stmt) => self.visit_return(stmt),
}
}
fn visit_assert(&mut self, input: &AssertStatement) {
match &input.variant {
AssertVariant::Assert(expr) => {
let _type = self.visit_expression(expr, &Some(Type::Boolean));
}
AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
let t1 = self.visit_expression_reject_numeric(left, &None);
let t2 = self.visit_expression_reject_numeric(right, &None);
if t1 != Type::Err && t2 != Type::Err && !t1.eq_user(&t2) {
let op =
if matches!(input.variant, AssertVariant::AssertEq(..)) { "assert_eq" } else { "assert_neq" };
self.emit_err(TypeCheckerError::operation_types_mismatch(op, &t1, &t2, input.span()));
}
}
}
}
fn visit_assign(&mut self, input: &AssignStatement) {
let assign_target_info = self.visit_expression_assign(&input.place);
let value = &input.value;
if assign_target_info.kind == AssignTargetKind::Err {
self.visit_expression(value, &None);
return;
}
if assign_target_info.kind == AssignTargetKind::Storage
&& !assign_target_info.ty.is_vector()
&& !assign_target_info.ty.is_mapping()
{
self.check_access_allowed("storage write", true, input.place.span())
}
if assign_target_info.kind == AssignTargetKind::ExternalStorage {
self.emit_err(TypeCheckerError::cannot_modify_external_storage_variable(input.span()));
}
let expected_rhs_ty = match (&assign_target_info.kind, value.is_none_expr(), &assign_target_info.ty) {
(AssignTargetKind::Storage, false, Type::Optional(OptionalType { inner })) => {
Some(*inner.clone())
}
_ => {
Some(assign_target_info.ty)
}
};
self.visit_expression(value, &expected_rhs_ty);
}
fn visit_block(&mut self, input: &Block) {
self.in_scope(input.id, |slf| {
input.statements.iter().for_each(|stmt| slf.visit_statement(stmt));
});
}
fn visit_conditional(&mut self, input: &ConditionalStatement) {
self.visit_expression(&input.condition, &Some(Type::Boolean));
let mut then_block_has_return = false;
let mut otherwise_block_has_return = false;
let previous_has_return = core::mem::replace(&mut self.scope_state.has_return, then_block_has_return);
let previous_is_conditional = core::mem::replace(&mut self.scope_state.is_conditional, true);
self.in_conditional_scope(|slf| slf.visit_block(&input.then));
then_block_has_return = self.scope_state.has_return;
if let Some(otherwise) = &input.otherwise {
self.scope_state.has_return = otherwise_block_has_return;
match &**otherwise {
Statement::Block(stmt) => {
self.in_conditional_scope(|slf| slf.visit_block(stmt));
}
Statement::Conditional(stmt) => self.visit_conditional(stmt),
_ => unreachable!("Else-case can only be a block or conditional statement."),
}
otherwise_block_has_return = self.scope_state.has_return;
}
self.scope_state.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
self.scope_state.is_conditional = previous_is_conditional;
}
fn visit_const(&mut self, input: &ConstDeclaration) {
self.visit_type(&input.type_);
if self.contains_optional_type(&input.type_) {
self.emit_err(TypeCheckerError::const_cannot_be_optional(input.span));
}
match &input.type_ {
Type::Unit => self.emit_err(TypeCheckerError::lhs_must_be_identifier_or_tuple(input.span)),
Type::Tuple(tuple) => match tuple.length() {
0 | 1 => unreachable!("Parsing guarantees that tuple types have at least two elements."),
_ => {
if tuple.elements().iter().any(|type_| matches!(type_, Type::Tuple(_))) {
self.emit_err(TypeCheckerError::nested_tuple_type(input.span))
}
}
},
Type::Mapping(_) | Type::Err => unreachable!(
"Parsing guarantees that `mapping` and `err` types are not present at this location in the AST."
),
_ => (), }
self.visit_expression(&input.value, &Some(input.type_.clone()));
if self.scope_state.function.is_some() {
self.state.symbol_table.set_local_type(input.place.name, input.type_.clone());
}
}
fn visit_definition(&mut self, input: &DefinitionStatement) {
if let Some(ty) = &input.type_ {
self.visit_type(ty);
self.assert_type_is_valid(ty, input.span);
}
match &input.type_ {
Some(Type::Tuple(tuple)) => match tuple.length() {
0 | 1 => unreachable!("Parsing guarantees that tuple types have at least two elements."),
_ => {
for type_ in tuple.elements() {
if matches!(type_, Type::Tuple(_)) {
self.emit_err(TypeCheckerError::nested_tuple_type(input.span))
}
}
}
},
Some(Type::Mapping(_)) | Some(Type::Err) => unreachable!(
"Parsing guarantees that `mapping` and `err` types are not present at this location in the AST."
),
_ => (), }
let inferred_type = self.visit_expression_reject_numeric(&input.value, &input.type_);
if inferred_type.is_vector() {
self.emit_err(TypeCheckerError::storage_vectors_cannot_be_moved_or_assigned(input.value.span()));
}
match &input.place {
DefinitionPlace::Single(identifier) => {
self.set_local_type(
Some(inferred_type.clone()),
identifier,
input.type_.clone().unwrap_or(inferred_type),
);
}
DefinitionPlace::Multiple(identifiers) => {
let tuple_type = match (&input.type_, inferred_type.clone()) {
(Some(Type::Tuple(tuple_type)), _) => tuple_type.clone(),
(None, Type::Tuple(tuple_type)) => tuple_type.clone(),
_ => {
return;
}
};
if identifiers.len() != tuple_type.length() {
return self.emit_err(TypeCheckerError::incorrect_num_tuple_elements(
identifiers.len(),
tuple_type.length(),
input.span(),
));
}
for (i, identifier) in identifiers.iter().enumerate() {
let inferred = if let Type::Tuple(inferred_tuple) = &inferred_type {
inferred_tuple.elements().get(i).cloned().unwrap_or_default()
} else {
Type::Err
};
self.set_local_type(Some(inferred), identifier, tuple_type.elements()[i].clone());
}
}
}
}
fn visit_expression_statement(&mut self, input: &ExpressionStatement) {
if !matches!(input.expression, Expression::Call(_) | Expression::Intrinsic(_) | Expression::Unit(_)) {
self.emit_err(TypeCheckerError::expression_statement_must_be_function_call(input.span()));
} else {
self.visit_expression(&input.expression, &None);
}
}
fn visit_iteration(&mut self, input: &IterationStatement) {
if let Some(ty) = &input.type_ {
self.visit_type(ty);
self.assert_int_type(ty, input.variable.span);
}
let start_ty = self.visit_expression(&input.start, &input.type_.clone());
let stop_ty = self.visit_expression(&input.stop, &input.type_.clone());
self.assert_int_type(&start_ty, input.start.span());
self.assert_int_type(&stop_ty, input.stop.span());
if start_ty != stop_ty {
self.emit_err(TypeCheckerError::range_bounds_type_mismatch(input.start.span() + input.stop.span()));
}
let iterator_ty = input.type_.clone().unwrap_or(start_ty);
self.state.type_table.insert(input.variable.id(), iterator_ty.clone());
self.in_scope(input.id(), |slf| {
slf.state.symbol_table.set_local_type(input.variable.name, iterator_ty.clone());
let prior_has_return = core::mem::take(&mut slf.scope_state.has_return);
let prior_has_finalize = core::mem::take(&mut slf.scope_state.has_called_finalize);
slf.visit_block(&input.block);
if slf.scope_state.has_return {
slf.emit_err(TypeCheckerError::loop_body_contains_return(input.span()));
}
if slf.scope_state.has_called_finalize {
slf.emit_err(TypeCheckerError::loop_body_contains_async("function call", input.span()));
}
if slf.scope_state.already_contains_an_async_block {
slf.emit_err(TypeCheckerError::loop_body_contains_async("block expression", input.span()));
}
slf.scope_state.has_return = prior_has_return;
slf.scope_state.has_called_finalize = prior_has_finalize;
});
}
fn visit_return(&mut self, input: &ReturnStatement) {
if self.async_block_id.is_some() {
return self.emit_err(TypeCheckerError::async_block_cannot_return(input.span()));
}
if self.scope_state.is_constructor {
if !matches!(input.expression, Expression::Unit(..)) {
self.emit_err(TypeCheckerError::constructor_can_only_return_unit(&input.expression, input.span));
}
return;
}
let caller_name = self.scope_state.function.expect("`self.function` is set every time a function is visited.");
let caller_path =
self.scope_state.module_name.iter().cloned().chain(std::iter::once(caller_name)).collect::<Vec<Symbol>>();
let current_program = self.scope_state.program_name.unwrap();
let func_symbol = self
.state
.symbol_table
.lookup_function(current_program, &Location::new(current_program, caller_path.clone()))
.expect("The symbol table creator should already have visited all functions.");
let mut return_type = func_symbol.function.output_type.clone();
if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.has_called_finalize {
let inferred_future_type = Future(FutureType::new(
if let Some(finalizer) = &func_symbol.finalizer { finalizer.inferred_inputs.clone() } else { vec![] },
Some(Location::new(current_program, caller_path)),
true,
));
let inferred = match return_type.clone() {
Future(_) => inferred_future_type,
Tuple(tuple) => Tuple(TupleType::new(
tuple
.elements()
.iter()
.map(|t| if matches!(t, Future(_)) { inferred_future_type.clone() } else { t.clone() })
.collect::<Vec<Type>>(),
)),
_ => {
return self.emit_err(TypeCheckerError::async_transition_missing_future_to_return(input.span()));
}
};
return_type = self.assert_and_return_type(inferred, &Some(return_type), input.span());
}
if matches!(input.expression, Expression::Unit(..)) {
if return_type != Type::Unit {
return self.emit_err(TypeCheckerError::missing_return(input.span()));
}
}
self.visit_expression(&input.expression, &Some(return_type));
self.scope_state.has_return = true;
}
}