use ecow::EcoString;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::collections::BTreeMap;
use syntax::ast::Visibility;
use syntax::program::MethodSignatures;
use syntax::types::{CompoundKind, Symbol, Type};
use crate::store::Store;
#[derive(Clone, Debug)]
pub enum MemberKind {
Field {
ty: Type,
visibility: Visibility,
},
Method {
ty: Type,
},
}
#[derive(Clone, Debug)]
pub struct ResolvedMember {
pub name: EcoString,
pub depth: usize,
pub embed_path: Vec<EcoString>,
pub declaring_type: Symbol,
pub indirect: bool,
pub kind: MemberKind,
}
#[derive(Clone, Debug)]
pub enum Resolution {
Found(ResolvedMember),
Ambiguous { sources: Vec<Symbol> },
NotFound,
}
pub fn has_direct_embed(store: &Store, ty: &Type) -> bool {
let Type::Nominal { id, .. } = store.deep_resolve_alias(&ty.strip_refs()) else {
return false;
};
store
.fields_of(id.as_str())
.is_some_and(|fields| fields.iter().any(|f| f.embedded))
}
pub fn resolve_selector(store: &Store, outer: &Type, name: &str) -> Resolution {
let entries = walk(store, outer);
resolve_in_entries(store, &entries, outer, name)
}
pub fn promoted_method_set(store: &Store, outer: &Type) -> MethodSignatures {
let entries = walk(store, outer);
let mut names: HashSet<EcoString> = HashSet::default();
for entry in &entries {
collect_member_names(store, &entry.ty, &mut names);
}
let mut result = MethodSignatures::default();
for name in names {
if let Resolution::Found(member) = resolve_in_entries(store, &entries, outer, &name)
&& let MemberKind::Method { ty } = member.kind
{
result.insert(name, ty);
}
}
result
}
#[derive(Clone)]
struct Entry {
ty: Type,
depth: usize,
indirect: bool,
multiples: bool,
embed_path: Vec<EcoString>,
}
fn walk(store: &Store, outer: &Type) -> Vec<Entry> {
let mut visited: Vec<Entry> = Vec::new();
let mut seen: HashSet<String> = HashSet::default();
let Some(root) = nominal_entry(store, outer.clone(), 0, false, false, Vec::new()) else {
return visited;
};
let mut current = vec![root];
let mut depth = 0;
while !current.is_empty() {
let mut next: Vec<Entry> = Vec::new();
for entry in ¤t {
if !seen.insert(type_key(&entry.ty)) {
continue;
}
visited.push(entry.clone());
let Some(id) = entry.ty.get_qualified_id() else {
continue;
};
if store.get_interface(id).is_some() {
continue;
}
let Some(fields) = store.fields_of(id) else {
continue;
};
for field in fields {
if !field.embedded {
continue;
}
let resolved_field = store.deep_resolve_alias(&field.ty);
let (target, is_pointer) = deref_once(&resolved_field);
let mut path = entry.embed_path.clone();
path.push(field.name.clone());
if let Some(child) = nominal_entry(
store,
target,
depth + 1,
entry.indirect || is_pointer,
entry.multiples,
path,
) {
next.push(child);
}
}
}
current = consolidate(next);
depth += 1;
}
visited
}
fn resolve_in_entries(store: &Store, entries: &[Entry], outer: &Type, name: &str) -> Resolution {
let mut by_depth: BTreeMap<usize, Vec<(&Entry, Candidate)>> = BTreeMap::new();
for entry in entries {
if let Some(candidate) = entry_candidate(store, &entry.ty, name) {
by_depth
.entry(entry.depth)
.or_default()
.push((entry, candidate));
}
}
let Some((_, candidates)) = by_depth.into_iter().next() else {
return Resolution::NotFound;
};
if let [(entry, candidate)] = candidates.as_slice()
&& !entry.multiples
{
return Resolution::Found(build_member(outer, name, entry, candidate));
}
let mut sources: Vec<Symbol> = candidates
.iter()
.map(|(_, c)| c.declaring_type.clone())
.collect();
sources.sort();
sources.dedup();
Resolution::Ambiguous { sources }
}
struct Candidate {
declaring_type: Symbol,
detail: CandidateDetail,
}
enum CandidateDetail {
Field { ty: Type, visibility: Visibility },
Method { ty: Type },
}
fn entry_candidate(store: &Store, ty: &Type, name: &str) -> Option<Candidate> {
let id = ty.get_qualified_id()?;
if store.get_interface(id).is_some() {
let methods = store.get_all_methods(ty, &Default::default());
let method_ty = methods.get(name)?.clone();
return Some(Candidate {
declaring_type: Symbol::from_raw(id),
detail: CandidateDetail::Method { ty: method_ty },
});
}
if let Some(method_ty) = store.get_own_methods(id).and_then(|m| m.get(name)) {
return Some(Candidate {
declaring_type: Symbol::from_raw(id),
detail: CandidateDetail::Method {
ty: method_ty.clone(),
},
});
}
if let Some(field) = store
.fields_of(id)
.and_then(|fields| fields.iter().find(|f| f.name == name))
{
return Some(Candidate {
declaring_type: Symbol::from_raw(id),
detail: CandidateDetail::Field {
ty: field.ty.clone(),
visibility: field.visibility,
},
});
}
None
}
fn build_member(outer: &Type, name: &str, entry: &Entry, candidate: &Candidate) -> ResolvedMember {
let kind = match &candidate.detail {
CandidateDetail::Field { ty, visibility } => MemberKind::Field {
ty: ty.clone(),
visibility: *visibility,
},
CandidateDetail::Method { ty } => {
let method_ty = if entry.depth == 0 {
ty.clone()
} else if !entry.indirect && method_has_pointer_receiver(ty) {
ty.with_replaced_first_param(&ref_of(outer))
} else {
ty.with_replaced_first_param(outer)
};
MemberKind::Method { ty: method_ty }
}
};
ResolvedMember {
name: name.into(),
depth: entry.depth,
embed_path: entry.embed_path.clone(),
declaring_type: candidate.declaring_type.clone(),
indirect: entry.indirect,
kind,
}
}
fn collect_member_names(store: &Store, ty: &Type, names: &mut HashSet<EcoString>) {
let Some(id) = ty.get_qualified_id() else {
return;
};
if store.get_interface(id).is_some() {
for key in store.get_all_methods(ty, &Default::default()).keys() {
names.insert(key.clone());
}
return;
}
if let Some(methods) = store.get_own_methods(id) {
for key in methods.keys() {
names.insert(key.clone());
}
}
if let Some(fields) = store.fields_of(id) {
for field in fields {
names.insert(field.name.clone());
}
}
}
fn nominal_entry(
store: &Store,
target: Type,
depth: usize,
indirect: bool,
multiples: bool,
embed_path: Vec<EcoString>,
) -> Option<Entry> {
let resolved = store.deep_resolve_alias(&target);
if !matches!(resolved, Type::Nominal { .. }) {
return None;
}
Some(Entry {
ty: resolved,
depth,
indirect,
multiples,
embed_path,
})
}
fn consolidate(list: Vec<Entry>) -> Vec<Entry> {
let mut result: Vec<Entry> = Vec::with_capacity(list.len());
let mut index_of: HashMap<String, usize> = HashMap::default();
for entry in list {
let key = type_key(&entry.ty);
if let Some(&i) = index_of.get(&key) {
result[i].multiples = true;
} else {
index_of.insert(key, result.len());
result.push(entry);
}
}
result
}
fn type_key(ty: &Type) -> String {
match ty {
Type::Nominal { id, params, .. } if params.is_empty() => id.as_str().to_string(),
Type::Nominal { id, params, .. } => {
let args: Vec<String> = params.iter().map(type_key).collect();
format!("{}<{}>", id, args.join(","))
}
other => other.to_string(),
}
}
fn deref_once(ty: &Type) -> (Type, bool) {
if ty.is_ref() {
(ty.inner().unwrap_or(Type::Error), true)
} else {
(ty.clone(), false)
}
}
fn method_has_pointer_receiver(method_ty: &Type) -> bool {
let body = match method_ty {
Type::Forall { body, .. } => body.as_ref(),
other => other,
};
matches!(body, Type::Function(f) if f.params.first().is_some_and(Type::is_ref))
}
fn ref_of(ty: &Type) -> Type {
Type::Compound {
kind: CompoundKind::Ref,
args: vec![ty.clone()],
}
}
#[cfg(test)]
mod tests {
use super::*;
use syntax::ast::{Annotation, Span, StructFieldDefinition, StructKind};
use syntax::program::Visibility as ProgVis;
use syntax::program::{Attributes, Definition, DefinitionBody, Interface};
const MODULE: &str = "m";
fn nominal(name: &str) -> Type {
Type::Nominal {
id: Symbol::from_parts(MODULE, name),
params: vec![],
underlying_ty: None,
}
}
fn value_method(owner: &str) -> Type {
Type::function(
vec![nominal(owner)],
vec![false],
vec![],
Box::new(Type::string()),
)
}
fn pointer_method(owner: &str) -> Type {
Type::function(
vec![ref_of(&nominal(owner))],
vec![false],
vec![],
Box::new(Type::string()),
)
}
fn interface_method() -> Type {
Type::function(vec![], vec![], vec![], Box::new(Type::string()))
}
fn field(name: &str, ty: Type, embedded: bool) -> StructFieldDefinition {
StructFieldDefinition {
doc: None,
attributes: vec![],
name: name.into(),
name_span: Span::dummy(),
annotation: Annotation::Unknown,
visibility: Visibility::Public,
ty,
embedded,
}
}
struct Builder {
store: Store,
}
impl Builder {
fn new() -> Self {
let mut store = Store::new();
store.add_module(MODULE);
Builder { store }
}
fn insert(&mut self, name: &str, body: DefinitionBody) -> &mut Self {
let def = Definition {
visibility: ProgVis::Public,
ty: nominal(name),
name: Some(name.into()),
name_span: None,
doc: None,
body,
};
self.store
.get_module_mut(MODULE)
.unwrap()
.definitions
.insert(Symbol::from_parts(MODULE, name), def);
self
}
fn struct_(
&mut self,
name: &str,
fields: Vec<StructFieldDefinition>,
methods: Vec<(&str, Type)>,
) -> &mut Self {
let mut method_map = MethodSignatures::default();
for (n, t) in methods {
method_map.insert(n.into(), t);
}
self.insert(
name,
DefinitionBody::Struct {
generics: vec![],
fields,
kind: StructKind::Record,
methods: method_map,
constructor: None,
attributes: Attributes::default(),
},
)
}
fn interface(&mut self, name: &str, methods: Vec<&str>, parents: Vec<&str>) -> &mut Self {
let mut method_map = MethodSignatures::default();
for n in methods {
method_map.insert(n.into(), interface_method());
}
self.insert(
name,
DefinitionBody::Interface {
definition: Interface {
name: name.into(),
generics: vec![],
parents: parents.into_iter().map(nominal).collect(),
methods: method_map,
},
},
)
}
}
fn vembed(target: &str) -> StructFieldDefinition {
field(target, nominal(target), true)
}
fn pembed(target: &str) -> StructFieldDefinition {
field(target, ref_of(&nominal(target)), true)
}
fn found(resolution: Resolution) -> ResolvedMember {
match resolution {
Resolution::Found(member) => member,
other => panic!("expected Found, got {other:?}"),
}
}
fn is_pointer_receiver(member: &ResolvedMember) -> bool {
match &member.kind {
MemberKind::Method { ty } => ty.get_function_params().unwrap()[0].is_ref(),
other => panic!("expected a method, got {other:?}"),
}
}
#[test]
fn direct_method_at_depth_zero() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
let member = found(resolve_selector(&b.store, &nominal("N0"), "m"));
assert_eq!(member.depth, 0);
assert!(!is_pointer_receiver(&member));
}
#[test]
fn value_embed_promotes_value_method() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N1"), "m"));
assert_eq!(member.depth, 1);
assert_eq!(member.embed_path, vec![EcoString::from("N0")]);
assert!(!member.indirect);
assert!(!is_pointer_receiver(&member));
}
#[test]
fn value_embed_of_pointer_method_is_pointer_only() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("pm", pointer_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N1"), "pm"));
assert!(!member.indirect);
assert!(is_pointer_receiver(&member));
}
#[test]
fn pointer_embed_puts_pointer_method_in_value_set() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("pm", pointer_method("N0"))]);
b.struct_("N1", vec![pembed("N0")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N1"), "pm"));
assert!(member.indirect);
assert!(!is_pointer_receiver(&member));
}
#[test]
fn pointer_embed_of_value_method_is_value() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
b.struct_("N1", vec![pembed("N0")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N1"), "m"));
assert!(member.indirect);
assert!(!is_pointer_receiver(&member));
}
#[test]
fn diamond_is_ambiguous() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![]);
b.struct_("N2", vec![vembed("N0")], vec![]);
b.struct_("N3", vec![vembed("N1"), vembed("N2")], vec![]);
assert!(matches!(
resolve_selector(&b.store, &nominal("N3"), "m"),
Resolution::Ambiguous { .. }
));
}
#[test]
fn shallower_path_shadows_deeper_diamond() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![]);
b.struct_("N3", vec![vembed("N0"), vembed("N1")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N3"), "m"));
assert_eq!(member.depth, 1);
}
#[test]
fn own_member_shadows_promoted() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![("m", value_method("N1"))]);
let member = found(resolve_selector(&b.store, &nominal("N1"), "m"));
assert_eq!(member.depth, 0);
assert_eq!(member.declaring_type.as_str(), "m.N1");
}
#[test]
fn field_promotes() {
let mut b = Builder::new();
b.struct_("N0", vec![field("f", Type::int(), false)], vec![]);
b.struct_("N1", vec![vembed("N0")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N1"), "f"));
assert_eq!(member.depth, 1);
assert!(matches!(member.kind, MemberKind::Field { .. }));
}
#[test]
fn field_and_method_collide_across_embeds() {
let mut b = Builder::new();
b.struct_("A", vec![field("x", Type::int(), false)], vec![]);
b.struct_("B", vec![], vec![("x", value_method("B"))]);
b.struct_("N2", vec![vembed("A"), vembed("B")], vec![]);
assert!(matches!(
resolve_selector(&b.store, &nominal("N2"), "x"),
Resolution::Ambiguous { .. }
));
}
#[test]
fn pointer_cycle_terminates_and_resolves() {
let mut b = Builder::new();
b.struct_("N0", vec![pembed("N1")], vec![("a", value_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![("bb", value_method("N1"))]);
assert_eq!(
found(resolve_selector(&b.store, &nominal("N0"), "a")).depth,
0
);
assert_eq!(
found(resolve_selector(&b.store, &nominal("N0"), "bb")).depth,
1
);
assert_eq!(
found(resolve_selector(&b.store, &nominal("N1"), "a")).depth,
1
);
assert!(matches!(
resolve_selector(&b.store, &nominal("N0"), "absent"),
Resolution::NotFound
));
}
#[test]
fn embedded_interface_promotes_value_callable() {
let mut b = Builder::new();
b.interface("I", vec!["speak"], vec![]);
b.struct_("N2", vec![vembed("I")], vec![]);
let member = found(resolve_selector(&b.store, &nominal("N2"), "speak"));
assert_eq!(member.depth, 1);
assert!(!is_pointer_receiver(&member));
}
#[test]
fn struct_embedding_interface_and_struct_with_same_method_is_ambiguous() {
let mut b = Builder::new();
b.interface("I", vec!["speak"], vec![]);
b.struct_("S", vec![], vec![("speak", value_method("S"))]);
b.struct_("N2", vec![vembed("I"), vembed("S")], vec![]);
assert!(matches!(
resolve_selector(&b.store, &nominal("N2"), "speak"),
Resolution::Ambiguous { .. }
));
assert!(!promoted_method_set(&b.store, &nominal("N2")).contains_key("speak"));
}
#[test]
fn method_set_includes_promoted_excludes_ambiguous() {
let mut b = Builder::new();
b.struct_(
"N0",
vec![],
vec![("m", value_method("N0")), ("pm", pointer_method("N0"))],
);
b.struct_("N1", vec![vembed("N0")], vec![("o", value_method("N1"))]);
let set = promoted_method_set(&b.store, &nominal("N1"));
assert!(set.contains_key("o"));
assert!(set.contains_key("m"));
assert!(set.contains_key("pm"));
assert!(!set.get("m").unwrap().get_function_params().unwrap()[0].is_ref());
assert!(set.get("pm").unwrap().get_function_params().unwrap()[0].is_ref());
}
#[test]
fn method_set_drops_ambiguous_diamond_member() {
let mut b = Builder::new();
b.struct_("N0", vec![], vec![("m", value_method("N0"))]);
b.struct_("N1", vec![vembed("N0")], vec![]);
b.struct_("N2", vec![vembed("N0")], vec![]);
b.struct_("N3", vec![vembed("N1"), vembed("N2")], vec![]);
assert!(!promoted_method_set(&b.store, &nominal("N3")).contains_key("m"));
}
#[test]
fn has_direct_embed_detects_embeds() {
let mut b = Builder::new();
b.struct_("N0", vec![field("f", Type::int(), false)], vec![]);
b.struct_("N1", vec![vembed("N0")], vec![]);
assert!(!has_direct_embed(&b.store, &nominal("N0")));
assert!(has_direct_embed(&b.store, &nominal("N1")));
assert!(has_direct_embed(&b.store, &ref_of(&nominal("N1"))));
}
}