use std::collections::HashMap;
use lutra_bin::ident;
use crate::diagnostic::{Diagnostic, DiagnosticCode, WithErrorInfo};
use crate::pr::{self, Ty};
use crate::utils::fold::PrFold;
use crate::{Result, Span, printer, utils};
use super::scope;
impl super::TypeResolver<'_> {
pub fn resolve_tuple_constructor(&mut self, node: pr::Expr) -> Result<pr::Expr> {
let pr::ExprKind::Tuple(fields_in) = node.kind else {
unreachable!()
};
let mut fields = Vec::with_capacity(fields_in.len());
let mut ty_fields: Vec<pr::TyTupleField> = Vec::with_capacity(fields_in.len());
let mut diag = None;
for f in fields_in {
let name = (f.name.clone()).or_else(|| self.infer_tuple_field_name(&f.expr));
let expr = match self.fold_expr(f.expr) {
Ok(e) => e,
Err(d) => {
self.collect_diag(&mut diag, d);
continue;
}
};
let ty = expr.ty.as_deref().cloned().unwrap();
if f.unpack {
let ty_ref = self.get_ty_mat(&ty).unwrap();
match ty_ref {
scope::TyRef::Ty(t) => {
if !t.kind.is_tuple() {
self.collect_diag(
&mut diag,
Diagnostic::new(
"only tuples can be unpacked",
DiagnosticCode::TYPE,
)
.with_span(expr.span)
.push_hint(format!("got type {}", printer::print_ty(t))),
);
continue;
}
}
scope::TyRef::Param(id) => {
let (param_name, domain) = self.get_ty_param(id);
let pr::TyDomain::TupleHasFields(_) = domain else {
self.collect_diag(
&mut diag,
error_lookup_into_unpack_of_ty_param(param_name),
);
continue;
};
}
scope::TyRef::Var(_, o) => {
let domain = pr::TyDomain::TupleHasFields(vec![]);
let scope = self.get_ty_var_scope();
scope.infer_type_var_in_domain(o, domain);
}
};
}
fields.push(pr::TupleField {
name: f.name,
unpack: f.unpack,
expr,
});
ty_fields.push(pr::TyTupleField {
name,
unpack: f.unpack,
ty,
});
}
if let Some(d) = diag {
return Err(d);
}
let kind = pr::ExprKind::Tuple(fields);
let ty = pr::Ty::new(pr::TyKind::Tuple(ty_fields));
Ok(pr::Expr {
kind,
ty: Some(Box::new(ty)),
..node
})
}
fn infer_tuple_field_name(&self, field: &pr::Expr) -> Option<String> {
match &field.kind {
pr::ExprKind::Lookup { base: _, lookup } => match lookup {
pr::Lookup::Name(name) => Some(name.clone()),
pr::Lookup::Position(_) => None,
},
pr::ExprKind::Ident(ident) => Some(ident.last().to_string()),
_ => None,
}
}
pub fn resolve_tuple_lookup(
&mut self,
base: &Ty,
lookup: &pr::Lookup,
span: Span,
) -> Result<pr::Ty> {
let base_ref = self.get_ty_mat(base)?;
match base_ref {
scope::TyRef::Ty(b) => {
let base = b.clone();
self.lookup_in_tuple(&base, lookup, span)
}
scope::TyRef::Param(id) => Ok(self.lookup_name_in_ty_param(lookup, id)?),
scope::TyRef::Var(_, o) => {
let field_ty = self.introduce_ty_var(pr::TyDomain::Open, span);
let domain = pr::TyDomain::TupleHasFields(vec![pr::TyDomainTupleField {
location: lookup.clone(),
ty: field_ty.clone(),
span,
}]);
let scope = self.get_ty_var_scope();
scope.infer_type_var_in_domain(o, domain);
Ok(field_ty)
}
}
}
pub fn lookup_in_tuple(
&mut self,
base: &Ty,
lookup: &pr::Lookup,
span: Span,
) -> Result<pr::Ty> {
match &base.kind {
pr::TyKind::Tuple(fields) => {
let r = match lookup {
pr::Lookup::Name(name) => self.lookup_name_in_tuple(fields, name)?,
pr::Lookup::Position(pos) => self
.lookup_position_in_tuple(fields, *pos as usize, 0)
.map(|x| x.ok())?,
};
r.ok_or_else(|| error_no_field(base, lookup).with_span(Some(span)))
}
pr::TyKind::TupleComprehension(comp) => {
if comp.body_name.is_none() && lookup.is_name() {
return Err(error_no_field(base, lookup).with_span(Some(span)));
}
let var_input = self.resolve_tuple_lookup(&comp.tuple, lookup, span)?;
let var_ref = pr::Ref::Local {
scope: base.scope_id.unwrap(),
offset: 0,
};
let mapping = HashMap::from_iter(Some((var_ref, var_input)));
Ok(utils::TypeReplacer::on_ty(*comp.body_ty.clone(), mapping))
}
pr::TyKind::Ident(_) => {
let target = base.target.as_ref().unwrap();
let scope::Named::Ty {
ty,
is_framed,
framed_label,
} = self.get_ref(target).unwrap()
else {
unreachable!();
};
assert!(is_framed);
match lookup {
pr::Lookup::Name(n) if Some(n.as_str()) == framed_label => {}
pr::Lookup::Position(0) => {}
_ => {
let label_hint = framed_label
.map(|l| format!("`.{l}` or "))
.unwrap_or_default();
return Err(error_no_field(base, lookup)
.with_span(Some(span))
.push_hint(format!(
"{} is a framed type. Inner value can be accessed with {label_hint}`.0`",
printer::print_ty(base)
)));
}
}
Ok(ty.clone())
}
_ => Err(Diagnostic::new(
format!("lookup expected a tuple, found {}", printer::print_ty(base)),
DiagnosticCode::TYPE,
)),
}
}
fn lookup_name_in_tuple(
&self,
fields: &[pr::TyTupleField],
name: &str,
) -> Result<Option<pr::Ty>> {
for field in fields {
if !field.unpack {
if field.matches_name(name) {
return Ok(Some(field.ty.clone()));
}
} else {
let base_ref = self.get_ty_mat(&field.ty)?;
let base = match base_ref {
scope::TyRef::Ty(b) => b,
scope::TyRef::Param(id) => {
let lookup = pr::Lookup::Name(name.to_string());
return self.lookup_name_in_ty_param(&lookup, id).map(Some);
}
scope::TyRef::Var(_, _) => {
return Err(error_lookup_into_unpack_of_ty_var());
}
};
let pr::TyKind::Tuple(fields) = &base.kind else {
panic!();
};
if let Some(target) = self.lookup_name_in_tuple(fields, name)? {
return Ok(Some(target));
}
}
}
Ok(None)
}
fn lookup_position_in_tuple(
&self,
ty_fields: &[pr::TyTupleField],
position: usize,
passed_parent: usize,
) -> Result<Result<pr::Ty, usize>> {
let mut passed = passed_parent;
for f in ty_fields {
if !f.unpack {
if passed == position {
return Ok(Ok(f.ty.clone()));
}
passed += 1;
} else {
let ty = self.get_ty_mat(&f.ty).unwrap();
let ty = match ty {
scope::TyRef::Ty(t) => t,
scope::TyRef::Param(param_id) => {
let pos = position - passed;
return self.lookup_position_in_ty_param(param_id, pos).map(Ok);
}
scope::TyRef::Var(_, _) => {
return Err(error_lookup_into_unpack_of_ty_var());
}
};
let pr::TyKind::Tuple(ty_fields) = &ty.kind else {
panic!()
};
match self.lookup_position_in_tuple(ty_fields, position, passed)? {
Ok(t) => return Ok(Ok(t)),
Err(p) => passed = p,
}
}
}
Ok(Err(passed))
}
fn lookup_position_in_ty_param(&self, param_id: usize, position: usize) -> Result<Ty> {
let (param_name, param_domain) = self.get_ty_param(param_id);
let pr::TyDomain::TupleHasFields(fields) = param_domain else {
return Err(error_lookup_into_unpack_of_ty_param(param_name));
};
let lookup = pr::Lookup::Position(position as i64);
Self::lookup_in_tuple_domain(fields, &lookup).map_err(|_| {
Diagnostic::new_custom("cannot do positional lookup into unpack of this type param")
})
}
fn lookup_name_in_ty_param(&self, lookup: &pr::Lookup, id: usize) -> Result<pr::Ty> {
let (param_name, param) = self.get_ty_param(id);
let pr::TyDomain::TupleHasFields(fields) = param else {
return Err(error_lookup_into_unpack_of_ty_param(param_name));
};
Self::lookup_in_tuple_domain(fields, lookup)
}
pub fn lookup_in_tuple_domain(
fields: &[pr::TyDomainTupleField],
lookup: &pr::Lookup,
) -> Result<pr::Ty> {
let Some(field) = fields.iter().find(|f| &f.location == lookup) else {
return Err(Diagnostic::new(
format!("field {} does not exist", print_lookup(lookup),),
DiagnosticCode::TYPE,
));
};
Ok(field.ty.clone())
}
}
pub fn error_no_field(base: &Ty, lookup: &pr::Lookup) -> Diagnostic {
Diagnostic::new(
format!(
"field {} does not exist in type {}",
print_lookup(lookup),
printer::print_ty(base)
),
DiagnosticCode::TYPE,
)
}
fn error_lookup_into_unpack_of_ty_var() -> Diagnostic {
Diagnostic::new_custom("ambiguous lookup into unpack of an unknown type")
.push_hint("consider annotating the unpacked expression")
}
fn error_lookup_into_unpack_of_ty_param(param_name: &str) -> Diagnostic {
Diagnostic::new_custom(format!(
"lookup expected a tuple, found type parameter {param_name}"
))
.push_hint(format!("{param_name} is not constrained to tuples only"))
.push_hint(format!(
"add `{param_name}: {{}}` to constrain it to tuples"
))
}
fn print_lookup(lookup: &pr::Lookup) -> String {
match lookup {
pr::Lookup::Name(n) => format!(".{}", ident::display(n)),
pr::Lookup::Position(p) => format!(".{p}",),
}
}