use crate::Result;
use crate::Span;
use crate::diagnostic::{Diagnostic, WithErrorInfo};
use crate::pr;
use crate::project::TargetSpan;
use crate::resolver::NS_STD;
use crate::utils::IdGenerator;
use crate::utils::fold::{self, PrFold};
use super::Scope;
pub struct NameResolver<'a> {
pub root: &'a pr::ModuleDef,
pub def_module_path: &'a [String],
pub scopes: Vec<Scope>,
pub refs: Vec<pr::Path>,
pub scope_id_gen: &'a mut IdGenerator<usize>,
pub allow_recursive: bool,
pub target_spans: &'a mut Vec<(Span, crate::project::TargetSpan)>,
}
impl NameResolver<'_> {
pub fn fold_def_kind(&mut self, def: pr::DefKind) -> Result<pr::DefKind> {
Ok(match def {
pr::DefKind::Expr(var_def) => pr::DefKind::Expr(self.fold_expr_def(var_def)?),
pr::DefKind::Ty(ty_def) => pr::DefKind::Ty(self.fold_type_def(ty_def)?),
pr::DefKind::Module(_) | pr::DefKind::Import(_) | pr::DefKind::Unresolved(_) => {
unreachable!()
}
})
}
pub fn fold_import_def(
&mut self,
import_def: pr::ImportDef,
) -> Result<pr::ImportDef, Diagnostic> {
let target = import_def.as_simple().unwrap();
let target_fq = self.lookup_in_mod_tree(self.def_module_path, target)?;
Ok(pr::ImportDef::new_simple(target_fq, import_def.span))
}
}
impl fold::PrFold for NameResolver<'_> {
fn fold_def(&mut self, _def: pr::Def) -> Result<pr::Def> {
unreachable!()
}
fn fold_expr(&mut self, expr: pr::Expr) -> Result<pr::Expr> {
Ok(match expr.kind {
pr::ExprKind::Ident(ident) => {
let (target, target_ref) =
self.lookup_ident(&ident).with_span_fallback(expr.span)?;
if let (Some(span), Some(tr)) = (expr.span, target_ref) {
self.target_spans.push((span, tr));
}
pr::Expr {
kind: pr::ExprKind::Ident(ident),
target: Some(target),
..expr
}
}
pr::ExprKind::Lookup { .. } => pr::Expr {
kind: fold::fold_expr_kind(self, expr.kind)?,
..expr
},
pr::ExprKind::Func(func) => {
let scope_id = self.scope_id_gen.next();
let scope = Scope::new_of_func(scope_id, &func)?;
self.scopes.push(scope);
let func = fold::fold_func(self, *func);
self.scopes.pop();
pr::Expr {
kind: pr::ExprKind::Func(Box::new(func?)),
scope_id: Some(scope_id),
..expr
}
}
pr::ExprKind::Match(match_) => {
let subject = Box::new(self.fold_expr(*match_.subject)?);
let mut branches = Vec::with_capacity(match_.branches.len());
for branch in match_.branches {
let scope_id = self.scope_id_gen.next();
let scope = Scope::new_empty(scope_id);
self.scopes.push(scope);
let pattern = self.fold_pattern(branch.pattern)?;
let mut value = Box::new(self.fold_expr(*branch.value)?);
self.scopes.pop();
value.scope_id = Some(scope_id);
branches.push(pr::MatchBranch { pattern, value })
}
pr::Expr {
kind: pr::ExprKind::Match(pr::Match { subject, branches }),
..expr
}
}
pr::ExprKind::VarBinding(binding) => {
let bound = Box::new(self.fold_expr(*binding.bound)?);
let scope_id = self.scope_id_gen.next();
let mut scope = Scope::new_empty(scope_id);
scope.insert_local(binding.name.clone(), Some(binding.name_span));
self.scopes.push(scope);
let main = Box::new(self.fold_expr(*binding.main)?);
self.scopes.pop();
pr::Expr {
kind: pr::ExprKind::VarBinding(pr::VarBinding {
name: binding.name,
name_span: binding.name_span,
bound,
main,
}),
scope_id: Some(scope_id),
..expr
}
}
_ => pr::Expr {
kind: fold::fold_expr_kind(self, expr.kind)?,
..expr
},
})
}
fn fold_pattern(&mut self, pattern: pr::Pattern) -> Result<pr::Pattern> {
let binds = collect_pattern_binds(&pattern)?;
let scope = self.scopes.last_mut().unwrap();
for (name, span) in binds {
scope.insert_local(name, Some(span));
}
Ok(pattern)
}
fn fold_type(&mut self, ty: pr::Ty) -> Result<pr::Ty> {
Ok(match ty.kind {
pr::TyKind::Ident(ident) => {
let (target, target_span) =
self.lookup_ident(&ident).with_span_fallback(ty.span)?;
if let (Some(span), Some(tr)) = (ty.span, target_span) {
self.target_spans.push((span, tr));
}
pr::Ty {
kind: pr::TyKind::Ident(ident),
target: Some(target),
..ty
}
}
pr::TyKind::Func(ty_func) => {
if self.scopes.is_empty() {
let scope_id = self.scope_id_gen.next();
let scope = Scope::new_of_ty_func(scope_id, &ty_func)?;
self.scopes.push(scope);
let r = fold::fold_ty_func(self, ty_func);
self.scopes.pop();
pr::Ty {
kind: pr::TyKind::Func(r?),
scope_id: Some(scope_id),
..ty
}
} else {
if let Some(param) = ty_func.ty_params.first() {
return Err(Diagnostic::new_custom(
"generic type parameters are not allowed here",
)
.with_span(param.span));
}
let ty = pr::Ty {
kind: pr::TyKind::Func(ty_func),
..ty
};
fold::fold_type(self, ty)?
}
}
pr::TyKind::TupleComprehension(ref comp) => {
if comp
.body_name
.as_ref()
.is_some_and(|b| b != &comp.variable_name)
{
return Err(Diagnostic::new_custom(format!(
"expected field to be named {}",
comp.variable_name
)));
}
let scope_id = self.scope_id_gen.next();
let mut scope = Scope::new_empty(scope_id);
scope.insert_ty_local(&comp.variable_ty);
self.scopes.push(scope);
let mut ty = fold::fold_type(self, ty)?;
ty.scope_id = Some(scope_id);
self.scopes.pop();
ty
}
_ => fold::fold_type(self, ty)?,
})
}
}
impl NameResolver<'_> {
fn lookup_ident(&mut self, ident: &pr::Path) -> Result<(pr::Ref, Option<TargetSpan>)> {
for scope in self.scopes.iter().rev() {
if let Some((scope, offset, def_span)) = scope.get(ident.first()) {
if ident.len() != 1 {
return Err(Diagnostic::new_custom(format!(
"{} is a param, not a module",
ident.first()
)));
}
let target = pr::Ref::Local { scope, offset };
let span = def_span.map(TargetSpan::Span);
return Ok((target, span));
}
}
let fq = (self.lookup_in_mod_tree(self.def_module_path, ident)).inspect_err(|_| {
tracing::debug!("scopes: {:?}", self.scopes);
})?;
self.refs.push(fq.clone());
let span = Some(TargetSpan::Global(fq.clone()));
Ok((pr::Ref::Global(fq), span))
}
#[tracing::instrument("lookup", skip_all, fields(from = pr::Path::new(def_mod_fq).to_string()))]
fn lookup_in_mod_tree(&self, def_mod_fq: &[String], ident: &pr::Path) -> Result<pr::Path> {
tracing::debug!("lookup for {ident}");
let steps = ident.as_steps();
let (base_path, relative) = match ident.first() {
"project" => (pr::Path::empty(), pr::Path::new(&steps[1..])),
"module" => (pr::Path::new(def_mod_fq), pr::Path::new(&steps[1..])),
"super" => {
let mut path = pr::Path::new(def_mod_fq);
path.pop();
(path, pr::Path::new(&steps[1..]))
}
NS_STD => (pr::Path::from_name(NS_STD), pr::Path::new(&steps[1..])),
_ => (pr::Path::new(def_mod_fq), pr::Path::new(steps)),
};
let base_def = (self.root.get_submodule(base_path.as_steps()))
.ok_or_else(|| Diagnostic::new_custom("unknown name"))?;
self.lookup_in_module(base_def, base_path, relative)
}
fn lookup_in_module(
&self,
base_mod: &pr::ModuleDef,
mut base_fq: pr::Path,
mut steps: pr::Path,
) -> Result<pr::Path> {
let Some(first) = steps.pop_first() else {
return Ok(base_fq);
};
let Some(def) = base_mod.defs.get(&first) else {
return Err(Diagnostic::new_custom("name does not exist".to_string()));
};
base_fq.push(first);
match &def.kind {
pr::DefKind::Module(sub_module) => self.lookup_in_module(sub_module, base_fq, steps),
pr::DefKind::Import(import) => {
let mut new_path = import.as_simple().unwrap().clone();
new_path.extend(steps);
self.lookup_in_mod_tree(&[], &new_path)
}
pr::DefKind::Unresolved(Some(def_kind)) if def_kind.is_import() => {
let import = def_kind.as_import().unwrap();
let import_target = import.as_simple().unwrap();
let import_fq = self
.lookup_in_mod_tree(base_fq.parent(), import_target)
.with_span_fallback(def.span)?;
tracing::debug!("resolved import to {import_fq:?}, steps={steps}");
let mut new_path = import_fq;
new_path.extend(steps);
self.lookup_in_mod_tree(&[], &new_path)
}
pr::DefKind::Unresolved(None) => {
if steps.is_empty() && self.allow_recursive {
Ok(base_fq)
} else {
Err(Diagnostic::new_custom("recursive path"))
}
}
_ if steps.is_empty() => Ok(base_fq),
_ => Err(Diagnostic::new_custom("unknown name")),
}
}
}
fn collect_pattern_binds(pattern: &pr::Pattern) -> Result<Vec<(String, Span)>> {
match &pattern.kind {
pr::PatternKind::Literal(_) => Ok(vec![]),
pr::PatternKind::Bind(name) => Ok(vec![(name.clone(), pattern.span)]),
pr::PatternKind::Enum(_, inner) => Ok(inner
.as_ref()
.map(|p| collect_pattern_binds(p))
.transpose()?
.unwrap_or_default()),
pr::PatternKind::AnyOf(branches) => {
assert!(branches.len() >= 2);
let mut result: Option<Vec<(String, Span)>> = None;
for b in branches {
let binds = collect_pattern_binds(b)?;
if let Some(result) = &result {
let names_match = result.len() == binds.len()
&& result.iter().zip(&binds).all(|((a, _), (b, _))| a == b);
if !names_match {
return Err(Diagnostic::new_custom(
"patterns introduce different variable names",
)
.with_span(Some(b.span)));
}
} else {
result = Some(binds);
}
}
Ok(result.unwrap())
}
}
}