alloy_trie/nodes/
branch.rs

1use super::{super::TrieMask, RlpNode, CHILD_INDEX_RANGE};
2use alloy_primitives::{hex, B256};
3use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE};
4use core::{fmt, ops::Range, slice::Iter};
5
6#[allow(unused_imports)]
7use alloc::vec::Vec;
8
9/// A branch node in an Ethereum Merkle Patricia Trie.
10///
11/// Branch node is a 17-element array consisting of 16 slots that correspond to each hexadecimal
12/// character and an additional slot for a value. We do exclude the node value since all paths have
13/// a fixed size.
14#[derive(PartialEq, Eq, Clone, Default)]
15pub struct BranchNode {
16    /// The collection of RLP encoded children.
17    pub stack: Vec<RlpNode>,
18    /// The bitmask indicating the presence of children at the respective nibble positions
19    pub state_mask: TrieMask,
20}
21
22impl fmt::Debug for BranchNode {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("BranchNode")
25            .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
26            .field("state_mask", &self.state_mask)
27            .field("first_child_index", &self.as_ref().first_child_index())
28            .finish()
29    }
30}
31
32impl Encodable for BranchNode {
33    #[inline]
34    fn encode(&self, out: &mut dyn BufMut) {
35        self.as_ref().encode(out)
36    }
37
38    #[inline]
39    fn length(&self) -> usize {
40        self.as_ref().length()
41    }
42}
43
44impl Decodable for BranchNode {
45    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
46        let mut bytes = Header::decode_bytes(buf, true)?;
47
48        let mut stack = Vec::new();
49        let mut state_mask = TrieMask::default();
50        for index in CHILD_INDEX_RANGE {
51            // The buffer must contain empty string code for value.
52            if bytes.len() <= 1 {
53                return Err(alloy_rlp::Error::InputTooShort);
54            }
55
56            if bytes[0] == EMPTY_STRING_CODE {
57                bytes.advance(1);
58                continue;
59            }
60
61            // Decode without advancing
62            let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?;
63            let len = payload_length + length_of_length(payload_length);
64            stack.push(RlpNode::from_raw_rlp(&bytes[..len])?);
65            bytes.advance(len);
66            state_mask.set_bit(index);
67        }
68
69        // Consume empty string code for branch node value.
70        let bytes = Header::decode_bytes(&mut bytes, false)?;
71        if !bytes.is_empty() {
72            return Err(alloy_rlp::Error::Custom("branch values not supported"));
73        }
74        debug_assert!(bytes.is_empty(), "bytes {}", alloy_primitives::hex::encode(bytes));
75
76        Ok(Self { stack, state_mask })
77    }
78}
79
80impl BranchNode {
81    /// Creates a new branch node with the given stack and state mask.
82    pub const fn new(stack: Vec<RlpNode>, state_mask: TrieMask) -> Self {
83        Self { stack, state_mask }
84    }
85
86    /// Return branch node as [BranchNodeRef].
87    pub fn as_ref(&self) -> BranchNodeRef<'_> {
88        BranchNodeRef::new(&self.stack, self.state_mask)
89    }
90}
91
92/// A reference to [BranchNode] and its state mask.
93/// NOTE: The stack may contain more items that specified in the state mask.
94#[derive(Clone)]
95pub struct BranchNodeRef<'a> {
96    /// Reference to the collection of RLP encoded nodes.
97    /// NOTE: The referenced stack might have more items than the number of children
98    /// for this node. We should only ever access items starting from
99    /// [BranchNodeRef::first_child_index].
100    pub stack: &'a [RlpNode],
101    /// Reference to bitmask indicating the presence of children at
102    /// the respective nibble positions.
103    pub state_mask: TrieMask,
104}
105
106impl fmt::Debug for BranchNodeRef<'_> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("BranchNodeRef")
109            .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
110            .field("state_mask", &self.state_mask)
111            .field("first_child_index", &self.first_child_index())
112            .finish()
113    }
114}
115
116/// Implementation of RLP encoding for branch node in Ethereum Merkle Patricia Trie.
117/// Encode it as a 17-element list consisting of 16 slots that correspond to
118/// each child of the node (0-f) and an additional slot for a value.
119impl Encodable for BranchNodeRef<'_> {
120    #[inline]
121    fn encode(&self, out: &mut dyn BufMut) {
122        Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);
123
124        // Extend the RLP buffer with the present children
125        for (_, child) in self.children() {
126            if let Some(child) = child {
127                out.put_slice(child);
128            } else {
129                out.put_u8(EMPTY_STRING_CODE);
130            }
131        }
132
133        out.put_u8(EMPTY_STRING_CODE);
134    }
135
136    #[inline]
137    fn length(&self) -> usize {
138        let payload_length = self.rlp_payload_length();
139        payload_length + length_of_length(payload_length)
140    }
141}
142
143impl<'a> BranchNodeRef<'a> {
144    /// Create a new branch node from the stack of nodes.
145    #[inline]
146    pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self {
147        Self { stack, state_mask }
148    }
149
150    /// Returns the stack index of the first child for this node.
151    ///
152    /// # Panics
153    ///
154    /// If the stack length is less than number of children specified in state mask.
155    /// Means that the node is in inconsistent state.
156    #[inline]
157    pub fn first_child_index(&self) -> usize {
158        self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
159    }
160
161    /// Returns an iterator over children of the branch node.
162    #[inline]
163    pub fn children(&self) -> impl Iterator<Item = (u8, Option<&RlpNode>)> + '_ {
164        BranchChildrenIter::new(self)
165    }
166
167    /// Given the hash mask of children, return an iterator over stack items
168    /// that match the mask.
169    #[inline]
170    pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator<Item = B256> + '_ {
171        self.children()
172            .filter_map(|(i, c)| c.map(|c| (i, c)))
173            .filter(move |(index, _)| hash_mask.is_bit_set(*index))
174            .map(|(_, child)| B256::from_slice(&child[1..]))
175    }
176
177    /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`.
178    #[inline]
179    pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
180        self.encode(rlp);
181        RlpNode::from_rlp(rlp)
182    }
183
184    /// Returns the length of RLP encoded fields of branch node.
185    #[inline]
186    fn rlp_payload_length(&self) -> usize {
187        let mut payload_length = 1;
188        for (_, child) in self.children() {
189            if let Some(child) = child {
190                payload_length += child.len();
191            } else {
192                payload_length += 1;
193            }
194        }
195        payload_length
196    }
197}
198
199/// Iterator over branch node children.
200#[derive(Debug)]
201struct BranchChildrenIter<'a> {
202    range: Range<u8>,
203    state_mask: TrieMask,
204    stack_iter: Iter<'a, RlpNode>,
205}
206
207impl<'a> BranchChildrenIter<'a> {
208    /// Create new iterator over branch node children.
209    fn new(node: &BranchNodeRef<'a>) -> Self {
210        Self {
211            range: CHILD_INDEX_RANGE,
212            state_mask: node.state_mask,
213            stack_iter: node.stack[node.first_child_index()..].iter(),
214        }
215    }
216}
217
218impl<'a> Iterator for BranchChildrenIter<'a> {
219    type Item = (u8, Option<&'a RlpNode>);
220
221    #[inline]
222    fn next(&mut self) -> Option<Self::Item> {
223        let i = self.range.next()?;
224        let value = if self.state_mask.is_bit_set(i) {
225            // SAFETY: `first_child_index` guarantees that `stack` is exactly
226            // `state_mask.count_ones()` long.
227            Some(unsafe { self.stack_iter.next().unwrap_unchecked() })
228        } else {
229            None
230        };
231        Some((i, value))
232    }
233
234    #[inline]
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        let len = self.len();
237        (len, Some(len))
238    }
239}
240
241impl core::iter::FusedIterator for BranchChildrenIter<'_> {}
242
243impl ExactSizeIterator for BranchChildrenIter<'_> {
244    #[inline]
245    fn len(&self) -> usize {
246        self.range.len()
247    }
248}
249
250/// A struct representing a branch node in an Ethereum trie.
251///
252/// A branch node can have up to 16 children, each corresponding to one of the possible nibble
253/// values (`0` to `f`) in the trie's path.
254///
255/// The masks in a BranchNode are used to efficiently represent and manage information about the
256/// presence and types of its children. They are bitmasks, where each bit corresponds to a nibble
257/// (half-byte, or 4 bits) value from `0` to `f`.
258#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
259#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
260pub struct BranchNodeCompact {
261    /// The bitmask indicating the presence of children at the respective nibble positions in the
262    /// trie. If the bit at position i (counting from the right) is set (1), it indicates that a
263    /// child exists for the nibble value i. If the bit is unset (0), it means there is no child
264    /// for that nibble value.
265    pub state_mask: TrieMask,
266    /// The bitmask representing the children at the respective nibble positions in the trie that
267    /// are also stored in the database. If the bit at position `i` (counting from the right)
268    /// is set (1) and also present in the state_mask, it indicates that the corresponding
269    /// child at the nibble value `i` is stored in the database. If the bit is unset (0), it means
270    /// the child is not stored in the database.
271    pub tree_mask: TrieMask,
272    /// The bitmask representing the hashed branch children nodes at the respective nibble
273    /// positions in the trie. If the bit at position `i` (counting from the right) is set (1)
274    /// and also present in the state_mask, it indicates that the corresponding child at the
275    /// nibble value `i` is a hashed branch child node. If the bit is unset (0), it means the child
276    /// is not a hashed branch child node.
277    pub hash_mask: TrieMask,
278    /// Collection of hashes associated with the children of the branch node.
279    /// Each child hash is calculated by hashing two consecutive sub-branch roots.
280    pub hashes: Vec<B256>,
281    /// An optional root hash of the subtree rooted at this branch node.
282    pub root_hash: Option<B256>,
283}
284
285impl BranchNodeCompact {
286    /// Creates a new [BranchNodeCompact] from the given parameters.
287    pub fn new(
288        state_mask: impl Into<TrieMask>,
289        tree_mask: impl Into<TrieMask>,
290        hash_mask: impl Into<TrieMask>,
291        hashes: Vec<B256>,
292        root_hash: Option<B256>,
293    ) -> Self {
294        let (state_mask, tree_mask, hash_mask) =
295            (state_mask.into(), tree_mask.into(), hash_mask.into());
296        assert!(
297            tree_mask.is_subset_of(state_mask),
298            "state mask: {state_mask:?} tree mask: {tree_mask:?}"
299        );
300        assert!(
301            hash_mask.is_subset_of(state_mask),
302            "state_mask {state_mask:?} hash_mask: {hash_mask:?}"
303        );
304        assert_eq!(hash_mask.count_ones() as usize, hashes.len());
305        Self { state_mask, tree_mask, hash_mask, hashes, root_hash }
306    }
307
308    /// Returns the hash associated with the given nibble.
309    pub fn hash_for_nibble(&self, nibble: u8) -> B256 {
310        let mask = *TrieMask::from_nibble(nibble) - 1;
311        let index = (*self.hash_mask & mask).count_ones();
312        self.hashes[index as usize]
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::nodes::{ExtensionNode, LeafNode};
320    use nybbles::Nibbles;
321
322    #[test]
323    fn rlp_branch_node_roundtrip() {
324        let empty = BranchNode::default();
325        let encoded = alloy_rlp::encode(&empty);
326        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty);
327
328        let sparse_node = BranchNode::new(
329            vec![
330                RlpNode::word_rlp(&B256::repeat_byte(1)),
331                RlpNode::word_rlp(&B256::repeat_byte(2)),
332            ],
333            TrieMask::new(0b1000100),
334        );
335        let encoded = alloy_rlp::encode(&sparse_node);
336        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node);
337
338        let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec());
339        let mut buf = vec![];
340        let leaf_rlp = leaf_child.as_ref().rlp(&mut buf);
341        let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010));
342        let encoded = alloy_rlp::encode(&branch_with_leaf);
343        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_leaf);
344
345        let extension_child = ExtensionNode::new(Nibbles::from_nibbles(hex!("0203")), leaf_rlp);
346        let mut buf = vec![];
347        let extension_rlp = extension_child.as_ref().rlp(&mut buf);
348        let branch_with_ext = BranchNode::new(vec![extension_rlp], TrieMask::new(0b00000100000));
349        let encoded = alloy_rlp::encode(&branch_with_ext);
350        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext);
351
352        let full = BranchNode::new(
353            core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
354            TrieMask::new(u16::MAX),
355        );
356        let encoded = alloy_rlp::encode(&full);
357        assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), full);
358    }
359}