bend/fun/transform/
float_combinators.rs

1use crate::{
2  fun::{Book, Definition, Name, Pattern, Rule, Source, Term},
3  maybe_grow, multi_iterator,
4};
5use std::collections::{BTreeMap, HashSet};
6
7pub const NAME_SEP: &str = "__C";
8
9impl Book {
10  /// Extracts combinator terms into new definitions.
11  ///
12  /// Precondition: Variables must have been sanitized.
13  ///
14  /// The floating algorithm follows these rules:
15  /// For each child of the term:
16  /// - Recursively float every grandchild term.
17  /// - If the child is a combinator:
18  ///   * If the child is not "safe", extract it.
19  ///   * If the term is a combinator and it's "safe":
20  ///     - If the term is currently larger than `max_size`, extract the child.
21  ///   * Otherwise, always extract the child to a new definition.
22  /// - If the child is not a combinator, we can't extract it since
23  ///   it would generate an invalid term.
24  ///
25  /// Terms are considered combinators if they have no free vars,
26  /// no unmatched unscoped binds/vars and are not references (to
27  /// avoid infinite recursion).
28  ///
29  /// See [`Term::is_safe`] for what is considered safe here.
30  ///
31  /// See [`Term::size`] for the measurement of size.
32  /// It should more or less correspond to the compiled inet size.
33  pub fn float_combinators(&mut self, max_size: usize) {
34    let book = self.clone();
35    let mut ctx = FloatCombinatorsCtx::new(&book, max_size);
36
37    for (def_name, def) in self.defs.iter_mut() {
38      // Don't float combinators in the main entrypoint.
39      // This avoids making programs unexpectedly too lazy,
40      // returning just a reference without executing anything.
41      if let Some(main) = self.entrypoint.as_ref() {
42        if def_name == main {
43          continue;
44        }
45      }
46
47      let source = def.source.clone();
48      let check = def.check;
49      let body = &mut def.rule_mut().body;
50      ctx.reset();
51      ctx.def_size = body.size();
52      body.float_combinators(&mut ctx, def_name, source, check);
53    }
54
55    self.defs.extend(ctx.combinators.into_iter().map(|(nam, (_, def))| (nam, def)));
56  }
57}
58
59struct FloatCombinatorsCtx<'b> {
60  pub combinators: BTreeMap<Name, (bool, Definition)>,
61  pub name_gen: usize,
62  pub seen: HashSet<Name>,
63  pub book: &'b Book,
64  pub max_size: usize,
65  pub def_size: usize,
66}
67
68impl<'b> FloatCombinatorsCtx<'b> {
69  fn new(book: &'b Book, max_size: usize) -> Self {
70    Self {
71      combinators: Default::default(),
72      name_gen: 0,
73      seen: Default::default(),
74      book,
75      max_size,
76      def_size: 0,
77    }
78  }
79
80  fn reset(&mut self) {
81    self.def_size = 0;
82    self.name_gen = 0;
83    self.seen = Default::default();
84  }
85}
86
87impl Term {
88  fn float_combinators(
89    &mut self,
90    ctx: &mut FloatCombinatorsCtx,
91    def_name: &Name,
92    source: Source,
93    check: bool,
94  ) {
95    maybe_grow(|| {
96      // Recursively float the grandchildren terms.
97      for child in self.float_children_mut() {
98        child.float_combinators(ctx, def_name, source.clone(), check);
99      }
100
101      let mut size = self.size();
102      let is_combinator = self.is_combinator();
103
104      // Float unsafe children and children that make the term too big.
105      for child in self.float_children_mut() {
106        let child_is_safe = child.is_safe(ctx);
107        let child_size = child.size();
108
109        let extract_for_size = if is_combinator { size > ctx.max_size } else { ctx.def_size > ctx.max_size };
110
111        if child.is_combinator() && child_size > 0 && (!child_is_safe || extract_for_size) {
112          ctx.def_size -= child_size;
113          size -= child_size;
114          child.float(ctx, def_name, source.clone(), check, child_is_safe);
115        }
116      }
117    })
118  }
119
120  /// Inserts a new definition for the given term in the combinators map.
121  fn float(
122    &mut self,
123    ctx: &mut FloatCombinatorsCtx,
124    def_name: &Name,
125    source: Source,
126    check: bool,
127    is_safe: bool,
128  ) {
129    let comb_name = Name::new(format!("{}{}{}", def_name, NAME_SEP, ctx.name_gen));
130    ctx.name_gen += 1;
131
132    let comb_ref = Term::Ref { nam: comb_name.clone() };
133    let extracted_term = std::mem::replace(self, comb_ref);
134
135    let rules = vec![Rule { body: extracted_term, pats: Vec::new() }];
136    let rule = Definition::new_gen(comb_name.clone(), rules, source, check);
137    ctx.combinators.insert(comb_name, (is_safe, rule));
138  }
139}
140
141impl Term {
142  /// A term can be considered safe if it is:
143  /// - A Number or an Eraser.
144  /// - A Tuple or Superposition where all elements are safe.
145  /// - An application or numeric operation where all arguments are safe.
146  /// - A safe Lambda, e.g. a nullary constructor or a lambda with safe body.
147  /// - A Reference with a safe body.
148  ///
149  /// A reference to a recursive definition (or mutually recursive) is not safe.
150  fn is_safe(&self, ctx: &mut FloatCombinatorsCtx) -> bool {
151    maybe_grow(|| match self {
152      Term::Num { .. }
153      | Term::Era
154      | Term::Err
155      | Term::Fan { .. }
156      | Term::App { .. }
157      | Term::Oper { .. }
158      | Term::Swt { .. } => self.children().all(|c| c.is_safe(ctx)),
159      Term::Lam { .. } => self.is_safe_lambda(ctx),
160      Term::Ref { nam } => {
161        // Constructors are safe
162        if ctx.book.ctrs.contains_key(nam) {
163          return true;
164        }
165        // If recursive, not safe
166        if ctx.seen.contains(nam) {
167          return false;
168        }
169        ctx.seen.insert(nam.clone());
170
171        // Check if the function it's referring to is safe
172        let safe = if let Some(def) = ctx.book.defs.get(nam) {
173          def.rule().body.is_safe(ctx)
174        } else if let Some((safe, _)) = ctx.combinators.get(nam) {
175          *safe
176        } else {
177          false
178        };
179
180        ctx.seen.remove(nam);
181        safe
182      }
183      // TODO: Variables can be safe depending on how they're used
184      // For example, in a well-typed numop they're safe.
185      _ => false,
186    })
187  }
188
189  /// Checks if the term is a lambda sequence with a safe body.
190  /// If the body is a variable bound in the lambdas, it's a nullary constructor.
191  /// If the body is a reference, it's in inactive position, so always safe.
192  fn is_safe_lambda(&self, ctx: &mut FloatCombinatorsCtx) -> bool {
193    let mut current = self;
194    let mut scope = Vec::new();
195
196    while let Term::Lam { pat, bod, .. } = current {
197      scope.extend(pat.binds().filter_map(|x| x.as_ref()));
198      current = bod;
199    }
200
201    match current {
202      Term::Var { nam } if scope.contains(&nam) => true,
203      Term::Ref { .. } => true,
204      term => term.is_safe(ctx),
205    }
206  }
207
208  pub fn has_unscoped_diff(&self) -> bool {
209    let (declared, used) = self.unscoped_vars();
210    declared.difference(&used).count() != 0 || used.difference(&declared).count() != 0
211  }
212
213  fn is_combinator(&self) -> bool {
214    self.free_vars().is_empty() && !self.has_unscoped_diff() && !matches!(self, Term::Ref { .. })
215  }
216
217  fn base_size(&self) -> usize {
218    match self {
219      Term::Let { pat, .. } => pat.size(),
220      Term::Fan { els, .. } => els.len() - 1,
221      Term::Mat { arms, .. } => arms.len(),
222      Term::Swt { arms, .. } => 2 * (arms.len() - 1),
223      Term::Lam { .. } => 1,
224      Term::App { .. } => 1,
225      Term::Oper { .. } => 1,
226      Term::Var { .. } => 0,
227      Term::Link { .. } => 0,
228      Term::Use { .. } => 0,
229      Term::Num { .. } => 0,
230      Term::Ref { .. } => 0,
231      Term::Era => 0,
232      Term::Bend { .. }
233      | Term::Fold { .. }
234      | Term::Nat { .. }
235      | Term::Str { .. }
236      | Term::List { .. }
237      | Term::With { .. }
238      | Term::Ask { .. }
239      | Term::Open { .. }
240      | Term::Def { .. }
241      | Term::Err => unreachable!(),
242    }
243  }
244
245  fn size(&self) -> usize {
246    maybe_grow(|| {
247      let children_size: usize = self.children().map(|c| c.size()).sum();
248      self.base_size() + children_size
249    })
250  }
251
252  pub fn float_children_mut(&mut self) -> impl Iterator<Item = &mut Term> {
253    multi_iterator!(FloatIter { Zero, Two, Vec, Mat, App, Swt });
254    match self {
255      Term::App { .. } => {
256        let mut next = Some(self);
257        FloatIter::App(std::iter::from_fn(move || {
258          let cur = next.take();
259          if let Some(Term::App { fun, arg, .. }) = cur {
260            next = Some(&mut *fun);
261            Some(&mut **arg)
262          } else {
263            cur
264          }
265        }))
266      }
267      Term::Mat { arg, bnd: _, with_bnd: _, with_arg, arms } => FloatIter::Mat(
268        [arg.as_mut()].into_iter().chain(with_arg.iter_mut()).chain(arms.iter_mut().map(|r| &mut r.2)),
269      ),
270      Term::Swt { arg, bnd: _, with_bnd: _, with_arg, pred: _, arms } => {
271        FloatIter::Swt([arg.as_mut()].into_iter().chain(with_arg.iter_mut()).chain(arms.iter_mut()))
272      }
273      Term::Fan { els, .. } | Term::List { els } => FloatIter::Vec(els),
274      Term::Let { val: fst, nxt: snd, .. }
275      | Term::Use { val: fst, nxt: snd, .. }
276      | Term::Oper { fst, snd, .. } => FloatIter::Two([fst.as_mut(), snd.as_mut()]),
277      Term::Lam { bod, .. } => bod.float_children_mut(),
278      Term::Var { .. }
279      | Term::Link { .. }
280      | Term::Num { .. }
281      | Term::Nat { .. }
282      | Term::Str { .. }
283      | Term::Ref { .. }
284      | Term::Era
285      | Term::Err => FloatIter::Zero([]),
286      Term::With { .. }
287      | Term::Ask { .. }
288      | Term::Bend { .. }
289      | Term::Fold { .. }
290      | Term::Open { .. }
291      | Term::Def { .. } => {
292        unreachable!()
293      }
294    }
295  }
296}
297
298impl Pattern {
299  fn size(&self) -> usize {
300    match self {
301      Pattern::Var(_) => 0,
302      Pattern::Chn(_) => 0,
303      Pattern::Fan(_, _, pats) => pats.len() - 1 + pats.iter().map(|p| p.size()).sum::<usize>(),
304
305      Pattern::Num(_) | Pattern::Lst(_) | Pattern::Str(_) | Pattern::Ctr(_, _) => unreachable!(),
306    }
307  }
308}