use rustc_hash::FxHashMap as HashMap;
use syntax::ast::BindingKind;
use syntax::ast::{Annotation, Binding, Expression, Pattern, Span, StructKind};
use syntax::program::{CallKind, Definition, NativeTypeKind};
use syntax::types::{Bound, SubstitutionMap, Type, substitute};
use super::super::Checker;
use super::super::checks::check_binding_pattern;
use super::primitives::contains_deref;
use crate::checker::PostInferenceCheck;
use crate::checker::scopes::UseContext;
use crate::store::ENTRY_MODULE_ID;
fn has_numeric_member_in_chain(expression: &Expression) -> bool {
let mut current = expression.unwrap_parens();
while let Expression::DotAccess {
expression: inner,
member,
..
} = current
{
if member.parse::<usize>().is_ok() {
return true;
}
current = inner.unwrap_parens();
}
false
}
impl Checker<'_, '_> {
pub(super) fn infer_function(
&mut self,
expression: Expression,
expected_ty: &Type,
) -> Expression {
let Expression::Function {
doc,
attributes,
name,
name_span,
generics,
params,
return_annotation,
visibility,
body,
span,
..
} = expression
else {
unreachable!("infer_function called with non-Function expression");
};
if self.scopes.lookup_fn_return_type().is_some() {
self.sink
.push(diagnostics::infer::nested_function(name_span));
}
if name == "main"
&& self.cursor.module_id == ENTRY_MODULE_ID
&& (!params.is_empty() || return_annotation != Annotation::Unknown)
{
self.sink
.push(diagnostics::infer::invalid_main_signature(name_span));
}
self.scopes.push();
self.put_in_scope(&generics);
let mut bounds = vec![];
for g in &generics {
let qualified_name = self.qualify_name(&g.name);
for b in &g.bounds {
let bound_ty = self.convert_to_type(b, &span);
self.scopes
.current_mut()
.trait_bounds
.get_or_insert_with(HashMap::default)
.entry(qualified_name.clone())
.or_default()
.push(bound_ty.clone());
bounds.push(Bound {
param_name: g.name.clone(),
generic: Type::Parameter(g.name.clone()),
ty: bound_ty,
});
}
}
let expected_params = expected_ty.get_function_params().unwrap_or_default();
let new_params = self.infer_function_params(params, expected_params, true);
let return_ty =
self.infer_return_type(&return_annotation, expected_ty, &span, self.type_unit());
self.scopes.current_mut().fn_return_type = Some(return_ty.clone());
let base_fn_ty = Type::Function {
param_mutability: new_params.iter().map(|p| p.mutable).collect(),
params: new_params.iter().map(|p| p.ty.clone()).collect(),
bounds,
return_type: return_ty.clone().into(),
};
let has_implicit_unit_return = return_annotation == Annotation::Unknown;
let body_ty = if has_implicit_unit_return {
Type::ignored()
} else {
return_ty.clone()
};
let new_body = self.infer_function_body(body, &body_ty, &return_annotation, &return_ty);
self.scopes.pop();
self.check_constrained_return_type(
&return_ty,
&generics,
&return_annotation.get_span(),
&name,
);
self.check_unused_type_parameters(&generics, &base_fn_ty);
let fn_forall_ty = if generics.is_empty() {
base_fn_ty.clone()
} else {
Type::Forall {
vars: generics.iter().map(|g| g.name.clone()).collect(),
body: Box::new(base_fn_ty),
}
};
let (fn_ty, _) = self.instantiate(&fn_forall_ty);
self.unify(expected_ty, &fn_ty, &span);
Expression::Function {
doc,
attributes,
name,
name_span,
generics,
params: new_params,
return_annotation,
return_type: return_ty,
visibility,
body: new_body.into(),
ty: fn_ty,
span,
}
}
pub(super) fn infer_lambda(
&mut self,
params: Vec<Binding>,
return_annotation: Annotation,
body: Box<Expression>,
span: Span,
expected_ty: &Type,
) -> Expression {
self.scopes.push();
let expected_params = expected_ty.get_function_params().unwrap_or_default();
let new_params = self.infer_function_params(params, expected_params, false);
let default_return = self.new_type_var();
let return_ty =
self.infer_return_type(&return_annotation, expected_ty, &span, default_return);
self.scopes.current_mut().fn_return_type = Some(return_ty.clone());
let base_fn_ty = Type::Function {
param_mutability: vec![false; new_params.len()],
params: new_params.iter().map(|p| p.ty.clone()).collect(),
bounds: vec![],
return_type: return_ty.clone().into(),
};
let saved_loop_depth = self.scopes.reset_loop_depth();
let new_body = self.infer_function_body(body, &return_ty, &return_annotation, &return_ty);
self.scopes.restore_loop_depth(saved_loop_depth);
self.scopes.pop();
let (fn_ty, _) = self.instantiate(&base_fn_ty);
self.unify(expected_ty, &fn_ty, &span);
Expression::Lambda {
params: new_params,
return_annotation,
body: new_body.into(),
ty: fn_ty,
span,
}
}
pub(super) fn infer_function_call(
&mut self,
expression: Box<Expression>,
args: Vec<Expression>,
type_args: Vec<Annotation>,
span: Span,
expected_ty: &Type,
) -> Expression {
let callee_ty = self.new_type_var();
let prev_context = self.scopes.set_callee_context();
let callee_expression = self.infer_expression(*expression, &callee_ty);
self.scopes.restore_use_context(prev_context);
let forall_ty = self.resolve_callee_forall_type(&callee_expression, &type_args);
let (callee_ty, new_type_args) =
self.instantiate_callee_type(&forall_ty, &type_args, &callee_expression, &span);
let (param_types, param_mutability, return_ty, bounds) =
self.extract_call_signature(callee_ty, args.len(), &callee_expression);
if self.is_panic_call(&callee_expression)
&& self.scopes.is_value_context()
&& !expected_ty.is_unit()
&& !expected_ty.is_ignored()
&& !expected_ty.is_never()
&& !expected_ty.is_variable()
{
self.sink
.push(diagnostics::infer::panic_in_expression_position(span));
}
if self.is_generic_callee(&callee_expression)
&& !expected_ty.resolve().is_variable()
&& !expected_ty.is_ignored()
&& self.is_enum_type(&return_ty.resolve())
&& self.has_interface_type_param(expected_ty)
{
let _ = self.speculatively(|this| this.try_unify(expected_ty, &return_ty, &span));
}
let new_args = self.infer_call_arguments(args, ¶m_types);
for arg in &new_args {
self.check_not_temp_producing(arg);
}
self.check_call_arity(¶m_types, &new_args, &callee_expression, &span);
self.check_mut_param_arguments(&new_args, ¶m_mutability, &callee_expression);
let expected_was_variable = expected_ty.resolve().is_variable();
self.unify(expected_ty, &return_ty, &span);
self.unify_trait_bounds(&bounds, &new_args, &span);
let result_unused = prev_context != UseContext::Value && {
let resolved = expected_ty.resolve();
resolved.is_unit() || resolved.is_ignored() || expected_was_variable
};
self.check_native_mutating_call(&callee_expression, result_unused, &span);
self.check_unconstrained_bounded_type_params(&bounds, &span);
if self.is_generic_callee(&callee_expression)
&& type_args.is_empty()
&& !self.is_enum_type(&return_ty.resolve())
{
self.post_inference_checks
.push(PostInferenceCheck::GenericCall {
return_ty: return_ty.clone(),
span,
});
}
let call_ty = if !expected_ty.is_variable()
&& self.is_generic_container_with_interface(expected_ty)
{
expected_ty.clone()
} else {
return_ty.clone()
};
let call_kind = self.classify_call(&callee_expression);
self.resolutions.mark_call(span, call_kind);
Expression::Call {
expression: callee_expression.into(),
args: new_args,
type_args: new_type_args,
ty: call_ty,
span,
}
}
fn resolve_callee_forall_type(
&mut self,
expression: &Expression,
type_args: &[Annotation],
) -> Type {
if type_args.is_empty() {
return expression.get_type();
}
match expression {
Expression::Identifier { value, .. } => self
.lookup_type(value)
.unwrap_or_else(|| expression.get_type()),
Expression::DotAccess {
expression: receiver,
member,
..
} => {
let receiver_ty = receiver.get_type().resolve();
if let Some(method_ty) = self
.get_all_methods(&receiver_ty.strip_refs())
.get(member)
.cloned()
{
return method_ty;
}
if let Type::Constructor { id, .. } = receiver_ty.strip_refs() {
let qualified = format!("{}.{}", id, member);
if let Some(definition) = self.store.get_definition(&qualified) {
return definition.ty().clone();
}
if let Some(module_id) = id.strip_prefix("@import/") {
let qualified = format!("{}.{}", module_id, member);
if let Some(definition) = self.store.get_definition(&qualified) {
return definition.ty().clone();
}
}
}
expression.get_type()
}
_ => expression.get_type(),
}
}
fn is_generic_callee(&self, expression: &Expression) -> bool {
match expression {
Expression::Identifier { value, .. } => self
.lookup_type(value)
.map(|ty| matches!(ty, Type::Forall { .. }))
.unwrap_or(false),
Expression::DotAccess {
expression: receiver,
member,
..
} => {
let receiver_ty = receiver.get_type().resolve();
self.get_all_methods(&receiver_ty.strip_refs())
.get(member)
.map(|ty| matches!(ty, Type::Forall { .. }))
.unwrap_or(false)
}
_ => false,
}
}
fn instantiate_callee_type(
&mut self,
forall_ty: &Type,
type_args: &[Annotation],
callee_expression: &Expression,
span: &Span,
) -> (Type, Vec<Annotation>) {
let Type::Forall { vars, body } = forall_ty else {
if !type_args.is_empty() {
self.sink.push(diagnostics::infer::type_args_on_non_generic(
type_args.len(),
*span,
));
}
let (instantiated, _) = self.instantiate(forall_ty);
return (instantiated.resolve(), vec![]);
};
if type_args.is_empty() {
let (instantiated, _) = self.instantiate(forall_ty);
return (instantiated.resolve(), vec![]);
}
let receiver_generics_count =
if let Expression::DotAccess { expression, .. } = callee_expression {
let receiver_ty = expression.get_type().resolve().strip_refs().clone();
self.get_receiver_generics_count(&receiver_ty)
} else {
0
};
let method_only_count = vars.len().saturating_sub(receiver_generics_count);
let is_full_arity = type_args.len() == vars.len();
let is_method_only_arity =
receiver_generics_count > 0 && type_args.len() == method_only_count;
if !is_full_arity && !is_method_only_arity {
let actual_types: Vec<Type> = type_args
.iter()
.map(|arg| self.convert_to_type(arg, span))
.collect();
let vars_as_str: Vec<String> = vars.iter().map(|s| s.to_string()).collect();
self.sink.push(diagnostics::infer::generics_arity_mismatch(
&vars_as_str,
type_args,
&actual_types,
*span,
));
}
let mut instantiated = if is_method_only_arity {
let mut map: SubstitutionMap = SubstitutionMap::default();
for var in &vars[..receiver_generics_count] {
map.insert(var.clone(), self.new_type_var());
}
for (var, ann) in vars[receiver_generics_count..].iter().zip(type_args.iter()) {
map.insert(var.clone(), self.convert_to_type(ann, span));
}
substitute(body, &map)
} else {
self.instantiate_from_annotations(vars, body, type_args, span)
};
if let Expression::DotAccess { expression, .. } = callee_expression {
let receiver_ty = expression.get_type().resolve();
let callee_params = callee_expression.get_type().resolve().param_count();
let instantiated_params = instantiated.param_count();
let has_receiver = instantiated_params > callee_params;
if has_receiver
&& let Type::Function {
ref mut params,
ref mut param_mutability,
..
} = instantiated
&& !params.is_empty()
{
let receiver_param = params.remove(0);
if !param_mutability.is_empty() {
param_mutability.remove(0);
}
let receiver_ty_stripped = receiver_ty.strip_refs();
if receiver_param.is_ref() && !receiver_ty.is_ref() {
if let Some(inner) = receiver_param.inner() {
self.unify(&inner, &receiver_ty_stripped, span);
}
} else {
self.unify(&receiver_param, &receiver_ty_stripped, span);
}
}
self.unify(&instantiated, &callee_expression.get_type(), span);
}
(instantiated, type_args.to_vec())
}
fn extract_call_signature(
&mut self,
callee_ty: Type,
arg_count: usize,
callee_expression: &Expression,
) -> (Vec<Type>, Vec<bool>, Type, Vec<Bound>) {
let callee_ty = callee_ty.resolve();
let bounds = callee_ty.get_bounds().to_vec();
let param_mutability = callee_ty.get_param_mutability().to_vec();
let is_variadic = callee_ty.is_variadic();
let (param_types, return_ty) = match self.extract_function_type(&callee_ty) {
Some((mut params, return_type)) => {
if let Some(variadic_ty) = is_variadic {
params.pop();
while params.len() < arg_count {
params.push(variadic_ty.clone());
}
}
(params, return_type)
}
None if callee_ty.is_variable() => {
let param_types = (0..arg_count).map(|_| self.new_type_var()).collect();
let return_ty = self.new_type_var();
(param_types, return_ty)
}
None if callee_ty.resolve().is_error() => {
let param_types = (0..arg_count).map(|_| Type::Error).collect();
let return_ty = Type::Error;
(param_types, return_ty)
}
None => {
self.sink.push(diagnostics::infer::not_callable(
&callee_ty,
callee_expression.get_span(),
));
let param_types = (0..arg_count).map(|_| Type::Error).collect();
let return_ty = Type::Error;
(param_types, return_ty)
}
};
(param_types, param_mutability, return_ty, bounds)
}
fn extract_function_type(&self, ty: &Type) -> Option<(Vec<Type>, Type)> {
let fn_type = |ty: &Type| -> Option<(Vec<Type>, Type)> {
if let Type::Function {
params,
return_type,
..
} = ty
{
Some((params.clone(), (**return_type).clone()))
} else {
None
}
};
if let result @ Some(_) = fn_type(ty) {
return result;
}
if let Type::Constructor {
underlying_ty: Some(underlying),
..
} = ty
&& let result @ Some(_) = fn_type(underlying)
{
return result;
}
if let Type::Constructor { id, params, .. } = ty
&& let Some(Definition::TypeAlias { ty: alias_ty, .. }) = self.store.get_definition(id)
{
let concrete_alias_ty = match alias_ty {
Type::Forall { vars, body } => {
let map: SubstitutionMap =
vars.iter().cloned().zip(params.iter().cloned()).collect();
substitute(body, &map)
}
other => other.clone(),
};
let resolved = concrete_alias_ty.resolve();
if let Type::Constructor {
underlying_ty: Some(underlying),
..
} = &resolved
{
return fn_type(underlying);
}
}
None
}
fn infer_call_arguments(
&mut self,
args: Vec<Expression>,
param_types: &[Type],
) -> Vec<Expression> {
args.into_iter()
.enumerate()
.map(|(i, arg)| {
let expected_ty = param_types
.get(i)
.cloned()
.unwrap_or_else(|| self.new_type_var());
self.with_value_context(|s| s.infer_expression(arg, &expected_ty))
})
.collect()
}
fn unify_trait_bounds(&mut self, bounds: &[Bound], args: &[Expression], fallback_span: &Span) {
for bound in bounds {
let resolved_ty = bound.generic.resolve();
if resolved_ty.is_variable() {
continue;
}
let interface_ty = bound.ty.resolve();
let Type::Constructor { id, params, .. } = interface_ty else {
continue;
};
let Some(interface) = self.store.get_interface(&id).cloned() else {
continue;
};
let span = args
.iter()
.find(|arg| arg.get_type().resolve() == resolved_ty)
.map(|arg| arg.get_span())
.unwrap_or_else(|| *fallback_span);
let _ = self.satisfies_interface(&resolved_ty, &interface, ¶ms, &span);
}
}
fn infer_function_body(
&mut self,
body: Box<Expression>,
body_ty: &Type,
return_annotation: &Annotation,
return_ty: &Type,
) -> Expression {
if let Expression::Block {
items,
span: body_span,
..
} = body.as_ref()
&& items.is_empty()
&& *return_annotation != Annotation::Unknown
&& !return_ty.is_unit()
{
self.sink
.push(diagnostics::infer::empty_body_return_mismatch(
return_ty,
return_annotation.get_span(),
));
return Expression::Block {
items: vec![],
ty: self.type_unit(),
span: *body_span,
};
}
self.infer_expression(*body, body_ty)
}
fn infer_function_params(
&mut self,
params: Vec<Binding>,
expected_params: &[Type],
handle_self_receiver: bool,
) -> Vec<Binding> {
params
.into_iter()
.enumerate()
.map(|(index, binding)| {
let expected_param_ty = match binding.annotation {
None => expected_params.get(index).cloned(),
_ => None,
};
let binding_ty = expected_param_ty.unwrap_or_else(|| {
let pattern_span = &binding.pattern.get_span();
if handle_self_receiver
&& let Pattern::Identifier { identifier, .. } = &binding.pattern
&& identifier == "self"
&& binding.annotation.is_none()
&& let Some(impl_ty) = &self.inference.impl_receiver_type
{
return impl_ty.clone();
}
binding
.annotation
.as_ref()
.map(|a| self.convert_to_type(a, pattern_span))
.unwrap_or_else(|| self.new_type_var())
});
let (new_pattern, typed_pattern) = self.infer_pattern(
binding.pattern,
binding_ty.clone(),
BindingKind::Parameter {
mutable: binding.mutable,
},
);
check_binding_pattern(self.sink, &new_pattern);
Binding {
pattern: new_pattern,
annotation: binding.annotation,
typed_pattern: Some(typed_pattern),
ty: binding_ty,
mutable: binding.mutable,
}
})
.collect()
}
fn infer_return_type(
&mut self,
annotation: &Annotation,
expected_ty: &Type,
span: &Span,
default_for_unknown: Type,
) -> Type {
match annotation {
Annotation::Unknown => {
if let Type::Function { return_type, .. } = expected_ty {
(**return_type).clone()
} else if let Type::Constructor {
underlying_ty: Some(inner),
..
} = expected_ty
&& let Type::Function { return_type, .. } = inner.as_ref()
{
(**return_type).clone()
} else {
default_for_unknown
}
}
_ => self.convert_to_type(annotation, span),
}
}
fn classify_call(&self, callee: &Expression) -> CallKind {
let callee = callee.unwrap_parens();
match callee {
Expression::DotAccess {
expression: receiver,
member,
..
} => {
let receiver_ty = receiver.get_type().resolve().strip_refs();
if let Type::Constructor { id, .. } = &receiver_ty
&& self
.ufcs_methods
.contains(&(id.to_string(), member.to_string()))
{
return CallKind::UfcsMethod;
}
if let Some(kind) = NativeTypeKind::from_type(&receiver.get_type()) {
return CallKind::NativeMethod(kind);
}
if let Type::Constructor { id, .. } = receiver.get_type().resolve()
&& let Some(module_id) = id.strip_prefix("@import/")
{
let qualified = format!("{}.{}", module_id, member);
if matches!(
self.store.get_definition(&qualified),
Some(Definition::Struct {
kind: StructKind::Tuple,
..
})
) {
return CallKind::TupleStructConstructor;
}
}
}
Expression::Identifier { value, .. } => {
let qualified = self.qualify_name(value);
let definition = self.store.get_definition(&qualified);
if definition.is_none() && value == "assert_type" {
return CallKind::AssertType;
}
if self.is_tuple_struct_definition(definition, callee) {
return CallKind::TupleStructConstructor;
}
let constructor_kind = match value.as_str() {
"Channel.new" | "Channel.buffered" => Some(NativeTypeKind::Channel),
"Map.new" => Some(NativeTypeKind::Map),
"Slice.new" => Some(NativeTypeKind::Slice),
_ => None,
};
if let Some(kind) = constructor_kind {
return CallKind::NativeConstructor(kind);
}
if let Some((prefix, _method)) = value.split_once('.')
&& let Some(kind) = NativeTypeKind::from_name(prefix)
{
return CallKind::NativeMethodIdentifier(kind);
}
if let Some(kind) = self.try_classify_receiver_ufcs(value) {
return kind;
}
}
_ => {}
}
CallKind::Regular
}
fn try_classify_receiver_ufcs(&self, value: &str) -> Option<CallKind> {
let last_dot = value.rfind('.')?;
let method = &value[last_dot + 1..];
let type_part = &value[..last_dot];
let qualified_name = self.lookup_qualified_name(type_part)?;
let definition = self.store.get_definition(&qualified_name)?;
let methods = match definition {
Definition::Struct { methods, .. } => methods,
Definition::Enum { methods, .. } => methods,
Definition::TypeAlias { methods, .. } => methods,
_ => return None,
};
let method_ty = methods.get(method)?;
let has_self = match method_ty {
Type::Function { params, .. } => !params.is_empty(),
Type::Forall { body, .. } => {
if let Type::Function { params, .. } = body.as_ref() {
!params.is_empty()
} else {
false
}
}
_ => false,
};
if !has_self {
return None;
}
if self
.ufcs_methods
.contains(&(qualified_name.to_string(), method.to_string()))
{
return None;
}
let is_public = self
.store
.get_definition(&format!("{}.{}", qualified_name, method))
.map(|d| d.visibility().is_public())
.unwrap_or(false);
Some(CallKind::ReceiverMethodUfcs { is_public })
}
fn is_tuple_struct_definition(
&self,
definition: Option<&Definition>,
callee: &Expression,
) -> bool {
if matches!(
definition,
Some(Definition::Struct {
kind: StructKind::Tuple,
..
})
) {
return true;
}
if matches!(definition, Some(Definition::TypeAlias { .. })) {
let ty = callee.get_type().resolve();
let return_ty = match ty.unwrap_forall() {
Type::Function { return_type, .. } => return_type.as_ref().clone(),
_ => return false,
};
if let Type::Constructor { id, .. } = return_ty.resolve() {
return matches!(
self.store.get_definition(&id),
Some(Definition::Struct {
kind: StructKind::Tuple,
..
})
);
}
}
false
}
fn is_panic_call(&self, expression: &Expression) -> bool {
match expression {
Expression::Identifier { value, .. } => value == "panic",
_ => false,
}
}
fn is_external_callee(&self, expression: &Expression) -> bool {
if let Expression::DotAccess {
expression: base, ..
} = expression
&& let Expression::Identifier { value, .. } = base.as_ref()
{
return self
.imports
.prefix_to_module
.get(value.as_ref())
.is_some_and(|module_id| module_id.starts_with("go:"));
}
false
}
fn check_native_mutating_call(
&mut self,
callee: &Expression,
result_unused: bool,
span: &Span,
) {
let Expression::DotAccess {
expression: receiver,
member,
..
} = callee
else {
return;
};
let receiver_ty = receiver.get_type().resolve().strip_refs();
if matches!(receiver_ty.get_name(), Some("Slice"))
&& (member == "append" || member == "extend")
&& self.has_map_field_in_chain(receiver)
&& !has_numeric_member_in_chain(receiver)
{
self.sink
.push(diagnostics::infer::map_field_chain_assignment(*span));
return;
}
let is_mutating = match receiver_ty.get_name() {
Some("Slice") => {
(member == "append" || member == "extend") && result_unused
}
Some("Map") => member == "delete",
_ => false,
};
if !is_mutating {
return;
}
let Some(var_name) = receiver.get_var_name() else {
return;
};
if let Some(binding_id) = self.scopes.lookup_binding_id(&var_name) {
self.facts.mark_mutated(binding_id);
}
let is_deref = contains_deref(receiver);
let binding_is_ref = self
.scopes
.lookup_value(&var_name)
.map(|t| t.resolve().is_ref())
.unwrap_or(false);
if !is_deref && !binding_is_ref && !self.scopes.lookup_mutable(&var_name) {
self.sink.push(diagnostics::infer::disallowed_mutation(
&var_name, *span, None,
));
}
}
fn check_mut_param_arguments(
&mut self,
args: &[Expression],
param_mutability: &[bool],
callee: &Expression,
) {
let is_external = self.is_external_callee(callee);
for (i, arg) in args.iter().enumerate() {
let is_mut_param = param_mutability.get(i).copied().unwrap_or(false);
if !is_mut_param {
continue;
}
if let Some(var_name) = arg.get_var_name() {
if !self.scopes.lookup_mutable(&var_name) {
self.sink
.push(diagnostics::infer::immutable_argument_to_mut_param(
&var_name,
arg.get_span(),
is_external,
));
}
if let Some(binding_id) = self.scopes.lookup_binding_id(&var_name) {
self.facts.mark_mutated(binding_id);
}
}
}
}
}