1use ethrex_crypto::{Crypto, NativeCrypto};
2use ethrex_rlp::encode::RLPEncode;
3
4use crate::ValueRLP;
5use crate::nibbles::Nibbles;
6use crate::node::NodeRemoveResult;
7use crate::node_hash::NodeHash;
8use crate::{
9 TrieDB,
10 error::{ExtensionNodeErrorData, InconsistentTreeError, TrieError},
11};
12
13use super::{BranchNode, Node, NodeRef, ValueOrHash};
14
15#[derive(
18 Debug,
19 Clone,
20 PartialEq,
21 serde::Serialize,
22 serde::Deserialize,
23 rkyv::Serialize,
24 rkyv::Deserialize,
25 rkyv::Archive,
26)]
27pub struct ExtensionNode {
28 pub prefix: Nibbles,
29 pub child: NodeRef,
30}
31
32impl ExtensionNode {
33 pub const fn new(prefix: Nibbles, child: NodeRef) -> Self {
35 Self { prefix, child }
36 }
37
38 pub fn get(&self, db: &dyn TrieDB, mut path: Nibbles) -> Result<Option<ValueRLP>, TrieError> {
40 if path.skip_prefix(&self.prefix) {
43 let child_node = self.child.get_node(db, path.current())?.ok_or_else(|| {
44 TrieError::InconsistentTree(Box::new(
45 InconsistentTreeError::ExtensionNodeChildNotFound(ExtensionNodeErrorData {
46 node_hash: self
47 .child
48 .compute_hash(&NativeCrypto)
49 .finalize(&NativeCrypto),
50 extension_node_hash: self
51 .compute_hash(&NativeCrypto)
52 .finalize(&NativeCrypto),
53 extension_node_prefix: self.prefix.clone(),
54 node_path: path.current(),
55 }),
56 ))
57 })?;
58
59 child_node.get(db, path)
60 } else {
61 Ok(None)
62 }
63 }
64
65 pub fn insert(
68 &mut self,
69 db: &dyn TrieDB,
70 path: Nibbles,
71 value: ValueOrHash,
72 ) -> Result<Option<Node>, TrieError> {
73 let match_index = path.count_prefix(&self.prefix);
84 if match_index == self.prefix.len() {
85 let path = path.offset(match_index);
86 let Some(child_node) = self.child.get_node_mut(db, path.current())? else {
88 return Err(TrieError::InconsistentTree(Box::new(
89 InconsistentTreeError::ExtensionNodeChildNotFound(ExtensionNodeErrorData {
90 node_hash: self
91 .child
92 .compute_hash(&NativeCrypto)
93 .finalize(&NativeCrypto),
94 extension_node_hash: self
95 .compute_hash(&NativeCrypto)
96 .finalize(&NativeCrypto),
97 extension_node_prefix: self.prefix.clone(),
98 node_path: path.current(),
99 }),
100 )));
101 };
102 child_node.insert(db, path, value)?;
103 self.child.clear_hash();
104 Ok(None)
105 } else if match_index == 0 {
106 let mut new_node = if self.prefix.len() == 1 {
107 self.child.clone()
108 } else {
109 Node::from(ExtensionNode::new(
110 self.prefix.offset(1),
111 self.child.clone(),
112 ))
113 .into()
114 };
115 let mut choices = BranchNode::EMPTY_CHOICES;
116 let mut branch_node = if self.prefix.at(0) == 16 {
117 match new_node.get_node_mut(db, path.current())? {
118 Some(Node::Leaf(leaf)) => {
119 BranchNode::new_with_value(choices, leaf.value.clone())
120 }
121 Some(_) => {
122 return Err(TrieError::InconsistentTree(Box::new(
123 InconsistentTreeError::ExtensionNodeChildDiffers(
124 ExtensionNodeErrorData {
125 node_hash: new_node
126 .compute_hash(&NativeCrypto)
127 .finalize(&NativeCrypto),
128 extension_node_hash: self
129 .compute_hash(&NativeCrypto)
130 .finalize(&NativeCrypto),
131 extension_node_prefix: self.prefix.clone(),
132 node_path: path.current(),
133 },
134 ),
135 )));
136 }
137 None => {
138 return Err(TrieError::InconsistentTree(Box::new(
139 InconsistentTreeError::ExtensionNodeChildNotFound(
140 ExtensionNodeErrorData {
141 node_hash: new_node
142 .compute_hash(&NativeCrypto)
143 .finalize(&NativeCrypto),
144 extension_node_hash: self
145 .compute_hash(&NativeCrypto)
146 .finalize(&NativeCrypto),
147 extension_node_prefix: self.prefix.clone(),
148 node_path: path.current(),
149 },
150 ),
151 )));
152 }
153 }
154 } else {
155 choices[self.prefix.at(0)] = new_node;
156 BranchNode::new(choices)
157 };
158 branch_node.insert(db, path, value)?;
159 Ok(Some(branch_node.into()))
160 } else {
161 let mut new_extension =
162 ExtensionNode::new(self.prefix.offset(match_index), self.child.clone());
163 let new_node = new_extension
164 .insert(db, path.offset(match_index), value)?
165 .unwrap_or(new_extension.into());
166 self.prefix = self.prefix.slice(0, match_index);
167 self.child = new_node.into();
168 Ok(None)
169 }
170 }
171
172 pub fn remove(
173 &mut self,
174 db: &dyn TrieDB,
175 mut path: Nibbles,
176 ) -> Result<(Option<NodeRemoveResult>, Option<ValueRLP>), TrieError> {
177 if path.skip_prefix(&self.prefix) {
187 let Some(child_node) = self.child.get_node_mut(db, path.current())? else {
188 return Err(TrieError::InconsistentTree(Box::new(
189 InconsistentTreeError::ExtensionNodeChildNotFound(ExtensionNodeErrorData {
190 node_hash: self
191 .child
192 .compute_hash(&NativeCrypto)
193 .finalize(&NativeCrypto),
194 extension_node_hash: self
195 .compute_hash(&NativeCrypto)
196 .finalize(&NativeCrypto),
197 extension_node_prefix: self.prefix.clone(),
198 node_path: path.current(),
199 }),
200 )));
201 };
202 let (empty_trie, old_value) = child_node.remove(db, path)?;
204 let result = if empty_trie {
206 Ok((None, old_value))
207 } else {
208 let node = match child_node {
209 branch_node @ Node::Branch(_) => {
211 self.child = (*branch_node).clone().into();
212 NodeRemoveResult::Mutated
213 }
214 Node::Extension(extension_node) => {
216 let mut extension_node = extension_node.take();
217 let mut self_node = self.take();
218 self_node.prefix.extend(&extension_node.prefix);
219 extension_node.prefix = self_node.prefix;
220 NodeRemoveResult::New(extension_node.into())
221 }
222 Node::Leaf(leaf_node) => {
224 let mut leaf_node = leaf_node.take();
225 let mut self_node = self.take();
226 self_node.prefix.extend(&leaf_node.partial);
227 leaf_node.partial = self_node.prefix;
228 NodeRemoveResult::New(leaf_node.into())
229 }
230 };
231 Ok((Some(node), old_value))
232 };
233 self.child.clear_hash();
234 result
235 } else {
236 Ok((Some(NodeRemoveResult::Mutated), None))
237 }
238 }
239
240 pub fn compute_hash(&self, crypto: &dyn Crypto) -> NodeHash {
242 self.compute_hash_no_alloc(&mut vec![], crypto)
243 }
244
245 pub fn compute_hash_no_alloc(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) -> NodeHash {
247 buf.clear();
248 self.encode(buf);
249 let hash = NodeHash::from_encoded(buf, crypto);
250 buf.clear();
251 hash
252 }
253
254 pub fn get_path(
258 &self,
259 db: &dyn TrieDB,
260 mut path: Nibbles,
261 node_path: &mut Vec<Vec<u8>>,
262 ) -> Result<(), TrieError> {
263 let encoded = self.encode_to_vec();
265 if encoded.len() >= 32 {
266 node_path.push(encoded);
267 };
268 if path.skip_prefix(&self.prefix) {
270 let child_node = self.child.get_node(db, path.current())?.ok_or_else(|| {
271 TrieError::InconsistentTree(Box::new(
272 InconsistentTreeError::ExtensionNodeChildNotFound(ExtensionNodeErrorData {
273 node_hash: self
274 .child
275 .compute_hash(&NativeCrypto)
276 .finalize(&NativeCrypto),
277 extension_node_hash: self
278 .compute_hash(&NativeCrypto)
279 .finalize(&NativeCrypto),
280 extension_node_prefix: self.prefix.clone(),
281 node_path: path.current(),
282 }),
283 ))
284 })?;
285 child_node.get_path(db, path, node_path)?;
286 }
287 Ok(())
288 }
289
290 pub fn take(&mut self) -> Self {
294 ExtensionNode {
295 prefix: self.prefix.take(),
296 child: self.child.clone(),
297 }
298 }
299}
300
301#[cfg(test)]
302mod test {
303 use ethrex_crypto::NativeCrypto;
304 use ethrex_rlp::{decode::RLPDecode, encode::RLPEncode};
305
306 use super::*;
307 use crate::{Trie, node::LeafNode, pmt_node};
308
309 #[test]
310 fn new() {
311 let node = ExtensionNode::new(Nibbles::default(), Default::default());
312
313 assert_eq!(node.prefix.len(), 0);
314 assert_eq!(node.child, Default::default());
315 }
316
317 #[test]
318 fn get_some() {
319 let trie = Trie::new_temp();
320 let node = pmt_node! { @(trie)
321 extension { [0], branch {
322 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
323 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
324 } }
325 };
326
327 assert_eq!(
328 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x00]))
329 .unwrap(),
330 Some(vec![0x12, 0x34, 0x56, 0x78]),
331 );
332 assert_eq!(
333 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x01]))
334 .unwrap(),
335 Some(vec![0x34, 0x56, 0x78, 0x9A]),
336 );
337 }
338
339 #[test]
340 fn get_none() {
341 let trie = Trie::new_temp();
342 let node = pmt_node! { @(trie)
343 extension { [0], branch {
344 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
345 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
346 } }
347 };
348
349 assert_eq!(
350 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x02]))
351 .unwrap(),
352 None,
353 );
354 }
355
356 #[test]
357 fn insert_passthrough() {
358 let trie = Trie::new_temp();
359 let mut node = pmt_node! { @(trie)
360 extension { [0], branch {
361 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
362 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
363 } }
364 };
365
366 let none = node
367 .insert(
368 trie.db.as_ref(),
369 Nibbles::from_bytes(&[0x02]),
370 Vec::new().into(),
371 )
372 .unwrap();
373 assert!(none.is_none());
374
375 assert_eq!(node.prefix.as_ref(), &[0]);
376 }
377
378 #[test]
379 fn insert_branch() {
380 let trie = Trie::new_temp();
381 let mut node = pmt_node! { @(trie)
382 extension { [0], branch {
383 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
384 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
385 } }
386 };
387
388 let node = node
389 .insert(
390 trie.db.as_ref(),
391 Nibbles::from_bytes(&[0x10]),
392 vec![0x20].into(),
393 )
394 .unwrap();
395 let node = match node {
396 Some(Node::Branch(x)) => x,
397 _ => panic!("expected a branch node"),
398 };
399 assert_eq!(
400 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x10]))
401 .unwrap(),
402 Some(vec![0x20])
403 );
404 }
405
406 #[test]
407 fn insert_branch_extension() {
408 let trie = Trie::new_temp();
409 let mut node = pmt_node! { @(trie)
410 extension { [0, 0], branch {
411 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
412 1 => leaf { vec![16]=> vec![0x34, 0x56, 0x78, 0x9A] },
413 } }
414 };
415
416 let node = node
417 .insert(
418 trie.db.as_ref(),
419 Nibbles::from_bytes(&[0x10]),
420 vec![0x20].into(),
421 )
422 .unwrap();
423 let node = match node {
424 Some(Node::Branch(x)) => x,
425 _ => panic!("expected a branch node"),
426 };
427 assert_eq!(
428 node.get(trie.db.as_ref(), Nibbles::from_bytes(&[0x10]))
429 .unwrap(),
430 Some(vec![0x20])
431 );
432 }
433
434 #[test]
435 fn insert_extension_branch() {
436 let trie = Trie::new_temp();
437 let mut node = pmt_node! { @(trie)
438 extension { [0, 0], branch {
439 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
440 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
441 } }
442 };
443
444 let path = Nibbles::from_bytes(&[0x01]);
445 let value = vec![0x02];
446
447 let none = node
448 .insert(trie.db.as_ref(), path.clone(), value.clone().into())
449 .unwrap();
450
451 assert!(none.is_none());
452 assert_eq!(node.get(trie.db.as_ref(), path).unwrap(), Some(value));
453 }
454
455 #[test]
456 fn insert_extension_branch_extension() {
457 let trie = Trie::new_temp();
458 let mut node = pmt_node! { @(trie)
459 extension { [0, 0], branch {
460 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
461 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
462 } }
463 };
464
465 let path = Nibbles::from_bytes(&[0x01]);
466 let value = vec![0x04];
467
468 let none = node
469 .insert(trie.db.as_ref(), path.clone(), value.clone().into())
470 .unwrap();
471
472 assert!(none.is_none());
473 assert_eq!(node.get(trie.db.as_ref(), path).unwrap(), Some(value));
474 }
475
476 #[test]
477 fn remove_none() {
478 let trie = Trie::new_temp();
479 let mut node = pmt_node! { @(trie)
480 extension { [0], branch {
481 0 => leaf { vec![16] => vec![0x00] },
482 1 => leaf { vec![16] => vec![0x01] },
483 } }
484 };
485
486 let (node, value) = node
487 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[0x02]))
488 .unwrap();
489
490 assert!(matches!(node, Some(NodeRemoveResult::Mutated)));
491 assert_eq!(value, None);
492 }
493
494 #[test]
495 fn remove_into_leaf() {
496 let trie = Trie::new_temp();
497 let mut node = pmt_node! { @(trie)
498 extension { [0], branch {
499 0 => leaf { vec![16] => vec![0x00] },
500 1 => leaf { vec![16] => vec![0x01] },
501 } }
502 };
503
504 let (node, value) = node
505 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[0x01]))
506 .unwrap();
507
508 assert!(matches!(node, Some(NodeRemoveResult::New(Node::Leaf(_)))));
509 assert_eq!(value, Some(vec![0x01]));
510 }
511
512 #[test]
513 fn remove_into_extension() {
514 let trie = Trie::new_temp();
515 let mut node = pmt_node! { @(trie)
516 extension { [0], branch {
517 0 => leaf { vec![16] => vec![0x00] },
518 1 => extension { [0], branch {
519 0 => leaf { vec![16] => vec![0x01, 0x00] },
520 1 => leaf { vec![16] => vec![0x01, 0x01] },
521 } },
522 } }
523 };
524
525 let (node, value) = node
526 .remove(trie.db.as_ref(), Nibbles::from_bytes(&[0x00]))
527 .unwrap();
528
529 assert!(matches!(
530 node,
531 Some(NodeRemoveResult::New(Node::Extension(_)))
532 ));
533 assert_eq!(value, Some(vec![0x00]));
534 }
535
536 #[test]
537 fn compute_hash() {
538 let leaf_node_a = LeafNode::new(Nibbles::from_hex(vec![0, 16]), vec![0x12, 0x34]);
548 let leaf_node_b = LeafNode::new(Nibbles::from_hex(vec![0, 16]), vec![0x56, 0x78]);
549 let mut choices = BranchNode::EMPTY_CHOICES;
550 choices[0] = leaf_node_a.compute_hash(&NativeCrypto).into();
551 choices[1] = leaf_node_b.compute_hash(&NativeCrypto).into();
552 let branch_node = BranchNode::new(choices);
553 let node = ExtensionNode::new(
554 Nibbles::from_hex(vec![0, 0]),
555 branch_node.compute_hash(&NativeCrypto).into(),
556 );
557
558 assert_eq!(
559 node.compute_hash(&NativeCrypto).as_ref(),
560 &[
561 0xDD, 0x82, 0x00, 0x00, 0xD9, 0xC4, 0x30, 0x82, 0x12, 0x34, 0xC4, 0x30, 0x82, 0x56,
562 0x78, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
563 0x80, 0x80,
564 ],
565 );
566 }
567
568 #[test]
569 fn compute_hash_long() {
570 let leaf_node_a = LeafNode::new(
580 Nibbles::from_hex(vec![0, 16]),
581 vec![0x12, 0x34, 0x56, 0x78, 0x9A],
582 );
583 let leaf_node_b = LeafNode::new(
584 Nibbles::from_hex(vec![0, 16]),
585 vec![0x34, 0x56, 0x78, 0x9A, 0xBC],
586 );
587 let mut choices = BranchNode::EMPTY_CHOICES;
588 choices[0] = leaf_node_a.compute_hash(&NativeCrypto).into();
589 choices[1] = leaf_node_b.compute_hash(&NativeCrypto).into();
590 let branch_node = BranchNode::new(choices);
591 let node = ExtensionNode::new(
592 Nibbles::from_hex(vec![0, 0]),
593 branch_node.compute_hash(&NativeCrypto).into(),
594 );
595
596 assert_eq!(
597 node.compute_hash(&NativeCrypto).as_ref(),
598 &[
599 0xFA, 0xBA, 0x42, 0x79, 0xB3, 0x9B, 0xCD, 0xEB, 0x7C, 0x53, 0x0F, 0xD7, 0x6E, 0x5A,
600 0xA3, 0x48, 0xD3, 0x30, 0x76, 0x26, 0x14, 0x84, 0x55, 0xA0, 0xAE, 0xFE, 0x0F, 0x52,
601 0x89, 0x5F, 0x36, 0x06,
602 ],
603 );
604 }
605
606 #[test]
607 fn symmetric_encoding_a() {
608 let node: Node = pmt_node! { @(trie)
609 extension { [0], branch {
610 0 => leaf { vec![16] => vec![0x12, 0x34, 0x56, 0x78] },
611 1 => leaf { vec![16] => vec![0x34, 0x56, 0x78, 0x9A] },
612 } }
613 }
614 .into();
615 assert_eq!(Node::decode(&node.encode_to_vec()).unwrap(), node)
616 }
617
618 #[test]
619 fn symmetric_encoding_b() {
620 let node: Node = pmt_node! { @(trie)
621 extension { [0], branch {
622 0 => leaf { vec![16] => vec![0x00] },
623 1 => extension { [0], branch {
624 0 => leaf { vec![16] => vec![0x01, 0x00] },
625 1 => leaf { vec![16] => vec![0x01, 0x01] },
626 } },
627 } }
628 }
629 .into();
630
631 assert_eq!(Node::decode(&node.encode_to_vec()).unwrap(), node)
632 }
633
634 #[test]
635 fn symmetric_encoding_c() {
636 let node: Node = pmt_node! { @(trie)
637 extension { [0], branch {
638 0 => leaf { vec![16] => vec![0x00] },
639 1 => extension { [0], branch {
640 0 => leaf { vec![16] => vec![0x01, 0x00] },
641 1 => leaf { vec![16] => vec![0x01, 0x01] },
642 2 => leaf { vec![16] => vec![0x01, 0x00] },
643 3 => leaf { vec![16] => vec![0x03, 0x01] },
644 4 => leaf { vec![16] => vec![0x04, 0x00] },
645 5 => leaf { vec![16] => vec![0x05, 0x01] },
646 6 => leaf { vec![16] => vec![0x06, 0x00] },
647 7 => leaf { vec![16] => vec![0x07, 0x01] },
648 8 => leaf { vec![16] => vec![0x08, 0x00] },
649 9 => leaf { vec![16] => vec![0x09, 0x01] },
650 10 => leaf { vec![16] => vec![0x10, 0x00] },
651 11 => leaf { vec![16] => vec![0x11, 0x01] },
652 } },
653 } }
654 }
655 .into();
656 assert_eq!(Node::decode(&node.encode_to_vec()).unwrap(), node)
657 }
658}