miden_core/mast/merger/mod.rs
1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_crypto::hash::blake::Blake3Digest;
4
5use crate::mast::{
6 DecoratorId, MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId,
7 MultiMastForestIteratorItem, MultiMastForestNodeIter,
8};
9
10#[cfg(test)]
11mod tests;
12
13/// A type that allows merging [`MastForest`]s.
14///
15/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
16pub(crate) struct MastForestMerger {
17 mast_forest: MastForest,
18 // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
19 // computation.
20 //
21 // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
22 // `mast_forest` are also added to the indices.
23 node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
24 hash_by_node_id: BTreeMap<MastNodeId, MastNodeFingerprint>,
25 decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
26 /// Mappings from old decorator and node ids to their new ids.
27 ///
28 /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
29 decorator_id_mappings: Vec<DecoratorIdMap>,
30 /// Mappings from previous `MastNodeId`s to their new ids.
31 ///
32 /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
33 node_id_mappings: Vec<MastForestNodeIdMap>,
34}
35
36impl MastForestMerger {
37 /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
38 /// into it.
39 pub(crate) fn merge<'forest>(
40 forests: impl IntoIterator<Item = &'forest MastForest>,
41 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
42 let forests = forests.into_iter().collect::<Vec<_>>();
43 let decorator_id_mappings = Vec::with_capacity(forests.len());
44 let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()];
45
46 let mut merger = Self {
47 node_id_by_hash: BTreeMap::new(),
48 hash_by_node_id: BTreeMap::new(),
49 decorators_by_hash: BTreeMap::new(),
50 mast_forest: MastForest::new(),
51 decorator_id_mappings,
52 node_id_mappings,
53 };
54
55 merger.merge_inner(forests.clone())?;
56
57 let Self { mast_forest, node_id_mappings, .. } = merger;
58
59 let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
60
61 Ok((mast_forest, root_maps))
62 }
63
64 /// Merges all `forests` into self.
65 ///
66 /// It does this in three steps:
67 ///
68 /// 1. Merge all advice maps, checking for key collisions.
69 /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
70 /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
71 /// merged forest.
72 /// 3. Merge all nodes of forests.
73 /// - Similar to decorators, node indices might move during merging, so the merger keeps a
74 /// node id mapping as it merges nodes.
75 /// - This is a depth-first traversal over all forests to ensure all children are processed
76 /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
77 /// on this traversal.
78 /// - Because all parents are processed after their children, we can use the node id mapping
79 /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
80 /// forest.
81 /// - If any external node is encountered during this traversal with a digest `foo` for which
82 /// a `replacement` node exists in another forest with digest `foo`, then the external node
83 /// will be replaced by that node. In particular, it means we do not want to add the
84 /// external node to the merged forest, so it is never yielded from the iterator.
85 /// - Assuming the simple case, where the `replacement` was not visited yet and is just a
86 /// single node (not a tree), the iterator would first yield the `replacement` node which
87 /// means it is going to be merged into the forest.
88 /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
89 /// which signals that an external node was replaced by another node. In this example,
90 /// the `replacement_*` indices contained in that variant would point to the
91 /// `replacement` node. Now we can simply add a mapping from the external node to the
92 /// `replacement` node in our node id mapping which means all nodes that referenced the
93 /// external node will point to the `replacement` instead.
94 /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
95 /// their potentially new indices in the merged forest and add them to the forest,
96 /// deduplicating in the process, too.
97 fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
98 for other_forest in forests.iter() {
99 self.merge_advice_map(other_forest)?;
100 }
101 for other_forest in forests.iter() {
102 self.merge_decorators(other_forest)?;
103 }
104 for other_forest in forests.iter() {
105 self.merge_error_codes(other_forest)?;
106 }
107
108 let iterator = MultiMastForestNodeIter::new(forests.clone());
109 for item in iterator {
110 match item {
111 MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
112 let node = &forests[forest_idx][node_id];
113 self.merge_node(forest_idx, node_id, node)?;
114 },
115 MultiMastForestIteratorItem::ExternalNodeReplacement {
116 // forest index of the node which replaces the external node
117 replacement_forest_idx,
118 // ID of the node that replaces the external node
119 replacement_mast_node_id,
120 // forest index of the external node
121 replaced_forest_idx,
122 // ID of the external node
123 replaced_mast_node_id,
124 } => {
125 // The iterator is not aware of the merged forest, so the node indices it yields
126 // are for the existing forests. That means we have to map the ID of the
127 // replacement to its new location, since it was previously merged and its IDs
128 // have very likely changed.
129 let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
130 .get(&replacement_mast_node_id)
131 .copied()
132 .expect("every merged node id should be mapped");
133
134 // SAFETY: The iterator only yields valid forest indices, so it is safe to index
135 // directly.
136 self.node_id_mappings[replaced_forest_idx]
137 .insert(replaced_mast_node_id, mapped_replacement);
138 },
139 }
140 }
141
142 for (forest_idx, forest) in forests.iter().enumerate() {
143 self.merge_roots(forest_idx, forest)?;
144 }
145
146 Ok(())
147 }
148
149 fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
150 let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());
151
152 for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
153 let merging_decorator_hash = merging_decorator.fingerprint();
154 let new_decorator_id = if let Some(existing_decorator) =
155 self.decorators_by_hash.get(&merging_decorator_hash)
156 {
157 *existing_decorator
158 } else {
159 let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
160 self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
161 new_decorator_id
162 };
163
164 decorator_id_remapping
165 .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
166 }
167
168 self.decorator_id_mappings.push(decorator_id_remapping);
169
170 Ok(())
171 }
172
173 fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
174 self.mast_forest
175 .advice_map
176 .merge(&other_forest.advice_map)
177 .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
178 }
179
180 fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
181 self.mast_forest.error_codes.extend(other_forest.error_codes.clone());
182 Ok(())
183 }
184
185 fn merge_node(
186 &mut self,
187 forest_idx: usize,
188 merging_id: MastNodeId,
189 node: &MastNode,
190 ) -> Result<(), MastForestError> {
191 // We need to remap the node prior to computing the MastNodeFingerprint.
192 //
193 // This is because the MastNodeFingerprint computation looks up its descendants and
194 // decorators in the internal index, and if we were to pass the original node to
195 // that computation, it would look up the incorrect descendants and decorators
196 // (since the descendant's indices may have changed).
197 //
198 // Remapping at this point is guaranteed to be "complete", meaning all ids of children
199 // will be present in the node id mapping since the DFS iteration guarantees
200 // that all children of this `node` have been processed before this node and
201 // their indices have been added to the mappings.
202 let remapped_node = self.remap_node(forest_idx, node)?;
203
204 let node_fingerprint = MastNodeFingerprint::from_mast_node(
205 &self.mast_forest,
206 &self.hash_by_node_id,
207 &remapped_node,
208 )
209 .expect(
210 "hash_by_node_id should contain the fingerprints of all children of `remapped_node`",
211 );
212
213 match self.lookup_node_by_fingerprint(&node_fingerprint) {
214 Some(matching_node_id) => {
215 // If a node with a matching fingerprint exists, then the merging node is a
216 // duplicate and we remap it to the existing node.
217 self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
218 },
219 None => {
220 // If no node with a matching fingerprint exists, then the merging node is
221 // unique and we can add it to the merged forest.
222 let new_node_id = self.mast_forest.add_node(remapped_node)?;
223 self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
224
225 // We need to update the indices with the newly inserted nodes
226 // since the MastNodeFingerprint computation requires all descendants of a node
227 // to be in this index. Hence when we encounter a node in the merging forest
228 // which has descendants (Call, Loop, Split, ...), then their descendants need to be
229 // in the indices.
230 self.node_id_by_hash.insert(node_fingerprint, new_node_id);
231 self.hash_by_node_id.insert(new_node_id, node_fingerprint);
232 },
233 }
234
235 Ok(())
236 }
237
238 fn merge_roots(
239 &mut self,
240 forest_idx: usize,
241 other_forest: &MastForest,
242 ) -> Result<(), MastForestError> {
243 for root_id in other_forest.roots.iter() {
244 // Map the previous root to its possibly new id.
245 let new_root = self.node_id_mappings[forest_idx]
246 .get(root_id)
247 .expect("all node ids should have an entry");
248 // This takes O(n) where n is the number of roots in the merged forest every time to
249 // check if the root already exists. As the number of roots is relatively low generally,
250 // this should be okay.
251 self.mast_forest.make_root(*new_root);
252 }
253
254 Ok(())
255 }
256
257 /// Remaps a nodes' potentially contained children and decorators to their new IDs according to
258 /// the given maps.
259 fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
260 let map_decorator_id = |decorator_id: &DecoratorId| {
261 self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| {
262 MastForestError::DecoratorIdOverflow(
263 *decorator_id,
264 self.decorator_id_mappings[forest_idx].len(),
265 )
266 })
267 };
268 let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
269 decorators.iter().map(map_decorator_id).collect()
270 };
271
272 let map_node_id = |node_id: MastNodeId| {
273 self.node_id_mappings[forest_idx]
274 .get(&node_id)
275 .copied()
276 .expect("every node id should have an entry")
277 };
278
279 // Due to DFS postorder iteration all children of node's should have been inserted before
280 // their parents which is why we can `expect` the constructor calls here.
281 let mut mapped_node = match node {
282 MastNode::Join(join_node) => {
283 let first = map_node_id(join_node.first());
284 let second = map_node_id(join_node.second());
285
286 MastNode::new_join(first, second, &self.mast_forest)
287 .expect("JoinNode children should have been mapped to a lower index")
288 },
289 MastNode::Split(split_node) => {
290 let if_branch = map_node_id(split_node.on_true());
291 let else_branch = map_node_id(split_node.on_false());
292
293 MastNode::new_split(if_branch, else_branch, &self.mast_forest)
294 .expect("SplitNode children should have been mapped to a lower index")
295 },
296 MastNode::Loop(loop_node) => {
297 let body = map_node_id(loop_node.body());
298 MastNode::new_loop(body, &self.mast_forest)
299 .expect("LoopNode children should have been mapped to a lower index")
300 },
301 MastNode::Call(call_node) => {
302 let callee = map_node_id(call_node.callee());
303 MastNode::new_call(callee, &self.mast_forest)
304 .expect("CallNode children should have been mapped to a lower index")
305 },
306 // Other nodes are simply copied.
307 MastNode::Block(basic_block_node) => {
308 MastNode::new_basic_block(
309 basic_block_node.operations().copied().collect(),
310 // Operation Indices of decorators stay the same while decorator IDs need to be
311 // mapped.
312 Some(
313 basic_block_node
314 .decorators()
315 .iter()
316 .map(|(idx, decorator_id)| match map_decorator_id(decorator_id) {
317 Ok(mapped_decorator) => Ok((*idx, mapped_decorator)),
318 Err(err) => Err(err),
319 })
320 .collect::<Result<Vec<_>, _>>()?,
321 ),
322 )
323 .expect("previously valid BasicBlockNode should still be valid")
324 },
325 MastNode::Dyn(_) => MastNode::new_dyn(),
326 MastNode::External(external_node) => MastNode::new_external(external_node.digest()),
327 };
328
329 // Decorators must be handled specially for basic block nodes.
330 // For other node types we can handle it centrally.
331 if !mapped_node.is_basic_block() {
332 mapped_node.append_before_enter(&map_decorators(node.before_enter())?);
333 mapped_node.append_after_exit(&map_decorators(node.after_exit())?);
334 }
335
336 Ok(mapped_node)
337 }
338
339 // HELPERS
340 // ================================================================================================
341
342 /// Returns a slice of nodes in the merged forest which have the given `mast_root`.
343 fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
344 self.node_id_by_hash.get(fingerprint).copied()
345 }
346}
347
348// MAST FOREST ROOT MAP
349// ================================================================================================
350
351/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
352///
353/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
354/// forest. See [`MastForest::merge`] for more details.
355#[derive(Debug, Clone, PartialEq, Eq)]
356pub struct MastForestRootMap {
357 root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
358}
359
360impl MastForestRootMap {
361 fn from_node_id_map(id_map: Vec<MastForestNodeIdMap>, forests: Vec<&MastForest>) -> Self {
362 let mut root_maps = vec![BTreeMap::new(); forests.len()];
363
364 for (forest_idx, forest) in forests.into_iter().enumerate() {
365 for root in forest.procedure_roots() {
366 let new_id = id_map[forest_idx]
367 .get(root)
368 .copied()
369 .expect("every node id should be mapped to its new id");
370 root_maps[forest_idx].insert(*root, new_id);
371 }
372 }
373
374 Self { root_maps }
375 }
376
377 /// Maps the given root to its new location in the merged forest, if such a mapping exists.
378 ///
379 /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
380 pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
381 self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
382 }
383}
384
385// DECORATOR ID MAP
386// ================================================================================================
387
388/// A specialized map from [`DecoratorId`] -> [`DecoratorId`].
389///
390/// When mapping Decorator IDs during merging, we always map all IDs of the merging
391/// forest to new ids. Hence it is more efficient to use a `Vec` instead of, say, a `BTreeMap`.
392///
393/// In other words, this type is similar to `BTreeMap<ID, ID>` but takes advantage of the fact that
394/// the keys are contiguous.
395///
396/// This type is meant to encapsulates some guarantees:
397///
398/// - Indexing into the vector for any ID is safe if that ID is valid for the corresponding forest.
399/// Despite that, we still cannot index unconditionally in case a node with invalid
400/// [`DecoratorId`]s is passed to `merge`.
401/// - The entry itself can be either None or Some. However:
402/// - For `DecoratorId`s we iterate and insert all decorators into this map before retrieving any
403/// entry, so all entries contain `Some`. Because of this, we can use `expect` in `get` for the
404/// `Option` value.
405/// - Similarly, inserting any ID from the corresponding forest is safe as the map contains a
406/// pre-allocated `Vec` of the appropriate size.
407struct DecoratorIdMap {
408 inner: Vec<Option<DecoratorId>>,
409}
410
411impl DecoratorIdMap {
412 fn new(num_ids: usize) -> Self {
413 Self { inner: vec![None; num_ids] }
414 }
415
416 /// Maps the given key to the given value.
417 ///
418 /// It is the caller's responsibility to only pass keys that belong to the forest for which this
419 /// map was originally created.
420 fn insert(&mut self, key: DecoratorId, value: DecoratorId) {
421 self.inner[key.as_usize()] = Some(value);
422 }
423
424 /// Retrieves the value for the given key.
425 fn get(&self, key: &DecoratorId) -> Option<DecoratorId> {
426 self.inner
427 .get(key.as_usize())
428 .map(|id| id.expect("every id should have a Some entry in the map when calling get"))
429 }
430
431 fn len(&self) -> usize {
432 self.inner.len()
433 }
434}
435
436/// A type definition for increased readability in function signatures.
437type MastForestNodeIdMap = BTreeMap<MastNodeId, MastNodeId>;