miden_core/mast/merger/mod.rs
1use alloc::{collections::BTreeMap, vec::Vec};
2
3use crate::{
4 Word,
5 mast::{
6 MastForest, MastForestContributor, MastForestError, MastNode, MastNodeBuilder, MastNodeId,
7 MultiMastForestIteratorItem, MultiMastForestNodeIter,
8 },
9 utils::{DenseIdMap, IndexVec},
10};
11
12#[cfg(test)]
13mod tests;
14
15/// A type that allows merging [`MastForest`]s.
16///
17/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
18pub(crate) struct MastForestMerger {
19 mast_forest: MastForest,
20 // Internal indices needed for efficient duplicate checking.
21 //
22 // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
23 // `mast_forest` are also added to the indices.
24 node_id_by_hash: BTreeMap<Word, MastNodeId>,
25 hash_by_node_id: IndexVec<MastNodeId, Word>,
26 /// Mappings from previous `MastNodeId`s to their new ids.
27 ///
28 /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
29 node_id_mappings: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
30}
31
32impl MastForestMerger {
33 /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
34 /// into it.
35 ///
36 /// # Normalizing Behavior
37 ///
38 /// This function performs normalization of the merged forest, which:
39 /// - Remaps all node IDs to maintain the invariant that child node IDs < parent node IDs
40 /// - Creates a clean, deduplicated forest structure
41 /// - Provides consistent node ordering regardless of input
42 ///
43 /// This normalization is idempotent, but it means that even for single-forest merges, the
44 /// resulting forest may have different node IDs and digests than the input. See assembly
45 /// test `issue_1644_single_forest_merge_identity` for detailed explanation of this
46 /// behavior.
47 pub(crate) fn merge<'forest>(
48 forests: impl IntoIterator<Item = &'forest MastForest>,
49 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
50 let forests = forests.into_iter().collect::<Vec<_>>();
51
52 let node_id_mappings =
53 forests.iter().map(|f| DenseIdMap::with_len(f.nodes().len())).collect();
54
55 let mut merger = Self {
56 node_id_by_hash: BTreeMap::new(),
57 hash_by_node_id: IndexVec::new(),
58 mast_forest: MastForest::new(),
59 node_id_mappings,
60 };
61
62 merger.merge_inner(forests.clone())?;
63
64 let Self { mast_forest, node_id_mappings, .. } = merger;
65
66 let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
67
68 Ok((mast_forest, root_maps))
69 }
70
71 /// Merges all `forests` into self.
72 ///
73 /// It does this in three steps:
74 ///
75 /// 1. Merge all advice maps, checking for key collisions.
76 /// 2. Merge all nodes of forests.
77 /// - Node indices might move during merging, so the merger keeps a node id mapping as it
78 /// merges nodes.
79 /// - This is a depth-first traversal over all forests to ensure all children are processed
80 /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
81 /// on this traversal.
82 /// - Because all parents are processed after their children, we can use the node id mapping
83 /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
84 /// forest.
85 /// - If any external node is encountered during this traversal with a digest `foo` for which
86 /// a `replacement` node exists in another forest with digest `foo`, then the external node
87 /// will be replaced by that node. In particular, it means we do not want to add the
88 /// external node to the merged forest, so it is never yielded from the iterator.
89 /// - Assuming the simple case, where the `replacement` was not visited yet and is just a
90 /// single node (not a tree), the iterator would first yield the `replacement` node which
91 /// means it is going to be merged into the forest.
92 /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
93 /// which signals that an external node was replaced by another node. In this example,
94 /// the `replacement_*` indices contained in that variant would point to the
95 /// `replacement` node. Now we can simply add a mapping from the external node to the
96 /// `replacement` node in our node id mapping which means all nodes that referenced the
97 /// external node will point to the `replacement` instead.
98 /// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to
99 /// their potentially new indices in the merged forest and add them to the forest,
100 /// deduplicating in the process, too.
101 fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
102 for other_forest in forests.iter() {
103 self.merge_advice_map(other_forest)?;
104 }
105 let iterator = MultiMastForestNodeIter::new(forests.clone());
106 for item in iterator {
107 match item {
108 MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
109 let node = forests[forest_idx][node_id].clone();
110 self.merge_node(forest_idx, node_id, node, &forests)?;
111 },
112 MultiMastForestIteratorItem::ExternalNodeReplacement {
113 // forest index of the node which replaces the external node
114 replacement_forest_idx,
115 // ID of the node that replaces the external node
116 replacement_mast_node_id,
117 // forest index of the external node
118 replaced_forest_idx,
119 // ID of the external node
120 replaced_mast_node_id,
121 } => {
122 // The iterator is not aware of the merged forest, so the node indices it yields
123 // are for the existing forests. That means we have to map the ID of the
124 // replacement to its new location, since it was previously merged and its IDs
125 // have very likely changed.
126 let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
127 .get(replacement_mast_node_id)
128 .expect("every merged node id should be mapped");
129
130 // SAFETY: The iterator only yields valid forest indices, so it is safe to index
131 // directly.
132 self.node_id_mappings[replaced_forest_idx]
133 .insert(replaced_mast_node_id, mapped_replacement);
134 },
135 }
136 }
137
138 for (forest_idx, forest) in forests.iter().enumerate() {
139 self.merge_roots(forest_idx, forest);
140 }
141
142 Ok(())
143 }
144
145 fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
146 self.mast_forest
147 .advice_map
148 .merge(&other_forest.advice_map)
149 .map_err(|((key, _prev), _new)| MastForestError::AdviceMapKeyCollisionOnMerge(key))
150 }
151
152 fn merge_node(
153 &mut self,
154 forest_idx: usize,
155 merging_id: MastNodeId,
156 node: MastNode,
157 original_forests: &[&MastForest],
158 ) -> Result<(), MastForestError> {
159 // We need to remap the node prior to computing the node fingerprint since child IDs may
160 // have changed in the merged forest.
161 //
162 // Remapping at this point is guaranteed to be "complete", meaning all IDs of children
163 // will be present in the node id mapping since the DFS iteration guarantees
164 // that all children of this `node` have been processed before this node and
165 // their indices have been added to the mappings.
166 let remapped_builder = self.build_with_remapped_children(
167 merging_id,
168 node,
169 original_forests[forest_idx],
170 &self.node_id_mappings[forest_idx],
171 )?;
172
173 let node_fingerprint =
174 remapped_builder.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)?;
175
176 match self.lookup_node_by_fingerprint(&node_fingerprint) {
177 Some(matching_node_id) => {
178 // If a node with a matching fingerprint exists, then the merging node is a
179 // duplicate and we remap it to the existing node.
180 self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
181 },
182 None => {
183 // If no node with a matching fingerprint exists, then the merging node is
184 // unique and we can add it to the merged forest using builders.
185 let new_node_id = remapped_builder.add_to_forest(&mut self.mast_forest)?;
186 self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
187
188 self.node_id_by_hash.insert(node_fingerprint, new_node_id);
189 let returned_id = self
190 .hash_by_node_id
191 .push(node_fingerprint)
192 .map_err(|_| MastForestError::TooManyNodes)?;
193 debug_assert_eq!(
194 returned_id, new_node_id,
195 "hash_by_node_id push() should return the same node IDs as node_id_by_hash"
196 );
197 },
198 }
199
200 Ok(())
201 }
202
203 fn merge_roots(&mut self, forest_idx: usize, other_forest: &MastForest) {
204 for root_id in other_forest.roots.iter() {
205 // Map the previous root to its possibly new id.
206 let new_root = self.node_id_mappings[forest_idx]
207 .get(*root_id)
208 .expect("all node ids should have an entry");
209 // This takes O(n) where n is the number of roots in the merged forest every time to
210 // check if the root already exists. As the number of roots is relatively low generally,
211 // this should be okay.
212 self.mast_forest.mark_root(new_root);
213 }
214 }
215
216 // HELPERS
217 // ================================================================================================
218
219 /// Returns the ID of the node in the merged forest that matches the given
220 /// fingerprint, if any.
221 fn lookup_node_by_fingerprint(&self, fingerprint: &Word) -> Option<MastNodeId> {
222 self.node_id_by_hash.get(fingerprint).copied()
223 }
224
225 /// Builds a new node with remapped children using the provided mappings.
226 fn build_with_remapped_children(
227 &self,
228 merging_id: MastNodeId,
229 src: MastNode,
230 original_forest: &MastForest,
231 nmap: &DenseIdMap<MastNodeId, MastNodeId>,
232 ) -> Result<MastNodeBuilder, MastForestError> {
233 super::build_node_with_remapped_ids(merging_id, src, original_forest, nmap)
234 }
235}
236
237// MAST FOREST ROOT MAP
238// ================================================================================================
239
240/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
241///
242/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
243/// forest. See [`MastForest::merge`] for more details.
244#[derive(Debug, Clone, Default, PartialEq, Eq)]
245pub struct MastForestRootMap {
246 node_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
247 root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
248}
249
250impl MastForestRootMap {
251 fn from_node_id_map(
252 id_map: Vec<DenseIdMap<MastNodeId, MastNodeId>>,
253 forests: Vec<&MastForest>,
254 ) -> Self {
255 let mut node_maps = vec![BTreeMap::new(); forests.len()];
256 let mut root_maps = vec![BTreeMap::new(); forests.len()];
257
258 for (forest_idx, forest) in forests.into_iter().enumerate() {
259 for (node_idx, _) in forest.nodes().iter().enumerate() {
260 let node_id = MastNodeId::new_unchecked(
261 node_idx.try_into().expect("MastForest node index exceeds u32"),
262 );
263 if let Some(new_id) = id_map[forest_idx].get(node_id) {
264 node_maps[forest_idx].insert(node_id, new_id);
265 }
266 }
267 for root in forest.procedure_roots() {
268 let new_id = id_map[forest_idx]
269 .get(*root)
270 .expect("every node id should be mapped to its new id");
271 root_maps[forest_idx].insert(*root, new_id);
272 }
273 }
274
275 Self { node_maps, root_maps }
276 }
277
278 /// Maps any node from the given input forest to its new location in the merged forest.
279 ///
280 /// This includes non-root nodes, which is required when remapping package-owned source/debug
281 /// graphs after [`MastForest::merge`].
282 pub fn map_node(&self, forest_index: usize, node: &MastNodeId) -> Option<MastNodeId> {
283 self.node_maps.get(forest_index).and_then(|map| map.get(node)).copied()
284 }
285
286 /// Maps the given root to its new location in the merged forest, if such a mapping exists.
287 ///
288 /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
289 pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
290 self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
291 }
292}