Skip to main content

ethrex_trie/
rlp.rs

1use std::array;
2
3// Contains RLP encoding and decoding implementations for Trie Nodes
4// This encoding is only used to store the nodes in the DB, it is not the encoding used for hash computation
5use ethrex_rlp::{
6    constants::RLP_NULL,
7    decode::{RLPDecode, decode_bytes},
8    encode::{RLPEncode, encode_length},
9    error::RLPDecodeError,
10    structs::{Decoder, Encoder},
11};
12
13use ethrex_crypto::NativeCrypto;
14
15use super::node::{BranchNode, ExtensionNode, LeafNode, Node};
16use crate::{Nibbles, NodeHash};
17
18// SAFETY: `NativeCrypto` is used here instead of a `&dyn Crypto` parameter because
19// `RLPEncode` is a fixed trait signature that cannot accept extra parameters.
20// This is safe in the `commit()` path: `NodeRef::commit()` recursively populates
21// child `OnceLock` hashes before calling `encode()`, so `compute_hash_ref` returns
22// cached values without invoking keccak. If `encode()` were called on uncommitted
23// nodes (e.g. from `put_batch_no_alloc`), `NativeCrypto` would be used and the
24// result stored in the `OnceLock` — but this only happens in native storage paths
25// where `NativeCrypto` is the correct provider.
26impl RLPEncode for BranchNode {
27    fn encode(&self, buf: &mut dyn bytes::BufMut) {
28        let value_len = <[u8] as RLPEncode>::length(&self.value);
29        let payload_len = self.choices.iter().fold(value_len, |acc, child| {
30            acc + RLPEncode::length(child.compute_hash_ref(&NativeCrypto))
31        });
32
33        encode_length(payload_len, buf);
34        for child in self.choices.iter() {
35            match child.compute_hash_ref(&NativeCrypto) {
36                NodeHash::Hashed(hash) => hash.0.encode(buf),
37                NodeHash::Inline((_, 0)) => buf.put_u8(RLP_NULL),
38                NodeHash::Inline((encoded, len)) => buf.put_slice(&encoded[..*len as usize]),
39            }
40        }
41        <[u8] as RLPEncode>::encode(&self.value, buf);
42    }
43
44    // Duplicated to prealloc the buffer and avoid calculating the payload length twice
45    fn encode_to_vec(&self) -> Vec<u8> {
46        let value_len = <[u8] as RLPEncode>::length(&self.value);
47        let choices_len = self.choices.iter().fold(0, |acc, child| {
48            acc + RLPEncode::length(child.compute_hash_ref(&NativeCrypto))
49        });
50        let payload_len = choices_len + value_len;
51
52        let mut buf: Vec<u8> = Vec::with_capacity(payload_len + 3); // 3 byte prefix headroom
53
54        encode_length(payload_len, &mut buf);
55        for child in self.choices.iter() {
56            match child.compute_hash_ref(&NativeCrypto) {
57                NodeHash::Hashed(hash) => hash.0.encode(&mut buf),
58                NodeHash::Inline((_, 0)) => buf.push(RLP_NULL),
59                NodeHash::Inline((encoded, len)) => {
60                    buf.extend_from_slice(&encoded[..*len as usize])
61                }
62            }
63        }
64        <[u8] as RLPEncode>::encode(&self.value, &mut buf);
65
66        buf
67    }
68}
69
70impl RLPEncode for ExtensionNode {
71    fn encode(&self, buf: &mut dyn bytes::BufMut) {
72        let mut encoder = Encoder::new(buf).encode_bytes(&self.prefix.encode_compact());
73        encoder = self.child.compute_hash(&NativeCrypto).encode(encoder);
74        encoder.finish();
75    }
76}
77
78impl RLPEncode for LeafNode {
79    fn encode(&self, buf: &mut dyn bytes::BufMut) {
80        Encoder::new(buf)
81            .encode_bytes(&self.partial.encode_compact())
82            .encode_bytes(&self.value)
83            .finish()
84    }
85}
86
87impl RLPEncode for Node {
88    fn encode(&self, buf: &mut dyn bytes::BufMut) {
89        match self {
90            Node::Branch(n) => n.encode(buf),
91            Node::Extension(n) => n.encode(buf),
92            Node::Leaf(n) => n.encode(buf),
93        }
94    }
95}
96
97impl RLPDecode for Node {
98    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
99        let mut rlp_items_len = 0;
100        let mut rlp_items: [Option<&[u8]>; 17] = Default::default();
101        let mut decoder = Decoder::new(rlp)?;
102        let mut item;
103        // Get encoded fields
104
105        // Check if we reached the end or if we decoded more items than the ones we need
106        while !decoder.is_done() && rlp_items_len < 17 {
107            (item, decoder) = decoder.get_encoded_item_ref()?;
108            rlp_items[rlp_items_len] = Some(item);
109            rlp_items_len += 1;
110        }
111        if !decoder.is_done() {
112            return Err(RLPDecodeError::Custom(
113                "Invalid arg count for Node, expected 2 or 17, got more than 17".to_string(),
114            ));
115        }
116        // Deserialize into node depending on the available fields
117        let node = match rlp_items_len {
118            // Leaf or Extension Node
119            2 => {
120                let (path, _) = decode_bytes(rlp_items[0].expect("we already checked the length"))?;
121                let path = Nibbles::decode_compact(path);
122                if path.is_leaf() {
123                    // Decode as Leaf
124                    let (value, _) =
125                        decode_bytes(rlp_items[1].expect("we already checked the length"))?;
126                    LeafNode {
127                        partial: path,
128                        value: value.to_vec(),
129                    }
130                    .into()
131                } else {
132                    // Decode as Extension
133                    ExtensionNode {
134                        prefix: path,
135                        child: decode_child(rlp_items[1].expect("we already checked the length"))
136                            .into(),
137                    }
138                    .into()
139                }
140            }
141            // Branch Node
142            17 => {
143                let choices = array::from_fn(|i| {
144                    decode_child(rlp_items[i].expect("we already checked the length")).into()
145                });
146                let (value, _) =
147                    decode_bytes(rlp_items[16].expect("we already checked the length"))?;
148                BranchNode {
149                    choices,
150                    value: value.to_vec(),
151                }
152                .into()
153            }
154            n => {
155                return Err(RLPDecodeError::Custom(format!(
156                    "Invalid arg count for Node, expected 2 or 17, got {n}"
157                )));
158            }
159        };
160        Ok((node, decoder.finish()?))
161    }
162}
163
164fn decode_child(rlp: &[u8]) -> NodeHash {
165    match decode_bytes(rlp) {
166        Ok((hash, &[])) if hash.len() == 32 => NodeHash::from_slice(hash),
167        Ok((&[], &[])) => NodeHash::default(),
168        _ => NodeHash::from_slice(rlp),
169    }
170}