1use alloc::{string::String, vec::Vec};
2use core::{fmt, slice};
3
4use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, Word};
5use crate::utils::{uninit_vector, word_to_hex};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
12#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
13pub struct MerkleTree {
14 nodes: Vec<Word>,
15}
16
17impl MerkleTree {
18 pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
25 where
26 T: AsRef<[Word]>,
27 {
28 let leaves = leaves.as_ref();
29 let n = leaves.len();
30 if n <= 1 {
31 return Err(MerkleError::DepthTooSmall(n as u8));
32 } else if !n.is_power_of_two() {
33 return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
34 }
35
36 let mut nodes = unsafe { uninit_vector(2 * n) };
38 nodes[0] = Word::default();
39
40 nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
42 *node = *leaf;
43 });
44
45 let ptr = nodes.as_ptr() as *const [Word; 2];
49 let pairs = unsafe { slice::from_raw_parts(ptr, n) };
50
51 for i in (1..n).rev() {
53 nodes[i] = Rpo256::merge(&pairs[i]);
54 }
55
56 Ok(Self { nodes })
57 }
58
59 pub fn root(&self) -> Word {
64 self.nodes[1]
65 }
66
67 pub fn depth(&self) -> u8 {
71 (self.nodes.len() / 2).ilog2() as u8
72 }
73
74 pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
81 if index.is_root() {
82 return Err(MerkleError::DepthTooSmall(index.depth()));
83 } else if index.depth() > self.depth() {
84 return Err(MerkleError::DepthTooBig(index.depth() as u64));
85 }
86
87 let pos = index.to_scalar_index() as usize;
88 Ok(self.nodes[pos])
89 }
90
91 pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
99 if index.is_root() {
100 return Err(MerkleError::DepthTooSmall(index.depth()));
101 } else if index.depth() > self.depth() {
102 return Err(MerkleError::DepthTooBig(index.depth() as u64));
103 }
104
105 Ok(MerklePath::from(Vec::from_iter(
106 index.proof_indices().map(|index| self.get_node(index).unwrap()),
107 )))
108 }
109
110 pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
115 let leaves_start = self.nodes.len() / 2;
116 self.nodes.iter().skip(leaves_start).enumerate().map(|(i, v)| (i as u64, v))
117 }
118
119 pub fn inner_nodes(&self) -> InnerNodeIterator<'_> {
123 InnerNodeIterator {
124 nodes: &self.nodes,
125 index: 1, }
127 }
128
129 pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
137 let mut index = NodeIndex::new(self.depth(), index_value)?;
138
139 debug_assert_eq!(self.nodes.len() & 1, 0);
143 let n = self.nodes.len() / 2;
144
145 let ptr = self.nodes.as_ptr() as *const [Word; 2];
151 let pairs: &'a [[Word; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
152
153 let pos = index.to_scalar_index() as usize;
155 self.nodes[pos] = value;
156
157 for _ in 0..index.depth() {
159 index.move_up();
160 let pos = index.to_scalar_index() as usize;
161 let value = Rpo256::merge(&pairs[pos]);
162 self.nodes[pos] = value;
163 }
164
165 Ok(())
166 }
167}
168
169impl TryFrom<&[Word]> for MerkleTree {
173 type Error = MerkleError;
174
175 fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
176 MerkleTree::new(value)
177 }
178}
179
180pub struct InnerNodeIterator<'a> {
187 nodes: &'a Vec<Word>,
188 index: usize,
189}
190
191impl Iterator for InnerNodeIterator<'_> {
192 type Item = InnerNodeInfo;
193
194 fn next(&mut self) -> Option<Self::Item> {
195 if self.index < self.nodes.len() / 2 {
196 let value = self.index;
197 let left = self.index * 2;
198 let right = left + 1;
199
200 self.index += 1;
201
202 Some(InnerNodeInfo {
203 value: self.nodes[value],
204 left: self.nodes[left],
205 right: self.nodes[right],
206 })
207 } else {
208 None
209 }
210 }
211}
212
213pub fn tree_to_text(tree: &MerkleTree) -> Result<String, fmt::Error> {
218 let indent = " ";
219 let mut s = String::new();
220 s.push_str(&word_to_hex(&tree.root())?);
221 s.push('\n');
222 for d in 1..=tree.depth() {
223 let entries = 2u64.pow(d.into());
224 for i in 0..entries {
225 let index = NodeIndex::new(d, i).expect("The index must always be valid");
226 let node = tree.get_node(index).expect("The node must always be found");
227
228 for _ in 0..d {
229 s.push_str(indent);
230 }
231 s.push_str(&word_to_hex(&node)?);
232 s.push('\n');
233 }
234 }
235
236 Ok(s)
237}
238
239pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
241 let mut s = String::new();
242 s.push('[');
243
244 for el in path.iter() {
245 s.push_str(&word_to_hex(el)?);
246 s.push_str(", ");
247 }
248
249 if !path.is_empty() {
251 s.pop();
252 s.pop();
253 }
254 s.push(']');
255
256 Ok(s)
257}
258
259#[cfg(test)]
263mod tests {
264 use core::mem::size_of;
265
266 use proptest::prelude::*;
267
268 use super::*;
269 use crate::{
270 Felt, WORD_SIZE,
271 merkle::{int_to_leaf, int_to_node},
272 };
273
274 const LEAVES4: [Word; WORD_SIZE] =
275 [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
276
277 const LEAVES8: [Word; 8] = [
278 int_to_node(1),
279 int_to_node(2),
280 int_to_node(3),
281 int_to_node(4),
282 int_to_node(5),
283 int_to_node(6),
284 int_to_node(7),
285 int_to_node(8),
286 ];
287
288 #[test]
289 fn build_merkle_tree() {
290 let tree = super::MerkleTree::new(LEAVES4).unwrap();
291 assert_eq!(8, tree.nodes.len());
292
293 for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
295 assert_eq!(a, b);
296 }
297
298 let (root, node2, node3) = compute_internal_nodes();
299
300 assert_eq!(root, tree.nodes[1]);
301 assert_eq!(node2, tree.nodes[2]);
302 assert_eq!(node3, tree.nodes[3]);
303
304 assert_eq!(root, tree.root());
305 }
306
307 #[test]
308 fn get_leaf() {
309 let tree = super::MerkleTree::new(LEAVES4).unwrap();
310
311 assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
313 assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
314 assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
315 assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
316
317 let (_, node2, node3) = compute_internal_nodes();
319
320 assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
321 assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
322 }
323
324 #[test]
325 fn get_path() {
326 let tree = super::MerkleTree::new(LEAVES4).unwrap();
327
328 let (_, node2, node3) = compute_internal_nodes();
329
330 assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
332 assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
333 assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
334 assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
335
336 assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
338 assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
339 }
340
341 #[test]
342 fn update_leaf() {
343 let mut tree = super::MerkleTree::new(LEAVES8).unwrap();
344
345 let value = 3;
347 let new_node = int_to_leaf(9);
348 let mut expected_leaves = LEAVES8.to_vec();
349 expected_leaves[value as usize] = new_node;
350 let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
351
352 tree.update_leaf(value, new_node).unwrap();
353 assert_eq!(expected_tree.nodes, tree.nodes);
354
355 let value = 6;
357 let new_node = int_to_leaf(10);
358 expected_leaves[value as usize] = new_node;
359 let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
360
361 tree.update_leaf(value, new_node).unwrap();
362 assert_eq!(expected_tree.nodes, tree.nodes);
363 }
364
365 #[test]
366 fn nodes() -> Result<(), MerkleError> {
367 let tree = super::MerkleTree::new(LEAVES4).unwrap();
368 let root = tree.root();
369 let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
370 let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
371 let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
372 let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
373 let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
374 let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
375
376 let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
377 let expected = vec![
378 InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
379 InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
380 InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
381 ];
382 assert_eq!(nodes, expected);
383
384 Ok(())
385 }
386
387 proptest! {
388 #[test]
389 fn arbitrary_word_can_be_represented_as_digest(
390 a in prop::num::u64::ANY,
391 b in prop::num::u64::ANY,
392 c in prop::num::u64::ANY,
393 d in prop::num::u64::ANY,
394 ) {
395 let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
401 let digest = Word::from(word);
402
403 let word_ptr = word.as_ptr() as *const u8;
405 let digest_ptr = digest.as_ptr() as *const u8;
406 assert_ne!(word_ptr, digest_ptr);
407
408 let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
410 let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<Word>()) };
411 assert_eq!(word_bytes, digest_bytes);
412 }
413 }
414
415 fn compute_internal_nodes() -> (Word, Word, Word) {
419 let node2 = Rpo256::hash_elements(&[*LEAVES4[0], *LEAVES4[1]].concat());
420 let node3 = Rpo256::hash_elements(&[*LEAVES4[2], *LEAVES4[3]].concat());
421 let root = Rpo256::merge(&[node2, node3]);
422
423 (root, node2, node3)
424 }
425}