miden_core/mast/merger/mod.rs
1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_crypto::hash::blake::Blake3Digest;
4
5use crate::mast::{
6 BasicBlockNode, CallNode, DecoratorId, DynNode, ExternalNode, JoinNode, LoopNode, MastForest,
7 MastForestError, MastNode, MastNodeErrorContext, MastNodeFingerprint, MastNodeId,
8 MultiMastForestIteratorItem, MultiMastForestNodeIter, SplitNode, node::MastNodeExt,
9};
10
11#[cfg(test)]
12mod tests;
13
14/// A type that allows merging [`MastForest`]s.
15///
16/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
17pub(crate) struct MastForestMerger {
18 mast_forest: MastForest,
19 // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
20 // computation.
21 //
22 // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
23 // `mast_forest` are also added to the indices.
24 node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
25 hash_by_node_id: BTreeMap<MastNodeId, MastNodeFingerprint>,
26 decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
27 /// Mappings from old decorator and node ids to their new ids.
28 ///
29 /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
30 decorator_id_mappings: Vec<DecoratorIdMap>,
31 /// Mappings from previous `MastNodeId`s to their new ids.
32 ///
33 /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
34 node_id_mappings: Vec<MastForestNodeIdMap>,
35}
36
37impl MastForestMerger {
38 /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
39 /// into it.
40 ///
41 /// # Normalizing Behavior
42 ///
43 /// This function performs normalization of the merged forest, which:
44 /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
45 /// - Creates a clean, deduplicated forest structure
46 /// - Provides consistent node ordering regardless of input
47 ///
48 /// This normalization is idempotent, but it means that even for single-forest merges, the
49 /// resulting forest may have different node IDs and digests than the input. See assembly
50 /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
51 /// behavior.
52 pub(crate) fn merge<'forest>(
53 forests: impl IntoIterator<Item = &'forest MastForest>,
54 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
55 let forests = forests.into_iter().collect::<Vec<_>>();
56
57 let decorator_id_mappings = Vec::with_capacity(forests.len());
58 let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()];
59
60 let mut merger = Self {
61 node_id_by_hash: BTreeMap::new(),
62 hash_by_node_id: BTreeMap::new(),
63 decorators_by_hash: BTreeMap::new(),
64 mast_forest: MastForest::new(),
65 decorator_id_mappings,
66 node_id_mappings,
67 };
68
69 merger.merge_inner(forests.clone())?;
70
71 let Self { mast_forest, node_id_mappings, .. } = merger;
72
73 let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
74
75 Ok((mast_forest, root_maps))
76 }
77
78 /// Merges all `forests` into self.
79 ///
80 /// It does this in three steps:
81 ///
82 /// 1. Merge all advice maps, checking for key collisions.
83 /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
84 /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
85 /// merged forest.
86 /// 3. Merge all nodes of forests.
87 /// - Similar to decorators, node indices might move during merging, so the merger keeps a
88 /// node id mapping as it merges nodes.
89 /// - This is a depth-first traversal over all forests to ensure all children are processed
90 /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
91 /// on this traversal.
92 /// - Because all parents are processed after their children, we can use the node id mapping
93 /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
94 /// forest.
95 /// - If any external node is encountered during this traversal with a digest `foo` for which
96 /// a `replacement` node exists in another forest with digest `foo`, then the external node
97 /// will be replaced by that node. In particular, it means we do not want to add the
98 /// external node to the merged forest, so it is never yielded from the iterator.
99 /// - Assuming the simple case, where the `replacement` was not visited yet and is just a
100 /// single node (not a tree), the iterator would first yield the `replacement` node which
101 /// means it is going to be merged into the forest.
102 /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
103 /// which signals that an external node was replaced by another node. In this example,
104 /// the `replacement_*` indices contained in that variant would point to the
105 /// `replacement` node. Now we can simply add a mapping from the external node to the
106 /// `replacement` node in our node id mapping which means all nodes that referenced the
107 /// external node will point to the `replacement` instead.
108 /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
109 /// their potentially new indices in the merged forest and add them to the forest,
110 /// deduplicating in the process, too.
111 fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
112 for other_forest in forests.iter() {
113 self.merge_advice_map(other_forest)?;
114 }
115 for other_forest in forests.iter() {
116 self.merge_decorators(other_forest)?;
117 }
118 for other_forest in forests.iter() {
119 self.merge_error_codes(other_forest)?;
120 }
121
122 let iterator = MultiMastForestNodeIter::new(forests.clone());
123 for item in iterator {
124 match item {
125 MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
126 let node = &forests[forest_idx][node_id];
127 self.merge_node(forest_idx, node_id, node)?;
128 },
129 MultiMastForestIteratorItem::ExternalNodeReplacement {
130 // forest index of the node which replaces the external node
131 replacement_forest_idx,
132 // ID of the node that replaces the external node
133 replacement_mast_node_id,
134 // forest index of the external node
135 replaced_forest_idx,
136 // ID of the external node
137 replaced_mast_node_id,
138 } => {
139 // The iterator is not aware of the merged forest, so the node indices it yields
140 // are for the existing forests. That means we have to map the ID of the
141 // replacement to its new location, since it was previously merged and its IDs
142 // have very likely changed.
143 let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
144 .get(&replacement_mast_node_id)
145 .copied()
146 .expect("every merged node id should be mapped");
147
148 // SAFETY: The iterator only yields valid forest indices, so it is safe to index
149 // directly.
150 self.node_id_mappings[replaced_forest_idx]
151 .insert(replaced_mast_node_id, mapped_replacement);
152 },
153 }
154 }
155
156 for (forest_idx, forest) in forests.iter().enumerate() {
157 self.merge_roots(forest_idx, forest)?;
158 }
159
160 Ok(())
161 }
162
163 fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
164 let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());
165
166 for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
167 let merging_decorator_hash = merging_decorator.fingerprint();
168 let new_decorator_id = if let Some(existing_decorator) =
169 self.decorators_by_hash.get(&merging_decorator_hash)
170 {
171 *existing_decorator
172 } else {
173 let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
174 self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
175 new_decorator_id
176 };
177
178 decorator_id_remapping
179 .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
180 }
181
182 self.decorator_id_mappings.push(decorator_id_remapping);
183
184 Ok(())
185 }
186
187 fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
188 self.mast_forest
189 .advice_map
190 .merge(&other_forest.advice_map)
191 .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
192 }
193
194 fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
195 self.mast_forest.error_codes.extend(other_forest.error_codes.clone());
196 Ok(())
197 }
198
199 fn merge_node(
200 &mut self,
201 forest_idx: usize,
202 merging_id: MastNodeId,
203 node: &MastNode,
204 ) -> Result<(), MastForestError> {
205 // We need to remap the node prior to computing the MastNodeFingerprint.
206 //
207 // This is because the MastNodeFingerprint computation looks up its descendants and
208 // decorators in the internal index, and if we were to pass the original node to
209 // that computation, it would look up the incorrect descendants and decorators
210 // (since the descendant's indices may have changed).
211 //
212 // Remapping at this point is guaranteed to be "complete", meaning all ids of children
213 // will be present in the node id mapping since the DFS iteration guarantees
214 // that all children of this `node` have been processed before this node and
215 // their indices have been added to the mappings.
216 let remapped_node = self.remap_node(forest_idx, node)?;
217
218 let node_fingerprint = MastNodeFingerprint::from_mast_node(
219 &self.mast_forest,
220 &self.hash_by_node_id,
221 &remapped_node,
222 )
223 .expect(
224 "hash_by_node_id should contain the fingerprints of all children of `remapped_node`",
225 );
226
227 match self.lookup_node_by_fingerprint(&node_fingerprint) {
228 Some(matching_node_id) => {
229 // If a node with a matching fingerprint exists, then the merging node is a
230 // duplicate and we remap it to the existing node.
231 self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
232 },
233 None => {
234 // If no node with a matching fingerprint exists, then the merging node is
235 // unique and we can add it to the merged forest.
236 let new_node_id = self.mast_forest.add_node(remapped_node)?;
237 self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
238
239 // We need to update the indices with the newly inserted nodes
240 // since the MastNodeFingerprint computation requires all descendants of a node
241 // to be in this index. Hence when we encounter a node in the merging forest
242 // which has descendants (Call, Loop, Split, ...), then their descendants need to be
243 // in the indices.
244 self.node_id_by_hash.insert(node_fingerprint, new_node_id);
245 self.hash_by_node_id.insert(new_node_id, node_fingerprint);
246 },
247 }
248
249 Ok(())
250 }
251
252 fn merge_roots(
253 &mut self,
254 forest_idx: usize,
255 other_forest: &MastForest,
256 ) -> Result<(), MastForestError> {
257 for root_id in other_forest.roots.iter() {
258 // Map the previous root to its possibly new id.
259 let new_root = self.node_id_mappings[forest_idx]
260 .get(root_id)
261 .expect("all node ids should have an entry");
262 // This takes O(n) where n is the number of roots in the merged forest every time to
263 // check if the root already exists. As the number of roots is relatively low generally,
264 // this should be okay.
265 self.mast_forest.make_root(*new_root);
266 }
267
268 Ok(())
269 }
270
271 /// Remaps a nodes' potentially contained children and decorators to their new IDs according to
272 /// the given maps.
273 fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
274 let map_decorator_id = |decorator_id: &DecoratorId| {
275 self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| {
276 MastForestError::DecoratorIdOverflow(
277 *decorator_id,
278 self.decorator_id_mappings[forest_idx].len(),
279 )
280 })
281 };
282 let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
283 decorators.iter().map(map_decorator_id).collect()
284 };
285
286 let map_node_id = |node_id: MastNodeId| {
287 self.node_id_mappings[forest_idx]
288 .get(&node_id)
289 .copied()
290 .expect("every node id should have an entry")
291 };
292
293 // Due to DFS postorder iteration all children of node's should have been inserted before
294 // their parents which is why we can `expect` the constructor calls here.
295 let mut mapped_node: MastNode = match node {
296 MastNode::Join(join_node) => {
297 let first = map_node_id(join_node.first());
298 let second = map_node_id(join_node.second());
299
300 JoinNode::new([first, second], &self.mast_forest)
301 .expect("JoinNode children should have been mapped to a lower index")
302 .into()
303 },
304 MastNode::Split(split_node) => {
305 let if_branch = map_node_id(split_node.on_true());
306 let else_branch = map_node_id(split_node.on_false());
307
308 SplitNode::new([if_branch, else_branch], &self.mast_forest)
309 .expect("SplitNode children should have been mapped to a lower index")
310 .into()
311 },
312 MastNode::Loop(loop_node) => {
313 let body = map_node_id(loop_node.body());
314 LoopNode::new(body, &self.mast_forest)
315 .expect("LoopNode children should have been mapped to a lower index")
316 .into()
317 },
318 MastNode::Call(call_node) => {
319 let callee = map_node_id(call_node.callee());
320 CallNode::new(callee, &self.mast_forest)
321 .expect("CallNode children should have been mapped to a lower index")
322 .into()
323 },
324 // Other nodes are simply copied.
325 MastNode::Block(basic_block_node) => {
326 BasicBlockNode::new(
327 basic_block_node.operations().copied().collect(),
328 // Operation Indices of decorators stay the same while decorator IDs need to be
329 // mapped.
330 Some(
331 basic_block_node
332 .decorators()
333 .map(|(idx, decorator_id)| match map_decorator_id(&decorator_id) {
334 Ok(mapped_decorator) => Ok((idx, mapped_decorator)),
335 Err(err) => Err(err),
336 })
337 .collect::<Result<Vec<_>, _>>()?,
338 ),
339 )
340 .expect("previously valid BasicBlockNode should still be valid")
341 .into()
342 },
343 MastNode::Dyn(_) => DynNode::new_dyn().into(),
344 MastNode::External(external_node) => ExternalNode::new(external_node.digest()).into(),
345 };
346
347 // Decorators must be handled specially for basic block nodes.
348 // For other node types we can handle it centrally.
349 if !mapped_node.is_basic_block() {
350 mapped_node.append_before_enter(&map_decorators(node.before_enter())?);
351 mapped_node.append_after_exit(&map_decorators(node.after_exit())?);
352 }
353
354 Ok(mapped_node)
355 }
356
357 // HELPERS
358 // ================================================================================================
359
360 /// Returns the ID of the node in the merged forest that matches the given
361 /// fingerprint, if any.
362 fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
363 self.node_id_by_hash.get(fingerprint).copied()
364 }
365}
366
367// MAST FOREST ROOT MAP
368// ================================================================================================
369
370/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
371///
372/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
373/// forest. See [`MastForest::merge`] for more details.
374#[derive(Debug, Clone, PartialEq, Eq)]
375pub struct MastForestRootMap {
376 root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
377}
378
379impl MastForestRootMap {
380 fn from_node_id_map(id_map: Vec<MastForestNodeIdMap>, forests: Vec<&MastForest>) -> Self {
381 let mut root_maps = vec![BTreeMap::new(); forests.len()];
382
383 for (forest_idx, forest) in forests.into_iter().enumerate() {
384 for root in forest.procedure_roots() {
385 let new_id = id_map[forest_idx]
386 .get(root)
387 .copied()
388 .expect("every node id should be mapped to its new id");
389 root_maps[forest_idx].insert(*root, new_id);
390 }
391 }
392
393 Self { root_maps }
394 }
395
396 /// Maps the given root to its new location in the merged forest, if such a mapping exists.
397 ///
398 /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
399 pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
400 self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
401 }
402}
403
404// DECORATOR ID MAP
405// ================================================================================================
406
407/// A specialized map from [`DecoratorId`] -> [`DecoratorId`].
408///
409/// When mapping Decorator IDs during merging, we always map all IDs of the merging
410/// forest to new ids. Hence it is more efficient to use a `Vec` instead of, say, a `BTreeMap`.
411///
412/// In other words, this type is similar to `BTreeMap<ID, ID>` but takes advantage of the fact that
413/// the keys are contiguous.
414///
415/// This type is meant to encapsulates some guarantees:
416///
417/// - Indexing into the vector for any ID is safe if that ID is valid for the corresponding forest.
418/// Despite that, we still cannot index unconditionally in case a node with invalid
419/// [`DecoratorId`]s is passed to `merge`.
420/// - The entry itself can be either None or Some. However:
421/// - For `DecoratorId`s we iterate and insert all decorators into this map before retrieving any
422/// entry, so all entries contain `Some`. Because of this, we can use `expect` in `get` for the
423/// `Option` value.
424/// - Similarly, inserting any ID from the corresponding forest is safe as the map contains a
425/// pre-allocated `Vec` of the appropriate size.
426struct DecoratorIdMap {
427 inner: Vec<Option<DecoratorId>>,
428}
429
430impl DecoratorIdMap {
431 fn new(num_ids: usize) -> Self {
432 Self { inner: vec![None; num_ids] }
433 }
434
435 /// Maps the given key to the given value.
436 ///
437 /// It is the caller's responsibility to only pass keys that belong to the forest for which this
438 /// map was originally created.
439 fn insert(&mut self, key: DecoratorId, value: DecoratorId) {
440 self.inner[key.as_usize()] = Some(value);
441 }
442
443 /// Retrieves the value for the given key.
444 fn get(&self, key: &DecoratorId) -> Option<DecoratorId> {
445 self.inner
446 .get(key.as_usize())
447 .map(|id| id.expect("every id should have a Some entry in the map when calling get"))
448 }
449
450 fn len(&self) -> usize {
451 self.inner.len()
452 }
453}
454
455/// A type definition for increased readability in function signatures.
456type MastForestNodeIdMap = BTreeMap<MastNodeId, MastNodeId>;