egraph_serialize/
algorithms.rs

1use std::collections::HashMap;
2
3use crate::{Class, ClassId, EGraph, Node, NodeId};
4
5pub const MISSING_ARG_VALUE: &str = "·";
6
7impl EGraph {
8    /// Inline all leaves (e-classes with a single node that has no children) into their parents, so that they
9    /// are added to the function name like f(10, ·).
10    /// Returns the number of leaves inlined.
11    pub fn inline_leaves(&mut self) -> usize {
12        // 1. Create mapping of eclass to nodes as well as nodes to their parents
13        let mut eclass_to_nodes = std::collections::HashMap::new();
14        let mut node_to_parents = std::collections::HashMap::new();
15        for (node_id, node) in &self.nodes {
16            eclass_to_nodes
17                .entry(node.eclass.clone())
18                .or_insert_with(Vec::new)
19                .push((node_id.clone(), node));
20            for child in &node.children {
21                node_to_parents
22                    .entry(child.clone())
23                    .or_insert_with(Vec::new)
24                    .push(node_id.clone());
25            }
26        }
27        // 2. Find all leaves (e-classes with a single node that has no children and also not in root-eclasses)
28        let mut leaves = Vec::new();
29        let mut leave_to_op = std::collections::HashMap::new();
30        for (eclass, nodes) in eclass_to_nodes {
31            if nodes.len() == 1 && nodes[0].1.children.is_empty() {
32                leaves.push((eclass, nodes[0].0.clone()));
33                leave_to_op.insert(nodes[0].0.clone(), nodes[0].1.op.clone());
34            }
35        }
36        // 3. Create mapping from all parents which are updated to the children which are inlined
37        let mut parents_to_children = std::collections::HashMap::new();
38        for (_, node_id) in &leaves {
39            let parents = node_to_parents.get(node_id);
40            // There will be no parents for isolated nodes with no parents or children
41            if let Some(parents) = parents {
42                for parent in parents {
43                    parents_to_children
44                        .entry(parent.clone())
45                        .or_insert_with(Vec::new)
46                        .push(node_id.clone());
47                }
48            }
49        }
50        // 4. Inline leaf nodes into their parents
51        for (parent, leaf_children) in &parents_to_children {
52            let additional_cost = leaf_children
53                .iter()
54                .map(|child| self.nodes.get(child).unwrap().cost)
55                .sum::<ordered_float::NotNan<f64>>();
56            let parent_node = self.nodes.get_mut(parent).unwrap();
57            let args = parent_node
58                .children
59                .iter()
60                .map(|child| {
61                    if leaf_children.contains(child) {
62                        leave_to_op.get(child).unwrap()
63                    } else {
64                        MISSING_ARG_VALUE
65                    }
66                })
67                .collect::<Vec<_>>();
68            // Remove leaf children from children
69            parent_node
70                .children
71                .retain(|child| !leaf_children.contains(child));
72            // If the parent node already had some children replaced, then just replace the remaining children
73            // otherwise, replace the entire op
74            let new_op = if parent_node.op.matches(MISSING_ARG_VALUE).count() == args.len() {
75                // Replace all instances of MISSING_ARG_VALUE with the corresponding arg by interleaving
76                // the op split by MISSING_ARG_VALUE with the args
77                parent_node
78                    .op
79                    .split(MISSING_ARG_VALUE)
80                    .enumerate()
81                    .flat_map(|(i, s)| {
82                        if i == args.len() {
83                            vec![s.to_string()]
84                        } else {
85                            vec![s.to_string(), args[i].to_string()]
86                        }
87                    })
88                    .collect::<String>()
89            } else {
90                format!("{}({})", parent_node.op, args.join(", "))
91            };
92            parent_node.op = new_op;
93            parent_node.cost += additional_cost;
94        }
95        let mut n_inlined = 0;
96        // 5. Remove leaf nodes from egraph, class data, and root eclasses
97        for (eclass, node_id) in &leaves {
98            // If this node has no parents, don't remove it, since it wasn't inlined
99            if !node_to_parents.contains_key(node_id) {
100                continue;
101            }
102            n_inlined += 1;
103            self.nodes.swap_remove(node_id);
104            self.class_data.swap_remove(eclass);
105            self.root_eclasses.retain(|root| root != eclass);
106        }
107        n_inlined
108    }
109
110    /// Inline all leaves (e-classes with a single node that has no children) into their parents, recursively.
111    pub fn saturate_inline_leaves(&mut self) {
112        while self.inline_leaves() > 0 {}
113    }
114
115    /// Given some function `should_split`, after calling this method, all nodes where it is true will have at most
116    /// one other node in their e-class and if they have parents, will no other nodes in their e-class.
117    ///
118    /// All nodes in an e-class shared with a split node will still be in a e-class with that node, but no longer
119    /// in an e-class with each other.
120    ///
121    /// This is used to help make the visualization easier to parse, by breaking cycles and allowing these split nodes
122    /// to be later inlined into their parents with `inline_leaves`, which only applies when there is a single node in an
123    /// e-class.
124    ///
125    /// For example, if we split on all "primitive" nodes or their wrappers, like Int(1), then those nodes
126    /// can then be inlined into their parents and are all nodes equal to them are no longer in single e-class,
127    /// making the graph layout easier to understand.
128    ///
129    /// Note that `should_split` should only ever be true for either a single node, or no nodes, in an e-class.
130    /// If it is true for multiple nodes in an e-class, then this method will panic.
131    ///
132    /// Another way to think about it is that any isomporphic function can be split, since if f(a) = f(b) then a = b,
133    /// in that case.
134    pub fn split_classes(&mut self, should_split: impl Fn(&NodeId, &Node) -> bool) {
135        // run till fixpoint since splitting a node might add more parents and require splitting the child down the line
136        let mut changed = true;
137        while changed {
138            changed = false;
139            // Mapping from class ID to all nodes that point to any node in that e-class
140            let parents: HashMap<ClassId, Vec<(NodeId, usize)>> =
141                self.nodes
142                    .iter()
143                    .fold(HashMap::new(), |mut parents, (node_id, node)| {
144                        for (position, child) in node.children.iter().enumerate() {
145                            let child_class = self.nodes[child].eclass.clone();
146                            parents
147                                .entry(child_class)
148                                .or_default()
149                                .push((node_id.clone(), position));
150                        }
151                        parents
152                    });
153
154            for Class { id, nodes } in self.classes().clone().into_values() {
155                let (unique_nodes, other_nodes): (Vec<_>, Vec<_>) = nodes
156                    .into_iter()
157                    .partition(|node_id| should_split(node_id, &self.nodes[node_id]));
158                if unique_nodes.len() > 1 {
159                    panic!(
160                        "Multiple nodes in one e-class should be split. E-class: {:?} Nodes: {:?}",
161                        id, unique_nodes
162                    );
163                }
164                let unique_node = unique_nodes.into_iter().next();
165                let class_data = self.class_data.get(&id).cloned();
166                if let Some(unique_node_id) = unique_node {
167                    let unique_node = self.nodes[&unique_node_id].clone();
168                    let n_other_nodes = other_nodes.len();
169                    let mut offset = 0;
170                    if n_other_nodes == 0 {
171                        continue;
172                    }
173                    // split out other nodes if there are multiple of them.
174                    // Leave one node in this e-class and make new e-classes for remaining nodes
175                    for other_node_id in other_nodes.into_iter().skip(1) {
176                        changed = true;
177                        // use same ID for new class and new node added to that class
178                        let new_id = format!("split-{}-{}", offset, unique_node_id);
179                        offset += 1;
180                        let new_class_id: ClassId = new_id.clone().into();
181                        // Copy the class data if it exists
182                        if let Some(class_data) = &class_data {
183                            self.class_data
184                                .insert(new_class_id.clone(), class_data.clone());
185                        }
186                        // Change the e-class of the other node
187                        self.nodes[&other_node_id].eclass = new_class_id.clone();
188                        // Create a new unique node with the same data
189                        let mut new_unique_node = unique_node.clone();
190                        new_unique_node.eclass = new_class_id;
191                        self.nodes.insert(new_id.into(), new_unique_node);
192                    }
193                    // If there are other nodes, then make one more copy and point all the parents at that
194                    let parents = parents.get(&id).cloned().unwrap_or_default();
195                    if parents.is_empty() {
196                        continue;
197                    }
198                    changed = true;
199                    let new_id = format!("split-{}-{}", offset, unique_node_id);
200                    let new_class_id: ClassId = new_id.clone().into();
201                    // Copy the class data if it exists
202                    if let Some(class_data) = &class_data {
203                        self.class_data
204                            .insert(new_class_id.clone(), class_data.clone());
205                    }
206                    // Create a new unique node with the same data
207                    let mut new_unique_node = unique_node.clone();
208                    new_unique_node.eclass = new_class_id;
209                    self.nodes.insert(new_id.clone().into(), new_unique_node);
210                    for (parent_id, position) in parents {
211                        // Change the child of the parent to the new node
212                        self.nodes.get_mut(&parent_id).unwrap().children[position] =
213                            new_id.clone().into();
214                    }
215                }
216            }
217            // reset the classes computation
218            self.once_cell_classes.take();
219        }
220    }
221}