use ecow::EcoString;
use syntax::ast::{
Annotation, Expression, Generic, ParentInterface, Pattern, Span,
Visibility as SyntacticVisibility,
};
use syntax::program::{Definition, Interface, Visibility};
use syntax::types::Type;
use super::{extract_attribute_flags, has_recursive_instantiation, wrap_with_impl_generics};
use crate::checker::Checker;
impl Checker<'_, '_> {
pub(super) fn try_register_instance_method(
&mut self,
module_id: &str,
receiver_qualified_name: &str,
type_name: &str,
fn_name: &str,
fn_name_span: Span,
method_ty: &Type,
) -> bool {
let module = self
.store
.get_module_mut(module_id)
.expect("current module must exist in store");
let Some(definition) = module.definitions.get_mut(receiver_qualified_name) else {
return false;
};
if let Definition::Struct { fields, .. } = &*definition
&& fields.iter().any(|f| f.name == fn_name)
{
self.sink.push(diagnostics::infer::method_shadows_field(
type_name,
fn_name,
fn_name_span,
));
}
if let Definition::Enum { variants, .. } = &*definition {
for variant in variants {
if variant.fields.is_struct() && variant.fields.iter().any(|f| f.name == fn_name) {
self.sink.push(diagnostics::infer::method_shadows_field(
type_name,
fn_name,
fn_name_span,
));
break;
}
}
}
if let Some(methods) = definition.methods_mut() {
methods.insert(fn_name.into(), method_ty.clone());
}
!matches!(definition, Definition::ValueEnum { .. })
}
fn check_duplicate_method(
&self,
module_id: &str,
receiver_qualified_name: &str,
type_name: &str,
fn_name: &str,
fn_name_span: Span,
impl_generics_empty: bool,
) {
let module_qualified_name = format!("{module_id}.{type_name}.{fn_name}");
let module = self
.store
.get_module(module_id)
.expect("current module must exist in store");
if !module
.definitions
.contains_key(module_qualified_name.as_str())
{
return;
}
let is_cross_specialization = impl_generics_empty
&& matches!(
module.definitions.get(receiver_qualified_name),
Some(Definition::Struct { generics: struct_generics, .. })
if !struct_generics.is_empty()
);
if is_cross_specialization {
let struct_generic_names: Vec<String> =
match module.definitions.get(receiver_qualified_name) {
Some(Definition::Struct { generics: g, .. }) => {
g.iter().map(|g| g.name.to_string()).collect()
}
_ => vec![],
};
self.sink.push(
diagnostics::infer::duplicate_method_across_specialized_impls(
fn_name,
type_name,
&struct_generic_names,
fn_name_span,
),
);
} else {
self.sink.push(diagnostics::infer::duplicate_impl_item(
fn_name,
type_name,
fn_name_span,
));
}
}
pub(super) fn populate_impl_methods(
&mut self,
annotation: &Annotation,
generics: &[Generic],
functions: &[Expression],
span: &Span,
) {
self.scopes.push();
self.put_in_scope(generics);
self.check_undeclared_impl_type_params(annotation, generics);
let receiver_ty = self.convert_to_type(annotation, span);
let Some(type_name) = receiver_ty.get_name() else {
self.scopes.pop();
return;
};
let receiver_qualified_name = receiver_ty.get_qualified_name();
let module_id = self.cursor.module_id.clone();
if let Some(type_module) = self
.store
.module_for_qualified_name(&receiver_qualified_name)
&& type_module != module_id
{
self.sink.push(diagnostics::infer::impl_on_foreign_type(
type_name,
type_module,
*span,
));
self.scopes.pop();
return;
}
let mut impl_bounds: Vec<syntax::types::Bound> = Vec::new();
for g in generics {
for b in &g.bounds {
let bound_ty = self.convert_to_type(b, span);
impl_bounds.push(syntax::types::Bound {
param_name: g.name.clone(),
generic: Type::Parameter(g.name.clone()),
ty: bound_ty,
});
}
}
let mut static_methods: Vec<(String, Type)> = Vec::new();
for function in functions {
let fn_attrs = if let Expression::Function { attributes, .. } = function {
attributes.as_slice()
} else {
&[]
};
let fn_visibility = if let Expression::Function { visibility, .. } = function
&& (*visibility == SyntacticVisibility::Public || self.is_d_lis())
{
Visibility::Public
} else {
Visibility::Private
};
let fn_sig = function.to_function_signature();
let mut fn_ty = self.extract_function_signature(&fn_sig, span);
let qualified_name = format!("{}.{}", type_name, fn_sig.name);
let module_qualified_name = format!("{}.{}", module_id, qualified_name);
let is_instance_method = fn_sig.params.first().is_some_and(|p| {
matches!(p.pattern, Pattern::Identifier { ref identifier, .. } if identifier == "self")
});
let has_unannotated_self = fn_sig
.params
.first()
.is_some_and(|p| p.annotation.is_none());
if is_instance_method && has_unannotated_self {
fn_ty = fn_ty.with_replaced_first_param(&receiver_ty);
}
let method_ty = wrap_with_impl_generics(&fn_ty, generics, &impl_bounds);
if !generics.is_empty()
&& self.impl_has_simple_type_params(&receiver_ty, generics)
&& has_recursive_instantiation(&receiver_qualified_name, &fn_ty)
{
self.sink
.push(diagnostics::infer::recursive_generic_instantiation(
type_name,
fn_sig.name_span,
));
}
if is_instance_method {
if !self.try_register_instance_method(
&module_id,
&receiver_qualified_name,
type_name,
&fn_sig.name,
fn_sig.name_span,
&method_ty,
) {
continue;
}
} else {
static_methods.push((qualified_name, method_ty.clone()));
}
self.check_duplicate_method(
&module_id,
&receiver_qualified_name,
type_name,
&fn_sig.name,
fn_sig.name_span,
generics.is_empty(),
);
let module = self
.store
.get_module_mut(&module_id)
.expect("current module must exist in store");
module.definitions.insert(
module_qualified_name.into(),
Definition::Value {
visibility: fn_visibility.clone(),
ty: method_ty,
name_span: Some(fn_sig.name_span),
allowed_lints: extract_attribute_flags(fn_attrs, "allow"),
go_hints: extract_attribute_flags(fn_attrs, "go"),
go_name: None,
doc: None,
},
);
}
self.scopes.pop();
let scope = self.scopes.current_mut();
for (name, ty) in static_methods {
scope.values.insert(name, ty);
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn populate_interface(
&mut self,
interface_name: &str,
name_span: &Span,
generics: &[Generic],
parents: &[ParentInterface],
fn_expressions: &[Expression],
span: &Span,
doc: &Option<String>,
) {
self.scopes.push();
self.put_in_scope(generics);
self.validate_generic_bounds(generics, span);
let new_parents = parents
.iter()
.map(|s| self.convert_to_type(&s.annotation, &s.span))
.collect();
let mut method_defs: Vec<(EcoString, Type, Vec<String>)> = Vec::new();
let methods = fn_expressions
.iter()
.map(|fe| {
let fn_attrs = if let Expression::Function { attributes, .. } = fe {
attributes.as_slice()
} else {
&[]
};
let method_sig = fe.to_function_signature();
let fn_ty = self.extract_function_signature(&method_sig, span);
let fn_ty = match &fn_ty {
Type::Forall { body, .. } => body.as_ref().clone(),
_ => fn_ty,
};
let has_self_receiver = method_sig.params.first().is_some_and(|p| {
matches!(p.pattern, Pattern::Identifier { ref identifier, .. } if identifier == "self")
&& p.annotation.is_none()
});
let fn_ty = if has_self_receiver {
match fn_ty {
Type::Function {
params,
param_mutability,
bounds,
return_type,
} => Type::Function {
params: params[1..].to_vec(),
param_mutability: if param_mutability.is_empty() {
vec![]
} else {
param_mutability[1..].to_vec()
},
bounds,
return_type,
},
other => other,
}
} else {
fn_ty
};
method_defs.push((
method_sig.name.clone(),
fn_ty.clone(),
extract_attribute_flags(fn_attrs, "go"),
));
(method_sig.name, fn_ty)
})
.collect();
self.scopes.pop();
let qualified_name = self.qualify_name(interface_name);
let interface_ty = self
.store
.get_type(&qualified_name)
.expect("interface type scheme must exist")
.clone();
let interface = Interface {
name: interface_name.into(),
generics: generics.to_owned(),
parents: new_parents,
methods,
};
let visibility = self
.store
.get_module(&self.cursor.module_id)
.expect("current module must exist in store")
.definitions
.get(qualified_name.as_str())
.map(|definition| definition.visibility().clone())
.unwrap_or(Visibility::Private);
let module = self
.store
.get_module_mut(&self.cursor.module_id)
.expect("current module must exist in store");
module.definitions.insert(
qualified_name.clone().into(),
Definition::Interface {
visibility: visibility.clone(),
ty: interface_ty,
name_span: *name_span,
definition: interface,
doc: doc.clone(),
},
);
let module_id = self.cursor.module_id.clone();
for (method_name, method_ty, go_hints) in method_defs {
let method_qualified_name = format!("{}.{}.{}", module_id, interface_name, method_name);
module.definitions.insert(
method_qualified_name.into(),
Definition::Value {
visibility: visibility.clone(),
ty: method_ty,
name_span: None, allowed_lints: vec![],
go_hints,
go_name: None,
doc: None,
},
);
}
self.check_interface_embedding(&qualified_name, interface_name, name_span);
}
fn check_interface_embedding(
&mut self,
qualified_name: &str,
interface_name: &str,
span: &Span,
) {
let interface = match self.store.get_interface(qualified_name) {
Some(iface) => iface,
None => return,
};
for parent_ty in &interface.parents {
if let Some(parent_id) = parent_ty.get_qualified_id()
&& parent_id == qualified_name
{
self.sink.push(diagnostics::infer::interface_self_embedding(
interface_name,
*span,
));
return; }
}
let mut visited = rustc_hash::FxHashSet::default();
let mut path = vec![qualified_name.to_string()];
visited.insert(qualified_name.to_string());
for parent_ty in &interface.parents {
if let Some(parent_id) = parent_ty.get_qualified_id()
&& let Some(cycle) = self.detect_interface_cycle(parent_id, &mut visited, &mut path)
{
self.sink
.push(diagnostics::infer::interface_embedding_cycle(&cycle, *span));
return; }
}
let mut inherited_methods: Vec<(String, Type, String)> = Vec::new();
let mut method_visited = rustc_hash::FxHashSet::default();
for parent_ty in &interface.parents {
if let Some(parent_id) = parent_ty.get_qualified_id() {
let parent_simple_name = parent_id.rsplit('.').next().unwrap_or(parent_id);
self.collect_interface_methods(
parent_id,
parent_simple_name,
&mut inherited_methods,
&mut method_visited,
);
}
}
let mut seen: rustc_hash::FxHashMap<String, (Type, String)> =
rustc_hash::FxHashMap::default();
for (method_name, method_ty, source) in &inherited_methods {
if let Some((existing_ty, existing_source)) = seen.get(method_name) {
if existing_ty != method_ty {
self.sink
.push(diagnostics::infer::interface_method_conflict(
interface_name,
method_name,
existing_source,
source,
*span,
));
}
} else {
seen.insert(method_name.clone(), (method_ty.clone(), source.clone()));
}
}
}
fn detect_interface_cycle(
&self,
current_id: &str,
visited: &mut rustc_hash::FxHashSet<String>,
path: &mut Vec<String>,
) -> Option<Vec<String>> {
if !visited.insert(current_id.to_string()) {
let simple = |id: &str| -> String { id.rsplit('.').next().unwrap_or(id).to_string() };
if let Some(position) = path.iter().position(|p| p == current_id) {
let mut cycle: Vec<String> = path[position..].iter().map(|p| simple(p)).collect();
cycle.push(simple(current_id));
return Some(cycle);
}
return None;
}
path.push(current_id.to_string());
if let Some(interface) = self.store.get_interface(current_id) {
for parent_ty in &interface.parents {
if let Some(parent_id) = parent_ty.get_qualified_id()
&& let Some(cycle) = self.detect_interface_cycle(parent_id, visited, path)
{
path.pop();
return Some(cycle);
}
}
}
path.pop();
visited.remove(current_id); None
}
fn collect_interface_methods(
&self,
interface_id: &str,
source_name: &str,
methods: &mut Vec<(String, Type, String)>,
visited: &mut rustc_hash::FxHashSet<String>,
) {
if !visited.insert(interface_id.to_string()) {
return;
}
if let Some(interface) = self.store.get_interface(interface_id) {
for (method_name, method_ty) in &interface.methods {
methods.push((
method_name.to_string(),
method_ty.clone(),
source_name.to_string(),
));
}
for parent_ty in &interface.parents {
if let Some(parent_id) = parent_ty.get_qualified_id() {
let parent_simple = parent_id.rsplit('.').next().unwrap_or(parent_id);
self.collect_interface_methods(parent_id, parent_simple, methods, visited);
}
}
}
}
fn impl_has_simple_type_params(&self, receiver_ty: &Type, generics: &[Generic]) -> bool {
let params = match receiver_ty {
Type::Constructor { params, .. } => params,
_ => return false,
};
if params.len() != generics.len() {
return false;
}
params
.iter()
.zip(generics.iter())
.all(|(param, generic)| matches!(param, Type::Parameter(name) if *name == generic.name))
}
}