use crate::diagnostic::{Diagnostic, WithErrorInfo};
use crate::{pr, utils};
use crate::Result;
use crate::resolver::types::scope::ScopeKind;
use crate::utils::fold::PrFold;
use super::scope::{Scope, TyRef};
impl<'a> super::TypeResolver<'a> {
pub fn resolve_match(&mut self, expr: pr::Expr) -> Result<pr::Expr> {
let pr::ExprKind::Match(match_) = expr.kind else {
panic!()
};
let subject = Box::new(self.fold_expr(*match_.subject)?);
let subject_ty = subject.ty.as_ref().unwrap();
let mut branches = Vec::with_capacity(match_.branches.len());
let mut ty = None;
for branch in match_.branches {
self.scopes
.push(Scope::new(branch.scope_id.unwrap(), ScopeKind::Nested));
let (pattern, bound_tys) = self.resolve_pattern(subject_ty, branch.pattern)?;
let scope = self.scopes.last_mut().unwrap();
for bound_ty in bound_tys {
scope.insert_local(bound_ty.clone());
}
let value = self.fold_expr(*branch.value)?;
let mapping = self.finalize_type_vars()?;
let mut value = utils::TypeReplacer::on_expr(value, mapping);
let scope = self.scopes.pop().unwrap();
match &ty {
None => {
ty = value.ty.clone();
}
Some(ty) => {
self.validate_expr_type(&mut value, ty, &|| Some("match".into()))?;
}
}
let value = Box::new(value);
branches.push(pr::MatchBranch {
scope_id: Some(scope.id),
pattern,
value,
})
}
let ty = ty.unwrap();
Ok(pr::Expr {
kind: pr::ExprKind::Match(pr::Match { subject, branches }),
ty: Some(ty),
..expr
})
}
pub fn resolve_pattern(
&mut self,
subject_ty: &pr::Ty,
pattern: pr::Pattern,
) -> Result<(pr::Pattern, Vec<pr::Ty>)> {
match pattern.kind {
pr::PatternKind::Enum(variant_name, inner) => {
let subject_ty_mat = self.get_ty_mat(subject_ty)?;
let (tag, variant_ty) = match &subject_ty_mat {
TyRef::Ty(t) => {
let pr::TyKind::Enum(variants) = &t.kind else {
return Err(Diagnostic::new_custom("expected an enum")
.with_span(Some(pattern.span)));
};
let (tag, variant) = lookup_variant(variants, &variant_name)
.with_span(Some(pattern.span))?;
(Some(tag), variant.ty.clone())
}
TyRef::Param(_) => {
return Err(Diagnostic::new_custom("expected an enum")
.push_hint("found type parameter, which might not be an enum")
.with_span(Some(pattern.span)));
}
TyRef::Var(_, id) => {
let variant_ty = if inner.is_some() {
self.introduce_ty_var(pr::TyDomain::Open, pattern.span)
} else {
pr::Ty::new(pr::TyKind::Tuple(vec![]))
};
let restriction =
pr::TyDomain::EnumVariants(vec![pr::TyDomainEnumVariant {
name: variant_name.clone(),
ty: variant_ty.clone(),
}]);
let scope = self.get_ty_var_scope();
scope.infer_type_var_in_domain(*id, restriction);
(None, variant_ty)
}
};
let (inner, bound_tys) = if let Some(inner) = inner {
let (inner, bound_tys) = self.resolve_pattern(&variant_ty, *inner)?;
(Some(Box::new(inner)), bound_tys)
} else {
(None, Vec::new())
};
let pattern = pr::Pattern {
kind: pr::PatternKind::Enum(variant_name, inner),
variant_tag: tag,
..pattern
};
Ok((pattern, bound_tys))
}
pr::PatternKind::AnyOf(branches) => {
assert!(!branches.is_empty());
let mut res = Vec::with_capacity(branches.len());
let mut res_bound_tys = None;
for br in branches {
let span = br.span;
let (pat, bound_tys) = self.resolve_pattern(subject_ty, br)?;
res.push(pat);
if let Some(res) = &res_bound_tys {
for (e, f) in std::iter::zip(res, bound_tys) {
self.validate_type(&f, e, &|| None)
.with_span_fallback(f.span)
.with_span_fallback(Some(span))?;
}
} else {
res_bound_tys = Some(bound_tys);
}
}
let pattern = pr::Pattern {
kind: pr::PatternKind::AnyOf(res),
..pattern
};
Ok((pattern, res_bound_tys.unwrap()))
}
pr::PatternKind::Bind(_) => {
let mut bound_ty = subject_ty.clone();
bound_ty.span = Some(pattern.span);
Ok((pattern, vec![bound_ty]))
}
pr::PatternKind::Literal(l) => {
let found_ty = self.infer_type_of_literal(&l, Some(pattern.span));
self.validate_type(&found_ty, subject_ty, &|| Some("match".into()))
.unwrap_or_else(self.push_diagnostic());
let pattern = pr::Pattern {
kind: pr::PatternKind::Literal(l),
..pattern
};
Ok((pattern, vec![]))
}
}
}
}
pub fn lookup_variant<'a>(
variants: &'a [pr::TyEnumVariant],
variant_name: &str,
) -> Result<(usize, &'a pr::TyEnumVariant), Diagnostic> {
variants
.iter()
.enumerate()
.find(|(_, v)| v.name == variant_name)
.ok_or_else(|| Diagnostic::new_custom("variant does not exist"))
}
pub fn lookup_variant_in_domain<'a>(
variants: &'a [pr::TyDomainEnumVariant],
variant_name: &str,
) -> Result<(usize, &'a pr::TyDomainEnumVariant), Diagnostic> {
variants
.iter()
.enumerate()
.find(|(_, v)| v.name == variant_name)
.ok_or_else(|| Diagnostic::new_custom("variant does not exist"))
}