use crate::{
fun::{Book, Definition, Name, Pattern, Rule, Term},
maybe_grow, multi_iterator,
};
use std::collections::{BTreeMap, HashSet};
impl Book {
pub fn float_combinators(&mut self, max_size: usize) {
let book = self.clone();
let mut ctx = FloatCombinatorsCtx::new(&book, max_size);
for (def_name, def) in self.defs.iter_mut() {
let builtin = def.builtin;
let body = &mut def.rule_mut().body;
ctx.reset();
ctx.def_size = body.size();
body.float_combinators(&mut ctx, def_name, builtin);
}
self.defs.extend(ctx.combinators.into_iter().map(|(nam, (_, def))| (nam, def)));
}
}
struct FloatCombinatorsCtx<'b> {
pub combinators: BTreeMap<Name, (bool, Definition)>,
pub name_gen: usize,
pub seen: HashSet<Name>,
pub book: &'b Book,
pub max_size: usize,
pub def_size: usize,
}
impl<'b> FloatCombinatorsCtx<'b> {
fn new(book: &'b Book, max_size: usize) -> Self {
Self {
combinators: Default::default(),
name_gen: 0,
seen: Default::default(),
book,
max_size,
def_size: 0,
}
}
fn reset(&mut self) {
self.def_size = 0;
self.name_gen = 0;
self.seen = Default::default();
}
}
impl Term {
fn float_combinators(&mut self, ctx: &mut FloatCombinatorsCtx, def_name: &Name, builtin: bool) {
maybe_grow(|| {
for child in self.float_children_mut() {
child.float_combinators(ctx, def_name, builtin);
}
let mut size = self.size();
let is_combinator = self.is_combinator();
for child in self.float_children_mut() {
let child_is_safe = child.is_safe(ctx);
let child_size = child.size();
let extract_for_size = if is_combinator { size > ctx.max_size } else { ctx.def_size > ctx.max_size };
if child.is_combinator() && child_size > 0 && (!child_is_safe || extract_for_size) {
ctx.def_size -= child_size;
size -= child_size;
child.float(ctx, def_name, builtin, child_is_safe);
}
}
})
}
fn float(&mut self, ctx: &mut FloatCombinatorsCtx, def_name: &Name, builtin: bool, is_safe: bool) {
let comb_name = Name::new(format!("{}__C{}", def_name, ctx.name_gen));
ctx.name_gen += 1;
let comb_ref = Term::Ref { nam: comb_name.clone() };
let extracted_term = std::mem::replace(self, comb_ref);
let rules = vec![Rule { body: extracted_term, pats: Vec::new() }];
let rule = Definition { name: comb_name.clone(), rules, builtin };
ctx.combinators.insert(comb_name, (is_safe, rule));
}
}
impl Term {
fn is_safe(&self, ctx: &mut FloatCombinatorsCtx) -> bool {
maybe_grow(|| match self {
Term::Num { .. }
| Term::Era
| Term::Err
| Term::Fan { .. }
| Term::App { .. }
| Term::Oper { .. }
| Term::Swt { .. } => self.children().all(|c| c.is_safe(ctx)),
Term::Lam { .. } => self.is_safe_lambda(ctx),
Term::Ref { nam } => {
if ctx.book.ctrs.contains_key(nam) {
return true;
}
if ctx.seen.contains(nam) {
return false;
}
ctx.seen.insert(nam.clone());
let safe = if let Some(def) = ctx.book.defs.get(nam) {
let ref_safe = def.rule().body.is_safe(ctx);
ref_safe
} else if let Some((safe, _)) = ctx.combinators.get(nam) {
*safe
} else {
false
};
ctx.seen.remove(nam);
safe
}
_ => false,
})
}
fn is_safe_lambda(&self, ctx: &mut FloatCombinatorsCtx) -> bool {
let mut current = self;
let mut scope = Vec::new();
while let Term::Lam { bod, .. } = current {
scope.extend(current.pattern().unwrap().binds().filter_map(|x| x.as_ref()));
current = bod;
}
match current {
Term::Var { nam } if scope.contains(&nam) => true,
Term::Ref { .. } => true,
term => term.is_safe(ctx),
}
}
pub fn has_unscoped_diff(&self) -> bool {
let (declared, used) = self.unscoped_vars();
declared.difference(&used).count() != 0 || used.difference(&declared).count() != 0
}
fn is_combinator(&self) -> bool {
self.free_vars().is_empty() && !self.has_unscoped_diff() && !matches!(self, Term::Ref { .. })
}
fn base_size(&self) -> usize {
match self {
Term::Let { pat, .. } => pat.size(),
Term::Fan { els, .. } => els.len() - 1,
Term::Mat { arms, .. } => arms.len(),
Term::Swt { arms, .. } => 2 * (arms.len() - 1),
Term::Lam { .. } => 1,
Term::App { .. } => 1,
Term::Oper { .. } => 1,
Term::Var { .. } => 0,
Term::Link { .. } => 0,
Term::Use { .. } => 0,
Term::Num { .. } => 0,
Term::Ref { .. } => 0,
Term::Era => 0,
Term::Bend { .. }
| Term::Fold { .. }
| Term::Nat { .. }
| Term::Str { .. }
| Term::List { .. }
| Term::Do { .. }
| Term::Ask { .. }
| Term::Open { .. }
| Term::Err => unreachable!(),
}
}
fn size(&self) -> usize {
maybe_grow(|| {
let children_size: usize = self.children().map(|c| c.size()).sum();
self.base_size() + children_size
})
}
pub fn float_children_mut(&mut self) -> impl Iterator<Item = &mut Term> {
multi_iterator!(FloatIter { Zero, Two, Vec, Mat, App, Swt });
match self {
Term::App { .. } => {
let mut next = Some(self);
FloatIter::App(std::iter::from_fn(move || {
let cur = next.take();
if let Some(Term::App { fun, arg, .. }) = cur {
next = Some(&mut *fun);
Some(&mut **arg)
} else {
cur
}
}))
}
Term::Mat { arg, bnd: _, with: _, arms } => {
FloatIter::Mat([arg.as_mut()].into_iter().chain(arms.iter_mut().map(|r| &mut r.2)))
}
Term::Swt { arg, bnd: _, with: _, pred: _, arms } => {
FloatIter::Swt([arg.as_mut()].into_iter().chain(arms.iter_mut()))
}
Term::Fan { els, .. } | Term::List { els } => FloatIter::Vec(els),
Term::Let { val: fst, nxt: snd, .. }
| Term::Use { val: fst, nxt: snd, .. }
| Term::Oper { fst, snd, .. } => FloatIter::Two([fst.as_mut(), snd.as_mut()]),
Term::Lam { bod, .. } => bod.float_children_mut(),
Term::Var { .. }
| Term::Link { .. }
| Term::Num { .. }
| Term::Nat { .. }
| Term::Str { .. }
| Term::Ref { .. }
| Term::Era
| Term::Err => FloatIter::Zero([]),
Term::Do { .. } | Term::Ask { .. } | Term::Bend { .. } | Term::Fold { .. } | Term::Open { .. } => {
unreachable!()
}
}
}
}
impl Pattern {
fn size(&self) -> usize {
match self {
Pattern::Var(_) => 0,
Pattern::Chn(_) => 0,
Pattern::Fan(_, _, pats) => pats.len() - 1 + pats.iter().map(|p| p.size()).sum::<usize>(),
Pattern::Num(_) | Pattern::Lst(_) | Pattern::Str(_) | Pattern::Ctr(_, _) => unreachable!(),
}
}
}