1use std::{cmp, mem};
2
3use ethrex_rlp::{
4 decode::RLPDecode,
5 encode::RLPEncode,
6 error::RLPDecodeError,
7 structs::{Decoder, Encoder},
8};
9
10#[inline]
22#[allow(unsafe_code)]
23unsafe fn expand_bytes_to_nibbles(bytes: &[u8], output: *mut u8) {
24 #[cfg(target_arch = "x86_64")]
25 {
26 unsafe { expand_bytes_to_nibbles_x86_64(bytes, output) };
28 return;
29 }
30 #[cfg(target_arch = "aarch64")]
31 {
32 unsafe { expand_bytes_to_nibbles_aarch64(bytes, output) };
34 return;
35 }
36 #[allow(unreachable_code)]
38 unsafe {
40 expand_bytes_to_nibbles_scalar(bytes, output)
41 };
42}
43
44#[cfg(target_arch = "x86_64")]
45#[allow(unsafe_code)]
46#[inline]
47unsafe fn expand_bytes_to_nibbles_x86_64(bytes: &[u8], output: *mut u8) {
48 use std::arch::x86_64::*;
49
50 let n = bytes.len();
51 let mut i = 0usize;
52
53 #[cfg(target_feature = "avx2")]
57 unsafe {
59 let mask256 = _mm256_set1_epi8(0x0F_u8 as i8);
60 while i + 32 <= n {
61 let v = _mm256_loadu_si256(bytes.as_ptr().add(i).cast::<__m256i>());
63 let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), mask256);
65 let lo = _mm256_and_si256(v, mask256);
67 let unpack_lo = _mm256_unpacklo_epi8(hi, lo);
71 let unpack_hi = _mm256_unpackhi_epi8(hi, lo);
72 let out_lo = _mm256_permute2x128_si256::<0x20>(unpack_lo, unpack_hi);
76 let out_hi = _mm256_permute2x128_si256::<0x31>(unpack_lo, unpack_hi);
77 _mm256_storeu_si256(output.add(i * 2).cast::<__m256i>(), out_lo);
78 _mm256_storeu_si256(output.add(i * 2 + 32).cast::<__m256i>(), out_hi);
79 i += 32;
80 }
81 }
82
83 unsafe {
87 let mask128 = _mm_set1_epi8(0x0F_u8 as i8);
88 while i + 16 <= n {
89 let v = _mm_loadu_si128(bytes.as_ptr().add(i).cast::<__m128i>());
90 let hi = _mm_and_si128(_mm_srli_epi16(v, 4), mask128);
91 let lo = _mm_and_si128(v, mask128);
92 let lo16 = _mm_unpacklo_epi8(hi, lo);
93 let hi16 = _mm_unpackhi_epi8(hi, lo);
94 _mm_storeu_si128(output.add(i * 2).cast::<__m128i>(), lo16);
95 _mm_storeu_si128(output.add(i * 2 + 16).cast::<__m128i>(), hi16);
96 i += 16;
97 }
98
99 while i < n {
101 let b = *bytes.get_unchecked(i);
102 *output.add(i * 2) = b >> 4;
103 *output.add(i * 2 + 1) = b & 0x0F;
104 i += 1;
105 }
106 }
107}
108
109#[cfg(target_arch = "aarch64")]
110#[target_feature(enable = "neon")]
111#[allow(unsafe_code)]
112#[inline]
113unsafe fn expand_bytes_to_nibbles_aarch64(bytes: &[u8], output: *mut u8) {
114 use std::arch::aarch64::*;
115
116 let n = bytes.len();
117 let mut i = 0usize;
118
119 unsafe {
122 let mask_0f = vdupq_n_u8(0x0F);
123 while i + 16 <= n {
124 let v = vld1q_u8(bytes.as_ptr().add(i));
125 let hi = vshrq_n_u8(v, 4);
127 let lo = vandq_u8(v, mask_0f);
128 let lo16 = vzip1q_u8(hi, lo); let hi16 = vzip2q_u8(hi, lo); vst1q_u8(output.add(i * 2), lo16);
132 vst1q_u8(output.add(i * 2 + 16), hi16);
133 i += 16;
134 }
135
136 while i < n {
138 let b = *bytes.get_unchecked(i);
139 *output.add(i * 2) = b >> 4;
140 *output.add(i * 2 + 1) = b & 0x0F;
141 i += 1;
142 }
143 }
144}
145
146#[allow(unsafe_code)]
147#[inline]
148unsafe fn expand_bytes_to_nibbles_scalar(bytes: &[u8], output: *mut u8) {
149 unsafe {
151 for (i, &b) in bytes.iter().enumerate() {
152 *output.add(i * 2) = b >> 4;
153 *output.add(i * 2 + 1) = b & 0x0F;
154 }
155 }
156}
157
158#[inline]
175#[allow(unsafe_code)]
176unsafe fn pack_nibble_pairs(nibbles: &[u8], output: *mut u8) {
177 debug_assert!(nibbles.len().is_multiple_of(2));
178 #[cfg(target_arch = "x86_64")]
179 {
180 unsafe { pack_nibble_pairs_x86_64(nibbles, output) };
181 return;
182 }
183 #[cfg(target_arch = "aarch64")]
184 {
185 unsafe { pack_nibble_pairs_aarch64(nibbles, output) };
186 return;
187 }
188 #[allow(unreachable_code)]
189 unsafe {
190 pack_nibble_pairs_scalar(nibbles, output)
191 };
192}
193
194#[cfg(target_arch = "x86_64")]
195#[allow(unsafe_code)]
196#[inline]
197unsafe fn pack_nibble_pairs_x86_64(nibbles: &[u8], output: *mut u8) {
198 let n = nibbles.len(); let mut i = 0usize; let mut o = 0usize; #[cfg(target_feature = "ssse3")]
206 unsafe {
208 use std::arch::x86_64::*;
209 let weights = _mm_set1_epi16(0x0110_u16 as i16); while i + 32 <= n {
212 let lo_chunk = _mm_loadu_si128(nibbles.as_ptr().add(i).cast::<__m128i>());
214 let hi_chunk = _mm_loadu_si128(nibbles.as_ptr().add(i + 16).cast::<__m128i>());
215 let lo_packed = _mm_maddubs_epi16(lo_chunk, weights);
217 let hi_packed = _mm_maddubs_epi16(hi_chunk, weights);
218 let result = _mm_packus_epi16(lo_packed, hi_packed);
220 _mm_storeu_si128(output.add(o).cast::<__m128i>(), result);
221 i += 32;
222 o += 16;
223 }
224 }
225
226 unsafe {
229 while i + 2 <= n {
230 *output.add(o) = (*nibbles.get_unchecked(i) << 4) | *nibbles.get_unchecked(i + 1);
231 i += 2;
232 o += 1;
233 }
234 }
235}
236
237#[cfg(target_arch = "aarch64")]
238#[target_feature(enable = "neon")]
239#[allow(unsafe_code)]
240#[inline]
241unsafe fn pack_nibble_pairs_aarch64(nibbles: &[u8], output: *mut u8) {
242 use std::arch::aarch64::*;
243
244 let n = nibbles.len();
245 let mut i = 0usize;
246 let mut o = 0usize;
247
248 unsafe {
250 while i + 32 <= n {
251 let v = vld2q_u8(nibbles.as_ptr().add(i));
253 let packed = vorrq_u8(vshlq_n_u8(v.0, 4), v.1);
256 vst1q_u8(output.add(o), packed);
257 i += 32;
258 o += 16;
259 }
260 while i + 2 <= n {
261 *output.add(o) = (*nibbles.get_unchecked(i) << 4) | *nibbles.get_unchecked(i + 1);
262 i += 2;
263 o += 1;
264 }
265 }
266}
267
268#[allow(unsafe_code)]
269#[inline]
270unsafe fn pack_nibble_pairs_scalar(nibbles: &[u8], output: *mut u8) {
271 unsafe {
273 let mut o = 0usize;
274 let mut i = 0usize;
275 let n = nibbles.len();
276 while i + 2 <= n {
277 *output.add(o) = (*nibbles.get_unchecked(i) << 4) | *nibbles.get_unchecked(i + 1);
278 i += 2;
279 o += 1;
280 }
281 }
282}
283#[allow(unsafe_code)]
292#[inline]
293fn count_common_prefix(a: &[u8], b: &[u8]) -> usize {
294 #[cfg(target_arch = "x86_64")]
295 {
296 return unsafe { count_common_prefix_x86_64(a, b) };
298 }
299 #[cfg(target_arch = "aarch64")]
300 {
301 return unsafe { count_common_prefix_aarch64(a, b) };
303 }
304 #[allow(unreachable_code)]
305 count_common_prefix_scalar(a, b)
306}
307
308#[cfg(target_arch = "x86_64")]
309#[allow(unsafe_code)]
310#[inline]
311unsafe fn count_common_prefix_x86_64(a: &[u8], b: &[u8]) -> usize {
312 use std::arch::x86_64::*;
313
314 let n = a.len().min(b.len());
315 let mut i = 0usize;
316
317 #[cfg(target_feature = "avx2")]
318 unsafe {
320 while i + 32 <= n {
321 let va = _mm256_loadu_si256(a.as_ptr().add(i).cast::<__m256i>());
322 let vb = _mm256_loadu_si256(b.as_ptr().add(i).cast::<__m256i>());
323 let eq = _mm256_cmpeq_epi8(va, vb);
325 let mask = _mm256_movemask_epi8(eq) as u32;
327 if mask != 0xFFFF_FFFF {
328 return i + mask.trailing_ones() as usize;
330 }
331 i += 32;
332 }
333 }
334
335 unsafe {
338 while i + 16 <= n {
339 let va = _mm_loadu_si128(a.as_ptr().add(i).cast::<__m128i>());
340 let vb = _mm_loadu_si128(b.as_ptr().add(i).cast::<__m128i>());
341 let eq = _mm_cmpeq_epi8(va, vb);
342 let mask = _mm_movemask_epi8(eq) as u16;
343 if mask != 0xFFFF {
344 return i + mask.trailing_ones() as usize;
345 }
346 i += 16;
347 }
348 }
349
350 i + count_common_prefix_scalar(&a[i..n], &b[i..n])
352}
353
354#[cfg(target_arch = "aarch64")]
355#[target_feature(enable = "neon")]
356#[allow(unsafe_code)]
357#[inline]
358unsafe fn count_common_prefix_aarch64(a: &[u8], b: &[u8]) -> usize {
359 use std::arch::aarch64::*;
360
361 let n = a.len().min(b.len());
362 let mut i = 0usize;
363
364 unsafe {
366 while i + 16 <= n {
367 let va = vld1q_u8(a.as_ptr().add(i));
368 let vb = vld1q_u8(b.as_ptr().add(i));
369 let eq = vceqq_u8(va, vb);
371 if vminvq_u8(eq) == 0xFF {
373 i += 16;
374 continue;
375 }
376 let mut eq_arr = [0u8; 16];
378 vst1q_u8(eq_arr.as_mut_ptr(), eq);
379 for (j, &byte) in eq_arr.iter().enumerate() {
380 if byte == 0 {
381 return i + j;
382 }
383 }
384 unreachable!()
385 }
386 }
387
388 i + count_common_prefix_scalar(&a[i..n], &b[i..n])
389}
390
391#[inline]
392fn count_common_prefix_scalar(a: &[u8], b: &[u8]) -> usize {
393 a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
394}
395#[derive(
401 Debug,
402 Clone,
403 Default,
404 serde::Serialize,
405 serde::Deserialize,
406 rkyv::Deserialize,
407 rkyv::Serialize,
408 rkyv::Archive,
409)]
410pub struct Nibbles {
411 data: Vec<u8>,
412 already_consumed: Vec<u8>,
415}
416
417impl PartialEq for Nibbles {
420 fn eq(&self, other: &Nibbles) -> bool {
421 self.data == other.data
422 }
423}
424
425impl Eq for Nibbles {}
426
427impl PartialOrd for Nibbles {
428 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
429 Some(self.cmp(other))
430 }
431}
432
433impl Ord for Nibbles {
434 fn cmp(&self, other: &Self) -> cmp::Ordering {
435 self.data.cmp(&other.data)
436 }
437}
438
439impl std::hash::Hash for Nibbles {
440 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
441 self.data.hash(state);
442 }
443}
444
445impl Nibbles {
446 pub const fn from_hex(hex: Vec<u8>) -> Self {
448 Self {
449 data: hex,
450 already_consumed: vec![],
451 }
452 }
453
454 pub fn from_bytes(bytes: &[u8]) -> Self {
456 Self::from_raw(bytes, true)
457 }
458
459 pub fn from_raw(bytes: &[u8], is_leaf: bool) -> Self {
461 let extra = usize::from(is_leaf);
462 let mut data = Vec::with_capacity(bytes.len() * 2 + extra);
463
464 #[allow(unsafe_code)]
467 unsafe {
468 expand_bytes_to_nibbles(bytes, data.as_mut_ptr());
469 data.set_len(bytes.len() * 2);
470 }
471
472 if is_leaf {
473 data.push(16);
474 }
475
476 Self {
477 data,
478 already_consumed: vec![],
479 }
480 }
481
482 pub fn into_vec(self) -> Vec<u8> {
483 self.data
484 }
485
486 pub fn len(&self) -> usize {
488 self.data.len()
489 }
490
491 pub fn is_empty(&self) -> bool {
493 self.data.is_empty()
494 }
495
496 pub fn skip_prefix(&mut self, prefix: &Nibbles) -> bool {
499 if self.len() >= prefix.len() && &self.data[..prefix.len()] == prefix.as_ref() {
500 self.already_consumed.extend_from_slice(&prefix.data);
501 self.data.drain(..prefix.len());
502 true
503 } else {
504 false
505 }
506 }
507
508 pub fn compare_prefix(&self, prefix: &Nibbles) -> cmp::Ordering {
510 if self.len() > prefix.len() {
511 self.data[..prefix.len()].cmp(&prefix.data)
512 } else {
513 self.data[..].cmp(&prefix.data[..self.len()])
514 }
515 }
516
517 pub fn count_prefix(&self, other: &Nibbles) -> usize {
519 count_common_prefix(self.as_ref(), other.as_ref())
520 }
521
522 #[allow(clippy::should_implement_trait)]
524 pub fn next(&mut self) -> Option<u8> {
525 (!self.is_empty()).then(|| {
526 self.already_consumed.push(self.data[0]);
527 self.data.remove(0)
528 })
529 }
530
531 pub fn next_choice(&mut self) -> Option<usize> {
533 self.next().filter(|choice| *choice < 16).map(usize::from)
534 }
535
536 pub fn offset(&self, offset: usize) -> Nibbles {
538 let mut already_consumed = Vec::with_capacity(self.already_consumed.len() + offset);
539 already_consumed.extend_from_slice(&self.already_consumed);
540 already_consumed.extend_from_slice(&self.data[..offset]);
541 Nibbles {
542 data: self.data[offset..].to_vec(),
543 already_consumed,
544 }
545 }
546
547 pub fn slice(&self, start: usize, end: usize) -> Nibbles {
549 Nibbles::from_hex(self.data[start..end].to_vec())
550 }
551
552 pub fn extend(&mut self, other: &Nibbles) {
554 self.data.extend_from_slice(other.as_ref());
555 }
556
557 pub fn at(&self, i: usize) -> usize {
559 self.data[i] as usize
560 }
561
562 pub fn prepend(&mut self, nibble: u8) {
564 self.data.insert(0, nibble);
565 }
566
567 pub fn append(&mut self, nibble: u8) {
569 self.data.push(nibble);
570 }
571
572 #[allow(unsafe_code)]
575 pub fn encode_compact(&self) -> Vec<u8> {
576 let is_leaf = self.is_leaf();
577 let mut hex = if is_leaf {
578 &self.data[0..self.data.len() - 1]
579 } else {
580 &self.data[0..]
581 };
582 let prefix_nibble = if hex.len() % 2 == 1 {
589 let v = 0x10 + hex[0];
590 hex = &hex[1..];
591 v
592 } else {
593 0x00
594 };
595
596 let pair_count = hex.len() / 2;
597 let mut compact = Vec::with_capacity(1 + pair_count);
598 compact.push(prefix_nibble + if is_leaf { 0x20 } else { 0x00 });
599
600 unsafe {
605 let out_ptr = compact.as_mut_ptr().add(1);
606 pack_nibble_pairs(hex, out_ptr);
607 compact.set_len(1 + pair_count);
608 }
609
610 compact
611 }
612
613 pub fn decode_compact(compact: &[u8]) -> Self {
615 Self::from_hex(compact_to_hex(compact))
616 }
617
618 pub fn is_leaf(&self) -> bool {
620 if self.is_empty() {
621 false
622 } else {
623 self.data[self.data.len() - 1] == 16
624 }
625 }
626
627 pub fn to_bytes(&self) -> Vec<u8> {
629 let data = if !self.is_empty() && self.is_leaf() {
631 &self.data[..self.len() - 1]
632 } else {
633 &self.data[..]
634 };
635 data.chunks(2)
637 .map(|chunk| match chunk.len() {
638 1 => chunk[0] << 4,
639 _ => chunk[0] << 4 | chunk[1],
640 })
641 .collect::<Vec<_>>()
642 }
643
644 pub fn concat(&self, other: &Nibbles) -> Nibbles {
646 let mut data = Vec::with_capacity(self.data.len() + other.data.len());
647 data.extend_from_slice(&self.data);
648 data.extend_from_slice(&other.data);
649 Nibbles {
650 data,
651 already_consumed: self.already_consumed.clone(),
652 }
653 }
654
655 pub fn append_new(&self, nibble: u8) -> Nibbles {
657 let mut data = Vec::with_capacity(self.data.len() + 1);
658 data.extend_from_slice(&self.data);
659 data.push(nibble);
660 Nibbles {
661 data,
662 already_consumed: self.already_consumed.clone(),
663 }
664 }
665
666 pub fn current(&self) -> Nibbles {
668 Nibbles {
669 data: self.already_consumed.clone(),
670 already_consumed: vec![],
671 }
672 }
673
674 pub fn take(&mut self) -> Self {
676 Nibbles {
677 data: mem::take(&mut self.data),
678 already_consumed: mem::take(&mut self.already_consumed),
679 }
680 }
681}
682
683impl AsRef<[u8]> for Nibbles {
684 fn as_ref(&self) -> &[u8] {
685 &self.data
686 }
687}
688
689impl RLPEncode for Nibbles {
690 fn encode(&self, buf: &mut dyn bytes::BufMut) {
691 Encoder::new(buf).encode_field(&self.data).finish();
692 }
693}
694
695impl RLPDecode for Nibbles {
696 fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
697 let decoder = Decoder::new(rlp)?;
698 let (data, decoder) = decoder.decode_field("data")?;
699 Ok((
700 Self {
701 data,
702 already_consumed: vec![],
703 },
704 decoder.finish()?,
705 ))
706 }
707}
708
709fn compact_to_hex(compact: &[u8]) -> Vec<u8> {
711 if compact.is_empty() {
712 return vec![];
713 }
714 let mut base = keybytes_to_hex(compact);
715 let end = if base[0] < 2 {
717 base.len() - 1
718 } else {
719 base.len()
720 };
721 let chop = 2 - (base[0] & 1) as usize;
723 base.drain(..chop);
724 base.truncate(end - chop);
725 base
726}
727
728fn keybytes_to_hex(keybytes: &[u8]) -> Vec<u8> {
730 let nibble_count = keybytes.len() * 2;
731 let mut nibbles = Vec::with_capacity(nibble_count + 1);
732
733 #[allow(unsafe_code)]
735 unsafe {
736 expand_bytes_to_nibbles(keybytes, nibbles.as_mut_ptr());
737 nibbles.set_len(nibble_count);
738 }
739 nibbles.push(16); nibbles
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746
747 fn expand_bytes_scalar_ref(bytes: &[u8]) -> Vec<u8> {
749 bytes.iter().flat_map(|&b| [b >> 4, b & 0x0F]).collect()
750 }
751
752 fn pack_nibble_pairs_scalar_ref(nibbles: &[u8]) -> Vec<u8> {
754 nibbles
755 .chunks_exact(2)
756 .map(|pair| (pair[0] << 4) | pair[1])
757 .collect()
758 }
759
760 #[test]
761 fn expand_bytes_to_nibbles_matches_scalar() {
762 for &len in &[0, 1, 2, 15, 16, 17, 31, 32, 33, 48, 64] {
764 let input: Vec<u8> = (0..len).map(|i| (i * 37 + 13) as u8).collect();
765 let expected = expand_bytes_scalar_ref(&input);
766
767 let mut actual = vec![0u8; input.len() * 2];
768 #[allow(unsafe_code)]
769 unsafe {
770 expand_bytes_to_nibbles(&input, actual.as_mut_ptr());
771 }
772 assert_eq!(actual, expected, "mismatch at input length {len}");
773 }
774 }
775
776 #[test]
777 fn pack_nibble_pairs_matches_scalar() {
778 for &nibble_count in &[0, 2, 4, 14, 16, 30, 32, 34, 48, 64] {
780 let input: Vec<u8> = (0..nibble_count).map(|i| (i % 16) as u8).collect();
781 let expected = pack_nibble_pairs_scalar_ref(&input);
782
783 let mut actual = vec![0u8; nibble_count / 2];
784 #[allow(unsafe_code)]
785 unsafe {
786 pack_nibble_pairs(&input, actual.as_mut_ptr());
787 }
788 assert_eq!(actual, expected, "mismatch at nibble count {nibble_count}");
789 }
790 }
791
792 #[test]
793 fn expand_then_pack_roundtrip() {
794 for &len in &[0, 1, 16, 32, 33] {
795 let input: Vec<u8> = (0..len).map(|i| (i * 53 + 7) as u8).collect();
796 let mut nibbles = vec![0u8; input.len() * 2];
797 #[allow(unsafe_code)]
798 unsafe {
799 expand_bytes_to_nibbles(&input, nibbles.as_mut_ptr());
800 }
801
802 let mut packed = vec![0u8; input.len()];
803 #[allow(unsafe_code)]
804 unsafe {
805 pack_nibble_pairs(&nibbles, packed.as_mut_ptr());
806 }
807 assert_eq!(packed, input, "roundtrip failed at length {len}");
808 }
809 }
810
811 #[test]
812 fn count_common_prefix_correctness() {
813 let a = vec![1u8, 2, 3, 4, 5];
815 assert_eq!(count_common_prefix(&a, &a), 5);
816
817 assert_eq!(count_common_prefix(&[1, 2, 3], &[4, 5, 6]), 0);
819
820 assert_eq!(count_common_prefix(&[1, 2, 3, 4], &[1, 2, 5, 6]), 2);
822
823 assert_eq!(count_common_prefix(&[], &[1, 2]), 0);
825 assert_eq!(count_common_prefix(&[1, 2], &[]), 0);
826 assert_eq!(count_common_prefix(&[], &[]), 0);
827
828 let long_a: Vec<u8> = (0..33).collect();
830 let mut long_b = long_a.clone();
831 long_b[32] = 255;
832 assert_eq!(count_common_prefix(&long_a, &long_b), 32);
833 }
834
835 #[test]
836 fn from_raw_leaf_flag() {
837 let bytes = &[0xAB, 0xCD];
838 let with_leaf = Nibbles::from_raw(bytes, true);
839 let without_leaf = Nibbles::from_raw(bytes, false);
840
841 assert_eq!(with_leaf.data, vec![0x0A, 0x0B, 0x0C, 0x0D, 16]);
842 assert_eq!(without_leaf.data, vec![0x0A, 0x0B, 0x0C, 0x0D]);
843 }
844}