bend/fun/transform/
lift_local_defs.rs

1use std::collections::BTreeSet;
2
3use indexmap::IndexMap;
4
5use crate::{
6  fun::{Book, Definition, Name, Pattern, Rule, Term},
7  maybe_grow,
8};
9
10impl Book {
11  pub fn lift_local_defs(&mut self) {
12    let mut defs = IndexMap::new();
13    for (name, def) in self.defs.iter_mut() {
14      let mut gen = 0;
15      for rule in def.rules.iter_mut() {
16        rule.body.lift_local_defs(name, def.check, &mut defs, &mut gen);
17      }
18    }
19    self.defs.extend(defs);
20  }
21}
22
23impl Rule {
24  pub fn binds(&self) -> impl DoubleEndedIterator<Item = &Option<Name>> + Clone {
25    self.pats.iter().flat_map(Pattern::binds)
26  }
27}
28
29impl Term {
30  pub fn lift_local_defs(
31    &mut self,
32    parent: &Name,
33    check: bool,
34    defs: &mut IndexMap<Name, Definition>,
35    gen: &mut usize,
36  ) {
37    maybe_grow(|| match self {
38      Term::Def { def, nxt } => {
39        let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
40        for rule in def.rules.iter_mut() {
41          rule.body.lift_local_defs(&local_name, check, defs, gen);
42        }
43        nxt.lift_local_defs(parent, check, defs, gen);
44        *gen += 1;
45
46        let inner_defs =
47          defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
48        let (r#use, fvs, mut rules) =
49          gen_use(inner_defs, &local_name, &def.name, nxt, std::mem::take(&mut def.rules));
50        let source = std::mem::take(&mut def.source);
51        *self = r#use;
52
53        apply_closure(&mut rules, &fvs);
54
55        let new_def = Definition::new_gen(local_name.clone(), rules, source, check);
56        defs.insert(local_name.clone(), new_def);
57      }
58      _ => {
59        for child in self.children_mut() {
60          child.lift_local_defs(parent, check, defs, gen);
61        }
62      }
63    })
64  }
65}
66
67fn gen_use(
68  inner_defs: BTreeSet<Name>,
69  local_name: &Name,
70  nam: &Name,
71  nxt: &mut Box<Term>,
72  mut rules: Vec<Rule>,
73) -> (Term, BTreeSet<Name>, Vec<Rule>) {
74  let mut fvs = BTreeSet::<Name>::new();
75  for rule in rules.iter() {
76    fvs.extend(rule.body.free_vars().into_keys().collect::<BTreeSet<_>>());
77  }
78  fvs.retain(|fv| !inner_defs.contains(fv));
79  for rule in rules.iter() {
80    for bind in rule.binds().flatten() {
81      fvs.remove(bind);
82    }
83  }
84  fvs.remove(nam);
85
86  let call = Term::call(
87    Term::Ref { nam: local_name.clone() },
88    fvs.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>(),
89  );
90
91  for rule in rules.iter_mut() {
92    let slf = std::mem::take(&mut rule.body);
93    rule.body = Term::Use { nam: Some(nam.clone()), val: Box::new(call.clone()), nxt: Box::new(slf) };
94  }
95
96  let r#use = Term::Use { nam: Some(nam.clone()), val: Box::new(call.clone()), nxt: std::mem::take(nxt) };
97
98  (r#use, fvs, rules)
99}
100
101fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet<Name>) {
102  for rule in rules.iter_mut() {
103    let captured = fvs.iter().cloned().map(Some).collect::<Vec<_>>();
104    rule.body = Term::rfold_lams(std::mem::take(&mut rule.body), captured.into_iter());
105  }
106}