miden_core/mast/merger/mod.rs
1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_crypto::hash::blake::Blake3Digest;
4
5use crate::{
6 DenseIdMap, IndexVec,
7 mast::{
8 BasicBlockNode, CallNode, DecoratorId, DynNode, ExternalNode, JoinNode, LoopNode,
9 MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId,
10 MultiMastForestIteratorItem, MultiMastForestNodeIter, SplitNode, node::MastNodeExt,
11 },
12};
13
14#[cfg(test)]
15mod tests;
16
17/// A type that allows merging [`MastForest`]s.
18///
19/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
20pub(crate) struct MastForestMerger {
21 mast_forest: MastForest,
22 // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
23 // computation.
24 //
25 // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
26 // `mast_forest` are also added to the indices.
27 node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
28 hash_by_node_id: IndexVec<MastNodeId, MastNodeFingerprint>,
29 decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
30 /// Mappings from old decorator and node ids to their new ids.
31 ///
32 /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
33 decorator_id_mappings: Vec<DenseIdMap<DecoratorId, DecoratorId>>,
34 /// Mappings from previous `MastNodeId`s to their new ids.
35 ///
36 /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
37 node_id_mappings: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
38}
39
40impl MastForestMerger {
41 /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
42 /// into it.
43 ///
44 /// # Normalizing Behavior
45 ///
46 /// This function performs normalization of the merged forest, which:
47 /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
48 /// - Creates a clean, deduplicated forest structure
49 /// - Provides consistent node ordering regardless of input
50 ///
51 /// This normalization is idempotent, but it means that even for single-forest merges, the
52 /// resulting forest may have different node IDs and digests than the input. See assembly
53 /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
54 /// behavior.
55 pub(crate) fn merge<'forest>(
56 forests: impl IntoIterator<Item = &'forest MastForest>,
57 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
58 let forests = forests.into_iter().collect::<Vec<_>>();
59
60 let decorator_id_mappings = Vec::with_capacity(forests.len());
61 let node_id_mappings =
62 forests.iter().map(|f| DenseIdMap::with_len(f.nodes().len())).collect();
63
64 let mut merger = Self {
65 node_id_by_hash: BTreeMap::new(),
66 hash_by_node_id: IndexVec::new(),
67 decorators_by_hash: BTreeMap::new(),
68 mast_forest: MastForest::new(),
69 decorator_id_mappings,
70 node_id_mappings,
71 };
72
73 merger.merge_inner(forests.clone())?;
74
75 let Self { mast_forest, node_id_mappings, .. } = merger;
76
77 let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
78
79 Ok((mast_forest, root_maps))
80 }
81
82 /// Merges all `forests` into self.
83 ///
84 /// It does this in three steps:
85 ///
86 /// 1. Merge all advice maps, checking for key collisions.
87 /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
88 /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
89 /// merged forest.
90 /// 3. Merge all nodes of forests.
91 /// - Similar to decorators, node indices might move during merging, so the merger keeps a
92 /// node id mapping as it merges nodes.
93 /// - This is a depth-first traversal over all forests to ensure all children are processed
94 /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
95 /// on this traversal.
96 /// - Because all parents are processed after their children, we can use the node id mapping
97 /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
98 /// forest.
99 /// - If any external node is encountered during this traversal with a digest `foo` for which
100 /// a `replacement` node exists in another forest with digest `foo`, then the external node
101 /// will be replaced by that node. In particular, it means we do not want to add the
102 /// external node to the merged forest, so it is never yielded from the iterator.
103 /// - Assuming the simple case, where the `replacement` was not visited yet and is just a
104 /// single node (not a tree), the iterator would first yield the `replacement` node which
105 /// means it is going to be merged into the forest.
106 /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
107 /// which signals that an external node was replaced by another node. In this example,
108 /// the `replacement_*` indices contained in that variant would point to the
109 /// `replacement` node. Now we can simply add a mapping from the external node to the
110 /// `replacement` node in our node id mapping which means all nodes that referenced the
111 /// external node will point to the `replacement` instead.
112 /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
113 /// their potentially new indices in the merged forest and add them to the forest,
114 /// deduplicating in the process, too.
115 fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
116 for other_forest in forests.iter() {
117 self.merge_advice_map(other_forest)?;
118 }
119 for other_forest in forests.iter() {
120 self.merge_decorators(other_forest)?;
121 }
122 for other_forest in forests.iter() {
123 self.merge_error_codes(other_forest)?;
124 }
125
126 let iterator = MultiMastForestNodeIter::new(forests.clone());
127 for item in iterator {
128 match item {
129 MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
130 let node = &forests[forest_idx][node_id];
131 self.merge_node(forest_idx, node_id, node)?;
132 },
133 MultiMastForestIteratorItem::ExternalNodeReplacement {
134 // forest index of the node which replaces the external node
135 replacement_forest_idx,
136 // ID of the node that replaces the external node
137 replacement_mast_node_id,
138 // forest index of the external node
139 replaced_forest_idx,
140 // ID of the external node
141 replaced_mast_node_id,
142 } => {
143 // The iterator is not aware of the merged forest, so the node indices it yields
144 // are for the existing forests. That means we have to map the ID of the
145 // replacement to its new location, since it was previously merged and its IDs
146 // have very likely changed.
147 let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
148 .get(replacement_mast_node_id)
149 .expect("every merged node id should be mapped");
150
151 // SAFETY: The iterator only yields valid forest indices, so it is safe to index
152 // directly.
153 self.node_id_mappings[replaced_forest_idx]
154 .insert(replaced_mast_node_id, mapped_replacement);
155 },
156 }
157 }
158
159 for (forest_idx, forest) in forests.iter().enumerate() {
160 self.merge_roots(forest_idx, forest)?;
161 }
162
163 Ok(())
164 }
165
166 fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
167 let mut decorator_id_remapping = DenseIdMap::with_len(other_forest.decorators.len());
168
169 for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
170 let merging_decorator_hash = merging_decorator.fingerprint();
171 let new_decorator_id = if let Some(existing_decorator) =
172 self.decorators_by_hash.get(&merging_decorator_hash)
173 {
174 *existing_decorator
175 } else {
176 let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
177 self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
178 new_decorator_id
179 };
180
181 decorator_id_remapping
182 .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
183 }
184
185 self.decorator_id_mappings.push(decorator_id_remapping);
186
187 Ok(())
188 }
189
190 fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
191 self.mast_forest
192 .advice_map
193 .merge(&other_forest.advice_map)
194 .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
195 }
196
197 fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
198 self.mast_forest.error_codes.extend(other_forest.error_codes.clone());
199 Ok(())
200 }
201
202 fn merge_node(
203 &mut self,
204 forest_idx: usize,
205 merging_id: MastNodeId,
206 node: &MastNode,
207 ) -> Result<(), MastForestError> {
208 // We need to remap the node prior to computing the MastNodeFingerprint.
209 //
210 // This is because the MastNodeFingerprint computation looks up its descendants and
211 // decorators in the internal index, and if we were to pass the original node to
212 // that computation, it would look up the incorrect descendants and decorators
213 // (since the descendant's indices may have changed).
214 //
215 // Remapping at this point is guaranteed to be "complete", meaning all ids of children
216 // will be present in the node id mapping since the DFS iteration guarantees
217 // that all children of this `node` have been processed before this node and
218 // their indices have been added to the mappings.
219 let remapped_node = self.remap_node(forest_idx, node)?;
220
221 let node_fingerprint = MastNodeFingerprint::from_mast_node(
222 &self.mast_forest,
223 &self.hash_by_node_id,
224 &remapped_node,
225 )
226 .expect(
227 "hash_by_node_id should contain the fingerprints of all children of `remapped_node`",
228 );
229
230 match self.lookup_node_by_fingerprint(&node_fingerprint) {
231 Some(matching_node_id) => {
232 // If a node with a matching fingerprint exists, then the merging node is a
233 // duplicate and we remap it to the existing node.
234 self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
235 },
236 None => {
237 // If no node with a matching fingerprint exists, then the merging node is
238 // unique and we can add it to the merged forest.
239 let new_node_id = self.mast_forest.add_node(remapped_node)?;
240 self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
241
242 // We need to update the indices with the newly inserted nodes
243 // since the MastNodeFingerprint computation requires all descendants of a node
244 // to be in this index. Hence when we encounter a node in the merging forest
245 // which has descendants (Call, Loop, Split, ...), then their descendants need to be
246 // in the indices.
247 self.node_id_by_hash.insert(node_fingerprint, new_node_id);
248 let returned_id = self
249 .hash_by_node_id
250 .push(node_fingerprint)
251 .map_err(|_| MastForestError::TooManyNodes)?;
252 debug_assert_eq!(
253 returned_id, new_node_id,
254 "hash_by_node_id push() should return the same node IDs as node_id_by_hash"
255 );
256 },
257 }
258
259 Ok(())
260 }
261
262 fn merge_roots(
263 &mut self,
264 forest_idx: usize,
265 other_forest: &MastForest,
266 ) -> Result<(), MastForestError> {
267 for root_id in other_forest.roots.iter() {
268 // Map the previous root to its possibly new id.
269 let new_root = self.node_id_mappings[forest_idx]
270 .get(*root_id)
271 .expect("all node ids should have an entry");
272 // This takes O(n) where n is the number of roots in the merged forest every time to
273 // check if the root already exists. As the number of roots is relatively low generally,
274 // this should be okay.
275 self.mast_forest.make_root(new_root);
276 }
277
278 Ok(())
279 }
280
281 /// Remaps a nodes' potentially contained children and decorators to their new IDs according to
282 /// the given maps.
283 fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
284 self.build_node_with_remapped_children(
285 node,
286 &self.node_id_mappings[forest_idx],
287 &self.decorator_id_mappings[forest_idx],
288 )
289 }
290
291 // HELPERS
292 // ================================================================================================
293
294 /// Remaps a child node ID using the node ID map, returning the original ID if not found.
295 fn remap_child(
296 &self,
297 child_id: MastNodeId,
298 nmap: &DenseIdMap<MastNodeId, MastNodeId>,
299 ) -> MastNodeId {
300 nmap.get(child_id).expect("every node id should have an entry")
301 }
302
303 /// Returns the ID of the node in the merged forest that matches the given
304 /// fingerprint, if any.
305 fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
306 self.node_id_by_hash.get(fingerprint).copied()
307 }
308
309 /// Builds a new node with remapped children and decorators using the provided mappings.
310 fn build_node_with_remapped_children(
311 &self,
312 src: &MastNode,
313 nmap: &DenseIdMap<MastNodeId, MastNodeId>,
314 dmap: &DenseIdMap<DecoratorId, DecoratorId>,
315 ) -> Result<MastNode, MastForestError> {
316 let map_decorator_id = |decorator_id: DecoratorId| {
317 dmap.get(decorator_id)
318 .ok_or_else(|| MastForestError::DecoratorIdOverflow(decorator_id, dmap.len()))
319 };
320
321 let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
322 decorators.iter().copied().map(map_decorator_id).collect()
323 };
324
325 let mut mapped_node: MastNode = match src {
326 MastNode::Join(join_node) => {
327 let first = self.remap_child(join_node.first(), nmap);
328 let second = self.remap_child(join_node.second(), nmap);
329
330 JoinNode::new([first, second], &self.mast_forest)
331 .expect("JoinNode children should have been mapped to a lower index")
332 .into()
333 },
334 MastNode::Split(split_node) => {
335 let if_branch = self.remap_child(split_node.on_true(), nmap);
336 let else_branch = self.remap_child(split_node.on_false(), nmap);
337
338 SplitNode::new([if_branch, else_branch], &self.mast_forest)
339 .expect("SplitNode children should have been mapped to a lower index")
340 .into()
341 },
342 MastNode::Loop(loop_node) => {
343 let body = self.remap_child(loop_node.body(), nmap);
344 LoopNode::new(body, &self.mast_forest)
345 .expect("LoopNode children should have been mapped to a lower index")
346 .into()
347 },
348 MastNode::Call(call_node) => {
349 let callee = self.remap_child(call_node.callee(), nmap);
350 CallNode::new(callee, &self.mast_forest)
351 .expect("CallNode children should have been mapped to a lower index")
352 .into()
353 },
354 MastNode::Block(basic_block_node) => BasicBlockNode::new(
355 basic_block_node.operations().copied().collect(),
356 basic_block_node
357 .indexed_decorator_iter()
358 .map(|(idx, decorator_id)| {
359 let mapped_decorator = map_decorator_id(decorator_id)?;
360 Ok((idx, mapped_decorator))
361 })
362 .collect::<Result<Vec<_>, _>>()?,
363 )
364 .expect("previously valid BasicBlockNode should still be valid")
365 .into(),
366 MastNode::Dyn(_) => DynNode::new_dyn().into(),
367 MastNode::External(external_node) => ExternalNode::new(external_node.digest()).into(),
368 };
369
370 // Decorators must be handled specially for basic block nodes.
371 // For other node types we can handle it centrally.
372 {
373 mapped_node.append_before_enter(&map_decorators(src.before_enter())?);
374 mapped_node.append_after_exit(&map_decorators(src.after_exit())?);
375 }
376
377 Ok(mapped_node)
378 }
379}
380
381// MAST FOREST ROOT MAP
382// ================================================================================================
383
384/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
385///
386/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
387/// forest. See [`MastForest::merge`] for more details.
388#[derive(Debug, Clone, PartialEq, Eq)]
389pub struct MastForestRootMap {
390 root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
391}
392
393impl MastForestRootMap {
394 fn from_node_id_map(
395 id_map: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
396 forests: Vec<&MastForest>,
397 ) -> Self {
398 let mut root_maps = vec![BTreeMap::new(); forests.len()];
399
400 for (forest_idx, forest) in forests.into_iter().enumerate() {
401 for root in forest.procedure_roots() {
402 let new_id = id_map[forest_idx]
403 .get(*root)
404 .expect("every node id should be mapped to its new id");
405 root_maps[forest_idx].insert(*root, new_id);
406 }
407 }
408
409 Self { root_maps }
410 }
411
412 /// Maps the given root to its new location in the merged forest, if such a mapping exists.
413 ///
414 /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
415 pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
416 self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
417 }
418}