use crate::{
diagnostics::{Diagnostics, WarningType, ERR_INDENT_SIZE},
fun::{Adts, Constructors, CtrField, Ctx, MatchRule, Name, Num, Term},
maybe_grow,
};
use std::collections::HashMap;
enum FixMatchErr {
AdtMismatch { expected: Name, found: Name, ctr: Name },
NonExhaustiveMatch { typ: Name, missing: Name },
IrrefutableMatch { var: Option<Name> },
UnreachableMatchArms { var: Option<Name> },
RedundantArm { ctr: Name },
}
impl Ctx<'_> {
pub fn fix_match_terms(&mut self) -> Result<(), Diagnostics> {
for def in self.book.defs.values_mut() {
for rule in def.rules.iter_mut() {
let errs = rule.body.fix_match_terms(&self.book.ctrs, &self.book.adts);
for err in errs {
match err {
FixMatchErr::AdtMismatch { .. } | FixMatchErr::NonExhaustiveMatch { .. } => {
self.info.add_function_error(err, def.name.clone(), def.source.clone())
}
FixMatchErr::IrrefutableMatch { .. } => self.info.add_function_warning(
err,
WarningType::IrrefutableMatch,
def.name.clone(),
def.source.clone(),
),
FixMatchErr::UnreachableMatchArms { .. } => self.info.add_function_warning(
err,
WarningType::UnreachableMatch,
def.name.clone(),
def.source.clone(),
),
FixMatchErr::RedundantArm { .. } => self.info.add_function_warning(
err,
WarningType::RedundantMatch,
def.name.clone(),
def.source.clone(),
),
}
}
}
}
self.info.fatal(())
}
}
impl Term {
fn fix_match_terms(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<FixMatchErr> {
maybe_grow(|| {
let mut errs = Vec::new();
for child in self.children_mut() {
let mut e = child.fix_match_terms(ctrs, adts);
errs.append(&mut e);
}
if matches!(self, Term::Mat { .. } | Term::Fold { .. }) {
self.fix_match(&mut errs, ctrs, adts);
}
match self {
Term::Def { def, nxt } => {
for rule in def.rules.iter_mut() {
errs.extend(rule.body.fix_match_terms(ctrs, adts));
}
errs.extend(nxt.fix_match_terms(ctrs, adts));
}
Term::Mat { arg: _, bnd, with_bnd: _, with_arg: _, arms }
| Term::Fold { bnd, arg: _, with_bnd: _, with_arg: _, arms } => {
for (ctr, fields, body) in arms {
if let Some(ctr) = ctr {
*body = Term::Use {
nam: bnd.clone(),
val: Box::new(Term::call(
Term::Ref { nam: ctr.clone() },
fields.iter().flatten().cloned().map(|nam| Term::Var { nam }),
)),
nxt: Box::new(std::mem::take(body)),
};
}
}
}
Term::Swt { arg: _, bnd, with_bnd: _, with_arg: _, pred, arms } => {
let n_nums = arms.len() - 1;
for (i, arm) in arms.iter_mut().enumerate() {
let orig = if i == n_nums {
Term::add_num(Term::Var { nam: pred.clone().unwrap() }, Num::U24(i as u32))
} else {
Term::Num { val: Num::U24(i as u32) }
};
*arm = Term::Use { nam: bnd.clone(), val: Box::new(orig), nxt: Box::new(std::mem::take(arm)) };
}
}
_ => {}
}
match self {
Term::Mat { bnd, .. } | Term::Swt { bnd, .. } | Term::Fold { bnd, .. } => *bnd = None,
_ => {}
}
errs
})
}
fn fix_match(&mut self, errs: &mut Vec<FixMatchErr>, ctrs: &Constructors, adts: &Adts) {
let (Term::Mat { bnd, arg, with_bnd, with_arg, arms }
| Term::Fold { bnd, arg, with_bnd, with_arg, arms }) = self
else {
unreachable!()
};
let bnd = bnd.clone().unwrap();
if let Some(ctr_nam) = &arms[0].0 {
if let Some(adt_nam) = ctrs.get(ctr_nam) {
let adt_ctrs = &adts[adt_nam].ctrs;
let mut bodies = fixed_match_arms(&bnd, arms, adt_nam, adt_ctrs.keys(), ctrs, adts, errs);
let mut new_rules = vec![];
for (ctr_nam, ctr) in adt_ctrs.iter() {
let fields = ctr.fields.iter().map(|f| Some(match_field(&bnd, &f.nam))).collect::<Vec<_>>();
let body = if let Some(Some(body)) = bodies.remove(ctr_nam) {
body
} else {
errs.push(FixMatchErr::NonExhaustiveMatch { typ: adt_nam.clone(), missing: ctr_nam.clone() });
Term::Err
};
new_rules.push((Some(ctr_nam.clone()), fields, body));
}
*arms = new_rules;
return;
}
}
errs.push(FixMatchErr::IrrefutableMatch { var: arms[0].0.clone() });
let match_var = arms[0].0.take();
let arg = std::mem::take(arg);
let with_bnd = std::mem::take(with_bnd);
let with_arg = std::mem::take(with_arg);
*self = std::mem::take(&mut arms[0].2);
*self = Term::rfold_lams(std::mem::take(self), with_bnd.into_iter());
*self = Term::call(std::mem::take(self), with_arg);
if let Some(var) = match_var {
*self = Term::Use {
nam: Some(bnd.clone()),
val: arg,
nxt: Box::new(Term::Use {
nam: Some(var),
val: Box::new(Term::Var { nam: bnd }),
nxt: Box::new(std::mem::take(self)),
}),
}
}
}
}
fn fixed_match_arms<'a>(
bnd: &Name,
rules: &mut Vec<MatchRule>,
adt_nam: &Name,
adt_ctrs: impl Iterator<Item = &'a Name>,
ctrs: &Constructors,
adts: &Adts,
errs: &mut Vec<FixMatchErr>,
) -> HashMap<&'a Name, Option<Term>> {
let mut bodies = HashMap::<&Name, Option<Term>>::from_iter(adt_ctrs.map(|ctr| (ctr, None)));
for rule_idx in 0..rules.len() {
if let Some(ctr_nam) = &rules[rule_idx].0 {
if let Some(found_adt) = ctrs.get(ctr_nam) {
if found_adt == adt_nam {
let body = bodies.get_mut(ctr_nam).unwrap();
if body.is_none() {
*body = Some(rules[rule_idx].2.clone());
} else {
errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() });
}
} else {
errs.push(FixMatchErr::AdtMismatch {
expected: adt_nam.clone(),
found: found_adt.clone(),
ctr: ctr_nam.clone(),
})
}
continue;
}
}
for (ctr, body) in bodies.iter_mut() {
if body.is_none() {
let mut new_body = rules[rule_idx].2.clone();
if let Some(var) = &rules[rule_idx].0 {
new_body = Term::Use {
nam: Some(var.clone()),
val: Box::new(rebuild_ctr(bnd, ctr, &adts[adt_nam].ctrs[&**ctr].fields)),
nxt: Box::new(new_body),
};
}
*body = Some(new_body);
}
}
if rule_idx != rules.len() - 1 {
errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() });
rules.truncate(rule_idx + 1);
}
break;
}
bodies
}
fn match_field(arg: &Name, field: &Name) -> Name {
Name::new(format!("{arg}.{field}"))
}
fn rebuild_ctr(arg: &Name, ctr: &Name, fields: &[CtrField]) -> Term {
let ctr = Term::Ref { nam: ctr.clone() };
let fields = fields.iter().map(|f| Term::Var { nam: match_field(arg, &f.nam) });
Term::call(ctr, fields)
}
impl std::fmt::Display for FixMatchErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FixMatchErr::AdtMismatch { expected, found, ctr } => write!(
f,
"Type mismatch in 'match' expression: Expected a constructor of type '{expected}', found '{ctr}' of type '{found}'"
),
FixMatchErr::NonExhaustiveMatch { typ, missing } => {
write!(f, "Non-exhaustive 'match' expression of type '{typ}'. Case '{missing}' not covered.")
}
FixMatchErr::IrrefutableMatch { var } => {
writeln!(
f,
"Irrefutable 'match' expression. All cases after variable pattern '{}' will be ignored.",
var.as_ref().unwrap_or(&Name::new("*")),
)?;
writeln!(
f,
"{:ERR_INDENT_SIZE$}Note that to use a 'match' expression, the matched constructors need to be defined in a 'data' definition.",
"",
)?;
write!(
f,
"{:ERR_INDENT_SIZE$}If this is not a mistake, consider using a 'let' expression instead.",
""
)
}
FixMatchErr::UnreachableMatchArms { var } => write!(
f,
"Unreachable arms in 'match' expression. All cases after '{}' will be ignored.",
var.as_ref().unwrap_or(&Name::new("*"))
),
FixMatchErr::RedundantArm { ctr } => {
write!(f, "Redundant arm in 'match' expression. Case '{ctr}' appears more than once.")
}
}
}
}