Skip to main content

ethrex_trie/
node.rs

1mod branch;
2mod extension;
3mod leaf;
4
5use std::sync::Arc;
6#[cfg(not(all(feature = "eip-8025", target_arch = "riscv64")))]
7use std::sync::OnceLock;
8
9/// `OnceLock` replacement for zkVM guest gated on `eip-8025` feature
10///
11/// `std::sync::OnceLock` atomics are pure overhead in zkVM guest.
12/// This struct copies the methods from `once_cell::unsync::OnceCell` and uses unsafe
13/// to get around the Sync requirement.
14///
15/// This code is only sound because the guest is guaranteed to be single-threaded.
16#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
17pub struct OnceLock<T>(core::cell::UnsafeCell<Option<T>>);
18
19#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
20unsafe impl<T: Sync> Sync for OnceLock<T> {}
21
22#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
23impl<T> OnceLock<T> {
24    #[inline]
25    fn new() -> Self {
26        Self(core::cell::UnsafeCell::new(None))
27    }
28
29    #[inline]
30    fn get(&self) -> Option<&T> {
31        unsafe { &*self.0.get() }.as_ref()
32    }
33
34    #[inline]
35    fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
36        match self.get_or_try_init(|| Ok::<T, core::convert::Infallible>(f())) {
37            Ok(val) => val,
38            Err(e) => match e {},
39        }
40    }
41
42    #[inline]
43    fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
44        if let Some(val) = self.get() {
45            return Ok(val);
46        }
47        self.try_init(f)
48    }
49
50    #[inline]
51    fn set(&self, value: T) -> Result<(), T> {
52        match self.try_insert(value) {
53            Ok(_) => Ok(()),
54            Err((_, value)) => Err(value),
55        }
56    }
57
58    #[inline]
59    fn try_insert(&self, value: T) -> Result<&T, (&T, T)> {
60        if let Some(old) = self.get() {
61            return Err((old, value));
62        }
63        let slot = unsafe { &mut *self.0.get() };
64        Ok(slot.insert(value))
65    }
66
67    #[inline]
68    fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
69        let val = f()?;
70        let slot = unsafe { &mut *self.0.get() };
71        debug_assert!(slot.is_none());
72        Ok(slot.insert(val))
73    }
74
75    #[inline]
76    fn take(&mut self) -> Option<T> {
77        self.0.get_mut().take()
78    }
79}
80
81#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
82impl<T: PartialEq> PartialEq for OnceLock<T> {
83    #[inline]
84    fn eq(&self, other: &Self) -> bool {
85        self.get() == other.get()
86    }
87}
88
89#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
90impl<T> Default for OnceLock<T> {
91    #[inline]
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
98impl<T: Eq> Eq for OnceLock<T> {}
99
100#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
101impl<T: Clone> Clone for OnceLock<T> {
102    #[inline]
103    fn clone(&self) -> OnceLock<T> {
104        match self.get() {
105            Some(value) => OnceLock::from(value.clone()),
106            None => OnceLock::new(),
107        }
108    }
109}
110
111#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
112impl<T: std::fmt::Debug> std::fmt::Debug for OnceLock<T> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        let mut d = f.debug_tuple("OnceLock");
115        match self.get() {
116            Some(v) => d.field(v),
117            None => d.field(&format_args!("<uninit>")),
118        };
119        d.finish()
120    }
121}
122
123#[cfg(all(feature = "eip-8025", target_arch = "riscv64"))]
124impl<T> From<T> for OnceLock<T> {
125    #[inline]
126    fn from(value: T) -> Self {
127        OnceLock {
128            0: core::cell::UnsafeCell::new(Some(value)),
129        }
130    }
131}
132
133pub use branch::BranchNode;
134use ethrex_rlp::{decode::RLPDecode, encode::RLPEncode};
135pub use extension::ExtensionNode;
136pub use leaf::LeafNode;
137use rkyv::{
138    de::Pooling,
139    rancor::Source,
140    ser::{Allocator, Sharing, Writer},
141    validation::{ArchiveContext, SharedContext},
142    with::Skip,
143};
144
145use ethrex_crypto::{Crypto, NativeCrypto};
146
147use crate::{NodeRLP, TrieDB, error::TrieError, nibbles::Nibbles};
148
149use super::{ValueRLP, node_hash::NodeHash};
150
151/// A reference to a node.
152///
153/// Explicit rkyv bounds are needed because this is a recursive type, whose
154/// bounds can't be automatically resolved.
155#[derive(
156    Clone,
157    Debug,
158    serde::Serialize,
159    serde::Deserialize,
160    rkyv::Serialize,
161    rkyv::Deserialize,
162    rkyv::Archive,
163)]
164#[rkyv(serialize_bounds(__S: Writer + Allocator + Sharing, __S::Error: Source))]
165#[rkyv(deserialize_bounds(__D: Pooling, __D::Error: Source))]
166#[rkyv(bytecheck(bounds(__C: ArchiveContext + SharedContext)))]
167pub enum NodeRef {
168    /// The node is embedded within the reference.
169    Node(
170        #[rkyv(omit_bounds)] Arc<Node>,
171        #[rkyv(with = Skip)]
172        #[serde(skip)]
173        OnceLock<NodeHash>,
174    ),
175    /// The node is in the database, referenced by its hash.
176    Hash(NodeHash),
177}
178
179impl NodeRef {
180    /// Gets a shared reference to the inner node.
181    /// Requires that the trie is in a consistent state, ie that all leaves being pointed are in the database.
182    /// Outside of snapsync this should always be the case.
183    pub fn get_node(&self, db: &dyn TrieDB, path: Nibbles) -> Result<Option<Arc<Node>>, TrieError> {
184        match self {
185            NodeRef::Node(node, _) => Ok(Some(node.clone())),
186            NodeRef::Hash(hash @ NodeHash::Inline(_)) => {
187                Ok(Some(Arc::new(Node::decode(hash.as_ref())?)))
188            }
189            NodeRef::Hash(_) => db
190                .get(path)?
191                .filter(|rlp| !rlp.is_empty())
192                .map(|rlp| Ok(Arc::new(Node::decode(&rlp)?)))
193                .transpose(),
194        }
195    }
196
197    /// Gets a shared reference to the inner node, checking its hash.
198    /// Returns `Ok(None)` if the hash is invalid.
199    ///
200    /// Uses `NativeCrypto` directly because this function is only reachable from
201    /// native storage/sync paths (`get_root_node`, `get_proof`, `validate`,
202    /// `verify_range`, trie iterator) — never from the guest program path, which
203    /// traverses via `Node::get()`.
204    pub fn get_node_checked(
205        &self,
206        db: &dyn TrieDB,
207        path: Nibbles,
208    ) -> Result<Option<Arc<Node>>, TrieError> {
209        match self {
210            NodeRef::Node(node, _) => Ok(Some(node.clone())),
211            NodeRef::Hash(hash @ NodeHash::Inline(_)) => {
212                Ok(Some(Arc::new(Node::decode(hash.as_ref())?)))
213            }
214            NodeRef::Hash(hash @ NodeHash::Hashed(_)) => {
215                db.get(path)?
216                    .filter(|rlp| !rlp.is_empty())
217                    .and_then(|rlp| match Node::decode(&rlp) {
218                        Ok(node) => (node.compute_hash(&NativeCrypto) == *hash)
219                            .then_some(Ok(Arc::new(node))),
220                        Err(err) => Some(Err(TrieError::RLPDecode(err))),
221                    })
222                    .transpose()
223            }
224        }
225    }
226
227    /// Gets a mutable shared reference to the inner node.
228    ///
229    /// # Caution
230    ///
231    /// 1. If more than one strong reference exists to this node, it will be cloned (see `Arc::make_mut`).
232    /// 2. Mutating the inner node without updating parents can lead to trie inconsistencies.
233    pub(crate) fn get_node_mut(
234        &mut self,
235        db: &dyn TrieDB,
236        path: Nibbles,
237    ) -> Result<Option<&mut Node>, TrieError> {
238        match self {
239            NodeRef::Node(node, _) => Ok(Some(Arc::make_mut(node))),
240            NodeRef::Hash(hash @ NodeHash::Inline(_)) => {
241                let node = Node::decode(hash.as_ref())?;
242                *self = NodeRef::Node(Arc::new(node), OnceLock::from(*hash));
243                self.get_node_mut(db, path)
244            }
245            NodeRef::Hash(hash @ NodeHash::Hashed(_)) => {
246                let Some(node) = db
247                    .get(path.clone())?
248                    .filter(|rlp| !rlp.is_empty())
249                    .map(|rlp| Node::decode(&rlp).map_err(TrieError::RLPDecode))
250                    .transpose()?
251                else {
252                    return Ok(None);
253                };
254                *self = NodeRef::Node(Arc::new(node), OnceLock::from(*hash));
255                self.get_node_mut(db, path)
256            }
257        }
258    }
259
260    pub fn is_valid(&self) -> bool {
261        match self {
262            NodeRef::Node(_, _) => true,
263            NodeRef::Hash(hash) => hash.is_valid(),
264        }
265    }
266
267    pub fn commit(
268        &mut self,
269        path: Nibbles,
270        acc: &mut Vec<(Nibbles, Vec<u8>)>,
271        crypto: &dyn Crypto,
272    ) -> NodeHash {
273        match *self {
274            NodeRef::Node(ref mut node, ref mut hash) => {
275                if let Some(hash) = hash.get() {
276                    return *hash;
277                }
278                match Arc::make_mut(node) {
279                    Node::Branch(node) => {
280                        for (choice, node) in &mut node.choices.iter_mut().enumerate() {
281                            node.commit(path.append_new(choice as u8), acc, crypto);
282                        }
283                    }
284                    Node::Extension(node) => {
285                        node.child.commit(path.concat(&node.prefix), acc, crypto);
286                    }
287                    Node::Leaf(_) => {}
288                }
289                let mut buf = Vec::new();
290                node.encode(&mut buf);
291                let hash = *hash.get_or_init(|| NodeHash::from_encoded(&buf, crypto));
292                if let Node::Leaf(leaf) = node.as_ref() {
293                    acc.push((path.concat(&leaf.partial), leaf.value.clone()));
294                }
295                acc.push((path, buf));
296
297                hash
298            }
299            NodeRef::Hash(hash) => hash,
300        }
301    }
302
303    pub fn compute_hash(&self, crypto: &dyn Crypto) -> NodeHash {
304        *self.compute_hash_ref(crypto)
305    }
306
307    pub fn compute_hash_ref(&self, crypto: &dyn Crypto) -> &NodeHash {
308        match self {
309            NodeRef::Node(node, hash) => hash.get_or_init(|| node.compute_hash(crypto)),
310            NodeRef::Hash(hash) => hash,
311        }
312    }
313
314    pub fn compute_hash_no_alloc(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) -> &NodeHash {
315        match self {
316            NodeRef::Node(node, hash) => {
317                hash.get_or_init(|| node.compute_hash_no_alloc(buf, crypto))
318            }
319            NodeRef::Hash(hash) => hash,
320        }
321    }
322
323    pub fn memoize_hashes(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) {
324        if let NodeRef::Node(node, hash) = &self
325            && hash.get().is_none()
326        {
327            node.memoize_hashes(buf, crypto);
328            let _ = hash.set(node.compute_hash_no_alloc(buf, crypto));
329        }
330    }
331
332    /// Resets the memoized hash of this Node
333    ///
334    /// This is used when mutating a node in place, in which case the memoized hash
335    /// is not valid anymore.
336    pub fn clear_hash(&mut self) {
337        if let NodeRef::Node(_, hash) = self {
338            hash.take();
339        }
340    }
341}
342
343impl Default for NodeRef {
344    fn default() -> Self {
345        Self::Hash(NodeHash::default())
346    }
347}
348
349impl From<Node> for NodeRef {
350    fn from(value: Node) -> Self {
351        Self::Node(Arc::new(value), OnceLock::new())
352    }
353}
354
355impl From<NodeHash> for NodeRef {
356    fn from(value: NodeHash) -> Self {
357        Self::Hash(value)
358    }
359}
360
361impl From<Arc<Node>> for NodeRef {
362    fn from(value: Arc<Node>) -> Self {
363        Self::Node(value, OnceLock::new())
364    }
365}
366
367impl PartialEq for NodeRef {
368    fn eq(&self, other: &Self) -> bool {
369        let mut buf = Vec::new();
370        self.compute_hash_no_alloc(&mut buf, &NativeCrypto)
371            == other.compute_hash_no_alloc(&mut buf, &NativeCrypto)
372    }
373}
374
375pub enum ValueOrHash {
376    Value(ValueRLP),
377    Hash(NodeHash),
378}
379
380impl From<ValueRLP> for ValueOrHash {
381    fn from(value: ValueRLP) -> Self {
382        Self::Value(value)
383    }
384}
385
386impl From<NodeHash> for ValueOrHash {
387    fn from(value: NodeHash) -> Self {
388        Self::Hash(value)
389    }
390}
391
392#[derive(
393    Debug,
394    Clone,
395    PartialEq,
396    serde::Serialize,
397    serde::Deserialize,
398    rkyv::Deserialize,
399    rkyv::Serialize,
400    rkyv::Archive,
401)]
402/// A Node in an Ethereum Compatible Patricia Merkle Trie
403pub enum Node {
404    Branch(Box<BranchNode>),
405    Extension(ExtensionNode),
406    Leaf(LeafNode),
407}
408
409impl Default for Node {
410    fn default() -> Self {
411        // empty leaf node as a placeholder
412        Self::Leaf(LeafNode {
413            partial: Nibbles::from_bytes(&[]),
414            value: Vec::new(),
415        })
416    }
417}
418
419impl From<Box<BranchNode>> for Node {
420    fn from(val: Box<BranchNode>) -> Self {
421        Node::Branch(val)
422    }
423}
424
425impl From<BranchNode> for Node {
426    fn from(val: BranchNode) -> Self {
427        Node::Branch(Box::new(val))
428    }
429}
430
431impl From<ExtensionNode> for Node {
432    fn from(val: ExtensionNode) -> Self {
433        Node::Extension(val)
434    }
435}
436
437impl From<LeafNode> for Node {
438    fn from(val: LeafNode) -> Self {
439        Node::Leaf(val)
440    }
441}
442
443impl Node {
444    /// Retrieves a value from the subtrie originating from this node given its path
445    pub fn get(&self, db: &dyn TrieDB, path: Nibbles) -> Result<Option<ValueRLP>, TrieError> {
446        match self {
447            Node::Branch(n) => n.get(db, path),
448            Node::Extension(n) => n.get(db, path),
449            Node::Leaf(n) => n.get(path),
450        }
451    }
452
453    /// Inserts a value into the subtrie originating from this node.
454    pub fn insert(
455        &mut self,
456        db: &dyn TrieDB,
457        path: Nibbles,
458        value: impl Into<ValueOrHash>,
459    ) -> Result<(), TrieError> {
460        let new_node = match self {
461            Node::Branch(n) => {
462                n.insert(db, path, value.into())?;
463                Ok(None)
464            }
465            Node::Extension(n) => n.insert(db, path, value.into()),
466            Node::Leaf(n) => n.insert(path, value.into()),
467        };
468        if let Some(new_node) = new_node? {
469            *self = new_node;
470        }
471        Ok(())
472    }
473
474    /// Removes a value from the subtrie originating from this node given its path
475    /// Returns a bool indicating if the new subtrie is empty, and the removed value if it existed in the subtrie
476    pub fn remove(
477        &mut self,
478        db: &dyn TrieDB,
479        path: Nibbles,
480    ) -> Result<(bool, Option<ValueRLP>), TrieError> {
481        let (new_root, value) = match self {
482            Node::Branch(n) => n.remove(db, path),
483            Node::Extension(n) => n.remove(db, path),
484            Node::Leaf(n) => n.remove(path),
485        }?;
486
487        let is_trie_empty = new_root.is_none();
488        if let Some(NodeRemoveResult::New(new_root)) = new_root {
489            *self = new_root;
490        }
491        Ok((is_trie_empty, value))
492    }
493
494    /// Traverses own subtrie until reaching the node containing `path`
495    /// Appends all encoded nodes traversed to `node_path` (including self)
496    /// Only nodes with encoded len over or equal to 32 bytes are included
497    pub fn get_path(
498        &self,
499        db: &dyn TrieDB,
500        path: Nibbles,
501        node_path: &mut Vec<Vec<u8>>,
502    ) -> Result<(), TrieError> {
503        match self {
504            Node::Branch(n) => n.get_path(db, path, node_path),
505            Node::Extension(n) => n.get_path(db, path, node_path),
506            Node::Leaf(n) => n.get_path(node_path),
507        }
508    }
509
510    /// Computes the node's hash
511    pub fn compute_hash(&self, crypto: &dyn Crypto) -> NodeHash {
512        let mut buf = Vec::new();
513        self.memoize_hashes(&mut buf, crypto);
514        match self {
515            Node::Branch(n) => n.compute_hash_no_alloc(&mut buf, crypto),
516            Node::Extension(n) => n.compute_hash_no_alloc(&mut buf, crypto),
517            Node::Leaf(n) => n.compute_hash_no_alloc(&mut buf, crypto),
518        }
519    }
520
521    /// Computes the node's hash
522    pub fn compute_hash_no_alloc(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) -> NodeHash {
523        self.memoize_hashes(buf, crypto);
524        match self {
525            Node::Branch(n) => n.compute_hash_no_alloc(buf, crypto),
526            Node::Extension(n) => n.compute_hash_no_alloc(buf, crypto),
527            Node::Leaf(n) => n.compute_hash_no_alloc(buf, crypto),
528        }
529    }
530
531    /// Recursively memoizes the hashes of all nodes of the subtrie that has
532    /// `self` as root (post-order traversal)
533    pub fn memoize_hashes(&self, buf: &mut Vec<u8>, crypto: &dyn Crypto) {
534        match self {
535            Node::Branch(n) => {
536                for child in &n.choices {
537                    child.memoize_hashes(buf, crypto);
538                }
539            }
540            Node::Extension(n) => n.child.memoize_hashes(buf, crypto),
541            _ => {}
542        }
543    }
544
545    /// Recursively encodes all embedded nodes of the subtrie that has
546    /// `self` as root.
547    ///
548    /// This won't encode nodes which are not embedded in `self`.
549    pub fn encode_subtrie(&self, encoded: &mut Vec<NodeRLP>) -> Result<(), TrieError> {
550        match self {
551            Node::Branch(node) => {
552                for choice in &node.choices {
553                    if let NodeRef::Node(choice, _) = choice {
554                        choice.encode_subtrie(encoded)?;
555                    }
556                }
557            }
558            Node::Extension(node) => {
559                if let NodeRef::Node(child, _) = &node.child {
560                    child.encode_subtrie(encoded)?;
561                }
562            }
563            Node::Leaf(_) => {}
564        };
565
566        encoded.push(self.encode_to_vec());
567        Ok(())
568    }
569}
570
571/// Used as return type for `Node` remove operations that may resolve into either:
572/// - a mutation of the `Node`
573/// - a new `Node`
574pub enum NodeRemoveResult {
575    Mutated,
576    New(Node),
577}