use std::cell::RefCell;
use std::collections::HashMap;
use crate::diagnostic::{Diagnostic, DiagnosticCode, WithErrorInfo};
use crate::pr;
use crate::{Result, Span};
use crate::{printer, utils};
use super::TypeResolver;
#[derive(Debug)]
pub struct Scope {
pub(super) id: usize,
pub(super) kind: ScopeKind,
pub(super) names: append_only_vec::AppendOnlyVec<ScopedKind>,
pub(super) ty_var_constraints: RefCell<Vec<TyVarConstraint>>,
}
#[derive(Debug)]
pub enum ScopeKind {
Isolated,
Nested,
}
#[derive(Debug, Clone)]
pub enum ScopedKind {
Param {
ty: pr::Ty,
},
Local {
ty: pr::Ty,
},
LocalTy {
ty: pr::Ty,
},
TyParam {
name: String,
domain: pr::TyDomain,
},
TyVar(TyVar),
}
pub type TyVarId = usize;
#[derive(Debug, Clone)]
pub struct TyVar {
pub name_hint: Option<String>,
pub span: Option<Span>,
}
#[derive(Debug)]
pub enum TyVarConstraint {
IsTy(TyVarId, pr::Ty),
Equals(TyVarId, TyVarId),
InDomain(TyVarId, pr::TyDomain),
}
#[derive(Debug, strum::AsRefStr)]
pub enum Named<'a> {
Expr(&'a pr::Expr),
Ty {
ty: &'a pr::Ty,
is_framed: bool,
framed_label: Option<&'a str>,
},
Module,
Scoped(&'a ScopedKind),
}
#[derive(Debug, Clone)]
pub enum TyRef<'a> {
Ty(&'a pr::Ty),
Param(usize),
#[allow(dead_code)] Var(usize, usize),
}
impl Scope {
pub fn new(id: usize, kind: ScopeKind) -> Self {
Self {
id,
kind,
names: Default::default(),
ty_var_constraints: RefCell::new(Vec::new()),
}
}
pub fn for_ty_vars(&self) -> bool {
matches!(self.kind, ScopeKind::Isolated)
}
pub fn insert_type_params(&mut self, type_params: &[pr::TyParam]) {
for gtp in type_params {
let scoped = ScopedKind::TyParam {
name: gtp.name.clone(),
domain: gtp.domain.clone(),
};
self.names.push(scoped);
}
}
pub fn insert_params(&mut self, func: &pr::Func) -> Result<(), Vec<Diagnostic>> {
let mut d = Vec::new();
for param in &func.params {
let Some(ty) = param.ty.clone() else {
d.push(
Diagnostic::new_custom("missing type annotations").with_span(Some(param.span)),
);
continue;
};
let scoped = ScopedKind::Param { ty };
self.names.push(scoped);
}
if d.is_empty() { Ok(()) } else { Err(d) }
}
#[must_use]
pub fn insert_type_var(
&self,
name_hint: Option<String>,
span: Span,
mut domain: pr::TyDomain,
) -> usize {
let type_arg = TyVar {
name_hint,
span: Some(span),
};
let var_id = self.names.push(ScopedKind::TyVar(type_arg));
if let pr::TyDomain::TupleHasFields(fields) = &mut domain {
for f in fields {
f.span = span;
}
}
if !matches!(domain, pr::TyDomain::Open) {
self.ty_var_constraints
.borrow_mut()
.push(TyVarConstraint::InDomain(var_id, domain));
}
var_id
}
pub fn insert_local(&mut self, ty: pr::Ty) -> usize {
let local = ScopedKind::Local { ty };
self.names.push(local)
}
pub fn insert_local_ty(&mut self, ty: pr::Ty) -> usize {
let local = ScopedKind::LocalTy { ty };
self.names.push(local)
}
pub fn infer_type_var(&self, id: TyVarId, ty: pr::Ty) {
tracing::debug!("inferring {id:?} is {}", crate::printer::print_ty(&ty));
let mut constraints = self.ty_var_constraints.borrow_mut();
constraints.push(TyVarConstraint::IsTy(id, ty));
}
pub fn infer_type_var_in_domain(&self, id: TyVarId, domain: pr::TyDomain) {
tracing::debug!("inferring {id:?} in domain {domain:?}");
let mut constraints = self.ty_var_constraints.borrow_mut();
constraints.push(TyVarConstraint::InDomain(id, domain));
}
pub fn infer_type_vars_equal(&self, a: TyVarId, b: TyVarId) {
tracing::debug!("inferring equality between {a:?} and {b:?}");
let mut constraints = self.ty_var_constraints.borrow_mut();
constraints.push(TyVarConstraint::Equals(a, b));
}
}
impl<'a> TypeResolver<'a> {
pub fn get_ty_var_scope(&self) -> &Scope {
let mut stack = self.scopes.iter().rev();
stack.find(|s| s.for_ty_vars()).unwrap()
}
pub fn get_ty_var_scope_mut(&mut self) -> &mut Scope {
let mut stack = self.scopes.iter_mut().rev();
stack.find(|s| s.for_ty_vars()).unwrap()
}
pub(super) fn get_ref(&'a self, target: &pr::Ref) -> Result<Named<'a>> {
tracing::trace!("get_ident: {target:?}");
match target {
pr::Ref::Global(tgt_fq) => {
let def = self.root_mod.get(tgt_fq);
match &def.unwrap_or_else(|| panic!("cannot find {tgt_fq}")).kind {
pr::DefKind::Expr(expr) => Ok(Named::Expr(&expr.value)),
pr::DefKind::Ty(def) => Ok(Named::Ty {
ty: &def.ty,
is_framed: def.is_framed,
framed_label: def.framed_label.as_deref(),
}),
pr::DefKind::Module(_) => Ok(Named::Module),
pr::DefKind::Unresolved(_) | pr::DefKind::Import(_) => unreachable!(),
}
}
pr::Ref::Local { scope, offset } => {
let scope = self
.scopes
.iter()
.find(|s| s.id == *scope)
.ok_or_else(|| panic!("cannot find scope: {scope}"))
.unwrap();
Ok(Named::Scoped(&scope.names[*offset]))
}
}
}
pub fn get_ty_param(&self, param_id: usize) -> (&String, &pr::TyDomain) {
let scope = self.scopes.last().unwrap();
let scoped = &scope.names[param_id];
let ScopedKind::TyParam { name, domain } = scoped else {
panic!()
};
(name, domain)
}
pub fn get_ty_var(&self, id: TyVarId) -> &TyVar {
let scoped = &self.get_ty_var_scope().names[id];
let ScopedKind::TyVar(var) = scoped else {
panic!()
};
var
}
pub fn get_ty_mat<'t>(&'t self, ty: &'t pr::Ty) -> Result<TyRef<'t>> {
let pr::TyKind::Ident(_) = &ty.kind else {
return Ok(TyRef::Ty(ty));
};
let target = ty.target.as_ref().unwrap();
let named = self.get_ref(target).with_span(ty.span)?;
match named {
Named::Ty {
is_framed: true, ..
} => {
Ok(TyRef::Ty(ty))
}
Named::Ty {
is_framed: false,
ty,
..
} => {
self.get_ty_mat(ty)
}
Named::Scoped(scoped) => {
let pr::Ref::Local { scope, offset } = target else {
panic!()
};
match scoped {
ScopedKind::LocalTy { ty } => self.get_ty_mat(ty),
ScopedKind::TyParam { .. } => Ok(TyRef::Param(*offset)),
ScopedKind::TyVar(_) => {
Ok(TyRef::Var(*scope, *offset))
}
ScopedKind::Param { ty, .. } => Err(err_name_kind("a type", "a value")
.push_hint(format!("got param of type `{}`", printer::print_ty(ty)))
.with_span(ty.span)),
ScopedKind::Local { ty, .. } => Err(err_name_kind("a type", "a value")
.push_hint(format!("got local var of type `{}`", printer::print_ty(ty)))
.with_span(ty.span)),
}
}
Named::Expr(_) => Err(err_name_kind("a type", "a value").with_span(ty.span)),
Named::Module => Err(err_name_kind("a type", "a module").with_span(ty.span)),
}
}
pub fn introduce_ty_into_scope(&mut self, ty: pr::Ty, span: Span) -> (pr::Ty, Vec<pr::Ty>) {
let pr::TyKind::Func(mut ty_func) = ty.kind else {
return (ty, Vec::new());
};
if ty_func.ty_params.is_empty() {
return (
pr::Ty {
kind: pr::TyKind::Func(ty_func),
..ty
},
Vec::new(),
);
}
let mut mapping = HashMap::new();
let mut ty_args = Vec::with_capacity(ty_func.ty_params.len());
let scope = self.get_ty_var_scope_mut();
tracing::debug!(
"introducing generics for ty_func={} into scope {}",
crate::printer::print_ty(&pr::Ty::new(ty_func.clone())),
scope.id,
);
for (gtp_position, gtp) in ty_func.ty_params.drain(..).enumerate() {
let gtp_ref = pr::Ref::Local {
scope: ty.scope_id.unwrap(), offset: gtp_position,
};
let mut ty_arg_ident = pr::Ty::new(
pr::Path::new(vec![gtp.name.clone()]),
);
let offset = scope.insert_type_var(Some(gtp.name), span, gtp.domain);
ty_arg_ident.target = Some(pr::Ref::Local {
scope: scope.id, offset,
});
mapping.insert(gtp_ref, ty_arg_ident.clone());
ty_args.push(ty_arg_ident);
}
let ty = pr::Ty {
kind: pr::TyKind::Func(ty_func),
..ty
};
(utils::TypeReplacer::on_ty(ty, mapping), ty_args)
}
pub fn introduce_ty_var(&self, domain: pr::TyDomain, span: Span) -> pr::Ty {
let scope = self.get_ty_var_scope();
let var_id = scope.insert_type_var(None, span, domain);
pr::Ty {
span: Some(span),
target: Some(pr::Ref::Local {
scope: scope.id,
offset: var_id,
}),
..pr::Ty::new(pr::TyKind::Ident(pr::Path::from_name("_")))
}
}
}
pub(super) fn err_name_kind(expected: &str, found: &str) -> Diagnostic {
Diagnostic::new(
format!("expected {expected}, found {found}"),
DiagnosticCode::NAME_KIND,
)
}