1use super::{super::TrieMask, RlpNode, CHILD_INDEX_RANGE};
2use alloy_primitives::{hex, B256};
3use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE};
4use core::{fmt, ops::Range, slice::Iter};
5
6#[allow(unused_imports)]
7use alloc::vec::Vec;
8
9#[derive(PartialEq, Eq, Clone, Default)]
15pub struct BranchNode {
16 pub stack: Vec<RlpNode>,
18 pub state_mask: TrieMask,
20}
21
22impl fmt::Debug for BranchNode {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_struct("BranchNode")
25 .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
26 .field("state_mask", &self.state_mask)
27 .field("first_child_index", &self.as_ref().first_child_index())
28 .finish()
29 }
30}
31
32impl Encodable for BranchNode {
33 #[inline]
34 fn encode(&self, out: &mut dyn BufMut) {
35 self.as_ref().encode(out)
36 }
37
38 #[inline]
39 fn length(&self) -> usize {
40 self.as_ref().length()
41 }
42}
43
44impl Decodable for BranchNode {
45 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
46 let mut bytes = Header::decode_bytes(buf, true)?;
47
48 let mut stack = Vec::new();
49 let mut state_mask = TrieMask::default();
50 for index in CHILD_INDEX_RANGE {
51 if bytes.len() <= 1 {
53 return Err(alloy_rlp::Error::InputTooShort);
54 }
55
56 if bytes[0] == EMPTY_STRING_CODE {
57 bytes.advance(1);
58 continue;
59 }
60
61 let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?;
63 let len = payload_length + length_of_length(payload_length);
64 stack.push(RlpNode::from_raw_rlp(&bytes[..len])?);
65 bytes.advance(len);
66 state_mask.set_bit(index);
67 }
68
69 let bytes = Header::decode_bytes(&mut bytes, false)?;
71 if !bytes.is_empty() {
72 return Err(alloy_rlp::Error::Custom("branch values not supported"));
73 }
74 debug_assert!(bytes.is_empty(), "bytes {}", alloy_primitives::hex::encode(bytes));
75
76 Ok(Self { stack, state_mask })
77 }
78}
79
80impl BranchNode {
81 pub const fn new(stack: Vec<RlpNode>, state_mask: TrieMask) -> Self {
83 Self { stack, state_mask }
84 }
85
86 pub fn as_ref(&self) -> BranchNodeRef<'_> {
88 BranchNodeRef::new(&self.stack, self.state_mask)
89 }
90}
91
92#[derive(Clone)]
95pub struct BranchNodeRef<'a> {
96 pub stack: &'a [RlpNode],
101 pub state_mask: TrieMask,
104}
105
106impl fmt::Debug for BranchNodeRef<'_> {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("BranchNodeRef")
109 .field("stack", &self.stack.iter().map(hex::encode).collect::<Vec<_>>())
110 .field("state_mask", &self.state_mask)
111 .field("first_child_index", &self.first_child_index())
112 .finish()
113 }
114}
115
116impl Encodable for BranchNodeRef<'_> {
120 #[inline]
121 fn encode(&self, out: &mut dyn BufMut) {
122 Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);
123
124 for (_, child) in self.children() {
126 if let Some(child) = child {
127 out.put_slice(child);
128 } else {
129 out.put_u8(EMPTY_STRING_CODE);
130 }
131 }
132
133 out.put_u8(EMPTY_STRING_CODE);
134 }
135
136 #[inline]
137 fn length(&self) -> usize {
138 let payload_length = self.rlp_payload_length();
139 payload_length + length_of_length(payload_length)
140 }
141}
142
143impl<'a> BranchNodeRef<'a> {
144 #[inline]
146 pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self {
147 Self { stack, state_mask }
148 }
149
150 #[inline]
157 pub fn first_child_index(&self) -> usize {
158 self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
159 }
160
161 #[inline]
163 pub fn children(&self) -> impl Iterator<Item = (u8, Option<&RlpNode>)> + '_ {
164 BranchChildrenIter::new(self)
165 }
166
167 #[inline]
170 pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator<Item = B256> + '_ {
171 self.children()
172 .filter_map(|(i, c)| c.map(|c| (i, c)))
173 .filter(move |(index, _)| hash_mask.is_bit_set(*index))
174 .map(|(_, child)| B256::from_slice(&child[1..]))
175 }
176
177 #[inline]
179 pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
180 self.encode(rlp);
181 RlpNode::from_rlp(rlp)
182 }
183
184 #[inline]
186 fn rlp_payload_length(&self) -> usize {
187 let mut payload_length = 1;
188 for (_, child) in self.children() {
189 if let Some(child) = child {
190 payload_length += child.len();
191 } else {
192 payload_length += 1;
193 }
194 }
195 payload_length
196 }
197}
198
199#[derive(Debug)]
201struct BranchChildrenIter<'a> {
202 range: Range<u8>,
203 state_mask: TrieMask,
204 stack_iter: Iter<'a, RlpNode>,
205}
206
207impl<'a> BranchChildrenIter<'a> {
208 fn new(node: &BranchNodeRef<'a>) -> Self {
210 Self {
211 range: CHILD_INDEX_RANGE,
212 state_mask: node.state_mask,
213 stack_iter: node.stack[node.first_child_index()..].iter(),
214 }
215 }
216}
217
218impl<'a> Iterator for BranchChildrenIter<'a> {
219 type Item = (u8, Option<&'a RlpNode>);
220
221 #[inline]
222 fn next(&mut self) -> Option<Self::Item> {
223 let i = self.range.next()?;
224 let value = if self.state_mask.is_bit_set(i) {
225 Some(unsafe { self.stack_iter.next().unwrap_unchecked() })
228 } else {
229 None
230 };
231 Some((i, value))
232 }
233
234 #[inline]
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 let len = self.len();
237 (len, Some(len))
238 }
239}
240
241impl core::iter::FusedIterator for BranchChildrenIter<'_> {}
242
243impl ExactSizeIterator for BranchChildrenIter<'_> {
244 #[inline]
245 fn len(&self) -> usize {
246 self.range.len()
247 }
248}
249
250#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
259#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
260pub struct BranchNodeCompact {
261 pub state_mask: TrieMask,
266 pub tree_mask: TrieMask,
272 pub hash_mask: TrieMask,
278 pub hashes: Vec<B256>,
281 pub root_hash: Option<B256>,
283}
284
285impl BranchNodeCompact {
286 pub fn new(
288 state_mask: impl Into<TrieMask>,
289 tree_mask: impl Into<TrieMask>,
290 hash_mask: impl Into<TrieMask>,
291 hashes: Vec<B256>,
292 root_hash: Option<B256>,
293 ) -> Self {
294 let (state_mask, tree_mask, hash_mask) =
295 (state_mask.into(), tree_mask.into(), hash_mask.into());
296 assert!(
297 tree_mask.is_subset_of(state_mask),
298 "state mask: {state_mask:?} tree mask: {tree_mask:?}"
299 );
300 assert!(
301 hash_mask.is_subset_of(state_mask),
302 "state_mask {state_mask:?} hash_mask: {hash_mask:?}"
303 );
304 assert_eq!(hash_mask.count_ones() as usize, hashes.len());
305 Self { state_mask, tree_mask, hash_mask, hashes, root_hash }
306 }
307
308 pub fn hash_for_nibble(&self, nibble: u8) -> B256 {
310 let mask = *TrieMask::from_nibble(nibble) - 1;
311 let index = (*self.hash_mask & mask).count_ones();
312 self.hashes[index as usize]
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::nodes::{ExtensionNode, LeafNode};
320 use nybbles::Nibbles;
321
322 #[test]
323 fn rlp_branch_node_roundtrip() {
324 let empty = BranchNode::default();
325 let encoded = alloy_rlp::encode(&empty);
326 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty);
327
328 let sparse_node = BranchNode::new(
329 vec![
330 RlpNode::word_rlp(&B256::repeat_byte(1)),
331 RlpNode::word_rlp(&B256::repeat_byte(2)),
332 ],
333 TrieMask::new(0b1000100),
334 );
335 let encoded = alloy_rlp::encode(&sparse_node);
336 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node);
337
338 let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec());
339 let mut buf = vec![];
340 let leaf_rlp = leaf_child.as_ref().rlp(&mut buf);
341 let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010));
342 let encoded = alloy_rlp::encode(&branch_with_leaf);
343 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_leaf);
344
345 let extension_child = ExtensionNode::new(Nibbles::from_nibbles(hex!("0203")), leaf_rlp);
346 let mut buf = vec![];
347 let extension_rlp = extension_child.as_ref().rlp(&mut buf);
348 let branch_with_ext = BranchNode::new(vec![extension_rlp], TrieMask::new(0b00000100000));
349 let encoded = alloy_rlp::encode(&branch_with_ext);
350 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext);
351
352 let full = BranchNode::new(
353 core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
354 TrieMask::new(u16::MAX),
355 );
356 let encoded = alloy_rlp::encode(&full);
357 assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), full);
358 }
359}