nam_sparse_merkle_tree/
internal_key.rs1#[cfg(feature = "borsh")]
2use borsh::{BorshDeserialize, BorshSerialize};
3#[cfg(feature = "borsh")]
4use core::convert::TryInto;
5use std::fmt::Debug;
6use std::io::Read;
7#[cfg(feature = "borsh")]
8use std::io::Write;
9
10#[derive(Eq, PartialEq, Debug, Hash, Clone, Copy, PartialOrd, Ord)]
12pub struct InternalKey<const N: usize>([u8; N]);
13
14#[cfg(feature = "borsh")]
15impl<const N: usize> BorshSerialize for InternalKey<N> {
16 fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
17 let bytes = self.0.to_vec();
18 BorshSerialize::serialize(&bytes, writer)
19 }
20}
21
22#[cfg(feature = "borsh")]
23impl<const N: usize> BorshDeserialize for InternalKey<N> {
24 fn deserialize_reader<R: Read>(reader: &mut R) -> std::io::Result<Self> {
25 use std::io::ErrorKind;
26 let bytes: Vec<u8> = BorshDeserialize::deserialize_reader(reader)?;
27 let bytes: [u8; N] = bytes.try_into().map_err(|_| {
28 std::io::Error::new(ErrorKind::InvalidData, "Input byte vector is too large")
29 })?;
30 Ok(InternalKey(bytes))
31 }
32}
33
34const BYTE_SIZE: usize = 8;
35
36impl<const N: usize> InternalKey<N> {
37 pub fn new(array: [u8; N]) -> Self {
38 Self(array)
39 }
40
41 pub fn as_slice(&self) -> &[u8] {
42 &self.0[..]
43 }
44
45 pub const fn zero() -> Self {
46 InternalKey([0u8; N])
47 }
48
49 pub const fn max_index() -> usize {
50 N - 1
51 }
52
53 #[inline]
54 pub fn get_bit(&self, i: usize) -> bool {
55 if i / BYTE_SIZE > Self::max_index() {
56 println!("Hey");
57 }
58 let byte_pos = Self::max_index() - i / BYTE_SIZE;
59 let bit_pos = i % BYTE_SIZE;
60 let bit = self.0[byte_pos] >> bit_pos & 1;
61 bit != 0
62 }
63
64 #[inline]
65 pub fn set_bit(&mut self, i: usize) {
66 let byte_pos = Self::max_index() - i / BYTE_SIZE;
67 let bit_pos = i % BYTE_SIZE;
68 self.0[byte_pos as usize] |= 1 << bit_pos as u8;
69 }
70
71 #[inline]
72 pub fn clear_bit(&mut self, i: usize) {
73 let byte_pos = Self::max_index() - i / BYTE_SIZE;
74 let bit_pos = i % BYTE_SIZE;
75 self.0[byte_pos as usize] &= !((1 << bit_pos) as u8);
76 }
77
78 pub fn fork_height(&self, key: &InternalKey<N>) -> usize {
82 let max = (BYTE_SIZE * N) as usize;
83 for h in (0..max).rev() {
84 if self.get_bit(h) != key.get_bit(h) {
85 return h;
86 }
87 }
88 0
89 }
90
91 pub fn parent_path(&self, height: usize) -> Self {
94 height
95 .checked_add(1)
96 .map(|i| self.copy_bits(i..))
97 .unwrap_or_else(InternalKey::zero)
98 }
99
100 pub fn copy_bits(&self, range: impl core::ops::RangeBounds<usize>) -> Self {
102 let array_size = N;
103 let max = 8 * N;
104 use core::ops::Bound;
105
106 let mut target = InternalKey::zero();
107 let start = match range.start_bound() {
108 Bound::Included(&i) => i as usize,
109 Bound::Excluded(&i) => panic!("do not allows excluded start: {}", i),
110 Bound::Unbounded => 0,
111 };
112
113 let mut end = match range.end_bound() {
114 Bound::Included(&i) => i.saturating_add(1) as usize,
115 Bound::Excluded(&i) => i as usize,
116 Bound::Unbounded => max,
117 };
118
119 if start >= max {
120 return target;
121 } else if end > max {
122 end = max;
123 }
124
125 if end < start {
126 panic!("end can't less than start: start {} end {}", start, end);
127 }
128
129 let end_byte = {
130 let remain = if start % BYTE_SIZE != 0 { 1 } else { 0 };
131 array_size - start / BYTE_SIZE - remain
132 };
133 let start_byte = array_size - end / BYTE_SIZE;
134 if start_byte < self.0.len() && start_byte <= end_byte {
136 target.0[start_byte..end_byte].copy_from_slice(&self.0[start_byte..end_byte]);
137 }
138
139 for i in (start..core::cmp::min((array_size - end_byte) * BYTE_SIZE, end))
141 .chain(core::cmp::max((array_size - start_byte) * BYTE_SIZE, start)..end)
142 {
143 if self.get_bit(i) {
144 target.set_bit(i)
145 }
146 }
147 target
148 }
149}
150
151impl<const N: usize> From<[u8; N]> for InternalKey<N> {
152 fn from(v: [u8; N]) -> Self {
153 Self::new(v)
154 }
155}
156
157impl<const N: usize> From<InternalKey<N>> for [u8; N] {
158 fn from(v: InternalKey<N>) -> Self {
159 v.0
160 }
161}