use crate::{
diagnostics::{Diagnostics, WarningType},
fun::{builtins, Adts, Constructors, Ctx, Definition, FanKind, Name, Num, Pattern, Rule, Tag, Term},
maybe_grow,
};
use itertools::Itertools;
use std::collections::{BTreeSet, HashSet};
pub enum DesugarMatchDefErr {
AdtNotExhaustive { adt: Name, ctr: Name },
NumMissingDefault,
TypeMismatch { expected: Type, found: Type, pat: Pattern },
RepeatedBind { bind: Name },
UnreachableRule { idx: usize, nam: Name, pats: Vec<Pattern> },
}
impl Ctx<'_> {
pub fn desugar_match_defs(&mut self) -> Result<(), Diagnostics> {
for (def_name, def) in self.book.defs.iter_mut() {
let errs = def.desugar_match_def(&self.book.ctrs, &self.book.adts);
for err in errs {
match err {
DesugarMatchDefErr::AdtNotExhaustive { .. }
| DesugarMatchDefErr::NumMissingDefault
| DesugarMatchDefErr::TypeMismatch { .. } => {
self.info.add_function_error(err, def_name.clone(), def.source.clone())
}
DesugarMatchDefErr::RepeatedBind { .. } => self.info.add_function_warning(
err,
WarningType::RepeatedBind,
def_name.clone(),
def.source.clone(),
),
DesugarMatchDefErr::UnreachableRule { .. } => self.info.add_function_warning(
err,
WarningType::UnreachableMatch,
def_name.clone(),
def.source.clone(),
),
}
}
}
self.info.fatal(())
}
}
impl Definition {
pub fn desugar_match_def(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<DesugarMatchDefErr> {
let mut errs = vec![];
for rule in self.rules.iter_mut() {
desugar_inner_match_defs(&mut rule.body, ctrs, adts, &mut errs);
}
let repeated_bind_errs = fix_repeated_binds(&mut self.rules);
errs.extend(repeated_bind_errs);
let args = (0..self.arity()).map(|i| Name::new(format!("%arg{i}"))).collect::<Vec<_>>();
let rules = std::mem::take(&mut self.rules);
let idx = (0..rules.len()).collect::<Vec<_>>();
let mut used = BTreeSet::new();
match simplify_rule_match(args.clone(), rules.clone(), idx.clone(), vec![], &mut used, ctrs, adts) {
Ok(body) => {
let body = Term::rfold_lams(body, args.into_iter().map(Some));
self.rules = vec![Rule { pats: vec![], body }];
for i in idx {
if !used.contains(&i) {
let e = DesugarMatchDefErr::UnreachableRule {
idx: i,
nam: self.name.clone(),
pats: rules[i].pats.clone(),
};
errs.push(e);
}
}
}
Err(e) => errs.push(e),
}
errs
}
}
fn desugar_inner_match_defs(
term: &mut Term,
ctrs: &Constructors,
adts: &Adts,
errs: &mut Vec<DesugarMatchDefErr>,
) {
maybe_grow(|| match term {
Term::Def { def, nxt } => {
errs.extend(def.desugar_match_def(ctrs, adts));
desugar_inner_match_defs(nxt, ctrs, adts, errs);
}
_ => {
for child in term.children_mut() {
desugar_inner_match_defs(child, ctrs, adts, errs);
}
}
})
}
fn fix_repeated_binds(rules: &mut [Rule]) -> Vec<DesugarMatchDefErr> {
let mut errs = vec![];
for rule in rules {
let mut binds = HashSet::new();
rule.pats.iter_mut().flat_map(|p| p.binds_mut()).rev().for_each(|nam| {
if binds.contains(nam) {
if let Some(nam) = nam {
errs.push(DesugarMatchDefErr::RepeatedBind { bind: nam.clone() });
}
*nam = None;
} else {
binds.insert(&*nam);
}
});
}
errs
}
fn simplify_rule_match(
args: Vec<Name>,
rules: Vec<Rule>,
idx: Vec<usize>,
with: Vec<Name>,
used: &mut BTreeSet<usize>,
ctrs: &Constructors,
adts: &Adts,
) -> Result<Term, DesugarMatchDefErr> {
if args.is_empty() {
used.insert(idx[0]);
Ok(rules.into_iter().next().unwrap().body)
} else if rules[0].pats.iter().all(|p| p.is_wildcard()) {
Ok(irrefutable_fst_row_rule(args, rules.into_iter().next().unwrap(), idx[0], used))
} else {
let typ = Type::infer_from_def_arg(&rules, 0, ctrs)?;
match typ {
Type::Any => var_rule(args, rules, idx, with, used, ctrs, adts),
Type::Fan(fan, tag, tup_len) => fan_rule(args, rules, idx, with, used, fan, tag, tup_len, ctrs, adts),
Type::Num => num_rule(args, rules, idx, with, used, ctrs, adts),
Type::Adt(adt_name) => switch_rule(args, rules, idx, with, adt_name, used, ctrs, adts),
}
}
}
fn irrefutable_fst_row_rule(args: Vec<Name>, rule: Rule, idx: usize, used: &mut BTreeSet<usize>) -> Term {
let mut term = rule.body;
for (arg, pat) in args.into_iter().zip(rule.pats.into_iter()) {
match pat {
Pattern::Var(None) => {}
Pattern::Var(Some(var)) => {
term = Term::Use { nam: Some(var), val: Box::new(Term::Var { nam: arg }), nxt: Box::new(term) };
}
Pattern::Chn(var) => {
term = Term::Let {
pat: Box::new(Pattern::Chn(var)),
val: Box::new(Term::Var { nam: arg }),
nxt: Box::new(term),
};
}
_ => unreachable!(),
}
}
used.insert(idx);
term
}
fn var_rule(
mut args: Vec<Name>,
rules: Vec<Rule>,
idx: Vec<usize>,
mut with: Vec<Name>,
used: &mut BTreeSet<usize>,
ctrs: &Constructors,
adts: &Adts,
) -> Result<Term, DesugarMatchDefErr> {
let arg = args[0].clone();
let new_args = args.split_off(1);
let mut new_rules = vec![];
for mut rule in rules {
let new_pats = rule.pats.split_off(1);
let pat = rule.pats.pop().unwrap();
if let Pattern::Var(Some(nam)) = &pat {
rule.body = Term::Use {
nam: Some(nam.clone()),
val: Box::new(Term::Var { nam: arg.clone() }),
nxt: Box::new(std::mem::take(&mut rule.body)),
};
}
let new_rule = Rule { pats: new_pats, body: rule.body };
new_rules.push(new_rule);
}
with.push(arg);
simplify_rule_match(new_args, new_rules, idx, with, used, ctrs, adts)
}
#[allow(clippy::too_many_arguments)]
fn fan_rule(
mut args: Vec<Name>,
rules: Vec<Rule>,
idx: Vec<usize>,
with: Vec<Name>,
used: &mut BTreeSet<usize>,
fan: FanKind,
tag: Tag,
len: usize,
ctrs: &Constructors,
adts: &Adts,
) -> Result<Term, DesugarMatchDefErr> {
let arg = args[0].clone();
let old_args = args.split_off(1);
let new_args = (0..len).map(|i| Name::new(format!("{arg}.{i}")));
let mut new_rules = vec![];
for mut rule in rules {
let pat = rule.pats[0].clone();
let old_pats = rule.pats.split_off(1);
let mut new_pats = match pat {
Pattern::Fan(.., sub_pats) => sub_pats,
Pattern::Var(var) => {
if let Some(var) = var {
let tup =
Term::Fan { fan, tag: tag.clone(), els: new_args.clone().map(|nam| Term::Var { nam }).collect() };
rule.body =
Term::Use { nam: Some(var), val: Box::new(tup), nxt: Box::new(std::mem::take(&mut rule.body)) };
}
new_args.clone().map(|nam| Pattern::Var(Some(nam))).collect()
}
_ => unreachable!(),
};
new_pats.extend(old_pats);
let new_rule = Rule { pats: new_pats, body: rule.body };
new_rules.push(new_rule);
}
let bnd = new_args.clone().map(|x| Pattern::Var(Some(x))).collect();
let args = new_args.chain(old_args).collect();
let nxt = simplify_rule_match(args, new_rules, idx, with, used, ctrs, adts)?;
let term = Term::Let {
pat: Box::new(Pattern::Fan(fan, tag.clone(), bnd)),
val: Box::new(Term::Var { nam: arg }),
nxt: Box::new(nxt),
};
Ok(term)
}
fn num_rule(
mut args: Vec<Name>,
rules: Vec<Rule>,
idx: Vec<usize>,
with: Vec<Name>,
used: &mut BTreeSet<usize>,
ctrs: &Constructors,
adts: &Adts,
) -> Result<Term, DesugarMatchDefErr> {
if !rules.iter().any(|r| r.pats[0].is_wildcard()) {
return Err(DesugarMatchDefErr::NumMissingDefault);
}
let arg = args[0].clone();
let args = args.split_off(1);
let pred_var = Name::new(format!("{arg}-1"));
let nums = rules
.iter()
.filter_map(|r| if let Pattern::Num(n) = r.pats[0] { Some(n) } else { None })
.collect::<BTreeSet<_>>()
.into_iter()
.collect::<Vec<_>>();
let mut num_bodies = vec![];
for num in nums.iter() {
let mut new_rules = vec![];
let mut new_idx = vec![];
for (rule, &idx) in rules.iter().zip(&idx) {
match &rule.pats[0] {
Pattern::Num(n) if n == num => {
let body = rule.body.clone();
let rule = Rule { pats: rule.pats[1..].to_vec(), body };
new_rules.push(rule);
new_idx.push(idx);
}
Pattern::Var(var) => {
let mut body = rule.body.clone();
if let Some(var) = var {
body = Term::Use {
nam: Some(var.clone()),
val: Box::new(Term::Num { val: Num::U24(*num) }),
nxt: Box::new(std::mem::take(&mut body)),
};
}
let rule = Rule { pats: rule.pats[1..].to_vec(), body };
new_rules.push(rule);
new_idx.push(idx);
}
_ => (),
}
}
let body = simplify_rule_match(args.clone(), new_rules, new_idx, with.clone(), used, ctrs, adts)?;
num_bodies.push(body);
}
let mut new_rules = vec![];
let mut new_idx = vec![];
for (rule, &idx) in rules.into_iter().zip(&idx) {
if let Pattern::Var(var) = &rule.pats[0] {
let mut body = rule.body.clone();
if let Some(var) = var {
let last_num = *nums.last().unwrap();
let cur_num = 1 + last_num;
let var_recovered = Term::add_num(Term::Var { nam: pred_var.clone() }, Num::U24(cur_num));
body = Term::Use { nam: Some(var.clone()), val: Box::new(var_recovered), nxt: Box::new(body) };
fast_pred_access(&mut body, cur_num, var, &pred_var);
}
let rule = Rule { pats: rule.pats[1..].to_vec(), body };
new_rules.push(rule);
new_idx.push(idx);
}
}
let mut default_with = with.clone();
default_with.push(pred_var.clone());
let default_body = simplify_rule_match(args.clone(), new_rules, new_idx, default_with, used, ctrs, adts)?;
let with = with.into_iter().chain(args).collect::<Vec<_>>();
let with_bnd = with.iter().cloned().map(Some).collect::<Vec<_>>();
let with_arg = with.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>();
let term = num_bodies.into_iter().enumerate().rfold(default_body, |term, (i, body)| {
let val = if i > 0 {
Term::sub_num(Term::Var { nam: pred_var.clone() }, Num::U24(nums[i] - 1 - nums[i - 1]))
} else {
Term::sub_num(Term::Var { nam: arg.clone() }, Num::U24(nums[i]))
};
Term::Swt {
arg: Box::new(val),
bnd: Some(arg.clone()),
with_bnd: with_bnd.clone(),
with_arg: with_arg.clone(),
pred: Some(pred_var.clone()),
arms: vec![body, term],
}
});
Ok(term)
}
fn fast_pred_access(body: &mut Term, cur_num: u32, var: &Name, pred_var: &Name) {
maybe_grow(|| {
if let Term::Oper { opr: crate::fun::Op::SUB, fst, snd } = body {
if let Term::Num { val: crate::fun::Num::U24(val) } = &**snd {
if let Term::Var { nam } = &**fst {
if nam == var && *val == cur_num {
*body = Term::Var { nam: pred_var.clone() };
}
}
}
}
for child in body.children_mut() {
fast_pred_access(child, cur_num, var, pred_var)
}
})
}
#[allow(clippy::too_many_arguments)]
fn switch_rule(
mut args: Vec<Name>,
rules: Vec<Rule>,
idx: Vec<usize>,
with: Vec<Name>,
adt_name: Name,
used: &mut BTreeSet<usize>,
ctrs: &Constructors,
adts: &Adts,
) -> Result<Term, DesugarMatchDefErr> {
let arg = args[0].clone();
let old_args = args.split_off(1);
let mut new_arms = vec![];
for (ctr_nam, ctr) in &adts[&adt_name].ctrs {
let new_args = ctr.fields.iter().map(|f| Name::new(format!("{}.{}", arg, f.nam)));
let args = new_args.clone().chain(old_args.clone()).collect();
let mut new_rules = vec![];
let mut new_idx = vec![];
for (rule, &idx) in rules.iter().zip(&idx) {
let old_pats = rule.pats[1..].to_vec();
match &rule.pats[0] {
Pattern::Ctr(found_ctr, new_pats) if ctr_nam == found_ctr => {
let pats = new_pats.iter().cloned().chain(old_pats).collect();
let body = rule.body.clone();
let rule = Rule { pats, body };
new_rules.push(rule);
new_idx.push(idx);
}
Pattern::Var(var) => {
let new_pats = new_args.clone().map(|n| Pattern::Var(Some(n)));
let pats = new_pats.chain(old_pats.clone()).collect();
let mut body = rule.body.clone();
let reconstructed_var =
Term::call(Term::Ref { nam: ctr_nam.clone() }, new_args.clone().map(|nam| Term::Var { nam }));
if let Some(var) = var {
body =
Term::Use { nam: Some(var.clone()), val: Box::new(reconstructed_var), nxt: Box::new(body) };
}
let rule = Rule { pats, body };
new_rules.push(rule);
new_idx.push(idx);
}
_ => (),
}
}
if new_rules.is_empty() {
return Err(DesugarMatchDefErr::AdtNotExhaustive { adt: adt_name, ctr: ctr_nam.clone() });
}
let body = simplify_rule_match(args, new_rules, new_idx, with.clone(), used, ctrs, adts)?;
new_arms.push((Some(ctr_nam.clone()), new_args.map(Some).collect(), body));
}
let with = with.into_iter().chain(old_args).collect::<Vec<_>>();
let with_bnd = with.iter().cloned().map(Some).collect::<Vec<_>>();
let with_arg = with.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>();
let term = Term::Mat {
arg: Box::new(Term::Var { nam: arg.clone() }),
bnd: Some(arg.clone()),
with_bnd,
with_arg,
arms: new_arms,
};
Ok(term)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
Any,
Fan(FanKind, Tag, usize),
Num,
Adt(Name),
}
impl Type {
fn infer_from_def_arg(
rules: &[Rule],
arg_idx: usize,
ctrs: &Constructors,
) -> Result<Type, DesugarMatchDefErr> {
let pats = rules.iter().map(|r| &r.pats[arg_idx]);
let mut arg_type = Type::Any;
for pat in pats {
arg_type = match (arg_type, pat.to_type(ctrs)) {
(Type::Any, found) => found,
(expected, Type::Any) => expected,
(expected, found) if expected == found => expected,
(expected, found) => {
return Err(DesugarMatchDefErr::TypeMismatch { expected, found, pat: pat.clone() });
}
};
}
Ok(arg_type)
}
}
impl Pattern {
fn to_type(&self, ctrs: &Constructors) -> Type {
match self {
Pattern::Var(_) | Pattern::Chn(_) => Type::Any,
Pattern::Ctr(ctr_nam, _) => {
let adt_nam = ctrs.get(ctr_nam).unwrap_or_else(|| panic!("Unknown constructor '{ctr_nam}'"));
Type::Adt(adt_nam.clone())
}
Pattern::Fan(is_tup, tag, args) => Type::Fan(*is_tup, tag.clone(), args.len()),
Pattern::Num(_) => Type::Num,
Pattern::Lst(..) => Type::Adt(Name::new(builtins::LIST)),
Pattern::Str(..) => Type::Adt(Name::new(builtins::STRING)),
}
}
}
impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Any => write!(f, "any"),
Type::Fan(FanKind::Tup, tag, n) => write!(f, "{}{n}-tuple", tag.display_padded()),
Type::Fan(FanKind::Dup, tag, n) => write!(f, "{}{n}-dup", tag.display_padded()),
Type::Num => write!(f, "number"),
Type::Adt(nam) => write!(f, "{nam}"),
}
}
}
impl std::fmt::Display for DesugarMatchDefErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DesugarMatchDefErr::AdtNotExhaustive { adt, ctr } => {
write!(f, "Non-exhaustive pattern matching rule. Constructor '{ctr}' of type '{adt}' not covered")
}
DesugarMatchDefErr::TypeMismatch { expected, found, pat } => {
write!(
f,
"Type mismatch in pattern matching rule. Expected a constructor of type '{}', found '{}' with type '{}'.",
expected, pat, found
)
}
DesugarMatchDefErr::NumMissingDefault => {
write!(f, "Non-exhaustive pattern matching rule. Default case of number type not covered.")
}
DesugarMatchDefErr::RepeatedBind { bind } => {
write!(f, "Repeated bind in pattern matching rule: '{bind}'.")
}
DesugarMatchDefErr::UnreachableRule { idx, nam, pats } => {
write!(
f,
"Unreachable pattern matching rule '({}{})' (rule index {idx}).",
nam,
pats.iter().map(|p| format!(" {p}")).join("")
)
}
}
}
}