Skip to main content

ethrex_trie/
nibbles.rs

1use std::{cmp, mem};
2
3use ethrex_rlp::{
4    decode::RLPDecode,
5    encode::RLPEncode,
6    error::RLPDecodeError,
7    structs::{Decoder, Encoder},
8};
9
10// ── SIMD nibble expansion ────────────────────────────────────────────────────
11//
12// The hot path during block execution converts 32-byte keccak keys to 64
13// nibbles on every trie lookup / insert.  We replace the original flat_map
14// iterator chain with a SIMD kernel that processes 16 or 32 bytes per cycle.
15
16/// Expands each byte in `bytes` into two nibbles (high nibble first),
17/// writing `bytes.len() * 2` bytes to the uninitialized `output` pointer.
18///
19/// # Safety
20/// `output` must be valid for writes of at least `bytes.len() * 2` bytes.
21#[inline]
22#[allow(unsafe_code)]
23unsafe fn expand_bytes_to_nibbles(bytes: &[u8], output: *mut u8) {
24    #[cfg(target_arch = "x86_64")]
25    {
26        // SAFETY: caller guarantees output is writable for bytes.len() * 2 bytes.
27        unsafe { expand_bytes_to_nibbles_x86_64(bytes, output) };
28        return;
29    }
30    #[cfg(target_arch = "aarch64")]
31    {
32        // SAFETY: caller guarantees output is writable for bytes.len() * 2 bytes.
33        unsafe { expand_bytes_to_nibbles_aarch64(bytes, output) };
34        return;
35    }
36    // Portable scalar fallback for other architectures.
37    #[allow(unreachable_code)]
38    // SAFETY: caller guarantees output is writable for bytes.len() * 2 bytes.
39    unsafe {
40        expand_bytes_to_nibbles_scalar(bytes, output)
41    };
42}
43
44#[cfg(target_arch = "x86_64")]
45#[allow(unsafe_code)]
46#[inline]
47unsafe fn expand_bytes_to_nibbles_x86_64(bytes: &[u8], output: *mut u8) {
48    use std::arch::x86_64::*;
49
50    let n = bytes.len();
51    let mut i = 0usize;
52
53    // --- AVX2 path: 32 bytes → 64 nibbles per iteration ---
54    // Enabled only when the compiler has +avx2 in target-feature
55    // (set via .cargo/config.toml for the production x86_64-linux target).
56    #[cfg(target_feature = "avx2")]
57    // SAFETY: AVX2 is enabled at compile time via .cargo/config.toml target-feature.
58    unsafe {
59        let mask256 = _mm256_set1_epi8(0x0F_u8 as i8);
60        while i + 32 <= n {
61            // Load 32 input bytes.
62            let v = _mm256_loadu_si256(bytes.as_ptr().add(i).cast::<__m256i>());
63            // Extract high nibbles: shift each 16-bit word right by 4, then mask.
64            let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), mask256);
65            // Extract low nibbles.
66            let lo = _mm256_and_si256(v, mask256);
67            // Interleave hi/lo within each 128-bit lane.
68            // _mm256_unpacklo_epi8 → [hi0,lo0,hi1,lo1,…,hi7,lo7 | hi16,lo16,…,hi23,lo23]
69            // _mm256_unpackhi_epi8 → [hi8,lo8,…,hi15,lo15       | hi24,lo24,…,hi31,lo31]
70            let unpack_lo = _mm256_unpacklo_epi8(hi, lo);
71            let unpack_hi = _mm256_unpackhi_epi8(hi, lo);
72            // Cross-lane permute to restore sequential byte order:
73            //   out_lo = [lane0(unpack_lo), lane0(unpack_hi)] = bytes  0-15 nibbles
74            //   out_hi = [lane1(unpack_lo), lane1(unpack_hi)] = bytes 16-31 nibbles
75            let out_lo = _mm256_permute2x128_si256::<0x20>(unpack_lo, unpack_hi);
76            let out_hi = _mm256_permute2x128_si256::<0x31>(unpack_lo, unpack_hi);
77            _mm256_storeu_si256(output.add(i * 2).cast::<__m256i>(), out_lo);
78            _mm256_storeu_si256(output.add(i * 2 + 32).cast::<__m256i>(), out_hi);
79            i += 32;
80        }
81    }
82
83    // --- SSE2 path: 16 bytes → 32 nibbles per iteration ---
84    // SSE2 is part of the x86_64 baseline; no runtime check needed.
85    // SAFETY: SSE2 is always available on x86_64; pointer arithmetic stays within bounds.
86    unsafe {
87        let mask128 = _mm_set1_epi8(0x0F_u8 as i8);
88        while i + 16 <= n {
89            let v = _mm_loadu_si128(bytes.as_ptr().add(i).cast::<__m128i>());
90            let hi = _mm_and_si128(_mm_srli_epi16(v, 4), mask128);
91            let lo = _mm_and_si128(v, mask128);
92            let lo16 = _mm_unpacklo_epi8(hi, lo);
93            let hi16 = _mm_unpackhi_epi8(hi, lo);
94            _mm_storeu_si128(output.add(i * 2).cast::<__m128i>(), lo16);
95            _mm_storeu_si128(output.add(i * 2 + 16).cast::<__m128i>(), hi16);
96            i += 16;
97        }
98
99        // Scalar tail for remaining bytes (0-15).
100        while i < n {
101            let b = *bytes.get_unchecked(i);
102            *output.add(i * 2) = b >> 4;
103            *output.add(i * 2 + 1) = b & 0x0F;
104            i += 1;
105        }
106    }
107}
108
109#[cfg(target_arch = "aarch64")]
110#[target_feature(enable = "neon")]
111#[allow(unsafe_code)]
112#[inline]
113unsafe fn expand_bytes_to_nibbles_aarch64(bytes: &[u8], output: *mut u8) {
114    use std::arch::aarch64::*;
115
116    let n = bytes.len();
117    let mut i = 0usize;
118
119    // NEON is mandatory on aarch64; no runtime detection needed.
120    // SAFETY: NEON is always available on aarch64; bounds are maintained by the loop guard.
121    unsafe {
122        let mask_0f = vdupq_n_u8(0x0F);
123        while i + 16 <= n {
124            let v = vld1q_u8(bytes.as_ptr().add(i));
125            // vshrq_n_u8 shifts each 8-bit lane independently → high nibbles directly.
126            let hi = vshrq_n_u8(v, 4);
127            let lo = vandq_u8(v, mask_0f);
128            // vzip1q_u8 / vzip2q_u8 interleave bytes from two vectors.
129            let lo16 = vzip1q_u8(hi, lo); // [hi0,lo0,hi1,lo1,…,hi7,lo7]
130            let hi16 = vzip2q_u8(hi, lo); // [hi8,lo8,…,hi15,lo15]
131            vst1q_u8(output.add(i * 2), lo16);
132            vst1q_u8(output.add(i * 2 + 16), hi16);
133            i += 16;
134        }
135
136        // Scalar tail.
137        while i < n {
138            let b = *bytes.get_unchecked(i);
139            *output.add(i * 2) = b >> 4;
140            *output.add(i * 2 + 1) = b & 0x0F;
141            i += 1;
142        }
143    }
144}
145
146#[allow(unsafe_code)]
147#[inline]
148unsafe fn expand_bytes_to_nibbles_scalar(bytes: &[u8], output: *mut u8) {
149    // SAFETY: caller guarantees output is valid for bytes.len() * 2 bytes.
150    unsafe {
151        for (i, &b) in bytes.iter().enumerate() {
152            *output.add(i * 2) = b >> 4;
153            *output.add(i * 2 + 1) = b & 0x0F;
154        }
155    }
156}
157
158// ── SIMD nibble packing ──────────────────────────────────────────────────────
159//
160// pack_nibble_pairs combines pairs of nibbles [hi, lo, hi, lo, …] into bytes
161// [(hi<<4)|lo, …].  This is the hot path inside encode_compact, called on every
162// trie node when computing the Merkle root.
163//
164// Strategy: use SSSE3 _mm_maddubs_epi16 which does
165//   result[i] = a[2i]*16 + a[2i+1]  (treating a as unsigned, b as signed)
166// Setting b = [16, 1, 16, 1, …] gives the packed nibble byte for each pair.
167// This is enabled when SSSE3 is available (always on x86-64-v3).
168
169/// Packs pairs of nibbles in `nibbles` into bytes, writing to `output`.
170/// `nibbles.len()` must be even.
171///
172/// # Safety
173/// `output` must be writable for `nibbles.len() / 2` bytes.
174#[inline]
175#[allow(unsafe_code)]
176unsafe fn pack_nibble_pairs(nibbles: &[u8], output: *mut u8) {
177    debug_assert!(nibbles.len().is_multiple_of(2));
178    #[cfg(target_arch = "x86_64")]
179    {
180        unsafe { pack_nibble_pairs_x86_64(nibbles, output) };
181        return;
182    }
183    #[cfg(target_arch = "aarch64")]
184    {
185        unsafe { pack_nibble_pairs_aarch64(nibbles, output) };
186        return;
187    }
188    #[allow(unreachable_code)]
189    unsafe {
190        pack_nibble_pairs_scalar(nibbles, output)
191    };
192}
193
194#[cfg(target_arch = "x86_64")]
195#[allow(unsafe_code)]
196#[inline]
197unsafe fn pack_nibble_pairs_x86_64(nibbles: &[u8], output: *mut u8) {
198    let n = nibbles.len(); // always even
199    let mut i = 0usize; // index into nibbles (steps of 32)
200    let mut o = 0usize; // index into output (steps of 16)
201
202    // SSSE3 path: 32 nibbles → 16 output bytes per iteration.
203    // _mm_maddubs_epi16(a, b): result[k] = (a[2k]*b[2k] + a[2k+1]*b[2k+1]) as i16
204    // With b=[16,1,16,1,...] and a=[hi,lo,...]: result[k] = hi*16 + lo
205    #[cfg(target_feature = "ssse3")]
206    // SAFETY: SSSE3 enabled at compile time; pointer arithmetic stays within bounds.
207    unsafe {
208        use std::arch::x86_64::*;
209        // Multiplier: weight = [16, 1] repeated → multiply even nibble by 16, odd by 1
210        let weights = _mm_set1_epi16(0x0110_u16 as i16); // bytes: [16, 1, 16, 1, ...]
211        while i + 32 <= n {
212            // Load 32 nibbles (16 pairs) in two 128-bit chunks.
213            let lo_chunk = _mm_loadu_si128(nibbles.as_ptr().add(i).cast::<__m128i>());
214            let hi_chunk = _mm_loadu_si128(nibbles.as_ptr().add(i + 16).cast::<__m128i>());
215            // maddubs: [hi0*16+lo0, hi1*16+lo1, …, hi7*16+lo7] as 16-bit lanes
216            let lo_packed = _mm_maddubs_epi16(lo_chunk, weights);
217            let hi_packed = _mm_maddubs_epi16(hi_chunk, weights);
218            // packus: saturate to u8 and pack both 8×i16 → 16×u8
219            let result = _mm_packus_epi16(lo_packed, hi_packed);
220            _mm_storeu_si128(output.add(o).cast::<__m128i>(), result);
221            i += 32;
222            o += 16;
223        }
224    }
225
226    // Scalar tail (handles remaining pairs when n is not a multiple of 32,
227    // or when SSSE3 is unavailable).
228    unsafe {
229        while i + 2 <= n {
230            *output.add(o) = (*nibbles.get_unchecked(i) << 4) | *nibbles.get_unchecked(i + 1);
231            i += 2;
232            o += 1;
233        }
234    }
235}
236
237#[cfg(target_arch = "aarch64")]
238#[target_feature(enable = "neon")]
239#[allow(unsafe_code)]
240#[inline]
241unsafe fn pack_nibble_pairs_aarch64(nibbles: &[u8], output: *mut u8) {
242    use std::arch::aarch64::*;
243
244    let n = nibbles.len();
245    let mut i = 0usize;
246    let mut o = 0usize;
247
248    // SAFETY: NEON always available; bounds maintained by loop guard.
249    unsafe {
250        while i + 32 <= n {
251            // Load 32 nibbles interleaved as [hi, lo] pairs.
252            let v = vld2q_u8(nibbles.as_ptr().add(i));
253            // v.0 = hi nibbles, v.1 = lo nibbles
254            // Pack: (hi << 4) | lo
255            let packed = vorrq_u8(vshlq_n_u8(v.0, 4), v.1);
256            vst1q_u8(output.add(o), packed);
257            i += 32;
258            o += 16;
259        }
260        while i + 2 <= n {
261            *output.add(o) = (*nibbles.get_unchecked(i) << 4) | *nibbles.get_unchecked(i + 1);
262            i += 2;
263            o += 1;
264        }
265    }
266}
267
268#[allow(unsafe_code)]
269#[inline]
270unsafe fn pack_nibble_pairs_scalar(nibbles: &[u8], output: *mut u8) {
271    // SAFETY: caller ensures `output` is valid for nibbles.len()/2 bytes.
272    unsafe {
273        let mut o = 0usize;
274        let mut i = 0usize;
275        let n = nibbles.len();
276        while i + 2 <= n {
277            *output.add(o) = (*nibbles.get_unchecked(i) << 4) | *nibbles.get_unchecked(i + 1);
278            i += 2;
279            o += 1;
280        }
281    }
282}
283// ─────────────────────────────────────────────────────────────────────────────
284
285// ── SIMD prefix comparison ───────────────────────────────────────────────────
286//
287// count_common_prefix finds the length of the longest common prefix of two
288// byte slices.  The trie uses this on every insert/lookup to navigate branch
289// nodes.  Using SIMD we can compare 16 (SSE2) or 32 (AVX2) bytes at once.
290
291#[allow(unsafe_code)]
292#[inline]
293fn count_common_prefix(a: &[u8], b: &[u8]) -> usize {
294    #[cfg(target_arch = "x86_64")]
295    {
296        // SAFETY: x86_64 SIMD; bounds are maintained within the function.
297        return unsafe { count_common_prefix_x86_64(a, b) };
298    }
299    #[cfg(target_arch = "aarch64")]
300    {
301        // SAFETY: NEON enabled; bounds are maintained within the function.
302        return unsafe { count_common_prefix_aarch64(a, b) };
303    }
304    #[allow(unreachable_code)]
305    count_common_prefix_scalar(a, b)
306}
307
308#[cfg(target_arch = "x86_64")]
309#[allow(unsafe_code)]
310#[inline]
311unsafe fn count_common_prefix_x86_64(a: &[u8], b: &[u8]) -> usize {
312    use std::arch::x86_64::*;
313
314    let n = a.len().min(b.len());
315    let mut i = 0usize;
316
317    #[cfg(target_feature = "avx2")]
318    // SAFETY: AVX2 enabled at compile time; pointer arithmetic stays within bounds.
319    unsafe {
320        while i + 32 <= n {
321            let va = _mm256_loadu_si256(a.as_ptr().add(i).cast::<__m256i>());
322            let vb = _mm256_loadu_si256(b.as_ptr().add(i).cast::<__m256i>());
323            // Compare bytes: equal → 0xFF, else 0x00
324            let eq = _mm256_cmpeq_epi8(va, vb);
325            // Create a 32-bit mask where bit k = 1 iff byte k was equal.
326            let mask = _mm256_movemask_epi8(eq) as u32;
327            if mask != 0xFFFF_FFFF {
328                // First differing byte is at bit position (trailing ones).
329                return i + mask.trailing_ones() as usize;
330            }
331            i += 32;
332        }
333    }
334
335    // SSE2 (16-byte chunks). SSE2 is x86_64 baseline.
336    // SAFETY: SSE2 always available; bounds maintained by loop guard.
337    unsafe {
338        while i + 16 <= n {
339            let va = _mm_loadu_si128(a.as_ptr().add(i).cast::<__m128i>());
340            let vb = _mm_loadu_si128(b.as_ptr().add(i).cast::<__m128i>());
341            let eq = _mm_cmpeq_epi8(va, vb);
342            let mask = _mm_movemask_epi8(eq) as u16;
343            if mask != 0xFFFF {
344                return i + mask.trailing_ones() as usize;
345            }
346            i += 16;
347        }
348    }
349
350    // Scalar tail.
351    i + count_common_prefix_scalar(&a[i..n], &b[i..n])
352}
353
354#[cfg(target_arch = "aarch64")]
355#[target_feature(enable = "neon")]
356#[allow(unsafe_code)]
357#[inline]
358unsafe fn count_common_prefix_aarch64(a: &[u8], b: &[u8]) -> usize {
359    use std::arch::aarch64::*;
360
361    let n = a.len().min(b.len());
362    let mut i = 0usize;
363
364    // SAFETY: NEON always available; pointer arithmetic stays within bounds.
365    unsafe {
366        while i + 16 <= n {
367            let va = vld1q_u8(a.as_ptr().add(i));
368            let vb = vld1q_u8(b.as_ptr().add(i));
369            // vceqq_u8: equal lanes → 0xFF, else 0x00
370            let eq = vceqq_u8(va, vb);
371            // vminvq_u8: reduce to minimum; if all 0xFF then all bytes matched.
372            if vminvq_u8(eq) == 0xFF {
373                i += 16;
374                continue;
375            }
376            // Find first non-matching byte by scanning the 16-byte window.
377            let mut eq_arr = [0u8; 16];
378            vst1q_u8(eq_arr.as_mut_ptr(), eq);
379            for (j, &byte) in eq_arr.iter().enumerate() {
380                if byte == 0 {
381                    return i + j;
382                }
383            }
384            unreachable!()
385        }
386    }
387
388    i + count_common_prefix_scalar(&a[i..n], &b[i..n])
389}
390
391#[inline]
392fn count_common_prefix_scalar(a: &[u8], b: &[u8]) -> usize {
393    a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
394}
395// ─────────────────────────────────────────────────────────────────────────────
396
397// TODO: move path-tracking logic somewhere else
398// PERF: try using a stack-allocated array
399/// Struct representing a list of nibbles (half-bytes)
400#[derive(
401    Debug,
402    Clone,
403    Default,
404    serde::Serialize,
405    serde::Deserialize,
406    rkyv::Deserialize,
407    rkyv::Serialize,
408    rkyv::Archive,
409)]
410pub struct Nibbles {
411    data: Vec<u8>,
412    /// Parts of the path that have already been consumed (used for tracking
413    /// current position when visiting nodes). See `current()`.
414    already_consumed: Vec<u8>,
415}
416
417// NOTE: custom impls to ignore the `already_consumed` field
418
419impl PartialEq for Nibbles {
420    fn eq(&self, other: &Nibbles) -> bool {
421        self.data == other.data
422    }
423}
424
425impl Eq for Nibbles {}
426
427impl PartialOrd for Nibbles {
428    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
429        Some(self.cmp(other))
430    }
431}
432
433impl Ord for Nibbles {
434    fn cmp(&self, other: &Self) -> cmp::Ordering {
435        self.data.cmp(&other.data)
436    }
437}
438
439impl std::hash::Hash for Nibbles {
440    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
441        self.data.hash(state);
442    }
443}
444
445impl Nibbles {
446    /// Create `Nibbles` from  hex-encoded nibbles
447    pub const fn from_hex(hex: Vec<u8>) -> Self {
448        Self {
449            data: hex,
450            already_consumed: vec![],
451        }
452    }
453
454    /// Splits incoming bytes into nibbles and appends the leaf flag (a 16 nibble at the end)
455    pub fn from_bytes(bytes: &[u8]) -> Self {
456        Self::from_raw(bytes, true)
457    }
458
459    /// Splits incoming bytes into nibbles and appends the leaf flag (a 16 nibble at the end) if is_leaf is true
460    pub fn from_raw(bytes: &[u8], is_leaf: bool) -> Self {
461        let extra = usize::from(is_leaf);
462        let mut data = Vec::with_capacity(bytes.len() * 2 + extra);
463
464        // SAFETY: we allocated at least `bytes.len() * 2` capacity (plus `extra`),
465        // and we set_len to exactly `bytes.len() * 2` after the SIMD kernel fills them.
466        #[allow(unsafe_code)]
467        unsafe {
468            expand_bytes_to_nibbles(bytes, data.as_mut_ptr());
469            data.set_len(bytes.len() * 2);
470        }
471
472        if is_leaf {
473            data.push(16);
474        }
475
476        Self {
477            data,
478            already_consumed: vec![],
479        }
480    }
481
482    pub fn into_vec(self) -> Vec<u8> {
483        self.data
484    }
485
486    /// Returns the amount of nibbles
487    pub fn len(&self) -> usize {
488        self.data.len()
489    }
490
491    /// Returns true if there are no nibbles
492    pub fn is_empty(&self) -> bool {
493        self.data.is_empty()
494    }
495
496    /// If `prefix` is a prefix of self, move the offset after
497    /// the prefix and return true, otherwise return false.
498    pub fn skip_prefix(&mut self, prefix: &Nibbles) -> bool {
499        if self.len() >= prefix.len() && &self.data[..prefix.len()] == prefix.as_ref() {
500            self.already_consumed.extend_from_slice(&prefix.data);
501            self.data.drain(..prefix.len());
502            true
503        } else {
504            false
505        }
506    }
507
508    /// Compares self to another, comparing prefixes only in case of unequal lengths.
509    pub fn compare_prefix(&self, prefix: &Nibbles) -> cmp::Ordering {
510        if self.len() > prefix.len() {
511            self.data[..prefix.len()].cmp(&prefix.data)
512        } else {
513            self.data[..].cmp(&prefix.data[..self.len()])
514        }
515    }
516
517    /// Compares self to another and returns the shared nibble count (amount of nibbles that are equal, from the start)
518    pub fn count_prefix(&self, other: &Nibbles) -> usize {
519        count_common_prefix(self.as_ref(), other.as_ref())
520    }
521
522    /// Removes and returns the first nibble
523    #[allow(clippy::should_implement_trait)]
524    pub fn next(&mut self) -> Option<u8> {
525        (!self.is_empty()).then(|| {
526            self.already_consumed.push(self.data[0]);
527            self.data.remove(0)
528        })
529    }
530
531    /// Removes and returns the first nibble if it is a suitable choice index (aka < 16)
532    pub fn next_choice(&mut self) -> Option<usize> {
533        self.next().filter(|choice| *choice < 16).map(usize::from)
534    }
535
536    /// Returns the nibbles after the given offset
537    pub fn offset(&self, offset: usize) -> Nibbles {
538        let mut already_consumed = Vec::with_capacity(self.already_consumed.len() + offset);
539        already_consumed.extend_from_slice(&self.already_consumed);
540        already_consumed.extend_from_slice(&self.data[..offset]);
541        Nibbles {
542            data: self.data[offset..].to_vec(),
543            already_consumed,
544        }
545    }
546
547    /// Returns the nibbles beween the start and end indexes
548    pub fn slice(&self, start: usize, end: usize) -> Nibbles {
549        Nibbles::from_hex(self.data[start..end].to_vec())
550    }
551
552    /// Extends the nibbles with another list of nibbles
553    pub fn extend(&mut self, other: &Nibbles) {
554        self.data.extend_from_slice(other.as_ref());
555    }
556
557    /// Return the nibble at the given index, will panic if the index is out of range
558    pub fn at(&self, i: usize) -> usize {
559        self.data[i] as usize
560    }
561
562    /// Inserts a nibble at the start
563    pub fn prepend(&mut self, nibble: u8) {
564        self.data.insert(0, nibble);
565    }
566
567    /// Inserts a nibble at the end
568    pub fn append(&mut self, nibble: u8) {
569        self.data.push(nibble);
570    }
571
572    /// Taken from https://github.com/citahub/cita_trie/blob/master/src/nibbles.rs#L56
573    /// Encodes the nibbles in compact form
574    #[allow(unsafe_code)]
575    pub fn encode_compact(&self) -> Vec<u8> {
576        let is_leaf = self.is_leaf();
577        let mut hex = if is_leaf {
578            &self.data[0..self.data.len() - 1]
579        } else {
580            &self.data[0..]
581        };
582        // node type    path length    |    prefix    hexchar
583        // --------------------------------------------------
584        // extension    even           |    0000      0x0
585        // extension    odd            |    0001      0x1
586        // leaf         even           |    0010      0x2
587        // leaf         odd            |    0011      0x3
588        let prefix_nibble = if hex.len() % 2 == 1 {
589            let v = 0x10 + hex[0];
590            hex = &hex[1..];
591            v
592        } else {
593            0x00
594        };
595
596        let pair_count = hex.len() / 2;
597        let mut compact = Vec::with_capacity(1 + pair_count);
598        compact.push(prefix_nibble + if is_leaf { 0x20 } else { 0x00 });
599
600        // SIMD-accelerated packing of nibble pairs → bytes.
601        // SAFETY: compact has capacity for `pair_count` bytes beyond the one already pushed.
602        // pack_nibble_pairs writes exactly `pair_count` bytes starting at offset 1;
603        // set_len then exposes those initialized bytes.
604        unsafe {
605            let out_ptr = compact.as_mut_ptr().add(1);
606            pack_nibble_pairs(hex, out_ptr);
607            compact.set_len(1 + pair_count);
608        }
609
610        compact
611    }
612
613    /// Encodes the nibbles in compact form
614    pub fn decode_compact(compact: &[u8]) -> Self {
615        Self::from_hex(compact_to_hex(compact))
616    }
617
618    /// Returns true if the nibbles contain the leaf flag (16) at the end
619    pub fn is_leaf(&self) -> bool {
620        if self.is_empty() {
621            false
622        } else {
623            self.data[self.data.len() - 1] == 16
624        }
625    }
626
627    /// Combines the nibbles into bytes, trimming the leaf flag if necessary
628    pub fn to_bytes(&self) -> Vec<u8> {
629        // Trim leaf flag
630        let data = if !self.is_empty() && self.is_leaf() {
631            &self.data[..self.len() - 1]
632        } else {
633            &self.data[..]
634        };
635        // Combine nibbles into bytes
636        data.chunks(2)
637            .map(|chunk| match chunk.len() {
638                1 => chunk[0] << 4,
639                _ => chunk[0] << 4 | chunk[1],
640            })
641            .collect::<Vec<_>>()
642    }
643
644    /// Concatenates self and another Nibbles returning a new Nibbles
645    pub fn concat(&self, other: &Nibbles) -> Nibbles {
646        let mut data = Vec::with_capacity(self.data.len() + other.data.len());
647        data.extend_from_slice(&self.data);
648        data.extend_from_slice(&other.data);
649        Nibbles {
650            data,
651            already_consumed: self.already_consumed.clone(),
652        }
653    }
654
655    /// Returns a copy of self with the nibble added at the end
656    pub fn append_new(&self, nibble: u8) -> Nibbles {
657        let mut data = Vec::with_capacity(self.data.len() + 1);
658        data.extend_from_slice(&self.data);
659        data.push(nibble);
660        Nibbles {
661            data,
662            already_consumed: self.already_consumed.clone(),
663        }
664    }
665
666    /// Return already consumed parts of path
667    pub fn current(&self) -> Nibbles {
668        Nibbles {
669            data: self.already_consumed.clone(),
670            already_consumed: vec![],
671        }
672    }
673
674    /// Empties `self.data` and returns the content
675    pub fn take(&mut self) -> Self {
676        Nibbles {
677            data: mem::take(&mut self.data),
678            already_consumed: mem::take(&mut self.already_consumed),
679        }
680    }
681}
682
683impl AsRef<[u8]> for Nibbles {
684    fn as_ref(&self) -> &[u8] {
685        &self.data
686    }
687}
688
689impl RLPEncode for Nibbles {
690    fn encode(&self, buf: &mut dyn bytes::BufMut) {
691        Encoder::new(buf).encode_field(&self.data).finish();
692    }
693}
694
695impl RLPDecode for Nibbles {
696    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
697        let decoder = Decoder::new(rlp)?;
698        let (data, decoder) = decoder.decode_field("data")?;
699        Ok((
700            Self {
701                data,
702                already_consumed: vec![],
703            },
704            decoder.finish()?,
705        ))
706    }
707}
708
709// Code taken from https://github.com/ethereum/go-ethereum/blob/a1093d98eb3260f2abf340903c2d968b2b891c11/trie/encoding.go#L82
710fn compact_to_hex(compact: &[u8]) -> Vec<u8> {
711    if compact.is_empty() {
712        return vec![];
713    }
714    let mut base = keybytes_to_hex(compact);
715    // delete terminator flag
716    let end = if base[0] < 2 {
717        base.len() - 1
718    } else {
719        base.len()
720    };
721    // apply odd flag
722    let chop = 2 - (base[0] & 1) as usize;
723    base.drain(..chop);
724    base.truncate(end - chop);
725    base
726}
727
728// Code taken from https://github.com/ethereum/go-ethereum/blob/a1093d98eb3260f2abf340903c2d968b2b891c11/trie/encoding.go#L96
729fn keybytes_to_hex(keybytes: &[u8]) -> Vec<u8> {
730    let nibble_count = keybytes.len() * 2;
731    let mut nibbles = Vec::with_capacity(nibble_count + 1);
732
733    // SAFETY: we just allocated `nibble_count` capacity; SIMD kernel fills them.
734    #[allow(unsafe_code)]
735    unsafe {
736        expand_bytes_to_nibbles(keybytes, nibbles.as_mut_ptr());
737        nibbles.set_len(nibble_count);
738    }
739    nibbles.push(16); // leaf terminator
740    nibbles
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746
747    /// Scalar reference for expand_bytes_to_nibbles (no SIMD).
748    fn expand_bytes_scalar_ref(bytes: &[u8]) -> Vec<u8> {
749        bytes.iter().flat_map(|&b| [b >> 4, b & 0x0F]).collect()
750    }
751
752    /// Scalar reference for pack_nibble_pairs (no SIMD).
753    fn pack_nibble_pairs_scalar_ref(nibbles: &[u8]) -> Vec<u8> {
754        nibbles
755            .chunks_exact(2)
756            .map(|pair| (pair[0] << 4) | pair[1])
757            .collect()
758    }
759
760    #[test]
761    fn expand_bytes_to_nibbles_matches_scalar() {
762        // Test edge-case lengths: 0, 1, 15, 16, 17, 31, 32, 33, 64
763        for &len in &[0, 1, 2, 15, 16, 17, 31, 32, 33, 48, 64] {
764            let input: Vec<u8> = (0..len).map(|i| (i * 37 + 13) as u8).collect();
765            let expected = expand_bytes_scalar_ref(&input);
766
767            let mut actual = vec![0u8; input.len() * 2];
768            #[allow(unsafe_code)]
769            unsafe {
770                expand_bytes_to_nibbles(&input, actual.as_mut_ptr());
771            }
772            assert_eq!(actual, expected, "mismatch at input length {len}");
773        }
774    }
775
776    #[test]
777    fn pack_nibble_pairs_matches_scalar() {
778        // Test edge-case pair counts: 0, 2, 14, 16, 30, 32, 34, 64
779        for &nibble_count in &[0, 2, 4, 14, 16, 30, 32, 34, 48, 64] {
780            let input: Vec<u8> = (0..nibble_count).map(|i| (i % 16) as u8).collect();
781            let expected = pack_nibble_pairs_scalar_ref(&input);
782
783            let mut actual = vec![0u8; nibble_count / 2];
784            #[allow(unsafe_code)]
785            unsafe {
786                pack_nibble_pairs(&input, actual.as_mut_ptr());
787            }
788            assert_eq!(actual, expected, "mismatch at nibble count {nibble_count}");
789        }
790    }
791
792    #[test]
793    fn expand_then_pack_roundtrip() {
794        for &len in &[0, 1, 16, 32, 33] {
795            let input: Vec<u8> = (0..len).map(|i| (i * 53 + 7) as u8).collect();
796            let mut nibbles = vec![0u8; input.len() * 2];
797            #[allow(unsafe_code)]
798            unsafe {
799                expand_bytes_to_nibbles(&input, nibbles.as_mut_ptr());
800            }
801
802            let mut packed = vec![0u8; input.len()];
803            #[allow(unsafe_code)]
804            unsafe {
805                pack_nibble_pairs(&nibbles, packed.as_mut_ptr());
806            }
807            assert_eq!(packed, input, "roundtrip failed at length {len}");
808        }
809    }
810
811    #[test]
812    fn count_common_prefix_correctness() {
813        // Identical slices
814        let a = vec![1u8, 2, 3, 4, 5];
815        assert_eq!(count_common_prefix(&a, &a), 5);
816
817        // No common prefix
818        assert_eq!(count_common_prefix(&[1, 2, 3], &[4, 5, 6]), 0);
819
820        // Partial match
821        assert_eq!(count_common_prefix(&[1, 2, 3, 4], &[1, 2, 5, 6]), 2);
822
823        // Empty
824        assert_eq!(count_common_prefix(&[], &[1, 2]), 0);
825        assert_eq!(count_common_prefix(&[1, 2], &[]), 0);
826        assert_eq!(count_common_prefix(&[], &[]), 0);
827
828        // Long match crossing SIMD boundaries (>16 bytes)
829        let long_a: Vec<u8> = (0..33).collect();
830        let mut long_b = long_a.clone();
831        long_b[32] = 255;
832        assert_eq!(count_common_prefix(&long_a, &long_b), 32);
833    }
834
835    #[test]
836    fn from_raw_leaf_flag() {
837        let bytes = &[0xAB, 0xCD];
838        let with_leaf = Nibbles::from_raw(bytes, true);
839        let without_leaf = Nibbles::from_raw(bytes, false);
840
841        assert_eq!(with_leaf.data, vec![0x0A, 0x0B, 0x0C, 0x0D, 16]);
842        assert_eq!(without_leaf.data, vec![0x0A, 0x0B, 0x0C, 0x0D]);
843    }
844}