miden_core/mast/merger/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_crypto::{hash::blake::Blake3Digest, utils::collections::KvMap};
4
5use crate::mast::{
6    DecoratorId, MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId,
7    MultiMastForestIteratorItem, MultiMastForestNodeIter,
8};
9
10#[cfg(test)]
11mod tests;
12
13/// A type that allows merging [`MastForest`]s.
14///
15/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details.
16pub(crate) struct MastForestMerger {
17    mast_forest: MastForest,
18    // Internal indices needed for efficient duplicate checking and MastNodeFingerprint
19    // computation.
20    //
21    // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the
22    // `mast_forest` are also added to the indices.
23    node_id_by_hash: BTreeMap<MastNodeFingerprint, MastNodeId>,
24    hash_by_node_id: BTreeMap<MastNodeId, MastNodeFingerprint>,
25    decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
26    /// Mappings from old decorator and node ids to their new ids.
27    ///
28    /// Any decorator in `mast_forest` is present as the target of some mapping in this map.
29    decorator_id_mappings: Vec<DecoratorIdMap>,
30    /// Mappings from previous `MastNodeId`s to their new ids.
31    ///
32    /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map.
33    node_id_mappings: Vec<MastForestNodeIdMap>,
34}
35
36impl MastForestMerger {
37    /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s
38    /// into it.
39    pub(crate) fn merge<'forest>(
40        forests: impl IntoIterator<Item = &'forest MastForest>,
41    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
42        let forests = forests.into_iter().collect::<Vec<_>>();
43        let decorator_id_mappings = Vec::with_capacity(forests.len());
44        let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()];
45
46        let mut merger = Self {
47            node_id_by_hash: BTreeMap::new(),
48            hash_by_node_id: BTreeMap::new(),
49            decorators_by_hash: BTreeMap::new(),
50            mast_forest: MastForest::new(),
51            decorator_id_mappings,
52            node_id_mappings,
53        };
54
55        merger.merge_inner(forests.clone())?;
56
57        let Self { mast_forest, node_id_mappings, .. } = merger;
58
59        let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);
60
61        Ok((mast_forest, root_maps))
62    }
63
64    /// Merges all `forests` into self.
65    ///
66    /// It does this in three steps:
67    ///
68    /// 1. Merge all advice maps, checking for key collisions.
69    /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
70    ///    mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
71    ///    merged forest.
72    /// 3. Merge all nodes of forests.
73    ///    - Similar to decorators, node indices might move during merging, so the merger keeps a
74    ///      node id mapping as it merges nodes.
75    ///    - This is a depth-first traversal over all forests to ensure all children are processed
76    ///      before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
77    ///      on this traversal.
78    ///    - Because all parents are processed after their children, we can use the node id mapping
79    ///      to remap all [`MastNodeId`]s of the children to their potentially new id in the merged
80    ///      forest.
81    ///    - If any external node is encountered during this traversal with a digest `foo` for which
82    ///      a `replacement` node exists in another forest with digest `foo`, then the external node
83    ///      will be replaced by that node. In particular, it means we do not want to add the
84    ///      external node to the merged forest, so it is never yielded from the iterator.
85    ///      - Assuming the simple case, where the `replacement` was not visited yet and is just a
86    ///        single node (not a tree), the iterator would first yield the `replacement` node which
87    ///        means it is going to be merged into the forest.
88    ///      - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`]
89    ///        which signals that an external node was replaced by another node. In this example,
90    ///        the `replacement_*` indices contained in that variant would point to the
91    ///        `replacement` node. Now we can simply add a mapping from the external node to the
92    ///        `replacement` node in our node id mapping which means all nodes that referenced the
93    ///        external node will point to the `replacement` instead.
94    /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
95    ///    their potentially new indices in the merged forest and add them to the forest,
96    ///    deduplicating in the process, too.
97    fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
98        for other_forest in forests.iter() {
99            self.merge_advice_map(other_forest)?;
100        }
101        for other_forest in forests.iter() {
102            self.merge_decorators(other_forest)?;
103        }
104        for other_forest in forests.iter() {
105            self.merge_error_codes(other_forest)?;
106        }
107
108        let iterator = MultiMastForestNodeIter::new(forests.clone());
109        for item in iterator {
110            match item {
111                MultiMastForestIteratorItem::Node { forest_idx, node_id } => {
112                    let node = &forests[forest_idx][node_id];
113                    self.merge_node(forest_idx, node_id, node)?;
114                },
115                MultiMastForestIteratorItem::ExternalNodeReplacement {
116                    // forest index of the node which replaces the external node
117                    replacement_forest_idx,
118                    // ID of the node that replaces the external node
119                    replacement_mast_node_id,
120                    // forest index of the external node
121                    replaced_forest_idx,
122                    // ID of the external node
123                    replaced_mast_node_id,
124                } => {
125                    // The iterator is not aware of the merged forest, so the node indices it yields
126                    // are for the existing forests. That means we have to map the ID of the
127                    // replacement to its new location, since it was previously merged and its IDs
128                    // have very likely changed.
129                    let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
130                        .get(&replacement_mast_node_id)
131                        .copied()
132                        .expect("every merged node id should be mapped");
133
134                    // SAFETY: The iterator only yields valid forest indices, so it is safe to index
135                    // directly.
136                    self.node_id_mappings[replaced_forest_idx]
137                        .insert(replaced_mast_node_id, mapped_replacement);
138                },
139            }
140        }
141
142        for (forest_idx, forest) in forests.iter().enumerate() {
143            self.merge_roots(forest_idx, forest)?;
144        }
145
146        Ok(())
147    }
148
149    fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
150        let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());
151
152        for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
153            let merging_decorator_hash = merging_decorator.fingerprint();
154            let new_decorator_id = if let Some(existing_decorator) =
155                self.decorators_by_hash.get(&merging_decorator_hash)
156            {
157                *existing_decorator
158            } else {
159                let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
160                self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
161                new_decorator_id
162            };
163
164            decorator_id_remapping
165                .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id);
166        }
167
168        self.decorator_id_mappings.push(decorator_id_remapping);
169
170        Ok(())
171    }
172
173    fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
174        for (digest, values) in other_forest.advice_map.iter() {
175            if let Some(stored_values) = self.mast_forest.advice_map().get(digest) {
176                if stored_values != values {
177                    return Err(MastForestError::AdviceMapKeyCollisionOnMerge(*digest));
178                }
179            } else {
180                self.mast_forest.advice_map_mut().insert(*digest, values.clone());
181            }
182        }
183        Ok(())
184    }
185
186    fn merge_error_codes(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
187        self.mast_forest.error_codes.extend(other_forest.error_codes.clone());
188        Ok(())
189    }
190
191    fn merge_node(
192        &mut self,
193        forest_idx: usize,
194        merging_id: MastNodeId,
195        node: &MastNode,
196    ) -> Result<(), MastForestError> {
197        // We need to remap the node prior to computing the MastNodeFingerprint.
198        //
199        // This is because the MastNodeFingerprint computation looks up its descendants and
200        // decorators in the internal index, and if we were to pass the original node to
201        // that computation, it would look up the incorrect descendants and decorators
202        // (since the descendant's indices may have changed).
203        //
204        // Remapping at this point is guaranteed to be "complete", meaning all ids of children
205        // will be present in the node id mapping since the DFS iteration guarantees
206        // that all children of this `node` have been processed before this node and
207        // their indices have been added to the mappings.
208        let remapped_node = self.remap_node(forest_idx, node)?;
209
210        let node_fingerprint = MastNodeFingerprint::from_mast_node(
211            &self.mast_forest,
212            &self.hash_by_node_id,
213            &remapped_node,
214        )
215        .expect(
216            "hash_by_node_id should contain the fingerprints of all children of `remapped_node`",
217        );
218
219        match self.lookup_node_by_fingerprint(&node_fingerprint) {
220            Some(matching_node_id) => {
221                // If a node with a matching fingerprint exists, then the merging node is a
222                // duplicate and we remap it to the existing node.
223                self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
224            },
225            None => {
226                // If no node with a matching fingerprint exists, then the merging node is
227                // unique and we can add it to the merged forest.
228                let new_node_id = self.mast_forest.add_node(remapped_node)?;
229                self.node_id_mappings[forest_idx].insert(merging_id, new_node_id);
230
231                // We need to update the indices with the newly inserted nodes
232                // since the MastNodeFingerprint computation requires all descendants of a node
233                // to be in this index. Hence when we encounter a node in the merging forest
234                // which has descendants (Call, Loop, Split, ...), then their descendants need to be
235                // in the indices.
236                self.node_id_by_hash.insert(node_fingerprint, new_node_id);
237                self.hash_by_node_id.insert(new_node_id, node_fingerprint);
238            },
239        }
240
241        Ok(())
242    }
243
244    fn merge_roots(
245        &mut self,
246        forest_idx: usize,
247        other_forest: &MastForest,
248    ) -> Result<(), MastForestError> {
249        for root_id in other_forest.roots.iter() {
250            // Map the previous root to its possibly new id.
251            let new_root = self.node_id_mappings[forest_idx]
252                .get(root_id)
253                .expect("all node ids should have an entry");
254            // This takes O(n) where n is the number of roots in the merged forest every time to
255            // check if the root already exists. As the number of roots is relatively low generally,
256            // this should be okay.
257            self.mast_forest.make_root(*new_root);
258        }
259
260        Ok(())
261    }
262
263    /// Remaps a nodes' potentially contained children and decorators to their new IDs according to
264    /// the given maps.
265    fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
266        let map_decorator_id = |decorator_id: &DecoratorId| {
267            self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| {
268                MastForestError::DecoratorIdOverflow(
269                    *decorator_id,
270                    self.decorator_id_mappings[forest_idx].len(),
271                )
272            })
273        };
274        let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
275            decorators.iter().map(map_decorator_id).collect()
276        };
277
278        let map_node_id = |node_id: MastNodeId| {
279            self.node_id_mappings[forest_idx]
280                .get(&node_id)
281                .copied()
282                .expect("every node id should have an entry")
283        };
284
285        // Due to DFS postorder iteration all children of node's should have been inserted before
286        // their parents which is why we can `expect` the constructor calls here.
287        let mut mapped_node = match node {
288            MastNode::Join(join_node) => {
289                let first = map_node_id(join_node.first());
290                let second = map_node_id(join_node.second());
291
292                MastNode::new_join(first, second, &self.mast_forest)
293                    .expect("JoinNode children should have been mapped to a lower index")
294            },
295            MastNode::Split(split_node) => {
296                let if_branch = map_node_id(split_node.on_true());
297                let else_branch = map_node_id(split_node.on_false());
298
299                MastNode::new_split(if_branch, else_branch, &self.mast_forest)
300                    .expect("SplitNode children should have been mapped to a lower index")
301            },
302            MastNode::Loop(loop_node) => {
303                let body = map_node_id(loop_node.body());
304                MastNode::new_loop(body, &self.mast_forest)
305                    .expect("LoopNode children should have been mapped to a lower index")
306            },
307            MastNode::Call(call_node) => {
308                let callee = map_node_id(call_node.callee());
309                MastNode::new_call(callee, &self.mast_forest)
310                    .expect("CallNode children should have been mapped to a lower index")
311            },
312            // Other nodes are simply copied.
313            MastNode::Block(basic_block_node) => {
314                MastNode::new_basic_block(
315                    basic_block_node.operations().copied().collect(),
316                    // Operation Indices of decorators stay the same while decorator IDs need to be
317                    // mapped.
318                    Some(
319                        basic_block_node
320                            .decorators()
321                            .iter()
322                            .map(|(idx, decorator_id)| match map_decorator_id(decorator_id) {
323                                Ok(mapped_decorator) => Ok((*idx, mapped_decorator)),
324                                Err(err) => Err(err),
325                            })
326                            .collect::<Result<Vec<_>, _>>()?,
327                    ),
328                )
329                .expect("previously valid BasicBlockNode should still be valid")
330            },
331            MastNode::Dyn(_) => MastNode::new_dyn(),
332            MastNode::External(external_node) => MastNode::new_external(external_node.digest()),
333        };
334
335        // Decorators must be handled specially for basic block nodes.
336        // For other node types we can handle it centrally.
337        if !mapped_node.is_basic_block() {
338            mapped_node.append_before_enter(&map_decorators(node.before_enter())?);
339            mapped_node.append_after_exit(&map_decorators(node.after_exit())?);
340        }
341
342        Ok(mapped_node)
343    }
344
345    // HELPERS
346    // ================================================================================================
347
348    /// Returns a slice of nodes in the merged forest which have the given `mast_root`.
349    fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option<MastNodeId> {
350        self.node_id_by_hash.get(fingerprint).copied()
351    }
352}
353
354// MAST FOREST ROOT MAP
355// ================================================================================================
356
357/// A mapping for the new location of the roots of a [`MastForest`] after a merge.
358///
359/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged
360/// forest. See [`MastForest::merge`] for more details.
361#[derive(Debug, Clone, PartialEq, Eq)]
362pub struct MastForestRootMap {
363    root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
364}
365
366impl MastForestRootMap {
367    fn from_node_id_map(id_map: Vec<MastForestNodeIdMap>, forests: Vec<&MastForest>) -> Self {
368        let mut root_maps = vec![BTreeMap::new(); forests.len()];
369
370        for (forest_idx, forest) in forests.into_iter().enumerate() {
371            for root in forest.procedure_roots() {
372                let new_id = id_map[forest_idx]
373                    .get(root)
374                    .copied()
375                    .expect("every node id should be mapped to its new id");
376                root_maps[forest_idx].insert(*root, new_id);
377            }
378        }
379
380        Self { root_maps }
381    }
382
383    /// Maps the given root to its new location in the merged forest, if such a mapping exists.
384    ///
385    /// It is guaranteed that every root of the map's corresponding forest is contained in the map.
386    pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
387        self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
388    }
389}
390
391// DECORATOR ID MAP
392// ================================================================================================
393
394/// A specialized map from [`DecoratorId`] -> [`DecoratorId`].
395///
396/// When mapping Decorator IDs during merging, we always map all IDs of the merging
397/// forest to new ids. Hence it is more efficient to use a `Vec` instead of, say, a `BTreeMap`.
398///
399/// In other words, this type is similar to `BTreeMap<ID, ID>` but takes advantage of the fact that
400/// the keys are contiguous.
401///
402/// This type is meant to encapsulates some guarantees:
403///
404/// - Indexing into the vector for any ID is safe if that ID is valid for the corresponding forest.
405///   Despite that, we still cannot index unconditionally in case a node with invalid
406///   [`DecoratorId`]s is passed to `merge`.
407/// - The entry itself can be either None or Some. However:
408///   - For `DecoratorId`s we iterate and insert all decorators into this map before retrieving any
409///     entry, so all entries contain `Some`. Because of this, we can use `expect` in `get` for the
410///     `Option` value.
411/// - Similarly, inserting any ID from the corresponding forest is safe as the map contains a
412///   pre-allocated `Vec` of the appropriate size.
413struct DecoratorIdMap {
414    inner: Vec<Option<DecoratorId>>,
415}
416
417impl DecoratorIdMap {
418    fn new(num_ids: usize) -> Self {
419        Self { inner: vec![None; num_ids] }
420    }
421
422    /// Maps the given key to the given value.
423    ///
424    /// It is the caller's responsibility to only pass keys that belong to the forest for which this
425    /// map was originally created.
426    fn insert(&mut self, key: DecoratorId, value: DecoratorId) {
427        self.inner[key.as_usize()] = Some(value);
428    }
429
430    /// Retrieves the value for the given key.
431    fn get(&self, key: &DecoratorId) -> Option<DecoratorId> {
432        self.inner
433            .get(key.as_usize())
434            .map(|id| id.expect("every id should have a Some entry in the map when calling get"))
435    }
436
437    fn len(&self) -> usize {
438        self.inner.len()
439    }
440}
441
442/// A type definition for increased readability in function signatures.
443type MastForestNodeIdMap = BTreeMap<MastNodeId, MastNodeId>;