1use alloc::{string::String, vec::Vec};
2use core::{fmt, ops::Deref, slice};
3
4use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
5use crate::utils::{uninit_vector, word_to_hex};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
12#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
13pub struct MerkleTree {
14 nodes: Vec<RpoDigest>,
15}
16
17impl MerkleTree {
18 pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
25 where
26 T: AsRef<[Word]>,
27 {
28 let leaves = leaves.as_ref();
29 let n = leaves.len();
30 if n <= 1 {
31 return Err(MerkleError::DepthTooSmall(n as u8));
32 } else if !n.is_power_of_two() {
33 return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
34 }
35
36 let mut nodes = unsafe { uninit_vector(2 * n) };
38 nodes[0] = RpoDigest::default();
39
40 nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
42 *node = RpoDigest::from(*leaf);
43 });
44
45 let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
49 let pairs = unsafe { slice::from_raw_parts(ptr, n) };
50
51 for i in (1..n).rev() {
53 nodes[i] = Rpo256::merge(&pairs[i]);
54 }
55
56 Ok(Self { nodes })
57 }
58
59 pub fn root(&self) -> RpoDigest {
64 self.nodes[1]
65 }
66
67 pub fn depth(&self) -> u8 {
71 (self.nodes.len() / 2).ilog2() as u8
72 }
73
74 pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
81 if index.is_root() {
82 return Err(MerkleError::DepthTooSmall(index.depth()));
83 } else if index.depth() > self.depth() {
84 return Err(MerkleError::DepthTooBig(index.depth() as u64));
85 }
86
87 let pos = index.to_scalar_index() as usize;
88 Ok(self.nodes[pos])
89 }
90
91 pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
99 if index.is_root() {
100 return Err(MerkleError::DepthTooSmall(index.depth()));
101 } else if index.depth() > self.depth() {
102 return Err(MerkleError::DepthTooBig(index.depth() as u64));
103 }
104
105 let mut path = Vec::with_capacity(index.depth() as usize);
109 for _ in 0..index.depth() {
110 let sibling = index.sibling().to_scalar_index() as usize;
111 path.push(self.nodes[sibling]);
112 index.move_up();
113 }
114
115 debug_assert!(index.is_root(), "the path walk must go all the way to the root");
116
117 Ok(path.into())
118 }
119
120 pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
125 let leaves_start = self.nodes.len() / 2;
126 self.nodes
127 .iter()
128 .skip(leaves_start)
129 .enumerate()
130 .map(|(i, v)| (i as u64, v.deref()))
131 }
132
133 pub fn inner_nodes(&self) -> InnerNodeIterator {
137 InnerNodeIterator {
138 nodes: &self.nodes,
139 index: 1, }
141 }
142
143 pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
151 let mut index = NodeIndex::new(self.depth(), index_value)?;
152
153 debug_assert_eq!(self.nodes.len() & 1, 0);
157 let n = self.nodes.len() / 2;
158
159 let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
165 let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
166
167 let pos = index.to_scalar_index() as usize;
169 self.nodes[pos] = value.into();
170
171 for _ in 0..index.depth() {
173 index.move_up();
174 let pos = index.to_scalar_index() as usize;
175 let value = Rpo256::merge(&pairs[pos]);
176 self.nodes[pos] = value;
177 }
178
179 Ok(())
180 }
181}
182
183impl TryFrom<&[Word]> for MerkleTree {
187 type Error = MerkleError;
188
189 fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
190 MerkleTree::new(value)
191 }
192}
193
194impl TryFrom<&[RpoDigest]> for MerkleTree {
195 type Error = MerkleError;
196
197 fn try_from(value: &[RpoDigest]) -> Result<Self, Self::Error> {
198 let value: Vec<Word> = value.iter().map(|v| *v.deref()).collect();
199 MerkleTree::new(value)
200 }
201}
202
203pub struct InnerNodeIterator<'a> {
210 nodes: &'a Vec<RpoDigest>,
211 index: usize,
212}
213
214impl Iterator for InnerNodeIterator<'_> {
215 type Item = InnerNodeInfo;
216
217 fn next(&mut self) -> Option<Self::Item> {
218 if self.index < self.nodes.len() / 2 {
219 let value = self.index;
220 let left = self.index * 2;
221 let right = left + 1;
222
223 self.index += 1;
224
225 Some(InnerNodeInfo {
226 value: self.nodes[value],
227 left: self.nodes[left],
228 right: self.nodes[right],
229 })
230 } else {
231 None
232 }
233 }
234}
235
236pub fn tree_to_text(tree: &MerkleTree) -> Result<String, fmt::Error> {
241 let indent = " ";
242 let mut s = String::new();
243 s.push_str(&word_to_hex(&tree.root())?);
244 s.push('\n');
245 for d in 1..=tree.depth() {
246 let entries = 2u64.pow(d.into());
247 for i in 0..entries {
248 let index = NodeIndex::new(d, i).expect("The index must always be valid");
249 let node = tree.get_node(index).expect("The node must always be found");
250
251 for _ in 0..d {
252 s.push_str(indent);
253 }
254 s.push_str(&word_to_hex(&node)?);
255 s.push('\n');
256 }
257 }
258
259 Ok(s)
260}
261
262pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
264 let mut s = String::new();
265 s.push('[');
266
267 for el in path.iter() {
268 s.push_str(&word_to_hex(el)?);
269 s.push_str(", ");
270 }
271
272 if !path.is_empty() {
274 s.pop();
275 s.pop();
276 }
277 s.push(']');
278
279 Ok(s)
280}
281
282#[cfg(test)]
286mod tests {
287 use core::mem::size_of;
288
289 use proptest::prelude::*;
290
291 use super::*;
292 use crate::{
293 Felt, WORD_SIZE,
294 merkle::{digests_to_words, int_to_leaf, int_to_node},
295 };
296
297 const LEAVES4: [RpoDigest; WORD_SIZE] =
298 [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
299
300 const LEAVES8: [RpoDigest; 8] = [
301 int_to_node(1),
302 int_to_node(2),
303 int_to_node(3),
304 int_to_node(4),
305 int_to_node(5),
306 int_to_node(6),
307 int_to_node(7),
308 int_to_node(8),
309 ];
310
311 #[test]
312 fn build_merkle_tree() {
313 let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
314 assert_eq!(8, tree.nodes.len());
315
316 for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
318 assert_eq!(a, b);
319 }
320
321 let (root, node2, node3) = compute_internal_nodes();
322
323 assert_eq!(root, tree.nodes[1]);
324 assert_eq!(node2, tree.nodes[2]);
325 assert_eq!(node3, tree.nodes[3]);
326
327 assert_eq!(root, tree.root());
328 }
329
330 #[test]
331 fn get_leaf() {
332 let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
333
334 assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
336 assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
337 assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
338 assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
339
340 let (_, node2, node3) = compute_internal_nodes();
342
343 assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
344 assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
345 }
346
347 #[test]
348 fn get_path() {
349 let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
350
351 let (_, node2, node3) = compute_internal_nodes();
352
353 assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
355 assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
356 assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
357 assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
358
359 assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
361 assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
362 }
363
364 #[test]
365 fn update_leaf() {
366 let mut tree = super::MerkleTree::new(digests_to_words(&LEAVES8)).unwrap();
367
368 let value = 3;
370 let new_node = int_to_leaf(9);
371 let mut expected_leaves = digests_to_words(&LEAVES8);
372 expected_leaves[value as usize] = new_node;
373 let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
374
375 tree.update_leaf(value, new_node).unwrap();
376 assert_eq!(expected_tree.nodes, tree.nodes);
377
378 let value = 6;
380 let new_node = int_to_leaf(10);
381 expected_leaves[value as usize] = new_node;
382 let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
383
384 tree.update_leaf(value, new_node).unwrap();
385 assert_eq!(expected_tree.nodes, tree.nodes);
386 }
387
388 #[test]
389 fn nodes() -> Result<(), MerkleError> {
390 let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
391 let root = tree.root();
392 let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
393 let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
394 let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
395 let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
396 let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
397 let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
398
399 let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
400 let expected = vec![
401 InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
402 InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
403 InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
404 ];
405 assert_eq!(nodes, expected);
406
407 Ok(())
408 }
409
410 proptest! {
411 #[test]
412 fn arbitrary_word_can_be_represented_as_digest(
413 a in prop::num::u64::ANY,
414 b in prop::num::u64::ANY,
415 c in prop::num::u64::ANY,
416 d in prop::num::u64::ANY,
417 ) {
418 let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
424 let digest = RpoDigest::from(word);
425
426 let word_ptr = word.as_ptr() as *const u8;
428 let digest_ptr = digest.as_ptr() as *const u8;
429 assert_ne!(word_ptr, digest_ptr);
430
431 let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
433 let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
434 assert_eq!(word_bytes, digest_bytes);
435 }
436 }
437
438 fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
442 let node2 =
443 Rpo256::hash_elements(&[Word::from(LEAVES4[0]), Word::from(LEAVES4[1])].concat());
444 let node3 =
445 Rpo256::hash_elements(&[Word::from(LEAVES4[2]), Word::from(LEAVES4[3])].concat());
446 let root = Rpo256::merge(&[node2, node3]);
447
448 (root, node2, node3)
449 }
450}