use std::collections::HashMap;
use super::super::module_resolver::ModuleResolver;
use super::super::sem_type::SemType;
use super::super::SemanticAnalyzer;
use crate::ast::{Expr, File};
fn substitute_named_in_sem(
ty: &SemType,
bindings: &HashMap<String, SemType>,
param_names: &std::collections::HashSet<String>,
) -> SemType {
let mut out = ty.clone();
for (name, concrete) in bindings {
if !param_names.contains(name) {
continue;
}
out = out.substitute_named(name, concrete);
}
out
}
fn unify_sem(pattern: &SemType, concrete: &SemType, out: &mut HashMap<String, SemType>) {
match (pattern, concrete) {
(SemType::Named(name), other) => {
out.entry(name.clone()).or_insert_with(|| other.clone());
}
(SemType::Array(p_inner), SemType::Array(c_inner))
| (SemType::Optional(p_inner), SemType::Optional(c_inner)) => {
unify_sem(p_inner, c_inner, out);
}
(SemType::Tuple(p_fields), SemType::Tuple(c_fields)) => {
for ((_, pt), (_, ct)) in p_fields.iter().zip(c_fields.iter()) {
unify_sem(pt, ct, out);
}
}
(
SemType::Dictionary {
key: p_key,
value: p_val,
},
SemType::Dictionary {
key: c_key,
value: c_val,
},
) => {
unify_sem(p_key, c_key, out);
unify_sem(p_val, c_val, out);
}
(
SemType::Generic {
base: p_base,
args: p_args,
},
SemType::Generic {
base: c_base,
args: c_args,
},
) if p_base == c_base => {
for (pa, ca) in p_args.iter().zip(c_args.iter()) {
unify_sem(pa, ca, out);
}
}
(
SemType::Closure {
params: p_params,
return_ty: p_ret,
},
SemType::Closure {
params: c_params,
return_ty: c_ret,
},
) => {
for (pt, ct) in p_params.iter().zip(c_params.iter()) {
unify_sem(pt, ct, out);
}
unify_sem(p_ret, c_ret, out);
}
_ => {}
}
}
impl<R: ModuleResolver> SemanticAnalyzer<R> {
pub(super) fn infer_type_invocation(
&self,
path: &[crate::ast::Ident],
type_args: &[crate::ast::Type],
args: &[(Option<crate::ast::Ident>, Expr)],
file: &File,
) -> SemType {
let name = path
.iter()
.map(|id| id.name.as_str())
.collect::<Vec<_>>()
.join("::");
if path.len() == 1 {
let scope_lookup = {
let stack = self.inference_scope_stack.borrow();
stack
.iter()
.rev()
.find_map(|frame| frame.get(&name).cloned())
};
let resolved_ty =
scope_lookup.or_else(|| self.local_let_bindings.get(&name).map(|(t, _)| t.clone()));
if let Some(SemType::Closure { return_ty, .. }) = resolved_ty {
return *return_ty;
}
}
if self.symbols.is_struct_qualified(&name) {
if type_args.is_empty() {
if let Some(inferred) = self.infer_struct_type_args(&name, args, file) {
SemType::Generic {
base: name,
args: inferred,
}
} else {
SemType::Named(name)
}
} else {
SemType::Generic {
base: name,
args: type_args.iter().map(SemType::from_ast).collect(),
}
}
} else if let Some(func_info) = self.symbols.get_function(&name) {
let raw = func_info
.return_type
.as_ref()
.map_or(SemType::Nil, SemType::from_ast);
self.specialise_generic_return(func_info, raw, args, file)
} else if path.len() >= 2 {
let (Some(first), Some(last)) = (path.first(), path.last()) else {
return SemType::Unknown;
};
let receiver = &first.name;
let method_name = &last.name;
if self.symbols.is_struct(receiver) {
if let Some(ret) = self.infer_method_return_from_impls(receiver, method_name) {
return ret;
}
}
if self.symbols.get_enum_variants(receiver).is_some() {
return SemType::Named(receiver.clone());
}
if let Some(ret) = self.lookup_qualified_function_return(path) {
return ret;
}
SemType::Unknown
} else {
SemType::Unknown
}
}
pub(super) fn infer_struct_type_args(
&self,
struct_name: &str,
args: &[(Option<crate::ast::Ident>, Expr)],
file: &File,
) -> Option<Vec<SemType>> {
let info = self.symbols.get_struct_qualified(struct_name)?;
if info.generics.is_empty() {
return None;
}
let param_names: Vec<String> = info
.generics
.iter()
.map(|p| p.name.name.clone())
.collect();
let fields: Vec<(String, crate::ast::Type)> = info
.fields
.iter()
.map(|f| (f.name.clone(), f.ty.clone()))
.collect();
let mut bindings: std::collections::HashMap<String, SemType> =
std::collections::HashMap::new();
for (name_opt, expr) in args {
let Some(arg_name) = name_opt.as_ref() else {
continue;
};
let Some((_, declared_ast)) = fields.iter().find(|(n, _)| n == &arg_name.name) else {
continue;
};
let declared_sem = SemType::from_ast(declared_ast);
let arg_sem = self.infer_type_sem(expr, file);
unify_sem(&declared_sem, &arg_sem, &mut bindings);
}
let mut out = Vec::with_capacity(param_names.len());
for name in ¶m_names {
out.push(bindings.remove(name)?);
}
Some(out)
}
pub(in crate::semantic) fn can_infer_struct_type_args(
&self,
struct_name: &str,
args: &[(Option<crate::ast::Ident>, Expr)],
file: &File,
) -> bool {
self.infer_struct_type_args(struct_name, args, file).is_some()
}
fn specialise_generic_return(
&self,
func_info: &super::super::symbol_table::FunctionInfo,
raw_ret: SemType,
args: &[(Option<crate::ast::Ident>, Expr)],
file: &File,
) -> SemType {
if func_info.generics.is_empty() {
return raw_ret;
}
let param_names: std::collections::HashSet<String> = func_info
.generics
.iter()
.map(|g| g.name.name.clone())
.collect();
let mut bindings: HashMap<String, SemType> = HashMap::new();
for (i, param) in func_info.params.iter().enumerate() {
let Some(declared_ast) = ¶m.ty else {
continue;
};
let arg_expr = args
.iter()
.find_map(|(n, e)| {
n.as_ref()
.filter(|name| name.name == param.name.name)
.map(|_| e)
})
.or_else(|| args.get(i).map(|(_, e)| e));
let Some(arg) = arg_expr else { continue };
let declared_sem = SemType::from_ast(declared_ast);
let arg_sem = self.infer_type_sem(arg, file);
unify_sem(&declared_sem, &arg_sem, &mut bindings);
}
substitute_named_in_sem(&raw_ret, &bindings, ¶m_names)
}
fn lookup_qualified_function_return(&self, path: &[crate::ast::Ident]) -> Option<SemType> {
let last = path.last()?;
let segments: Vec<&str> = path
.iter()
.take(path.len().saturating_sub(1))
.map(|i| i.name.as_str())
.collect();
let look = |symbols: &super::super::symbol_table::SymbolTable| -> Option<SemType> {
let mut current = symbols;
for part in &segments {
match current.modules.get(*part) {
Some(info) => current = &info.symbols,
None => return None,
}
}
current.get_function(&last.name).map(|f| {
f.return_type
.as_ref()
.map_or(SemType::Nil, SemType::from_ast)
})
};
if let Some(ty) = look(&self.symbols) {
return Some(ty);
}
for (_, symbols) in self.module_cache.values() {
if let Some(ty) = look(symbols) {
return Some(ty);
}
}
None
}
pub(super) fn build_match_arm_scope_for_type(
&self,
scrutinee_ty: &SemType,
pattern: &crate::ast::Pattern,
) -> HashMap<String, SemType> {
use crate::ast::Pattern;
if let SemType::Optional(inner) = scrutinee_ty {
let mut frame = HashMap::new();
if let Pattern::Variant { name, bindings } = pattern {
if name.name == "some" {
if let Some(b) = bindings.first() {
frame.insert(b.name.clone(), (**inner).clone());
}
}
}
return frame;
}
let stripped = scrutinee_ty.strip_optional();
let (enum_name, receiver_args): (String, Vec<SemType>) = match &stripped {
SemType::Generic { base, args } => (base.clone(), args.clone()),
SemType::Named(n) => (n.clone(), Vec::new()),
SemType::Primitive(_)
| SemType::Array(_)
| SemType::Optional(_)
| SemType::Tuple(_)
| SemType::Dictionary { .. }
| SemType::Closure { .. }
| SemType::Unknown
| SemType::InferredEnum
| SemType::Nil => (stripped.display(), Vec::new()),
};
self.build_match_arm_scope(&enum_name, pattern, &receiver_args)
}
pub(super) fn build_match_arm_scope(
&self,
enum_name: &str,
pattern: &crate::ast::Pattern,
receiver_args: &[SemType],
) -> HashMap<String, SemType> {
use crate::ast::Pattern;
let mut frame = HashMap::new();
let Pattern::Variant { name, bindings } = pattern else {
return frame;
};
let variant_field_tys = self
.lookup_enum_variant_field_types(enum_name, &name.name)
.unwrap_or_default();
let generic_param_names: Vec<String> = if receiver_args.is_empty() {
Vec::new()
} else {
self.symbols
.get_generics(enum_name)
.map(|g| g.iter().map(|p| p.name.name.clone()).collect())
.unwrap_or_default()
};
for (i, ident) in bindings.iter().enumerate() {
if let Some(ty) = variant_field_tys.get(i) {
let mut sem = ty.clone();
if !generic_param_names.is_empty() {
for (param, arg) in generic_param_names.iter().zip(receiver_args.iter()) {
sem = sem.substitute_named(param, arg);
}
}
frame.insert(ident.name.clone(), sem);
}
}
frame
}
fn lookup_enum_variant_field_types(
&self,
enum_name: &str,
variant_name: &str,
) -> Option<Vec<SemType>> {
if let Some(info) = self.symbols.enums.get(enum_name) {
if let Some(fields) = info.variant_fields.get(variant_name) {
return Some(fields.iter().map(|f| SemType::from_ast(&f.ty)).collect());
}
}
for (_, symbols) in self.module_cache.values() {
if let Some(info) = symbols.enums.get(enum_name) {
if let Some(fields) = info.variant_fields.get(variant_name) {
return Some(fields.iter().map(|f| SemType::from_ast(&f.ty)).collect());
}
}
}
None
}
fn infer_method_return_from_impls(
&self,
struct_name: &str,
method_name: &str,
) -> Option<SemType> {
let trait_names = self.symbols.get_all_traits_for_struct(struct_name);
for trait_name in trait_names {
if let Some(trait_info) = self.symbols.get_trait(&trait_name) {
for m in &trait_info.methods {
if m.name.name == method_name {
return Some(
m.return_type
.as_ref()
.map_or(SemType::Nil, SemType::from_ast),
);
}
}
}
}
None
}
}