1use std::mem;
2
3use ethrex_crypto::{Crypto, NativeCrypto};
4use ethrex_rlp::encode::RLPEncode;
5
6use crate::{
7 InconsistentTreeError, TrieDB, ValueRLP, error::TrieError, nibbles::Nibbles,
8 node::NodeRemoveResult, node_hash::NodeHash,
9};
10
11use super::{ExtensionNode, LeafNode, Node, NodeRef, ValueOrHash};
12
13#[derive(
16 Debug,
17 Clone,
18 PartialEq,
19 Default,
20 serde::Serialize,
21 serde::Deserialize,
22 rkyv::Serialize,
23 rkyv::Deserialize,
24 rkyv::Archive,
25)]
26pub struct BranchNode {
27 pub choices: [NodeRef; 16],
28 pub value: ValueRLP,
29}
30
31impl BranchNode {
32 const EMPTY_REF: NodeRef = NodeRef::Hash(NodeHash::Inline(([0; 31], 0)));
33
34 pub const EMPTY_CHOICES: [NodeRef; 16] = [Self::EMPTY_REF; 16];
36
37 pub fn new(choices: [NodeRef; 16]) -> Self {
39 Self {
40 choices,
41 value: Default::default(),
42 }
43 }
44
45 pub const fn new_with_value(choices: [NodeRef; 16], value: ValueRLP) -> Self {
47 Self { choices, value }
48 }
49
50 pub fn update(&mut self, new_value: ValueRLP) {
52 self.value = new_value;
53 }
54
55 pub fn get(&self, db: &dyn TrieDB, mut path: Nibbles) -> Result<Option<ValueRLP>, TrieError> {
57 if let Some(choice) = path.next_choice() {
60 let child_ref = &self.choices[choice];
62 if child_ref.is_valid() {
63 let child_node = child_ref.get_node(db, path.current())?.ok_or_else(|| {
64 TrieError::InconsistentTree(Box::new(
65 InconsistentTreeError::NodeNotFoundOnBranchNode(
66 child_ref
67 .compute_hash(&NativeCrypto)
68 .finalize(&NativeCrypto),
69 self.compute_hash(&NativeCrypto).finalize(&NativeCrypto),
70 path.current(),
71 ),
72 ))
73 })?;
74 child_node.get(db, path)
75 } else {
76 Ok(None)
77 }
78 } else {
79 Ok((!self.value.is_empty()).then_some(self.value.clone()))
81 }
82 }
83
84 pub fn insert(
86 &mut self,
87 db: &dyn TrieDB,
88 mut path: Nibbles,
89 value: ValueOrHash,
90 ) -> Result<(), TrieError> {
91 if let Some(choice) = path.next_choice() {
94 match (&mut self.choices[choice], value) {
95 (choice_ref, ValueOrHash::Value(value)) if !choice_ref.is_valid() => {
97 let new_leaf = LeafNode::new(path, value);
98 *choice_ref = Node::from(new_leaf).into()
99 }
100 (choice_ref, ValueOrHash::Value(value)) => {
102 let Some(choice_node) = choice_ref.get_node_mut(db, path.current())? else {
103 return Err(TrieError::InconsistentTree(Box::new(
104 InconsistentTreeError::NodeNotFoundOnBranchNode(
105 choice_ref
106 .compute_hash(&NativeCrypto)
107 .finalize(&NativeCrypto),
108 self.compute_hash(&NativeCrypto).finalize(&NativeCrypto),
109 path.current(),
110 ),
111 )));
112 };
113
114 choice_node.insert(db, path, value)?;
115 choice_ref.clear_hash();
116 }
117 (choice_ref, value @ ValueOrHash::Hash(hash)) => {
119 if !choice_ref.is_valid() {
120 *choice_ref = hash.into();
121 } else if path.is_empty() {
122 return Err(TrieError::Verify(
123 "attempt to override proof node with external hash".to_string(),
124 ));
125 } else {
126 let Some(choice_node) = choice_ref.get_node_mut(db, path.current())? else {
127 return Err(TrieError::InconsistentTree(Box::new(
128 InconsistentTreeError::NodeNotFoundOnBranchNode(
129 choice_ref
130 .compute_hash(&NativeCrypto)
131 .finalize(&NativeCrypto),
132 self.compute_hash(&NativeCrypto).finalize(&NativeCrypto),
133 path.current(),
134 ),
135 )));
136 };
137 choice_node.insert(db, path, value)?;
138 choice_ref.clear_hash();
139 }
140 }
141 }
142 } else if let ValueOrHash::Value(value) = value {
143 self.update(value);
145 } else {
146 todo!("handle override case (error?)")
148 }
149
150 Ok(())
151 }
152
153 pub fn remove(
157 &mut self,
158 db: &dyn TrieDB,
159 mut path: Nibbles,
160 ) -> Result<(Option<NodeRemoveResult>, Option<ValueRLP>), TrieError> {
161 let base_path = path.clone();
179
180 let value = if let Some(choice_index) = path.next_choice() {
183 if self.choices[choice_index].is_valid() {
184 let Some(child_node) =
185 self.choices[choice_index].get_node_mut(db, path.current())?
186 else {
187 return Err(TrieError::InconsistentTree(Box::new(
188 InconsistentTreeError::NodeNotFoundOnBranchNode(
189 self.choices[choice_index]
190 .compute_hash(&NativeCrypto)
191 .finalize(&NativeCrypto),
192 self.compute_hash(&NativeCrypto).finalize(&NativeCrypto),
193 path.current(),
194 ),
195 )));
196 };
197
198 let (empty_trie, old_value) = child_node.remove(db, path.clone())?;
200 if empty_trie {
201 self.choices[choice_index] = NodeHash::default().into();
203 }
204 self.choices[choice_index].clear_hash();
205 old_value
206 } else {
207 None
208 }
209 } else {
210 if !self.value.is_empty() {
212 Some(mem::take(&mut self.value))
213 } else {
214 None
215 }
216 };
217
218 let mut children = self
220 .choices
221 .iter_mut()
222 .enumerate()
223 .filter(|(_, child)| child.is_valid())
224 .collect::<Vec<_>>();
225 let new_node = match (children.len(), !self.value.is_empty()) {
226 (0, true) => NodeRemoveResult::New(
228 LeafNode::new(Nibbles::from_hex(vec![16]), mem::take(&mut self.value)).into(),
229 ),
230 (1, false) => {
232 let (choice_index, child_ref) = children.get_mut(0).unwrap();
233 let Some(child) = child_ref
234 .get_node_mut(db, base_path.current().append_new(*choice_index as u8))?
235 else {
236 return Err(TrieError::InconsistentTree(Box::new(
237 InconsistentTreeError::NodeNotFoundOnBranchNode(
238 child_ref
239 .compute_hash(&NativeCrypto)
240 .finalize(&NativeCrypto),
241 self.compute_hash(&NativeCrypto).finalize(&NativeCrypto),
242 base_path.current(),
243 ),
244 )));
245 };
246
247 let node = match child {
248 Node::Branch(_) => ExtensionNode::new(
250 Nibbles::from_hex(vec![*choice_index as u8]),
251 child_ref.clone(),
252 )
253 .into(),
254 Node::Extension(extension_node) => {
256 let mut extension_node = extension_node.take();
257 extension_node.prefix.prepend(*choice_index as u8);
258 extension_node.into()
259 }
260 Node::Leaf(leaf) => {
261 let mut leaf = leaf.take();
262 leaf.partial.prepend(*choice_index as u8);
263 leaf.into()
264 }
265 };
266 NodeRemoveResult::New(node)
267 }
268 _ => NodeRemoveResult::Mutated,
270 };
271 Ok((Some(new_node), value))
272 }
273
274 pub fn compute_hash(&self, crypto: &dyn Crypto) -> NodeHash {
276 self.compute_hash_no_alloc(&mut vec![], crypto)
277 }
278
279 pub fn compute_hash_no_alloc(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) -> NodeHash {
281 buf.clear();
282 self.encode(buf);
283 let hash = NodeHash::from_encoded(buf, crypto);
284 buf.clear();
285 hash
286 }
287
288 pub fn get_path(
292 &self,
293 db: &dyn TrieDB,
294 mut path: Nibbles,
295 node_path: &mut Vec<Vec<u8>>,
296 ) -> Result<(), TrieError> {
297 let encoded = self.encode_to_vec();
299 if encoded.len() >= 32 {
300 node_path.push(encoded);
301 };
302 if let Some(choice) = path.next_choice() {
304 let child_ref = &self.choices[choice];
306 if child_ref.is_valid() {
307 let child_node = child_ref.get_node(db, path.current())?.ok_or_else(|| {
308 TrieError::InconsistentTree(Box::new(
309 InconsistentTreeError::NodeNotFoundOnBranchNode(
310 child_ref
311 .compute_hash(&NativeCrypto)
312 .finalize(&NativeCrypto),
313 self.compute_hash(&NativeCrypto).finalize(&NativeCrypto),
314 path.current(),
315 ),
316 ))
317 })?;
318 child_node.get_path(db, path, node_path)?;
319 }
320 }
321 Ok(())
322 }
323}
324
325#[cfg(test)]
326mod test {
327 use ethereum_types::H256;
328 use ethrex_crypto::NativeCrypto;
329 use ethrex_rlp::{decode::RLPDecode, encode::RLPEncode};
330
331 use super::*;
332
333 use crate::{Trie, pmt_node};
334
335 #[test]
336 fn new() {
337 let node = BranchNode::new({
338 let mut choices = BranchNode::EMPTY_CHOICES;
339
340 choices[2] = NodeHash::Hashed(H256([2; 32])).into();
341 choices[5] = NodeHash::Hashed(H256([5; 32])).into();
342
343 choices
344 });
345
346 assert_eq!(
347 node.choices,
348 [
349 Default::default(),
350 Default::default(),
351 NodeHash::Hashed(H256([2; 32])).into(),
352 Default::default(),
353 Default::default(),
354 NodeHash::Hashed(H256([5; 32])).into(),
355 Default::default(),
356 Default::default(),
357 Default::default(),
358 Default::default(),
359 Default::default(),
360 Default::default(),
361 Default::default(),
362 Default::default(),
363 Default::default(),
364 Default::default(),
365 ],
366 );
367 }
368
369 #[test]
370 fn get_some() {
371 let trie = Trie::new_temp();
372 let node = pmt_node! { @(trie)
373 branch {
374 0 => leaf { vec![0,16] => vec![0x12, 0x34, 0x56, 0x78] },
375 1 => leaf { vec![0,16] => vec![0x34, 0x56, 0x78, 0x9A] },
376 }
377 };
378
379 assert_eq!(
380 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x00]))
381 .unwrap(),
382 Some(vec![0x12, 0x34, 0x56, 0x78]),
383 );
384 assert_eq!(
385 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x10]))
386 .unwrap(),
387 Some(vec![0x34, 0x56, 0x78, 0x9A]),
388 );
389 }
390
391 #[test]
392 fn get_none() {
393 let trie = Trie::new_temp();
394 let node = pmt_node! { @(trie)
395 branch {
396 0 => leaf { vec![0,16] => vec![0x12, 0x34, 0x56, 0x78] },
397 1 => leaf { vec![0,16] => vec![0x34, 0x56, 0x78, 0x9A] },
398 }
399 };
400
401 assert_eq!(
402 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x20]))
403 .unwrap(),
404 None,
405 );
406 }
407
408 #[test]
409 fn insert_self() {
410 let trie = Trie::new_temp();
411 let mut node = pmt_node! { @(trie)
412 branch {
413 0 => leaf { vec![0, 16] => vec![0x12, 0x34, 0x56, 0x78] },
414 1 => leaf { vec![0, 16] => vec![0x34, 0x56, 0x78, 0x9A] },
415 }
416 };
417 let path = Nibbles::from_bytes(&[2]);
418 let value = vec![0x3];
419
420 node.insert(trie.db.as_ref(), path.clone(), value.clone().into())
421 .unwrap();
422
423 assert_eq!(node.get(trie.db.as_ref(), path).unwrap(), Some(value));
424 }
425
426 #[test]
427 fn insert_choice() {
428 let trie = Trie::new_temp();
429 let mut node = pmt_node! { @(trie)
430 branch {
431 0 => leaf { vec![0, 16] => vec![0x12, 0x34, 0x56, 0x78] },
432 1 => leaf { vec![0, 16] => vec![0x34, 0x56, 0x78, 0x9A] },
433 }
434 };
435
436 let path = Nibbles::from_bytes(&[0x20]);
437 let value = vec![0x21];
438
439 node.insert(trie.db.as_ref(), path.clone(), value.clone().into())
440 .unwrap();
441
442 assert_eq!(node.get(trie.db.as_ref(), path).unwrap(), Some(value));
443 }
444
445 #[test]
446 fn insert_passthrough() {
447 let trie = Trie::new_temp();
448 let node = pmt_node! { @(trie)
449 branch {
450 0 => leaf { vec![0, 16] => vec![0x12, 0x34, 0x56, 0x78] },
451 1 => leaf { vec![0, 16] => vec![0x34, 0x56, 0x78, 0x9A] },
452 }
453 };
454
455 let path = Nibbles::from_bytes(&[0x00]).offset(2);
457 let value = vec![0x1];
458
459 let mut new_node = node.clone();
460 new_node
461 .insert(trie.db.as_ref(), path, value.clone().into())
462 .unwrap();
463
464 assert_eq!(new_node.choices, node.choices);
465 assert_eq!(new_node.value, value);
466 }
467
468 #[test]
469 fn remove_choice_into_inner() {
470 let trie = Trie::new_temp();
471 let mut node = pmt_node! { @(trie)
472 branch {
473 0 => leaf { vec![0, 16] => vec![0x00] },
474 1 => leaf { vec![0, 16] => vec![0x10] },
475 }
476 };
477
478 let (node, value) = node
479 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[0x00]))
480 .unwrap();
481
482 assert!(matches!(node, Some(NodeRemoveResult::New(Node::Leaf(_)))));
483 assert_eq!(value, Some(vec![0x00]));
484 }
485
486 #[test]
487 fn remove_choice() {
488 let trie = Trie::new_temp();
489 let mut node = pmt_node! { @(trie)
490 branch {
491 0 => leaf { vec![0, 16] => vec![0x00] },
492 1 => leaf { vec![0, 16] => vec![0x10] },
493 2 => leaf { vec![0, 16] => vec![0x10] },
494 }
495 };
496
497 let (node, value) = node
498 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[0x00]))
499 .unwrap();
500
501 assert!(matches!(node, Some(NodeRemoveResult::Mutated)));
502 assert_eq!(value, Some(vec![0x00]));
503 }
504
505 #[test]
506 fn remove_choice_into_value() {
507 let trie = Trie::new_temp();
508 let mut node = pmt_node! { @(trie)
509 branch {
510 0 => leaf { vec![0, 16] => vec![0x00] },
511 } with_leaf { &[0x01] => vec![0xFF] }
512 };
513
514 let (node, value) = node
515 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[0x00]))
516 .unwrap();
517
518 assert!(matches!(node, Some(NodeRemoveResult::New(Node::Leaf(_)))));
519 assert_eq!(value, Some(vec![0x00]));
520 }
521
522 #[test]
523 fn remove_value_into_inner() {
524 let trie = Trie::new_temp();
525 let mut node = pmt_node! { @(trie)
526 branch {
527 0 => leaf { vec![0, 16] => vec![0x00] },
528 } with_leaf { &[0x1] => vec![0xFF] }
529 };
530
531 let (node, value) = node
532 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[]))
533 .unwrap();
534
535 assert!(matches!(node, Some(NodeRemoveResult::New(Node::Leaf(_)))));
536 assert_eq!(value, Some(vec![0xFF]));
537 }
538
539 #[test]
540 fn remove_value() {
541 let trie = Trie::new_temp();
542 let mut node = pmt_node! { @(trie)
543 branch {
544 0 => leaf { vec![0, 16] => vec![0x00] },
545 1 => leaf { vec![0, 16] => vec![0x10] },
546 } with_leaf { &[0x1] => vec![0xFF] }
547 };
548
549 let (node, value) = node
550 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[]))
551 .unwrap();
552
553 assert!(matches!(node, Some(NodeRemoveResult::Mutated)));
554 assert_eq!(value, Some(vec![0xFF]));
555 }
556
557 #[test]
558 fn compute_hash_two_choices() {
559 let node = pmt_node! { @(trie)
560 branch {
561 2 => leaf { vec![0, 16] => vec![0x20] },
562 4 => leaf { vec![0, 16] => vec![0x40] },
563 }
564 };
565
566 assert_eq!(
567 node.compute_hash(&NativeCrypto).as_ref(),
568 &[
569 0xD5, 0x80, 0x80, 0xC2, 0x30, 0x20, 0x80, 0xC2, 0x30, 0x40, 0x80, 0x80, 0x80, 0x80,
570 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
571 ],
572 );
573 }
574
575 #[test]
576 fn compute_hash_all_choices() {
577 let node = pmt_node! { @(trie)
578 branch {
579 0x0 => leaf { vec![0, 16] => vec![0x00] },
580 0x1 => leaf { vec![0, 16] => vec![0x10] },
581 0x2 => leaf { vec![0, 16] => vec![0x20] },
582 0x3 => leaf { vec![0, 16] => vec![0x30] },
583 0x4 => leaf { vec![0, 16] => vec![0x40] },
584 0x5 => leaf { vec![0, 16] => vec![0x50] },
585 0x6 => leaf { vec![0, 16] => vec![0x60] },
586 0x7 => leaf { vec![0, 16] => vec![0x70] },
587 0x8 => leaf { vec![0, 16] => vec![0x80] },
588 0x9 => leaf { vec![0, 16] => vec![0x90] },
589 0xA => leaf { vec![0, 16] => vec![0xA0] },
590 0xB => leaf { vec![0, 16] => vec![0xB0] },
591 0xC => leaf { vec![0, 16] => vec![0xC0] },
592 0xD => leaf { vec![0, 16] => vec![0xD0] },
593 0xE => leaf { vec![0, 16] => vec![0xE0] },
594 0xF => leaf { vec![0, 16] => vec![0xF0] },
595 }
596 };
597
598 assert_eq!(
599 node.compute_hash(&NativeCrypto).as_ref(),
600 &[
601 0x0A, 0x3C, 0x06, 0x2D, 0x4A, 0xE3, 0x61, 0xEC, 0xC4, 0x82, 0x07, 0xB3, 0x2A, 0xDB,
602 0x6A, 0x3A, 0x3F, 0x3E, 0x98, 0x33, 0xC8, 0x9C, 0x9A, 0x71, 0x66, 0x3F, 0x4E, 0xB5,
603 0x61, 0x72, 0xD4, 0x9D,
604 ],
605 );
606 }
607
608 #[test]
609 fn compute_hash_one_choice_with_value() {
610 let node = pmt_node! { @(trie)
611 branch {
612 2 => leaf { vec![0, 16] => vec![0x20] },
613 4 => leaf { vec![0, 16] => vec![0x40] },
614 } with_leaf { &[0x1] => vec![0x1] }
615 };
616
617 assert_eq!(
618 node.compute_hash(&NativeCrypto).as_ref(),
619 &[
620 0xD5, 0x80, 0x80, 0xC2, 0x30, 0x20, 0x80, 0xC2, 0x30, 0x40, 0x80, 0x80, 0x80, 0x80,
621 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01,
622 ],
623 );
624 }
625
626 #[test]
627 fn compute_hash_all_choices_with_value() {
628 let node = pmt_node! { @(trie)
629 branch {
630 0x0 => leaf { vec![0, 16] => vec![0x00] },
631 0x1 => leaf { vec![0, 16] => vec![0x10] },
632 0x2 => leaf { vec![0, 16] => vec![0x20] },
633 0x3 => leaf { vec![0, 16] => vec![0x30] },
634 0x4 => leaf { vec![0, 16] => vec![0x40] },
635 0x5 => leaf { vec![0, 16] => vec![0x50] },
636 0x6 => leaf { vec![0, 16] => vec![0x60] },
637 0x7 => leaf { vec![0, 16] => vec![0x70] },
638 0x8 => leaf { vec![0, 16] => vec![0x80] },
639 0x9 => leaf { vec![0, 16] => vec![0x90] },
640 0xA => leaf { vec![0, 16] => vec![0xA0] },
641 0xB => leaf { vec![0, 16] => vec![0xB0] },
642 0xC => leaf { vec![0, 16] => vec![0xC0] },
643 0xD => leaf { vec![0, 16] => vec![0xD0] },
644 0xE => leaf { vec![0, 16] => vec![0xE0] },
645 0xF => leaf { vec![0, 16] => vec![0xF0] },
646 } with_leaf { &[0x1] => vec![0x1] }
647 };
648
649 assert_eq!(
650 node.compute_hash(&NativeCrypto).as_ref(),
651 &[
652 0x2A, 0x85, 0x67, 0xC5, 0x63, 0x4A, 0x87, 0xBA, 0x19, 0x6F, 0x2C, 0x65, 0x15, 0x16,
653 0x66, 0x37, 0xE0, 0x9A, 0x34, 0xE6, 0xC9, 0xB0, 0x4D, 0xA5, 0x6F, 0xC4, 0x70, 0x4E,
654 0x38, 0x61, 0x7D, 0x8E
655 ],
656 );
657 }
658
659 #[test]
660 fn symmetric_encoding_a() {
661 let node: Node = pmt_node! { @(trie)
662 branch {
663 0 => leaf { vec![0,16] => vec![0x12, 0x34, 0x56, 0x78] },
664 1 => leaf { vec![0,16] => vec![0x34, 0x56, 0x78, 0x9A] },
665 }
666 }
667 .into();
668 assert_eq!(Node::decode(&node.encode_to_vec()).unwrap(), node)
669 }
670
671 #[test]
672 fn symmetric_encoding_b() {
673 let node: Node = pmt_node! { @(trie)
674 branch {
675 0 => leaf { vec![0, 16] => vec![0x00] },
676 1 => leaf { vec![0, 16] => vec![0x10] },
677 3 => extension { [0], branch {
678 0 => leaf { vec![16] => vec![0x01, 0x00] },
679 1 => leaf { vec![16] => vec![0x01, 0x01] },
680 } },
681 }
682 }
683 .into();
684 assert_eq!(Node::decode(&node.encode_to_vec()).unwrap(), node)
685 }
686
687 #[test]
688 fn symmetric_encoding_c() {
689 let node: Node = pmt_node! { @(trie)
690 branch {
691 0x0 => leaf { vec![0, 16] => vec![0x00] },
692 0x1 => leaf { vec![0, 16] => vec![0x10] },
693 0x2 => leaf { vec![0, 16] => vec![0x20] },
694 0x3 => leaf { vec![0, 16] => vec![0x30] },
695 0x4 => leaf { vec![0, 16] => vec![0x40] },
696 0x5 => leaf { vec![0, 16] => vec![0x50] },
697 0x6 => leaf { vec![0, 16] => vec![0x60] },
698 0x7 => leaf { vec![0, 16] => vec![0x70] },
699 0x8 => leaf { vec![0, 16] => vec![0x80] },
700 0x9 => leaf { vec![0, 16] => vec![0x90] },
701 0xA => leaf { vec![0, 16] => vec![0xA0] },
702 0xB => leaf { vec![0, 16] => vec![0xB0] },
703 0xC => leaf { vec![0, 16] => vec![0xC0] },
704 0xD => leaf { vec![0, 16] => vec![0xD0] },
705 0xE => leaf { vec![0, 16] => vec![0xE0] },
706 0xF => leaf { vec![0, 16] => vec![0xF0] },
707 } with_leaf { &[0x1] => vec![0x1] }
708 }
709 .into();
710 assert_eq!(Node::decode(&node.encode_to_vec()).unwrap(), node)
711 }
712}