miden_core/mast/merger/mod.rs
1use alloc::{collections::BTreeMap, vec::Vec};
2
3use crate::{
4 crypto::hash::Blake3Digest,
5 mast::{
6 AsmOpId, DebugVarId, DecoratorId, MastForest, MastForestContributor, MastForestError,
7 MastNode, MastNodeBuilder, MastNodeFingerprint, MastNodeId, MultiMastForestIteratorItem,
8 MultiMastForestNodeIter,
9 },
10 serde::Serializable,
11 utils::{DenseIdMap, IndexVec},
12};
13
14#[cfg(test)]
15mod tests;
16
17/// A type that allows merging [`MastForest`]s.
18///
19/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
20pub(crate) struct MastForestMerger {
21 mast_forest: MastForest,
22 // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
23 // computation.
24 //
25 // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
26 // `mast_forest` are also added to the indices.
27 node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
28 hash_by_node_id: IndexVec<MastNodeId, MastNodeFingerprint>,
29 decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
30 /// Mappings from old decorator and node ids to their new ids.
31 ///
32 /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
33 decorator_id_mappings: Vec<DenseIdMap<DecoratorId, DecoratorId>>,
34 /// Mappings from previous `MastNodeId`s to their new ids.
35 ///
36 /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
37 node_id_mappings: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
38}
39
40impl MastForestMerger {
41 /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
42 /// into it.
43 ///
44 /// # Normalizing Behavior
45 ///
46 /// This function performs normalization of the merged forest, which:
47 /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
48 /// - Creates a clean, deduplicated forest structure
49 /// - Provides consistent node ordering regardless of input
50 ///
51 /// This normalization is idempotent, but it means that even for single-forest merges, the
52 /// resulting forest may have different node IDs and digests than the input. See assembly
53 /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
54 /// behavior.
55 pub(crate) fn merge<'forest>(
56 forests: impl IntoIterator<Item = &'forest MastForest>,
57 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
58 let forests = forests.into_iter().collect::<Vec<_>>();
59
60 let decorator_id_mappings = Vec::with_capacity(forests.len());
61 let node_id_mappings =
62 forests.iter().map(|f| DenseIdMap::with_len(f.nodes().len())).collect();
63
64 let mut merger = Self {
65 node_id_by_hash: BTreeMap::new(),
66 hash_by_node_id: IndexVec::new(),
67 decorators_by_hash: BTreeMap::new(),
68 mast_forest: MastForest::new(),
69 decorator_id_mappings,
70 node_id_mappings,
71 };
72
73 merger.merge_inner(forests.clone())?;
74
75 let Self { mast_forest, node_id_mappings, .. } = merger;
76
77 let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
78
79 Ok((mast_forest, root_maps))
80 }
81
82 /// Merges all `forests` into self.
83 ///
84 /// It does this in three steps:
85 ///
86 /// 1. Merge all advice maps, checking for key collisions.
87 /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
88 /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
89 /// merged forest.
90 /// 3. Merge all nodes of forests.
91 /// - Similar to decorators, node indices might move during merging, so the merger keeps a
92 /// node id mapping as it merges nodes.
93 /// - This is a depth-first traversal over all forests to ensure all children are processed
94 /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
95 /// on this traversal.
96 /// - Because all parents are processed after their children, we can use the node id mapping
97 /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
98 /// forest.
99 /// - If any external node is encountered during this traversal with a digest `foo` for which
100 /// a `replacement` node exists in another forest with digest `foo`, then the external node
101 /// will be replaced by that node. In particular, it means we do not want to add the
102 /// external node to the merged forest, so it is never yielded from the iterator.
103 /// - Assuming the simple case, where the `replacement` was not visited yet and is just a
104 /// single node (not a tree), the iterator would first yield the `replacement` node which
105 /// means it is going to be merged into the forest.
106 /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
107 /// which signals that an external node was replaced by another node. In this example,
108 /// the `replacement_*` indices contained in that variant would point to the
109 /// `replacement` node. Now we can simply add a mapping from the external node to the
110 /// `replacement` node in our node id mapping which means all nodes that referenced the
111 /// external node will point to the `replacement` instead.
112 /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
113 /// their potentially new indices in the merged forest and add them to the forest,
114 /// deduplicating in the process, too.
115 fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
116 for other_forest in forests.iter() {
117 self.merge_advice_map(other_forest)?;
118 }
119 for other_forest in forests.iter() {
120 self.merge_decorators(other_forest)?;
121 }
122 for other_forest in forests.iter() {
123 self.merge_error_codes(other_forest)?;
124 }
125
126 let iterator = MultiMastForestNodeIter::new(forests.clone());
127 for item in iterator {
128 match item {
129 MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
130 let node = forests[forest_idx][node_id].clone();
131 self.merge_node(forest_idx, node_id, node, &forests)?;
132 },
133 MultiMastForestIteratorItem::ExternalNodeReplacement {
134 // forest index of the node which replaces the external node
135 replacement_forest_idx,
136 // ID of the node that replaces the external node
137 replacement_mast_node_id,
138 // forest index of the external node
139 replaced_forest_idx,
140 // ID of the external node
141 replaced_mast_node_id,
142 } => {
143 // The iterator is not aware of the merged forest, so the node indices it yields
144 // are for the existing forests. That means we have to map the ID of the
145 // replacement to its new location, since it was previously merged and its IDs
146 // have very likely changed.
147 let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
148 .get(replacement_mast_node_id)
149 .expect("every merged node id should be mapped");
150
151 // SAFETY: The iterator only yields valid forest indices, so it is safe to index
152 // directly.
153 self.node_id_mappings[replaced_forest_idx]
154 .insert(replaced_mast_node_id, mapped_replacement);
155 },
156 }
157 }
158
159 for (forest_idx, forest) in forests.iter().enumerate() {
160 self.merge_roots(forest_idx, forest)?;
161 }
162
163 self.merge_debug_metadata(&forests)?;
164
165 Ok(())
166 }
167
168 fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
169 let mut decorator_id_remapping =
170 DenseIdMap::with_len(other_forest.debug_info.num_decorators());
171
172 for (merging_id, merging_decorator) in
173 other_forest.debug_info.decorators().iter().enumerate()
174 {
175 let merging_decorator_hash = merging_decorator.fingerprint();
176 let new_decorator_id = if let Some(existing_decorator) =
177 self.decorators_by_hash.get(&merging_decorator_hash)
178 {
179 *existing_decorator
180 } else {
181 let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
182 self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
183 new_decorator_id
184 };
185
186 decorator_id_remapping
187 .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
188 }
189
190 self.decorator_id_mappings.push(decorator_id_remapping);
191
192 Ok(())
193 }
194
195 fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
196 self.mast_forest
197 .advice_map
198 .merge(&other_forest.advice_map)
199 .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
200 }
201
202 fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
203 self.mast_forest.debug_info.extend_error_codes(
204 other_forest.debug_info.error_codes().map(|(k, v)| (*k, v.clone())),
205 );
206 Ok(())
207 }
208
209 fn merge_node(
210 &mut self,
211 forest_idx: usize,
212 merging_id: MastNodeId,
213 node: MastNode,
214 original_forests: &[&MastForest],
215 ) -> Result<(), MastForestError> {
216 // We need to remap the node prior to computing the MastNodeFingerprint.
217 //
218 // This is because the MastNodeFingerprint computation looks up its descendants and
219 // decorators in the internal index, and if we were to pass the original node to
220 // that computation, it would look up the incorrect descendants and decorators
221 // (since the descendant's indices may have changed).
222 //
223 // Remapping at this point is guaranteed to be "complete", meaning all ids of children
224 // will be present in the node id mapping since the DFS iteration guarantees
225 // that all children of this `node` have been processed before this node and
226 // their indices have been added to the mappings.
227 let remapped_builder = self.build_with_remapped_children(
228 merging_id,
229 node,
230 original_forests[forest_idx],
231 &self.node_id_mappings[forest_idx],
232 &self.decorator_id_mappings[forest_idx],
233 )?;
234
235 let base_fingerprint =
236 remapped_builder.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)?;
237
238 // Augment with the source node's debug vars so same-ops/different-vars
239 // blocks from different forests are not collapsed.
240 let debug_var_data =
241 serialize_debug_var_content_for_node(original_forests[forest_idx], merging_id);
242 let asm_op_data =
243 serialize_asm_op_content_for_node(original_forests[forest_idx], merging_id);
244 let node_fingerprint = base_fingerprint
245 .augment_with_data(&debug_var_data)
246 .augment_with_data(&asm_op_data);
247
248 match self.lookup_node_by_fingerprint(&node_fingerprint) {
249 Some(matching_node_id) => {
250 // If a node with a matching fingerprint exists, then the merging node is a
251 // duplicate and we remap it to the existing node.
252 self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
253 },
254 None => {
255 // If no node with a matching fingerprint exists, then the merging node is
256 // unique and we can add it to the merged forest using builders.
257 let new_node_id = remapped_builder.add_to_forest(&mut self.mast_forest)?;
258 self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
259
260 // We need to update the indices with the newly inserted nodes
261 // since the MastNodeFingerprint computation requires all descendants of a node
262 // to be in this index. Hence when we encounter a node in the merging forest
263 // which has descendants (Call, Loop, Split, ...), then their descendants need to be
264 // in the indices.
265 self.node_id_by_hash.insert(node_fingerprint, new_node_id);
266 let returned_id = self
267 .hash_by_node_id
268 .push(node_fingerprint)
269 .map_err(|_| MastForestError::TooManyNodes)?;
270 debug_assert_eq!(
271 returned_id, new_node_id,
272 "hash_by_node_id push() should return the same node IDs as node_id_by_hash"
273 );
274 },
275 }
276
277 Ok(())
278 }
279
280 fn merge_roots(
281 &mut self,
282 forest_idx: usize,
283 other_forest: &MastForest,
284 ) -> Result<(), MastForestError> {
285 for root_id in other_forest.roots.iter() {
286 // Map the previous root to its possibly new id.
287 let new_root = self.node_id_mappings[forest_idx]
288 .get(*root_id)
289 .expect("all node ids should have an entry");
290 // This takes O(n) where n is the number of roots in the merged forest every time to
291 // check if the root already exists. As the number of roots is relatively low generally,
292 // this should be okay.
293 self.mast_forest.make_root(new_root);
294 }
295
296 Ok(())
297 }
298
299 /// Transfers procedure names, asm ops, and debug vars from the source forests
300 /// into the merged forest, remapping all IDs along the way.
301 ///
302 /// Procedure names are merged separately by digest. Per-node asm-op and debug-var
303 /// metadata are remapped by node ID, and when two source nodes map to the same merged
304 /// node (dedup), the first forest's per-node metadata wins.
305 fn merge_debug_metadata(&mut self, forests: &[&MastForest]) -> Result<(), MastForestError> {
306 // Procedure names are keyed by digest. First name wins so that a later
307 // forest cannot silently rename an already-registered procedure.
308 for forest in forests.iter() {
309 for (digest, name) in forest.debug_info.procedure_names() {
310 if self.mast_forest.debug_info.procedure_name(&digest).is_none() {
311 self.mast_forest.debug_info.insert_procedure_name(digest, name.clone());
312 }
313 }
314 }
315
316 // Collect per-node asm-op and debug-var registrations across all forests.
317 // BTreeMap gives us sorted-by-node-id iteration, which the CSR requires.
318 let mut asm_entries: BTreeMap<MastNodeId, Vec<(usize, AsmOpId)>> = BTreeMap::new();
319 let mut dbg_entries: BTreeMap<MastNodeId, Vec<(usize, DebugVarId)>> = BTreeMap::new();
320
321 for (forest_idx, forest) in forests.iter().enumerate() {
322 // Copy AssemblyOp objects and build old→new AsmOpId remapping.
323 let mut asm_id_map: BTreeMap<AsmOpId, AsmOpId> = BTreeMap::new();
324 for (raw, asm_op) in forest.debug_info.asm_ops().iter().enumerate() {
325 let old_id = AsmOpId::new(raw as u32);
326 let new_id = self.mast_forest.debug_info.add_asm_op(asm_op.clone())?;
327 asm_id_map.insert(old_id, new_id);
328 }
329
330 // Copy DebugVarInfo objects and build old→new DebugVarId remapping.
331 let mut dbg_id_map: BTreeMap<DebugVarId, DebugVarId> = BTreeMap::new();
332 for (raw, dvar) in forest.debug_info.debug_vars().iter().enumerate() {
333 let old_id = DebugVarId::from(raw as u32);
334 let new_id = self.mast_forest.debug_info.add_debug_var(dvar.clone())?;
335 dbg_id_map.insert(old_id, new_id);
336 }
337
338 // For each source node, remap and store entries. First forest wins per node.
339 for old_raw in 0..forest.num_nodes() {
340 let old_id = MastNodeId::new_unchecked(old_raw);
341 let new_id = match self.node_id_mappings[forest_idx].get(old_id) {
342 Some(id) => id,
343 None => continue,
344 };
345
346 if let alloc::collections::btree_map::Entry::Vacant(e) = asm_entries.entry(new_id) {
347 let ops = forest.debug_info.asm_ops_for_node(old_id);
348 if !ops.is_empty() {
349 let remapped =
350 ops.into_iter().map(|(idx, id)| (idx, asm_id_map[&id])).collect();
351 e.insert(remapped);
352 }
353 }
354
355 if let alloc::collections::btree_map::Entry::Vacant(e) = dbg_entries.entry(new_id) {
356 let vars = forest.debug_info.debug_vars_for_node(old_id);
357 if !vars.is_empty() {
358 let remapped =
359 vars.into_iter().map(|(idx, id)| (idx, dbg_id_map[&id])).collect();
360 e.insert(remapped);
361 }
362 }
363 }
364 }
365
366 // Register in node-ID order (CSR sequential constraint).
367 for (node_id, entries) in asm_entries {
368 let num_ops = match &self.mast_forest[node_id] {
369 MastNode::Block(block) => block.num_operations() as usize,
370 _ => entries.iter().map(|(idx, _)| idx + 1).max().unwrap_or(0),
371 };
372 self.mast_forest
373 .debug_info
374 .register_asm_ops(node_id, num_ops, entries)
375 .map_err(|_| MastForestError::TooManyNodes)?;
376 }
377
378 for (node_id, entries) in dbg_entries {
379 self.mast_forest
380 .debug_info
381 .register_op_indexed_debug_vars(node_id, entries)
382 .map_err(|_| MastForestError::TooManyNodes)?;
383 }
384
385 Ok(())
386 }
387
388 // HELPERS
389 // ================================================================================================
390
391 /// Returns the ID of the node in the merged forest that matches the given
392 /// fingerprint, if any.
393 fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
394 self.node_id_by_hash.get(fingerprint).copied()
395 }
396
397 /// Builds a new node with remapped children and decorators using the provided mappings.
398 fn build_with_remapped_children(
399 &self,
400 merging_id: MastNodeId,
401 src: MastNode,
402 original_forest: &MastForest,
403 nmap: &DenseIdMap<MastNodeId, MastNodeId>,
404 dmap: &DenseIdMap<DecoratorId, DecoratorId>,
405 ) -> Result<MastNodeBuilder, MastForestError> {
406 super::build_node_with_remapped_ids(merging_id, src, original_forest, nmap, dmap)
407 }
408}
409
410// HELPERS
411// ================================================================================================
412
413/// Serializes the actual debug var *content* (name, location, etc.) for a node,
414/// producing a stable byte sequence suitable for fingerprint augmentation.
415///
416/// Unlike the assembler's `serialize_debug_vars` (which serializes `(op_idx, DebugVarId)` pairs),
417/// this serializes the resolved DebugVarInfo so that two forests assigning different DebugVarIds
418/// to identical variables still produce the same fingerprint contribution.
419fn serialize_debug_var_content_for_node(forest: &MastForest, node_id: MastNodeId) -> Vec<u8> {
420 let entries = forest.debug_info().debug_vars_for_node(node_id);
421 if entries.is_empty() {
422 return Vec::new();
423 }
424
425 let mut data = Vec::new();
426 for (op_idx, var_id) in entries {
427 data.extend_from_slice(&op_idx.to_le_bytes());
428 if let Some(info) = forest.debug_info().debug_var(var_id) {
429 info.write_into(&mut data);
430 }
431 }
432 data
433}
434
435/// Serializes the actual asm-op content for a node, producing a stable byte
436/// sequence suitable for fingerprint augmentation.
437///
438/// This ensures that nodes with identical structure but different source-mapping
439/// metadata do not collapse during merge.
440fn serialize_asm_op_content_for_node(forest: &MastForest, node_id: MastNodeId) -> Vec<u8> {
441 let entries = forest.debug_info().asm_ops_for_node(node_id);
442 if entries.is_empty() {
443 return Vec::new();
444 }
445
446 let mut data = Vec::new();
447 for (op_idx, asm_op_id) in entries {
448 data.extend_from_slice(&op_idx.to_le_bytes());
449 if let Some(asm_op) = forest.debug_info().asm_op(asm_op_id) {
450 asm_op.context_name().write_into(&mut data);
451 asm_op.op().write_into(&mut data);
452 asm_op.num_cycles().write_into(&mut data);
453 match asm_op.location() {
454 Some(location) => {
455 data.push(1);
456 location.uri.write_into(&mut data);
457 data.extend_from_slice(&u32::from(location.start).to_le_bytes());
458 data.extend_from_slice(&u32::from(location.end).to_le_bytes());
459 },
460 None => data.push(0),
461 }
462 }
463 }
464 data
465}
466
467// MAST FOREST ROOT MAP
468// ================================================================================================
469
470/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
471///
472/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
473/// forest. See [`MastForest::merge`] for more details.
474#[derive(Debug, Clone, PartialEq, Eq)]
475pub struct MastForestRootMap {
476 root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
477}
478
479impl MastForestRootMap {
480 fn from_node_id_map(
481 id_map: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
482 forests: Vec<&MastForest>,
483 ) -> Self {
484 let mut root_maps = vec![BTreeMap::new(); forests.len()];
485
486 for (forest_idx, forest) in forests.into_iter().enumerate() {
487 for root in forest.procedure_roots() {
488 let new_id = id_map[forest_idx]
489 .get(*root)
490 .expect("every node id should be mapped to its new id");
491 root_maps[forest_idx].insert(*root, new_id);
492 }
493 }
494
495 Self { root_maps }
496 }
497
498 /// Maps the given root to its new location in the merged forest, if such a mapping exists.
499 ///
500 /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
501 pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
502 self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
503 }
504}