bend/fun/transform/
definition_merge.rs

1use crate::{
2  fun::{Book, Definition, Name, Rule, Term},
3  maybe_grow,
4};
5use indexmap::{IndexMap, IndexSet};
6use itertools::Itertools;
7use std::collections::BTreeMap;
8
9pub const MERGE_SEPARATOR: &str = "__M_";
10
11impl Book {
12  /// Merges definitions that have the same structure into one definition.
13  /// Expects variables to be linear.
14  ///
15  /// Some of the origins of the rules will be lost in this stage,
16  /// Should not be preceded by passes that cares about the origins.
17  pub fn merge_definitions(&mut self) {
18    let defs: Vec<_> = self.defs.keys().cloned().collect();
19    self.merge(defs.into_iter());
20  }
21
22  /// Checks and merges identical definitions given by `defs`.
23  /// We never merge the entrypoint function with something else.
24  fn merge(&mut self, defs: impl Iterator<Item = Name>) {
25    let name = self.entrypoint.clone();
26    // Sets of definitions that are identical, indexed by the body term.
27    let equal_terms =
28      self.collect_terms(defs.filter(|def_name| !name.as_ref().is_some_and(|m| m == def_name)));
29
30    // Map of old name to new merged name
31    let mut name_map = BTreeMap::new();
32
33    for (term, equal_defs) in equal_terms {
34      // def1_$_def2_$_def3
35      let new_name = Name::new(equal_defs.iter().join(MERGE_SEPARATOR));
36
37      if equal_defs.len() > 1 {
38        // Merging some defs
39
40        // The source of the generated definition will be based on the first one we get from `equal_defs`.
41        // In the future, we might want to change this to point to every source of every definition
42        // it's based on.
43        // This could be done by having SourceKind::Generated contain a Vec<Source> or Vec<Definition>.
44        let any_def_name = equal_defs.iter().next().unwrap(); // we know we can unwrap since equal_defs.len() > 1
45
46        // Add the merged def
47        let source = self.defs[any_def_name].source.clone();
48        let rules = vec![Rule { pats: vec![], body: term }];
49        // Note: This will erase types, so type checking needs to come before this.
50        let new_def = Definition::new_gen(new_name.clone(), rules, source, false);
51        self.defs.insert(new_name.clone(), new_def);
52        // Remove the old ones and write the map of old names to new ones.
53        for name in equal_defs {
54          self.defs.swap_remove(&name);
55          name_map.insert(name, new_name.clone());
56        }
57      } else {
58        // Not merging, just put the body back
59        let def_name = equal_defs.into_iter().next().unwrap();
60        let def = self.defs.get_mut(&def_name).unwrap();
61        def.rule_mut().body = term;
62      }
63    }
64    self.update_refs(&name_map);
65  }
66
67  fn collect_terms(&mut self, def_entries: impl Iterator<Item = Name>) -> IndexMap<Term, IndexSet<Name>> {
68    let mut equal_terms: IndexMap<Term, IndexSet<Name>> = IndexMap::new();
69
70    for def_name in def_entries {
71      let def = self.defs.get_mut(&def_name).unwrap();
72      let term = std::mem::take(&mut def.rule_mut().body);
73      equal_terms.entry(term).or_default().insert(def_name);
74    }
75
76    equal_terms
77  }
78
79  fn update_refs(&mut self, name_map: &BTreeMap<Name, Name>) {
80    let mut updated_defs = Vec::new();
81
82    for def in self.defs.values_mut() {
83      if Term::subst_ref_to_ref(&mut def.rule_mut().body, name_map) {
84        updated_defs.push(def.name.clone());
85      }
86    }
87
88    if !updated_defs.is_empty() {
89      self.merge(updated_defs.into_iter());
90    }
91  }
92}
93
94impl Term {
95  /// Performs reference substitution within a term replacing any references found in
96  /// `ref_map` with their corresponding targets.
97  pub fn subst_ref_to_ref(term: &mut Term, ref_map: &BTreeMap<Name, Name>) -> bool {
98    maybe_grow(|| match term {
99      Term::Ref { nam: def_name } => {
100        if let Some(target_name) = ref_map.get(def_name) {
101          *def_name = target_name.clone();
102          true
103        } else {
104          false
105        }
106      }
107
108      _ => {
109        let mut subst = false;
110        for child in term.children_mut() {
111          subst |= Term::subst_ref_to_ref(child, ref_map);
112        }
113        subst
114      }
115    })
116  }
117}