bend/fun/transform/
linearize_vars.rs

1use crate::{
2  fun::{Book, FanKind, Name, Pattern, Tag, Term},
3  maybe_grow, multi_iterator,
4};
5use std::collections::HashMap;
6
7/// Erases variables that weren't used, dups the ones that were used more than once.
8/// Substitutes lets into their variable use.
9/// In details:
10/// For all var declarations:
11///   If they're used 0 times: erase the declaration
12///   If they're used 1 time: leave them as-is
13///   If they're used more times: insert dups to make var use affine
14/// For all let vars:
15///   If they're used 0 times: Discard the let
16///   If they're used 1 time: substitute the body in the var use
17///   If they're use more times: add dups for all the uses, put the let body at the root dup.
18/// Precondition: All variables are bound and have unique names within each definition.
19impl Book {
20  pub fn linearize_vars(&mut self) {
21    for def in self.defs.values_mut() {
22      def.rule_mut().body.linearize_vars();
23    }
24  }
25}
26
27impl Term {
28  pub fn linearize_vars(&mut self) {
29    term_to_linear(self, &mut HashMap::new());
30  }
31}
32
33fn term_to_linear(term: &mut Term, var_uses: &mut HashMap<Name, u64>) {
34  maybe_grow(|| {
35    if let Term::Let { pat, val, nxt } = term {
36      if let Pattern::Var(Some(nam)) = pat.as_ref() {
37        // TODO: This is swapping the order of how the bindings are
38        // used, since it's not following the usual AST order (first
39        // val, then nxt). Doesn't change behaviour, but looks strange.
40        term_to_linear(nxt, var_uses);
41
42        let uses = get_var_uses(Some(nam), var_uses);
43        term_to_linear(val, var_uses);
44        match uses {
45          0 => {
46            let Term::Let { pat, .. } = term else { unreachable!() };
47            **pat = Pattern::Var(None);
48          }
49          1 => {
50            nxt.subst(nam, val.as_ref());
51            *term = std::mem::take(nxt.as_mut());
52          }
53          _ => {
54            let new_pat = duplicate_pat(nam, uses);
55            let Term::Let { pat, .. } = term else { unreachable!() };
56            *pat = new_pat;
57          }
58        }
59        return;
60      }
61    }
62    if let Term::Var { nam } = term {
63      let instantiated_count = var_uses.entry(nam.clone()).or_default();
64      *instantiated_count += 1;
65      *nam = dup_name(nam, *instantiated_count);
66      return;
67    }
68
69    for (child, binds) in term.children_mut_with_binds_mut() {
70      term_to_linear(child, var_uses);
71
72      for bind in binds {
73        let uses = get_var_uses(bind.as_ref(), var_uses);
74        match uses {
75          // Erase binding
76          0 => *bind = None,
77          // Keep as-is
78          1 => (),
79          // Duplicate binding
80          uses => {
81            debug_assert!(uses > 1);
82            let nam = bind.as_ref().unwrap();
83            *child = Term::Let {
84              pat: duplicate_pat(nam, uses),
85              val: Box::new(Term::Var { nam: nam.clone() }),
86              nxt: Box::new(std::mem::take(child)),
87            }
88          }
89        }
90      }
91    }
92  })
93}
94
95fn get_var_uses(nam: Option<&Name>, var_uses: &HashMap<Name, u64>) -> u64 {
96  nam.and_then(|nam| var_uses.get(nam).copied()).unwrap_or_default()
97}
98
99fn duplicate_pat(nam: &Name, uses: u64) -> Box<Pattern> {
100  Box::new(Pattern::Fan(
101    FanKind::Dup,
102    Tag::Auto,
103    (1..uses + 1).map(|i| Pattern::Var(Some(dup_name(nam, i)))).collect(),
104  ))
105}
106
107fn dup_name(nam: &Name, uses: u64) -> Name {
108  if uses == 1 {
109    nam.clone()
110  } else {
111    Name::new(format!("{nam}_{uses}"))
112  }
113}
114
115impl Term {
116  /// Because multiple children can share the same binds, this function is very restricted.
117  /// Should only be called after desugaring bends/folds/matches/switches.
118  pub fn children_mut_with_binds_mut(
119    &mut self,
120  ) -> impl DoubleEndedIterator<Item = (&mut Term, impl DoubleEndedIterator<Item = &mut Option<Name>>)> {
121    multi_iterator!(ChildrenIter { Zero, One, Two, Vec, Swt });
122    multi_iterator!(BindsIter { Zero, One, Pat });
123    match self {
124      Term::Swt { arg, bnd, with_bnd, with_arg, pred, arms } => {
125        debug_assert!(bnd.is_none());
126        debug_assert!(with_bnd.is_empty());
127        debug_assert!(with_arg.is_empty());
128        debug_assert!(pred.is_none());
129        ChildrenIter::Swt(
130          [(arg.as_mut(), BindsIter::Zero([]))]
131            .into_iter()
132            .chain(arms.iter_mut().map(|x| (x, BindsIter::Zero([])))),
133        )
134      }
135      Term::Fan { els, .. } | Term::List { els } => {
136        ChildrenIter::Vec(els.iter_mut().map(|el| (el, BindsIter::Zero([]))))
137      }
138      Term::Use { nam, val, nxt } => {
139        ChildrenIter::Two([(val.as_mut(), BindsIter::Zero([])), (nxt.as_mut(), BindsIter::One([nam]))])
140      }
141      Term::Let { pat, val, nxt, .. } | Term::Ask { pat, val, nxt, .. } => ChildrenIter::Two([
142        (val.as_mut(), BindsIter::Zero([])),
143        (nxt.as_mut(), BindsIter::Pat(pat.binds_mut())),
144      ]),
145      Term::App { fun: fst, arg: snd, .. } | Term::Oper { fst, snd, .. } => {
146        ChildrenIter::Two([(fst.as_mut(), BindsIter::Zero([])), (snd.as_mut(), BindsIter::Zero([]))])
147      }
148      Term::Lam { pat, bod, .. } => ChildrenIter::One([(bod.as_mut(), BindsIter::Pat(pat.binds_mut()))]),
149      Term::With { bod, .. } => ChildrenIter::One([(bod.as_mut(), BindsIter::Zero([]))]),
150      Term::Var { .. }
151      | Term::Link { .. }
152      | Term::Num { .. }
153      | Term::Nat { .. }
154      | Term::Str { .. }
155      | Term::Ref { .. }
156      | Term::Era
157      | Term::Err => ChildrenIter::Zero([]),
158      Term::Mat { .. } => unreachable!("'match' should be removed in earlier pass"),
159      Term::Fold { .. } => unreachable!("'fold' should be removed in earlier pass"),
160      Term::Bend { .. } => unreachable!("'bend' should be removed in earlier pass"),
161      Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
162      Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
163    }
164  }
165}