Skip to main content

miden_core/mast/merger/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use crate::{
4    crypto::hash::Blake3Digest,
5    mast::{
6        AsmOpId, DebugVarId, DecoratorId, MastForest, MastForestContributor, MastForestError,
7        MastNode, MastNodeBuilder, MastNodeFingerprint, MastNodeId, MultiMastForestIteratorItem,
8        MultiMastForestNodeIter,
9    },
10    serde::Serializable,
11    utils::{DenseIdMap, IndexVec},
12};
13
14#[cfg(test)]
15mod tests;
16
17/// A type that allows merging [`MastForest`]s.
18///
19/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
20pub(crate) struct MastForestMerger {
21    mast_forest: MastForest,
22    // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
23    // computation.
24    //
25    // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
26    // `mast_forest` are also added to the indices.
27    node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
28    hash_by_node_id: IndexVec<MastNodeId, MastNodeFingerprint>,
29    decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
30    /// Mappings from old decorator and node ids to their new ids.
31    ///
32    /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
33    decorator_id_mappings: Vec<DenseIdMap<DecoratorId, DecoratorId>>,
34    /// Mappings from previous `MastNodeId`s to their new ids.
35    ///
36    /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
37    node_id_mappings: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
38}
39
40impl MastForestMerger {
41    /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
42    /// into it.
43    ///
44    /// # Normalizing Behavior
45    ///
46    /// This function performs normalization of the merged forest, which:
47    /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
48    /// - Creates a clean, deduplicated forest structure
49    /// - Provides consistent node ordering regardless of input
50    ///
51    /// This normalization is idempotent, but it means that even for single-forest merges, the
52    /// resulting forest may have different node IDs and digests than the input. See assembly
53    /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
54    /// behavior.
55    pub(crate) fn merge<'forest>(
56        forests: impl IntoIterator<Item = &'forest MastForest>,
57    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
58        let forests = forests.into_iter().collect::<Vec<_>>();
59
60        let decorator_id_mappings = Vec::with_capacity(forests.len());
61        let node_id_mappings =
62            forests.iter().map(|f| DenseIdMap::with_len(f.nodes().len())).collect();
63
64        let mut merger = Self {
65            node_id_by_hash: BTreeMap::new(),
66            hash_by_node_id: IndexVec::new(),
67            decorators_by_hash: BTreeMap::new(),
68            mast_forest: MastForest::new(),
69            decorator_id_mappings,
70            node_id_mappings,
71        };
72
73        merger.merge_inner(forests.clone())?;
74
75        let Self { mast_forest, node_id_mappings, .. } = merger;
76
77        let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
78
79        Ok((mast_forest, root_maps))
80    }
81
82    /// Merges all `forests` into self.
83    ///
84    /// It does this in three steps:
85    ///
86    /// 1. Merge all advice maps, checking for key collisions.
87    /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
88    ///    mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
89    ///    merged forest.
90    /// 3. Merge all nodes of forests.
91    ///    - Similar to decorators, node indices might move during merging, so the merger keeps a
92    ///      node id mapping as it merges nodes.
93    ///    - This is a depth-first traversal over all forests to ensure all children are processed
94    ///      before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
95    ///      on this traversal.
96    ///    - Because all parents are processed after their children, we can use the node id mapping
97    ///      to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
98    ///      forest.
99    ///    - If any external node is encountered during this traversal with a digest `foo` for which
100    ///      a `replacement` node exists in another forest with digest `foo`, then the external node
101    ///      will be replaced by that node. In particular, it means we do not want to add the
102    ///      external node to the merged forest, so it is never yielded from the iterator.
103    ///      - Assuming the simple case, where the `replacement` was not visited yet and is just a
104    ///        single node (not a tree), the iterator would first yield the `replacement` node which
105    ///        means it is going to be merged into the forest.
106    ///      - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
107    ///        which signals that an external node was replaced by another node. In this example,
108    ///        the `replacement_*` indices contained in that variant would point to the
109    ///        `replacement` node. Now we can simply add a mapping from the external node to the
110    ///        `replacement` node in our node id mapping which means all nodes that referenced the
111    ///        external node will point to the `replacement` instead.
112    /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
113    ///    their potentially new indices in the merged forest and add them to the forest,
114    ///    deduplicating in the process, too.
115    fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
116        for other_forest in forests.iter() {
117            self.merge_advice_map(other_forest)?;
118        }
119        for other_forest in forests.iter() {
120            self.merge_decorators(other_forest)?;
121        }
122        for other_forest in forests.iter() {
123            self.merge_error_codes(other_forest)?;
124        }
125
126        let iterator = MultiMastForestNodeIter::new(forests.clone());
127        for item in iterator {
128            match item {
129                MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
130                    let node = forests[forest_idx][node_id].clone();
131                    self.merge_node(forest_idx, node_id, node, &forests)?;
132                },
133                MultiMastForestIteratorItem::ExternalNodeReplacement {
134                    // forest index of the node which replaces the external node
135                    replacement_forest_idx,
136                    // ID of the node that replaces the external node
137                    replacement_mast_node_id,
138                    // forest index of the external node
139                    replaced_forest_idx,
140                    // ID of the external node
141                    replaced_mast_node_id,
142                } => {
143                    // The iterator is not aware of the merged forest, so the node indices it yields
144                    // are for the existing forests. That means we have to map the ID of the
145                    // replacement to its new location, since it was previously merged and its IDs
146                    // have very likely changed.
147                    let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
148                        .get(replacement_mast_node_id)
149                        .expect("every merged node id should be mapped");
150
151                    // SAFETY: The iterator only yields valid forest indices, so it is safe to index
152                    // directly.
153                    self.node_id_mappings[replaced_forest_idx]
154                        .insert(replaced_mast_node_id, mapped_replacement);
155                },
156            }
157        }
158
159        for (forest_idx, forest) in forests.iter().enumerate() {
160            self.merge_roots(forest_idx, forest)?;
161        }
162
163        self.merge_debug_metadata(&forests)?;
164
165        Ok(())
166    }
167
168    fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
169        let mut decorator_id_remapping =
170            DenseIdMap::with_len(other_forest.debug_info.num_decorators());
171
172        for (merging_id, merging_decorator) in
173            other_forest.debug_info.decorators().iter().enumerate()
174        {
175            let merging_decorator_hash = merging_decorator.fingerprint();
176            let new_decorator_id = if let Some(existing_decorator) =
177                self.decorators_by_hash.get(&merging_decorator_hash)
178            {
179                *existing_decorator
180            } else {
181                let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
182                self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
183                new_decorator_id
184            };
185
186            decorator_id_remapping
187                .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
188        }
189
190        self.decorator_id_mappings.push(decorator_id_remapping);
191
192        Ok(())
193    }
194
195    fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
196        self.mast_forest
197            .advice_map
198            .merge(&other_forest.advice_map)
199            .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
200    }
201
202    fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
203        self.mast_forest.debug_info.extend_error_codes(
204            other_forest.debug_info.error_codes().map(|(k, v)| (*k, v.clone())),
205        );
206        Ok(())
207    }
208
209    fn merge_node(
210        &mut self,
211        forest_idx: usize,
212        merging_id: MastNodeId,
213        node: MastNode,
214        original_forests: &[&MastForest],
215    ) -> Result<(), MastForestError> {
216        // We need to remap the node prior to computing the MastNodeFingerprint.
217        //
218        // This is because the MastNodeFingerprint computation looks up its descendants and
219        // decorators in the internal index, and if we were to pass the original node to
220        // that computation, it would look up the incorrect descendants and decorators
221        // (since the descendant's indices may have changed).
222        //
223        // Remapping at this point is guaranteed to be "complete", meaning all ids of children
224        // will be present in the node id mapping since the DFS iteration guarantees
225        // that all children of this `node` have been processed before this node and
226        // their indices have been added to the mappings.
227        let remapped_builder = self.build_with_remapped_children(
228            merging_id,
229            node,
230            original_forests[forest_idx],
231            &self.node_id_mappings[forest_idx],
232            &self.decorator_id_mappings[forest_idx],
233        )?;
234
235        let base_fingerprint =
236            remapped_builder.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)?;
237
238        // Augment with the source node's debug vars so same-ops/different-vars
239        // blocks from different forests are not collapsed.
240        let debug_var_data =
241            serialize_debug_var_content_for_node(original_forests[forest_idx], merging_id);
242        let asm_op_data =
243            serialize_asm_op_content_for_node(original_forests[forest_idx], merging_id);
244        let node_fingerprint = base_fingerprint
245            .augment_with_data(&debug_var_data)
246            .augment_with_data(&asm_op_data);
247
248        match self.lookup_node_by_fingerprint(&node_fingerprint) {
249            Some(matching_node_id) => {
250                // If a node with a matching fingerprint exists, then the merging node is a
251                // duplicate and we remap it to the existing node.
252                self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
253            },
254            None => {
255                // If no node with a matching fingerprint exists, then the merging node is
256                // unique and we can add it to the merged forest using builders.
257                let new_node_id = remapped_builder.add_to_forest(&mut self.mast_forest)?;
258                self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
259
260                // We need to update the indices with the newly inserted nodes
261                // since the MastNodeFingerprint computation requires all descendants of a node
262                // to be in this index. Hence when we encounter a node in the merging forest
263                // which has descendants (Call, Loop, Split, ...), then their descendants need to be
264                // in the indices.
265                self.node_id_by_hash.insert(node_fingerprint, new_node_id);
266                let returned_id = self
267                    .hash_by_node_id
268                    .push(node_fingerprint)
269                    .map_err(|_| MastForestError::TooManyNodes)?;
270                debug_assert_eq!(
271                    returned_id, new_node_id,
272                    "hash_by_node_id push() should return the same node IDs as node_id_by_hash"
273                );
274            },
275        }
276
277        Ok(())
278    }
279
280    fn merge_roots(
281        &mut self,
282        forest_idx: usize,
283        other_forest: &MastForest,
284    ) -> Result<(), MastForestError> {
285        for root_id in other_forest.roots.iter() {
286            // Map the previous root to its possibly new id.
287            let new_root = self.node_id_mappings[forest_idx]
288                .get(*root_id)
289                .expect("all node ids should have an entry");
290            // This takes O(n) where n is the number of roots in the merged forest every time to
291            // check if the root already exists. As the number of roots is relatively low generally,
292            // this should be okay.
293            self.mast_forest.make_root(new_root);
294        }
295
296        Ok(())
297    }
298
299    /// Transfers procedure names, asm ops, and debug vars from the source forests
300    /// into the merged forest, remapping all IDs along the way.
301    ///
302    /// Procedure names are merged separately by digest. Per-node asm-op and debug-var
303    /// metadata are remapped by node ID, and when two source nodes map to the same merged
304    /// node (dedup), the first forest's per-node metadata wins.
305    fn merge_debug_metadata(&mut self, forests: &[&MastForest]) -> Result<(), MastForestError> {
306        // Procedure names are keyed by digest. First name wins so that a later
307        // forest cannot silently rename an already-registered procedure.
308        for forest in forests.iter() {
309            for (digest, name) in forest.debug_info.procedure_names() {
310                if self.mast_forest.debug_info.procedure_name(&digest).is_none() {
311                    self.mast_forest.debug_info.insert_procedure_name(digest, name.clone());
312                }
313            }
314        }
315
316        // Collect per-node asm-op and debug-var registrations across all forests.
317        // BTreeMap gives us sorted-by-node-id iteration, which the CSR requires.
318        let mut asm_entries: BTreeMap<MastNodeId, Vec<(usize, AsmOpId)>> = BTreeMap::new();
319        let mut dbg_entries: BTreeMap<MastNodeId, Vec<(usize, DebugVarId)>> = BTreeMap::new();
320
321        for (forest_idx, forest) in forests.iter().enumerate() {
322            // Copy AssemblyOp objects and build old→new AsmOpId remapping.
323            let mut asm_id_map: BTreeMap<AsmOpId, AsmOpId> = BTreeMap::new();
324            for (raw, asm_op) in forest.debug_info.asm_ops().iter().enumerate() {
325                let old_id = AsmOpId::new(raw as u32);
326                let new_id = self.mast_forest.debug_info.add_asm_op(asm_op.clone())?;
327                asm_id_map.insert(old_id, new_id);
328            }
329
330            // Copy DebugVarInfo objects and build old→new DebugVarId remapping.
331            let mut dbg_id_map: BTreeMap<DebugVarId, DebugVarId> = BTreeMap::new();
332            for (raw, dvar) in forest.debug_info.debug_vars().iter().enumerate() {
333                let old_id = DebugVarId::from(raw as u32);
334                let new_id = self.mast_forest.debug_info.add_debug_var(dvar.clone())?;
335                dbg_id_map.insert(old_id, new_id);
336            }
337
338            // For each source node, remap and store entries. First forest wins per node.
339            for old_raw in 0..forest.num_nodes() {
340                let old_id = MastNodeId::new_unchecked(old_raw);
341                let new_id = match self.node_id_mappings[forest_idx].get(old_id) {
342                    Some(id) => id,
343                    None => continue,
344                };
345
346                if let alloc::collections::btree_map::Entry::Vacant(e) = asm_entries.entry(new_id) {
347                    let ops = forest.debug_info.asm_ops_for_node(old_id);
348                    if !ops.is_empty() {
349                        let remapped =
350                            ops.into_iter().map(|(idx, id)| (idx, asm_id_map[&id])).collect();
351                        e.insert(remapped);
352                    }
353                }
354
355                if let alloc::collections::btree_map::Entry::Vacant(e) = dbg_entries.entry(new_id) {
356                    let vars = forest.debug_info.debug_vars_for_node(old_id);
357                    if !vars.is_empty() {
358                        let remapped =
359                            vars.into_iter().map(|(idx, id)| (idx, dbg_id_map[&id])).collect();
360                        e.insert(remapped);
361                    }
362                }
363            }
364        }
365
366        // Register in node-ID order (CSR sequential constraint).
367        for (node_id, entries) in asm_entries {
368            let num_ops = match &self.mast_forest[node_id] {
369                MastNode::Block(block) => block.num_operations() as usize,
370                _ => entries.iter().map(|(idx, _)| idx + 1).max().unwrap_or(0),
371            };
372            self.mast_forest
373                .debug_info
374                .register_asm_ops(node_id, num_ops, entries)
375                .map_err(|_| MastForestError::TooManyNodes)?;
376        }
377
378        for (node_id, entries) in dbg_entries {
379            self.mast_forest
380                .debug_info
381                .register_op_indexed_debug_vars(node_id, entries)
382                .map_err(|_| MastForestError::TooManyNodes)?;
383        }
384
385        Ok(())
386    }
387
388    // HELPERS
389    // ================================================================================================
390
391    /// Returns the ID of the node in the merged forest that matches the given
392    /// fingerprint, if any.
393    fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
394        self.node_id_by_hash.get(fingerprint).copied()
395    }
396
397    /// Builds a new node with remapped children and decorators using the provided mappings.
398    fn build_with_remapped_children(
399        &self,
400        merging_id: MastNodeId,
401        src: MastNode,
402        original_forest: &MastForest,
403        nmap: &DenseIdMap<MastNodeId, MastNodeId>,
404        dmap: &DenseIdMap<DecoratorId, DecoratorId>,
405    ) -> Result<MastNodeBuilder, MastForestError> {
406        super::build_node_with_remapped_ids(merging_id, src, original_forest, nmap, dmap)
407    }
408}
409
410// HELPERS
411// ================================================================================================
412
413/// Serializes the actual debug var *content* (name, location, etc.) for a node,
414/// producing a stable byte sequence suitable for fingerprint augmentation.
415///
416/// Unlike the assembler's `serialize_debug_vars` (which serializes `(op_idx, DebugVarId)` pairs),
417/// this serializes the resolved DebugVarInfo so that two forests assigning different DebugVarIds
418/// to identical variables still produce the same fingerprint contribution.
419fn serialize_debug_var_content_for_node(forest: &MastForest, node_id: MastNodeId) -> Vec<u8> {
420    let entries = forest.debug_info().debug_vars_for_node(node_id);
421    if entries.is_empty() {
422        return Vec::new();
423    }
424
425    let mut data = Vec::new();
426    for (op_idx, var_id) in entries {
427        data.extend_from_slice(&op_idx.to_le_bytes());
428        if let Some(info) = forest.debug_info().debug_var(var_id) {
429            info.write_into(&mut data);
430        }
431    }
432    data
433}
434
435/// Serializes the actual asm-op content for a node, producing a stable byte
436/// sequence suitable for fingerprint augmentation.
437///
438/// This ensures that nodes with identical structure but different source-mapping
439/// metadata do not collapse during merge.
440fn serialize_asm_op_content_for_node(forest: &MastForest, node_id: MastNodeId) -> Vec<u8> {
441    let entries = forest.debug_info().asm_ops_for_node(node_id);
442    if entries.is_empty() {
443        return Vec::new();
444    }
445
446    let mut data = Vec::new();
447    for (op_idx, asm_op_id) in entries {
448        data.extend_from_slice(&op_idx.to_le_bytes());
449        if let Some(asm_op) = forest.debug_info().asm_op(asm_op_id) {
450            asm_op.context_name().write_into(&mut data);
451            asm_op.op().write_into(&mut data);
452            asm_op.num_cycles().write_into(&mut data);
453            match asm_op.location() {
454                Some(location) => {
455                    data.push(1);
456                    location.uri.write_into(&mut data);
457                    data.extend_from_slice(&u32::from(location.start).to_le_bytes());
458                    data.extend_from_slice(&u32::from(location.end).to_le_bytes());
459                },
460                None => data.push(0),
461            }
462        }
463    }
464    data
465}
466
467// MAST FOREST ROOT MAP
468// ================================================================================================
469
470/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
471///
472/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
473/// forest. See [`MastForest::merge`] for more details.
474#[derive(Debug, Clone, PartialEq, Eq)]
475pub struct MastForestRootMap {
476    root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
477}
478
479impl MastForestRootMap {
480    fn from_node_id_map(
481        id_map: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
482        forests: Vec<&MastForest>,
483    ) -> Self {
484        let mut root_maps = vec![BTreeMap::new(); forests.len()];
485
486        for (forest_idx, forest) in forests.into_iter().enumerate() {
487            for root in forest.procedure_roots() {
488                let new_id = id_map[forest_idx]
489                    .get(*root)
490                    .expect("every node id should be mapped to its new id");
491                root_maps[forest_idx].insert(*root, new_id);
492            }
493        }
494
495        Self { root_maps }
496    }
497
498    /// Maps the given root to its new location in the merged forest, if such a mapping exists.
499    ///
500    /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
501    pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
502        self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
503    }
504}