1mod branch;
2mod extension;
3mod leaf;
4
5use std::sync::Arc;
6#[cfg(not(all(feature = "eip-8025", target_arch = "riscv64")))]
7use std::sync::OnceLock;
8
9#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
17pub struct OnceLock<T>(core::cell::UnsafeCell<Option<T>>);
18
19#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
20unsafe impl<T: Sync> Sync for OnceLock<T> {}
21
22#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
23impl<T> OnceLock<T> {
24 #[inline]
25 fn new() -> Self {
26 Self(core::cell::UnsafeCell::new(None))
27 }
28
29 #[inline]
30 fn get(&self) -> Option<&T> {
31 unsafe { &*self.0.get() }.as_ref()
32 }
33
34 #[inline]
35 fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
36 match self.get_or_try_init(|| Ok::<T, core::convert::Infallible>(f())) {
37 Ok(val) => val,
38 Err(e) => match e {},
39 }
40 }
41
42 #[inline]
43 fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
44 if let Some(val) = self.get() {
45 return Ok(val);
46 }
47 self.try_init(f)
48 }
49
50 #[inline]
51 fn set(&self, value: T) -> Result<(), T> {
52 match self.try_insert(value) {
53 Ok(_) => Ok(()),
54 Err((_, value)) => Err(value),
55 }
56 }
57
58 #[inline]
59 fn try_insert(&self, value: T) -> Result<&T, (&T, T)> {
60 if let Some(old) = self.get() {
61 return Err((old, value));
62 }
63 let slot = unsafe { &mut *self.0.get() };
64 Ok(slot.insert(value))
65 }
66
67 #[inline]
68 fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
69 let val = f()?;
70 let slot = unsafe { &mut *self.0.get() };
71 debug_assert!(slot.is_none());
72 Ok(slot.insert(val))
73 }
74
75 #[inline]
76 fn take(&mut self) -> Option<T> {
77 self.0.get_mut().take()
78 }
79}
80
81#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
82impl<T: PartialEq> PartialEq for OnceLock<T> {
83 #[inline]
84 fn eq(&self, other: &Self) -> bool {
85 self.get() == other.get()
86 }
87}
88
89#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
90impl<T> Default for OnceLock<T> {
91 #[inline]
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
98impl<T: Eq> Eq for OnceLock<T> {}
99
100#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
101impl<T: Clone> Clone for OnceLock<T> {
102 #[inline]
103 fn clone(&self) -> OnceLock<T> {
104 match self.get() {
105 Some(value) => OnceLock::from(value.clone()),
106 None => OnceLock::new(),
107 }
108 }
109}
110
111#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
112impl<T: std::fmt::Debug> std::fmt::Debug for OnceLock<T> {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 let mut d = f.debug_tuple("OnceLock");
115 match self.get() {
116 Some(v) => d.field(v),
117 None => d.field(&format_args!("<uninit>")),
118 };
119 d.finish()
120 }
121}
122
123#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
124impl<T> From<T> for OnceLock<T> {
125 #[inline]
126 fn from(value: T) -> Self {
127 OnceLock {
128 0: core::cell::UnsafeCell::new(Some(value)),
129 }
130 }
131}
132
133pub use branch::BranchNode;
134use ethrex_rlp::{decode::RLPDecode, encode::RLPEncode};
135pub use extension::ExtensionNode;
136pub use leaf::LeafNode;
137use rkyv::{
138 de::Pooling,
139 rancor::Source,
140 ser::{Allocator, Sharing, Writer},
141 validation::{ArchiveContext, SharedContext},
142 with::Skip,
143};
144
145use ethrex_crypto::{Crypto, NativeCrypto};
146
147use crate::{NodeRLP, TrieDB, error::TrieError, nibbles::Nibbles};
148
149use super::{ValueRLP, node_hash::NodeHash};
150
151#[derive(
156 Clone,
157 Debug,
158 serde::Serialize,
159 serde::Deserialize,
160 rkyv::Serialize,
161 rkyv::Deserialize,
162 rkyv::Archive,
163)]
164#[rkyv(serialize_bounds(__S: Writer + Allocator + Sharing, __S::Error: Source))]
165#[rkyv(deserialize_bounds(__D: Pooling, __D::Error: Source))]
166#[rkyv(bytecheck(bounds(__C: ArchiveContext + SharedContext)))]
167pub enum NodeRef {
168 Node(
170 #[rkyv(omit_bounds)] Arc<Node>,
171 #[rkyv(with = Skip)]
172 #[serde(skip)]
173 OnceLock<NodeHash>,
174 ),
175 Hash(NodeHash),
177}
178
179impl NodeRef {
180 pub fn get_node(&self, db: &dyn TrieDB, path: Nibbles) -> Result<Option<Arc<Node>>, TrieError> {
184 match self {
185 NodeRef::Node(node, _) => Ok(Some(node.clone())),
186 NodeRef::Hash(hash @ NodeHash::Inline(_)) => {
187 Ok(Some(Arc::new(Node::decode(hash.as_ref())?)))
188 }
189 NodeRef::Hash(_) => db
190 .get(path)?
191 .filter(|rlp| !rlp.is_empty())
192 .map(|rlp| Ok(Arc::new(Node::decode(&rlp)?)))
193 .transpose(),
194 }
195 }
196
197 pub fn get_node_checked(
205 &self,
206 db: &dyn TrieDB,
207 path: Nibbles,
208 ) -> Result<Option<Arc<Node>>, TrieError> {
209 match self {
210 NodeRef::Node(node, _) => Ok(Some(node.clone())),
211 NodeRef::Hash(hash @ NodeHash::Inline(_)) => {
212 Ok(Some(Arc::new(Node::decode(hash.as_ref())?)))
213 }
214 NodeRef::Hash(hash @ NodeHash::Hashed(_)) => {
215 db.get(path)?
216 .filter(|rlp| !rlp.is_empty())
217 .and_then(|rlp| match Node::decode(&rlp) {
218 Ok(node) => (node.compute_hash(&NativeCrypto) == *hash)
219 .then_some(Ok(Arc::new(node))),
220 Err(err) => Some(Err(TrieError::RLPDecode(err))),
221 })
222 .transpose()
223 }
224 }
225 }
226
227 pub(crate) fn get_node_mut(
234 &mut self,
235 db: &dyn TrieDB,
236 path: Nibbles,
237 ) -> Result<Option<&mut Node>, TrieError> {
238 match self {
239 NodeRef::Node(node, _) => Ok(Some(Arc::make_mut(node))),
240 NodeRef::Hash(hash @ NodeHash::Inline(_)) => {
241 let node = Node::decode(hash.as_ref())?;
242 *self = NodeRef::Node(Arc::new(node), OnceLock::from(*hash));
243 self.get_node_mut(db, path)
244 }
245 NodeRef::Hash(hash @ NodeHash::Hashed(_)) => {
246 let Some(node) = db
247 .get(path.clone())?
248 .filter(|rlp| !rlp.is_empty())
249 .map(|rlp| Node::decode(&rlp).map_err(TrieError::RLPDecode))
250 .transpose()?
251 else {
252 return Ok(None);
253 };
254 *self = NodeRef::Node(Arc::new(node), OnceLock::from(*hash));
255 self.get_node_mut(db, path)
256 }
257 }
258 }
259
260 pub fn is_valid(&self) -> bool {
261 match self {
262 NodeRef::Node(_, _) => true,
263 NodeRef::Hash(hash) => hash.is_valid(),
264 }
265 }
266
267 pub fn commit(
268 &mut self,
269 path: Nibbles,
270 acc: &mut Vec<(Nibbles, Vec<u8>)>,
271 crypto: &dyn Crypto,
272 ) -> NodeHash {
273 match *self {
274 NodeRef::Node(ref mut node, ref mut hash) => {
275 if let Some(hash) = hash.get() {
276 return *hash;
277 }
278 match Arc::make_mut(node) {
279 Node::Branch(node) => {
280 for (choice, node) in &mut node.choices.iter_mut().enumerate() {
281 node.commit(path.append_new(choice as u8), acc, crypto);
282 }
283 }
284 Node::Extension(node) => {
285 node.child.commit(path.concat(&node.prefix), acc, crypto);
286 }
287 Node::Leaf(_) => {}
288 }
289 let mut buf = Vec::new();
290 node.encode(&mut buf);
291 let hash = *hash.get_or_init(|| NodeHash::from_encoded(&buf, crypto));
292 if let Node::Leaf(leaf) = node.as_ref() {
293 acc.push((path.concat(&leaf.partial), leaf.value.clone()));
294 }
295 acc.push((path, buf));
296
297 hash
298 }
299 NodeRef::Hash(hash) => hash,
300 }
301 }
302
303 pub fn compute_hash(&self, crypto: &dyn Crypto) -> NodeHash {
304 *self.compute_hash_ref(crypto)
305 }
306
307 pub fn compute_hash_ref(&self, crypto: &dyn Crypto) -> &NodeHash {
308 match self {
309 NodeRef::Node(node, hash) => hash.get_or_init(|| node.compute_hash(crypto)),
310 NodeRef::Hash(hash) => hash,
311 }
312 }
313
314 pub fn compute_hash_no_alloc(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) -> &NodeHash {
315 match self {
316 NodeRef::Node(node, hash) => {
317 hash.get_or_init(|| node.compute_hash_no_alloc(buf, crypto))
318 }
319 NodeRef::Hash(hash) => hash,
320 }
321 }
322
323 pub fn memoize_hashes(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) {
324 if let NodeRef::Node(node, hash) = &self
325 && hash.get().is_none()
326 {
327 node.memoize_hashes(buf, crypto);
328 let _ = hash.set(node.compute_hash_no_alloc(buf, crypto));
329 }
330 }
331
332 pub fn clear_hash(&mut self) {
337 if let NodeRef::Node(_, hash) = self {
338 hash.take();
339 }
340 }
341}
342
343impl Default for NodeRef {
344 fn default() -> Self {
345 Self::Hash(NodeHash::default())
346 }
347}
348
349impl From<Node> for NodeRef {
350 fn from(value: Node) -> Self {
351 Self::Node(Arc::new(value), OnceLock::new())
352 }
353}
354
355impl From<NodeHash> for NodeRef {
356 fn from(value: NodeHash) -> Self {
357 Self::Hash(value)
358 }
359}
360
361impl From<Arc<Node>> for NodeRef {
362 fn from(value: Arc<Node>) -> Self {
363 Self::Node(value, OnceLock::new())
364 }
365}
366
367impl PartialEq for NodeRef {
368 fn eq(&self, other: &Self) -> bool {
369 let mut buf = Vec::new();
370 self.compute_hash_no_alloc(&mut buf, &NativeCrypto)
371 == other.compute_hash_no_alloc(&mut buf, &NativeCrypto)
372 }
373}
374
375pub enum ValueOrHash {
376 Value(ValueRLP),
377 Hash(NodeHash),
378}
379
380impl From<ValueRLP> for ValueOrHash {
381 fn from(value: ValueRLP) -> Self {
382 Self::Value(value)
383 }
384}
385
386impl From<NodeHash> for ValueOrHash {
387 fn from(value: NodeHash) -> Self {
388 Self::Hash(value)
389 }
390}
391
392#[derive(
393 Debug,
394 Clone,
395 PartialEq,
396 serde::Serialize,
397 serde::Deserialize,
398 rkyv::Deserialize,
399 rkyv::Serialize,
400 rkyv::Archive,
401)]
402pub enum Node {
404 Branch(Box<BranchNode>),
405 Extension(ExtensionNode),
406 Leaf(LeafNode),
407}
408
409impl Default for Node {
410 fn default() -> Self {
411 Self::Leaf(LeafNode {
413 partial: Nibbles::from_bytes(&[]),
414 value: Vec::new(),
415 })
416 }
417}
418
419impl From<Box<BranchNode>> for Node {
420 fn from(val: Box<BranchNode>) -> Self {
421 Node::Branch(val)
422 }
423}
424
425impl From<BranchNode> for Node {
426 fn from(val: BranchNode) -> Self {
427 Node::Branch(Box::new(val))
428 }
429}
430
431impl From<ExtensionNode> for Node {
432 fn from(val: ExtensionNode) -> Self {
433 Node::Extension(val)
434 }
435}
436
437impl From<LeafNode> for Node {
438 fn from(val: LeafNode) -> Self {
439 Node::Leaf(val)
440 }
441}
442
443impl Node {
444 pub fn get(&self, db: &dyn TrieDB, path: Nibbles) -> Result<Option<ValueRLP>, TrieError> {
446 match self {
447 Node::Branch(n) => n.get(db, path),
448 Node::Extension(n) => n.get(db, path),
449 Node::Leaf(n) => n.get(path),
450 }
451 }
452
453 pub fn insert(
455 &mut self,
456 db: &dyn TrieDB,
457 path: Nibbles,
458 value: impl Into<ValueOrHash>,
459 ) -> Result<(), TrieError> {
460 let new_node = match self {
461 Node::Branch(n) => {
462 n.insert(db, path, value.into())?;
463 Ok(None)
464 }
465 Node::Extension(n) => n.insert(db, path, value.into()),
466 Node::Leaf(n) => n.insert(path, value.into()),
467 };
468 if let Some(new_node) = new_node? {
469 *self = new_node;
470 }
471 Ok(())
472 }
473
474 pub fn remove(
477 &mut self,
478 db: &dyn TrieDB,
479 path: Nibbles,
480 ) -> Result<(bool, Option<ValueRLP>), TrieError> {
481 let (new_root, value) = match self {
482 Node::Branch(n) => n.remove(db, path),
483 Node::Extension(n) => n.remove(db, path),
484 Node::Leaf(n) => n.remove(path),
485 }?;
486
487 let is_trie_empty = new_root.is_none();
488 if let Some(NodeRemoveResult::New(new_root)) = new_root {
489 *self = new_root;
490 }
491 Ok((is_trie_empty, value))
492 }
493
494 pub fn get_path(
498 &self,
499 db: &dyn TrieDB,
500 path: Nibbles,
501 node_path: &mut Vec<Vec<u8>>,
502 ) -> Result<(), TrieError> {
503 match self {
504 Node::Branch(n) => n.get_path(db, path, node_path),
505 Node::Extension(n) => n.get_path(db, path, node_path),
506 Node::Leaf(n) => n.get_path(node_path),
507 }
508 }
509
510 pub fn compute_hash(&self, crypto: &dyn Crypto) -> NodeHash {
512 let mut buf = Vec::new();
513 self.memoize_hashes(&mut buf, crypto);
514 match self {
515 Node::Branch(n) => n.compute_hash_no_alloc(&mut buf, crypto),
516 Node::Extension(n) => n.compute_hash_no_alloc(&mut buf, crypto),
517 Node::Leaf(n) => n.compute_hash_no_alloc(&mut buf, crypto),
518 }
519 }
520
521 pub fn compute_hash_no_alloc(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) -> NodeHash {
523 self.memoize_hashes(buf, crypto);
524 match self {
525 Node::Branch(n) => n.compute_hash_no_alloc(buf, crypto),
526 Node::Extension(n) => n.compute_hash_no_alloc(buf, crypto),
527 Node::Leaf(n) => n.compute_hash_no_alloc(buf, crypto),
528 }
529 }
530
531 pub fn memoize_hashes(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) {
534 match self {
535 Node::Branch(n) => {
536 for child in &n.choices {
537 child.memoize_hashes(buf, crypto);
538 }
539 }
540 Node::Extension(n) => n.child.memoize_hashes(buf, crypto),
541 _ => {}
542 }
543 }
544
545 pub fn encode_subtrie(&self, encoded: &mut Vec<NodeRLP>) -> Result<(), TrieError> {
550 match self {
551 Node::Branch(node) => {
552 for choice in &node.choices {
553 if let NodeRef::Node(choice, _) = choice {
554 choice.encode_subtrie(encoded)?;
555 }
556 }
557 }
558 Node::Extension(node) => {
559 if let NodeRef::Node(child, _) = &node.child {
560 child.encode_subtrie(encoded)?;
561 }
562 }
563 Node::Leaf(_) => {}
564 };
565
566 encoded.push(self.encode_to_vec());
567 Ok(())
568 }
569}
570
571pub enum NodeRemoveResult {
575 Mutated,
576 New(Node),
577}