miden_core/mast/merger/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_crypto::hash::blake::Blake3Digest;
4
5use crate::mast::{
6    BasicBlockNode, CallNode, DecoratorId, DynNode, ExternalNode, JoinNode, LoopNode, MastForest,
7    MastForestError, MastNode, MastNodeErrorContext, MastNodeFingerprint, MastNodeId,
8    MultiMastForestIteratorItem, MultiMastForestNodeIter, SplitNode, node::MastNodeExt,
9};
10
11#[cfg(test)]
12mod tests;
13
14/// A type that allows merging [`MastForest`]s.
15///
16/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
17pub(crate) struct MastForestMerger {
18    mast_forest: MastForest,
19    // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
20    // computation.
21    //
22    // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
23    // `mast_forest` are also added to the indices.
24    node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
25    hash_by_node_id: BTreeMap<MastNodeId, MastNodeFingerprint>,
26    decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
27    /// Mappings from old decorator and node ids to their new ids.
28    ///
29    /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
30    decorator_id_mappings: Vec<DecoratorIdMap>,
31    /// Mappings from previous `MastNodeId`s to their new ids.
32    ///
33    /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
34    node_id_mappings: Vec<MastForestNodeIdMap>,
35}
36
37impl MastForestMerger {
38    /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
39    /// into it.
40    ///
41    /// # Normalizing Behavior
42    ///
43    /// This function performs normalization of the merged forest, which:
44    /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
45    /// - Creates a clean, deduplicated forest structure
46    /// - Provides consistent node ordering regardless of input
47    ///
48    /// This normalization is idempotent, but it means that even for single-forest merges, the
49    /// resulting forest may have different node IDs and digests than the input. See assembly
50    /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
51    /// behavior.
52    pub(crate) fn merge<'forest>(
53        forests: impl IntoIterator<Item = &'forest MastForest>,
54    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
55        let forests = forests.into_iter().collect::<Vec<_>>();
56
57        let decorator_id_mappings = Vec::with_capacity(forests.len());
58        let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()];
59
60        let mut merger = Self {
61            node_id_by_hash: BTreeMap::new(),
62            hash_by_node_id: BTreeMap::new(),
63            decorators_by_hash: BTreeMap::new(),
64            mast_forest: MastForest::new(),
65            decorator_id_mappings,
66            node_id_mappings,
67        };
68
69        merger.merge_inner(forests.clone())?;
70
71        let Self { mast_forest, node_id_mappings, .. } = merger;
72
73        let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
74
75        Ok((mast_forest, root_maps))
76    }
77
78    /// Merges all `forests` into self.
79    ///
80    /// It does this in three steps:
81    ///
82    /// 1. Merge all advice maps, checking for key collisions.
83    /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
84    ///    mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
85    ///    merged forest.
86    /// 3. Merge all nodes of forests.
87    ///    - Similar to decorators, node indices might move during merging, so the merger keeps a
88    ///      node id mapping as it merges nodes.
89    ///    - This is a depth-first traversal over all forests to ensure all children are processed
90    ///      before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
91    ///      on this traversal.
92    ///    - Because all parents are processed after their children, we can use the node id mapping
93    ///      to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
94    ///      forest.
95    ///    - If any external node is encountered during this traversal with a digest `foo` for which
96    ///      a `replacement` node exists in another forest with digest `foo`, then the external node
97    ///      will be replaced by that node. In particular, it means we do not want to add the
98    ///      external node to the merged forest, so it is never yielded from the iterator.
99    ///      - Assuming the simple case, where the `replacement` was not visited yet and is just a
100    ///        single node (not a tree), the iterator would first yield the `replacement` node which
101    ///        means it is going to be merged into the forest.
102    ///      - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
103    ///        which signals that an external node was replaced by another node. In this example,
104    ///        the `replacement_*` indices contained in that variant would point to the
105    ///        `replacement` node. Now we can simply add a mapping from the external node to the
106    ///        `replacement` node in our node id mapping which means all nodes that referenced the
107    ///        external node will point to the `replacement` instead.
108    /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
109    ///    their potentially new indices in the merged forest and add them to the forest,
110    ///    deduplicating in the process, too.
111    fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
112        for other_forest in forests.iter() {
113            self.merge_advice_map(other_forest)?;
114        }
115        for other_forest in forests.iter() {
116            self.merge_decorators(other_forest)?;
117        }
118        for other_forest in forests.iter() {
119            self.merge_error_codes(other_forest)?;
120        }
121
122        let iterator = MultiMastForestNodeIter::new(forests.clone());
123        for item in iterator {
124            match item {
125                MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
126                    let node = &forests[forest_idx][node_id];
127                    self.merge_node(forest_idx, node_id, node)?;
128                },
129                MultiMastForestIteratorItem::ExternalNodeReplacement {
130                    // forest index of the node which replaces the external node
131                    replacement_forest_idx,
132                    // ID of the node that replaces the external node
133                    replacement_mast_node_id,
134                    // forest index of the external node
135                    replaced_forest_idx,
136                    // ID of the external node
137                    replaced_mast_node_id,
138                } => {
139                    // The iterator is not aware of the merged forest, so the node indices it yields
140                    // are for the existing forests. That means we have to map the ID of the
141                    // replacement to its new location, since it was previously merged and its IDs
142                    // have very likely changed.
143                    let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
144                        .get(&replacement_mast_node_id)
145                        .copied()
146                        .expect("every merged node id should be mapped");
147
148                    // SAFETY: The iterator only yields valid forest indices, so it is safe to index
149                    // directly.
150                    self.node_id_mappings[replaced_forest_idx]
151                        .insert(replaced_mast_node_id, mapped_replacement);
152                },
153            }
154        }
155
156        for (forest_idx, forest) in forests.iter().enumerate() {
157            self.merge_roots(forest_idx, forest)?;
158        }
159
160        Ok(())
161    }
162
163    fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
164        let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());
165
166        for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
167            let merging_decorator_hash = merging_decorator.fingerprint();
168            let new_decorator_id = if let Some(existing_decorator) =
169                self.decorators_by_hash.get(&merging_decorator_hash)
170            {
171                *existing_decorator
172            } else {
173                let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
174                self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
175                new_decorator_id
176            };
177
178            decorator_id_remapping
179                .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
180        }
181
182        self.decorator_id_mappings.push(decorator_id_remapping);
183
184        Ok(())
185    }
186
187    fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
188        self.mast_forest
189            .advice_map
190            .merge(&other_forest.advice_map)
191            .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
192    }
193
194    fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
195        self.mast_forest.error_codes.extend(other_forest.error_codes.clone());
196        Ok(())
197    }
198
199    fn merge_node(
200        &mut self,
201        forest_idx: usize,
202        merging_id: MastNodeId,
203        node: &MastNode,
204    ) -> Result<(), MastForestError> {
205        // We need to remap the node prior to computing the MastNodeFingerprint.
206        //
207        // This is because the MastNodeFingerprint computation looks up its descendants and
208        // decorators in the internal index, and if we were to pass the original node to
209        // that computation, it would look up the incorrect descendants and decorators
210        // (since the descendant's indices may have changed).
211        //
212        // Remapping at this point is guaranteed to be "complete", meaning all ids of children
213        // will be present in the node id mapping since the DFS iteration guarantees
214        // that all children of this `node` have been processed before this node and
215        // their indices have been added to the mappings.
216        let remapped_node = self.remap_node(forest_idx, node)?;
217
218        let node_fingerprint = MastNodeFingerprint::from_mast_node(
219            &self.mast_forest,
220            &self.hash_by_node_id,
221            &remapped_node,
222        )
223        .expect(
224            "hash_by_node_id should contain the fingerprints of all children of `remapped_node`",
225        );
226
227        match self.lookup_node_by_fingerprint(&node_fingerprint) {
228            Some(matching_node_id) => {
229                // If a node with a matching fingerprint exists, then the merging node is a
230                // duplicate and we remap it to the existing node.
231                self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
232            },
233            None => {
234                // If no node with a matching fingerprint exists, then the merging node is
235                // unique and we can add it to the merged forest.
236                let new_node_id = self.mast_forest.add_node(remapped_node)?;
237                self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
238
239                // We need to update the indices with the newly inserted nodes
240                // since the MastNodeFingerprint computation requires all descendants of a node
241                // to be in this index. Hence when we encounter a node in the merging forest
242                // which has descendants (Call, Loop, Split, ...), then their descendants need to be
243                // in the indices.
244                self.node_id_by_hash.insert(node_fingerprint, new_node_id);
245                self.hash_by_node_id.insert(new_node_id, node_fingerprint);
246            },
247        }
248
249        Ok(())
250    }
251
252    fn merge_roots(
253        &mut self,
254        forest_idx: usize,
255        other_forest: &MastForest,
256    ) -> Result<(), MastForestError> {
257        for root_id in other_forest.roots.iter() {
258            // Map the previous root to its possibly new id.
259            let new_root = self.node_id_mappings[forest_idx]
260                .get(root_id)
261                .expect("all node ids should have an entry");
262            // This takes O(n) where n is the number of roots in the merged forest every time to
263            // check if the root already exists. As the number of roots is relatively low generally,
264            // this should be okay.
265            self.mast_forest.make_root(*new_root);
266        }
267
268        Ok(())
269    }
270
271    /// Remaps a nodes' potentially contained children and decorators to their new IDs according to
272    /// the given maps.
273    fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
274        let map_decorator_id = |decorator_id: &DecoratorId| {
275            self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| {
276                MastForestError::DecoratorIdOverflow(
277                    *decorator_id,
278                    self.decorator_id_mappings[forest_idx].len(),
279                )
280            })
281        };
282        let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
283            decorators.iter().map(map_decorator_id).collect()
284        };
285
286        let map_node_id = |node_id: MastNodeId| {
287            self.node_id_mappings[forest_idx]
288                .get(&node_id)
289                .copied()
290                .expect("every node id should have an entry")
291        };
292
293        // Due to DFS postorder iteration all children of node's should have been inserted before
294        // their parents which is why we can `expect` the constructor calls here.
295        let mut mapped_node: MastNode = match node {
296            MastNode::Join(join_node) => {
297                let first = map_node_id(join_node.first());
298                let second = map_node_id(join_node.second());
299
300                JoinNode::new([first, second], &self.mast_forest)
301                    .expect("JoinNode children should have been mapped to a lower index")
302                    .into()
303            },
304            MastNode::Split(split_node) => {
305                let if_branch = map_node_id(split_node.on_true());
306                let else_branch = map_node_id(split_node.on_false());
307
308                SplitNode::new([if_branch, else_branch], &self.mast_forest)
309                    .expect("SplitNode children should have been mapped to a lower index")
310                    .into()
311            },
312            MastNode::Loop(loop_node) => {
313                let body = map_node_id(loop_node.body());
314                LoopNode::new(body, &self.mast_forest)
315                    .expect("LoopNode children should have been mapped to a lower index")
316                    .into()
317            },
318            MastNode::Call(call_node) => {
319                let callee = map_node_id(call_node.callee());
320                CallNode::new(callee, &self.mast_forest)
321                    .expect("CallNode children should have been mapped to a lower index")
322                    .into()
323            },
324            // Other nodes are simply copied.
325            MastNode::Block(basic_block_node) => {
326                BasicBlockNode::new(
327                    basic_block_node.operations().copied().collect(),
328                    // Operation Indices of decorators stay the same while decorator IDs need to be
329                    // mapped.
330                    Some(
331                        basic_block_node
332                            .decorators()
333                            .map(|(idx, decorator_id)| match map_decorator_id(&decorator_id) {
334                                Ok(mapped_decorator) => Ok((idx, mapped_decorator)),
335                                Err(err) => Err(err),
336                            })
337                            .collect::<Result<Vec<_>, _>>()?,
338                    ),
339                )
340                .expect("previously valid BasicBlockNode should still be valid")
341                .into()
342            },
343            MastNode::Dyn(_) => DynNode::new_dyn().into(),
344            MastNode::External(external_node) => ExternalNode::new(external_node.digest()).into(),
345        };
346
347        // Decorators must be handled specially for basic block nodes.
348        // For other node types we can handle it centrally.
349        if !mapped_node.is_basic_block() {
350            mapped_node.append_before_enter(&map_decorators(node.before_enter())?);
351            mapped_node.append_after_exit(&map_decorators(node.after_exit())?);
352        }
353
354        Ok(mapped_node)
355    }
356
357    // HELPERS
358    // ================================================================================================
359
360    /// Returns the ID of the node in the merged forest that matches the given
361    /// fingerprint, if any.
362    fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
363        self.node_id_by_hash.get(fingerprint).copied()
364    }
365}
366
367// MAST FOREST ROOT MAP
368// ================================================================================================
369
370/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
371///
372/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
373/// forest. See [`MastForest::merge`] for more details.
374#[derive(Debug, Clone, PartialEq, Eq)]
375pub struct MastForestRootMap {
376    root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
377}
378
379impl MastForestRootMap {
380    fn from_node_id_map(id_map: Vec<MastForestNodeIdMap>, forests: Vec<&MastForest>) -> Self {
381        let mut root_maps = vec![BTreeMap::new(); forests.len()];
382
383        for (forest_idx, forest) in forests.into_iter().enumerate() {
384            for root in forest.procedure_roots() {
385                let new_id = id_map[forest_idx]
386                    .get(root)
387                    .copied()
388                    .expect("every node id should be mapped to its new id");
389                root_maps[forest_idx].insert(*root, new_id);
390            }
391        }
392
393        Self { root_maps }
394    }
395
396    /// Maps the given root to its new location in the merged forest, if such a mapping exists.
397    ///
398    /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
399    pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
400        self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
401    }
402}
403
404// DECORATOR ID MAP
405// ================================================================================================
406
407/// A specialized map from [`DecoratorId`] -> [`DecoratorId`].
408///
409/// When mapping Decorator IDs during merging, we always map all IDs of the merging
410/// forest to new ids. Hence it is more efficient to use a `Vec` instead of, say, a `BTreeMap`.
411///
412/// In other words, this type is similar to `BTreeMap<ID, ID>` but takes advantage of the fact that
413/// the keys are contiguous.
414///
415/// This type is meant to encapsulates some guarantees:
416///
417/// - Indexing into the vector for any ID is safe if that ID is valid for the corresponding forest.
418///   Despite that, we still cannot index unconditionally in case a node with invalid
419///   [`DecoratorId`]s is passed to `merge`.
420/// - The entry itself can be either None or Some. However:
421///   - For `DecoratorId`s we iterate and insert all decorators into this map before retrieving any
422///     entry, so all entries contain `Some`. Because of this, we can use `expect` in `get` for the
423///     `Option` value.
424/// - Similarly, inserting any ID from the corresponding forest is safe as the map contains a
425///   pre-allocated `Vec` of the appropriate size.
426struct DecoratorIdMap {
427    inner: Vec<Option<DecoratorId>>,
428}
429
430impl DecoratorIdMap {
431    fn new(num_ids: usize) -> Self {
432        Self { inner: vec![None; num_ids] }
433    }
434
435    /// Maps the given key to the given value.
436    ///
437    /// It is the caller's responsibility to only pass keys that belong to the forest for which this
438    /// map was originally created.
439    fn insert(&mut self, key: DecoratorId, value: DecoratorId) {
440        self.inner[key.as_usize()] = Some(value);
441    }
442
443    /// Retrieves the value for the given key.
444    fn get(&self, key: &DecoratorId) -> Option<DecoratorId> {
445        self.inner
446            .get(key.as_usize())
447            .map(|id| id.expect("every id should have a Some entry in the map when calling get"))
448    }
449
450    fn len(&self) -> usize {
451        self.inner.len()
452    }
453}
454
455/// A type definition for increased readability in function signatures.
456type MastForestNodeIdMap = BTreeMap<MastNodeId, MastNodeId>;