miden_core/mast/merger/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_crypto::hash::blake::Blake3Digest;
4
5use crate::{
6    DenseIdMap, IndexVec,
7    mast::{
8        BasicBlockNode, CallNode, DecoratorId, DynNode, ExternalNode, JoinNode, LoopNode,
9        MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId,
10        MultiMastForestIteratorItem, MultiMastForestNodeIter, SplitNode, node::MastNodeExt,
11    },
12};
13
14#[cfg(test)]
15mod tests;
16
17/// A type that allows merging [`MastForest`]s.
18///
19/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
20pub(crate) struct MastForestMerger {
21    mast_forest: MastForest,
22    // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
23    // computation.
24    //
25    // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
26    // `mast_forest` are also added to the indices.
27    node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
28    hash_by_node_id: IndexVec<MastNodeId, MastNodeFingerprint>,
29    decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
30    /// Mappings from old decorator and node ids to their new ids.
31    ///
32    /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
33    decorator_id_mappings: Vec<DenseIdMap<DecoratorId, DecoratorId>>,
34    /// Mappings from previous `MastNodeId`s to their new ids.
35    ///
36    /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
37    node_id_mappings: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
38}
39
40impl MastForestMerger {
41    /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
42    /// into it.
43    ///
44    /// # Normalizing Behavior
45    ///
46    /// This function performs normalization of the merged forest, which:
47    /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
48    /// - Creates a clean, deduplicated forest structure
49    /// - Provides consistent node ordering regardless of input
50    ///
51    /// This normalization is idempotent, but it means that even for single-forest merges, the
52    /// resulting forest may have different node IDs and digests than the input. See assembly
53    /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
54    /// behavior.
55    pub(crate) fn merge<'forest>(
56        forests: impl IntoIterator<Item = &'forest MastForest>,
57    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
58        let forests = forests.into_iter().collect::<Vec<_>>();
59
60        let decorator_id_mappings = Vec::with_capacity(forests.len());
61        let node_id_mappings =
62            forests.iter().map(|f| DenseIdMap::with_len(f.nodes().len())).collect();
63
64        let mut merger = Self {
65            node_id_by_hash: BTreeMap::new(),
66            hash_by_node_id: IndexVec::new(),
67            decorators_by_hash: BTreeMap::new(),
68            mast_forest: MastForest::new(),
69            decorator_id_mappings,
70            node_id_mappings,
71        };
72
73        merger.merge_inner(forests.clone())?;
74
75        let Self { mast_forest, node_id_mappings, .. } = merger;
76
77        let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
78
79        Ok((mast_forest, root_maps))
80    }
81
82    /// Merges all `forests` into self.
83    ///
84    /// It does this in three steps:
85    ///
86    /// 1. Merge all advice maps, checking for key collisions.
87    /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
88    ///    mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
89    ///    merged forest.
90    /// 3. Merge all nodes of forests.
91    ///    - Similar to decorators, node indices might move during merging, so the merger keeps a
92    ///      node id mapping as it merges nodes.
93    ///    - This is a depth-first traversal over all forests to ensure all children are processed
94    ///      before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
95    ///      on this traversal.
96    ///    - Because all parents are processed after their children, we can use the node id mapping
97    ///      to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
98    ///      forest.
99    ///    - If any external node is encountered during this traversal with a digest `foo` for which
100    ///      a `replacement` node exists in another forest with digest `foo`, then the external node
101    ///      will be replaced by that node. In particular, it means we do not want to add the
102    ///      external node to the merged forest, so it is never yielded from the iterator.
103    ///      - Assuming the simple case, where the `replacement` was not visited yet and is just a
104    ///        single node (not a tree), the iterator would first yield the `replacement` node which
105    ///        means it is going to be merged into the forest.
106    ///      - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
107    ///        which signals that an external node was replaced by another node. In this example,
108    ///        the `replacement_*` indices contained in that variant would point to the
109    ///        `replacement` node. Now we can simply add a mapping from the external node to the
110    ///        `replacement` node in our node id mapping which means all nodes that referenced the
111    ///        external node will point to the `replacement` instead.
112    /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
113    ///    their potentially new indices in the merged forest and add them to the forest,
114    ///    deduplicating in the process, too.
115    fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
116        for other_forest in forests.iter() {
117            self.merge_advice_map(other_forest)?;
118        }
119        for other_forest in forests.iter() {
120            self.merge_decorators(other_forest)?;
121        }
122        for other_forest in forests.iter() {
123            self.merge_error_codes(other_forest)?;
124        }
125
126        let iterator = MultiMastForestNodeIter::new(forests.clone());
127        for item in iterator {
128            match item {
129                MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
130                    let node = &forests[forest_idx][node_id];
131                    self.merge_node(forest_idx, node_id, node)?;
132                },
133                MultiMastForestIteratorItem::ExternalNodeReplacement {
134                    // forest index of the node which replaces the external node
135                    replacement_forest_idx,
136                    // ID of the node that replaces the external node
137                    replacement_mast_node_id,
138                    // forest index of the external node
139                    replaced_forest_idx,
140                    // ID of the external node
141                    replaced_mast_node_id,
142                } => {
143                    // The iterator is not aware of the merged forest, so the node indices it yields
144                    // are for the existing forests. That means we have to map the ID of the
145                    // replacement to its new location, since it was previously merged and its IDs
146                    // have very likely changed.
147                    let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
148                        .get(replacement_mast_node_id)
149                        .expect("every merged node id should be mapped");
150
151                    // SAFETY: The iterator only yields valid forest indices, so it is safe to index
152                    // directly.
153                    self.node_id_mappings[replaced_forest_idx]
154                        .insert(replaced_mast_node_id, mapped_replacement);
155                },
156            }
157        }
158
159        for (forest_idx, forest) in forests.iter().enumerate() {
160            self.merge_roots(forest_idx, forest)?;
161        }
162
163        Ok(())
164    }
165
166    fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
167        let mut decorator_id_remapping = DenseIdMap::with_len(other_forest.decorators.len());
168
169        for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
170            let merging_decorator_hash = merging_decorator.fingerprint();
171            let new_decorator_id = if let Some(existing_decorator) =
172                self.decorators_by_hash.get(&merging_decorator_hash)
173            {
174                *existing_decorator
175            } else {
176                let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
177                self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
178                new_decorator_id
179            };
180
181            decorator_id_remapping
182                .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
183        }
184
185        self.decorator_id_mappings.push(decorator_id_remapping);
186
187        Ok(())
188    }
189
190    fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
191        self.mast_forest
192            .advice_map
193            .merge(&other_forest.advice_map)
194            .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
195    }
196
197    fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
198        self.mast_forest.error_codes.extend(other_forest.error_codes.clone());
199        Ok(())
200    }
201
202    fn merge_node(
203        &mut self,
204        forest_idx: usize,
205        merging_id: MastNodeId,
206        node: &MastNode,
207    ) -> Result<(), MastForestError> {
208        // We need to remap the node prior to computing the MastNodeFingerprint.
209        //
210        // This is because the MastNodeFingerprint computation looks up its descendants and
211        // decorators in the internal index, and if we were to pass the original node to
212        // that computation, it would look up the incorrect descendants and decorators
213        // (since the descendant's indices may have changed).
214        //
215        // Remapping at this point is guaranteed to be "complete", meaning all ids of children
216        // will be present in the node id mapping since the DFS iteration guarantees
217        // that all children of this `node` have been processed before this node and
218        // their indices have been added to the mappings.
219        let remapped_node = self.remap_node(forest_idx, node)?;
220
221        let node_fingerprint = MastNodeFingerprint::from_mast_node(
222            &self.mast_forest,
223            &self.hash_by_node_id,
224            &remapped_node,
225        )
226        .expect(
227            "hash_by_node_id should contain the fingerprints of all children of `remapped_node`",
228        );
229
230        match self.lookup_node_by_fingerprint(&node_fingerprint) {
231            Some(matching_node_id) => {
232                // If a node with a matching fingerprint exists, then the merging node is a
233                // duplicate and we remap it to the existing node.
234                self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
235            },
236            None => {
237                // If no node with a matching fingerprint exists, then the merging node is
238                // unique and we can add it to the merged forest.
239                let new_node_id = self.mast_forest.add_node(remapped_node)?;
240                self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
241
242                // We need to update the indices with the newly inserted nodes
243                // since the MastNodeFingerprint computation requires all descendants of a node
244                // to be in this index. Hence when we encounter a node in the merging forest
245                // which has descendants (Call, Loop, Split, ...), then their descendants need to be
246                // in the indices.
247                self.node_id_by_hash.insert(node_fingerprint, new_node_id);
248                let returned_id = self
249                    .hash_by_node_id
250                    .push(node_fingerprint)
251                    .map_err(|_| MastForestError::TooManyNodes)?;
252                debug_assert_eq!(
253                    returned_id, new_node_id,
254                    "hash_by_node_id push() should return the same node IDs as node_id_by_hash"
255                );
256            },
257        }
258
259        Ok(())
260    }
261
262    fn merge_roots(
263        &mut self,
264        forest_idx: usize,
265        other_forest: &MastForest,
266    ) -> Result<(), MastForestError> {
267        for root_id in other_forest.roots.iter() {
268            // Map the previous root to its possibly new id.
269            let new_root = self.node_id_mappings[forest_idx]
270                .get(*root_id)
271                .expect("all node ids should have an entry");
272            // This takes O(n) where n is the number of roots in the merged forest every time to
273            // check if the root already exists. As the number of roots is relatively low generally,
274            // this should be okay.
275            self.mast_forest.make_root(new_root);
276        }
277
278        Ok(())
279    }
280
281    /// Remaps a nodes' potentially contained children and decorators to their new IDs according to
282    /// the given maps.
283    fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
284        self.build_node_with_remapped_children(
285            node,
286            &self.node_id_mappings[forest_idx],
287            &self.decorator_id_mappings[forest_idx],
288        )
289    }
290
291    // HELPERS
292    // ================================================================================================
293
294    /// Remaps a child node ID using the node ID map, returning the original ID if not found.
295    fn remap_child(
296        &self,
297        child_id: MastNodeId,
298        nmap: &DenseIdMap<MastNodeId, MastNodeId>,
299    ) -> MastNodeId {
300        nmap.get(child_id).expect("every node id should have an entry")
301    }
302
303    /// Returns the ID of the node in the merged forest that matches the given
304    /// fingerprint, if any.
305    fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
306        self.node_id_by_hash.get(fingerprint).copied()
307    }
308
309    /// Builds a new node with remapped children and decorators using the provided mappings.
310    fn build_node_with_remapped_children(
311        &self,
312        src: &MastNode,
313        nmap: &DenseIdMap<MastNodeId, MastNodeId>,
314        dmap: &DenseIdMap<DecoratorId, DecoratorId>,
315    ) -> Result<MastNode, MastForestError> {
316        let map_decorator_id = |decorator_id: DecoratorId| {
317            dmap.get(decorator_id)
318                .ok_or_else(|| MastForestError::DecoratorIdOverflow(decorator_id, dmap.len()))
319        };
320
321        let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
322            decorators.iter().copied().map(map_decorator_id).collect()
323        };
324
325        let mut mapped_node: MastNode = match src {
326            MastNode::Join(join_node) => {
327                let first = self.remap_child(join_node.first(), nmap);
328                let second = self.remap_child(join_node.second(), nmap);
329
330                JoinNode::new([first, second], &self.mast_forest)
331                    .expect("JoinNode children should have been mapped to a lower index")
332                    .into()
333            },
334            MastNode::Split(split_node) => {
335                let if_branch = self.remap_child(split_node.on_true(), nmap);
336                let else_branch = self.remap_child(split_node.on_false(), nmap);
337
338                SplitNode::new([if_branch, else_branch], &self.mast_forest)
339                    .expect("SplitNode children should have been mapped to a lower index")
340                    .into()
341            },
342            MastNode::Loop(loop_node) => {
343                let body = self.remap_child(loop_node.body(), nmap);
344                LoopNode::new(body, &self.mast_forest)
345                    .expect("LoopNode children should have been mapped to a lower index")
346                    .into()
347            },
348            MastNode::Call(call_node) => {
349                let callee = self.remap_child(call_node.callee(), nmap);
350                CallNode::new(callee, &self.mast_forest)
351                    .expect("CallNode children should have been mapped to a lower index")
352                    .into()
353            },
354            MastNode::Block(basic_block_node) => BasicBlockNode::new(
355                basic_block_node.operations().copied().collect(),
356                basic_block_node
357                    .indexed_decorator_iter()
358                    .map(|(idx, decorator_id)| {
359                        let mapped_decorator = map_decorator_id(decorator_id)?;
360                        Ok((idx, mapped_decorator))
361                    })
362                    .collect::<Result<Vec<_>, _>>()?,
363            )
364            .expect("previously valid BasicBlockNode should still be valid")
365            .into(),
366            MastNode::Dyn(_) => DynNode::new_dyn().into(),
367            MastNode::External(external_node) => ExternalNode::new(external_node.digest()).into(),
368        };
369
370        // Decorators must be handled specially for basic block nodes.
371        // For other node types we can handle it centrally.
372        {
373            mapped_node.append_before_enter(&map_decorators(src.before_enter())?);
374            mapped_node.append_after_exit(&map_decorators(src.after_exit())?);
375        }
376
377        Ok(mapped_node)
378    }
379}
380
381// MAST FOREST ROOT MAP
382// ================================================================================================
383
384/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
385///
386/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
387/// forest. See [`MastForest::merge`] for more details.
388#[derive(Debug, Clone, PartialEq, Eq)]
389pub struct MastForestRootMap {
390    root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
391}
392
393impl MastForestRootMap {
394    fn from_node_id_map(
395        id_map: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
396        forests: Vec<&MastForest>,
397    ) -> Self {
398        let mut root_maps = vec![BTreeMap::new(); forests.len()];
399
400        for (forest_idx, forest) in forests.into_iter().enumerate() {
401            for root in forest.procedure_roots() {
402                let new_id = id_map[forest_idx]
403                    .get(*root)
404                    .expect("every node id should be mapped to its new id");
405                root_maps[forest_idx].insert(*root, new_id);
406            }
407        }
408
409        Self { root_maps }
410    }
411
412    /// Maps the given root to its new location in the merged forest, if such a mapping exists.
413    ///
414    /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
415    pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
416        self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
417    }
418}