use crate::Result;
use crate::diagnostic::WithErrorInfo;
use crate::resolver::types::{TypeResolver, scope};
use crate::utils::fold::{self, PrFold};
use crate::{pr, utils};
impl fold::PrFold for super::TypeResolver<'_> {
#[tracing::instrument(name = "e", skip(self, node))]
fn fold_expr(&mut self, node: pr::Expr) -> Result<pr::Expr> {
tracing::debug!("{}", node.kind.as_ref());
let span = node.span;
let r = match node.kind {
pr::ExprKind::Ident(ident) => {
tracing::debug!("resolving ident {ident:?}...");
let target = node.target.as_ref().unwrap();
let named = self.get_ref(target).with_span(span)?;
tracing::debug!("... resolved to {}", named.as_ref());
let ty = match named {
scope::Named::Expr(expr) => expr.ty.as_deref().cloned().unwrap(),
scope::Named::Module => {
return Err(scope::err_name_kind("a value", "a module").with_span(span));
}
scope::Named::Ty {
is_framed: false, ..
} => {
return Err(scope::err_name_kind("a value", "a type").with_span(span));
}
scope::Named::Scoped(scoped) => match scoped {
scope::ScopedKind::Param { ty } => ty.clone(),
scope::ScopedKind::Local { ty } => ty.clone(),
scope::ScopedKind::LocalTy { .. }
| scope::ScopedKind::TyParam { .. }
| scope::ScopedKind::TyVar { .. } => {
return Err(scope::err_name_kind("a value", "a type")
.push_hint(format!("scoped = {scoped:?}"))
.with_span(span));
}
},
scope::Named::Ty {
is_framed: true,
ty,
framed_label,
} => {
let ty_framed = pr::Ty {
target: Some(target.clone()),
..pr::Ty::new(ident.clone())
};
let ty_inner = ty.clone();
return Ok(pr::Expr {
kind: pr::ExprKind::Ident(ident),
ty: Some(Box::new(pr::Ty::new(pr::TyFunc {
params: vec![pr::TyFuncParam {
constant: false,
ty: Some(ty_inner),
label: framed_label.map(str::to_string),
}],
body: Some(Box::new(ty_framed)),
ty_params: vec![],
}))),
..node
});
}
};
let (ty, ty_args) = self.introduce_ty_into_scope(ty, span.unwrap());
pr::Expr {
kind: pr::ExprKind::Ident(ident),
ty: Some(Box::new(ty)),
ty_args,
..node
}
}
pr::ExprKind::Lookup { base, lookup } => {
let base = Box::new(self.fold_expr(*base)?);
let base_ty = base.ty.as_deref().unwrap();
let target_ty = self
.resolve_tuple_lookup(base_ty, &lookup, span.unwrap())
.with_span(span)?;
let kind = pr::ExprKind::Lookup { base, lookup };
pr::Expr {
ty: Some(Box::new(target_ty)),
kind,
..node
}
}
pr::ExprKind::Variant(mut variant) => {
let inner_ty = if let Some(inner) = variant.inner {
let inner = self.fold_expr(*inner)?;
let inner_ty = inner.ty.as_deref().cloned().unwrap();
variant.inner = Some(Box::new(inner));
inner_ty
} else {
pr::Ty::new(pr::TyKind::Tuple(vec![]))
};
let enum_domain = pr::TyDomain::EnumVariants(vec![pr::TyDomainEnumVariant {
name: variant.name.clone(),
ty: inner_ty,
}]);
let ty = self.introduce_ty_var(enum_domain, span.unwrap());
pr::Expr {
ty: Some(Box::new(ty)),
..pr::Expr::new(variant)
}
}
pr::ExprKind::Call(pr::Call { subject, args }) => {
let subject = Box::new(self.fold_expr(*subject)?);
let subject_target = subject.target.as_ref().map(|ref_| {
self.get_ref(ref_).unwrap()
});
let is_framing = matches!(subject_target, Some(scope::Named::Ty { .. }));
let resolved = self.resolve_func_call(subject, args, span)?;
if is_framing {
let call = resolved.kind.into_call().unwrap();
let arg = call.args.into_iter().next();
let mut inner = arg.map(|a| a.expr).unwrap_or_else(|| {
pr::Expr::new(pr::ExprKind::Tuple(vec![]))
});
inner.ty = resolved.ty;
inner
} else {
resolved
}
}
pr::ExprKind::Func(func) => {
let func = self
.resolve_func(node.scope_id.unwrap(), func)
.with_span_fallback(span)?;
pr::Expr {
kind: pr::ExprKind::Func(func),
..node
}
}
pr::ExprKind::TypeAnnotation(ann) => {
let ty = self.fold_type(*ann.ty)?;
let mut expr = self.fold_expr(*ann.expr)?;
self.validate_expr_type(&mut expr, &ty, &|| None)?;
let mut ty = ty;
expr.span = span;
ty.span = span;
return Ok(pr::Expr {
ty: Some(Box::new(ty)),
..expr
});
}
pr::ExprKind::Match(_) => self.resolve_match(node)?,
pr::ExprKind::If(_) => self.resolve_if(node)?,
pr::ExprKind::Tuple(_) => self.resolve_tuple_constructor(node)?,
pr::ExprKind::VarBinding(binding) => {
let bound = Box::new(self.fold_expr(*binding.bound)?);
let bound_ty = bound.ty.as_deref().unwrap();
let mut scope = scope::Scope::new(node.scope_id.unwrap(), scope::ScopeKind::Nested);
scope.insert_local(bound_ty.clone());
self.scopes.push(scope);
let main = self.fold_expr(*binding.main)?;
let mapping = self.finalize_type_vars()?;
let main = utils::TypeReplacer::on_expr(main, mapping);
self.scopes.pop().unwrap();
pr::Expr {
ty: main.ty.clone(),
kind: pr::ExprKind::VarBinding(pr::VarBinding {
name: binding.name,
name_span: binding.name_span,
bound,
main: Box::new(main),
}),
..node
}
}
item => pr::Expr {
kind: fold::fold_expr_kind(self, item)?,
..node
},
};
let mut r = r;
r.span = r.span.or(span);
if r.ty.is_none() {
r.ty = Some(Box::new(self.infer_type(&r)?));
}
if let Some(scope_id) = r.scope_id {
r.ty.as_deref_mut().unwrap().scope_id = Some(scope_id);
}
Ok(r)
}
fn fold_type(&mut self, ty: pr::Ty) -> Result<pr::Ty> {
let ty = match ty.kind {
pr::TyKind::Func(mut ty_func) => {
for p in &mut ty_func.params {
if p.ty.is_none() {
p.ty = Some(self.introduce_ty_var(pr::TyDomain::Open, ty.span.unwrap()));
}
}
if ty_func.body.is_none() {
ty_func.body = Some(Box::new(
self.introduce_ty_var(pr::TyDomain::Open, ty.span.unwrap()),
));
}
pr::Ty {
kind: pr::TyKind::Func(ty_func),
..ty
}
}
_ => fold::fold_type(self, ty)?,
};
Ok(ty)
}
fn fold_pattern(&mut self, _pattern: pr::Pattern) -> Result<pr::Pattern> {
unreachable!()
}
}
impl TypeResolver<'_> {
fn resolve_if(&mut self, node: pr::Expr) -> Result<pr::Expr> {
let pr::ExprKind::If(if_else) = node.kind else {
unreachable!()
};
let bool = pr::Ty::new(pr::TyPrimitive::bool);
let mut condition = Box::new(self.fold_expr_or_recover(*if_else.condition, &bool));
self.validate_expr_type(&mut condition, &bool, &|| Some("if".into()))
.unwrap_or_else(self.push_diagnostic());
let then = Box::new(self.fold_expr(*if_else.then)?);
let ty = then.ty.as_deref().cloned().unwrap();
let mut els = Box::new(self.fold_expr(*if_else.els)?);
self.validate_expr_type(&mut els, &ty, &|| None)
.unwrap_or_else(self.push_diagnostic());
Ok(pr::Expr {
kind: pr::ExprKind::If(pr::If {
condition,
then,
els,
}),
ty: Some(Box::new(ty)),
..node
})
}
}