bend/fun/transform/
linearize_matches.rs

1use crate::{
2  fun::{Book, Name, Pattern, Term},
3  maybe_grow,
4};
5use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
6
7/* Linearize preceding binds */
8
9impl Book {
10  /// Linearization of binds preceding match/switch terms, up to the
11  /// first bind used in either the scrutinee or the bind.
12  ///
13  /// Example:
14  /// ```hvm
15  /// @a @b @c let d = (b c); switch a {
16  ///   0: (A b c d)
17  ///   _: (B a-1 b c d)
18  /// }
19  /// // Since `b`, `c` and `d` would be eta-reducible if linearized,
20  /// // they get pushed inside the match.
21  /// @a switch a {
22  ///   0: @b @c let d = (b c); (A b c d)
23  ///   _: @b @c let d = (b c); (B a-1 b c d)
24  /// }
25  /// ```
26  pub fn linearize_match_binds(&mut self) {
27    for def in self.defs.values_mut() {
28      for rule in def.rules.iter_mut() {
29        rule.body.linearize_match_binds();
30      }
31    }
32  }
33}
34
35impl Term {
36  /// Linearize any binds preceding a match/switch term, up to the
37  /// first bind used in either the scrutinee or the bind.
38  pub fn linearize_match_binds(&mut self) {
39    self.linearize_match_binds_go(vec![]);
40  }
41
42  fn linearize_match_binds_go(&mut self, mut bind_terms: Vec<Term>) {
43    maybe_grow(|| match self {
44      // Binding terms
45      // Extract them in case they are preceding a match.
46      Term::Lam { pat, bod, .. } if !pat.has_unscoped() => {
47        let bod = std::mem::take(bod.as_mut());
48        let term = std::mem::replace(self, bod);
49        bind_terms.push(term);
50        self.linearize_match_binds_go(bind_terms);
51      }
52      Term::Let { val, nxt, .. } | Term::Use { val, nxt, .. } => {
53        val.linearize_match_binds_go(vec![]);
54        if val.has_unscoped() {
55          // Terms with unscoped can't be linearized since their names must be unique.
56          nxt.linearize_match_binds_go(vec![]);
57          self.wrap_with_bind_terms(bind_terms);
58        } else {
59          let nxt = std::mem::take(nxt.as_mut());
60          let term = std::mem::replace(self, nxt);
61          bind_terms.push(term);
62          self.linearize_match_binds_go(bind_terms);
63        }
64      }
65
66      // Matching terms
67      Term::Mat { .. } | Term::Swt { .. } => {
68        self.linearize_binds_single_match(bind_terms);
69      }
70
71      // Others
72      // Not a match preceded by binds, so put the extracted terms back.
73      term => {
74        for child in term.children_mut() {
75          child.linearize_match_binds_go(vec![]);
76        }
77        // Recover the extracted terms
78        term.wrap_with_bind_terms(bind_terms);
79      }
80    })
81  }
82
83  fn linearize_binds_single_match(&mut self, mut bind_terms: Vec<Term>) {
84    let (used_vars, with_bnd, with_arg, arms) = match self {
85      Term::Mat { arg, bnd: _, with_bnd, with_arg, arms } => {
86        let vars = arg.free_vars().into_keys().collect::<HashSet<_>>();
87        let arms = arms.iter_mut().map(|arm| &mut arm.2).collect::<Vec<_>>();
88        (vars, with_bnd, with_arg, arms)
89      }
90      Term::Swt { arg, bnd: _, with_bnd, with_arg, pred: _, arms } => {
91        let vars = arg.free_vars().into_keys().collect::<HashSet<_>>();
92        let arms = arms.iter_mut().collect();
93        (vars, with_bnd, with_arg, arms)
94      }
95      _ => unreachable!(),
96    };
97
98    // Add 'with' args as lets that can be moved
99    for (bnd, arg) in with_bnd.iter().zip(with_arg.iter()) {
100      let term = Term::Let {
101        pat: Box::new(Pattern::Var(bnd.clone())),
102        val: Box::new(arg.clone()),
103        nxt: Box::new(Term::Err),
104      };
105      bind_terms.push(term)
106    }
107
108    let (mut non_linearized, linearized) = fixed_and_linearized_terms(used_vars, bind_terms);
109
110    // Add the linearized terms to the arms and recurse
111    for arm in arms {
112      arm.wrap_with_bind_terms(linearized.clone());
113      arm.linearize_match_binds_go(vec![]);
114    }
115
116    // Remove the linearized binds from the with clause
117    let linearized_binds = linearized
118      .iter()
119      .flat_map(|t| match t {
120        Term::Lam { pat, .. } | Term::Let { pat, .. } => pat.binds().flatten().cloned().collect::<Vec<_>>(),
121        Term::Use { nam, .. } => {
122          if let Some(nam) = nam {
123            vec![nam.clone()]
124          } else {
125            vec![]
126          }
127        }
128        _ => unreachable!(),
129      })
130      .collect::<BTreeSet<_>>();
131    update_with_clause(with_bnd, with_arg, &linearized_binds);
132
133    // Remove the non-linearized 'with' binds from the terms that need
134    // to be added back (since we didn't move them).
135    non_linearized.retain(|term| {
136      if let Term::Let { pat, .. } = term {
137        if let Pattern::Var(bnd) = pat.as_ref() {
138          if with_bnd.contains(bnd) {
139            return false;
140          }
141        }
142      }
143      true
144    });
145
146    // Add the non-linearized terms back to before the match
147    self.wrap_with_bind_terms(non_linearized);
148  }
149
150  /// Given a term `self` and a sequence of `bind_terms`, wrap `self` with those binds.
151  ///
152  /// Example:
153  /// ```hvm
154  /// self = X
155  /// match_terms = [λb *, let c = (a b); *, λd *]
156  /// ```
157  ///
158  /// becomes
159  ///
160  /// ```hvm
161  /// self = λb let c = (a b); λd X
162  /// ```
163  fn wrap_with_bind_terms(
164    &mut self,
165    bind_terms: impl IntoIterator<IntoIter = impl DoubleEndedIterator<Item = Term>>,
166  ) {
167    *self = bind_terms.into_iter().rfold(std::mem::take(self), |acc, mut term| {
168      match &mut term {
169        Term::Lam { bod: nxt, .. } | Term::Let { nxt, .. } | Term::Use { nxt, .. } => {
170          *nxt.as_mut() = acc;
171        }
172        _ => unreachable!(),
173      }
174      term
175    });
176  }
177}
178
179/// Separates the bind terms surround the match in two partitions,
180/// one to be linearized, one to stay where they where.
181///
182/// We try to move down any binds that would become eta-reducible with linearization
183/// and that will not introduce extra duplications.
184///
185/// This requires the bind to follow some rules:
186/// * Can only depend on binds that will be moved
187/// * Can't come before any bind that will not be moved.
188/// * Must be a scoped bind.
189///
190/// Examples:
191///
192/// ```hvm
193/// @a @b @c switch b { 0: c; _: (c b-1) }
194/// // Will linearize `c` but not `a` since it comes before a lambda that can't be moved
195/// // Becomes
196/// @a @b switch b { 0: @c c; _: @c (c b-1) }
197/// ```
198///
199/// ```hvm
200/// @a let b = a; @c let e = b; let d = c; switch a { 0: X; _: Y }
201/// // Will not linearize `let b = a` since it would duplicate `a`
202/// // Will linearize `c` since it's a lambda that is not depended on by the argument
203/// // Will not linearize `let e = b` since it would duplicate `b`
204/// // Will linearize `let d = c` since it depends only on variables that will be moved
205/// // and is not depended on by the argument
206/// ```
207fn fixed_and_linearized_terms(used_in_arg: HashSet<Name>, bind_terms: Vec<Term>) -> (Vec<Term>, Vec<Term>) {
208  let fixed_binds = binds_fixed_by_dependency(used_in_arg, &bind_terms);
209
210  let mut fixed = VecDeque::new();
211  let mut linearized = VecDeque::new();
212  let mut stop = false;
213  for term in bind_terms.into_iter().rev() {
214    let to_linearize = match &term {
215      Term::Use { nam, .. } => nam.as_ref().map_or(true, |nam| !fixed_binds.contains(nam)),
216      Term::Let { pat, .. } => pat.binds().flatten().all(|nam| !fixed_binds.contains(nam)),
217      Term::Lam { pat, .. } => pat.binds().flatten().all(|nam| !fixed_binds.contains(nam)),
218      _ => unreachable!(),
219    };
220    let to_linearize = to_linearize && !stop;
221    if to_linearize {
222      linearized.push_front(term);
223    } else {
224      if matches!(term, Term::Lam { .. }) {
225        stop = true;
226      }
227      fixed.push_front(term);
228    }
229  }
230  (fixed.into_iter().collect(), linearized.into_iter().collect())
231}
232
233/// Get which binds are fixed because they are in the dependency graph
234/// of a free var or of a var used in the match arg.
235fn binds_fixed_by_dependency(used_in_arg: HashSet<Name>, bind_terms: &[Term]) -> HashSet<Name> {
236  let mut fixed_binds = used_in_arg;
237
238  // Find the use dependencies of each bind
239  let mut binds = vec![];
240  let mut dependency_digraph = HashMap::new();
241  for term in bind_terms {
242    // Gather what are the binds of this term and what vars it is directly using
243    let (term_binds, term_uses) = match term {
244      Term::Lam { pat, .. } => {
245        let binds = pat.binds().flatten().cloned().collect::<Vec<_>>();
246        (binds, vec![])
247      }
248      Term::Let { pat, val, .. } => {
249        let binds = pat.binds().flatten().cloned().collect::<Vec<_>>();
250        let uses = val.free_vars().into_keys().collect();
251        (binds, uses)
252      }
253      Term::Use { nam, val, .. } => {
254        let binds = if let Some(nam) = nam { vec![nam.clone()] } else { vec![] };
255        let uses = val.free_vars().into_keys().collect();
256        (binds, uses)
257      }
258      _ => unreachable!(),
259    };
260
261    for bind in term_binds {
262      dependency_digraph.insert(bind.clone(), term_uses.clone());
263      binds.push(bind);
264    }
265  }
266
267  // Mark binds that depend on free vars as fixed
268  for (bind, deps) in dependency_digraph.iter() {
269    if deps.iter().any(|dep| !binds.contains(dep)) {
270      fixed_binds.insert(bind.clone());
271    }
272  }
273
274  // Convert to undirected graph
275  let mut dependency_graph: HashMap<Name, HashSet<Name>> =
276    HashMap::from_iter(binds.iter().map(|k| (k.clone(), HashSet::new())));
277  for (bind, deps) in dependency_digraph {
278    for dep in deps {
279      if !binds.contains(&dep) {
280        dependency_graph.insert(dep.clone(), HashSet::new());
281      }
282      dependency_graph.get_mut(&dep).unwrap().insert(bind.clone());
283      dependency_graph.get_mut(&bind).unwrap().insert(dep);
284    }
285  }
286
287  // Find which binds are connected to the vars used in the match arg or to free vars.
288  let mut used_component = HashSet::new();
289  let mut visited = HashSet::new();
290  let mut to_visit = fixed_binds.iter().collect::<Vec<_>>();
291  while let Some(node) = to_visit.pop() {
292    if visited.contains(node) {
293      continue;
294    }
295    used_component.insert(node.clone());
296    visited.insert(node);
297
298    // Add these dependencies to be checked (if it's not a free var in the match arg)
299    if let Some(deps) = dependency_graph.get(node) {
300      to_visit.extend(deps);
301    }
302  }
303
304  // Mark lambdas that come before a fixed lambda as also fixed
305  let mut fixed_start = false;
306  let mut fixed_lams = HashSet::new();
307  for term in bind_terms.iter().rev() {
308    if let Term::Lam { pat, .. } = term {
309      if pat.binds().flatten().any(|p| used_component.contains(p)) {
310        fixed_start = true;
311      }
312      if fixed_start {
313        for bind in pat.binds().flatten() {
314          fixed_lams.insert(bind.clone());
315        }
316      }
317    }
318  }
319
320  let mut fixed_binds = used_component;
321
322  // Mark binds that depend on fixed lambdas as also fixed.
323  let mut visited = HashSet::new();
324  let mut to_visit = fixed_lams.iter().collect::<Vec<_>>();
325  while let Some(node) = to_visit.pop() {
326    if visited.contains(node) {
327      continue;
328    }
329    fixed_binds.insert(node.clone());
330    visited.insert(node);
331
332    // Add these dependencies to be checked (if it's not a free var in the match arg)
333    if let Some(deps) = dependency_graph.get(node) {
334      to_visit.extend(deps);
335    }
336  }
337
338  fixed_binds
339}
340
341fn update_with_clause(
342  with_bnd: &mut Vec<Option<Name>>,
343  with_arg: &mut Vec<Term>,
344  vars_to_lift: &BTreeSet<Name>,
345) {
346  let mut to_remove = Vec::new();
347  for i in 0..with_bnd.len() {
348    if let Some(with_bnd) = &with_bnd[i] {
349      if vars_to_lift.contains(with_bnd) {
350        to_remove.push(i);
351      }
352    }
353  }
354  for (removed, to_remove) in to_remove.into_iter().enumerate() {
355    with_bnd.remove(to_remove - removed);
356    with_arg.remove(to_remove - removed);
357  }
358}
359/* Linearize all used vars */
360
361impl Book {
362  /// Linearizes all variables used in a matches' arms.
363  pub fn linearize_matches(&mut self) {
364    for def in self.defs.values_mut() {
365      for rule in def.rules.iter_mut() {
366        rule.body.linearize_matches();
367      }
368    }
369  }
370}
371
372impl Term {
373  fn linearize_matches(&mut self) {
374    maybe_grow(|| {
375      for child in self.children_mut() {
376        child.linearize_matches();
377      }
378
379      if matches!(self, Term::Mat { .. } | Term::Swt { .. }) {
380        lift_match_vars(self);
381      }
382    })
383  }
384}
385
386/// Converts free vars inside the match arms into lambdas with
387/// applications around the match to pass them the external value.
388///
389/// Makes the rules extractable and linear (no need for dups even
390/// when a variable is used in multiple rules).
391///
392/// Obs: This does not modify unscoped variables.
393pub fn lift_match_vars(match_term: &mut Term) -> &mut Term {
394  // Collect match arms with binds
395  let (with_bnd, with_arg, arms) = match match_term {
396    Term::Mat { arg: _, bnd: _, with_bnd, with_arg, arms: rules } => {
397      let args =
398        rules.iter().map(|(_, binds, body)| (binds.iter().flatten().cloned().collect(), body)).collect();
399      (with_bnd.clone(), with_arg.clone(), args)
400    }
401    Term::Swt { arg: _, bnd: _, with_bnd, with_arg, pred, arms } => {
402      let (succ, nums) = arms.split_last_mut().unwrap();
403      let mut arms = nums.iter().map(|body| (vec![], body)).collect::<Vec<_>>();
404      arms.push((vec![pred.clone().unwrap()], succ));
405      (with_bnd.clone(), with_arg.clone(), arms)
406    }
407    _ => unreachable!(),
408  };
409
410  // Collect all free vars in the match arms
411  let mut free_vars = Vec::<Vec<_>>::new();
412  for (binds, body) in arms {
413    let mut arm_free_vars = body.free_vars();
414    for bind in binds {
415      arm_free_vars.shift_remove(&bind);
416    }
417    free_vars.push(arm_free_vars.into_keys().collect());
418  }
419
420  // Collect the vars to lift
421  // We need consistent iteration order.
422  let vars_to_lift: BTreeSet<Name> = free_vars.into_iter().flatten().collect();
423
424  // Add lambdas to the arms
425  match match_term {
426    Term::Mat { arg: _, bnd: _, with_bnd, with_arg, arms } => {
427      update_with_clause(with_bnd, with_arg, &vars_to_lift);
428      for arm in arms {
429        let old_body = std::mem::take(&mut arm.2);
430        arm.2 = Term::rfold_lams(old_body, vars_to_lift.iter().cloned().map(Some));
431      }
432    }
433    Term::Swt { arg: _, bnd: _, with_bnd, with_arg, pred: _, arms } => {
434      update_with_clause(with_bnd, with_arg, &vars_to_lift);
435      for arm in arms {
436        let old_body = std::mem::take(arm);
437        *arm = Term::rfold_lams(old_body, vars_to_lift.iter().cloned().map(Some));
438      }
439    }
440    _ => unreachable!(),
441  }
442
443  // Add apps to the match
444  let args = vars_to_lift
445    .into_iter()
446    .map(|nam| {
447      if let Some(idx) = with_bnd.iter().position(|x| x == &nam) {
448        with_arg[idx].clone()
449      } else {
450        Term::Var { nam }
451      }
452    })
453    .collect::<Vec<_>>();
454  let term = Term::call(std::mem::take(match_term), args);
455  *match_term = term;
456
457  get_match_reference(match_term)
458}
459
460/// Get a reference to the match again
461/// It returns a reference and not an owned value because we want
462/// to keep the new surrounding Apps but still modify the match further.
463fn get_match_reference(mut match_term: &mut Term) -> &mut Term {
464  loop {
465    match match_term {
466      Term::App { tag: _, fun, arg: _ } => match_term = fun.as_mut(),
467      Term::Swt { .. } | Term::Mat { .. } => return match_term,
468      _ => unreachable!(),
469    }
470  }
471}
472
473/* Linearize `with` vars  */
474
475impl Book {
476  /// Linearizes all variables specified in the `with` clauses of match terms.
477  pub fn linearize_match_with(&mut self) {
478    for def in self.defs.values_mut() {
479      for rule in def.rules.iter_mut() {
480        rule.body.linearize_match_with();
481      }
482    }
483  }
484}
485
486impl Term {
487  fn linearize_match_with(&mut self) {
488    maybe_grow(|| {
489      for child in self.children_mut() {
490        child.linearize_match_with();
491      }
492    });
493    match self {
494      Term::Mat { arg: _, bnd: _, with_bnd, with_arg, arms } => {
495        for rule in arms {
496          rule.2 = Term::rfold_lams(std::mem::take(&mut rule.2), with_bnd.clone().into_iter());
497        }
498        *with_bnd = vec![];
499        let call_args = std::mem::take(with_arg).into_iter();
500        *self = Term::call(std::mem::take(self), call_args);
501      }
502      Term::Swt { arg: _, bnd: _, with_bnd, with_arg, pred: _, arms } => {
503        for rule in arms {
504          *rule = Term::rfold_lams(std::mem::take(rule), with_bnd.clone().into_iter());
505        }
506        *with_bnd = vec![];
507        let call_args = std::mem::take(with_arg).into_iter();
508        *self = Term::call(std::mem::take(self), call_args);
509      }
510      _ => {}
511    }
512  }
513}