bend/fun/transform/
unique_names.rs

1// Pass to give all variables in a definition unique names.
2
3use crate::{
4  fun::{Book, Name, Term},
5  maybe_grow,
6};
7use std::collections::HashMap;
8
9impl Book {
10  /// Makes all variables in each definition have a new unique name.
11  /// Skips unbound variables.
12  /// Precondition: Definition references have been resolved.
13  pub fn make_var_names_unique(&mut self) {
14    for def in self.defs.values_mut() {
15      def.rule_mut().body.make_var_names_unique();
16    }
17  }
18}
19
20impl Term {
21  pub fn make_var_names_unique(&mut self) {
22    UniqueNameGenerator::default().unique_names_in_term(self);
23  }
24}
25
26type VarId = u64;
27
28#[derive(Default)]
29pub struct UniqueNameGenerator {
30  name_map: HashMap<Name, Vec<VarId>>,
31  name_count: VarId,
32}
33
34impl UniqueNameGenerator {
35  // Recursively assign an id to each variable in the term, then convert each id into a unique name.
36  pub fn unique_names_in_term(&mut self, term: &mut Term) {
37    // Note: we can't use the children iterators here because we mutate the binds,
38    // which are shared across multiple children.
39    maybe_grow(|| match term {
40      Term::Var { nam } => *nam = self.use_var(nam),
41
42      Term::Mat { bnd, arg, with_bnd, with_arg, arms }
43      | Term::Fold { bnd, arg, with_bnd, with_arg, arms } => {
44        // Process args
45        self.unique_names_in_term(arg);
46        for arg in with_arg {
47          self.unique_names_in_term(arg);
48        }
49
50        // Add binds shared by all arms
51        self.push(bnd.as_ref());
52        for bnd in with_bnd.iter() {
53          self.push(bnd.as_ref());
54        }
55
56        // Process arms
57        for arm in arms {
58          // Add binds unique to each arm
59          for bnd in arm.1.iter() {
60            self.push(bnd.as_ref());
61          }
62
63          // Process arm body
64          self.unique_names_in_term(&mut arm.2);
65
66          // Remove binds unique to each arm
67          for bnd in arm.1.iter_mut() {
68            *bnd = self.pop(bnd.as_ref());
69          }
70        }
71
72        // Remove binds shared by all arms
73        for bnd in with_bnd {
74          *bnd = self.pop(bnd.as_ref());
75        }
76        *bnd = self.pop(bnd.as_ref());
77      }
78
79      Term::Swt { bnd, arg, with_bnd, with_arg, pred, arms } => {
80        self.unique_names_in_term(arg);
81        for arg in with_arg {
82          self.unique_names_in_term(arg);
83        }
84
85        self.push(bnd.as_ref());
86        for bnd in with_bnd.iter() {
87          self.push(bnd.as_ref());
88        }
89
90        let (succ, nums) = arms.split_last_mut().unwrap();
91        for arm in nums.iter_mut() {
92          self.unique_names_in_term(arm);
93        }
94
95        self.push(pred.as_ref());
96        self.unique_names_in_term(succ);
97        *pred = self.pop(pred.as_ref());
98
99        for bnd in with_bnd {
100          *bnd = self.pop(bnd.as_ref());
101        }
102        *bnd = self.pop(bnd.as_ref());
103      }
104
105      Term::Bend { bnd, arg, cond, step, base } => {
106        for arg in arg {
107          self.unique_names_in_term(arg);
108        }
109        for bnd in bnd.iter() {
110          self.push(bnd.as_ref());
111        }
112        self.unique_names_in_term(cond);
113        self.unique_names_in_term(step);
114        self.unique_names_in_term(base);
115        for bnd in bnd {
116          *bnd = self.pop(bnd.as_ref());
117        }
118      }
119
120      Term::Let { pat, val, nxt } | Term::Ask { pat, val, nxt } => {
121        self.unique_names_in_term(val);
122        for bnd in pat.binds() {
123          self.push(bnd.as_ref());
124        }
125        self.unique_names_in_term(nxt);
126        for bind in pat.binds_mut() {
127          *bind = self.pop(bind.as_ref());
128        }
129      }
130      Term::Use { nam, val, nxt } => {
131        self.unique_names_in_term(val);
132        self.push(nam.as_ref());
133        self.unique_names_in_term(nxt);
134        *nam = self.pop(nam.as_ref());
135      }
136      Term::Lam { tag: _, pat, bod } => {
137        for bind in pat.binds() {
138          self.push(bind.as_ref());
139        }
140        self.unique_names_in_term(bod);
141        for bind in pat.binds_mut() {
142          *bind = self.pop(bind.as_ref());
143        }
144      }
145      Term::Fan { fan: _, tag: _, els } | Term::List { els } => {
146        for el in els {
147          self.unique_names_in_term(el);
148        }
149      }
150      Term::App { tag: _, fun: fst, arg: snd } | Term::Oper { opr: _, fst, snd } => {
151        self.unique_names_in_term(fst);
152        self.unique_names_in_term(snd);
153      }
154      Term::With { typ: _, bod } => {
155        self.unique_names_in_term(bod);
156      }
157      Term::Link { .. }
158      | Term::Num { .. }
159      | Term::Nat { .. }
160      | Term::Str { .. }
161      | Term::Ref { .. }
162      | Term::Era
163      | Term::Err => {}
164      Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
165      Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
166    })
167  }
168
169  fn push(&mut self, nam: Option<&Name>) {
170    if let Some(name) = nam {
171      if let Some(ids) = self.name_map.get_mut(name) {
172        ids.push(self.name_count);
173      } else {
174        self.name_map.insert(name.clone(), vec![self.name_count]);
175      }
176      self.name_count += 1;
177    }
178  }
179
180  fn pop(&mut self, nam: Option<&Name>) -> Option<Name> {
181    if let Some(name) = nam {
182      let var_id = self.name_map.get_mut(name).unwrap().pop().unwrap();
183      if self.name_map[name].is_empty() {
184        self.name_map.remove(name);
185      }
186      Some(Name::from(var_id))
187    } else {
188      None
189    }
190  }
191
192  fn use_var(&self, nam: &Name) -> Name {
193    if let Some(vars) = self.name_map.get(nam) {
194      let var_id = *vars.last().unwrap();
195      Name::from(var_id)
196    } else {
197      // Skip unbound variables.
198      // With this, we can use this function before checking for unbound vars.
199      nam.clone()
200    }
201  }
202}