Skip to main content

miden_core/mast/merger/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use crate::{
4    crypto::hash::Blake3Digest,
5    mast::{
6        DecoratorId, MastForest, MastForestContributor, MastForestError, MastNode, MastNodeBuilder,
7        MastNodeFingerprint, MastNodeId, MultiMastForestIteratorItem, MultiMastForestNodeIter,
8    },
9    utils::{DenseIdMap, IndexVec},
10};
11
12#[cfg(test)]
13mod tests;
14
15/// A type that allows merging [`MastForest`]s.
16///
17/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
18pub(crate) struct MastForestMerger {
19    mast_forest: MastForest,
20    // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
21    // computation.
22    //
23    // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
24    // `mast_forest` are also added to the indices.
25    node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
26    hash_by_node_id: IndexVec<MastNodeId, MastNodeFingerprint>,
27    decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
28    /// Mappings from old decorator and node ids to their new ids.
29    ///
30    /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
31    decorator_id_mappings: Vec<DenseIdMap<DecoratorId, DecoratorId>>,
32    /// Mappings from previous `MastNodeId`s to their new ids.
33    ///
34    /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
35    node_id_mappings: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
36}
37
38impl MastForestMerger {
39    /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
40    /// into it.
41    ///
42    /// # Normalizing Behavior
43    ///
44    /// This function performs normalization of the merged forest, which:
45    /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
46    /// - Creates a clean, deduplicated forest structure
47    /// - Provides consistent node ordering regardless of input
48    ///
49    /// This normalization is idempotent, but it means that even for single-forest merges, the
50    /// resulting forest may have different node IDs and digests than the input. See assembly
51    /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
52    /// behavior.
53    pub(crate) fn merge<'forest>(
54        forests: impl IntoIterator<Item = &'forest MastForest>,
55    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
56        let forests = forests.into_iter().collect::<Vec<_>>();
57
58        let decorator_id_mappings = Vec::with_capacity(forests.len());
59        let node_id_mappings =
60            forests.iter().map(|f| DenseIdMap::with_len(f.nodes().len())).collect();
61
62        let mut merger = Self {
63            node_id_by_hash: BTreeMap::new(),
64            hash_by_node_id: IndexVec::new(),
65            decorators_by_hash: BTreeMap::new(),
66            mast_forest: MastForest::new(),
67            decorator_id_mappings,
68            node_id_mappings,
69        };
70
71        merger.merge_inner(forests.clone())?;
72
73        let Self { mast_forest, node_id_mappings, .. } = merger;
74
75        let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
76
77        Ok((mast_forest, root_maps))
78    }
79
80    /// Merges all `forests` into self.
81    ///
82    /// It does this in three steps:
83    ///
84    /// 1. Merge all advice maps, checking for key collisions.
85    /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
86    ///    mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
87    ///    merged forest.
88    /// 3. Merge all nodes of forests.
89    ///    - Similar to decorators, node indices might move during merging, so the merger keeps a
90    ///      node id mapping as it merges nodes.
91    ///    - This is a depth-first traversal over all forests to ensure all children are processed
92    ///      before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
93    ///      on this traversal.
94    ///    - Because all parents are processed after their children, we can use the node id mapping
95    ///      to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
96    ///      forest.
97    ///    - If any external node is encountered during this traversal with a digest `foo` for which
98    ///      a `replacement` node exists in another forest with digest `foo`, then the external node
99    ///      will be replaced by that node. In particular, it means we do not want to add the
100    ///      external node to the merged forest, so it is never yielded from the iterator.
101    ///      - Assuming the simple case, where the `replacement` was not visited yet and is just a
102    ///        single node (not a tree), the iterator would first yield the `replacement` node which
103    ///        means it is going to be merged into the forest.
104    ///      - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
105    ///        which signals that an external node was replaced by another node. In this example,
106    ///        the `replacement_*` indices contained in that variant would point to the
107    ///        `replacement` node. Now we can simply add a mapping from the external node to the
108    ///        `replacement` node in our node id mapping which means all nodes that referenced the
109    ///        external node will point to the `replacement` instead.
110    /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
111    ///    their potentially new indices in the merged forest and add them to the forest,
112    ///    deduplicating in the process, too.
113    fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
114        for other_forest in forests.iter() {
115            self.merge_advice_map(other_forest)?;
116        }
117        for other_forest in forests.iter() {
118            self.merge_decorators(other_forest)?;
119        }
120        for other_forest in forests.iter() {
121            self.merge_error_codes(other_forest)?;
122        }
123
124        let iterator = MultiMastForestNodeIter::new(forests.clone());
125        for item in iterator {
126            match item {
127                MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
128                    let node = forests[forest_idx][node_id].clone();
129                    self.merge_node(forest_idx, node_id, node, &forests)?;
130                },
131                MultiMastForestIteratorItem::ExternalNodeReplacement {
132                    // forest index of the node which replaces the external node
133                    replacement_forest_idx,
134                    // ID of the node that replaces the external node
135                    replacement_mast_node_id,
136                    // forest index of the external node
137                    replaced_forest_idx,
138                    // ID of the external node
139                    replaced_mast_node_id,
140                } => {
141                    // The iterator is not aware of the merged forest, so the node indices it yields
142                    // are for the existing forests. That means we have to map the ID of the
143                    // replacement to its new location, since it was previously merged and its IDs
144                    // have very likely changed.
145                    let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
146                        .get(replacement_mast_node_id)
147                        .expect("every merged node id should be mapped");
148
149                    // SAFETY: The iterator only yields valid forest indices, so it is safe to index
150                    // directly.
151                    self.node_id_mappings[replaced_forest_idx]
152                        .insert(replaced_mast_node_id, mapped_replacement);
153                },
154            }
155        }
156
157        for (forest_idx, forest) in forests.iter().enumerate() {
158            self.merge_roots(forest_idx, forest)?;
159        }
160
161        Ok(())
162    }
163
164    fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
165        let mut decorator_id_remapping =
166            DenseIdMap::with_len(other_forest.debug_info.num_decorators());
167
168        for (merging_id, merging_decorator) in
169            other_forest.debug_info.decorators().iter().enumerate()
170        {
171            let merging_decorator_hash = merging_decorator.fingerprint();
172            let new_decorator_id = if let Some(existing_decorator) =
173                self.decorators_by_hash.get(&merging_decorator_hash)
174            {
175                *existing_decorator
176            } else {
177                let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
178                self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
179                new_decorator_id
180            };
181
182            decorator_id_remapping
183                .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
184        }
185
186        self.decorator_id_mappings.push(decorator_id_remapping);
187
188        Ok(())
189    }
190
191    fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
192        self.mast_forest
193            .advice_map
194            .merge(&other_forest.advice_map)
195            .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
196    }
197
198    fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
199        self.mast_forest.debug_info.extend_error_codes(
200            other_forest.debug_info.error_codes().map(|(k, v)| (*k, v.clone())),
201        );
202        Ok(())
203    }
204
205    fn merge_node(
206        &mut self,
207        forest_idx: usize,
208        merging_id: MastNodeId,
209        node: MastNode,
210        original_forests: &[&MastForest],
211    ) -> Result<(), MastForestError> {
212        // We need to remap the node prior to computing the MastNodeFingerprint.
213        //
214        // This is because the MastNodeFingerprint computation looks up its descendants and
215        // decorators in the internal index, and if we were to pass the original node to
216        // that computation, it would look up the incorrect descendants and decorators
217        // (since the descendant's indices may have changed).
218        //
219        // Remapping at this point is guaranteed to be "complete", meaning all ids of children
220        // will be present in the node id mapping since the DFS iteration guarantees
221        // that all children of this `node` have been processed before this node and
222        // their indices have been added to the mappings.
223        let remapped_builder = self.build_with_remapped_children(
224            merging_id,
225            node,
226            original_forests[forest_idx],
227            &self.node_id_mappings[forest_idx],
228            &self.decorator_id_mappings[forest_idx],
229        )?;
230
231        let node_fingerprint =
232            remapped_builder.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)?;
233
234        match self.lookup_node_by_fingerprint(&node_fingerprint) {
235            Some(matching_node_id) => {
236                // If a node with a matching fingerprint exists, then the merging node is a
237                // duplicate and we remap it to the existing node.
238                self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
239            },
240            None => {
241                // If no node with a matching fingerprint exists, then the merging node is
242                // unique and we can add it to the merged forest using builders.
243                let new_node_id = remapped_builder.add_to_forest(&mut self.mast_forest)?;
244                self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
245
246                // We need to update the indices with the newly inserted nodes
247                // since the MastNodeFingerprint computation requires all descendants of a node
248                // to be in this index. Hence when we encounter a node in the merging forest
249                // which has descendants (Call, Loop, Split, ...), then their descendants need to be
250                // in the indices.
251                self.node_id_by_hash.insert(node_fingerprint, new_node_id);
252                let returned_id = self
253                    .hash_by_node_id
254                    .push(node_fingerprint)
255                    .map_err(|_| MastForestError::TooManyNodes)?;
256                debug_assert_eq!(
257                    returned_id, new_node_id,
258                    "hash_by_node_id push() should return the same node IDs as node_id_by_hash"
259                );
260            },
261        }
262
263        Ok(())
264    }
265
266    fn merge_roots(
267        &mut self,
268        forest_idx: usize,
269        other_forest: &MastForest,
270    ) -> Result<(), MastForestError> {
271        for root_id in other_forest.roots.iter() {
272            // Map the previous root to its possibly new id.
273            let new_root = self.node_id_mappings[forest_idx]
274                .get(*root_id)
275                .expect("all node ids should have an entry");
276            // This takes O(n) where n is the number of roots in the merged forest every time to
277            // check if the root already exists. As the number of roots is relatively low generally,
278            // this should be okay.
279            self.mast_forest.make_root(new_root);
280        }
281
282        Ok(())
283    }
284
285    // HELPERS
286    // ================================================================================================
287
288    /// Returns the ID of the node in the merged forest that matches the given
289    /// fingerprint, if any.
290    fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
291        self.node_id_by_hash.get(fingerprint).copied()
292    }
293
294    /// Builds a new node with remapped children and decorators using the provided mappings.
295    fn build_with_remapped_children(
296        &self,
297        merging_id: MastNodeId,
298        src: MastNode,
299        original_forest: &MastForest,
300        nmap: &DenseIdMap<MastNodeId, MastNodeId>,
301        dmap: &DenseIdMap<DecoratorId, DecoratorId>,
302    ) -> Result<MastNodeBuilder, MastForestError> {
303        super::build_node_with_remapped_ids(merging_id, src, original_forest, nmap, dmap)
304    }
305}
306
307// MAST FOREST ROOT MAP
308// ================================================================================================
309
310/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
311///
312/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
313/// forest. See [`MastForest::merge`] for more details.
314#[derive(Debug, Clone, PartialEq, Eq)]
315pub struct MastForestRootMap {
316    root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
317}
318
319impl MastForestRootMap {
320    fn from_node_id_map(
321        id_map: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
322        forests: Vec<&MastForest>,
323    ) -> Self {
324        let mut root_maps = vec![BTreeMap::new(); forests.len()];
325
326        for (forest_idx, forest) in forests.into_iter().enumerate() {
327            for root in forest.procedure_roots() {
328                let new_id = id_map[forest_idx]
329                    .get(*root)
330                    .expect("every node id should be mapped to its new id");
331                root_maps[forest_idx].insert(*root, new_id);
332            }
333        }
334
335        Self { root_maps }
336    }
337
338    /// Maps the given root to its new location in the merged forest, if such a mapping exists.
339    ///
340    /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
341    pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
342        self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
343    }
344}