use super::TypeCheckingVisitor;
use crate::{VariableSymbol, VariableType};
use leo_ast::{DiGraphError, Type, *};
use leo_errors::{Label, TypeCheckerError};
use leo_span::{Symbol, sym};
use itertools::Itertools;
use snarkvm::prelude::{CanaryV0, MainnetV0, TestnetV0};
use std::collections::{BTreeMap, HashMap};
impl ProgramVisitor for TypeCheckingVisitor<'_> {
fn visit_program(&mut self, input: &Program) {
input.stubs.iter().for_each(|(symbol, stub)| {
if symbol != &stub.stub_id.name.name {
self.emit_err(TypeCheckerError::stub_name_mismatch(
symbol,
stub.stub_id.name,
stub.stub_id.network.span,
));
}
self.visit_stub(stub)
});
self.scope_state.is_stub = false;
input.modules.values().for_each(|module| self.visit_module(module));
input.program_scopes.values().for_each(|scope| self.visit_program_scope(scope));
}
fn visit_program_scope(&mut self, input: &ProgramScope) {
let program_name = input.program_id.name;
self.scope_state.program_name = Some(program_name.name);
let record_info: BTreeMap<String, leo_span::Span> = input
.structs
.iter()
.filter(|(_, c)| c.is_record)
.map(|(_, r)| (r.name().to_string(), r.identifier.span))
.collect();
for ((prev_name, _), (curr_name, curr_span)) in record_info.iter().tuple_windows() {
if curr_name.starts_with(prev_name) {
self.state
.handler
.emit_err(TypeCheckerError::record_prefixed_by_other_record(curr_name, prev_name, *curr_span));
}
}
input.consts.iter().for_each(|(_, c)| self.visit_const(c));
input.structs.iter().for_each(|(_, function)| self.visit_struct(function));
if let Err(DiGraphError::CycleDetected(path)) = self.state.struct_graph.post_order() {
self.emit_err(TypeCheckerError::cyclic_struct_dependency(
path.iter().map(|p| p.iter().format("::")).collect(),
));
}
let mut mapping_count = 0;
for (_, mapping) in input.mappings.iter() {
self.visit_mapping(mapping);
mapping_count += 1;
}
for (_, storage_variable) in input.storage_variables.iter() {
self.visit_storage_variable(storage_variable);
}
if mapping_count > self.limits.max_mappings {
self.emit_err(TypeCheckerError::too_many_mappings(
self.limits.max_mappings,
input.program_id.name.span + input.program_id.network.span,
));
}
let mut transition_count = 0;
for (_, function) in input.functions.iter() {
self.visit_function(function);
if function.variant.is_transition() {
transition_count += 1;
}
}
if let Some(constructor) = &input.constructor {
self.visit_constructor(constructor);
}
if let Err(DiGraphError::CycleDetected(path)) = self.state.call_graph.post_order() {
self.emit_err(TypeCheckerError::cyclic_function_dependency(path));
}
if transition_count > self.limits.max_functions {
self.emit_err(TypeCheckerError::too_many_transitions(
self.limits.max_functions,
input.program_id.name.span + input.program_id.network.span,
));
}
else if transition_count == 0 {
self.emit_err(TypeCheckerError::no_transitions(input.program_id.name.span + input.program_id.network.span));
}
}
fn visit_module(&mut self, input: &Module) {
let parent_module = self.scope_state.module_name.clone();
self.scope_state.program_name = Some(input.program_name);
self.scope_state.module_name = input.path.clone();
input.consts.iter().for_each(|(_, c)| self.visit_const(c));
input.structs.iter().for_each(|(_, function)| self.visit_struct(function));
for (_, function) in input.functions.iter() {
self.visit_function(function);
}
self.scope_state.module_name = parent_module;
}
fn visit_stub(&mut self, input: &Stub) {
self.scope_state.program_name = Some(input.stub_id.name.name);
self.scope_state.is_stub = true;
if !input.consts.is_empty() {
self.emit_err(TypeCheckerError::stubs_cannot_have_const_declarations(input.consts.first().unwrap().1.span));
}
input.structs.iter().for_each(|(_, function)| self.visit_struct_stub(function));
input.functions.iter().for_each(|(_, function)| self.visit_function_stub(function));
}
fn visit_struct(&mut self, input: &Composite) {
self.in_conditional_scope(|slf| {
slf.in_scope(input.id, |slf| {
if input.is_record && !input.const_parameters.is_empty() {
slf.emit_err(TypeCheckerError::unexpected_record_const_parameters(input.span));
} else {
input
.const_parameters
.iter()
.for_each(|const_param| slf.insert_symbol_conditional_scope(const_param.identifier.name));
for const_param in &input.const_parameters {
slf.visit_type(const_param.type_());
if !matches!(
const_param.type_(),
Type::Boolean | Type::Integer(_) | Type::Address | Type::Scalar | Type::Group | Type::Field
) {
slf.emit_err(TypeCheckerError::bad_const_generic_type(
const_param.type_(),
const_param.span(),
));
}
if let Err(err) = slf.state.symbol_table.insert_variable(
slf.scope_state.program_name.unwrap(),
&[const_param.identifier().name],
VariableSymbol {
type_: const_param.type_().clone(),
span: const_param.identifier.span(),
declaration: VariableType::ConstParameter,
},
) {
slf.state.handler.emit_err(err);
}
slf.state.type_table.insert(const_param.identifier().id(), const_param.type_().clone());
}
}
input.members.iter().for_each(|member| slf.visit_type(&member.type_));
})
});
let mut used = HashMap::new();
for Member { identifier, type_, span, .. } in &input.members {
self.assert_type_is_valid(type_, *span);
if let Some(first_span) = used.get(&identifier.name) {
self.emit_err(if input.is_record {
TypeCheckerError::duplicate_record_variable(identifier.name, *span).with_labels(vec![
Label::new(format!("`{}` first declared here", identifier.name), *first_span)
.with_color(leo_errors::Color::Blue),
Label::new("record variable already declared", *span),
])
} else {
TypeCheckerError::duplicate_struct_member(identifier.name, *span).with_labels(vec![
Label::new(format!("`{}` first declared here", identifier.name), *first_span)
.with_color(leo_errors::Color::Blue),
Label::new("struct field already declared", *span),
])
});
} else {
used.insert(identifier.name, *span);
}
}
if input.is_record {
let check_has_field =
|need, expected_ty: Type| match input.members.iter().find_map(|Member { identifier, type_, .. }| {
(identifier.name == need).then_some((identifier, type_))
}) {
Some((_, actual_ty)) if expected_ty.eq_flat_relaxed(actual_ty) => {} Some((field, _)) => {
self.emit_err(TypeCheckerError::record_var_wrong_type(field, expected_ty, input.span()));
}
None => {
self.emit_err(TypeCheckerError::required_record_variable(need, expected_ty, input.span()));
}
};
check_has_field(sym::owner, Type::Address);
for Member { identifier, type_, span, .. } in input.members.iter() {
if self.contains_optional_type(type_) {
self.emit_err(TypeCheckerError::record_field_cannot_be_optional(identifier, type_, *span));
}
}
}
else if input.members.is_empty() {
self.emit_err(TypeCheckerError::empty_struct(input.span()));
}
if !(input.is_record && self.scope_state.is_stub) {
for Member { mode, identifier, type_, span, .. } in input.members.iter() {
if matches!(type_, Type::Tuple(_)) {
self.emit_err(TypeCheckerError::composite_data_type_cannot_contain_tuple(
if input.is_record { "record" } else { "struct" },
identifier.span,
));
} else if matches!(type_, Type::Future(..)) {
self.emit_err(TypeCheckerError::composite_data_type_cannot_contain_future(
if input.is_record { "record" } else { "struct" },
identifier.span,
));
}
self.assert_member_is_not_record(identifier.span, input.identifier.name, type_);
let composite_path = self
.scope_state
.module_name
.iter()
.cloned()
.chain(std::iter::once(input.identifier.name))
.collect::<Vec<Symbol>>();
if let Type::Composite(struct_member_type) = type_ {
self.state.struct_graph.add_edge(composite_path, struct_member_type.path.absolute_path().to_vec());
} else if let Type::Array(array_type) = type_ {
let base_element_type = array_type.base_element_type();
if let Type::Composite(member_type) = base_element_type {
self.state.struct_graph.add_edge(composite_path, member_type.path.absolute_path().to_vec());
}
}
if !input.is_record && !matches!(mode, Mode::None) {
self.emit_err(TypeCheckerError::struct_cannot_have_member_mode(*span));
}
}
}
}
fn visit_mapping(&mut self, input: &Mapping) {
self.visit_type(&input.key_type);
self.visit_type(&input.value_type);
self.assert_type_is_valid(&input.key_type, input.span);
match input.key_type.clone() {
Type::Future(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("key", "future", input.span)),
Type::Tuple(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("key", "tuple", input.span)),
Type::Composite(struct_type) => {
if let Some(comp) = self.lookup_struct(
struct_type.program.or(self.scope_state.program_name),
&struct_type.path.absolute_path(),
) {
if comp.is_record {
self.emit_err(TypeCheckerError::invalid_mapping_type("key", "record", input.span));
}
} else {
self.emit_err(TypeCheckerError::undefined_type(&input.key_type, input.span));
}
}
Type::Mapping(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("key", "mapping", input.span)),
_ => {}
}
if self.contains_optional_type(&input.key_type) {
self.emit_err(TypeCheckerError::optional_type_not_allowed_in_mapping(
input.key_type.clone(),
"key",
input.span,
))
}
self.assert_type_is_valid(&input.value_type, input.span);
match input.value_type.clone() {
Type::Future(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("value", "future", input.span)),
Type::Tuple(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("value", "tuple", input.span)),
Type::Composite(struct_type) => {
if let Some(comp) = self.lookup_struct(
struct_type.program.or(self.scope_state.program_name),
&struct_type.path.absolute_path(),
) {
if comp.is_record {
self.emit_err(TypeCheckerError::invalid_mapping_type("value", "record", input.span));
}
} else {
self.emit_err(TypeCheckerError::undefined_type(&input.value_type, input.span));
}
}
Type::Mapping(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("value", "mapping", input.span)),
_ => {}
}
if self.contains_optional_type(&input.value_type) {
self.emit_err(TypeCheckerError::optional_type_not_allowed_in_mapping(
input.value_type.clone(),
"value",
input.span,
))
}
}
fn visit_storage_variable(&mut self, input: &StorageVariable) {
self.visit_type(&input.type_);
let storage_type = if let Type::Vector(VectorType { element_type }) = &input.type_ {
*element_type.clone()
} else {
input.type_.clone()
};
self.assert_storage_type_is_valid(&storage_type, input.span);
}
fn visit_function(&mut self, function: &Function) {
self.scope_state.reset();
self.scope_state.variant = Some(function.variant);
for annotation in function.annotations.iter() {
if !matches!(annotation.identifier.name, sym::test | sym::should_fail) {
self.emit_err(TypeCheckerError::unknown_annotation(annotation, annotation.span))
}
}
let get = |symbol: Symbol| -> &Annotation {
function.annotations.iter().find(|ann| ann.identifier.name == symbol).unwrap()
};
let check_annotation = |symbol: Symbol, allowed_keys: &[Symbol]| -> bool {
let count = function.annotations.iter().filter(|ann| ann.identifier.name == symbol).count();
if count > 0 {
let annotation = get(symbol);
for key in annotation.map.keys() {
if !allowed_keys.contains(key) {
self.emit_err(TypeCheckerError::annotation_error(
format_args!("Invalid key `{key}` for annotation @{symbol}"),
annotation.span,
));
}
}
if count > 1 {
self.emit_err(TypeCheckerError::annotation_error(
format_args!("Duplicate annotation @{symbol}"),
annotation.span,
));
}
}
count > 0
};
let has_test = check_annotation(sym::test, &[sym::private_key]);
let has_should_fail = check_annotation(sym::should_fail, &[]);
if has_test && !self.state.is_test {
self.emit_err(TypeCheckerError::annotation_error(
format_args!("Test annotation @test appears outside of tests"),
get(sym::test).span,
));
}
if has_should_fail && !self.state.is_test {
self.emit_err(TypeCheckerError::annotation_error(
format_args!("Test annotation @should_fail appears outside of tests"),
get(sym::should_fail).span,
));
}
if has_should_fail && !has_test {
self.emit_err(TypeCheckerError::annotation_error(
format_args!("Annotation @should_fail appears without @test"),
get(sym::should_fail).span,
));
}
if has_test
&& !self.scope_state.variant.unwrap().is_script()
&& !self.scope_state.variant.unwrap().is_transition()
{
self.emit_err(TypeCheckerError::annotation_error(
format_args!("Annotation @test may appear only on scripts and transitions"),
get(sym::test).span,
));
}
if (has_test) && !function.input.is_empty() {
self.emit_err(TypeCheckerError::annotation_error(
"A test procedure cannot have inputs",
function.input[0].span,
));
}
self.in_conditional_scope(|slf| {
slf.in_scope(function.id, |slf| {
function
.const_parameters
.iter()
.for_each(|const_param| slf.insert_symbol_conditional_scope(const_param.identifier.name));
function.input.iter().for_each(|input| slf.insert_symbol_conditional_scope(input.identifier.name));
slf.scope_state.function = Some(function.name());
slf.check_function_signature(function, false);
if function.variant == Variant::Function && function.input.is_empty() {
slf.emit_err(TypeCheckerError::empty_function_arglist(function.span));
}
slf.visit_block(&function.block);
if function.output_type != Type::Unit && !slf.scope_state.has_return {
slf.emit_err(TypeCheckerError::missing_return(function.span));
}
})
});
if self.scope_state.variant == Some(Variant::AsyncTransition)
&& !self.scope_state.has_called_finalize
&& !self.scope_state.already_contains_an_async_block
{
self.emit_err(TypeCheckerError::missing_async_operation_in_async_transition(function.span));
}
self.scope_state.reset();
}
fn visit_constructor(&mut self, constructor: &Constructor) {
self.scope_state.reset();
self.scope_state.function = Some(sym::constructor);
self.scope_state.variant = Some(Variant::AsyncFunction);
self.scope_state.is_constructor = true;
let result = match self.state.network {
NetworkName::CanaryV0 => constructor.get_upgrade_variant::<CanaryV0>(),
NetworkName::TestnetV0 => constructor.get_upgrade_variant::<TestnetV0>(),
NetworkName::MainnetV0 => constructor.get_upgrade_variant::<MainnetV0>(),
};
let upgrade_variant = match result {
Ok(upgrade_variant) => upgrade_variant,
Err(e) => {
self.emit_err(TypeCheckerError::custom(e, constructor.span));
return;
}
};
match (&upgrade_variant, constructor.block.statements.is_empty()) {
(UpgradeVariant::Custom, true) => {
self.emit_err(TypeCheckerError::custom("A 'custom' constructor cannot be empty", constructor.span));
}
(UpgradeVariant::NoUpgrade | UpgradeVariant::Admin { .. } | UpgradeVariant::Checksum { .. }, false) => {
self.emit_err(TypeCheckerError::custom("A 'noupgrade', 'admin', or 'checksum' constructor must be empty. The Leo compiler will insert the appropriate code.", constructor.span));
}
_ => {}
}
if let UpgradeVariant::Checksum { mapping, key, key_type } = &upgrade_variant {
let Some(VariableSymbol { type_: Type::Mapping(mapping_type), .. }) =
self.state.symbol_table.lookup_global(mapping)
else {
self.emit_err(TypeCheckerError::custom(
format!("The mapping '{mapping}' does not exist. Please ensure that it is imported or defined in your program."),
constructor.annotations[0].span,
));
return;
};
if *mapping_type.key != *key_type {
self.emit_err(TypeCheckerError::custom(
format!(
"The mapping '{}' key type '{}' does not match the key '{}' in the `@checksum` annotation",
mapping, mapping_type.key, key
),
constructor.annotations[0].span,
));
}
let check_value_type = |type_: &Type| -> bool {
if let Type::Array(array_type) = type_ {
if !matches!(array_type.element_type.as_ref(), &Type::Integer(_)) {
return false;
}
if let Some(length) = array_type.length.as_u32() {
return length == 32;
}
return false;
}
false
};
if !check_value_type(&mapping_type.value) {
self.emit_err(TypeCheckerError::custom(
format!("The mapping '{}' value type '{}' must be a '[u8; 32]'", mapping, mapping_type.value),
constructor.annotations[0].span,
));
}
}
self.in_conditional_scope(|slf| {
slf.in_scope(constructor.id, |slf| {
slf.visit_block(&constructor.block);
})
});
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::custom("The constructor cannot call `finalize`.", constructor.span));
}
if self.scope_state.already_contains_an_async_block {
self.emit_err(TypeCheckerError::custom("The constructor cannot have an `async` block.", constructor.span));
}
self.scope_state.reset();
}
fn visit_function_stub(&mut self, input: &FunctionStub) {
if input.variant == Variant::Inline {
self.emit_err(TypeCheckerError::stub_functions_must_not_be_inlines(input.span));
}
if input.variant == Variant::AsyncFunction {
let finalize_input_map = &mut self.async_function_input_types;
let resolved_inputs: Vec<Type> = input
.input
.iter()
.map(|input| {
match &input.type_ {
Type::Future(f) => {
Type::Future(FutureType::new(
finalize_input_map.get(f.location.as_ref().unwrap()).unwrap().clone(),
f.location.clone(),
true,
))
}
_ => input.clone().type_,
}
})
.collect();
finalize_input_map.insert(
Location::new(self.scope_state.program_name.unwrap(), vec![input.identifier.name]),
resolved_inputs,
);
}
self.check_function_signature(&Function::from(input.clone()), true);
}
fn visit_struct_stub(&mut self, input: &Composite) {
self.visit_struct(input);
}
}