use crate::{CallGraph, StructGraph, SymbolTable};
use leo_ast::{CoreFunction, Identifier, IntegerType, MappingType, Node, Type, Variant};
use leo_errors::{emitter::Handler, TypeCheckerError};
use leo_span::{Span, Symbol};
use itertools::Itertools;
use std::cell::RefCell;
pub struct TypeChecker<'a> {
pub(crate) symbol_table: RefCell<SymbolTable>,
pub(crate) struct_graph: StructGraph,
pub(crate) call_graph: CallGraph,
pub(crate) handler: &'a Handler,
pub(crate) function: Option<Symbol>,
pub(crate) variant: Option<Variant>,
pub(crate) has_return: bool,
pub(crate) has_finalize: bool,
pub(crate) is_finalize: bool,
pub(crate) is_imported: bool,
pub(crate) is_return: bool,
}
const BOOLEAN_TYPE: Type = Type::Boolean;
const FIELD_TYPE: Type = Type::Field;
const GROUP_TYPE: Type = Type::Group;
const SCALAR_TYPE: Type = Type::Scalar;
const INT_TYPES: [Type; 10] = [
Type::Integer(IntegerType::I8),
Type::Integer(IntegerType::I16),
Type::Integer(IntegerType::I32),
Type::Integer(IntegerType::I64),
Type::Integer(IntegerType::I128),
Type::Integer(IntegerType::U8),
Type::Integer(IntegerType::U16),
Type::Integer(IntegerType::U32),
Type::Integer(IntegerType::U64),
Type::Integer(IntegerType::U128),
];
const SIGNED_INT_TYPES: [Type; 5] = [
Type::Integer(IntegerType::I8),
Type::Integer(IntegerType::I16),
Type::Integer(IntegerType::I32),
Type::Integer(IntegerType::I64),
Type::Integer(IntegerType::I128),
];
const UNSIGNED_INT_TYPES: [Type; 5] = [
Type::Integer(IntegerType::U8),
Type::Integer(IntegerType::U16),
Type::Integer(IntegerType::U32),
Type::Integer(IntegerType::U64),
Type::Integer(IntegerType::U128),
];
const MAGNITUDE_TYPES: [Type; 3] =
[Type::Integer(IntegerType::U8), Type::Integer(IntegerType::U16), Type::Integer(IntegerType::U32)];
impl<'a> TypeChecker<'a> {
pub fn new(symbol_table: SymbolTable, handler: &'a Handler) -> Self {
let struct_names = symbol_table.structs.keys().cloned().collect();
let function_names = symbol_table.functions.keys().cloned().collect();
Self {
symbol_table: RefCell::new(symbol_table),
struct_graph: StructGraph::new(struct_names),
call_graph: CallGraph::new(function_names),
handler,
function: None,
variant: None,
has_return: false,
has_finalize: false,
is_finalize: false,
is_imported: false,
is_return: false,
}
}
pub(crate) fn enter_scope(&mut self, index: usize) {
let previous_symbol_table = std::mem::take(&mut self.symbol_table);
self.symbol_table.swap(previous_symbol_table.borrow().lookup_scope_by_index(index).unwrap());
self.symbol_table.borrow_mut().parent = Some(Box::new(previous_symbol_table.into_inner()));
}
pub(crate) fn create_child_scope(&mut self) -> usize {
let scope_index = self.symbol_table.borrow_mut().insert_block();
self.enter_scope(scope_index);
scope_index
}
pub(crate) fn exit_scope(&mut self, index: usize) {
let previous_symbol_table = *self.symbol_table.borrow_mut().parent.take().unwrap();
self.symbol_table.swap(previous_symbol_table.lookup_scope_by_index(index).unwrap());
self.symbol_table = RefCell::new(previous_symbol_table);
}
pub(crate) fn emit_err(&self, err: TypeCheckerError) {
self.handler.emit_err(err);
}
fn check_type(&self, is_valid: impl Fn(&Type) -> bool, error_string: String, type_: &Option<Type>, span: Span) {
if let Some(type_) = type_ {
if !is_valid(type_) {
self.emit_err(TypeCheckerError::expected_one_type_of(error_string, type_, span));
}
}
}
pub(crate) fn check_eq_types(&self, t1: &Option<Type>, t2: &Option<Type>, span: Span) {
match (t1, t2) {
(Some(t1), Some(t2)) if !Type::eq_flat(t1, t2) => {
self.emit_err(TypeCheckerError::type_should_be(t1, t2, span))
}
(Some(type_), None) | (None, Some(type_)) => {
self.emit_err(TypeCheckerError::type_should_be("no type", type_, span))
}
_ => {}
}
}
pub(crate) fn assert_and_return_type(&self, actual: Type, expected: &Option<Type>, span: Span) -> Type {
if let Some(expected) = expected {
if !actual.eq_flat(expected) {
self.emit_err(TypeCheckerError::type_should_be(actual.clone(), expected, span));
}
}
actual
}
pub(crate) fn assert_type(&self, actual: &Option<Type>, expected: &Type, span: Span) {
self.check_type(|actual: &Type| actual.eq_flat(expected), expected.to_string(), actual, span)
}
pub(crate) fn assert_bool_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_: &Type| BOOLEAN_TYPE.eq(type_), BOOLEAN_TYPE.to_string(), type_, span)
}
pub(crate) fn assert_field_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_: &Type| FIELD_TYPE.eq(type_), FIELD_TYPE.to_string(), type_, span)
}
pub(crate) fn assert_group_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_: &Type| GROUP_TYPE.eq(type_), GROUP_TYPE.to_string(), type_, span)
}
pub(crate) fn assert_scalar_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_: &Type| SCALAR_TYPE.eq(type_), SCALAR_TYPE.to_string(), type_, span)
}
pub(crate) fn assert_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_: &Type| INT_TYPES.contains(type_), types_to_string(&INT_TYPES), type_, span)
}
pub(crate) fn assert_signed_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| SIGNED_INT_TYPES.contains(type_),
types_to_string(&SIGNED_INT_TYPES),
type_,
span,
)
}
pub(crate) fn assert_unsigned_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| UNSIGNED_INT_TYPES.contains(type_),
types_to_string(&UNSIGNED_INT_TYPES),
type_,
span,
)
}
pub(crate) fn assert_magnitude_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_: &Type| MAGNITUDE_TYPES.contains(type_), types_to_string(&MAGNITUDE_TYPES), type_, span)
}
pub(crate) fn assert_bool_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| BOOLEAN_TYPE.eq(type_) | INT_TYPES.contains(type_),
format!("{BOOLEAN_TYPE}, {}", types_to_string(&INT_TYPES)),
type_,
span,
)
}
pub(crate) fn assert_field_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| FIELD_TYPE.eq(type_) | INT_TYPES.contains(type_),
format!("{FIELD_TYPE}, {}", types_to_string(&INT_TYPES)),
type_,
span,
)
}
pub(crate) fn assert_field_group_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| FIELD_TYPE.eq(type_) | GROUP_TYPE.eq(type_),
format!("{FIELD_TYPE}, {GROUP_TYPE}"),
type_,
span,
)
}
pub(crate) fn assert_field_group_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| FIELD_TYPE.eq(type_) | GROUP_TYPE.eq(type_) | INT_TYPES.contains(type_),
format!("{FIELD_TYPE}, {GROUP_TYPE}, {}", types_to_string(&INT_TYPES),),
type_,
span,
)
}
pub(crate) fn assert_field_group_signed_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| FIELD_TYPE.eq(type_) | GROUP_TYPE.eq(type_) | SIGNED_INT_TYPES.contains(type_),
format!("{FIELD_TYPE}, {GROUP_TYPE}, {}", types_to_string(&SIGNED_INT_TYPES),),
type_,
span,
)
}
pub(crate) fn assert_field_scalar_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| FIELD_TYPE.eq(type_) | SCALAR_TYPE.eq(type_) | INT_TYPES.contains(type_),
format!("{FIELD_TYPE}, {SCALAR_TYPE}, {}", types_to_string(&INT_TYPES),),
type_,
span,
)
}
pub(crate) fn assert_field_group_scalar_int_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(
|type_: &Type| {
FIELD_TYPE.eq(type_) | GROUP_TYPE.eq(type_) | SCALAR_TYPE.eq(type_) | INT_TYPES.contains(type_)
},
format!("{}, {}, {}, {}", FIELD_TYPE, GROUP_TYPE, SCALAR_TYPE, types_to_string(&INT_TYPES),),
type_,
span,
)
}
pub(crate) fn get_core_function_call(&self, struct_: &Type, function: &Identifier) -> Option<CoreFunction> {
if let Type::Identifier(ident) = struct_ {
match CoreFunction::from_symbols(ident.name, function.name) {
None => {
self.emit_err(TypeCheckerError::invalid_core_function(ident.name, function.name, ident.span()));
}
Some(core_instruction) => return Some(core_instruction),
}
}
None
}
pub(crate) fn check_core_function_call(
&self,
core_function: CoreFunction,
arguments: &[(Option<Type>, Span)],
function_span: Span,
) -> Option<Type> {
if arguments.len() != core_function.num_args() {
self.emit_err(TypeCheckerError::incorrect_num_args_to_call(
core_function.num_args(),
arguments.len(),
function_span,
));
return None;
}
let check_not_mapping_tuple_err_unit = |type_: &Option<Type>, span: &Span| {
self.check_type(
|type_: &Type| !matches!(type_, Type::Mapping(_) | Type::Tuple(_) | Type::Err | Type::Unit),
"address, boolean, field, group, struct, integer, scalar, scalar, string".to_string(),
type_,
*span,
);
};
let check_pedersen_64_bit_input = |type_: &Option<Type>, span: &Span| {
self.check_type(
|type_: &Type| {
matches!(
type_,
Type::Boolean
| Type::Integer(IntegerType::I8)
| Type::Integer(IntegerType::I16)
| Type::Integer(IntegerType::I32)
| Type::Integer(IntegerType::I64)
| Type::Integer(IntegerType::U8)
| Type::Integer(IntegerType::U16)
| Type::Integer(IntegerType::U32)
| Type::Integer(IntegerType::U64)
| Type::String
)
},
"boolean, integer (up to 64 bits), string".to_string(),
type_,
*span,
);
};
let check_pedersen_128_bit_input = |type_: &Option<Type>, span: &Span| {
self.check_type(
|type_: &Type| matches!(type_, Type::Boolean | Type::Integer(_) | Type::String),
"boolean, integer, string".to_string(),
type_,
*span,
);
};
match core_function {
CoreFunction::BHP256Commit
| CoreFunction::BHP512Commit
| CoreFunction::BHP768Commit
| CoreFunction::BHP1024Commit => {
check_not_mapping_tuple_err_unit(&arguments[0].0, &arguments[0].1);
self.assert_scalar_type(&arguments[1].0, arguments[1].1);
Some(Type::Field)
}
CoreFunction::BHP256Hash
| CoreFunction::BHP512Hash
| CoreFunction::BHP768Hash
| CoreFunction::BHP1024Hash => {
check_not_mapping_tuple_err_unit(&arguments[0].0, &arguments[0].1);
Some(Type::Field)
}
CoreFunction::Pedersen64Commit => {
check_pedersen_64_bit_input(&arguments[0].0, &arguments[0].1);
self.assert_scalar_type(&arguments[1].0, arguments[1].1);
Some(Type::Group)
}
CoreFunction::Pedersen64Hash => {
check_pedersen_64_bit_input(&arguments[0].0, &arguments[0].1);
Some(Type::Field)
}
CoreFunction::Pedersen128Commit => {
check_pedersen_128_bit_input(&arguments[0].0, &arguments[0].1);
self.assert_scalar_type(&arguments[1].0, arguments[1].1);
Some(Type::Group)
}
CoreFunction::Pedersen128Hash => {
check_pedersen_128_bit_input(&arguments[0].0, &arguments[0].1);
Some(Type::Field)
}
CoreFunction::Poseidon2Hash | CoreFunction::Poseidon4Hash | CoreFunction::Poseidon8Hash => {
check_not_mapping_tuple_err_unit(&arguments[0].0, &arguments[0].1);
Some(Type::Field)
}
CoreFunction::MappingGet => {
if !self.is_finalize {
self.handler
.emit_err(TypeCheckerError::invalid_operation_outside_finalize("Mapping::get", function_span))
}
if let Some(mapping_type) = self.assert_mapping_type(&arguments[0].0, arguments[0].1) {
self.assert_type(&arguments[1].0, &mapping_type.key, arguments[1].1);
Some(*mapping_type.value)
} else {
None
}
}
CoreFunction::MappingGetOrInit => {
if !self.is_finalize {
self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize(
"Mapping::get_or",
function_span,
))
}
if let Some(mapping_type) = self.assert_mapping_type(&arguments[0].0, arguments[0].1) {
self.assert_type(&arguments[1].0, &mapping_type.key, arguments[1].1);
self.assert_type(&arguments[2].0, &mapping_type.value, arguments[2].1);
Some(*mapping_type.value)
} else {
None
}
}
CoreFunction::MappingSet => {
if !self.is_finalize {
self.handler
.emit_err(TypeCheckerError::invalid_operation_outside_finalize("Mapping::set", function_span))
}
if let Some(mapping_type) = self.assert_mapping_type(&arguments[0].0, arguments[0].1) {
self.assert_type(&arguments[1].0, &mapping_type.key, arguments[1].1);
self.assert_type(&arguments[2].0, &mapping_type.value, arguments[2].1);
Some(Type::Unit)
} else {
None
}
}
}
}
pub(crate) fn check_expected_struct(&mut self, struct_: Identifier, expected: &Option<Type>, span: Span) -> Type {
if let Some(expected) = expected {
if !Type::Identifier(struct_).eq_flat(expected) {
self.emit_err(TypeCheckerError::type_should_be(struct_.name, expected, span));
}
}
Type::Identifier(struct_)
}
pub(crate) fn assert_member_is_not_record(&self, span: Span, parent: Symbol, type_: &Type) {
match type_ {
Type::Identifier(identifier)
if self
.symbol_table
.borrow()
.lookup_struct(identifier.name)
.map_or(false, |struct_| struct_.is_record) =>
{
self.emit_err(TypeCheckerError::struct_or_record_cannot_contain_record(parent, identifier.name, span))
}
Type::Tuple(tuple_type) => {
for type_ in tuple_type.iter() {
self.assert_member_is_not_record(span, parent, type_)
}
}
_ => {} }
}
pub(crate) fn assert_type_is_defined(&self, type_: &Type, span: Span) -> bool {
let mut is_defined = true;
match type_ {
Type::String => {
is_defined = false;
self.emit_err(TypeCheckerError::strings_are_not_supported(span));
}
Type::Identifier(identifier) if self.symbol_table.borrow().lookup_struct(identifier.name).is_none() => {
is_defined = false;
self.emit_err(TypeCheckerError::undefined_type(identifier.name, span));
}
Type::Tuple(tuple_type) => {
for type_ in tuple_type.iter() {
is_defined &= self.assert_type_is_defined(type_, span)
}
}
Type::Mapping(mapping_type) => {
is_defined &= self.assert_type_is_defined(&mapping_type.key, span);
is_defined &= self.assert_type_is_defined(&mapping_type.value, span);
}
_ => {} }
is_defined
}
pub(crate) fn assert_mapping_type(&self, type_: &Option<Type>, span: Span) -> Option<MappingType> {
self.check_type(|type_| matches!(type_, Type::Mapping(_)), "mapping".to_string(), type_, span);
match type_ {
Some(Type::Mapping(mapping_type)) => Some(mapping_type.clone()),
_ => None,
}
}
}
fn types_to_string(types: &[Type]) -> String {
types.iter().map(|type_| type_.to_string()).join(", ")
}