miden_crypto/merkle/partial_mt/
mod.rs1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 string::String,
4 vec::Vec,
5};
6use core::fmt;
7
8use super::{
9 EMPTY_WORD, InnerNodeInfo, MerkleError, MerklePath, MerkleProof, NodeIndex, Rpo256, Word,
10};
11use crate::utils::{
12 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, word_to_hex,
13};
14
15#[cfg(test)]
16mod tests;
17
18const ROOT_INDEX: NodeIndex = NodeIndex::root();
23
24const EMPTY_DIGEST: Word = EMPTY_WORD;
26
27#[derive(Debug, Clone, PartialEq, Eq)]
35#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
36pub struct PartialMerkleTree {
37 max_depth: u8,
38 nodes: BTreeMap<NodeIndex, Word>,
39 leaves: BTreeSet<NodeIndex>,
40}
41
42impl Default for PartialMerkleTree {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl PartialMerkleTree {
49 pub const MIN_DEPTH: u8 = 1;
54
55 pub const MAX_DEPTH: u8 = 64;
57
58 pub fn new() -> Self {
63 PartialMerkleTree {
64 max_depth: 0,
65 nodes: BTreeMap::new(),
66 leaves: BTreeSet::new(),
67 }
68 }
69
70 pub fn with_paths<I>(paths: I) -> Result<Self, MerkleError>
74 where
75 I: IntoIterator<Item = (u64, Word, MerklePath)>,
76 {
77 let tree = PartialMerkleTree::new();
79
80 paths.into_iter().try_fold(tree, |mut tree, (index, value, path)| {
81 tree.add_path(index, value, path)?;
82 Ok(tree)
83 })
84 }
85
86 pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
95 where
96 R: IntoIterator<IntoIter = I>,
97 I: Iterator<Item = (NodeIndex, Word)> + ExactSizeIterator,
98 {
99 let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
100 let mut leaves = BTreeSet::new();
101 let mut nodes = BTreeMap::new();
102
103 for (node_index, hash) in entries.into_iter() {
106 leaves.insert(node_index);
107 nodes.insert(node_index, hash);
108 layers
109 .entry(node_index.depth())
110 .and_modify(|layer_vec| layer_vec.push(node_index.value()))
111 .or_insert(vec![node_index.value()]);
112 }
113
114 if let Some(last_layer) = layers.last_entry() {
116 let last_layer_depth = *last_layer.key();
117 if last_layer_depth > 64 {
118 return Err(MerkleError::TooManyEntries(last_layer_depth));
119 }
120 }
121
122 let max_depth = *layers.keys().next_back().unwrap_or(&0);
124
125 for depth in 0..max_depth {
127 layers.entry(depth).or_default();
128 }
129
130 let mut layer_iter = layers.into_values().rev();
131 let mut parent_layer = layer_iter.next().unwrap();
132 let mut current_layer;
133
134 for depth in (1..max_depth + 1).rev() {
135 current_layer = layer_iter.next().unwrap();
137 core::mem::swap(&mut current_layer, &mut parent_layer);
138
139 for index_value in current_layer {
140 let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
142
143 if !parent_layer.contains(&parent_node.value()) {
146 let index = NodeIndex::new(depth, index_value)?;
148
149 let node =
151 nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
152 let sibling = nodes
154 .get(&index.sibling())
155 .ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
156 let parent = Rpo256::merge(&index.build_node(*node, *sibling));
158
159 parent_layer.push(parent_node.value());
161 nodes.insert(parent_node, parent);
163 }
164 }
165 }
166
167 Ok(PartialMerkleTree { max_depth, nodes, leaves })
168 }
169
170 pub fn root(&self) -> Word {
175 self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST)
176 }
177
178 pub fn max_depth(&self) -> u8 {
180 self.max_depth
181 }
182
183 pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
188 self.nodes
189 .get(&index)
190 .ok_or(MerkleError::NodeIndexNotFoundInTree(index))
191 .copied()
192 }
193
194 pub fn is_leaf(&self, index: NodeIndex) -> bool {
196 self.leaves.contains(&index)
197 }
198
199 pub fn to_paths(&self) -> Vec<(NodeIndex, MerkleProof)> {
201 let mut paths = Vec::new();
202 self.leaves.iter().for_each(|&leaf| {
203 paths.push((
204 leaf,
205 MerkleProof {
206 value: self.get_node(leaf).expect("Failed to get leaf node"),
207 path: self.get_path(leaf).expect("Failed to get path"),
208 },
209 ));
210 });
211 paths
212 }
213
214 pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
224 if index.is_root() {
225 return Err(MerkleError::DepthTooSmall(index.depth()));
226 } else if index.depth() > self.max_depth() {
227 return Err(MerkleError::DepthTooBig(index.depth() as u64));
228 }
229
230 if !self.nodes.contains_key(&index) {
231 return Err(MerkleError::NodeIndexNotFoundInTree(index));
232 }
233
234 let mut path = Vec::new();
235 for _ in 0..index.depth() {
236 let sibling_index = index.sibling();
237 index.move_up();
238 let sibling =
239 self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map");
240 path.push(sibling);
241 }
242 Ok(MerklePath::new(path))
243 }
244
245 pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, Word)> + '_ {
250 self.leaves.iter().map(|&leaf| {
251 (
252 leaf,
253 self.get_node(leaf)
254 .unwrap_or_else(|_| panic!("Leaf with {leaf} is not in the nodes map")),
255 )
256 })
257 }
258
259 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
261 let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
262 inner_nodes.map(|(index, digest)| {
263 let left_hash =
264 self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
265 let right_hash =
266 self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
267 InnerNodeInfo {
268 value: *digest,
269 left: *left_hash,
270 right: *right_hash,
271 }
272 })
273 }
274
275 pub fn add_path(
287 &mut self,
288 index_value: u64,
289 value: Word,
290 path: MerklePath,
291 ) -> Result<(), MerkleError> {
292 let index_value = NodeIndex::new(path.len() as u8, index_value)?;
293
294 Self::check_depth(index_value.depth())?;
295 self.update_depth(index_value.depth());
296
297 self.leaves.insert(index_value);
299 let sibling_node_index = index_value.sibling();
300 self.leaves.insert(sibling_node_index);
301
302 self.nodes.insert(index_value, value);
304 self.nodes.insert(sibling_node_index, path[0]);
305
306 let mut index_value = index_value;
308 let node = Rpo256::merge(&index_value.build_node(value, path[0]));
309 let root = path.iter().skip(1).copied().fold(node, |node, hash| {
310 index_value.move_up();
311 self.nodes.insert(index_value, node);
313
314 self.leaves.remove(&index_value);
316
317 let sibling_node = index_value.sibling();
318
319 if self.nodes.insert(sibling_node, hash).is_none() {
331 self.leaves.insert(sibling_node);
332 }
333
334 Rpo256::merge(&index_value.build_node(node, hash))
335 });
336
337 if self.root() == EMPTY_DIGEST {
340 self.nodes.insert(ROOT_INDEX, root);
341 } else if self.root() != root {
342 return Err(MerkleError::ConflictingRoots {
343 expected_root: self.root(),
344 actual_root: root,
345 });
346 }
347
348 Ok(())
349 }
350
351 pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
363 let mut node_index = NodeIndex::new(self.max_depth(), index)?;
364
365 for _ in 0..node_index.depth() {
367 if !self.leaves.contains(&node_index) {
368 node_index.move_up();
369 }
370 }
371
372 let old_value = self
374 .nodes
375 .insert(node_index, value)
376 .ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
377
378 if value == old_value {
380 return Ok(old_value);
381 }
382
383 let mut value = value;
384 for _ in 0..node_index.depth() {
385 let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
386 value = Rpo256::merge(&node_index.build_node(value, *sibling));
387 node_index.move_up();
388 self.nodes.insert(node_index, value);
389 }
390
391 Ok(old_value)
392 }
393
394 pub fn print(&self) -> Result<String, fmt::Error> {
399 let indent = " ";
400 let mut s = String::new();
401 s.push_str("root: ");
402 s.push_str(&word_to_hex(&self.root())?);
403 s.push('\n');
404 for d in 1..=self.max_depth() {
405 let entries = 2u64.pow(d.into());
406 for i in 0..entries {
407 let index = NodeIndex::new(d, i).expect("The index must always be valid");
408 let node = self.get_node(index);
409 let node = match node {
410 Err(_) => continue,
411 Ok(node) => node,
412 };
413
414 for _ in 0..d {
415 s.push_str(indent);
416 }
417 s.push_str(&format!("({}, {}): ", index.depth(), index.value()));
418 s.push_str(&word_to_hex(&node)?);
419 s.push('\n');
420 }
421 }
422
423 Ok(s)
424 }
425
426 fn update_depth(&mut self, new_depth: u8) {
431 self.max_depth = new_depth.max(self.max_depth);
432 }
433
434 fn check_depth(depth: u8) -> Result<(), MerkleError> {
436 if depth < Self::MIN_DEPTH {
438 return Err(MerkleError::DepthTooSmall(depth));
439 } else if Self::MAX_DEPTH < depth {
440 return Err(MerkleError::DepthTooBig(depth as u64));
441 }
442 Ok(())
443 }
444}
445
446impl Serializable for PartialMerkleTree {
450 fn write_into<W: ByteWriter>(&self, target: &mut W) {
451 target.write_u64(self.leaves.len() as u64);
453 for leaf_index in self.leaves.iter() {
454 leaf_index.write_into(target);
455 self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
456 }
457 }
458}
459
460impl Deserializable for PartialMerkleTree {
461 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
462 let leaves_len = source.read_u64()? as usize;
463 let mut leaf_nodes = Vec::with_capacity(leaves_len);
464
465 for _ in 0..leaves_len {
467 let index = NodeIndex::read_from(source)?;
468 let hash = Word::read_from(source)?;
469 leaf_nodes.push((index, hash));
470 }
471
472 let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
473 DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
474 })?;
475
476 Ok(pmt)
477 }
478}