bend/fun/transform/
lift_local_defs.rs1use std::collections::BTreeSet;
2
3use indexmap::IndexMap;
4
5use crate::{
6 fun::{Book, Definition, Name, Pattern, Rule, Term},
7 maybe_grow,
8};
9
10impl Book {
11 pub fn lift_local_defs(&mut self) {
12 let mut defs = IndexMap::new();
13 for (name, def) in self.defs.iter_mut() {
14 let mut gen = 0;
15 for rule in def.rules.iter_mut() {
16 rule.body.lift_local_defs(name, def.check, &mut defs, &mut gen);
17 }
18 }
19 self.defs.extend(defs);
20 }
21}
22
23impl Rule {
24 pub fn binds(&self) -> impl DoubleEndedIterator<Item = &Option<Name>> + Clone {
25 self.pats.iter().flat_map(Pattern::binds)
26 }
27}
28
29impl Term {
30 pub fn lift_local_defs(
31 &mut self,
32 parent: &Name,
33 check: bool,
34 defs: &mut IndexMap<Name, Definition>,
35 gen: &mut usize,
36 ) {
37 maybe_grow(|| match self {
38 Term::Def { def, nxt } => {
39 let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
40 for rule in def.rules.iter_mut() {
41 rule.body.lift_local_defs(&local_name, check, defs, gen);
42 }
43 nxt.lift_local_defs(parent, check, defs, gen);
44 *gen += 1;
45
46 let inner_defs =
47 defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
48 let (r#use, fvs, mut rules) =
49 gen_use(inner_defs, &local_name, &def.name, nxt, std::mem::take(&mut def.rules));
50 let source = std::mem::take(&mut def.source);
51 *self = r#use;
52
53 apply_closure(&mut rules, &fvs);
54
55 let new_def = Definition::new_gen(local_name.clone(), rules, source, check);
56 defs.insert(local_name.clone(), new_def);
57 }
58 _ => {
59 for child in self.children_mut() {
60 child.lift_local_defs(parent, check, defs, gen);
61 }
62 }
63 })
64 }
65}
66
67fn gen_use(
68 inner_defs: BTreeSet<Name>,
69 local_name: &Name,
70 nam: &Name,
71 nxt: &mut Box<Term>,
72 mut rules: Vec<Rule>,
73) -> (Term, BTreeSet<Name>, Vec<Rule>) {
74 let mut fvs = BTreeSet::<Name>::new();
75 for rule in rules.iter() {
76 fvs.extend(rule.body.free_vars().into_keys().collect::<BTreeSet<_>>());
77 }
78 fvs.retain(|fv| !inner_defs.contains(fv));
79 for rule in rules.iter() {
80 for bind in rule.binds().flatten() {
81 fvs.remove(bind);
82 }
83 }
84 fvs.remove(nam);
85
86 let call = Term::call(
87 Term::Ref { nam: local_name.clone() },
88 fvs.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>(),
89 );
90
91 for rule in rules.iter_mut() {
92 let slf = std::mem::take(&mut rule.body);
93 rule.body = Term::Use { nam: Some(nam.clone()), val: Box::new(call.clone()), nxt: Box::new(slf) };
94 }
95
96 let r#use = Term::Use { nam: Some(nam.clone()), val: Box::new(call.clone()), nxt: std::mem::take(nxt) };
97
98 (r#use, fvs, rules)
99}
100
101fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet<Name>) {
102 for rule in rules.iter_mut() {
103 let captured = fvs.iter().cloned().map(Some).collect::<Vec<_>>();
104 rule.body = Term::rfold_lams(std::mem::take(&mut rule.body), captured.into_iter());
105 }
106}