bend/fun/transform/
expand_main.rs

1use crate::{
2  fun::{Book, Name, Pattern, Term},
3  maybe_grow,
4};
5use std::collections::HashMap;
6
7impl Book {
8  /// Expands the main function so that it is not just a reference.
9  /// While technically correct, directly returning a reference is never what users want.
10  pub fn expand_main(&mut self) {
11    if self.entrypoint.is_none() {
12      return;
13    }
14
15    let main = self.defs.get_mut(self.entrypoint.as_ref().unwrap()).unwrap();
16    let mut main_bod = std::mem::take(&mut main.rule_mut().body);
17
18    let mut seen = vec![self.entrypoint.as_ref().unwrap().clone()];
19    main_bod.expand_ref_return(self, &mut seen, &mut 0);
20
21    // Undo the `float_combinators` pass for main, to recover the strictness of the main function.
22    main_bod.expand_floated_combinators(self);
23
24    let main = self.defs.get_mut(self.entrypoint.as_ref().unwrap()).unwrap();
25    main.rule_mut().body = main_bod;
26  }
27}
28
29impl Term {
30  /// Expands references in the main function that are in "return" position.
31  ///
32  /// This applies to:
33  /// - When main returns a reference.
34  /// - When main returns a lambda whose body is a reference.
35  /// - When main returns a pair or superposition and one of its elements is a reference.
36  ///
37  /// Only expand recursive functions once.
38  fn expand_ref_return(&mut self, book: &Book, seen: &mut Vec<Name>, globals_count: &mut usize) {
39    maybe_grow(|| match self {
40      Term::Ref { nam } => {
41        if seen.contains(nam) {
42          // Don't expand recursive references
43        } else if let Some(def) = book.defs.get(nam) {
44          // Regular function, expand
45          seen.push(nam.clone());
46          let mut body = def.rule().body.clone();
47          body.rename_unscoped(globals_count, &mut HashMap::new());
48          *self = body;
49          self.expand_ref_return(book, seen, globals_count);
50          seen.pop().unwrap();
51        } else {
52          // Not a regular function, don't expand
53        }
54      }
55      Term::Fan { els, .. } | Term::List { els } => {
56        for el in els {
57          el.expand_ref_return(book, seen, globals_count);
58        }
59      }
60      // If an application is just a constructor, we expand the arguments.
61      // That way we can write programs like
62      // `main = [do_thing1, do_thing2, do_thing3]`
63      Term::App { .. } => {
64        let (fun, args) = self.multi_arg_app();
65        if let Term::Ref { nam } = fun {
66          if book.ctrs.contains_key(nam) {
67            for arg in args {
68              // If the argument is a 0-ary constructor, we don't need to expand it.
69              if let Term::Ref { nam } = arg {
70                if let Some(adt_nam) = book.ctrs.get(nam) {
71                  if book.adts.get(adt_nam).unwrap().ctrs.get(nam).unwrap().fields.is_empty() {
72                    continue;
73                  }
74                }
75              }
76              // Otherwise, we expand the argument.
77              arg.expand_ref_return(book, seen, globals_count);
78            }
79          }
80        }
81      }
82      Term::Lam { bod: nxt, .. }
83      | Term::With { bod: nxt, .. }
84      | Term::Open { bod: nxt, .. }
85      | Term::Let { nxt, .. }
86      | Term::Ask { nxt, .. }
87      | Term::Use { nxt, .. } => nxt.expand_ref_return(book, seen, globals_count),
88      Term::Var { .. }
89      | Term::Link { .. }
90      | Term::Num { .. }
91      | Term::Nat { .. }
92      | Term::Str { .. }
93      | Term::Oper { .. }
94      | Term::Mat { .. }
95      | Term::Swt { .. }
96      | Term::Fold { .. }
97      | Term::Bend { .. }
98      | Term::Def { .. }
99      | Term::Era
100      | Term::Err => {}
101    })
102  }
103
104  fn expand_floated_combinators(&mut self, book: &Book) {
105    maybe_grow(|| {
106      if let Term::Ref { nam } = self {
107        if nam.contains(super::float_combinators::NAME_SEP) {
108          *self = book.defs.get(nam).unwrap().rule().body.clone();
109        }
110      }
111      for child in self.children_mut() {
112        child.expand_floated_combinators(book);
113      }
114    })
115  }
116
117  /// Read the term as an n-ary application.
118  fn multi_arg_app(&mut self) -> (&mut Term, Vec<&mut Term>) {
119    fn go<'a>(term: &'a mut Term, args: &mut Vec<&'a mut Term>) -> &'a mut Term {
120      match term {
121        Term::App { fun, arg, .. } => {
122          args.push(arg);
123          go(fun, args)
124        }
125        _ => term,
126      }
127    }
128    let mut args = vec![];
129    let fun = go(self, &mut args);
130    (fun, args)
131  }
132}
133
134impl Term {
135  /// Since expanded functions can contain unscoped variables, and
136  /// unscoped variable names must be unique, we need to rename them
137  /// to avoid conflicts.
138  fn rename_unscoped(&mut self, unscoped_count: &mut usize, unscoped_map: &mut HashMap<Name, Name>) {
139    match self {
140      Term::Let { pat, .. } | Term::Lam { pat, .. } => pat.rename_unscoped(unscoped_count, unscoped_map),
141      Term::Link { nam } => rename_unscoped(nam, unscoped_count, unscoped_map),
142      _ => {
143        // Isn't an unscoped bind or use, do nothing, just recurse.
144      }
145    }
146    for child in self.children_mut() {
147      child.rename_unscoped(unscoped_count, unscoped_map);
148    }
149  }
150}
151
152impl Pattern {
153  fn rename_unscoped(&mut self, unscoped_count: &mut usize, unscoped_map: &mut HashMap<Name, Name>) {
154    maybe_grow(|| {
155      match self {
156        Pattern::Chn(nam) => rename_unscoped(nam, unscoped_count, unscoped_map),
157        _ => {
158          // Pattern isn't an unscoped bind, just recurse.
159        }
160      }
161      for child in self.children_mut() {
162        child.rename_unscoped(unscoped_count, unscoped_map);
163      }
164    })
165  }
166}
167
168/// Generates a new name for an unscoped variable.
169fn rename_unscoped(nam: &mut Name, unscoped_count: &mut usize, unscoped_map: &mut HashMap<Name, Name>) {
170  if let Some(new_nam) = unscoped_map.get(nam) {
171    *nam = new_nam.clone();
172  } else {
173    let new_nam = Name::new(format!("{nam}%{}", unscoped_count));
174    unscoped_map.insert(nam.clone(), new_nam.clone());
175    *unscoped_count += 1;
176    *nam = new_nam;
177  }
178}